From 520825c4b242443632144c7a6fa2531a685ea9ee Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 24 Jan 2023 23:22:02 +0100 Subject: [PATCH] `ultralytics 8.0.19` seg/det dataset warning and DDP-cls/seg fixes (#595) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 曾逸夫(Zeng Yifu) <41098760+Zengyf-CVer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> --- .pre-commit-config.yaml | 15 +++++-- docs/cli.md | 4 +- examples/tutorial.ipynb | 4 +- tests/test_engine.py | 4 +- ultralytics/nn/tasks.py | 43 +++++++++++++------ ultralytics/yolo/data/dataloaders/v5loader.py | 2 + ultralytics/yolo/data/dataset.py | 15 ++++++- ultralytics/yolo/data/utils.py | 7 ++- ultralytics/yolo/engine/exporter.py | 2 +- ultralytics/yolo/engine/model.py | 7 ++- ultralytics/yolo/engine/predictor.py | 6 +-- ultralytics/yolo/engine/trainer.py | 8 ++-- ultralytics/yolo/engine/validator.py | 4 +- ultralytics/yolo/utils/__init__.py | 3 +- ultralytics/yolo/utils/checks.py | 2 +- ultralytics/yolo/utils/dist.py | 9 ++-- ultralytics/yolo/utils/downloads.py | 10 ++--- ultralytics/yolo/utils/torch_utils.py | 11 ++--- ultralytics/yolo/v8/classify/train.py | 4 +- ultralytics/yolo/v8/classify/val.py | 2 +- ultralytics/yolo/v8/segment/train.py | 4 +- 21 files changed, 103 insertions(+), 63 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 480127f..54903db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: # - id: end-of-file-fixer - id: trailing-whitespace @@ -25,14 +25,14 @@ repos: - id: check-docstring-first - repo: https://github.com/asottile/pyupgrade - rev: v2.37.3 + rev: v3.3.1 hooks: - id: pyupgrade name: Upgrade code args: [ --py37-plus ] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.11.4 hooks: - id: isort name: Sort imports @@ -59,7 +59,14 @@ repos: - id: flake8 name: PEP8 + - repo: https://github.com/codespell-project/codespell + rev: v2.2.2 + hooks: + - id: codespell + args: + - --ignore-words-list=crate,nd + #- repo: https://github.com/asottile/yesqa # rev: v1.4.0 # hooks: - # - id: yesqa \ No newline at end of file + # - id: yesqa diff --git a/docs/cli.md b/docs/cli.md index 0cf7d9a..fa111cd 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -183,7 +183,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL You can override the `default.yaml` config file entirely by passing a new file with the `cfg` arguments, i.e. `cfg=custom.yaml`. -To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-config` command. +To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-cfg` command. This will create `default_copy.yaml`, which you can then pass as `cfg=default_copy.yaml` along with any additional args, like `imgsz=320` in this example: @@ -192,6 +192,6 @@ like `imgsz=320` in this example: === "CLI" ```bash - yolo copy-config + yolo copy-cfg yolo cfg=default_copy.yaml imgsz=320 ``` \ No newline at end of file diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb index dbeb931..057e93e 100644 --- a/examples/tutorial.ipynb +++ b/examples/tutorial.ipynb @@ -638,11 +638,11 @@ { "cell_type": "code", "source": [ - "# Load YOLOv8n-cls, train it on imagenette160 for 3 epochs and predict an image with it\n", + "# Load YOLOv8n-cls, train it on mnist160 for 3 epochs and predict an image with it\n", "from ultralytics import YOLO\n", "\n", "model = YOLO('yolov8n-cls.pt') # load a pretrained YOLOv8n classification model\n", - "model.train(data='imagenette160', epochs=3) # train the model\n", + "model.train(data='mnist160', epochs=3) # train the model\n", "model('https://ultralytics.com/images/bus.jpg') # predict on an image" ], "metadata": { diff --git a/tests/test_engine.py b/tests/test_engine.py index eea6bb9..d6e04da 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -3,13 +3,13 @@ from pathlib import Path from ultralytics.yolo.cfg import get_cfg -from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS +from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, SETTINGS from ultralytics.yolo.v8 import classify, detect, segment CFG_DET = 'yolov8n.yaml' CFG_SEG = 'yolov8n-seg.yaml' CFG_CLS = 'squeezenet1_0' -CFG = get_cfg(DEFAULT_CFG_PATH) +CFG = get_cfg(DEFAULT_CFG) MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' SOURCE = ROOT / "assets" diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 119c476..68e6f66 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -313,13 +313,39 @@ class ClassificationModel(BaseModel): # Functions ------------------------------------------------------------------------------------------------------------ +def torch_safe_load(weight): + """ + This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it + catches the error, logs a warning message, and attempts to install the missing module via the check_requirements() + function. After installation, the function again attempts to load the model using torch.load(). + + Args: + weight (str): The file path of the PyTorch model. + + Returns: + The loaded PyTorch model. + """ + from ultralytics.yolo.utils.downloads import attempt_download + + file = attempt_download(weight) # search online if missing locally + try: + return torch.load(file, map_location='cpu') # load + except ModuleNotFoundError as e: + if e.name == 'omegaconf': # e.name is missing module name + LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements." + f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future." + f"\nRecommend fixes are to train a new model using updated ultraltyics package or to " + f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0") + check_requirements(e.name) # install missing module + return torch.load(file, map_location='cpu') # load + + def attempt_load_weights(weights, device=None, inplace=True, fuse=False): # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a - from ultralytics.yolo.utils.downloads import attempt_download model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: - ckpt = torch.load(attempt_download(w), map_location='cpu') # load + ckpt = torch_safe_load(w) # load ckpt args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model @@ -355,18 +381,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): # Loads a single model weights - from ultralytics.yolo.utils.downloads import attempt_download - - weight = attempt_download(weight) - try: - ckpt = torch.load(weight, map_location='cpu') # load - except ModuleNotFoundError: - LOGGER.warning(f"WARNING ⚠️ {weight} is deprecated as it requires omegaconf, which is now removed from " - "ultralytics requirements.\nAutoInstall will occur now but this feature will be removed for " - "omegaconf models in the future.\nPlease train a new model or download updated models " - "from https://github.com/ultralytics/assets/releases/tag/v0.0.0") - check_requirements('omegaconf') - ckpt = torch.load(weight, map_location='cpu') # load + ckpt = torch_safe_load(weight) # load ckpt args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model diff --git a/ultralytics/yolo/data/dataloaders/v5loader.py b/ultralytics/yolo/data/dataloaders/v5loader.py index 9590564..a44c8e6 100644 --- a/ultralytics/yolo/data/dataloaders/v5loader.py +++ b/ultralytics/yolo/data/dataloaders/v5loader.py @@ -611,6 +611,8 @@ class LoadImagesAndLabels(Dataset): def cache_labels(self, path=Path('./labels.cache'), prefix=''): # Cache dataset labels, check images and read shapes + if path.exists(): + path.unlink() # remove *.cache file if exists x = {} # dict nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f"{prefix}Scanning {path.parent / path.stem}..." diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index bff01e9..57d8478 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -47,6 +47,8 @@ class YOLODataset(BaseDataset): def cache_labels(self, path=Path("./labels.cache")): # Cache dataset labels, check images and read shapes + if path.exists(): + path.unlink() # remove *.cache file if exists x = {"labels": []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f"{self.prefix}Scanning {path.parent / path.stem}..." @@ -85,7 +87,7 @@ class YOLODataset(BaseDataset): x["results"] = nf, nm, ne, nc, len(self.im_files) x["msgs"] = msgs # warnings x["version"] = self.cache_version # cache version - self.im_files = [lb["im_file"] for lb in x["labels"]] + self.im_files = [lb["im_file"] for lb in x["labels"]] # update im_files if is_dir_writeable(path.parent): np.save(str(path), x) # save cache for next time path.with_suffix(".cache.npy").rename(path) # remove .npy suffix @@ -116,6 +118,17 @@ class YOLODataset(BaseDataset): # Read cache [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items labels = cache["labels"] + + # Check if the dataset is all boxes or all segments + len_boxes = sum(len(lb["bboxes"]) for lb in labels) + len_segments = sum(len(lb["segments"]) for lb in labels) + if len_segments and len_boxes != len_segments: + LOGGER.warning( + f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " + f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " + "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.") + for lb in labels: + lb["segments"] = [] nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}" return labels diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index 018eea7..598fa6f 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -14,7 +14,7 @@ import numpy as np import torch from PIL import ExifTags, Image, ImageOps -from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load +from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.downloads import download from ultralytics.yolo.utils.files import unzip_file @@ -202,7 +202,10 @@ def check_det_dataset(dataset, autodownload=True): # Checks for k in 'train', 'val', 'names': - assert k in data, f"data.yaml '{k}:' field missing ❌" + if k not in data: + raise SyntaxError( + emojis(f"{dataset} '{k}:' key missing ❌.\n" + f"'train', 'val' and 'names' are required in data.yaml files.")) if isinstance(data['names'], (list, tuple)): # old array format data['names'] = dict(enumerate(data['names'])) # convert to dict data['nc'] = len(data['names']) diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 2a73c01..698a908 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -388,7 +388,7 @@ class Exporter: @try_export def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): # YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt - assert self.im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `device==0`' + assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'" try: import tensorrt as trt # noqa except ImportError: diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index aea15d2..ef65898 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -53,7 +53,12 @@ class YOLO: self.overrides = {} # overrides for trainer object # Load or create new YOLO model - {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) + load_methods = {'.pt': self._load, '.yaml': self._new} + suffix = Path(model).suffix + if suffix in load_methods: + {'.pt': self._load, '.yaml': self._new}[suffix](model) + else: + raise NotImplementedError(f"'{suffix}' model loading not implemented") def __call__(self, source=None, stream=False, verbose=False, **kwargs): return self.predict(source, stream, verbose, **kwargs) diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index f90632b..6c917df 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS -from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode @@ -61,12 +61,12 @@ class BasePredictor: data_path (str): Path to data. """ - def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None): + def __init__(self, cfg=DEFAULT_CFG, overrides=None): """ Initializes the BasePredictor class. Args: - cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. overrides (dict, optional): Configuration overrides. Defaults to None. """ self.args = get_cfg(cfg, overrides) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index ce21fce..c540893 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -24,8 +24,8 @@ from ultralytics import __version__ from ultralytics.nn.tasks import attempt_load_one_weight from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset -from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, - emojis, yaml_save) +from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis, + yaml_save) from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command @@ -71,12 +71,12 @@ class BaseTrainer: csv (Path): Path to results CSV file. """ - def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None): + def __init__(self, cfg=DEFAULT_CFG, overrides=None): """ Initializes the BaseTrainer class. Args: - cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. overrides (dict, optional): Configuration overrides. Defaults to None. """ self.args = get_cfg(cfg, overrides) diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 99bd96b..7ca280e 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -10,7 +10,7 @@ from tqdm import tqdm from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset -from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.ops import Profile @@ -52,7 +52,7 @@ class BaseValidator: self.dataloader = dataloader self.pbar = pbar self.logger = logger or LOGGER - self.args = args or get_cfg(DEFAULT_CFG_PATH) + self.args = args or get_cfg(DEFAULT_CFG) self.model = None self.data = None self.device = None diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index bb1636f..bf447e1 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -127,8 +127,7 @@ def is_colab(): Returns: bool: True if running inside a Colab notebook, False otherwise. """ - # Check if the 'google.colab' module is present in sys.modules - return 'google.colab' in sys.modules + return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ def is_kaggle(): diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 1ae495f..abfbeed 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -224,7 +224,7 @@ def check_file(file, suffix=''): for d in 'models', 'yolo/data': # search directories files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file if not files: - raise FileNotFoundError(f"{file} does not exist") + raise FileNotFoundError(f"'{file}' does not exist") elif len(files) > 1: raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") return files[0] # return file diff --git a/ultralytics/yolo/utils/dist.py b/ultralytics/yolo/utils/dist.py index d380535..2150967 100644 --- a/ultralytics/yolo/utils/dist.py +++ b/ultralytics/yolo/utils/dist.py @@ -10,17 +10,14 @@ from . import USER_CONFIG_DIR def find_free_network_port() -> int: - # https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py """Finds a free port on localhost. It is useful in single-node training when we don't want to connect to a real main node but have to set the `MASTER_PORT` environment variable. """ - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - port = s.getsockname()[1] - s.close() - return port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] # port def generate_ddp_file(trainer): diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 61f38b5..6b990fc 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -91,12 +91,10 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'): file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) if name in assets: - url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror - safe_download( - file, - url=f'https://github.com/{repo}/releases/download/{tag}/{name}', - min_bytes=1E5, - error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}') + safe_download(file, + url=f'https://github.com/{repo}/releases/download/{tag}/{name}', + min_bytes=1E5, + error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}') return str(file) diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 1960440..e8ea557 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -58,7 +58,7 @@ def DDP_model(model): return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) -def select_device(device='', batch_size=0, newline=False): +def select_device(device='', batch=0, newline=False): # device = None or 'cpu' or 0 or '0' or '0,1,2,3' ver = git_describe() or ultralytics.__version__ # git commit or pip package version s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} ' @@ -71,14 +71,15 @@ def select_device(device='', batch_size=0, newline=False): os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False elif device: # non-cpu device requested os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() - assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ - f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)" + if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): + raise ValueError(f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)") if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 n = len(devices) # device count - if n > 1 and batch_size > 0: # check batch_size is divisible by device_count - assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' + if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count + raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n' + f'Try batch={batch // n} or batch={batch // n + 1}') space = ' ' * (len(s) + 1) for i, d in enumerate(devices): p = torch.cuda.get_device_properties(i) diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index e51579a..2e2a6b9 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -13,11 +13,11 @@ from ultralytics.yolo.utils.torch_utils import strip_optimizer class ClassificationTrainer(BaseTrainer): - def __init__(self, config=DEFAULT_CFG, overrides=None): + def __init__(self, cfg=DEFAULT_CFG, overrides=None): if overrides is None: overrides = {} overrides["task"] = "classify" - super().__init__(config, overrides) + super().__init__(cfg, overrides) def set_model_attributes(self): self.model.names = self.data["names"] diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index a478795..b28b9ed 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -47,7 +47,7 @@ class ClassificationValidator(BaseValidator): def val(cfg=DEFAULT_CFG): cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" - cfg.data = cfg.data or "imagenette160" + cfg.data = cfg.data or "mnist160" validator = ClassificationValidator(args=cfg) validator(model=cfg.model) diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index c528593..fbd96f8 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -18,11 +18,11 @@ from ultralytics.yolo.v8.detect.train import Loss # BaseTrainer python usage class SegmentationTrainer(v8.detect.DetectionTrainer): - def __init__(self, config=DEFAULT_CFG, overrides=None): + def __init__(self, cfg=DEFAULT_CFG, overrides=None): if overrides is None: overrides = {} overrides["task"] = "segment" - super().__init__(config, overrides) + super().__init__(cfg, overrides) def get_model(self, cfg=None, weights=None, verbose=True): model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)