Update metrics names (#85)

This commit is contained in:
Glenn Jocher
2022-12-24 02:32:24 +01:00
committed by GitHub
parent 6432afc5f9
commit 248d54ca03
9 changed files with 30 additions and 36 deletions

View File

@ -82,6 +82,7 @@ class BaseTrainer:
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = None
self.csv = self.save_dir / 'results.csv'
for callback, func in callbacks.default_callbacks.items():
@ -106,7 +107,7 @@ class BaseTrainer:
def train(self):
world_size = torch.cuda.device_count()
if world_size > 1 and not ("LOCAL_RANK" in os.environ):
if world_size > 1 and "LOCAL_RANK" not in os.environ:
command = generate_ddp_command(world_size, self)
subprocess.Popen(command)
ddp_cleanup(command, self)
@ -154,11 +155,9 @@ class BaseTrainer:
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
if rank in {0, -1}:
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
validator = self.get_validator()
# init metric, for plot_results
metric_keys = validator.metric_keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.validator = validator
self.validator = self.get_validator()
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1):