|
|
@ -82,6 +82,7 @@ class BaseTrainer:
|
|
|
|
self.fitness = None
|
|
|
|
self.fitness = None
|
|
|
|
self.loss = None
|
|
|
|
self.loss = None
|
|
|
|
self.tloss = None
|
|
|
|
self.tloss = None
|
|
|
|
|
|
|
|
self.loss_names = None
|
|
|
|
self.csv = self.save_dir / 'results.csv'
|
|
|
|
self.csv = self.save_dir / 'results.csv'
|
|
|
|
|
|
|
|
|
|
|
|
for callback, func in callbacks.default_callbacks.items():
|
|
|
|
for callback, func in callbacks.default_callbacks.items():
|
|
|
@ -106,7 +107,7 @@ class BaseTrainer:
|
|
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
def train(self):
|
|
|
|
world_size = torch.cuda.device_count()
|
|
|
|
world_size = torch.cuda.device_count()
|
|
|
|
if world_size > 1 and not ("LOCAL_RANK" in os.environ):
|
|
|
|
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
|
|
|
command = generate_ddp_command(world_size, self)
|
|
|
|
command = generate_ddp_command(world_size, self)
|
|
|
|
subprocess.Popen(command)
|
|
|
|
subprocess.Popen(command)
|
|
|
|
ddp_cleanup(command, self)
|
|
|
|
ddp_cleanup(command, self)
|
|
|
@ -154,11 +155,9 @@ class BaseTrainer:
|
|
|
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
|
|
|
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
|
|
|
|
if rank in {0, -1}:
|
|
|
|
if rank in {0, -1}:
|
|
|
|
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
|
|
|
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
|
|
|
validator = self.get_validator()
|
|
|
|
self.validator = self.get_validator()
|
|
|
|
# init metric, for plot_results
|
|
|
|
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
|
|
|
metric_keys = validator.metric_keys + self.label_loss_items(prefix="val")
|
|
|
|
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
|
|
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
|
|
|
|
|
|
|
self.validator = validator
|
|
|
|
|
|
|
|
self.ema = ModelEMA(self.model)
|
|
|
|
self.ema = ModelEMA(self.model)
|
|
|
|
|
|
|
|
|
|
|
|
def _do_train(self, rank=-1, world_size=1):
|
|
|
|
def _do_train(self, rank=-1, world_size=1):
|
|
|
|