ultralytics 8.0.44 export and task fixes (#1088)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-24 03:11:25 +01:00
committed by GitHub
parent fe61018975
commit 3ea659411b
32 changed files with 439 additions and 480 deletions

View File

@ -66,10 +66,7 @@ class DetectionTrainer(BaseTrainer):
def get_validator(self):
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return v8.detect.DetectionValidator(self.test_loader,
save_dir=self.save_dir,
logger=self.console,
args=copy(self.args))
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch):
if not hasattr(self, 'compute_loss'):

View File

@ -9,7 +9,7 @@ import torch
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr, ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -18,8 +18,8 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class DetectionValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args)
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
super().__init__(dataloader, save_dir, pbar, args)
self.args.task = 'detect'
self.is_coco = False
self.class_map = None
@ -112,7 +112,7 @@ class DetectionValidator(BaseValidator):
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs):
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed))
self.metrics.speed = self.speed
def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
@ -123,15 +123,15 @@ class DetectionValidator(BaseValidator):
def print_results(self):
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0:
self.logger.warning(
LOGGER.warning(
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
if self.args.plots:
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
@ -212,7 +212,7 @@ class DetectionValidator(BaseValidator):
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
@ -230,7 +230,7 @@ class DetectionValidator(BaseValidator):
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
self.logger.warning(f'pycocotools unable to run: {e}')
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats