ultralytics 8.0.18 new python callbacks and minor fixes (#580)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jeroen Rombouts <36196499+jarombouts@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		| @ -108,7 +108,7 @@ yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" | ||||
| #### Python | ||||
|  | ||||
| YOLOv8 may also be used directly in a Python environment, and accepts the | ||||
| same [arguments](https://docs.ultralytics.com/config/) as in the CLI example above: | ||||
| same [arguments](https://docs.ultralytics.com/cfg/) as in the CLI example above: | ||||
|  | ||||
| ```python | ||||
| from ultralytics import YOLO | ||||
|  | ||||
| @ -70,7 +70,7 @@ YOLOv8 可以直接在命令行界面(CLI)中使用 `yolo` 命令运行: | ||||
| yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg" | ||||
| ``` | ||||
|  | ||||
| `yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/config/)的完整列表。 | ||||
| `yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/cfg/)的完整列表。 | ||||
|  | ||||
| ```bash | ||||
| yolo task=detect    mode=train    model=yolov8n.pt        args... | ||||
| @ -79,7 +79,7 @@ yolo task=detect    mode=train    model=yolov8n.pt        args... | ||||
|                          export         yolov8n.pt        format=onnx  args... | ||||
| ``` | ||||
|  | ||||
| YOLOv8 也可以在 Python 环境中直接使用,并接受与上面 CLI 例子中相同的[参数](https://docs.ultralytics.com/config/): | ||||
| YOLOv8 也可以在 Python 环境中直接使用,并接受与上面 CLI 例子中相同的[参数](https://docs.ultralytics.com/cfg/): | ||||
|  | ||||
| ```python | ||||
| from ultralytics import YOLO | ||||
|  | ||||
| @ -167,7 +167,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL | ||||
|     === "Example 2" | ||||
|         Predict a YouTube video using a pretrained segmentation model at image size 320: | ||||
|         ```bash | ||||
|         yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320 | ||||
|         yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320 | ||||
|         ``` | ||||
|  | ||||
|     === "Example 3" | ||||
|  | ||||
| @ -101,7 +101,7 @@ | ||||
|       "source": [ | ||||
|         "# 1. Predict\n", | ||||
|         "\n", | ||||
|         "YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list of available `yolo` [arguments](https://docs.ultralytics.com/config/) in the YOLOv8 [Docs](https://docs.ultralytics.com).\n" | ||||
|         "YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list of available `yolo` [arguments](https://docs.ultralytics.com/cfg/) in the YOLOv8 [Docs](https://docs.ultralytics.com).\n" | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|  | ||||
| @ -127,7 +127,3 @@ def test_workflow(): | ||||
|     model.val() | ||||
|     model.predict(SOURCE) | ||||
|     model.export(format="onnx", opset=12)  # export a model to ONNX format | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     test_predict_img() | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
|  | ||||
| __version__ = "8.0.17" | ||||
| __version__ = "8.0.18" | ||||
|  | ||||
| from ultralytics.yolo.engine.model import YOLO | ||||
| from ultralytics.yolo.utils import ops | ||||
|  | ||||
| @ -24,7 +24,7 @@ yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100 | ||||
| ``` | ||||
|  | ||||
| They may also be used directly in a Python environment, and accepts the same | ||||
| [arguments](https://docs.ultralytics.com/config/) as in the CLI example above: | ||||
| [arguments](https://docs.ultralytics.com/cfg/) as in the CLI example above: | ||||
|  | ||||
| ```python | ||||
| from ultralytics import YOLO | ||||
|  | ||||
| @ -222,7 +222,8 @@ class AutoBackend(nn.Module): | ||||
|             nhwc = model.runtime.startswith("tensorflow") | ||||
|             ''' | ||||
|         else: | ||||
|             raise NotImplementedError(f'ERROR: {w} is not a supported format') | ||||
|             raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see " | ||||
|                                       f"https://docs.ultralytics.com/reference/nn/") | ||||
|  | ||||
|         # class names | ||||
|         if 'names' not in locals(): | ||||
|  | ||||
| @ -28,7 +28,7 @@ CLI_HELP_MSG = \ | ||||
|         yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 | ||||
|  | ||||
|     2. Predict a YouTube video using a pretrained segmentation model at image size 320: | ||||
|         yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320 | ||||
|         yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320 | ||||
|  | ||||
|     3. Val a pretrained detection model at batch-size 1 and image size 640: | ||||
|         yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640 | ||||
| @ -126,13 +126,13 @@ def merge_equals_args(args: List[str]) -> List[str]: | ||||
|     """ | ||||
|     new_args = [] | ||||
|     for i, arg in enumerate(args): | ||||
|         if arg == '=' and 0 < i < len(args) - 1: | ||||
|         if arg == '=' and 0 < i < len(args) - 1:  # merge ['arg', '=', 'val'] | ||||
|             new_args[-1] += f"={args[i + 1]}" | ||||
|             del args[i + 1] | ||||
|         elif arg.endswith('=') and i < len(args) - 1: | ||||
|         elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]:  # merge ['arg=', 'val'] | ||||
|             new_args.append(f"{arg}{args[i + 1]}") | ||||
|             del args[i + 1] | ||||
|         elif arg.startswith('=') and i > 0: | ||||
|         elif arg.startswith('=') and i > 0:  # merge ['arg', '=val'] | ||||
|             new_args[-1] += arg | ||||
|         else: | ||||
|             new_args.append(arg) | ||||
| @ -178,7 +178,7 @@ def entrypoint(debug=False): | ||||
|         if '=' in a: | ||||
|             try: | ||||
|                 re.sub(r' *= *', '=', a)  # remove spaces around equals sign | ||||
|                 k, v = a.split('=') | ||||
|                 k, v = a.split('=', 1)  # split on first '=' sign | ||||
|                 if k == 'cfg':  # custom.yaml passed | ||||
|                     LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}") | ||||
|                     overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'} | ||||
|  | ||||
| @ -59,8 +59,9 @@ line_thickness: 3  # bounding box thickness (pixels) | ||||
| visualize: False  # visualize model features | ||||
| augment: False  # apply image augmentation to prediction sources | ||||
| agnostic_nms: False  # class-agnostic NMS | ||||
| retina_masks: False  # use high-resolution segmentation masks | ||||
| classes: null  # filter results by class, i.e. class=0, or class=[0,2,3] | ||||
| retina_masks: False  # use high-resolution segmentation masks | ||||
| boxes: True # Show boxes in segmentation predictions | ||||
|  | ||||
| # Export settings ------------------------------------------------------------------------------------------------------ | ||||
| format: torchscript  # format to export to | ||||
|  | ||||
| @ -28,7 +28,7 @@ from PIL import ExifTags, Image, ImageOps | ||||
| from torch.utils.data import DataLoader, Dataset, dataloader, distributed | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from ultralytics.yolo.data.utils import check_dataset, unzip_file | ||||
| from ultralytics.yolo.data.utils import check_det_dataset, unzip_file | ||||
| from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable, | ||||
|                                     is_kaggle) | ||||
| from ultralytics.yolo.utils.checks import check_requirements, check_yaml | ||||
| @ -1061,7 +1061,7 @@ class HUBDatasetStats(): | ||||
|         except Exception as e: | ||||
|             raise Exception("error/HUB/dataset_stats/yaml_load") from e | ||||
|  | ||||
|         check_dataset(data, autodownload)  # download dataset if missing | ||||
|         check_det_dataset(data, autodownload)  # download dataset if missing | ||||
|         self.hub_dir = Path(data['path'] + '-hub') | ||||
|         self.im_dir = self.hub_dir / 'images' | ||||
|         self.im_dir.mkdir(parents=True, exist_ok=True)  # makes /images | ||||
|  | ||||
| @ -185,7 +185,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): | ||||
|     return masks, index | ||||
|  | ||||
|  | ||||
| def check_dataset_yaml(dataset, autodownload=True): | ||||
| def check_det_dataset(dataset, autodownload=True): | ||||
|     # Download, check and/or unzip dataset if not found locally | ||||
|     data = check_file(dataset) | ||||
|  | ||||
| @ -254,7 +254,7 @@ def check_dataset_yaml(dataset, autodownload=True): | ||||
|     return data  # dictionary | ||||
|  | ||||
|  | ||||
| def check_dataset(dataset: str): | ||||
| def check_cls_dataset(dataset: str): | ||||
|     """ | ||||
|     Check a classification dataset such as Imagenet. | ||||
|  | ||||
|  | ||||
| @ -69,31 +69,25 @@ from ultralytics.nn.modules import Detect, Segment | ||||
| from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages | ||||
| from ultralytics.yolo.data.utils import check_dataset | ||||
| from ultralytics.yolo.data.utils import check_det_dataset | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save | ||||
| from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml | ||||
| from ultralytics.yolo.utils.files import file_size | ||||
| from ultralytics.yolo.utils.ops import Profile | ||||
| from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode | ||||
| from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode | ||||
|  | ||||
| MACOS = platform.system() == 'Darwin'  # macOS environment | ||||
|  | ||||
|  | ||||
| def export_formats(): | ||||
|     # YOLOv8 export formats | ||||
|     x = [ | ||||
|         ['PyTorch', '-', '.pt', True, True], | ||||
|         ['TorchScript', 'torchscript', '.torchscript', True, True], | ||||
|         ['ONNX', 'onnx', '.onnx', True, True], | ||||
|         ['OpenVINO', 'openvino', '_openvino_model', True, False], | ||||
|         ['TensorRT', 'engine', '.engine', False, True], | ||||
|         ['CoreML', 'coreml', '.mlmodel', True, False], | ||||
|         ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], | ||||
|         ['TensorFlow GraphDef', 'pb', '.pb', True, True], | ||||
|         ['TensorFlow Lite', 'tflite', '.tflite', True, False], | ||||
|         ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], | ||||
|         ['TensorFlow.js', 'tfjs', '_web_model', False, False], | ||||
|         ['PaddlePaddle', 'paddle', '_paddle_model', True, True],] | ||||
|     x = [['PyTorch', '-', '.pt', True, True], ['TorchScript', 'torchscript', '.torchscript', True, True], | ||||
|          ['ONNX', 'onnx', '.onnx', True, True], ['OpenVINO', 'openvino', '_openvino_model', True, False], | ||||
|          ['TensorRT', 'engine', '.engine', False, True], ['CoreML', 'coreml', '.mlmodel', True, False], | ||||
|          ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], | ||||
|          ['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow Lite', 'tflite', '.tflite', True, False], | ||||
|          ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], | ||||
|          ['TensorFlow.js', 'tfjs', '_web_model', False, False], ['PaddlePaddle', 'paddle', '_paddle_model', True, True]] | ||||
|     return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) | ||||
|  | ||||
|  | ||||
| @ -135,7 +129,7 @@ class Exporter: | ||||
|             overrides (dict, optional): Configuration overrides. Defaults to None. | ||||
|         """ | ||||
|         self.args = get_cfg(cfg, overrides) | ||||
|         self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         callbacks.add_integration_callbacks(self) | ||||
|  | ||||
|     @smart_inference_mode() | ||||
| @ -241,7 +235,7 @@ class Exporter: | ||||
|         # Finish | ||||
|         f = [str(x) for x in f if x]  # filter out '' and None | ||||
|         if any(f): | ||||
|             task = guess_task_from_head(model.yaml["head"][-1][-2]) | ||||
|             task = guess_task_from_model_yaml(model) | ||||
|             s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models" | ||||
|             LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' | ||||
|                         f"\nResults saved to {colorstr('bold', file.parent.resolve())}" | ||||
| @ -570,7 +564,7 @@ class Exporter: | ||||
|                     if n >= n_images: | ||||
|                         break | ||||
|  | ||||
|             dataset = LoadImages(check_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False) | ||||
|             dataset = LoadImages(check_det_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False) | ||||
|             converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100) | ||||
|             converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | ||||
|             converter.target_spec.supported_types = [] | ||||
|  | ||||
| @ -6,9 +6,9 @@ from ultralytics import yolo  # noqa | ||||
| from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.engine.exporter import Exporter | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, yaml_load | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load | ||||
| from ultralytics.yolo.utils.checks import check_yaml | ||||
| from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode | ||||
| from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode | ||||
|  | ||||
| # Map head to model, trainer, validator, and predictor classes | ||||
| MODEL_MAP = { | ||||
| @ -68,7 +68,7 @@ class YOLO: | ||||
|         """ | ||||
|         cfg = check_yaml(cfg)  # check YAML | ||||
|         cfg_dict = yaml_load(cfg, append_filename=True)  # model dict | ||||
|         self.task = guess_task_from_head(cfg_dict["head"][-1][-2]) | ||||
|         self.task = guess_task_from_model_yaml(cfg_dict) | ||||
|         self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ | ||||
|             self._guess_ops_from_task(self.task) | ||||
|         self.model = self.ModelClass(cfg_dict, verbose=verbose)  # initialize | ||||
| @ -228,6 +228,12 @@ class YOLO: | ||||
|         """ | ||||
|         return self.model.names | ||||
|  | ||||
|     def add_callback(self, event: str, func): | ||||
|         """ | ||||
|         Add callback | ||||
|         """ | ||||
|         callbacks.default_callbacks[event].append(func) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _reset_ckpt_args(args): | ||||
|         args.pop("project", None) | ||||
|  | ||||
| @ -88,7 +88,7 @@ class BasePredictor: | ||||
|         self.vid_path, self.vid_writer = None, None | ||||
|         self.annotator = None | ||||
|         self.data_path = None | ||||
|         self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         callbacks.add_integration_callbacks(self) | ||||
|  | ||||
|     def preprocess(self, img): | ||||
| @ -172,16 +172,17 @@ class BasePredictor: | ||||
|         # setup source. Run every time predict is called | ||||
|         self.setup_source(source) | ||||
|         # check if save_dir/ label file exists | ||||
|         if self.args.save: | ||||
|         if self.args.save or self.args.save_txt: | ||||
|             (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) | ||||
|         # warmup model | ||||
|         if not self.done_warmup: | ||||
|             self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz)) | ||||
|             self.done_warmup = True | ||||
|  | ||||
|         self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) | ||||
|         self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None | ||||
|         for batch in self.dataset: | ||||
|             self.run_callbacks("on_predict_batch_start") | ||||
|             self.batch = batch | ||||
|             path, im, im0s, vid_cap, s = batch | ||||
|             visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False | ||||
|             with self.dt[0]: | ||||
| @ -195,13 +196,13 @@ class BasePredictor: | ||||
|  | ||||
|             # postprocess | ||||
|             with self.dt[2]: | ||||
|                 results = self.postprocess(preds, im, im0s, self.classes) | ||||
|                 self.results = self.postprocess(preds, im, im0s, self.classes) | ||||
|             for i in range(len(im)): | ||||
|                 p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s) | ||||
|                 p = Path(p) | ||||
|  | ||||
|                 if verbose or self.args.save or self.args.save_txt or self.args.show: | ||||
|                     s += self.write_results(i, results, (p, im, im0)) | ||||
|                     s += self.write_results(i, self.results, (p, im, im0)) | ||||
|  | ||||
|                 if self.args.show: | ||||
|                     self.show(p) | ||||
| @ -209,22 +210,21 @@ class BasePredictor: | ||||
|                 if self.args.save: | ||||
|                     self.save_preds(vid_cap, i, str(self.save_dir / p.name)) | ||||
|  | ||||
|             yield from results | ||||
|             self.run_callbacks("on_predict_batch_end") | ||||
|             yield from self.results | ||||
|  | ||||
|             # Print time (inference-only) | ||||
|             if verbose: | ||||
|                 LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") | ||||
|  | ||||
|             self.run_callbacks("on_predict_batch_end") | ||||
|  | ||||
|         # Print results | ||||
|         if verbose and self.seen: | ||||
|             t = tuple(x.t / self.seen * 1E3 for x in self.dt)  # speeds per image | ||||
|             LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape ' | ||||
|                         f'{(1, 3, *self.imgsz)}' % t) | ||||
|         if self.args.save_txt or self.args.save: | ||||
|             s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" \ | ||||
|                 if self.args.save_txt else '' | ||||
|             nl = len(list(self.save_dir.glob('labels/*.txt')))  # number of labels | ||||
|             s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else '' | ||||
|             LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") | ||||
|  | ||||
|         self.run_callbacks("on_predict_end") | ||||
|  | ||||
| @ -20,19 +20,18 @@ from torch.nn.parallel import DistributedDataParallel as DDP | ||||
| from torch.optim import lr_scheduler | ||||
| from tqdm import tqdm | ||||
|  | ||||
| import ultralytics.yolo.utils as utils | ||||
| from ultralytics import __version__ | ||||
| from ultralytics.nn.tasks import attempt_load_one_weight | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml | ||||
| from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset | ||||
| from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, | ||||
|                                     yaml_save) | ||||
|                                     emojis, yaml_save) | ||||
| from ultralytics.yolo.utils.autobatch import check_train_batch_size | ||||
| from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args | ||||
| from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command | ||||
| from ultralytics.yolo.utils.files import get_latest_run, increment_path | ||||
| from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, | ||||
|                                                 strip_optimizer) | ||||
|                                                 select_device, strip_optimizer) | ||||
|  | ||||
|  | ||||
| class BaseTrainer: | ||||
| @ -81,7 +80,7 @@ class BaseTrainer: | ||||
|             overrides (dict, optional): Configuration overrides. Defaults to None. | ||||
|         """ | ||||
|         self.args = get_cfg(cfg, overrides) | ||||
|         self.device = utils.torch_utils.select_device(self.args.device, self.args.batch) | ||||
|         self.device = select_device(self.args.device, self.args.batch) | ||||
|         self.check_resume() | ||||
|         self.console = LOGGER | ||||
|         self.validator = None | ||||
| @ -120,9 +119,11 @@ class BaseTrainer: | ||||
|         self.model = self.args.model | ||||
|         self.data = self.args.data | ||||
|         if self.data.endswith(".yaml"): | ||||
|             self.data = check_dataset_yaml(self.data) | ||||
|             self.data = check_det_dataset(self.data) | ||||
|         elif self.args.task == 'classify': | ||||
|             self.data = check_cls_dataset(self.data) | ||||
|         else: | ||||
|             self.data = check_dataset(self.data) | ||||
|             raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌")) | ||||
|         self.trainset, self.testset = self.get_dataset(self.data) | ||||
|         self.ema = None | ||||
|  | ||||
| @ -140,7 +141,7 @@ class BaseTrainer: | ||||
|         self.plot_idx = [0, 1, 2] | ||||
|  | ||||
|         # Callbacks | ||||
|         self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         if RANK in {0, -1}: | ||||
|             callbacks.add_integration_callbacks(self) | ||||
|  | ||||
|  | ||||
| @ -9,8 +9,8 @@ from tqdm import tqdm | ||||
|  | ||||
| from ultralytics.nn.autobackend import AutoBackend | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks | ||||
| from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis | ||||
| from ultralytics.yolo.utils.checks import check_imgsz | ||||
| from ultralytics.yolo.utils.files import increment_path | ||||
| from ultralytics.yolo.utils.ops import Profile | ||||
| @ -70,7 +70,7 @@ class BaseValidator: | ||||
|         if self.args.conf is None: | ||||
|             self.args.conf = 0.001  # default conf=0.001 | ||||
|  | ||||
|         self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|  | ||||
|     @smart_inference_mode() | ||||
|     def __call__(self, trainer=None, model=None): | ||||
| @ -109,9 +109,11 @@ class BaseValidator: | ||||
|                         f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') | ||||
|  | ||||
|             if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"): | ||||
|                 self.data = check_dataset_yaml(self.args.data) | ||||
|                 self.data = check_det_dataset(self.args.data) | ||||
|             elif self.args.task == 'classify': | ||||
|                 self.data = check_cls_dataset(self.args.data) | ||||
|             else: | ||||
|                 self.data = check_dataset(self.args.data) | ||||
|                 raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌")) | ||||
|  | ||||
|             if self.device.type == 'cpu': | ||||
|                 self.args.workers = 0  # faster CPU val as time dominated by inference, not dataloading | ||||
|  | ||||
| @ -68,7 +68,7 @@ HELP_MSG = \ | ||||
|             yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01 | ||||
|  | ||||
|         - Predict a YouTube video using a pretrained segmentation model at image size 320: | ||||
|             yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320 | ||||
|             yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320 | ||||
|  | ||||
|         - Val a pretrained detection model at batch-size 1 and image size 640: | ||||
|             yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640 | ||||
| @ -109,6 +109,9 @@ class IterableSimpleNamespace(SimpleNamespace): | ||||
|     def __str__(self): | ||||
|         return '\n'.join(f"{k}={v}" for k, v in vars(self).items()) | ||||
|  | ||||
|     def get(self, key, default=None): | ||||
|         return getattr(self, key, default) | ||||
|  | ||||
|  | ||||
| # Default configuration | ||||
| with open(DEFAULT_CFG_PATH, errors='ignore') as f: | ||||
|  | ||||
| @ -106,36 +106,36 @@ def on_export_end(exporter): | ||||
|  | ||||
| default_callbacks = { | ||||
|     # Run in trainer | ||||
|     'on_pretrain_routine_start': on_pretrain_routine_start, | ||||
|     'on_pretrain_routine_end': on_pretrain_routine_end, | ||||
|     'on_train_start': on_train_start, | ||||
|     'on_train_epoch_start': on_train_epoch_start, | ||||
|     'on_train_batch_start': on_train_batch_start, | ||||
|     'optimizer_step': optimizer_step, | ||||
|     'on_before_zero_grad': on_before_zero_grad, | ||||
|     'on_train_batch_end': on_train_batch_end, | ||||
|     'on_train_epoch_end': on_train_epoch_end, | ||||
|     'on_fit_epoch_end': on_fit_epoch_end,  # fit = train + val | ||||
|     'on_model_save': on_model_save, | ||||
|     'on_train_end': on_train_end, | ||||
|     'on_params_update': on_params_update, | ||||
|     'teardown': teardown, | ||||
|     'on_pretrain_routine_start': [on_pretrain_routine_start], | ||||
|     'on_pretrain_routine_end': [on_pretrain_routine_end], | ||||
|     'on_train_start': [on_train_start], | ||||
|     'on_train_epoch_start': [on_train_epoch_start], | ||||
|     'on_train_batch_start': [on_train_batch_start], | ||||
|     'optimizer_step': [optimizer_step], | ||||
|     'on_before_zero_grad': [on_before_zero_grad], | ||||
|     'on_train_batch_end': [on_train_batch_end], | ||||
|     'on_train_epoch_end': [on_train_epoch_end], | ||||
|     'on_fit_epoch_end': [on_fit_epoch_end],  # fit = train + val | ||||
|     'on_model_save': [on_model_save], | ||||
|     'on_train_end': [on_train_end], | ||||
|     'on_params_update': [on_params_update], | ||||
|     'teardown': [teardown], | ||||
|  | ||||
|     # Run in validator | ||||
|     'on_val_start': on_val_start, | ||||
|     'on_val_batch_start': on_val_batch_start, | ||||
|     'on_val_batch_end': on_val_batch_end, | ||||
|     'on_val_end': on_val_end, | ||||
|     'on_val_start': [on_val_start], | ||||
|     'on_val_batch_start': [on_val_batch_start], | ||||
|     'on_val_batch_end': [on_val_batch_end], | ||||
|     'on_val_end': [on_val_end], | ||||
|  | ||||
|     # Run in predictor | ||||
|     'on_predict_start': on_predict_start, | ||||
|     'on_predict_batch_start': on_predict_batch_start, | ||||
|     'on_predict_batch_end': on_predict_batch_end, | ||||
|     'on_predict_end': on_predict_end, | ||||
|     'on_predict_start': [on_predict_start], | ||||
|     'on_predict_batch_start': [on_predict_batch_start], | ||||
|     'on_predict_batch_end': [on_predict_batch_end], | ||||
|     'on_predict_end': [on_predict_end], | ||||
|  | ||||
|     # Run in exporter | ||||
|     'on_export_start': on_export_start, | ||||
|     'on_export_end': on_export_end} | ||||
|     'on_export_start': [on_export_start], | ||||
|     'on_export_end': [on_export_end]} | ||||
|  | ||||
|  | ||||
| def add_integration_callbacks(instance): | ||||
|  | ||||
| @ -307,18 +307,20 @@ def strip_optimizer(f='best.pt', s=''): | ||||
|     LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") | ||||
|  | ||||
|  | ||||
| def guess_task_from_head(head): | ||||
|     task = None | ||||
|     if head.lower() in ["classify", "classifier", "cls", "fc"]: | ||||
|         task = "classify" | ||||
|     if head.lower() in ["detect"]: | ||||
|         task = "detect" | ||||
|     if head.lower() in ["segment"]: | ||||
|         task = "segment" | ||||
|  | ||||
|     if not task: | ||||
|         raise SyntaxError("task or model not recognized! Please refer the docs at : ")  # TODO: add docs links | ||||
|  | ||||
| def guess_task_from_model_yaml(model): | ||||
|     try: | ||||
|         cfg = model if isinstance(model, dict) else model.yaml  # model cfg dict | ||||
|         m = cfg["head"][-1][-2].lower()  # output module name | ||||
|         task = None | ||||
|         if m in ["classify", "classifier", "cls", "fc"]: | ||||
|             task = "classify" | ||||
|         if m in ["detect"]: | ||||
|             task = "detect" | ||||
|         if m in ["segment"]: | ||||
|             task = "segment" | ||||
|     except Exception as e: | ||||
|         raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. ' | ||||
|                           'Valid tasks are detect, segment, classify.') from e | ||||
|     return task | ||||
|  | ||||
|  | ||||
| @ -374,14 +376,36 @@ def profile(input, ops, n=10, device=None): | ||||
|  | ||||
|  | ||||
| class EarlyStopping: | ||||
|     # early stopper | ||||
|     """ | ||||
|     Early stopping class that stops training when a specified number of epochs have passed without improvement. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, patience=30): | ||||
|         """ | ||||
|         Initialize early stopping object | ||||
|  | ||||
|         Args: | ||||
|             patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. Default is 30. | ||||
|         """ | ||||
|         self.best_fitness = 0.0  # i.e. mAP | ||||
|         self.best_epoch = 0 | ||||
|         self.patience = patience or float('inf')  # epochs to wait after fitness stops improving to stop | ||||
|         self.possible_stop = False  # possible stop may occur next epoch | ||||
|  | ||||
|     def __call__(self, epoch, fitness): | ||||
|         """ | ||||
|         Check whether to stop training | ||||
|  | ||||
|         Args: | ||||
|             epoch (int): Current epoch of training | ||||
|             fitness (float): Fitness value of current epoch | ||||
|  | ||||
|         Returns: | ||||
|             bool: True if training should stop, False otherwise | ||||
|         """ | ||||
|         if fitness is None:  # check if fitness=None (happens when val=False) | ||||
|             return False | ||||
|  | ||||
|         if fitness >= self.best_fitness:  # >= 0 to allow for early zero-fitness stage of training | ||||
|             self.best_epoch = epoch | ||||
|             self.best_fitness = fitness | ||||
|  | ||||
| @ -10,6 +10,7 @@ class ClassificationValidator(BaseValidator): | ||||
|  | ||||
|     def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): | ||||
|         super().__init__(dataloader, save_dir, pbar, logger, args) | ||||
|         self.args.task = 'classify' | ||||
|         self.metrics = ClassifyMetrics() | ||||
|  | ||||
|     def get_desc(self): | ||||
|  | ||||
| @ -20,6 +20,7 @@ class DetectionValidator(BaseValidator): | ||||
|  | ||||
|     def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): | ||||
|         super().__init__(dataloader, save_dir, pbar, logger, args) | ||||
|         self.args.task = 'detect' | ||||
|         self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None | ||||
|         self.is_coco = False | ||||
|         self.class_map = None | ||||
|  | ||||
| @ -87,7 +87,7 @@ class SegmentationPredictor(DetectionPredictor): | ||||
|                 c = int(cls)  # integer class | ||||
|                 label = None if self.args.hide_labels else ( | ||||
|                     self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') | ||||
|                 self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) | ||||
|                 self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None | ||||
|             if self.args.save_crop: | ||||
|                 imc = im0.copy() | ||||
|                 save_one_box(d.xyxy, | ||||
|  | ||||
| @ -19,7 +19,7 @@ class SegmentationValidator(DetectionValidator): | ||||
|  | ||||
|     def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): | ||||
|         super().__init__(dataloader, save_dir, pbar, logger, args) | ||||
|         self.args.task = "segment" | ||||
|         self.args.task = 'segment' | ||||
|         self.metrics = SegmentMetrics(save_dir=self.save_dir) | ||||
|  | ||||
|     def preprocess(self, batch): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user