`ultralytics 8.0.19` seg/det dataset warning and DDP-cls/seg fixes (#595)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: 曾逸夫(Zeng Yifu) <41098760+Zengyf-CVer@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 936414c615
commit 520825c4b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,7 +14,7 @@ ci:
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v4.4.0
hooks: hooks:
# - id: end-of-file-fixer # - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
@ -25,14 +25,14 @@ repos:
- id: check-docstring-first - id: check-docstring-first
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v2.37.3 rev: v3.3.1
hooks: hooks:
- id: pyupgrade - id: pyupgrade
name: Upgrade code name: Upgrade code
args: [ --py37-plus ] args: [ --py37-plus ]
- repo: https://github.com/PyCQA/isort - repo: https://github.com/PyCQA/isort
rev: 5.10.1 rev: 5.11.4
hooks: hooks:
- id: isort - id: isort
name: Sort imports name: Sort imports
@ -59,6 +59,13 @@ repos:
- id: flake8 - id: flake8
name: PEP8 name: PEP8
- repo: https://github.com/codespell-project/codespell
rev: v2.2.2
hooks:
- id: codespell
args:
- --ignore-words-list=crate,nd
#- repo: https://github.com/asottile/yesqa #- repo: https://github.com/asottile/yesqa
# rev: v1.4.0 # rev: v1.4.0
# hooks: # hooks:

@ -183,7 +183,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL
You can override the `default.yaml` config file entirely by passing a new file with the `cfg` arguments, You can override the `default.yaml` config file entirely by passing a new file with the `cfg` arguments,
i.e. `cfg=custom.yaml`. i.e. `cfg=custom.yaml`.
To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-config` command. To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-cfg` command.
This will create `default_copy.yaml`, which you can then pass as `cfg=default_copy.yaml` along with any additional args, This will create `default_copy.yaml`, which you can then pass as `cfg=default_copy.yaml` along with any additional args,
like `imgsz=320` in this example: like `imgsz=320` in this example:
@ -192,6 +192,6 @@ like `imgsz=320` in this example:
=== "CLI" === "CLI"
```bash ```bash
yolo copy-config yolo copy-cfg
yolo cfg=default_copy.yaml imgsz=320 yolo cfg=default_copy.yaml imgsz=320
``` ```

@ -638,11 +638,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Load YOLOv8n-cls, train it on imagenette160 for 3 epochs and predict an image with it\n", "# Load YOLOv8n-cls, train it on mnist160 for 3 epochs and predict an image with it\n",
"from ultralytics import YOLO\n", "from ultralytics import YOLO\n",
"\n", "\n",
"model = YOLO('yolov8n-cls.pt') # load a pretrained YOLOv8n classification model\n", "model = YOLO('yolov8n-cls.pt') # load a pretrained YOLOv8n classification model\n",
"model.train(data='imagenette160', epochs=3) # train the model\n", "model.train(data='mnist160', epochs=3) # train the model\n",
"model('https://ultralytics.com/images/bus.jpg') # predict on an image" "model('https://ultralytics.com/images/bus.jpg') # predict on an image"
], ],
"metadata": { "metadata": {

@ -3,13 +3,13 @@
from pathlib import Path from pathlib import Path
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, SETTINGS
from ultralytics.yolo.v8 import classify, detect, segment from ultralytics.yolo.v8 import classify, detect, segment
CFG_DET = 'yolov8n.yaml' CFG_DET = 'yolov8n.yaml'
CFG_SEG = 'yolov8n-seg.yaml' CFG_SEG = 'yolov8n-seg.yaml'
CFG_CLS = 'squeezenet1_0' CFG_CLS = 'squeezenet1_0'
CFG = get_cfg(DEFAULT_CFG_PATH) CFG = get_cfg(DEFAULT_CFG)
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
SOURCE = ROOT / "assets" SOURCE = ROOT / "assets"

@ -313,13 +313,39 @@ class ClassificationModel(BaseModel):
# Functions ------------------------------------------------------------------------------------------------------------ # Functions ------------------------------------------------------------------------------------------------------------
def torch_safe_load(weight):
"""
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it
catches the error, logs a warning message, and attempts to install the missing module via the check_requirements()
function. After installation, the function again attempts to load the model using torch.load().
Args:
weight (str): The file path of the PyTorch model.
Returns:
The loaded PyTorch model.
"""
from ultralytics.yolo.utils.downloads import attempt_download
file = attempt_download(weight) # search online if missing locally
try:
return torch.load(file, map_location='cpu') # load
except ModuleNotFoundError as e:
if e.name == 'omegaconf': # e.name is missing module name
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
f"\nRecommend fixes are to train a new model using updated ultraltyics package or to "
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
check_requirements(e.name) # install missing module
return torch.load(file, map_location='cpu') # load
def attempt_load_weights(weights, device=None, inplace=True, fuse=False): def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
from ultralytics.yolo.utils.downloads import attempt_download
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load ckpt = torch_safe_load(w) # load ckpt
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
@ -355,18 +381,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
# Loads a single model weights # Loads a single model weights
from ultralytics.yolo.utils.downloads import attempt_download ckpt = torch_safe_load(weight) # load ckpt
weight = attempt_download(weight)
try:
ckpt = torch.load(weight, map_location='cpu') # load
except ModuleNotFoundError:
LOGGER.warning(f"WARNING ⚠️ {weight} is deprecated as it requires omegaconf, which is now removed from "
"ultralytics requirements.\nAutoInstall will occur now but this feature will be removed for "
"omegaconf models in the future.\nPlease train a new model or download updated models "
"from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
check_requirements('omegaconf')
ckpt = torch.load(weight, map_location='cpu') # load
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model

@ -611,6 +611,8 @@ class LoadImagesAndLabels(Dataset):
def cache_labels(self, path=Path('./labels.cache'), prefix=''): def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes # Cache dataset labels, check images and read shapes
if path.exists():
path.unlink() # remove *.cache file if exists
x = {} # dict x = {} # dict
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning {path.parent / path.stem}..." desc = f"{prefix}Scanning {path.parent / path.stem}..."

@ -47,6 +47,8 @@ class YOLODataset(BaseDataset):
def cache_labels(self, path=Path("./labels.cache")): def cache_labels(self, path=Path("./labels.cache")):
# Cache dataset labels, check images and read shapes # Cache dataset labels, check images and read shapes
if path.exists():
path.unlink() # remove *.cache file if exists
x = {"labels": []} x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..." desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
@ -85,7 +87,7 @@ class YOLODataset(BaseDataset):
x["results"] = nf, nm, ne, nc, len(self.im_files) x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings x["msgs"] = msgs # warnings
x["version"] = self.cache_version # cache version x["version"] = self.cache_version # cache version
self.im_files = [lb["im_file"] for lb in x["labels"]] self.im_files = [lb["im_file"] for lb in x["labels"]] # update im_files
if is_dir_writeable(path.parent): if is_dir_writeable(path.parent):
np.save(str(path), x) # save cache for next time np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
@ -116,6 +118,17 @@ class YOLODataset(BaseDataset):
# Read cache # Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"] labels = cache["labels"]
# Check if the dataset is all boxes or all segments
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
len_segments = sum(len(lb["segments"]) for lb in labels)
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
for lb in labels:
lb["segments"] = []
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}" assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
return labels return labels

@ -14,7 +14,7 @@ import numpy as np
import torch import torch
from PIL import ExifTags, Image, ImageOps from PIL import ExifTags, Image, ImageOps
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.files import unzip_file from ultralytics.yolo.utils.files import unzip_file
@ -202,7 +202,10 @@ def check_det_dataset(dataset, autodownload=True):
# Checks # Checks
for k in 'train', 'val', 'names': for k in 'train', 'val', 'names':
assert k in data, f"data.yaml '{k}:' field missing ❌" if k not in data:
raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n"
f"'train', 'val' and 'names' are required in data.yaml files."))
if isinstance(data['names'], (list, tuple)): # old array format if isinstance(data['names'], (list, tuple)): # old array format
data['names'] = dict(enumerate(data['names'])) # convert to dict data['names'] = dict(enumerate(data['names'])) # convert to dict
data['nc'] = len(data['names']) data['nc'] = len(data['names'])

@ -388,7 +388,7 @@ class Exporter:
@try_export @try_export
def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
# YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt # YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt
assert self.im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `device==0`' assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
try: try:
import tensorrt as trt # noqa import tensorrt as trt # noqa
except ImportError: except ImportError:

@ -53,7 +53,12 @@ class YOLO:
self.overrides = {} # overrides for trainer object self.overrides = {} # overrides for trainer object
# Load or create new YOLO model # Load or create new YOLO model
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) load_methods = {'.pt': self._load, '.yaml': self._new}
suffix = Path(model).suffix
if suffix in load_methods:
{'.pt': self._load, '.yaml': self._new}[suffix](model)
else:
raise NotImplementedError(f"'{suffix}' model loading not implemented")
def __call__(self, source=None, stream=False, verbose=False, **kwargs): def __call__(self, source=None, stream=False, verbose=False, **kwargs):
return self.predict(source, stream, verbose, **kwargs) return self.predict(source, stream, verbose, **kwargs)

@ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -61,12 +61,12 @@ class BasePredictor:
data_path (str): Path to data. data_path (str): Path to data.
""" """
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None):
""" """
Initializes the BasePredictor class. Initializes the BasePredictor class.
Args: Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)

@ -24,8 +24,8 @@ from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis,
emojis, yaml_save) yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
@ -71,12 +71,12 @@ class BaseTrainer:
csv (Path): Path to results CSV file. csv (Path): Path to results CSV file.
""" """
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None):
""" """
Initializes the BaseTrainer class. Initializes the BaseTrainer class.
Args: Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)

@ -10,7 +10,7 @@ from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
@ -52,7 +52,7 @@ class BaseValidator:
self.dataloader = dataloader self.dataloader = dataloader
self.pbar = pbar self.pbar = pbar
self.logger = logger or LOGGER self.logger = logger or LOGGER
self.args = args or get_cfg(DEFAULT_CFG_PATH) self.args = args or get_cfg(DEFAULT_CFG)
self.model = None self.model = None
self.data = None self.data = None
self.device = None self.device = None

@ -127,8 +127,7 @@ def is_colab():
Returns: Returns:
bool: True if running inside a Colab notebook, False otherwise. bool: True if running inside a Colab notebook, False otherwise.
""" """
# Check if the 'google.colab' module is present in sys.modules return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
return 'google.colab' in sys.modules
def is_kaggle(): def is_kaggle():

@ -224,7 +224,7 @@ def check_file(file, suffix=''):
for d in 'models', 'yolo/data': # search directories for d in 'models', 'yolo/data': # search directories
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
if not files: if not files:
raise FileNotFoundError(f"{file} does not exist") raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1: elif len(files) > 1:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] # return file return files[0] # return file

@ -10,17 +10,14 @@ from . import USER_CONFIG_DIR
def find_free_network_port() -> int: def find_free_network_port() -> int:
# https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py
"""Finds a free port on localhost. """Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable. `MASTER_PORT` environment variable.
""" """
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) s.bind(('127.0.0.1', 0))
port = s.getsockname()[1] return s.getsockname()[1] # port
s.close()
return port
def generate_ddp_file(trainer): def generate_ddp_file(trainer):

@ -91,12 +91,10 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
if name in assets: if name in assets:
url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror safe_download(file,
safe_download(
file,
url=f'https://github.com/{repo}/releases/download/{tag}/{name}', url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
min_bytes=1E5, min_bytes=1E5,
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}') error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
return str(file) return str(file)

@ -58,7 +58,7 @@ def DDP_model(model):
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
def select_device(device='', batch_size=0, newline=False): def select_device(device='', batch=0, newline=False):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3' # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
ver = git_describe() or ultralytics.__version__ # git commit or pip package version ver = git_describe() or ultralytics.__version__ # git commit or pip package version
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} ' s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
@ -71,14 +71,15 @@ def select_device(device='', batch_size=0, newline=False):
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested elif device: # non-cpu device requested
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)" raise ValueError(f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)")
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count n = len(devices) # device count
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n'
f'Try batch={batch // n} or batch={batch // n + 1}')
space = ' ' * (len(s) + 1) space = ' ' * (len(s) + 1)
for i, d in enumerate(devices): for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i) p = torch.cuda.get_device_properties(i)

@ -13,11 +13,11 @@ from ultralytics.yolo.utils.torch_utils import strip_optimizer
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CFG, overrides=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides["task"] = "classify" overrides["task"] = "classify"
super().__init__(config, overrides) super().__init__(cfg, overrides)
def set_model_attributes(self): def set_model_attributes(self):
self.model.names = self.data["names"] self.model.names = self.data["names"]

@ -47,7 +47,7 @@ class ClassificationValidator(BaseValidator):
def val(cfg=DEFAULT_CFG): def val(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "imagenette160" cfg.data = cfg.data or "mnist160"
validator = ClassificationValidator(args=cfg) validator = ClassificationValidator(args=cfg)
validator(model=cfg.model) validator(model=cfg.model)

@ -18,11 +18,11 @@ from ultralytics.yolo.v8.detect.train import Loss
# BaseTrainer python usage # BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, config=DEFAULT_CFG, overrides=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides["task"] = "segment" overrides["task"] = "segment"
super().__init__(config, overrides) super().__init__(cfg, overrides)
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose) model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)

Loading…
Cancel
Save