ultralytics 8.0.44 export and task fixes (#1088)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-24 03:11:25 +01:00
committed by GitHub
parent fe61018975
commit 3ea659411b
32 changed files with 439 additions and 480 deletions

View File

@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CFG, RANK
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
@ -64,6 +64,7 @@ class ClassificationTrainer(BaseTrainer):
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # dont return ckpt. Classification doesn't support resume
@ -93,7 +94,7 @@ class ClassificationTrainer(BaseTrainer):
def get_validator(self):
self.loss_names = ['loss']
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
def criterion(self, preds, batch):
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
@ -132,11 +133,12 @@ class ClassificationTrainer(BaseTrainer):
strip_optimizer(f) # strip optimizers
# TODO: validate best.pt after training completes
# if f is self.best:
# self.console.info(f'\nValidating {f}...')
# LOGGER.info(f'\nValidating {f}...')
# self.validator.args.save_json = True
# self.metrics = self.validator(model=f)
# self.metrics.pop('fitness', None)
# self.run_callbacks('on_fit_epoch_end')
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def train(cfg=DEFAULT_CFG, use_python=False):