diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 55a1413..8e72e85 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,9 +32,14 @@ jobs: # key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }} # restore-keys: ${{ runner.os }}-Benchmarks- - name: Install requirements + shell: bash # for Windows compatibility run: | python -m pip install --upgrade pip wheel - pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.os }}" == "macos-latest" ]; then + pip install -e . coremltools openvino-dev tensorflow-macos --extra-index-url https://download.pytorch.org/whl/cpu + else + pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu + fi yolo export format=tflite - name: Check environment run: | @@ -94,6 +99,7 @@ jobs: key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }} restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip- - name: Install requirements + shell: bash # for Windows compatibility run: | python -m pip install --upgrade pip wheel if [ "${{ matrix.torch }}" == "1.8.0" ]; then @@ -101,7 +107,6 @@ jobs: else pip install -e '.[export]' pytest --extra-index-url https://download.pytorch.org/whl/cpu fi - shell: bash # for Windows compatibility - name: Check environment run: | echo "RUNNER_OS is ${{ runner.os }}" diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb index 9bf0355..c7e91b0 100644 --- a/examples/tutorial.ipynb +++ b/examples/tutorial.ipynb @@ -78,13 +78,6 @@ } ] }, - { - "cell_type": "markdown", - "source": [], - "metadata": { - "id": "ZOwTlorPd8-D" - } - }, { "cell_type": "markdown", "metadata": { diff --git a/tests/test_cli.py b/tests/test_cli.py index 57ac031..e4c1896 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ import subprocess from pathlib import Path -from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS +from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' CFG = 'yolov8n' @@ -49,9 +49,10 @@ def test_val_classify(): # Predict checks ------------------------------------------------------------------------------------------------------- def test_predict_detect(): run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32") - run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32') - run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32') - run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32') + if checks.check_online(): + run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32') + run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32') + run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32') def test_predict_segment(): diff --git a/tests/test_python.py b/tests/test_python.py index 34b00de..3ae8077 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -9,7 +9,7 @@ from PIL import Image from ultralytics import YOLO from ultralytics.yolo.data.build import load_inference_source -from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS +from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt' CFG = 'yolov8n.yaml' @@ -49,28 +49,20 @@ def test_predict_dir(): def test_predict_img(): model = YOLO(MODEL) - output = model(source=Image.open(SOURCE), save=True, verbose=True) # PIL - assert len(output) == 1, 'predict test failed' - img = cv2.imread(str(SOURCE)) - output = model(source=img, save=True, save_txt=True) # ndarray - assert len(output) == 1, 'predict test failed' - output = model(source=[img, img], save=True, save_txt=True) # batch - assert len(output) == 2, 'predict test failed' - output = model(source=[img, img], save=True, stream=True) # stream - assert len(list(output)) == 2, 'predict test failed' - tens = torch.zeros(320, 640, 3) - output = model(tens.numpy()) - assert len(output) == 1, 'predict test failed' - # test multiple source - imgs = [ - SOURCE, # filename + im = cv2.imread(str(SOURCE)) + assert len(model(source=Image.open(SOURCE), save=True, verbose=True)) == 1 # PIL + assert len(model(source=im, save=True, save_txt=True)) == 1 # ndarray + assert len(model(source=[im, im], save=True, save_txt=True)) == 2 # batch + assert len(list(model(source=[im, im], save=True, stream=True))) == 2 # stream + assert len(model(torch.zeros(320, 640, 3).numpy())) == 1 # tensor to numpy + batch = [ + str(SOURCE), # filename Path(SOURCE), # Path - 'https://ultralytics.com/images/zidane.jpg', # URI + 'https://ultralytics.com/images/zidane.jpg' if checks.check_online() else SOURCE, # URI cv2.imread(str(SOURCE)), # OpenCV Image.open(SOURCE), # PIL np.zeros((320, 640, 3))] # numpy - output = model(imgs) - assert len(output) == 6, 'predict test failed!' + assert len(model(batch)) == len(batch) # multiple sources in a batch def test_predict_grey_and_4ch(): @@ -85,6 +77,11 @@ def test_val(): model.val(data='coco8.yaml', imgsz=32) +def test_val_scratch(): + model = YOLO(CFG) + model.val(data='coco8.yaml', imgsz=32) + + def test_train_scratch(): model = YOLO(CFG) model.train(data='coco8.yaml', epochs=1, imgsz=32) @@ -103,6 +100,12 @@ def test_export_torchscript(): YOLO(f)(SOURCE) # exported model inference +def test_export_torchscript_scratch(): + model = YOLO(CFG) + f = model.export(format='torchscript') + YOLO(f)(SOURCE) # exported model inference + + def test_export_onnx(): model = YOLO(MODEL) f = model.export(format='onnx') diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index cefc173..f03064e 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.43' +__version__ = '8.0.44' from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils.checks import check_yolo as checks diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py index ed33d2c..6b51ce8 100644 --- a/ultralytics/hub/__init__.py +++ b/ultralytics/hub/__init__.py @@ -15,7 +15,7 @@ EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_c def start(key=''): """ - Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY') + Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY') """ auth = Auth(key) try: @@ -30,9 +30,9 @@ def start(key=''): session = HubTrainingSession(model_id=model_id, auth=auth) session.check_disk_space() - trainer = YOLO(session.input_file) - session.register_callbacks(trainer) - trainer.train(**session.train_args) + model = YOLO(session.input_file) + session.register_callbacks(model) + model.train(**session.train_args) except Exception as e: LOGGER.warning(f'{PREFIX}{e}') @@ -93,6 +93,5 @@ def get_export(key='', format='torchscript'): return r.json() -# temp. For checking if __name__ == '__main__': start() diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 70d9ea6..777dc4b 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -26,6 +26,7 @@ class HubTrainingSession: self._timers = {} # rate limit timers (seconds) self._metrics_queue = {} # metrics queue self.model = self._get_model() + self.alive = True self._start_heartbeat() # start heartbeats self._register_signal_handlers() @@ -52,37 +53,6 @@ class HubTrainingSession: payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'} smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2) - def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): - # Upload a model to HUB - file = None - if Path(weights).is_file(): - with open(weights, 'rb') as f: - file = f.read() - if final: - smart_request( - f'{self.api_url}/upload', - data={ - 'epoch': epoch, - 'type': 'final', - 'map': map}, - files={'best.pt': file}, - headers=self.auth_header, - retry=10, - timeout=3600, - code=4, - ) - else: - smart_request( - f'{self.api_url}/upload', - data={ - 'epoch': epoch, - 'type': 'epoch', - 'isBest': bool(is_best)}, - headers=self.auth_header, - files={'last.pt': file}, - code=3, - ) - def _get_model(self): # Returns model from database by id api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' @@ -151,7 +121,7 @@ class HubTrainingSession: model_info = { 'model/parameters': get_num_params(trainer.model), 'model/GFLOPs': round(get_flops(trainer.model), 3), - 'model/speed(ms)': round(trainer.validator.speed[1], 3)} + 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} all_plots = {**all_plots, **model_info} self._metrics_queue[trainer.epoch] = json.dumps(all_plots) if time() - self._timers['metrics'] > self._rate_limits['metrics']: @@ -169,52 +139,45 @@ class HubTrainingSession: def on_train_end(self, trainer): # Upload final model and metrics with exponential standoff - LOGGER.info(f'{PREFIX}Training completed successfully ✅') - LOGGER.info(f'{PREFIX}Uploading final {self.model_id}') + LOGGER.info(f'{PREFIX}Training completed successfully ✅\n' + f'{PREFIX}Uploading final {self.model_id}') - # hack for fetching mAP - mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0) - self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95 + self._upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) self.alive = False # stop heartbeats LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀') def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): # Upload a model to HUB - file = None if Path(weights).is_file(): with open(weights, 'rb') as f: file = f.read() - file_param = {'best.pt' if final else 'last.pt': file} - endpoint = f'{self.api_url}/upload' + else: + LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload failed. Missing model {weights}.') + file = None data = {'epoch': epoch} if final: data.update({'type': 'final', 'map': map}) else: data.update({'type': 'epoch', 'isBest': bool(is_best)}) - smart_request( - endpoint, - data=data, - files=file_param, - headers=self.auth_header, - retry=10 if final else None, - timeout=3600 if final else None, - code=4 if final else 3, - ) + smart_request(f'{self.api_url}/upload', + data=data, + files={'best.pt' if final else 'last.pt': file}, + headers=self.auth_header, + retry=10 if final else None, + timeout=3600 if final else None, + code=4 if final else 3) @threaded def _start_heartbeat(self): - self.alive = True while self.alive: - r = smart_request( - f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', - json={ - 'agent': AGENT_NAME, - 'agentId': self.agent_id}, - headers=self.auth_header, - retry=0, - code=5, - thread=False, - ) + r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', + json={ + 'agent': AGENT_NAME, + 'agentId': self.agent_id}, + headers=self.auth_header, + retry=0, + code=5, + thread=False) self.agent_id = r.json().get('data', {}).get('agentId', None) sleep(self._rate_limits['heartbeat']) diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 5b91a96..7656859 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -181,7 +181,6 @@ class AutoBackend(nn.Module): import tensorflow as tf keras = False # assume TF1 saved_model model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) - w = Path(w) / 'metadata.yaml' elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...') import tensorflow as tf @@ -258,8 +257,9 @@ class AutoBackend(nn.Module): f'\n\n{EXPORT_FORMATS_TABLE}') # Load external metadata YAML + w = Path(w) if xml or saved_model or paddle: - metadata = Path(w).parent / 'metadata.yaml' + metadata = (w if saved_model else w.parents[1] if paddle else w.parent) / 'metadata.yaml' if metadata.exists(): metadata = yaml_load(metadata) stride, names = int(metadata['stride']), metadata['names'] # load metadata diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 29ad140..9c09d61 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -287,6 +287,7 @@ class ClassificationModel(BaseModel): LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") self.yaml['nc'] = nc # override yaml value self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist + self.stride = torch.Tensor([1]) # no stride constraints self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict self.info() @@ -520,14 +521,15 @@ def guess_model_task(model): # Guess from model filename if isinstance(model, (str, Path)): - model = Path(model).stem - if '-seg' in model: + model = Path(model) + if '-seg' in model.stem or 'segment' in model.parts: return 'segment' - elif '-cls' in model: + elif '-cls' in model.stem or 'classify' in model.parts: return 'classify' - else: + elif 'detect' in model.parts: return 'detect' # Unable to determine task from model - raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, ' - "i.e. 'task=detect', 'task=segment' or 'task=classify'.") + LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " + "Explicitly define task for your model, i.e. 'task=detect', 'task=segment' or 'task=classify'.") + return 'detect' # assume detect diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py index 3e6bc60..5fdf8a3 100644 --- a/ultralytics/yolo/cfg/__init__.py +++ b/ultralytics/yolo/cfg/__init__.py @@ -47,11 +47,12 @@ CLI_HELP_MSG = \ GitHub: https://github.com/ultralytics/ultralytics """ -CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'} +# Define keys for arg type checks +CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'} CFG_FRACTION_KEYS = { - 'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma', - 'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', - 'mixup', 'copy_paste', 'conf', 'iou'} + 'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'label_smoothing', + 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'mixup', 'copy_paste', + 'conf', 'iou'} # fractional floats limited to 0.0 - 1.0 CFG_INT_KEYS = { 'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride', 'line_thickness', 'workspace', 'nbs', 'save_period'} @@ -224,7 +225,7 @@ def entrypoint(debug=''): assert v, f"missing '{k}' value" if k == 'cfg': # custom.yaml passed LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}') - overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'} + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'} else: if v.lower() == 'none': v = None @@ -255,7 +256,6 @@ def entrypoint(debug=''): check_cfg_mismatch(full_args_dict, {a: ''}) # Defaults - task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt') task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100') # Mode @@ -272,27 +272,28 @@ def entrypoint(debug=''): # Model model = overrides.pop('model', DEFAULT_CFG.model) - task = overrides.pop('task', None) if model is None: - model = task2model.get(task, 'yolov8n.pt') + model = 'yolov8n.pt' LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") from ultralytics.yolo.engine.model import YOLO overrides['model'] = model model = YOLO(model) # Task - if task and task != model.task: - LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. " - f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.") - task = model.task - overrides['task'] = task + # if task and task != model.task: + # LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. " + # f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.") + overrides['task'] = overrides.get('task', model.task) + model.task = overrides['task'] + + # Mode if mode in {'predict', 'track'} and 'source' not in overrides: overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \ else 'https://ultralytics.com/images/bus.jpg' LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") elif mode in ('train', 'val'): if 'data' not in overrides: - overrides['data'] = task2data.get(task, DEFAULT_CFG.data) + overrides['data'] = task2data.get(overrides['task'], DEFAULT_CFG.data) LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.") elif mode == 'export': if 'format' not in overrides: diff --git a/ultralytics/yolo/data/dataloaders/v5loader.py b/ultralytics/yolo/data/dataloaders/v5loader.py index 8ea4582..8039357 100644 --- a/ultralytics/yolo/data/dataloaders/v5loader.py +++ b/ultralytics/yolo/data/dataloaders/v5loader.py @@ -6,7 +6,6 @@ Dataloaders and dataset utils import contextlib import glob import hashlib -import json import math import os import random @@ -27,11 +26,9 @@ from PIL import ExifTags, Image, ImageOps from torch.utils.data import DataLoader, Dataset, dataloader, distributed from tqdm import tqdm -from ultralytics.yolo.data.utils import check_det_dataset from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable, - is_kaggle, yaml_load) -from ultralytics.yolo.utils.checks import check_requirements, check_yaml -from ultralytics.yolo.utils.downloads import unzip_file + is_kaggle) +from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first @@ -1037,127 +1034,6 @@ def verify_image_label(args): return [None, None, None, None, nm, nf, ne, nc, msg] -class HUBDatasetStats(): - """ Class for generating HUB dataset JSON and `-hub` dataset directory - - Arguments - path: Path to data.yaml or data.zip (with data.yaml inside data.zip) - autodownload: Attempt to download dataset if not found locally - - Usage - from ultralytics.yolo.data.dataloaders.v5loader import HUBDatasetStats - stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1 - stats = HUBDatasetStats('path/to/coco128.zip') # usage 2 - stats.get_json(save=False) - stats.process_images() - """ - - def __init__(self, path='coco128.yaml', autodownload=False): - # Initialize class - zipped, data_dir, yaml_path = self._unzip(Path(path)) - # try: - # data = yaml_load(check_yaml(yaml_path)) # data dict - # if zipped: - # data['path'] = data_dir - # except Exception as e: - # raise Exception('error/HUB/dataset_stats/yaml_load') from e - - data = check_det_dataset(yaml_path, autodownload) # download dataset if missing - self.hub_dir = Path(str(data['path']) + '-hub') - self.im_dir = self.hub_dir / 'images' - self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images - self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary - self.data = data - - @staticmethod - def _find_yaml(dir): - # Return data.yaml file - files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive - assert files, f'No *.yaml file found in {dir}' - if len(files) > 1: - files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name - assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed' - assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}' - return files[0] - - def _unzip(self, path): - # Unzip data.zip - if not str(path).endswith('.zip'): # path is data.yaml - return False, None, path - assert Path(path).is_file(), f'Error unzipping {path}, file not found' - unzip_file(path, path=path.parent) - dir = path.with_suffix('') # dataset directory == zip name - assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/' - return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path - - def _hub_ops(self, f, max_dim=1920): - # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing - f_new = self.im_dir / Path(f).name # dataset-hub image filename - try: # use PIL - im = Image.open(f) - r = max_dim / max(im.height, im.width) # ratio - if r < 1.0: # image too large - im = im.resize((int(im.width * r), int(im.height * r))) - im.save(f_new, 'JPEG', quality=50, optimize=True) # save - except Exception as e: # use OpenCV - LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}') - im = cv2.imread(f) - im_height, im_width = im.shape[:2] - r = max_dim / max(im_height, im_width) # ratio - if r < 1.0: # image too large - im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA) - cv2.imwrite(str(f_new), im) - - def get_json(self, save=False, verbose=False): - # Return dataset JSON for Ultralytics HUB - def _round(labels): - # Update labels to integer class and 6 decimal place floats - return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels] - - for split in 'train', 'val', 'test': - if self.data.get(split) is None: - self.stats[split] = None # i.e. no test set - continue - dataset = LoadImagesAndLabels(self.data[split]) # load dataset - x = np.array([ - np.bincount(label[:, 0].astype(int), minlength=self.data['nc']) - for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80) - self.stats[split] = { - 'instance_stats': { - 'total': int(x.sum()), - 'per_class': x.sum(0).tolist()}, - 'image_stats': { - 'total': dataset.n, - 'unlabelled': int(np.all(x == 0, 1).sum()), - 'per_class': (x > 0).sum(0).tolist()}, - 'labels': [{ - str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]} - - # Save, print and return - if save: - stats_path = self.hub_dir / 'stats.json' - LOGGER.info(f'Saving {stats_path.resolve()}...') - with open(stats_path, 'w') as f: - json.dump(self.stats, f) # save stats.json - if verbose: - LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) - return self.stats - - def process_images(self): - # Compress images for Ultralytics HUB - for split in 'train', 'val', 'test': - if self.data.get(split) is None: - continue - dataset = LoadImagesAndLabels(self.data[split]) # load dataset - desc = f'{split} images' - total = dataset.n - with ThreadPool(NUM_THREADS) as pool: - for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc): - pass - LOGGER.info(f'Done. All images saved to {self.im_dir}') - return self.im_dir - - # Classification dataloaders ------------------------------------------------------------------------------------------- class ClassificationDataset(torchvision.datasets.ImageFolder): """ diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index eacd3ae..fe23ec3 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -2,9 +2,11 @@ import contextlib import hashlib +import json import os import subprocess import time +from multiprocessing.pool import ThreadPool from pathlib import Path from tarfile import is_tarfile from zipfile import is_zipfile @@ -12,10 +14,11 @@ from zipfile import is_zipfile import cv2 import numpy as np from PIL import ExifTags, Image, ImageOps +from tqdm import tqdm -from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load +from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii -from ultralytics.yolo.utils.downloads import download, safe_download +from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file from ultralytics.yolo.utils.ops import segments2boxes HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' @@ -290,3 +293,128 @@ def check_cls_dataset(dataset: str): names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list names = dict(enumerate(sorted(names))) return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names} + + +class HUBDatasetStats(): + """ Class for generating HUB dataset JSON and `-hub` dataset directory + + Arguments + path: Path to data.yaml or data.zip (with data.yaml inside data.zip) + autodownload: Attempt to download dataset if not found locally + + Usage + from ultralytics.yolo.data.utils import HUBDatasetStats + stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1 + stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco6.zip') # usage 2 + stats.get_json(save=False) + stats.process_images() + """ + + def __init__(self, path='coco128.yaml', autodownload=False): + # Initialize class + zipped, data_dir, yaml_path = self._unzip(Path(path)) + try: + # data = yaml_load(check_yaml(yaml_path)) # data dict + data = check_det_dataset(yaml_path, autodownload) # data dict + if zipped: + data['path'] = data_dir + except Exception as e: + raise Exception('error/HUB/dataset_stats/yaml_load') from e + + self.hub_dir = Path(str(data['path']) + '-hub') + self.im_dir = self.hub_dir / 'images' + self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images + self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary + self.data = data + + @staticmethod + def _find_yaml(dir): + # Return data.yaml file + files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive + assert files, f'No *.yaml file found in {dir}' + if len(files) > 1: + files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name + assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed' + assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}' + return files[0] + + def _unzip(self, path): + # Unzip data.zip + if not str(path).endswith('.zip'): # path is data.yaml + return False, None, path + assert Path(path).is_file(), f'Error unzipping {path}, file not found' + unzip_file(path, path=path.parent) + dir = path.with_suffix('') # dataset directory == zip name + assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/' + return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path + + def _hub_ops(self, f, max_dim=1920): + # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing + f_new = self.im_dir / Path(f).name # dataset-hub image filename + try: # use PIL + im = Image.open(f) + r = max_dim / max(im.height, im.width) # ratio + if r < 1.0: # image too large + im = im.resize((int(im.width * r), int(im.height * r))) + im.save(f_new, 'JPEG', quality=50, optimize=True) # save + except Exception as e: # use OpenCV + LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}') + im = cv2.imread(f) + im_height, im_width = im.shape[:2] + r = max_dim / max(im_height, im_width) # ratio + if r < 1.0: # image too large + im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA) + cv2.imwrite(str(f_new), im) + + def get_json(self, save=False, verbose=False): + # Return dataset JSON for Ultralytics HUB + # from ultralytics.yolo.data import YOLODataset + from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels + + def _round(labels): + # Update labels to integer class and 6 decimal place floats + return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels] + + for split in 'train', 'val', 'test': + if self.data.get(split) is None: + self.stats[split] = None # i.e. no test set + continue + dataset = LoadImagesAndLabels(self.data[split]) # load dataset + x = np.array([ + np.bincount(label[:, 0].astype(int), minlength=self.data['nc']) + for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80) + self.stats[split] = { + 'instance_stats': { + 'total': int(x.sum()), + 'per_class': x.sum(0).tolist()}, + 'image_stats': { + 'total': len(dataset), + 'unlabelled': int(np.all(x == 0, 1).sum()), + 'per_class': (x > 0).sum(0).tolist()}, + 'labels': [{ + str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]} + + # Save, print and return + if save: + stats_path = self.hub_dir / 'stats.json' + LOGGER.info(f'Saving {stats_path.resolve()}...') + with open(stats_path, 'w') as f: + json.dump(self.stats, f) # save stats.json + if verbose: + LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) + return self.stats + + def process_images(self): + # Compress images for Ultralytics HUB + # from ultralytics.yolo.data import YOLODataset + from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels + + for split in 'train', 'val', 'test': + if self.data.get(split) is None: + continue + dataset = LoadImagesAndLabels(self.data[split]) # load dataset + with ThreadPool(NUM_THREADS) as pool: + for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'): + pass + LOGGER.info(f'Done. All images saved to {self.im_dir}') + return self.im_dir diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index f77ab01..242cecd 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -208,12 +208,15 @@ class Exporter: self.file = file self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y) self.pretty_name = self.file.stem.replace('yolo', 'YOLO') + description = f'Ultralytics {self.pretty_name} model' + f'trained on {Path(self.args.data).name}' \ + if self.args.data else '(untrained)' self.metadata = { - 'description': f'Ultralytics {self.pretty_name} model trained on {Path(self.args.data).name}', + 'description': description, 'author': 'Ultralytics', 'license': 'GPL-3.0 https://ultralytics.com/license', 'version': __version__, 'stride': int(max(model.stride)), + 'task': model.task, 'names': model.names} # model metadata LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and " diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 33135c2..c9620dc 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -9,76 +9,72 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat guess_model_task, nn) from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter -from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, yaml_load +from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS from ultralytics.yolo.utils.torch_utils import smart_inference_mode # Map head to model, trainer, validator, and predictor classes -MODEL_MAP = { +TASK_MAP = { 'classify': [ - ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator', - 'yolo.TYPE.classify.ClassificationPredictor'], + ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator, + yolo.v8.classify.ClassificationPredictor], 'detect': [ - DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator', - 'yolo.TYPE.detect.DetectionPredictor'], + DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator, + yolo.v8.detect.DetectionPredictor], 'segment': [ - SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator', - 'yolo.TYPE.segment.SegmentationPredictor']} + SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator, + yolo.v8.segment.SegmentationPredictor]} class YOLO: """ - YOLO (You Only Look Once) object detection model. - - Args: - model (str, Path): Path to the model file to load or create. - type (str): Type/version of models to use. Defaults to "v8". - - Attributes: - type (str): Type/version of models being used. - ModelClass (Any): Model class. - TrainerClass (Any): Trainer class. - ValidatorClass (Any): Validator class. - PredictorClass (Any): Predictor class. - predictor (Any): Predictor object. - model (Any): Model object. - trainer (Any): Trainer object. - task (str): Type of model task. - ckpt (Any): Checkpoint object if model loaded from *.pt file. - cfg (str): Model configuration if loaded from *.yaml file. - ckpt_path (str): Checkpoint file path. - overrides (dict): Overrides for trainer object. - metrics_data (Any): Data for metrics. - - Methods: - __call__(): Alias for predict method. - _new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions. - _load(weights): Initializes a new model and infers the task type from the model head. - _check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model. - reset(): Resets the model modules. - info(verbose=False): Logs model info. - fuse(): Fuse model for faster inference. - predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model. - - Returns: - list(ultralytics.yolo.engine.results.Results): The prediction results. - """ + YOLO (You Only Look Once) object detection model. + + Args: + model (str, Path): Path to the model file to load or create. + + Attributes: + predictor (Any): The predictor object. + model (Any): The model object. + trainer (Any): The trainer object. + task (str): The type of model task. + ckpt (Any): The checkpoint object if the model loaded from *.pt file. + cfg (str): The model configuration if loaded from *.yaml file. + ckpt_path (str): The checkpoint file path. + overrides (dict): Overrides for the trainer object. + metrics_data (Any): The data for metrics. + + Methods: + __call__(source=None, stream=False, **kwargs): + Alias for the predict method. + _new(cfg:str, verbose:bool=True) -> None: + Initializes a new model and infers the task type from the model definitions. + _load(weights:str, task:str='') -> None: + Initializes a new model and infers the task type from the model head. + _check_is_pytorch_model() -> None: + Raises TypeError if the model is not a PyTorch model. + reset() -> None: + Resets the model modules. + info(verbose:bool=False) -> None: + Logs the model info. + fuse() -> None: + Fuses the model for faster inference. + predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]: + Performs prediction using the YOLO model. + + Returns: + list[ultralytics.yolo.engine.results.Results]: The prediction results. + """ - def __init__(self, model='yolov8n.pt', type='v8') -> None: + def __init__(self, model='yolov8n.pt') -> None: """ Initializes the YOLO model. Args: model (str, Path): model to load or create - type (str): Type/version of models to use. Defaults to "v8". """ self._reset_callbacks() - self.type = type - self.ModelClass = None # model class - self.TrainerClass = None # trainer class - self.ValidatorClass = None # validator class - self.PredictorClass = None # predictor class self.predictor = None # reuse predictor self.model = None # model object self.trainer = None # trainer object @@ -101,6 +97,10 @@ class YOLO: def __call__(self, source=None, stream=False, **kwargs): return self.predict(source, stream, **kwargs) + def __getattr__(self, attr): + name = self.__class__.__name__ + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + def _new(self, cfg: str, verbose=True): """ Initializes a new model and infers the task type from the model definitions. @@ -112,11 +112,15 @@ class YOLO: self.cfg = check_yaml(cfg) # check YAML cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict self.task = guess_model_task(cfg_dict) - self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task() - self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize + self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model self.overrides['model'] = self.cfg - def _load(self, weights: str): + # Below added to allow export from yamls + args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args + self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model + self.model.task = self.task + + def _load(self, weights: str, task=''): """ Initializes a new model and infers the task type from the model head. @@ -127,8 +131,7 @@ class YOLO: if suffix == '.pt': self.model, self.ckpt = attempt_load_one_weight(weights) self.task = self.model.args['task'] - self.overrides = self.model.args - self._reset_ckpt_args(self.overrides) + self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) self.ckpt_path = self.model.pt_path else: weights = check_file(weights) @@ -136,7 +139,6 @@ class YOLO: self.task = guess_model_task(weights) self.ckpt_path = weights self.overrides['model'] = weights - self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task() def _check_is_pytorch_model(self): """ @@ -189,12 +191,13 @@ class YOLO: """ overrides = self.overrides.copy() overrides['conf'] = 0.25 - overrides.update(kwargs) + overrides.update(kwargs) # prefer kwargs overrides['mode'] = kwargs.get('mode', 'predict') assert overrides['mode'] in ['track', 'predict'] overrides['save'] = kwargs.get('save', False) # not save files by default if not self.predictor: - self.predictor = self.PredictorClass(overrides=overrides) + self.task = overrides.get('task') or self.task + self.predictor = TASK_MAP[self.task][3](overrides=overrides) self.predictor.setup_model(model=self.model) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) @@ -226,12 +229,15 @@ class YOLO: overrides['mode'] = 'val' args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args.data = data or args.data - args.task = self.task + if 'task' in overrides: + self.task = args.task + else: + args.task = self.task if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)): args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed args.imgsz = check_imgsz(args.imgsz, max_dim=1) - validator = self.ValidatorClass(args=args) + validator = TASK_MAP[self.task][2](args=args) validator(model=self.model) self.metrics_data = validator.metrics @@ -267,8 +273,7 @@ class YOLO: args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed if args.batch == DEFAULT_CFG.batch: args.batch = 1 # default to 1 if not modified - exporter = Exporter(overrides=args) - return exporter(model=self.model) + return Exporter(overrides=args)(model=self.model) def train(self, **kwargs): """ @@ -282,15 +287,15 @@ class YOLO: overrides.update(kwargs) if kwargs.get('cfg'): LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.") - overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True) - overrides['task'] = self.task + overrides = yaml_load(check_yaml(kwargs['cfg'])) overrides['mode'] = 'train' if not overrides.get('data'): raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") if overrides.get('resume'): overrides['resume'] = self.ckpt_path - self.trainer = self.TrainerClass(overrides=overrides) + self.task = overrides.get('task') or self.task + self.trainer = TASK_MAP[self.task][1](overrides=overrides) if not overrides.get('resume'): # manually set model only if not resuming self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.model = self.trainer.model @@ -311,13 +316,6 @@ class YOLO: self._check_is_pytorch_model() self.model.to(device) - def _assign_ops_from_task(self): - model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task] - trainer_class = eval(train_lit.replace('TYPE', f'{self.type}')) - validator_class = eval(val_lit.replace('TYPE', f'{self.type}')) - predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}')) - return model_class, trainer_class, validator_class, predictor_class - @property def names(self): """ @@ -357,9 +355,8 @@ class YOLO: @staticmethod def _reset_ckpt_args(args): - for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \ - 'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify': - args.pop(arg, None) + include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} @staticmethod def _reset_callbacks(): diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index f318cac..2b56421 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -108,7 +108,6 @@ class BasePredictor: def postprocess(self, preds, img, orig_img): return preds - @smart_inference_mode() def __call__(self, source=None, model=None, stream=False): if stream: return self.stream_inference(source, model) @@ -136,6 +135,7 @@ class BasePredictor: self.source_type = self.dataset.source_type self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs + @smart_inference_mode() def stream_inference(self, source=None, model=None): if self.args.verbose: LOGGER.info('') @@ -161,12 +161,14 @@ class BasePredictor: self.batch = batch path, im, im0s, vid_cap, s = batch visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False + + # preprocess with self.dt[0]: im = self.preprocess(im) if len(im.shape) == 3: im = im[None] # expand for batch dim - # Inference + # inference with self.dt[1]: preds = self.model(im, augment=self.args.augment, visualize=visualize) diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index f60f2cc..a39e5c4 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -18,29 +18,33 @@ from ultralytics.yolo.utils.plotting import Annotator, colors class Results: """ - A class for storing and manipulating inference results. + A class for storing and manipulating inference results. - Args: - boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. - masks (Masks, optional): A Masks object containing the detection masks. - probs (torch.Tensor, optional): A tensor containing the detection class probabilities. - orig_img (tuple, optional): Original image size. - - Attributes: - boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. - masks (Masks, optional): A Masks object containing the detection masks. - probs (torch.Tensor, optional): A tensor containing the detection class probabilities. - orig_img (tuple, optional): Original image size. - data (torch.Tensor): The raw masks tensor + Args: + orig_img (numpy.ndarray): The original image as a numpy array. + path (str): The path to the image file. + names (List[str]): A list of class names. + boxes (List[List[float]], optional): A list of bounding box coordinates for each detection. + masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image. + probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class. - """ + Attributes: + orig_img (numpy.ndarray): The original image as a numpy array. + orig_shape (tuple): The original image shape in (height, width) format. + boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. + masks (Masks, optional): A Masks object containing the detection masks. + probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class. + names (List[str]): A list of class names. + path (str): The path to the image file. + _keys (tuple): A tuple of attribute names for non-empty attributes. + """ def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None: self.orig_img = orig_img self.orig_shape = orig_img.shape[:2] - self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes - self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks - self.probs = probs if probs is not None else None + self.boxes = Boxes(boxes.cpu(), self.orig_shape) if boxes is not None else None # native size boxes + self.masks = Masks(masks.cpu(), self.orig_shape) if masks is not None else None # native size or imgsz masks + self.probs = probs.cpu() if probs is not None else None self.names = names self.path = path self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None) @@ -99,24 +103,22 @@ class Results: def __getattr__(self, attr): name = self.__class__.__name__ - raise AttributeError(f""" - '{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are: - - Attributes: - boxes (Boxes, optional): A Boxes object containing the detection bounding boxes. - masks (Masks, optional): A Masks object containing the detection masks. - probs (torch.Tensor, optional): A tensor containing the detection class probabilities. - orig_shape (tuple, optional): Original image size. - """) + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): """ - Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image + Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image. Args: - show_conf (bool): Show confidence - line_width (Float): The line width of boxes. Automatically scaled to img size if not provided - font_size (Float): The font size of . Automatically scaled to img size if not provided + show_conf (bool): Whether to show the detection confidence score. + line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size. + font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size. + font (str): The font to use for the text. + pil (bool): Whether to return the image as a PIL Image. + example (str): An example string to display in the plot. Useful for indicating the expected format of the output. + + Returns: + None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned. """ img = deepcopy(self.orig_img) annotator = Annotator(img, line_width, font_size, font, pil, example) @@ -157,15 +159,24 @@ class Boxes: boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes, with shape (num_boxes, 6). orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width). + is_track (bool): True if the boxes also include track IDs, False otherwise. Properties: xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format. conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes. cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes. + id (torch.Tensor) or (numpy.ndarray): The track IDs of the boxes (if available). xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format. xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size. xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size. data (torch.Tensor): The raw bboxes tensor + + Methods: + cpu(): Move the object to CPU memory. + numpy(): Convert the object to a numpy array. + cuda(): Move the object to CUDA memory. + to(*args, **kwargs): Move the object to the specified device. + pandas(): Convert the object to a pandas DataFrame (not yet implemented). """ def __init__(self, boxes, orig_shape) -> None: @@ -257,22 +268,7 @@ class Boxes: def __getattr__(self, attr): name = self.__class__.__name__ - raise AttributeError(f""" - '{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are: - - Attributes: - boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes, - with shape (num_boxes, 6). - orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width). - - Properties: - xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format. - conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes. - cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes. - xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format. - xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size. - xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size. - """) + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") class Masks: @@ -288,7 +284,18 @@ class Masks: orig_shape (tuple): Original image size, in the format (height, width). Properties: - segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks. + segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks. + + Methods: + cpu(): Returns a copy of the masks tensor on CPU memory. + numpy(): Returns a copy of the masks tensor as a numpy array. + cuda(): Returns a copy of the masks tensor on GPU memory. + to(): Returns a copy of the masks tensor with the specified device and dtype. + __len__(): Returns the number of masks in the tensor. + __str__(): Returns a string representation of the masks tensor. + __repr__(): Returns a detailed string representation of the masks tensor. + __getitem__(): Returns a new Masks object with the masks at the specified index. + __getattr__(): Raises an AttributeError with a list of valid attributes and properties. """ def __init__(self, masks, orig_shape) -> None: @@ -337,13 +344,4 @@ class Masks: def __getattr__(self, attr): name = self.__class__.__name__ - raise AttributeError(f""" - '{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are: - - Attributes: - masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width). - orig_shape (tuple): Original image size, in the format (height, width). - - Properties: - segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks. - """) + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 9573471..1dc38ff 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -44,7 +44,6 @@ class BaseTrainer: Attributes: args (SimpleNamespace): Configuration for the trainer. check_resume (method): Method to check if training should be resumed from a saved checkpoint. - console (logging.Logger): Logger instance. validator (BaseValidator): Validator instance. model (nn.Module): Model instance. callbacks (defaultdict): Dictionary of callbacks. @@ -84,7 +83,6 @@ class BaseTrainer: self.args = get_cfg(cfg, overrides) self.device = select_device(self.args.device, self.args.batch) self.check_resume() - self.console = LOGGER self.validator = None self.model = None self.metrics = None @@ -180,11 +178,12 @@ class BaseTrainer: if world_size > 1 and 'LOCAL_RANK' not in os.environ: cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans try: + LOGGER.info(f'Running DDP command {cmd}') subprocess.run(cmd, check=True) except Exception as e: - self.console.warning(e) + LOGGER.warning(e) finally: - ddp_cleanup(self, file) + ddp_cleanup(self, str(file)) else: self._do_train(RANK, world_size) @@ -193,7 +192,7 @@ class BaseTrainer: # os.environ['MASTER_PORT'] = '9020' torch.cuda.set_device(rank) self.device = torch.device('cuda', rank) - self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}') + LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}') dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size) def _setup_train(self, rank, world_size): @@ -262,10 +261,10 @@ class BaseTrainer: nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations last_opt_step = -1 self.run_callbacks('on_train_start') - self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' - f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' - f"Logging results to {colorstr('bold', self.save_dir)}\n" - f'Starting training for {self.epochs} epochs...') + LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' + f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' + f"Logging results to {colorstr('bold', self.save_dir)}\n" + f'Starting training for {self.epochs} epochs...') if self.args.close_mosaic: base_idx = (self.epochs - self.args.close_mosaic) * nb self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) @@ -278,14 +277,14 @@ class BaseTrainer: pbar = enumerate(self.train_loader) # Update dataloader attributes (optional) if epoch == (self.epochs - self.args.close_mosaic): - self.console.info('Closing dataloader mosaic') + LOGGER.info('Closing dataloader mosaic') if hasattr(self.train_loader.dataset, 'mosaic'): self.train_loader.dataset.mosaic = False if hasattr(self.train_loader.dataset, 'close_mosaic'): self.train_loader.dataset.close_mosaic(hyp=self.args) if rank in {-1, 0}: - self.console.info(self.progress_string()) + LOGGER.info(self.progress_string()) pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT) self.tloss = None self.optimizer.zero_grad() @@ -372,12 +371,11 @@ class BaseTrainer: if rank in {-1, 0}: # Do final val with best.pt - self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in ' - f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') + LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in ' + f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') self.final_eval() if self.args.plots: self.plot_metrics() - self.log(f"Results saved to {colorstr('bold', self.save_dir)}") self.run_callbacks('on_train_end') torch.cuda.empty_cache() self.run_callbacks('teardown') @@ -450,18 +448,6 @@ class BaseTrainer: self.best_fitness = fitness return metrics, fitness - def log(self, text, rank=-1): - """ - Logs the given text to given ranks process if provided, otherwise logs to all ranks. - - Args" - text (str): text to log - rank (List[Int]): process rank - - """ - if rank in {-1, 0}: - self.console.info(text) - def get_model(self, cfg=None, weights=None, verbose=True): raise NotImplementedError("This task trainer doesn't support loading cfg files") @@ -521,7 +507,7 @@ class BaseTrainer: if f.exists(): strip_optimizer(f) # strip optimizers if f is self.best: - self.console.info(f'\nValidating {f}...') + LOGGER.info(f'\nValidating {f}...') self.metrics = self.validator(model=f) self.metrics.pop('fitness', None) self.run_callbacks('on_fit_epoch_end') @@ -564,7 +550,7 @@ class BaseTrainer: self.best_fitness = best_fitness self.start_epoch = start_epoch if start_epoch > (self.epochs - self.args.close_mosaic): - self.console.info('Closing dataloader mosaic') + LOGGER.info('Closing dataloader mosaic') if hasattr(self.train_loader.dataset, 'mosaic'): self.train_loader.dataset.mosaic = False if hasattr(self.train_loader.dataset, 'close_mosaic'): diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index b2a5394..3f1a5ec 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -44,7 +44,6 @@ class BaseValidator: Attributes: dataloader (DataLoader): Dataloader to use for validation. pbar (tqdm): Progress bar to update during validation. - logger (logging.Logger): Logger to use for validation. args (SimpleNamespace): Configuration for the validator. model (nn.Module): Model to validate. data (dict): Data dictionary. @@ -56,7 +55,7 @@ class BaseValidator: save_dir (Path): Directory to save results. """ - def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None): """ Initializes a BaseValidator instance. @@ -69,14 +68,13 @@ class BaseValidator: """ self.dataloader = dataloader self.pbar = pbar - self.logger = logger or LOGGER self.args = args or get_cfg(DEFAULT_CFG) self.model = None self.data = None self.device = None self.batch_i = None self.training = True - self.speed = None + self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} self.jdict = None project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task @@ -123,14 +121,14 @@ class BaseValidator: self.device = model.device if not pt and not jit: self.args.batch = 1 # export.py models default to batch-size 1 - self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') + LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'): self.data = check_det_dataset(self.args.data) elif self.args.task == 'classify': self.data = check_cls_dataset(self.args.data) else: - raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌")) + raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) if self.device.type == 'cpu': self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading @@ -179,7 +177,7 @@ class BaseValidator: stats = self.get_stats() self.check_stats(stats) self.print_results() - self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image + self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt))) self.finalize_metrics() self.run_callbacks('on_val_end') if self.training: @@ -187,11 +185,11 @@ class BaseValidator: results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')} return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats else: - self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' % - self.speed) + LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' % + tuple(self.speed.values())) if self.args.save_json and self.jdict: with open(str(self.save_dir / 'predictions.json'), 'w') as f: - self.logger.info(f'Saving {f.name}...') + LOGGER.info(f'Saving {f.name}...') json.dump(self.jdict, f) # flatten and save stats = self.eval_json(stats) # update stats if self.args.plots or self.args.save_json: diff --git a/ultralytics/yolo/utils/benchmarks.py b/ultralytics/yolo/utils/benchmarks.py index eb0c440..e0f4e92 100644 --- a/ultralytics/yolo/utils/benchmarks.py +++ b/ultralytics/yolo/utils/benchmarks.py @@ -60,12 +60,12 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', # Export if format == '-': - filename = model.ckpt_path + filename = model.ckpt_path or model.cfg export = model # PyTorch format else: filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others export = YOLO(filename) - assert suffix in str(filename), 'export failed' + assert suffix in str(filename), 'export failed' # Predict if not (ROOT / 'assets/bus.jpg').exists(): diff --git a/ultralytics/yolo/utils/callbacks/clearml.py b/ultralytics/yolo/utils/callbacks/clearml.py index bd2d5e2..533da4a 100644 --- a/ultralytics/yolo/utils/callbacks/clearml.py +++ b/ultralytics/yolo/utils/callbacks/clearml.py @@ -29,7 +29,7 @@ def on_pretrain_routine_start(trainer): auto_connect_frameworks={'pytorch': False}) task.connect(vars(trainer.args), name='General') except Exception as e: - LOGGER.warning(f'WARNING ⚠️ ClearML not initialized correctly, not logging this run. {e}') + LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}') def on_train_epoch_end(trainer): @@ -41,9 +41,9 @@ def on_fit_epoch_end(trainer): task = Task.current_task() if task and trainer.epoch == 0: model_info = { - 'Parameters': get_num_params(trainer.model), - 'GFLOPs': round(get_flops(trainer.model), 3), - 'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)} + 'model/parameters': get_num_params(trainer.model), + 'model/GFLOPs': round(get_flops(trainer.model), 3), + 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} task.connect(model_info, name='Model') diff --git a/ultralytics/yolo/utils/callbacks/comet.py b/ultralytics/yolo/utils/callbacks/comet.py index d33871c..81c6afe 100644 --- a/ultralytics/yolo/utils/callbacks/comet.py +++ b/ultralytics/yolo/utils/callbacks/comet.py @@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer): experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8') experiment.log_parameters(vars(trainer.args)) except Exception as e: - LOGGER.warning(f'WARNING ⚠️ Comet not initialized correctly, not logging this run. {e}') + LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}') def on_train_epoch_end(trainer): @@ -36,7 +36,7 @@ def on_fit_epoch_end(trainer): model_info = { 'model/parameters': get_num_params(trainer.model), 'model/GFLOPs': round(get_flops(trainer.model), 3), - 'model/speed(ms)': round(trainer.validator.speed[1], 3)} + 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} experiment.log_metrics(model_info, step=trainer.epoch + 1) diff --git a/ultralytics/yolo/utils/callbacks/tensorboard.py b/ultralytics/yolo/utils/callbacks/tensorboard.py index aafd1b8..612409a 100644 --- a/ultralytics/yolo/utils/callbacks/tensorboard.py +++ b/ultralytics/yolo/utils/callbacks/tensorboard.py @@ -2,17 +2,24 @@ from torch.utils.tensorboard import SummaryWriter +from ultralytics.yolo.utils import LOGGER + writer = None # TensorBoard SummaryWriter instance def _log_scalars(scalars, step=0): - for k, v in scalars.items(): - writer.add_scalar(k, v, step) + if writer: + for k, v in scalars.items(): + writer.add_scalar(k, v, step) def on_pretrain_routine_start(trainer): global writer - writer = SummaryWriter(str(trainer.save_dir)) + try: + writer = SummaryWriter(str(trainer.save_dir)) + except Exception as e: + writer = None # TensorBoard SummaryWriter instance + LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') def on_batch_end(trainer): diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index aaf6f25..0ee5952 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -254,7 +254,7 @@ def check_file(file, suffix='', download=True): return file else: # search files = [] - for d in 'models', 'datasets', 'tracker/cfg': # search directories + for d in 'models', 'datasets', 'tracker/cfg', 'yolo/cfg': # search directories files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file if not files: raise FileNotFoundError(f"'{file}' does not exist") diff --git a/ultralytics/yolo/utils/dist.py b/ultralytics/yolo/utils/dist.py index 2be30a8..9c91540 100644 --- a/ultralytics/yolo/utils/dist.py +++ b/ultralytics/yolo/utils/dist.py @@ -51,10 +51,9 @@ def generate_ddp_command(world_size, trainer): file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0]) # Build command - torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' - cmd = [ - sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port', - f'{find_free_network_port()}', file] + args + dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' + port = find_free_network_port() + cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args return cmd, file diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index 45380ee..53770b8 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -12,7 +12,7 @@ import requests import torch from tqdm import tqdm -from ultralytics.yolo.utils import LOGGER +from ultralytics.yolo.utils import LOGGER, checks GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \ [f'yolov5{size}u.pt' for size in 'nsmlx'] + \ @@ -87,7 +87,7 @@ def safe_download(url, try: if curl or i > 0: # curl download with retry, continue s = 'sS' * (not progress) # silent - r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '9', '-C', '-']).returncode + r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode assert r == 0, f'Curl return value {r}' else: # urllib download method = 'torch' @@ -112,8 +112,10 @@ def safe_download(url, break # success f.unlink() # remove partial downloads except Exception as e: - if i >= retry: - raise ConnectionError(f'❌ Download failure for {url}') from e + if i == 0 and not checks.check_online(): + raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e + elif i >= retry: + raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}: diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 9b8d497..ec03d1c 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight from ultralytics.yolo import v8 from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.engine.trainer import BaseTrainer -from ultralytics.yolo.utils import DEFAULT_CFG, RANK +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer @@ -64,6 +64,7 @@ class ClassificationTrainer(BaseTrainer): self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) else: FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') + ClassificationModel.reshape_outputs(self.model, self.data['nc']) return # dont return ckpt. Classification doesn't support resume @@ -93,7 +94,7 @@ class ClassificationTrainer(BaseTrainer): def get_validator(self): self.loss_names = ['loss'] - return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console) + return v8.classify.ClassificationValidator(self.test_loader, self.save_dir) def criterion(self, preds, batch): loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs @@ -132,11 +133,12 @@ class ClassificationTrainer(BaseTrainer): strip_optimizer(f) # strip optimizers # TODO: validate best.pt after training completes # if f is self.best: - # self.console.info(f'\nValidating {f}...') + # LOGGER.info(f'\nValidating {f}...') # self.validator.args.save_json = True # self.metrics = self.validator(model=f) # self.metrics.pop('fitness', None) # self.run_callbacks('on_fit_epoch_end') + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") def train(cfg=DEFAULT_CFG, use_python=False): diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index 01cf309..30fd621 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -2,14 +2,14 @@ from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.engine.validator import BaseValidator -from ultralytics.yolo.utils import DEFAULT_CFG +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER from ultralytics.yolo.utils.metrics import ClassifyMetrics class ClassificationValidator(BaseValidator): - def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): - super().__init__(dataloader, save_dir, pbar, logger, args) + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None): + super().__init__(dataloader, save_dir, pbar, args) self.args.task = 'classify' self.metrics = ClassifyMetrics() @@ -31,7 +31,7 @@ class ClassificationValidator(BaseValidator): self.targets.append(batch['cls']) def finalize_metrics(self, *args, **kwargs): - self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed)) + self.metrics.speed = self.speed def get_stats(self): self.metrics.process(self.targets, self.pred) @@ -45,7 +45,7 @@ class ClassificationValidator(BaseValidator): def print_results(self): pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format - self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5)) + LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) def val(cfg=DEFAULT_CFG, use_python=False): diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 874d489..98ab7ae 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -66,10 +66,7 @@ class DetectionTrainer(BaseTrainer): def get_validator(self): self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' - return v8.detect.DetectionValidator(self.test_loader, - save_dir=self.save_dir, - logger=self.console, - args=copy(self.args)) + return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) def criterion(self, preds, batch): if not hasattr(self, 'compute_loss'): diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 974ed76..116c630 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -9,7 +9,7 @@ import torch from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.engine.validator import BaseValidator -from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr, ops from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou from ultralytics.yolo.utils.plotting import output_to_target, plot_images @@ -18,8 +18,8 @@ from ultralytics.yolo.utils.torch_utils import de_parallel class DetectionValidator(BaseValidator): - def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): - super().__init__(dataloader, save_dir, pbar, logger, args) + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None): + super().__init__(dataloader, save_dir, pbar, args) self.args.task = 'detect' self.is_coco = False self.class_map = None @@ -112,7 +112,7 @@ class DetectionValidator(BaseValidator): # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') def finalize_metrics(self, *args, **kwargs): - self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed)) + self.metrics.speed = self.speed def get_stats(self): stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy @@ -123,15 +123,15 @@ class DetectionValidator(BaseValidator): def print_results(self): pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format - self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) if self.nt_per_class.sum() == 0: - self.logger.warning( + LOGGER.warning( f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels') # Print results per class if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): for i, c in enumerate(self.metrics.ap_class_index): - self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) + LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) if self.args.plots: self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values())) @@ -212,7 +212,7 @@ class DetectionValidator(BaseValidator): if self.args.save_json and self.is_coco and len(self.jdict): anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations pred_json = self.save_dir / 'predictions.json' # predictions - self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb check_requirements('pycocotools>=2.0.6') from pycocotools.coco import COCO # noqa @@ -230,7 +230,7 @@ class DetectionValidator(BaseValidator): eval.summarize() stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50 except Exception as e: - self.logger.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f'pycocotools unable to run: {e}') return stats diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 84adde7..51110e2 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -68,11 +68,10 @@ class SegmentationPredictor(DetectionPredictor): log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # Mask plotting - self.annotator.masks( - mask.masks, - colors=[colors(x, True) for x in det.cls], - im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() / - 255 if self.args.retina_masks else im[idx]) + if self.args.save or self.args.show: + im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute( + 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx] + self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu) # Write results for j, d in enumerate(reversed(det)): diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index cfbe74a..46b7ede 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -32,10 +32,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer): def get_validator(self): self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' - return v8.segment.SegmentationValidator(self.test_loader, - save_dir=self.save_dir, - logger=self.console, - args=copy(self.args)) + return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) def criterion(self, preds, batch): if not hasattr(self, 'compute_loss'): @@ -86,10 +83,6 @@ class SegLoss(Loss): gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) - masks = batch['masks'].to(self.device).float() - if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample - masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] - # pboxes pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) @@ -103,10 +96,15 @@ class SegLoss(Loss): # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - # bbox loss if fg_mask.sum(): + # bbox loss loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, target_scores, target_scores_sum, fg_mask) + # masks loss + masks = batch['masks'].to(self.device).float() + if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample + masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] + for i in range(batch_size): if fg_mask[i].sum(): mask_idx = target_gt_idx[i][fg_mask[i]] @@ -121,9 +119,9 @@ class SegLoss(Loss): marea) # seg loss # WARNING: Uncomment lines below in case of Multi-GPU DDP unused gradient errors # else: - # loss[1] += proto.sum() * 0 + # loss[1] += proto.sum() * 0 + pred_masks.sum() * 0 # else: - # loss[1] += proto.sum() * 0 + # loss[1] += proto.sum() * 0 + pred_masks.sum() * 0 loss[0] *= self.hyp.box # box gain loss[1] *= self.hyp.box / batch_size # seg gain diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 65ee4b6..26759f1 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -7,7 +7,7 @@ import numpy as np import torch import torch.nn.functional as F -from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.plotting import output_to_target, plot_images @@ -16,8 +16,8 @@ from ultralytics.yolo.v8.detect import DetectionValidator class SegmentationValidator(DetectionValidator): - def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): - super().__init__(dataloader, save_dir, pbar, logger, args) + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None): + super().__init__(dataloader, save_dir, pbar, args) self.args.task = 'segment' self.metrics = SegmentMetrics(save_dir=self.save_dir) @@ -120,7 +120,7 @@ class SegmentationValidator(DetectionValidator): # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') def finalize_metrics(self, *args, **kwargs): - self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed)) + self.metrics.speed = self.speed def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False): """ @@ -207,7 +207,7 @@ class SegmentationValidator(DetectionValidator): if self.args.save_json and self.is_coco and len(self.jdict): anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations pred_json = self.save_dir / 'predictions.json' # predictions - self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb check_requirements('pycocotools>=2.0.6') from pycocotools.coco import COCO # noqa @@ -228,7 +228,7 @@ class SegmentationValidator(DetectionValidator): stats[self.metrics.keys[idx + 1]], stats[ self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 except Exception as e: - self.logger.warning(f'pycocotools unable to run: {e}') + LOGGER.warning(f'pycocotools unable to run: {e}') return stats