`ultralytics 8.0.18` new python callbacks and minor fixes (#580)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jeroen Rombouts <36196499+jarombouts@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent e9ab157330
commit 936414c615
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -108,7 +108,7 @@ yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
#### Python #### Python
YOLOv8 may also be used directly in a Python environment, and accepts the YOLOv8 may also be used directly in a Python environment, and accepts the
same [arguments](https://docs.ultralytics.com/config/) as in the CLI example above: same [arguments](https://docs.ultralytics.com/cfg/) as in the CLI example above:
```python ```python
from ultralytics import YOLO from ultralytics import YOLO

@ -70,7 +70,7 @@ YOLOv8 可以直接在命令行界面CLI中使用 `yolo` 命令运行:
yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
``` ```
`yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/config/)的完整列表。 `yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/cfg/)的完整列表。
```bash ```bash
yolo task=detect mode=train model=yolov8n.pt args... yolo task=detect mode=train model=yolov8n.pt args...
@ -79,7 +79,7 @@ yolo task=detect mode=train model=yolov8n.pt args...
export yolov8n.pt format=onnx args... export yolov8n.pt format=onnx args...
``` ```
YOLOv8 也可以在 Python 环境中直接使用,并接受与上面 CLI 例子中相同的[参数](https://docs.ultralytics.com/config/) YOLOv8 也可以在 Python 环境中直接使用,并接受与上面 CLI 例子中相同的[参数](https://docs.ultralytics.com/cfg/)
```python ```python
from ultralytics import YOLO from ultralytics import YOLO

@ -167,7 +167,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL
=== "Example 2" === "Example 2"
Predict a YouTube video using a pretrained segmentation model at image size 320: Predict a YouTube video using a pretrained segmentation model at image size 320:
```bash ```bash
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320 yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
``` ```
=== "Example 3" === "Example 3"

@ -101,7 +101,7 @@
"source": [ "source": [
"# 1. Predict\n", "# 1. Predict\n",
"\n", "\n",
"YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list of available `yolo` [arguments](https://docs.ultralytics.com/config/) in the YOLOv8 [Docs](https://docs.ultralytics.com).\n" "YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list of available `yolo` [arguments](https://docs.ultralytics.com/cfg/) in the YOLOv8 [Docs](https://docs.ultralytics.com).\n"
] ]
}, },
{ {

@ -127,7 +127,3 @@ def test_workflow():
model.val() model.val()
model.predict(SOURCE) model.predict(SOURCE)
model.export(format="onnx", opset=12) # export a model to ONNX format model.export(format="onnx", opset=12) # export a model to ONNX format
if __name__ == "__main__":
test_predict_img()

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.17" __version__ = "8.0.18"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -24,7 +24,7 @@ yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
``` ```
They may also be used directly in a Python environment, and accepts the same They may also be used directly in a Python environment, and accepts the same
[arguments](https://docs.ultralytics.com/config/) as in the CLI example above: [arguments](https://docs.ultralytics.com/cfg/) as in the CLI example above:
```python ```python
from ultralytics import YOLO from ultralytics import YOLO

@ -222,7 +222,8 @@ class AutoBackend(nn.Module):
nhwc = model.runtime.startswith("tensorflow") nhwc = model.runtime.startswith("tensorflow")
''' '''
else: else:
raise NotImplementedError(f'ERROR: {w} is not a supported format') raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see "
f"https://docs.ultralytics.com/reference/nn/")
# class names # class names
if 'names' not in locals(): if 'names' not in locals():

@ -28,7 +28,7 @@ CLI_HELP_MSG = \
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
2. Predict a YouTube video using a pretrained segmentation model at image size 320: 2. Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320 yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
3. Val a pretrained detection model at batch-size 1 and image size 640: 3. Val a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640 yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
@ -126,13 +126,13 @@ def merge_equals_args(args: List[str]) -> List[str]:
""" """
new_args = [] new_args = []
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg == '=' and 0 < i < len(args) - 1: if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
new_args[-1] += f"={args[i + 1]}" new_args[-1] += f"={args[i + 1]}"
del args[i + 1] del args[i + 1]
elif arg.endswith('=') and i < len(args) - 1: elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
new_args.append(f"{arg}{args[i + 1]}") new_args.append(f"{arg}{args[i + 1]}")
del args[i + 1] del args[i + 1]
elif arg.startswith('=') and i > 0: elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
new_args[-1] += arg new_args[-1] += arg
else: else:
new_args.append(arg) new_args.append(arg)
@ -178,7 +178,7 @@ def entrypoint(debug=False):
if '=' in a: if '=' in a:
try: try:
re.sub(r' *= *', '=', a) # remove spaces around equals sign re.sub(r' *= *', '=', a) # remove spaces around equals sign
k, v = a.split('=') k, v = a.split('=', 1) # split on first '=' sign
if k == 'cfg': # custom.yaml passed if k == 'cfg': # custom.yaml passed
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}") LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}")
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'} overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}

@ -59,8 +59,9 @@ line_thickness: 3 # bounding box thickness (pixels)
visualize: False # visualize model features visualize: False # visualize model features
augment: False # apply image augmentation to prediction sources augment: False # apply image augmentation to prediction sources
agnostic_nms: False # class-agnostic NMS agnostic_nms: False # class-agnostic NMS
retina_masks: False # use high-resolution segmentation masks
classes: null # filter results by class, i.e. class=0, or class=[0,2,3] classes: null # filter results by class, i.e. class=0, or class=[0,2,3]
retina_masks: False # use high-resolution segmentation masks
boxes: True # Show boxes in segmentation predictions
# Export settings ------------------------------------------------------------------------------------------------------ # Export settings ------------------------------------------------------------------------------------------------------
format: torchscript # format to export to format: torchscript # format to export to

@ -28,7 +28,7 @@ from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.data.utils import check_dataset, unzip_file from ultralytics.yolo.data.utils import check_det_dataset, unzip_file
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable, from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
is_kaggle) is_kaggle)
from ultralytics.yolo.utils.checks import check_requirements, check_yaml from ultralytics.yolo.utils.checks import check_requirements, check_yaml
@ -1061,7 +1061,7 @@ class HUBDatasetStats():
except Exception as e: except Exception as e:
raise Exception("error/HUB/dataset_stats/yaml_load") from e raise Exception("error/HUB/dataset_stats/yaml_load") from e
check_dataset(data, autodownload) # download dataset if missing check_det_dataset(data, autodownload) # download dataset if missing
self.hub_dir = Path(data['path'] + '-hub') self.hub_dir = Path(data['path'] + '-hub')
self.im_dir = self.hub_dir / 'images' self.im_dir = self.hub_dir / 'images'
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images

@ -185,7 +185,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
return masks, index return masks, index
def check_dataset_yaml(dataset, autodownload=True): def check_det_dataset(dataset, autodownload=True):
# Download, check and/or unzip dataset if not found locally # Download, check and/or unzip dataset if not found locally
data = check_file(dataset) data = check_file(dataset)
@ -254,7 +254,7 @@ def check_dataset_yaml(dataset, autodownload=True):
return data # dictionary return data # dictionary
def check_dataset(dataset: str): def check_cls_dataset(dataset: str):
""" """
Check a classification dataset such as Imagenet. Check a classification dataset such as Imagenet.

@ -69,31 +69,25 @@ from ultralytics.nn.modules import Detect, Segment
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import check_dataset from ultralytics.yolo.data.utils import check_det_dataset
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode
MACOS = platform.system() == 'Darwin' # macOS environment MACOS = platform.system() == 'Darwin' # macOS environment
def export_formats(): def export_formats():
# YOLOv8 export formats # YOLOv8 export formats
x = [ x = [['PyTorch', '-', '.pt', True, True], ['TorchScript', 'torchscript', '.torchscript', True, True],
['PyTorch', '-', '.pt', True, True], ['ONNX', 'onnx', '.onnx', True, True], ['OpenVINO', 'openvino', '_openvino_model', True, False],
['TorchScript', 'torchscript', '.torchscript', True, True], ['TensorRT', 'engine', '.engine', False, True], ['CoreML', 'coreml', '.mlmodel', True, False],
['ONNX', 'onnx', '.onnx', True, True],
['OpenVINO', 'openvino', '_openvino_model', True, False],
['TensorRT', 'engine', '.engine', False, True],
['CoreML', 'coreml', '.mlmodel', True, False],
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
['TensorFlow.js', 'tfjs', '_web_model', False, False], ['TensorFlow.js', 'tfjs', '_web_model', False, False], ['PaddlePaddle', 'paddle', '_paddle_model', True, True]]
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
@ -135,7 +129,7 @@ class Exporter:
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
@smart_inference_mode() @smart_inference_mode()
@ -241,7 +235,7 @@ class Exporter:
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None f = [str(x) for x in f if x] # filter out '' and None
if any(f): if any(f):
task = guess_task_from_head(model.yaml["head"][-1][-2]) task = guess_task_from_model_yaml(model)
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models" s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
@ -570,7 +564,7 @@ class Exporter:
if n >= n_images: if n >= n_images:
break break
dataset = LoadImages(check_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False) dataset = LoadImages(check_det_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False)
converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100) converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.target_spec.supported_types = [] converter.target_spec.supported_types = []

@ -6,9 +6,9 @@ from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode
# Map head to model, trainer, validator, and predictor classes # Map head to model, trainer, validator, and predictor classes
MODEL_MAP = { MODEL_MAP = {
@ -68,7 +68,7 @@ class YOLO:
""" """
cfg = check_yaml(cfg) # check YAML cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(cfg, append_filename=True) # model dict cfg_dict = yaml_load(cfg, append_filename=True) # model dict
self.task = guess_task_from_head(cfg_dict["head"][-1][-2]) self.task = guess_task_from_model_yaml(cfg_dict)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._guess_ops_from_task(self.task) self._guess_ops_from_task(self.task)
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
@ -228,6 +228,12 @@ class YOLO:
""" """
return self.model.names return self.model.names
def add_callback(self, event: str, func):
"""
Add callback
"""
callbacks.default_callbacks[event].append(func)
@staticmethod @staticmethod
def _reset_ckpt_args(args): def _reset_ckpt_args(args):
args.pop("project", None) args.pop("project", None)

@ -88,7 +88,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None self.vid_path, self.vid_writer = None, None
self.annotator = None self.annotator = None
self.data_path = None self.data_path = None
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
def preprocess(self, img): def preprocess(self, img):
@ -172,16 +172,17 @@ class BasePredictor:
# setup source. Run every time predict is called # setup source. Run every time predict is called
self.setup_source(source) self.setup_source(source)
# check if save_dir/ label file exists # check if save_dir/ label file exists
if self.args.save: if self.args.save or self.args.save_txt:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# warmup model # warmup model
if not self.done_warmup: if not self.done_warmup:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz)) self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
self.done_warmup = True self.done_warmup = True
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
for batch in self.dataset: for batch in self.dataset:
self.run_callbacks("on_predict_batch_start") self.run_callbacks("on_predict_batch_start")
self.batch = batch
path, im, im0s, vid_cap, s = batch path, im, im0s, vid_cap, s = batch
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
with self.dt[0]: with self.dt[0]:
@ -195,13 +196,13 @@ class BasePredictor:
# postprocess # postprocess
with self.dt[2]: with self.dt[2]:
results = self.postprocess(preds, im, im0s, self.classes) self.results = self.postprocess(preds, im, im0s, self.classes)
for i in range(len(im)): for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s) p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p = Path(p) p = Path(p)
if verbose or self.args.save or self.args.save_txt or self.args.show: if verbose or self.args.save or self.args.save_txt or self.args.show:
s += self.write_results(i, results, (p, im, im0)) s += self.write_results(i, self.results, (p, im, im0))
if self.args.show: if self.args.show:
self.show(p) self.show(p)
@ -209,22 +210,21 @@ class BasePredictor:
if self.args.save: if self.args.save:
self.save_preds(vid_cap, i, str(self.save_dir / p.name)) self.save_preds(vid_cap, i, str(self.save_dir / p.name))
yield from results self.run_callbacks("on_predict_batch_end")
yield from self.results
# Print time (inference-only) # Print time (inference-only)
if verbose: if verbose:
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
self.run_callbacks("on_predict_batch_end")
# Print results # Print results
if verbose and self.seen: if verbose and self.seen:
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape ' LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
f'{(1, 3, *self.imgsz)}' % t) f'{(1, 3, *self.imgsz)}' % t)
if self.args.save_txt or self.args.save: if self.args.save_txt or self.args.save:
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" \ nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
if self.args.save_txt else '' s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end") self.run_callbacks("on_predict_end")

@ -20,19 +20,18 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from tqdm import tqdm from tqdm import tqdm
import ultralytics.yolo.utils as utils
from ultralytics import __version__ from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
yaml_save) emojis, yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.yolo.utils.files import get_latest_run, increment_path from ultralytics.yolo.utils.files import get_latest_run, increment_path
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
strip_optimizer) select_device, strip_optimizer)
class BaseTrainer: class BaseTrainer:
@ -81,7 +80,7 @@ class BaseTrainer:
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch) self.device = select_device(self.args.device, self.args.batch)
self.check_resume() self.check_resume()
self.console = LOGGER self.console = LOGGER
self.validator = None self.validator = None
@ -120,9 +119,11 @@ class BaseTrainer:
self.model = self.args.model self.model = self.args.model
self.data = self.args.data self.data = self.args.data
if self.data.endswith(".yaml"): if self.data.endswith(".yaml"):
self.data = check_dataset_yaml(self.data) self.data = check_det_dataset(self.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.data)
else: else:
self.data = check_dataset(self.data) raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
self.trainset, self.testset = self.get_dataset(self.data) self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None self.ema = None
@ -140,7 +141,7 @@ class BaseTrainer:
self.plot_idx = [0, 1, 2] self.plot_idx = [0, 1, 2]
# Callbacks # Callbacks
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
if RANK in {0, -1}: if RANK in {0, -1}:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)

@ -9,8 +9,8 @@ from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
@ -70,7 +70,7 @@ class BaseValidator:
if self.args.conf is None: if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001 self.args.conf = 0.001 # default conf=0.001
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
@smart_inference_mode() @smart_inference_mode()
def __call__(self, trainer=None, model=None): def __call__(self, trainer=None, model=None):
@ -109,9 +109,11 @@ class BaseValidator:
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"): if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
self.data = check_dataset_yaml(self.args.data) self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
else: else:
self.data = check_dataset(self.args.data) raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
if self.device.type == 'cpu': if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading

@ -68,7 +68,7 @@ HELP_MSG = \
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
- Predict a YouTube video using a pretrained segmentation model at image size 320: - Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320 yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
- Val a pretrained detection model at batch-size 1 and image size 640: - Val a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640 yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
@ -109,6 +109,9 @@ class IterableSimpleNamespace(SimpleNamespace):
def __str__(self): def __str__(self):
return '\n'.join(f"{k}={v}" for k, v in vars(self).items()) return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
def get(self, key, default=None):
return getattr(self, key, default)
# Default configuration # Default configuration
with open(DEFAULT_CFG_PATH, errors='ignore') as f: with open(DEFAULT_CFG_PATH, errors='ignore') as f:

@ -106,36 +106,36 @@ def on_export_end(exporter):
default_callbacks = { default_callbacks = {
# Run in trainer # Run in trainer
'on_pretrain_routine_start': on_pretrain_routine_start, 'on_pretrain_routine_start': [on_pretrain_routine_start],
'on_pretrain_routine_end': on_pretrain_routine_end, 'on_pretrain_routine_end': [on_pretrain_routine_end],
'on_train_start': on_train_start, 'on_train_start': [on_train_start],
'on_train_epoch_start': on_train_epoch_start, 'on_train_epoch_start': [on_train_epoch_start],
'on_train_batch_start': on_train_batch_start, 'on_train_batch_start': [on_train_batch_start],
'optimizer_step': optimizer_step, 'optimizer_step': [optimizer_step],
'on_before_zero_grad': on_before_zero_grad, 'on_before_zero_grad': [on_before_zero_grad],
'on_train_batch_end': on_train_batch_end, 'on_train_batch_end': [on_train_batch_end],
'on_train_epoch_end': on_train_epoch_end, 'on_train_epoch_end': [on_train_epoch_end],
'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
'on_model_save': on_model_save, 'on_model_save': [on_model_save],
'on_train_end': on_train_end, 'on_train_end': [on_train_end],
'on_params_update': on_params_update, 'on_params_update': [on_params_update],
'teardown': teardown, 'teardown': [teardown],
# Run in validator # Run in validator
'on_val_start': on_val_start, 'on_val_start': [on_val_start],
'on_val_batch_start': on_val_batch_start, 'on_val_batch_start': [on_val_batch_start],
'on_val_batch_end': on_val_batch_end, 'on_val_batch_end': [on_val_batch_end],
'on_val_end': on_val_end, 'on_val_end': [on_val_end],
# Run in predictor # Run in predictor
'on_predict_start': on_predict_start, 'on_predict_start': [on_predict_start],
'on_predict_batch_start': on_predict_batch_start, 'on_predict_batch_start': [on_predict_batch_start],
'on_predict_batch_end': on_predict_batch_end, 'on_predict_batch_end': [on_predict_batch_end],
'on_predict_end': on_predict_end, 'on_predict_end': [on_predict_end],
# Run in exporter # Run in exporter
'on_export_start': on_export_start, 'on_export_start': [on_export_start],
'on_export_end': on_export_end} 'on_export_end': [on_export_end]}
def add_integration_callbacks(instance): def add_integration_callbacks(instance):

@ -307,18 +307,20 @@ def strip_optimizer(f='best.pt', s=''):
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
def guess_task_from_head(head): def guess_task_from_model_yaml(model):
try:
cfg = model if isinstance(model, dict) else model.yaml # model cfg dict
m = cfg["head"][-1][-2].lower() # output module name
task = None task = None
if head.lower() in ["classify", "classifier", "cls", "fc"]: if m in ["classify", "classifier", "cls", "fc"]:
task = "classify" task = "classify"
if head.lower() in ["detect"]: if m in ["detect"]:
task = "detect" task = "detect"
if head.lower() in ["segment"]: if m in ["segment"]:
task = "segment" task = "segment"
except Exception as e:
if not task: raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. '
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links 'Valid tasks are detect, segment, classify.') from e
return task return task
@ -374,14 +376,36 @@ def profile(input, ops, n=10, device=None):
class EarlyStopping: class EarlyStopping:
# early stopper """
Early stopping class that stops training when a specified number of epochs have passed without improvement.
"""
def __init__(self, patience=30): def __init__(self, patience=30):
"""
Initialize early stopping object
Args:
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. Default is 30.
"""
self.best_fitness = 0.0 # i.e. mAP self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0 self.best_epoch = 0
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
self.possible_stop = False # possible stop may occur next epoch self.possible_stop = False # possible stop may occur next epoch
def __call__(self, epoch, fitness): def __call__(self, epoch, fitness):
"""
Check whether to stop training
Args:
epoch (int): Current epoch of training
fitness (float): Fitness value of current epoch
Returns:
bool: True if training should stop, False otherwise
"""
if fitness is None: # check if fitness=None (happens when val=False)
return False
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
self.best_epoch = epoch self.best_epoch = epoch
self.best_fitness = fitness self.best_fitness = fitness

@ -10,6 +10,7 @@ class ClassificationValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, logger, args)
self.args.task = 'classify'
self.metrics = ClassifyMetrics() self.metrics = ClassifyMetrics()
def get_desc(self): def get_desc(self):

@ -20,6 +20,7 @@ class DetectionValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, logger, args)
self.args.task = 'detect'
self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
self.is_coco = False self.is_coco = False
self.class_map = None self.class_map = None

@ -87,7 +87,7 @@ class SegmentationPredictor(DetectionPredictor):
c = int(cls) # integer class c = int(cls) # integer class
label = None if self.args.hide_labels else ( label = None if self.args.hide_labels else (
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
if self.args.save_crop: if self.args.save_crop:
imc = im0.copy() imc = im0.copy()
save_one_box(d.xyxy, save_one_box(d.xyxy,

@ -19,7 +19,7 @@ class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args) super().__init__(dataloader, save_dir, pbar, logger, args)
self.args.task = "segment" self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir) self.metrics = SegmentMetrics(save_dir=self.save_dir)
def preprocess(self, batch): def preprocess(self, batch):

Loading…
Cancel
Save