ultralytics 8.0.58
new SimpleClass, fixes and updates (#1636)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -174,7 +174,12 @@ class BaseTrainer:
|
||||
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
||||
# Argument checks
|
||||
if self.args.rect:
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
|
||||
self.args.rect = False
|
||||
# Command
|
||||
cmd, file = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
LOGGER.info(f'Running DDP command {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
@ -183,17 +188,15 @@ class BaseTrainer:
|
||||
finally:
|
||||
ddp_cleanup(self, str(file))
|
||||
else:
|
||||
self._do_train(RANK, world_size)
|
||||
self._do_train(world_size)
|
||||
|
||||
def _setup_ddp(self, rank, world_size):
|
||||
# os.environ['MASTER_ADDR'] = 'localhost'
|
||||
# os.environ['MASTER_PORT'] = '9020'
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
|
||||
def _setup_ddp(self, world_size):
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=RANK, world_size=world_size)
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
def _setup_train(self, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
@ -213,7 +216,7 @@ class BaseTrainer:
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
self.model = DDP(self.model, device_ids=[RANK])
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||
@ -243,8 +246,8 @@ class BaseTrainer:
|
||||
|
||||
# dataloaders
|
||||
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
|
||||
if rank in (-1, 0):
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
||||
if RANK in (-1, 0):
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||
self.validator = self.get_validator()
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||
@ -256,11 +259,11 @@ class BaseTrainer:
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks('on_pretrain_routine_end')
|
||||
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
def _do_train(self, world_size=1):
|
||||
if world_size > 1:
|
||||
self._setup_ddp(rank, world_size)
|
||||
self._setup_ddp(world_size)
|
||||
|
||||
self._setup_train(rank, world_size)
|
||||
self._setup_train(world_size)
|
||||
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
@ -280,7 +283,7 @@ class BaseTrainer:
|
||||
self.epoch = epoch
|
||||
self.run_callbacks('on_train_epoch_start')
|
||||
self.model.train()
|
||||
if rank != -1:
|
||||
if RANK != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
@ -291,7 +294,7 @@ class BaseTrainer:
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
if rank in (-1, 0):
|
||||
if RANK in (-1, 0):
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||
self.tloss = None
|
||||
@ -315,7 +318,7 @@ class BaseTrainer:
|
||||
batch = self.preprocess_batch(batch)
|
||||
preds = self.model(batch['img'])
|
||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||
if rank != -1:
|
||||
if RANK != -1:
|
||||
self.loss *= world_size
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
@ -332,7 +335,7 @@ class BaseTrainer:
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if rank in (-1, 0):
|
||||
if RANK in (-1, 0):
|
||||
pbar.set_description(
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
||||
@ -347,7 +350,7 @@ class BaseTrainer:
|
||||
self.scheduler.step()
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
|
||||
if rank in (-1, 0):
|
||||
if RANK in (-1, 0):
|
||||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
@ -377,7 +380,7 @@ class BaseTrainer:
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
|
||||
if rank in (-1, 0):
|
||||
if RANK in (-1, 0):
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
@ -408,7 +411,8 @@ class BaseTrainer:
|
||||
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
|
||||
del ckpt
|
||||
|
||||
def get_dataset(self, data):
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user