ultralytics 8.0.67
Pose speeds, Comet and ClearML updates (#1871)
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Victor Sonck <victor.sonck@gmail.com> Co-authored-by: Danny Kim <dh031200@gmail.com>
This commit is contained in:
@ -53,7 +53,6 @@ import platform
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
@ -130,7 +129,7 @@ class Exporter:
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
@ -139,7 +138,7 @@ class Exporter:
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@smart_inference_mode()
|
||||
@ -854,6 +853,12 @@ class Exporter:
|
||||
LOGGER.info(f'{prefix} pipeline success')
|
||||
return model
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
@ -78,7 +78,7 @@ class YOLO:
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.callbacks = callbacks.get_default_callbacks()
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
@ -238,7 +238,7 @@ class YOLO:
|
||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
@ -277,7 +277,7 @@ class YOLO:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
|
||||
validator = TASK_MAP[self.task][2](args=args)
|
||||
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
|
||||
@ -316,7 +316,7 @@ class YOLO:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
@ -344,7 +344,7 @@ class YOLO:
|
||||
overrides['resume'] = self.ckpt_path
|
||||
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
@ -387,19 +387,17 @@ class YOLO:
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
callbacks.default_callbacks[event].append(func)
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
@staticmethod
|
||||
def _reset_callbacks():
|
||||
def _reset_callbacks(self):
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
@ -28,7 +28,6 @@ Usage - formats:
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
@ -75,7 +74,7 @@ class BasePredictor:
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
@ -104,7 +103,7 @@ class BasePredictor:
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.batch = None
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, img):
|
||||
@ -283,3 +282,9 @@ class BasePredictor:
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
self.callbacks[event].append(func)
|
||||
|
@ -8,7 +8,6 @@ Usage:
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -26,7 +25,7 @@ from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__,
|
||||
callbacks, colorstr, emojis, yaml_save)
|
||||
callbacks, clean_url, colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
@ -72,7 +71,7 @@ class BaseTrainer:
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
@ -124,7 +123,7 @@ class BaseTrainer:
|
||||
if 'yaml_file' in self.data:
|
||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.ema = None
|
||||
@ -143,7 +142,7 @@ class BaseTrainer:
|
||||
self.plot_idx = [0, 1, 2]
|
||||
|
||||
# Callbacks
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
||||
if RANK in (-1, 0):
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
|
@ -19,7 +19,6 @@ Usage - formats:
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -55,7 +54,7 @@ class BaseValidator:
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initializes a BaseValidator instance.
|
||||
|
||||
@ -85,7 +84,7 @@ class BaseValidator:
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.001 # default conf=0.001
|
||||
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
@ -195,6 +194,12 @@ class BaseValidator:
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
return stats
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
Reference in New Issue
Block a user