Add best.pt val and COCO pycocotools val (#98)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -22,6 +22,7 @@ from tqdm import tqdm
|
||||
|
||||
import ultralytics.yolo.utils as utils
|
||||
import ultralytics.yolo.utils.callbacks as callbacks
|
||||
from ultralytics import __version__
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||
@ -52,7 +53,8 @@ class BaseTrainer:
|
||||
self.batch_size = self.args.batch_size
|
||||
self.epochs = self.args.epochs
|
||||
self.start_epoch = 0
|
||||
print_args(dict(self.args))
|
||||
if RANK == -1:
|
||||
print_args(dict(self.args))
|
||||
|
||||
# Save run settings
|
||||
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
||||
@ -109,7 +111,6 @@ class BaseTrainer:
|
||||
world_size = torch.cuda.device_count()
|
||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||
command = generate_ddp_command(world_size, self)
|
||||
print('DDP command: ', command)
|
||||
try:
|
||||
subprocess.run(command)
|
||||
except Exception as e:
|
||||
@ -124,7 +125,7 @@ class BaseTrainer:
|
||||
# os.environ['MASTER_PORT'] = '9020'
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||
self.console.info(f"DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}")
|
||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
@ -259,8 +260,7 @@ class BaseTrainer:
|
||||
if not self.args.noval or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.trigger_callbacks('on_val_end')
|
||||
log_vals = {**self.label_loss_items(self.tloss), **self.metrics, **lr}
|
||||
self.save_metrics(metrics=log_vals)
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **lr})
|
||||
|
||||
# save model
|
||||
if (not self.args.nosave) or (epoch + 1 == self.epochs):
|
||||
@ -282,7 +282,6 @@ class BaseTrainer:
|
||||
self.plot_metrics()
|
||||
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
self.trigger_callbacks('on_train_end')
|
||||
dist.destroy_process_group() if world_size > 1 else None
|
||||
torch.cuda.empty_cache()
|
||||
self.trigger_callbacks('teardown')
|
||||
|
||||
@ -295,7 +294,8 @@ class BaseTrainer:
|
||||
'updates': self.ema.updates,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_args': self.args,
|
||||
'date': datetime.now().isoformat()}
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, self.last)
|
||||
@ -365,7 +365,7 @@ class BaseTrainer:
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(text)
|
||||
|
||||
def load_model(self, model_cfg, weights):
|
||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
@ -417,12 +417,14 @@ class BaseTrainer:
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
# TODO: need standalone evaluator to do this
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
self.console.info(f'\nValidating {f}...')
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.trigger_callbacks('on_val_end')
|
||||
|
||||
def check_resume(self):
|
||||
resume = self.args.resume
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -29,6 +30,7 @@ class BaseValidator:
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.speed = None
|
||||
self.jdict = None
|
||||
self.save_dir = save_dir if save_dir is not None else \
|
||||
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
|
||||
@ -65,11 +67,12 @@ class BaseValidator:
|
||||
self.logger.info(
|
||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if self.args.data.endswith(".yaml"):
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
|
||||
data = check_dataset_yaml(self.args.data)
|
||||
else:
|
||||
data = check_dataset(self.args.data)
|
||||
self.dataloader = self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
||||
self.dataloader = self.get_dataloader(data.get("val") or data.set("test"),
|
||||
self.args.batch_size) if not self.dataloader else self.dataloader
|
||||
|
||||
model.eval()
|
||||
|
||||
@ -81,6 +84,7 @@ class BaseValidator:
|
||||
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
|
||||
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
|
||||
self.init_metrics(de_parallel(model))
|
||||
self.jdict = [] # empty before each val
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.batch_i = batch_i
|
||||
# pre-process
|
||||
@ -105,25 +109,26 @@ class BaseValidator:
|
||||
self.plot_val_samples(batch, batch_i)
|
||||
self.plot_predictions(batch, preds, batch_i)
|
||||
|
||||
if self.args.save_json:
|
||||
self.pred_to_json(preds, batch)
|
||||
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
|
||||
self.print_results()
|
||||
|
||||
# calculate speed only once when training
|
||||
if not self.training or trainer.epoch == 0:
|
||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||
|
||||
if not self.training: # print only at inference
|
||||
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
|
||||
self.speed)
|
||||
|
||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||
if self.training:
|
||||
model.float()
|
||||
# TODO: implement save json
|
||||
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||
else:
|
||||
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
|
||||
self.speed)
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / "predictions.json"), 'w') as f:
|
||||
self.logger.info(f"Saving {f.name}...")
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
|
||||
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} \
|
||||
if self.training else stats
|
||||
self.eval_json()
|
||||
return stats
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||
@ -162,3 +167,9 @@ class BaseValidator:
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
pass
|
||||
|
||||
def pred_to_json(self, preds, batch):
|
||||
pass
|
||||
|
||||
def eval_json(self):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user