From e7a94c79c594cdf550c2e42876f9e44dc819651c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 5 Apr 2023 17:52:53 +0200 Subject: [PATCH] Revert 1783 fix callbacks by reference (#1847) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/__init__.py | 2 +- ultralytics/yolo/engine/model.py | 21 +++++++++++---------- ultralytics/yolo/engine/predictor.py | 4 ++-- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index efc7f0c..35d3cb6 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.63' +__version__ = '8.0.64' from ultralytics.hub import start from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index f96aac9..a67166d 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,7 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license import sys -from copy import deepcopy from pathlib import Path from typing import Union @@ -78,7 +77,7 @@ class YOLO: task (Any, optional): Task type for the YOLO model. Defaults to None. """ - self.callbacks = deepcopy(callbacks.default_callbacks) + self._reset_callbacks() self.predictor = None # reuse predictor self.model = None # model object self.trainer = None # trainer object @@ -118,7 +117,7 @@ class YOLO: return any(( model.startswith('https://hub.ultralytics.com/models/'), [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID - (len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID + len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID def _new(self, cfg: str, task=None, verbose=True): """ @@ -228,8 +227,8 @@ class YOLO: if source is None: source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") - is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and ( - ('predict' in sys.argv or 'mode=predict' in sys.argv) or ('track' in sys.argv or 'mode=track' in sys.argv)) + is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any( + x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track')) overrides = self.overrides.copy() overrides['conf'] = 0.25 overrides.update(kwargs) # prefer kwargs @@ -238,7 +237,7 @@ class YOLO: overrides['save'] = kwargs.get('save', False) # not save files by default if not self.predictor: self.task = overrides.get('task') or self.task - self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks) + self.predictor = TASK_MAP[self.task][3](overrides=overrides) self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) @@ -387,17 +386,19 @@ class YOLO: """ return self.model.transforms if hasattr(self.model, 'transforms') else None - def add_callback(self, event: str, func): + @staticmethod + def add_callback(event: str, func): """ Add callback """ - self.callbacks[event].append(func) + callbacks.default_callbacks[event].append(func) @staticmethod def _reset_ckpt_args(args): include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model return {k: v for k, v in args.items() if k in include} - def _reset_callbacks(self): + @staticmethod + def _reset_callbacks(): for event in callbacks.default_callbacks.keys(): - self.callbacks[event] = [callbacks.default_callbacks[event][0]] + callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]] diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index cb24faf..82905ca 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -75,7 +75,7 @@ class BasePredictor: data_path (str): Path to data. """ - def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + def __init__(self, cfg=DEFAULT_CFG, overrides=None): """ Initializes the BasePredictor class. @@ -104,7 +104,7 @@ class BasePredictor: self.data_path = None self.source_type = None self.batch = None - self.callbacks = defaultdict(list, _callbacks) if _callbacks else defaultdict(list, callbacks.default_callbacks) + self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks callbacks.add_integration_callbacks(self) def preprocess(self, img):