README and Docs updates with A100 TensorRT times (#270)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -84,6 +84,7 @@ class BaseTrainer:
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
self.console = LOGGER
|
||||
self.validator = None
|
||||
@ -113,7 +114,6 @@ class BaseTrainer:
|
||||
print_args(dict(self.args))
|
||||
|
||||
# Device
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
|
||||
self.amp = self.device.type != 'cpu'
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if self.device.type == 'cpu':
|
||||
@ -164,7 +164,15 @@ class BaseTrainer:
|
||||
callback(self)
|
||||
|
||||
def train(self):
|
||||
world_size = torch.cuda.device_count()
|
||||
# Allow device='', device=None on Multi-GPU systems to default to device=0
|
||||
if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
|
||||
world_size = torch.cuda.device_count()
|
||||
elif torch.cuda.is_available(): # i.e. device=None or device=''
|
||||
world_size = 1 # default to device 0
|
||||
else: # i.e. device='cpu' or 'mps'
|
||||
world_size = 0
|
||||
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||
command = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
|
Reference in New Issue
Block a user