`ultralytics 8.0.20` CLI `yolo` simplifications, DDP and ONNX fixes (#608)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sid Prabhakaran <s2siddhu@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 59d4335664
commit 15b3b0365a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.19" __version__ = "8.0.20"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -1,14 +1,19 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import os import os
import platform
import shutil import shutil
import sys
import threading import threading
import time import time
from pathlib import Path
from random import random from random import random
import requests import requests
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis,
get_git_origin_url, is_colab, is_docker, is_git_dir, is_github_actions_ci,
is_jupyter, is_kaggle, is_pip_package, is_pytest_running)
PREFIX = colorstr('Ultralytics: ') PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
@ -131,8 +136,31 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
return func(*args, **kwargs) return func(*args, **kwargs)
@TryExcept(verbose=False) class Traces:
def traces(cfg, all_keys=False, traces_sample_rate=0.0):
def __init__(self):
"""
Initialize Traces for error tracking and reporting if tests are not currently running.
"""
from ultralytics import __version__
env = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
'Docker' if is_docker() else platform.system()
self.rate_limit = 3.0 # rate limit (seconds)
self.t = time.time() # rate limit timer (seconds)
self.metadata = {
"sys_argv_name": Path(sys.argv[0]).name,
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
"python": platform.python_version(),
"release": __version__,
"environment": env}
self.enabled = SETTINGS['sync'] and \
RANK in {-1, 0} and \
not is_pytest_running() and \
not is_github_actions_ci() and \
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
@TryExcept(verbose=False)
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
""" """
Sync traces data if enabled in the global settings Sync traces data if enabled in the global settings
@ -141,11 +169,24 @@ def traces(cfg, all_keys=False, traces_sample_rate=0.0):
all_keys (bool): Sync all items, not just non-default values. all_keys (bool): Sync all items, not just non-default values.
traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0 traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0
""" """
if SETTINGS['sync'] and RANK in {-1, 0} and (random() < traces_sample_rate): t = time.time() # current time
if self.enabled and random() < traces_sample_rate and (t - self.t) > self.rate_limit:
self.t = t # reset rate limit timer
cfg = vars(cfg) # convert type from IterableSimpleNamespace to dict cfg = vars(cfg) # convert type from IterableSimpleNamespace to dict
if not all_keys: if not all_keys: # filter cfg
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None)} # retain non-default values include_keys = {'task', 'mode'} # always include
cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None) or k in include_keys}
trace = {'uuid': SETTINGS['uuid'], 'cfg': cfg, 'metadata': self.metadata}
# Send a request to the HUB API to sync analytics # Send a request to the HUB API to sync analytics
smart_request(f'{HUB_API_ROOT}/v1/usage/anonymous', json=cfg, headers=None, code=3, retry=0, verbose=False) smart_request(f'{HUB_API_ROOT}/v1/usage/anonymous',
json=trace,
headers=None,
code=3,
retry=0,
verbose=False)
# Run below code on hub/utils init -------------------------------------------------------------------------------------
traces = Traces()

@ -472,10 +472,13 @@ def guess_model_task(model):
Raises: Raises:
SyntaxError: If the task of the model could not be determined. SyntaxError: If the task of the model could not be determined.
""" """
cfg, task = None, None cfg = None
if isinstance(model, dict): if isinstance(model, dict):
cfg = model cfg = model
elif isinstance(model, nn.Module): # PyTorch model elif isinstance(model, nn.Module): # PyTorch model
for x in 'model.args', 'model.model.args', 'model.model.model.args':
with contextlib.suppress(Exception):
return eval(x)['task']
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml': for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
cfg = eval(x) cfg = eval(x)
@ -485,25 +488,22 @@ def guess_model_task(model):
if cfg: if cfg:
m = cfg["head"][-1][-2].lower() # output module name m = cfg["head"][-1][-2].lower() # output module name
if m in ["classify", "classifier", "cls", "fc"]: if m in ["classify", "classifier", "cls", "fc"]:
task = "classify" return "classify"
if m in ["detect"]: if m in ["detect"]:
task = "detect" return "detect"
if m in ["segment"]: if m in ["segment"]:
task = "segment" return "segment"
# Guess from PyTorch model # Guess from PyTorch model
if task is None and isinstance(model, nn.Module): if isinstance(model, nn.Module):
for m in model.modules(): for m in model.modules():
if isinstance(m, Detect): if isinstance(m, Detect):
task = "detect" return "detect"
elif isinstance(m, Segment): elif isinstance(m, Segment):
task = "segment" return "segment"
elif isinstance(m, Classify): elif isinstance(m, Classify):
task = "classify" return "classify"
# Unable to determine task from model # Unable to determine task from model
if task is None:
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, " raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
"i.e. 'task=detect', 'task=segment' or 'task=classify'.") "i.e. 'task=detect', 'task=segment' or 'task=classify'.")
else:
return task

@ -8,8 +8,8 @@ from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Union from typing import Dict, List, Union
from ultralytics import __version__, yolo from ultralytics import __version__
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, USER_CONFIG_DIR, from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT, USER_CONFIG_DIR,
IterableSimpleNamespace, colorstr, yaml_load, yaml_print) IterableSimpleNamespace, colorstr, yaml_load, yaml_print)
from ultralytics.yolo.utils.checks import check_yolo from ultralytics.yolo.utils.checks import check_yolo
@ -211,30 +211,42 @@ def entrypoint(debug=False):
else: else:
raise argument_error(a) raise argument_error(a)
cfg = get_cfg(DEFAULT_CFG_DICT, overrides) # create CFG instance # Mode
mode = overrides.pop('mode', None)
# Checks error catch model = overrides.pop('model', None)
if cfg.mode == 'checks': if mode == 'checks':
LOGGER.warning( LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
"WARNING ⚠️ 'yolo mode=checks' is deprecated and will be removed in the future. Use 'yolo checks' instead.")
check_yolo() check_yolo()
return return
elif mode is None:
# Mapping from task to module mode = DEFAULT_CFG_DICT['mode'] or 'predict'
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task) LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
if not module:
raise SyntaxError(f"yolo task={cfg.task} is invalid. Valid tasks are: {', '.join(tasks)}\n{CLI_HELP_MSG}") # Model
if model is None:
# Mapping from mode to function model = DEFAULT_CFG_DICT['model'] or 'yolov8n.pt'
func = { LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
"train": module.train, from ultralytics.yolo.engine.model import YOLO
"val": module.val, model = YOLO(model)
"predict": module.predict, task = model.task
"export": yolo.engine.exporter.export}.get(cfg.mode)
if not func: # Task
raise SyntaxError(f"yolo mode={cfg.mode} is invalid. Valid modes are: {', '.join(modes)}\n{CLI_HELP_MSG}") if mode == 'predict' and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG_DICT['source'] or ROOT / "assets" if (ROOT / "assets").exists() \
func(cfg) else "https://ultralytics.com/images/bus.jpg"
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
overrides['data'] = DEFAULT_CFG_DICT['data'] or 'mnist160' if task == 'classify' \
else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml'
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG_DICT['format'] or 'torchscript'
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python
getattr(model, mode)(verbose=True, **overrides)
# Special modes -------------------------------------------------------------------------------------------------------- # Special modes --------------------------------------------------------------------------------------------------------

@ -1,26 +1,26 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training # Default training settings and hyperparameters for medium-augmentation COCO training
task: "detect" # inference task, i.e. detect, segment, classify task: detect # inference task, i.e. detect, segment, classify
mode: "train" # YOLO mode, i.e. train, val, predict, export mode: train # YOLO mode, i.e. train, val, predict, export
# Train settings ------------------------------------------------------------------------------------------------------- # Train settings -------------------------------------------------------------------------------------------------------
model: null # path to model file, i.e. yolov8n.pt, yolov8n.yaml model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: null # path to data file, i.e. i.e. coco128.yaml data: # path to data file, i.e. i.e. coco128.yaml
epochs: 100 # number of epochs to train for epochs: 100 # number of epochs to train for
patience: 50 # epochs to wait for no observable improvement for early stopping of training patience: 50 # epochs to wait for no observable improvement for early stopping of training
batch: 16 # number of images per batch (-1 for AutoBatch) batch: 16 # number of images per batch (-1 for AutoBatch)
imgsz: 640 # size of input images as integer or w,h imgsz: 640 # size of input images as integer or w,h
save: True # save train checkpoints and predict results save: True # save train checkpoints and predict results
cache: False # True/ram, disk or False. Use cache for data loading cache: False # True/ram, disk or False. Use cache for data loading
device: null # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu device: # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
workers: 8 # number of worker threads for data loading (per RANK if DDP) workers: 8 # number of worker threads for data loading (per RANK if DDP)
project: null # project name project: # project name
name: null # experiment name name: # experiment name
exist_ok: False # whether to overwrite existing experiment exist_ok: False # whether to overwrite existing experiment
pretrained: False # whether to use a pretrained model pretrained: False # whether to use a pretrained model
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
verbose: False # whether to print verbose output verbose: True # whether to print verbose output
seed: 0 # random seed for reproducibility seed: 0 # random seed for reproducibility
deterministic: True # whether to enable deterministic mode deterministic: True # whether to enable deterministic mode
single_cls: False # train multi-class data as single-class single_cls: False # train multi-class data as single-class
@ -39,7 +39,7 @@ dropout: 0.0 # use dropout regularization (classify train only)
val: True # validate/test during training val: True # validate/test during training
save_json: False # save results to JSON file save_json: False # save results to JSON file
save_hybrid: False # save hybrid version of labels (labels + additional predictions) save_hybrid: False # save hybrid version of labels (labels + additional predictions)
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val) conf: # object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou: 0.7 # intersection over union (IoU) threshold for NMS iou: 0.7 # intersection over union (IoU) threshold for NMS
max_det: 300 # maximum number of detections per image max_det: 300 # maximum number of detections per image
half: False # use half precision (FP16) half: False # use half precision (FP16)
@ -47,7 +47,7 @@ dnn: False # use OpenCV DNN for ONNX inference
plots: True # save plots during train/val plots: True # save plots during train/val
# Prediction settings -------------------------------------------------------------------------------------------------- # Prediction settings --------------------------------------------------------------------------------------------------
source: null # source directory for images or videos source: # source directory for images or videos
show: False # show results if possible show: False # show results if possible
save_txt: False # save results as .txt file save_txt: False # save results as .txt file
save_conf: False # save results with confidence scores save_conf: False # save results with confidence scores
@ -59,7 +59,7 @@ line_thickness: 3 # bounding box thickness (pixels)
visualize: False # visualize model features visualize: False # visualize model features
augment: False # apply image augmentation to prediction sources augment: False # apply image augmentation to prediction sources
agnostic_nms: False # class-agnostic NMS agnostic_nms: False # class-agnostic NMS
classes: null # filter results by class, i.e. class=0, or class=[0,2,3] classes: # filter results by class, i.e. class=0, or class=[0,2,3]
retina_masks: False # use high-resolution segmentation masks retina_masks: False # use high-resolution segmentation masks
boxes: True # Show boxes in segmentation predictions boxes: True # Show boxes in segmentation predictions
@ -70,7 +70,7 @@ optimize: False # TorchScript: optimize for mobile
int8: False # CoreML/TF INT8 quantization int8: False # CoreML/TF INT8 quantization
dynamic: False # ONNX/TF/TensorRT: dynamic axes dynamic: False # ONNX/TF/TensorRT: dynamic axes
simplify: False # ONNX: simplify model simplify: False # ONNX: simplify model
opset: 17 # ONNX: opset version opset: # ONNX: opset version (optional)
workspace: 4 # TensorRT: workspace size (GB) workspace: 4 # TensorRT: workspace size (GB)
nms: False # CoreML: add NMS nms: False # CoreML: add NMS
@ -103,7 +103,7 @@ mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
# Custom config.yaml --------------------------------------------------------------------------------------------------- # Custom config.yaml ---------------------------------------------------------------------------------------------------
cfg: null # for overriding defaults.yaml cfg: # for overriding defaults.yaml
# Debug, do not modify ------------------------------------------------------------------------------------------------- # Debug, do not modify -------------------------------------------------------------------------------------------------
v5loader: False # use legacy YOLOv5 dataloader v5loader: False # use legacy YOLOv5 dataloader

@ -116,6 +116,9 @@ class IterableSimpleNamespace(SimpleNamespace):
# Default configuration # Default configuration
with open(DEFAULT_CFG_PATH, errors='ignore') as f: with open(DEFAULT_CFG_PATH, errors='ignore') as f:
DEFAULT_CFG_DICT = yaml.safe_load(f) DEFAULT_CFG_DICT = yaml.safe_load(f)
for k, v in DEFAULT_CFG_DICT.items():
if isinstance(v, str) and v.lower() == 'none':
DEFAULT_CFG_DICT[k] = None
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
@ -448,13 +451,13 @@ def set_sentry():
""" """
def before_send(event, hint): def before_send(event, hint):
oss = 'colab' if is_colab() else 'kaggle' if is_kaggle() else 'jupyter' if is_jupyter() else \ env = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
'docker' if is_docker() else platform.system() 'Docker' if is_docker() else platform.system()
event['tags'] = { event['tags'] = {
"sys_argv": sys.argv[0], "sys_argv": sys.argv[0],
"sys_argv_name": Path(sys.argv[0]).name, "sys_argv_name": Path(sys.argv[0]).name,
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', "install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
"os": oss} "os": env}
return event return event
if SETTINGS['sync'] and \ if SETTINGS['sync'] and \
@ -529,7 +532,7 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
yaml_save(file, SETTINGS) yaml_save(file, SETTINGS)
# Run below code on utils init ----------------------------------------------------------------------------------------- # Run below code on yolo/utils init ------------------------------------------------------------------------------------
# Set logger # Set logger
set_logging(LOGGING_NAME) # run before defining LOGGER set_logging(LOGGING_NAME) # run before defining LOGGER

@ -48,19 +48,19 @@ def on_train_end(trainer):
def on_train_start(trainer): def on_train_start(trainer):
traces(trainer.args, traces_sample_rate=0.0) traces(trainer.args, traces_sample_rate=1.0)
def on_val_start(validator): def on_val_start(validator):
traces(validator.args, traces_sample_rate=0.0) traces(validator.args, traces_sample_rate=1.0)
def on_predict_start(predictor): def on_predict_start(predictor):
traces(predictor.args, traces_sample_rate=0.0) traces(predictor.args, traces_sample_rate=1.0)
def on_export_start(exporter): def on_export_start(exporter):
traces(exporter.args, traces_sample_rate=0.0) traces(exporter.args, traces_sample_rate=1.0)
callbacks = { callbacks = {

@ -31,7 +31,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
@contextmanager @contextmanager
def torch_distributed_zero_first(local_rank: int): def torch_distributed_zero_first(local_rank: int):
# Decorator to make all processes in distributed training wait for each local_master to do something # Decorator to make all processes in distributed training wait for each local_master to do something
initialized = torch.distributed.is_initialized() # prevent 'Default process group has not been initialized' errors initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
if initialized and local_rank not in {-1, 0}: if initialized and local_rank not in {-1, 0}:
dist.barrier(device_ids=[local_rank]) dist.barrier(device_ids=[local_rank])
yield yield

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
import torch import torch
@ -63,11 +64,17 @@ class ClassificationPredictor(BasePredictor):
return log_string return log_string
def predict(cfg=DEFAULT_CFG): def predict(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \ source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"
predictor = ClassificationPredictor(cfg)
args = dict(model=model, source=source, verbose=True)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = ClassificationPredictor(args)
predictor.predict_cli() predictor.predict_cli()

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
import torch import torch
import torchvision import torchvision
@ -135,22 +136,18 @@ class ClassificationTrainer(BaseTrainer):
# self.run_callbacks('on_fit_epoch_end') # self.run_callbacks('on_fit_epoch_end')
def train(cfg=DEFAULT_CFG): def train(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist") data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
# Reproduce ImageNet results args = dict(model=model, data=data, device=device, verbose=True)
# cfg.lr0 = 0.1 if use_python:
# cfg.weight_decay = 5e-5
# cfg.label_smoothing = 0.1
# cfg.warmup_epochs = 0.0
cfg.device = cfg.device if cfg.device is not None else ''
# trainer = ClassificationTrainer(cfg)
# trainer.train()
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) YOLO(model).train(**args)
model.train(**vars(cfg)) else:
trainer = ClassificationTrainer(args)
trainer.train()
if __name__ == "__main__": if __name__ == "__main__":

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
@ -45,11 +46,17 @@ class ClassificationValidator(BaseValidator):
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5)) self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
def val(cfg=DEFAULT_CFG): def val(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "mnist160" data = cfg.data or "mnist160"
validator = ClassificationValidator(args=cfg)
validator(model=cfg.model) args = dict(model=model, data=data, verbose=True)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = ClassificationValidator(args=args)
validator(model=args['model'])
if __name__ == "__main__": if __name__ == "__main__":

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
import torch import torch
@ -81,11 +82,17 @@ class DetectionPredictor(BasePredictor):
return log_string return log_string
def predict(cfg=DEFAULT_CFG): def predict(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n.pt" model = cfg.model or "yolov8n.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \ source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"
predictor = DetectionPredictor(cfg)
args = dict(model=model, source=source, verbose=True)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = DetectionPredictor(args)
predictor.predict_cli() predictor.predict_cli()

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
from copy import copy from copy import copy
import torch import torch
@ -194,15 +194,18 @@ class Loss:
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def train(cfg=DEFAULT_CFG): def train(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n.pt" model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''
# trainer = DetectionTrainer(cfg)
# trainer.train() args = dict(model=model, data=data, device=device, verbose=True)
if use_python:
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) YOLO(model).train(**args)
model.train(**vars(cfg)) else:
trainer = DetectionTrainer(args)
trainer.train()
if __name__ == "__main__": if __name__ == "__main__":

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import os import os
import sys
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@ -232,11 +233,17 @@ class DetectionValidator(BaseValidator):
return stats return stats
def val(cfg=DEFAULT_CFG): def val(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n.pt" model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml" data = cfg.data or "coco128.yaml"
validator = DetectionValidator(args=cfg)
validator(model=cfg.model) args = dict(model=model, data=data, verbose=True)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = DetectionValidator(args=args)
validator(model=args['model'])
if __name__ == "__main__": if __name__ == "__main__":

@ -1,5 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
import torch import torch
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
@ -98,11 +100,17 @@ class SegmentationPredictor(DetectionPredictor):
return log_string return log_string
def predict(cfg=DEFAULT_CFG): def predict(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n-seg.pt" model = cfg.model or "yolov8n-seg.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \ source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"
predictor = SegmentationPredictor(cfg)
args = dict(model=model, source=source, verbose=True)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = SegmentationPredictor(args)
predictor.predict_cli() predictor.predict_cli()

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import sys
from copy import copy from copy import copy
import torch import torch
@ -140,15 +140,18 @@ class SegLoss(Loss):
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
def train(cfg=DEFAULT_CFG): def train(cfg=DEFAULT_CFG, use_python=False):
cfg.model = cfg.model or "yolov8n-seg.pt" model = cfg.model or "yolov8n-seg.pt"
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''
# trainer = SegmentationTrainer(cfg)
# trainer.train() args = dict(model=model, data=data, device=device, verbose=True)
if use_python:
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) YOLO(model).train(**args)
model.train(**vars(cfg)) else:
trainer = SegmentationTrainer(args)
trainer.train()
if __name__ == "__main__": if __name__ == "__main__":

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import os import os
import sys
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from pathlib import Path from pathlib import Path
@ -242,10 +243,17 @@ class SegmentationValidator(DetectionValidator):
return stats return stats
def val(cfg=DEFAULT_CFG): def val(cfg=DEFAULT_CFG, use_python=False):
cfg.data = cfg.data or "coco128-seg.yaml" model = cfg.model or "yolov8n-seg.pt"
validator = SegmentationValidator(args=cfg) data = cfg.data or "coco128-seg.yaml"
validator(model=cfg.model)
args = dict(model=model, data=data, verbose=True)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = SegmentationValidator(args=args)
validator(model=args['model'])
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save