ultralytics 8.0.59 new MLFlow and feature updates (#1720)

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: St. HeMeow <sheng.heyang@gmail.com>
Co-authored-by: Danny Kim <imbird0312@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Torge Kummerow <CySlider@users.noreply.github.com>
Co-authored-by: dankernel <dkdkernel@gmail.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Roshanlal <roshanlaladchitre103@gmail.com>
Co-authored-by: Lorenzo Mammana <lorenzo.mammana@orobix.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
Glenn Jocher
2023-03-31 20:33:02 +02:00
committed by GitHub
parent ccb6419835
commit e7876e1ba9
29 changed files with 326 additions and 160 deletions

View File

@ -11,14 +11,28 @@ from typing import Dict, List, Union
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print)
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify', 'pose'
TASK2DATA = {
'detect': 'coco128.yaml',
'segment': 'coco128-seg.yaml',
'pose': 'coco128-pose.yaml',
'classify': 'imagenet100'}
TASK2MODEL = {
'detect': 'yolov8n.pt',
'segment': 'yolov8n-seg.pt',
'pose': 'yolov8n-pose.yaml',
'classify': 'yolov8n-cls.pt'} # temp
CLI_HELP_MSG = \
f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export, track]
Where TASK (optional) is one of {TASKS}
MODE (required) is one of {MODES}
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
@ -59,12 +73,6 @@ CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', '
'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms',
'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader')
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify'
TASK2DATA = {'detect': 'coco128.yaml', 'segment': 'coco128-seg.yaml', 'classify': 'imagenet100'}
TASK2MODEL = {'detect': 'yolov8n.pt', 'segment': 'yolov8n-seg.pt', 'classify': 'yolov8n-cls.pt'}
def cfg2dict(cfg):
"""

View File

@ -26,7 +26,7 @@ seed: 0 # random seed for reproducibility
deterministic: True # whether to enable deterministic mode
single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training
rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val'
rect: False # rectangular training if mode='train' or rectangular validation if mode='val'
cos_lr: False # use cosine learning rate scheduler
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
resume: False # resume training from last checkpoint

View File

@ -278,6 +278,7 @@ def check_cls_dataset(dataset: str):
data (dict): A dictionary containing the following keys and values:
'train': Path object for the directory containing the training set of the dataset
'val': Path object for the directory containing the validation set of the dataset
'test': Path object for the directory containing the test set of the dataset
'nc': Number of classes in the dataset
'names': List of class names in the dataset
"""
@ -293,11 +294,12 @@ def check_cls_dataset(dataset: str):
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s)
train_set = data_dir / 'train'
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names)))
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
class HUBDatasetStats():

View File

@ -230,8 +230,9 @@ class YOLO:
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker import register_tracker
register_tracker(self)
if not hasattr(self.predictor, 'trackers'):
from ultralytics.tracker import register_tracker
register_tracker(self)
# ByteTrack-based method needs low confidence predictions as input
conf = kwargs.get('conf') or 0.1
kwargs['conf'] = conf

View File

@ -3,16 +3,16 @@
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
Usage - sources:
$ yolo mode=predict model=yolov8n.pt --source 0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images
list.streams # list of streams
'path/*.jpg' # glob
'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images
list.streams # list of streams
'path/*.jpg' # glob
'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
Usage - formats:
$ yolo mode=predict model=yolov8n.pt # PyTorch

View File

@ -14,6 +14,7 @@ import torchvision.transforms.functional as F
from ultralytics.yolo.utils import LOGGER, SimpleClass, ops
from ultralytics.yolo.utils.plotting import Annotator, colors
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
class Results(SimpleClass):
@ -129,7 +130,10 @@ class Results(SimpleClass):
if masks is not None:
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
if TORCHVISION_0_10:
im = F.resize(im.contiguous(), masks.data.shape[1:], antialias=True) / 255
else:
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
if probs is not None:
@ -259,7 +263,8 @@ class Masks(SimpleClass):
orig_shape (tuple): Original image size, in the format (height, width).
Properties:
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection.
xy (list): A list of segments (pixels) which includes x, y segments of each detection.
xyn (list): A list of segments (normalized) which includes x, y segments of each detection.
Methods:
cpu(): Returns a copy of the masks tensor on CPU memory.
@ -272,13 +277,28 @@ class Masks(SimpleClass):
self.masks = masks # N, h, w
self.orig_shape = orig_shape
def segments(self):
# Segments-deprecated (normalized)
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
"'Masks.xy' for segments (pixels) instead.")
return self.xyn
@property
@lru_cache(maxsize=1)
def segments(self):
def xyn(self):
# Segments (normalized)
return [
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=True)
for x in ops.masks2segments(self.masks)]
@property
@lru_cache(maxsize=1)
def xy(self):
# Segments (pixels)
return [
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=False)
for x in ops.masks2segments(self.masks)]
@property
def shape(self):
return self.masks.shape

View File

@ -370,6 +370,7 @@ class BaseTrainer:
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks('on_fit_epoch_end')
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
# Early Stopping
if RANK != -1: # if DDP training

View File

@ -484,7 +484,7 @@ def get_user_config_dir(sub_dir='Ultralytics'):
return path
USER_CONFIG_DIR = os.getenv('YOLO_CONFIG_DIR', get_user_config_dir()) # Ultralytics settings dir
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR', get_user_config_dir())) # Ultralytics settings dir
def emojis(string=''):

View File

@ -48,8 +48,6 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
emoji, filename = '', None # export defaults
try:
if model.task == 'classify':
assert i != 11, 'paddle cls exports coming soon'
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
if i == 10:
assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'

View File

@ -147,9 +147,10 @@ def add_integration_callbacks(instance):
from .clearml import callbacks as clearml_callbacks
from .comet import callbacks as comet_callbacks
from .hub import callbacks as hub_callbacks
from .mlflow import callbacks as mf_callbacks
from .tensorboard import callbacks as tb_callbacks
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks:
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks:
for k, v in x.items():
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
instance.callbacks[k].append(v) # callback[name].append(func)

View File

@ -6,9 +6,9 @@ try:
import clearml
from clearml import Task
assert clearml.__version__ # verify package is not directory
assert hasattr(clearml, '__version__') # verify package is not directory
assert not TESTS_RUNNING # do not log pytest
except (ImportError, AssertionError, AttributeError):
except (ImportError, AssertionError):
clearml = None

View File

@ -6,8 +6,8 @@ try:
import comet_ml
assert not TESTS_RUNNING # do not log pytest
assert comet_ml.__version__ # verify package is not directory
except (ImportError, AssertionError, AttributeError):
assert hasattr(comet_ml, '__version__') # verify package is not directory
except (ImportError, AssertionError):
comet_ml = None

View File

@ -0,0 +1,75 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import os
import re
from pathlib import Path
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
try:
import mlflow
assert not TESTS_RUNNING # do not log pytest
assert hasattr(mlflow, '__version__') # verify package is not directory
except (ImportError, AssertionError):
mlflow = None
def on_pretrain_routine_end(trainer):
global mlflow, run, run_id, experiment_name
if os.environ.get('MLFLOW_TRACKING_URI') is None:
mlflow = None
if mlflow:
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
mlflow.set_tracking_uri(mlflow_location)
experiment_name = trainer.args.project or 'YOLOv8'
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)
prefix = colorstr('MLFlow: ')
try:
run, active_run = mlflow, mlflow.start_run() if mlflow else None
if active_run is not None:
run_id = active_run.info.run_id
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
except Exception as err:
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
LOGGER.warning(f'{prefix}Continuing without Mlflow')
run = None
run.log_params(vars(trainer.model.args))
def on_fit_epoch_end(trainer):
if mlflow:
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
def on_model_save(trainer):
if mlflow:
run.log_artifact(trainer.last)
def on_train_end(trainer):
if mlflow:
root_dir = Path(__file__).resolve().parents[3]
run.log_artifact(trainer.best)
model_uri = f'runs:/{run_id}/'
run.register_model(model_uri, experiment_name)
run.pyfunc.log_model(artifact_path=experiment_name,
code_path=[str(root_dir)],
artifacts={'model_path': str(trainer.save_dir)},
python_model=run.pyfunc.PythonModel())
callbacks = {
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_model_save': on_model_save,
'on_train_end': on_train_end} if mlflow else {}

View File

@ -16,10 +16,12 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
from ultralytics.yolo.utils.checks import check_version
TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
TORCH_1_12 = check_version(torch.__version__, '1.12.0')