ultralytics 8.0.18
new python callbacks and minor fixes (#580)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jeroen Rombouts <36196499+jarombouts@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -108,7 +108,7 @@ yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
|
|||||||
#### Python
|
#### Python
|
||||||
|
|
||||||
YOLOv8 may also be used directly in a Python environment, and accepts the
|
YOLOv8 may also be used directly in a Python environment, and accepts the
|
||||||
same [arguments](https://docs.ultralytics.com/config/) as in the CLI example above:
|
same [arguments](https://docs.ultralytics.com/cfg/) as in the CLI example above:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
@ -70,7 +70,7 @@ YOLOv8 可以直接在命令行界面(CLI)中使用 `yolo` 命令运行:
|
|||||||
yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
|
yolo predict model=yolov8n.pt source="https://ultralytics.com/images/bus.jpg"
|
||||||
```
|
```
|
||||||
|
|
||||||
`yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/config/)的完整列表。
|
`yolo`可以用于各种任务和模式,并接受额外的参数,例如 `imgsz=640`。参见 YOLOv8 [文档](https://docs.ultralytics.com)中可用`yolo`[参数](https://docs.ultralytics.com/cfg/)的完整列表。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
yolo task=detect mode=train model=yolov8n.pt args...
|
yolo task=detect mode=train model=yolov8n.pt args...
|
||||||
@ -79,7 +79,7 @@ yolo task=detect mode=train model=yolov8n.pt args...
|
|||||||
export yolov8n.pt format=onnx args...
|
export yolov8n.pt format=onnx args...
|
||||||
```
|
```
|
||||||
|
|
||||||
YOLOv8 也可以在 Python 环境中直接使用,并接受与上面 CLI 例子中相同的[参数](https://docs.ultralytics.com/config/):
|
YOLOv8 也可以在 Python 环境中直接使用,并接受与上面 CLI 例子中相同的[参数](https://docs.ultralytics.com/cfg/):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
@ -167,7 +167,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL
|
|||||||
=== "Example 2"
|
=== "Example 2"
|
||||||
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||||
```bash
|
```bash
|
||||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Example 3"
|
=== "Example 3"
|
||||||
|
@ -101,7 +101,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# 1. Predict\n",
|
"# 1. Predict\n",
|
||||||
"\n",
|
"\n",
|
||||||
"YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list of available `yolo` [arguments](https://docs.ultralytics.com/config/) in the YOLOv8 [Docs](https://docs.ultralytics.com).\n"
|
"YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list of available `yolo` [arguments](https://docs.ultralytics.com/cfg/) in the YOLOv8 [Docs](https://docs.ultralytics.com).\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -127,7 +127,3 @@ def test_workflow():
|
|||||||
model.val()
|
model.val()
|
||||||
model.predict(SOURCE)
|
model.predict(SOURCE)
|
||||||
model.export(format="onnx", opset=12) # export a model to ONNX format
|
model.export(format="onnx", opset=12) # export a model to ONNX format
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_predict_img()
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.17"
|
__version__ = "8.0.18"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -24,7 +24,7 @@ yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
|
|||||||
```
|
```
|
||||||
|
|
||||||
They may also be used directly in a Python environment, and accepts the same
|
They may also be used directly in a Python environment, and accepts the same
|
||||||
[arguments](https://docs.ultralytics.com/config/) as in the CLI example above:
|
[arguments](https://docs.ultralytics.com/cfg/) as in the CLI example above:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
@ -222,7 +222,8 @@ class AutoBackend(nn.Module):
|
|||||||
nhwc = model.runtime.startswith("tensorflow")
|
nhwc = model.runtime.startswith("tensorflow")
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see "
|
||||||
|
f"https://docs.ultralytics.com/reference/nn/")
|
||||||
|
|
||||||
# class names
|
# class names
|
||||||
if 'names' not in locals():
|
if 'names' not in locals():
|
||||||
|
@ -28,7 +28,7 @@ CLI_HELP_MSG = \
|
|||||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||||
|
|
||||||
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||||
|
|
||||||
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
||||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||||
@ -126,13 +126,13 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|||||||
"""
|
"""
|
||||||
new_args = []
|
new_args = []
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
if arg == '=' and 0 < i < len(args) - 1:
|
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
||||||
new_args[-1] += f"={args[i + 1]}"
|
new_args[-1] += f"={args[i + 1]}"
|
||||||
del args[i + 1]
|
del args[i + 1]
|
||||||
elif arg.endswith('=') and i < len(args) - 1:
|
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
||||||
new_args.append(f"{arg}{args[i + 1]}")
|
new_args.append(f"{arg}{args[i + 1]}")
|
||||||
del args[i + 1]
|
del args[i + 1]
|
||||||
elif arg.startswith('=') and i > 0:
|
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
||||||
new_args[-1] += arg
|
new_args[-1] += arg
|
||||||
else:
|
else:
|
||||||
new_args.append(arg)
|
new_args.append(arg)
|
||||||
@ -178,7 +178,7 @@ def entrypoint(debug=False):
|
|||||||
if '=' in a:
|
if '=' in a:
|
||||||
try:
|
try:
|
||||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||||
k, v = a.split('=')
|
k, v = a.split('=', 1) # split on first '=' sign
|
||||||
if k == 'cfg': # custom.yaml passed
|
if k == 'cfg': # custom.yaml passed
|
||||||
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}")
|
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}")
|
||||||
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
||||||
|
@ -59,8 +59,9 @@ line_thickness: 3 # bounding box thickness (pixels)
|
|||||||
visualize: False # visualize model features
|
visualize: False # visualize model features
|
||||||
augment: False # apply image augmentation to prediction sources
|
augment: False # apply image augmentation to prediction sources
|
||||||
agnostic_nms: False # class-agnostic NMS
|
agnostic_nms: False # class-agnostic NMS
|
||||||
retina_masks: False # use high-resolution segmentation masks
|
|
||||||
classes: null # filter results by class, i.e. class=0, or class=[0,2,3]
|
classes: null # filter results by class, i.e. class=0, or class=[0,2,3]
|
||||||
|
retina_masks: False # use high-resolution segmentation masks
|
||||||
|
boxes: True # Show boxes in segmentation predictions
|
||||||
|
|
||||||
# Export settings ------------------------------------------------------------------------------------------------------
|
# Export settings ------------------------------------------------------------------------------------------------------
|
||||||
format: torchscript # format to export to
|
format: torchscript # format to export to
|
||||||
|
@ -28,7 +28,7 @@ from PIL import ExifTags, Image, ImageOps
|
|||||||
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.data.utils import check_dataset, unzip_file
|
from ultralytics.yolo.data.utils import check_det_dataset, unzip_file
|
||||||
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
||||||
is_kaggle)
|
is_kaggle)
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||||
@ -1061,7 +1061,7 @@ class HUBDatasetStats():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("error/HUB/dataset_stats/yaml_load") from e
|
raise Exception("error/HUB/dataset_stats/yaml_load") from e
|
||||||
|
|
||||||
check_dataset(data, autodownload) # download dataset if missing
|
check_det_dataset(data, autodownload) # download dataset if missing
|
||||||
self.hub_dir = Path(data['path'] + '-hub')
|
self.hub_dir = Path(data['path'] + '-hub')
|
||||||
self.im_dir = self.hub_dir / 'images'
|
self.im_dir = self.hub_dir / 'images'
|
||||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||||
|
@ -185,7 +185,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
|||||||
return masks, index
|
return masks, index
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_yaml(dataset, autodownload=True):
|
def check_det_dataset(dataset, autodownload=True):
|
||||||
# Download, check and/or unzip dataset if not found locally
|
# Download, check and/or unzip dataset if not found locally
|
||||||
data = check_file(dataset)
|
data = check_file(dataset)
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ def check_dataset_yaml(dataset, autodownload=True):
|
|||||||
return data # dictionary
|
return data # dictionary
|
||||||
|
|
||||||
|
|
||||||
def check_dataset(dataset: str):
|
def check_cls_dataset(dataset: str):
|
||||||
"""
|
"""
|
||||||
Check a classification dataset such as Imagenet.
|
Check a classification dataset such as Imagenet.
|
||||||
|
|
||||||
|
@ -69,31 +69,25 @@ from ultralytics.nn.modules import Detect, Segment
|
|||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||||
from ultralytics.yolo.data.utils import check_dataset
|
from ultralytics.yolo.data.utils import check_det_dataset
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||||
from ultralytics.yolo.utils.files import file_size
|
from ultralytics.yolo.utils.files import file_size
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode
|
||||||
|
|
||||||
MACOS = platform.system() == 'Darwin' # macOS environment
|
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||||
|
|
||||||
|
|
||||||
def export_formats():
|
def export_formats():
|
||||||
# YOLOv8 export formats
|
# YOLOv8 export formats
|
||||||
x = [
|
x = [['PyTorch', '-', '.pt', True, True], ['TorchScript', 'torchscript', '.torchscript', True, True],
|
||||||
['PyTorch', '-', '.pt', True, True],
|
['ONNX', 'onnx', '.onnx', True, True], ['OpenVINO', 'openvino', '_openvino_model', True, False],
|
||||||
['TorchScript', 'torchscript', '.torchscript', True, True],
|
['TensorRT', 'engine', '.engine', False, True], ['CoreML', 'coreml', '.mlmodel', True, False],
|
||||||
['ONNX', 'onnx', '.onnx', True, True],
|
|
||||||
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
|
||||||
['TensorRT', 'engine', '.engine', False, True],
|
|
||||||
['CoreML', 'coreml', '.mlmodel', True, False],
|
|
||||||
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
|
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
|
||||||
['TensorFlow GraphDef', 'pb', '.pb', True, True],
|
['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
|
||||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
|
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
|
||||||
['TensorFlow.js', 'tfjs', '_web_model', False, False],
|
['TensorFlow.js', 'tfjs', '_web_model', False, False], ['PaddlePaddle', 'paddle', '_paddle_model', True, True]]
|
||||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
|
|
||||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||||
|
|
||||||
|
|
||||||
@ -135,7 +129,7 @@ class Exporter:
|
|||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
@ -241,7 +235,7 @@ class Exporter:
|
|||||||
# Finish
|
# Finish
|
||||||
f = [str(x) for x in f if x] # filter out '' and None
|
f = [str(x) for x in f if x] # filter out '' and None
|
||||||
if any(f):
|
if any(f):
|
||||||
task = guess_task_from_head(model.yaml["head"][-1][-2])
|
task = guess_task_from_model_yaml(model)
|
||||||
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
|
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
|
||||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
||||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||||
@ -570,7 +564,7 @@ class Exporter:
|
|||||||
if n >= n_images:
|
if n >= n_images:
|
||||||
break
|
break
|
||||||
|
|
||||||
dataset = LoadImages(check_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False)
|
dataset = LoadImages(check_det_dataset(check_yaml(data))['train'], imgsz=imgsz, auto=False)
|
||||||
converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
|
converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
|
||||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||||
converter.target_spec.supported_types = []
|
converter.target_spec.supported_types = []
|
||||||
|
@ -6,9 +6,9 @@ from ultralytics import yolo # noqa
|
|||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.engine.exporter import Exporter
|
from ultralytics.yolo.engine.exporter import Exporter
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode
|
||||||
|
|
||||||
# Map head to model, trainer, validator, and predictor classes
|
# Map head to model, trainer, validator, and predictor classes
|
||||||
MODEL_MAP = {
|
MODEL_MAP = {
|
||||||
@ -68,7 +68,7 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
cfg = check_yaml(cfg) # check YAML
|
cfg = check_yaml(cfg) # check YAML
|
||||||
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
||||||
self.task = guess_task_from_head(cfg_dict["head"][-1][-2])
|
self.task = guess_task_from_model_yaml(cfg_dict)
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||||
self._guess_ops_from_task(self.task)
|
self._guess_ops_from_task(self.task)
|
||||||
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
||||||
@ -228,6 +228,12 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
return self.model.names
|
return self.model.names
|
||||||
|
|
||||||
|
def add_callback(self, event: str, func):
|
||||||
|
"""
|
||||||
|
Add callback
|
||||||
|
"""
|
||||||
|
callbacks.default_callbacks[event].append(func)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
args.pop("project", None)
|
args.pop("project", None)
|
||||||
|
@ -88,7 +88,7 @@ class BasePredictor:
|
|||||||
self.vid_path, self.vid_writer = None, None
|
self.vid_path, self.vid_writer = None, None
|
||||||
self.annotator = None
|
self.annotator = None
|
||||||
self.data_path = None
|
self.data_path = None
|
||||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
@ -172,16 +172,17 @@ class BasePredictor:
|
|||||||
# setup source. Run every time predict is called
|
# setup source. Run every time predict is called
|
||||||
self.setup_source(source)
|
self.setup_source(source)
|
||||||
# check if save_dir/ label file exists
|
# check if save_dir/ label file exists
|
||||||
if self.args.save:
|
if self.args.save or self.args.save_txt:
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
# warmup model
|
# warmup model
|
||||||
if not self.done_warmup:
|
if not self.done_warmup:
|
||||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
|
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
|
||||||
self.done_warmup = True
|
self.done_warmup = True
|
||||||
|
|
||||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
|
||||||
for batch in self.dataset:
|
for batch in self.dataset:
|
||||||
self.run_callbacks("on_predict_batch_start")
|
self.run_callbacks("on_predict_batch_start")
|
||||||
|
self.batch = batch
|
||||||
path, im, im0s, vid_cap, s = batch
|
path, im, im0s, vid_cap, s = batch
|
||||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||||
with self.dt[0]:
|
with self.dt[0]:
|
||||||
@ -195,13 +196,13 @@ class BasePredictor:
|
|||||||
|
|
||||||
# postprocess
|
# postprocess
|
||||||
with self.dt[2]:
|
with self.dt[2]:
|
||||||
results = self.postprocess(preds, im, im0s, self.classes)
|
self.results = self.postprocess(preds, im, im0s, self.classes)
|
||||||
for i in range(len(im)):
|
for i in range(len(im)):
|
||||||
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
||||||
p = Path(p)
|
p = Path(p)
|
||||||
|
|
||||||
if verbose or self.args.save or self.args.save_txt or self.args.show:
|
if verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||||
s += self.write_results(i, results, (p, im, im0))
|
s += self.write_results(i, self.results, (p, im, im0))
|
||||||
|
|
||||||
if self.args.show:
|
if self.args.show:
|
||||||
self.show(p)
|
self.show(p)
|
||||||
@ -209,22 +210,21 @@ class BasePredictor:
|
|||||||
if self.args.save:
|
if self.args.save:
|
||||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||||
|
|
||||||
yield from results
|
self.run_callbacks("on_predict_batch_end")
|
||||||
|
yield from self.results
|
||||||
|
|
||||||
# Print time (inference-only)
|
# Print time (inference-only)
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||||
|
|
||||||
self.run_callbacks("on_predict_batch_end")
|
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
if verbose and self.seen:
|
if verbose and self.seen:
|
||||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
|
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
|
||||||
f'{(1, 3, *self.imgsz)}' % t)
|
f'{(1, 3, *self.imgsz)}' % t)
|
||||||
if self.args.save_txt or self.args.save:
|
if self.args.save_txt or self.args.save:
|
||||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" \
|
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
||||||
if self.args.save_txt else ''
|
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||||
|
|
||||||
self.run_callbacks("on_predict_end")
|
self.run_callbacks("on_predict_end")
|
||||||
|
@ -20,19 +20,18 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import ultralytics.yolo.utils as utils
|
|
||||||
from ultralytics import __version__
|
from ultralytics import __version__
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||||
yaml_save)
|
emojis, yaml_save)
|
||||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||||
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
|
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
|
||||||
strip_optimizer)
|
select_device, strip_optimizer)
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
@ -81,7 +80,7 @@ class BaseTrainer:
|
|||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
|
self.device = select_device(self.args.device, self.args.batch)
|
||||||
self.check_resume()
|
self.check_resume()
|
||||||
self.console = LOGGER
|
self.console = LOGGER
|
||||||
self.validator = None
|
self.validator = None
|
||||||
@ -120,9 +119,11 @@ class BaseTrainer:
|
|||||||
self.model = self.args.model
|
self.model = self.args.model
|
||||||
self.data = self.args.data
|
self.data = self.args.data
|
||||||
if self.data.endswith(".yaml"):
|
if self.data.endswith(".yaml"):
|
||||||
self.data = check_dataset_yaml(self.data)
|
self.data = check_det_dataset(self.data)
|
||||||
|
elif self.args.task == 'classify':
|
||||||
|
self.data = check_cls_dataset(self.data)
|
||||||
else:
|
else:
|
||||||
self.data = check_dataset(self.data)
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
|
||||||
self.trainset, self.testset = self.get_dataset(self.data)
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
self.ema = None
|
self.ema = None
|
||||||
|
|
||||||
@ -140,7 +141,7 @@ class BaseTrainer:
|
|||||||
self.plot_idx = [0, 1, 2]
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||||
if RANK in {0, -1}:
|
if RANK in {0, -1}:
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
|
@ -9,8 +9,8 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
|
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz
|
from ultralytics.yolo.utils.checks import check_imgsz
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
@ -70,7 +70,7 @@ class BaseValidator:
|
|||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.001 # default conf=0.001
|
self.args.conf = 0.001 # default conf=0.001
|
||||||
|
|
||||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, trainer=None, model=None):
|
def __call__(self, trainer=None, model=None):
|
||||||
@ -109,9 +109,11 @@ class BaseValidator:
|
|||||||
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||||
|
|
||||||
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
|
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
|
||||||
self.data = check_dataset_yaml(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
|
elif self.args.task == 'classify':
|
||||||
|
self.data = check_cls_dataset(self.args.data)
|
||||||
else:
|
else:
|
||||||
self.data = check_dataset(self.args.data)
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
|
||||||
|
|
||||||
if self.device.type == 'cpu':
|
if self.device.type == 'cpu':
|
||||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
|
@ -68,7 +68,7 @@ HELP_MSG = \
|
|||||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||||
|
|
||||||
- Predict a YouTube video using a pretrained segmentation model at image size 320:
|
- Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||||
|
|
||||||
- Val a pretrained detection model at batch-size 1 and image size 640:
|
- Val a pretrained detection model at batch-size 1 and image size 640:
|
||||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||||
@ -109,6 +109,9 @@ class IterableSimpleNamespace(SimpleNamespace):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
|
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
|
||||||
# Default configuration
|
# Default configuration
|
||||||
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
|
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
|
||||||
|
@ -106,36 +106,36 @@ def on_export_end(exporter):
|
|||||||
|
|
||||||
default_callbacks = {
|
default_callbacks = {
|
||||||
# Run in trainer
|
# Run in trainer
|
||||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
'on_pretrain_routine_start': [on_pretrain_routine_start],
|
||||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
'on_pretrain_routine_end': [on_pretrain_routine_end],
|
||||||
'on_train_start': on_train_start,
|
'on_train_start': [on_train_start],
|
||||||
'on_train_epoch_start': on_train_epoch_start,
|
'on_train_epoch_start': [on_train_epoch_start],
|
||||||
'on_train_batch_start': on_train_batch_start,
|
'on_train_batch_start': [on_train_batch_start],
|
||||||
'optimizer_step': optimizer_step,
|
'optimizer_step': [optimizer_step],
|
||||||
'on_before_zero_grad': on_before_zero_grad,
|
'on_before_zero_grad': [on_before_zero_grad],
|
||||||
'on_train_batch_end': on_train_batch_end,
|
'on_train_batch_end': [on_train_batch_end],
|
||||||
'on_train_epoch_end': on_train_epoch_end,
|
'on_train_epoch_end': [on_train_epoch_end],
|
||||||
'on_fit_epoch_end': on_fit_epoch_end, # fit = train + val
|
'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
|
||||||
'on_model_save': on_model_save,
|
'on_model_save': [on_model_save],
|
||||||
'on_train_end': on_train_end,
|
'on_train_end': [on_train_end],
|
||||||
'on_params_update': on_params_update,
|
'on_params_update': [on_params_update],
|
||||||
'teardown': teardown,
|
'teardown': [teardown],
|
||||||
|
|
||||||
# Run in validator
|
# Run in validator
|
||||||
'on_val_start': on_val_start,
|
'on_val_start': [on_val_start],
|
||||||
'on_val_batch_start': on_val_batch_start,
|
'on_val_batch_start': [on_val_batch_start],
|
||||||
'on_val_batch_end': on_val_batch_end,
|
'on_val_batch_end': [on_val_batch_end],
|
||||||
'on_val_end': on_val_end,
|
'on_val_end': [on_val_end],
|
||||||
|
|
||||||
# Run in predictor
|
# Run in predictor
|
||||||
'on_predict_start': on_predict_start,
|
'on_predict_start': [on_predict_start],
|
||||||
'on_predict_batch_start': on_predict_batch_start,
|
'on_predict_batch_start': [on_predict_batch_start],
|
||||||
'on_predict_batch_end': on_predict_batch_end,
|
'on_predict_batch_end': [on_predict_batch_end],
|
||||||
'on_predict_end': on_predict_end,
|
'on_predict_end': [on_predict_end],
|
||||||
|
|
||||||
# Run in exporter
|
# Run in exporter
|
||||||
'on_export_start': on_export_start,
|
'on_export_start': [on_export_start],
|
||||||
'on_export_end': on_export_end}
|
'on_export_end': [on_export_end]}
|
||||||
|
|
||||||
|
|
||||||
def add_integration_callbacks(instance):
|
def add_integration_callbacks(instance):
|
||||||
|
@ -307,18 +307,20 @@ def strip_optimizer(f='best.pt', s=''):
|
|||||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||||
|
|
||||||
|
|
||||||
def guess_task_from_head(head):
|
def guess_task_from_model_yaml(model):
|
||||||
|
try:
|
||||||
|
cfg = model if isinstance(model, dict) else model.yaml # model cfg dict
|
||||||
|
m = cfg["head"][-1][-2].lower() # output module name
|
||||||
task = None
|
task = None
|
||||||
if head.lower() in ["classify", "classifier", "cls", "fc"]:
|
if m in ["classify", "classifier", "cls", "fc"]:
|
||||||
task = "classify"
|
task = "classify"
|
||||||
if head.lower() in ["detect"]:
|
if m in ["detect"]:
|
||||||
task = "detect"
|
task = "detect"
|
||||||
if head.lower() in ["segment"]:
|
if m in ["segment"]:
|
||||||
task = "segment"
|
task = "segment"
|
||||||
|
except Exception as e:
|
||||||
if not task:
|
raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. '
|
||||||
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
|
'Valid tasks are detect, segment, classify.') from e
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
@ -374,14 +376,36 @@ def profile(input, ops, n=10, device=None):
|
|||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
# early stopper
|
"""
|
||||||
|
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, patience=30):
|
def __init__(self, patience=30):
|
||||||
|
"""
|
||||||
|
Initialize early stopping object
|
||||||
|
|
||||||
|
Args:
|
||||||
|
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. Default is 30.
|
||||||
|
"""
|
||||||
self.best_fitness = 0.0 # i.e. mAP
|
self.best_fitness = 0.0 # i.e. mAP
|
||||||
self.best_epoch = 0
|
self.best_epoch = 0
|
||||||
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
||||||
self.possible_stop = False # possible stop may occur next epoch
|
self.possible_stop = False # possible stop may occur next epoch
|
||||||
|
|
||||||
def __call__(self, epoch, fitness):
|
def __call__(self, epoch, fitness):
|
||||||
|
"""
|
||||||
|
Check whether to stop training
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): Current epoch of training
|
||||||
|
fitness (float): Fitness value of current epoch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if training should stop, False otherwise
|
||||||
|
"""
|
||||||
|
if fitness is None: # check if fitness=None (happens when val=False)
|
||||||
|
return False
|
||||||
|
|
||||||
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
||||||
self.best_epoch = epoch
|
self.best_epoch = epoch
|
||||||
self.best_fitness = fitness
|
self.best_fitness = fitness
|
||||||
|
@ -10,6 +10,7 @@ class ClassificationValidator(BaseValidator):
|
|||||||
|
|
||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||||
|
self.args.task = 'classify'
|
||||||
self.metrics = ClassifyMetrics()
|
self.metrics = ClassifyMetrics()
|
||||||
|
|
||||||
def get_desc(self):
|
def get_desc(self):
|
||||||
|
@ -20,6 +20,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||||
|
self.args.task = 'detect'
|
||||||
self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
|
self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
|
||||||
self.is_coco = False
|
self.is_coco = False
|
||||||
self.class_map = None
|
self.class_map = None
|
||||||
|
@ -87,7 +87,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
c = int(cls) # integer class
|
c = int(cls) # integer class
|
||||||
label = None if self.args.hide_labels else (
|
label = None if self.args.hide_labels else (
|
||||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
|
||||||
if self.args.save_crop:
|
if self.args.save_crop:
|
||||||
imc = im0.copy()
|
imc = im0.copy()
|
||||||
save_one_box(d.xyxy,
|
save_one_box(d.xyxy,
|
||||||
|
@ -19,7 +19,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
|
|
||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||||
self.args.task = "segment"
|
self.args.task = 'segment'
|
||||||
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
|
Reference in New Issue
Block a user