From f8e32c4c1394cded1554d1df09fbaf8e7c19ad17 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 14 Jan 2023 17:39:50 +0100 Subject: [PATCH] General `ultralytics==8.0.6` updates (#351) Co-authored-by: Dzmitry Plashchynski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci.yaml | 12 +++++----- docker/Dockerfile | 2 +- tests/test_engine.py | 20 ++++++++--------- ultralytics/hub/session.py | 30 ++++++++++++------------- ultralytics/yolo/configs/__init__.py | 6 ++--- ultralytics/yolo/data/utils.py | 22 ++++++++---------- ultralytics/yolo/engine/predictor.py | 2 +- ultralytics/yolo/engine/trainer.py | 8 ++++--- ultralytics/yolo/utils/__init__.py | 25 ++++++++++++--------- ultralytics/yolo/utils/callbacks/hub.py | 2 -- ultralytics/yolo/utils/metrics.py | 2 +- ultralytics/yolo/v8/classify/train.py | 7 +++--- ultralytics/yolo/v8/detect/train.py | 8 +++---- ultralytics/yolo/v8/segment/predict.py | 3 +-- ultralytics/yolo/v8/segment/train.py | 7 +++--- ultralytics/yolo/v8/segment/val.py | 3 +-- 16 files changed, 79 insertions(+), 80 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 123d9d7..c1b726a 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -84,22 +84,22 @@ jobs: - name: Test detection shell: bash # for Windows compatibility run: | - yolo task=detect mode=train model=yolov8n.yaml data=coco8.yaml epochs=1 imgsz=32 - yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=32 + yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32 + yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32 yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript - name: Test segmentation shell: bash # for Windows compatibility run: | - yolo task=segment mode=train model=yolov8n-seg.yaml data=coco8-seg.yaml epochs=1 imgsz=32 - yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco8-seg.yaml imgsz=32 + yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32 + yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32 yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript - name: Test classification shell: bash # for Windows compatibility run: | - yolo task=classify mode=train model=yolov8n-cls.yaml data=mnist160 epochs=1 imgsz=32 - yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32 + yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32 + yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32 yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript - name: Pytest tests diff --git a/docker/Dockerfile b/docker/Dockerfile index 99f219d..accd2aa 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -52,7 +52,7 @@ ENV OMP_NUM_THREADS=1 # t=ultralytics/ultralytics:latest tnew=ultralytics/ultralytics:v6.2 && sudo docker pull $t && sudo docker tag $t $tnew && sudo docker push $tnew # Clean up -# docker system prune -a --volumes +# sudo docker system prune -a --volumes # Update Ubuntu drivers # https://www.maketecheasier.com/install-nvidia-drivers-ubuntu/ diff --git a/tests/test_engine.py b/tests/test_engine.py index c74dc7a..e33811f 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,13 +1,16 @@ # Ultralytics YOLO 🚀, GPL-3.0 license +from pathlib import Path + from ultralytics.yolo.configs import get_config -from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT +from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, SETTINGS from ultralytics.yolo.v8 import classify, detect, segment CFG_DET = 'yolov8n.yaml' CFG_SEG = 'yolov8n-seg.yaml' CFG_CLS = 'squeezenet1_0' CFG = get_config(DEFAULT_CONFIG) +MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' SOURCE = ROOT / "assets" @@ -18,15 +21,14 @@ def test_detect(): # Trainer trainer = detect.DetectionTrainer(overrides=overrides) trainer.train() - trained_model = trainer.best # Validator val = detect.DetectionValidator(args=CFG) - val(model=trained_model) + val(model=trainer.best) # validate best.pt # Predictor pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) - result = pred(source=SOURCE, model="yolov8n.pt", return_outputs=True) + result = pred(source=SOURCE, model=f"{MODEL}.pt", return_outputs=True) assert len(list(result)), "predictor test failed" overrides["resume"] = trainer.last @@ -49,15 +51,14 @@ def test_segment(): # trainer trainer = segment.SegmentationTrainer(overrides=overrides) trainer.train() - trained_model = trainer.best # Validator val = segment.SegmentationValidator(args=CFG) - val(model=trained_model) + val(model=trainer.best) # validate best.pt # Predictor pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]}) - result = pred(source=SOURCE, model="yolov8n-seg.pt", return_outputs=True) + result = pred(source=SOURCE, model=f"{MODEL}-seg.pt", return_outputs=True) assert len(list(result)) == 2, "predictor test failed" # Test resume @@ -82,13 +83,12 @@ def test_classify(): # Trainer trainer = classify.ClassificationTrainer(overrides=overrides) trainer.train() - trained_model = trainer.best # Validator val = classify.ClassificationValidator(args=CFG) - val(model=trained_model) + val(model=trainer.best) # Predictor pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]}) - result = pred(source=SOURCE, model=trained_model, return_outputs=True) + result = pred(source=SOURCE, model=trainer.best, return_outputs=True) assert len(list(result)) == 2, "predictor test failed" diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 58d268f..817530c 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -1,7 +1,5 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -import signal -import sys from pathlib import Path from time import sleep @@ -15,19 +13,21 @@ AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__versio session = None - -def signal_handler(signum, frame): - """ Confirm exit """ - global hub_logger - LOGGER.info(f'Signal received. {signum} {frame}') - if isinstance(session, HubTrainingSession): - hub_logger.alive = False - del hub_logger - sys.exit(signum) - - -signal.signal(signal.SIGTERM, signal_handler) -signal.signal(signal.SIGINT, signal_handler) +# Causing problems in tests (non-authenticated) +# import signal +# import sys +# def signal_handler(signum, frame): +# """ Confirm exit """ +# global hub_logger +# LOGGER.info(f'Signal received. {signum} {frame}') +# if isinstance(session, HubTrainingSession): +# hub_logger.alive = False +# del hub_logger +# sys.exit(signum) +# +# +# signal.signal(signal.SIGTERM, signal_handler) +# signal.signal(signal.SIGINT, signal_handler) class HubTrainingSession: diff --git a/ultralytics/yolo/configs/__init__.py b/ultralytics/yolo/configs/__init__.py index 2809303..1eda761 100644 --- a/ultralytics/yolo/configs/__init__.py +++ b/ultralytics/yolo/configs/__init__.py @@ -8,13 +8,13 @@ from omegaconf import DictConfig, OmegaConf from ultralytics.yolo.configs.hydra_patch import check_config_mismatch -def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = None): +def get_config(config: Union[str, Path, DictConfig], overrides: Union[str, Dict] = None): """ Load and merge configuration data from a file or dictionary. Args: - config (Union[str, DictConfig]): Configuration data in the form of a file name or a DictConfig object. - overrides (Union[str, Dict], optional): Overrides in the form of a file name or a dictionary. Default is None. + config (str) or (Path) or (DictConfig): Configuration data in the form of a file name or a DictConfig object. + overrides (str) or(Dict), optional: Overrides in the form of a file name or a dictionary. Default is None. Returns: OmegaConf.Namespace: Training arguments namespace. diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index 744fa0d..c5cf4ac 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -14,12 +14,11 @@ import numpy as np import torch from PIL import ExifTags, Image, ImageOps -from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, yaml_load +from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.downloads import download from ultralytics.yolo.utils.files import unzip_file - -from ..utils.ops import segments2boxes +from ultralytics.yolo.utils.ops import segments2boxes HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data" IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes @@ -173,12 +172,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): areas = [] ms = [] for si in range(len(segments)): - mask = polygon2mask( - imgsz, - [segments[si].reshape(-1)], - downsample_ratio=downsample_ratio, - color=1, - ) + mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1) ms.append(mask) areas.append(mask.sum()) areas = np.asarray(areas) @@ -194,13 +188,14 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): def check_dataset_yaml(data, autodownload=True): # Download, check and/or unzip dataset if not found locally data = check_file(data) - DATASETS_DIR = (Path.cwd() / "../datasets").resolve() # TODO: handle global dataset dir + # Download (optional) extract_dir = '' if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1) data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) extract_dir, autodownload = data.parent, False + # Read yaml (optional) if isinstance(data, (str, Path)): data = yaml_load(data, append_filename=True) # dictionary @@ -215,7 +210,7 @@ def check_dataset_yaml(data, autodownload=True): # Resolve paths path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.' if not path.is_absolute(): - path = (Path.cwd() / path).resolve() + path = (DATASETS_DIR / path).resolve() data['path'] = path # download scripts for k in 'train', 'val', 'test': if data.get(k): # prepend path @@ -253,6 +248,7 @@ def check_dataset_yaml(data, autodownload=True): s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" LOGGER.info(f"Dataset download {s}") check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts + return data # dictionary @@ -274,12 +270,12 @@ def check_dataset(dataset: str): 'nc': Number of classes in the dataset 'names': List of class names in the dataset """ - data_dir = (Path.cwd() / "datasets" / dataset).resolve() + data_dir = (DATASETS_DIR / dataset).resolve() if not data_dir.is_dir(): LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') t = time.time() if dataset == 'imagenet': - subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) + subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True) else: url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' download(url, dir=data_dir.parent) diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 13b3e5a..58495c1 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -240,7 +240,7 @@ class BasePredictor: if isinstance(self.vid_writer[idx], cv2.VideoWriter): self.vid_writer[idx].release() # release previous video writer if vid_cap: # video - fps = vid_cap.get(cv2.CAP_PROP_FPS) + fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) else: # stream diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 727a50f..d107291 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -506,10 +506,12 @@ class BaseTrainer: def check_resume(self): resume = self.args.resume if resume: - last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run()) + last = Path(check_file(resume) if isinstance(resume, (str, Path)) else get_latest_run()) args_yaml = last.parent.parent / 'args.yaml' # train options yaml - if args_yaml.is_file(): - args = get_config(args_yaml) # replace + assert args_yaml.is_file(), \ + FileNotFoundError('Resume checkpoint f{last} not found. ' + 'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt') + args = get_config(args_yaml) # replace args.model, resume = str(last), True # reinstate self.args = args self.resume = resume diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 9e8417f..ea2867a 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -187,7 +187,7 @@ def get_git_root_dir(): """ try: output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) - return Path(output.stdout.strip().decode('utf-8')).parent # parent/.git + return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # parent/.git except subprocess.CalledProcessError: return None @@ -348,16 +348,18 @@ def yaml_load(file='data.yaml', append_filename=False): return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f) -def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): +def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.0'): """ - Loads a global settings YAML file or creates one with default values if it does not exist. + Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. Args: - file (Path): Path to the settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR. + file (Path): Path to the Ultralytics settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR. + version (str): Settings version. If min settings version not met, new default settings will be saved. Returns: dict: Dictionary of settings key-value pairs. """ + from ultralytics.yolo.utils.checks import check_version from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first root = get_git_root_dir() or Path('') # not is_pip_package() @@ -366,7 +368,8 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): 'weights_dir': str(root / 'weights'), # default weights directory. 'runs_dir': str(root / 'runs'), # default runs directory. 'sync': True, # sync analytics to help with YOLO development - 'uuid': uuid.getnode()} # device UUID to align analytics + 'uuid': uuid.getnode(), # device UUID to align analytics + 'settings_version': version} # Ultralytics settings version with torch_distributed_zero_first(RANK): if not file.exists(): @@ -375,12 +378,14 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): settings = yaml_load(file) # Check that settings keys and types match defaults - correct = settings.keys() == defaults.keys() and \ - all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) + correct = settings.keys() == defaults.keys() \ + and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \ + and check_version(settings['settings_version'], version) if not correct: - LOGGER.warning('WARNING ⚠️ Different global settings detected, resetting to defaults. ' - 'This may be due to an ultralytics package update. ' - f'View and update your global settings directly in {file}') + LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. ' + '\nThis is normal and may be due to a recent ultralytics package update, ' + 'but may have overwritten previous settings. ' + f"\nYou may view and update settings directly in '{file}'") settings = defaults # merge **defaults with **settings (prefer **settings) yaml_save(file, settings) # save updated defaults diff --git a/ultralytics/yolo/utils/callbacks/hub.py b/ultralytics/yolo/utils/callbacks/hub.py index 47a7e54..2f9163e 100644 --- a/ultralytics/yolo/utils/callbacks/hub.py +++ b/ultralytics/yolo/utils/callbacks/hub.py @@ -3,8 +3,6 @@ import json from time import time -import torch - from ultralytics.hub.utils import PREFIX, sync_analytics from ultralytics.yolo.utils import LOGGER diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py index 35a973a..cd8a88a 100644 --- a/ultralytics/yolo/utils/metrics.py +++ b/ultralytics/yolo/utils/metrics.py @@ -252,7 +252,7 @@ class ConfusionMatrix: vmin=0.0, xticklabels=ticklabels, yticklabels=ticklabels).set_facecolor((1, 1, 1)) - ax.set_ylabel('True') + ax.set_xlabel('True') ax.set_ylabel('Predicted') ax.set_title('Confusion Matrix') fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 33375a1..237ae8e 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -113,11 +113,10 @@ class ClassificationTrainer(BaseTrainer): """ # Not needed for classification but necessary for segmentation & detection keys = [f"{prefix}/{x}" for x in self.loss_names] - if loss_items is not None: - loss_items = [round(float(loss_items), 5)] - return dict(zip(keys, loss_items)) - else: + if loss_items is None: return keys + loss_items = [round(float(loss_items), 5)] + return dict(zip(keys, loss_items)) def resume_training(self, ckpt): pass diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 88dff1a..817bca9 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -48,14 +48,14 @@ class DetectionTrainer(BaseTrainer): return batch def set_model_attributes(self): - nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps) - self.args.box *= 3 / nl # scale to layers + # nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps) + # self.args.box *= 3 / nl # scale to layers # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers - self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers + # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers self.model.nc = self.data["nc"] # attach number of classes to model + self.model.names = self.data["names"] # attach class names to model self.model.args = self.args # attach hyperparameters to model # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc - self.model.names = self.data["names"] def get_model(self, cfg=None, weights=None, verbose=True): model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose) diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index a1e5b22..47f3021 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -6,8 +6,7 @@ import torch from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.plotting import colors, save_one_box - -from ..detect.predict import DetectionPredictor +from ultralytics.yolo.v8.detect.predict import DetectionPredictor class SegmentationPredictor(DetectionPredictor): diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 95e6438..be41f16 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -13,14 +13,15 @@ from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.tal import make_anchors from ultralytics.yolo.utils.torch_utils import de_parallel - -from ..detect.train import Loss +from ultralytics.yolo.v8.detect.train import Loss # BaseTrainer python usage class SegmentationTrainer(v8.detect.DetectionTrainer): - def __init__(self, config=DEFAULT_CONFIG, overrides={}): + def __init__(self, config=DEFAULT_CONFIG, overrides=None): + if overrides is None: + overrides = {} overrides["task"] = "segment" super().__init__(config, overrides) diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 99f612e..32b8a9e 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -13,8 +13,7 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.plotting import output_to_target, plot_images - -from ..detect import DetectionValidator +from ultralytics.yolo.v8.detect import DetectionValidator class SegmentationValidator(DetectionValidator):