ultralytics 8.0.67 Pose speeds, Comet and ClearML updates (#1871)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Victor Sonck <victor.sonck@gmail.com>
Co-authored-by: Danny Kim <dh031200@gmail.com>
This commit is contained in:
Glenn Jocher
2023-04-06 19:07:10 +02:00
committed by GitHub
parent 1cb92d7f42
commit 2725545090
28 changed files with 547 additions and 146 deletions

View File

@ -53,7 +53,6 @@ import platform
import subprocess
import time
import warnings
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
@ -130,7 +129,7 @@ class Exporter:
save_dir (Path): Directory to save results.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the Exporter class.
@ -139,7 +138,7 @@ class Exporter:
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
callbacks.add_integration_callbacks(self)
@smart_inference_mode()
@ -854,6 +853,12 @@ class Exporter:
LOGGER.info(f'{prefix} pipeline success')
return model
def add_callback(self, event: str, callback):
"""
Appends the given callback.
"""
self.callbacks[event].append(callback)
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)

View File

@ -78,7 +78,7 @@ class YOLO:
task (Any, optional): Task type for the YOLO model. Defaults to None.
"""
self._reset_callbacks()
self.callbacks = callbacks.get_default_callbacks()
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
@ -238,7 +238,7 @@ class YOLO:
overrides['save'] = kwargs.get('save', False) # not save files by default
if not self.predictor:
self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
@ -277,7 +277,7 @@ class YOLO:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
validator = TASK_MAP[self.task][2](args=args)
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
@ -316,7 +316,7 @@ class YOLO:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args)(model=self.model)
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
def train(self, **kwargs):
"""
@ -344,7 +344,7 @@ class YOLO:
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
@ -387,19 +387,17 @@ class YOLO:
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@staticmethod
def add_callback(event: str, func):
def add_callback(self, event: str, func):
"""
Add callback
"""
callbacks.default_callbacks[event].append(func)
self.callbacks[event].append(func)
@staticmethod
def _reset_ckpt_args(args):
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include}
@staticmethod
def _reset_callbacks():
def _reset_callbacks(self):
for event in callbacks.default_callbacks.keys():
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
self.callbacks[event] = [callbacks.default_callbacks[event][0]]

View File

@ -28,7 +28,6 @@ Usage - formats:
yolov8n_paddle_model # PaddlePaddle
"""
import platform
from collections import defaultdict
from pathlib import Path
import cv2
@ -75,7 +74,7 @@ class BasePredictor:
data_path (str): Path to data.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BasePredictor class.
@ -104,7 +103,7 @@ class BasePredictor:
self.data_path = None
self.source_type = None
self.batch = None
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
callbacks.add_integration_callbacks(self)
def preprocess(self, img):
@ -283,3 +282,9 @@ class BasePredictor:
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)
def add_callback(self, event: str, func):
"""
Add callback
"""
self.callbacks[event].append(func)

View File

@ -8,7 +8,6 @@ Usage:
import os
import subprocess
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from pathlib import Path
@ -26,7 +25,7 @@ from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__,
callbacks, colorstr, emojis, yaml_save)
callbacks, clean_url, colorstr, emojis, yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
@ -72,7 +71,7 @@ class BaseTrainer:
csv (Path): Path to results CSV file.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BaseTrainer class.
@ -124,7 +123,7 @@ class BaseTrainer:
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None
@ -143,7 +142,7 @@ class BaseTrainer:
self.plot_idx = [0, 1, 2]
# Callbacks
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
if RANK in (-1, 0):
callbacks.add_integration_callbacks(self)

View File

@ -19,7 +19,6 @@ Usage - formats:
yolov8n_paddle_model # PaddlePaddle
"""
import json
from collections import defaultdict
from pathlib import Path
import torch
@ -55,7 +54,7 @@ class BaseValidator:
save_dir (Path): Directory to save results.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""
Initializes a BaseValidator instance.
@ -85,7 +84,7 @@ class BaseValidator:
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
@ -195,6 +194,12 @@ class BaseValidator:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
return stats
def add_callback(self, event: str, callback):
"""
Appends the given callback.
"""
self.callbacks[event].append(callback)
def run_callbacks(self, event: str):
for callback in self.callbacks.get(event, []):
callback(self)