`ultralytics 8.0.50` AMP check and YOLOv5u YAMLs (#1263)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Troy <wudashuo@vip.qq.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 3861e6c82a
commit f0d8e4718b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -55,7 +55,7 @@ include train, val, and predict.
| mode | 'train' | YOLO mode, i.e. train, val, predict, or export | | mode | 'train' | YOLO mode, i.e. train, val, predict, or export |
| resume | False | resume training from last checkpoint or custom checkpoint if passed as resume=path/to/best.pt | | resume | False | resume training from last checkpoint or custom checkpoint if passed as resume=path/to/best.pt |
| model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml | | model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
| data | null | path to data file, i.e. i.e. coco128.yaml | | data | null | path to data file, i.e. coco128.yaml |
### Training ### Training
@ -69,7 +69,7 @@ task.
| Key | Value | Description | | Key | Value | Description |
|-----------------|--------|--------------------------------------------------------------------------------| |-----------------|--------|--------------------------------------------------------------------------------|
| model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml | | model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
| data | null | path to data file, i.e. i.e. coco128.yaml | | data | null | path to data file, i.e. coco128.yaml |
| epochs | 100 | number of epochs to train for | | epochs | 100 | number of epochs to train for |
| patience | 50 | epochs to wait for no observable improvement for early stopping of training | | patience | 50 | epochs to wait for no observable improvement for early stopping of training |
| batch | 16 | number of images per batch (-1 for AutoBatch) | | batch | 16 | number of images per batch (-1 for AutoBatch) |

@ -47,7 +47,7 @@ source can be used as a stream and the model argument required for that source.
| CSV | | `'sources.csv'` | `str`, `Path` | RTSP, RTMP, HTTP | | CSV | | `'sources.csv'` | `str`, `Path` | RTSP, RTMP, HTTP |
| video | &check; | `'vid.mp4'` | `str`, `Path` | | | video | &check; | `'vid.mp4'` | `str`, `Path` | |
| directory | &check; | `'path/'` | `str`, `Path` | | | directory | &check; | `'path/'` | `str`, `Path` | |
| glob | &check; | `path/*.jpg'` | `str` | Use `*` operator | | glob | &check; | `'path/*.jpg'` | `str` | Use `*` operator |
| YouTube | &check; | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | | | YouTube | &check; | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | |
| stream | &check; | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP | | stream | &check; | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP |

@ -49,6 +49,8 @@ def test_predict_dir():
def test_predict_img(): def test_predict_img():
model = YOLO(MODEL) model = YOLO(MODEL)
seg_model = YOLO('yolov8n-seg.pt')
cls_model = YOLO('yolov8n-cls.pt')
im = cv2.imread(str(SOURCE)) im = cv2.imread(str(SOURCE))
assert len(model(source=Image.open(SOURCE), save=True, verbose=True)) == 1 # PIL assert len(model(source=Image.open(SOURCE), save=True, verbose=True)) == 1 # PIL
assert len(model(source=im, save=True, save_txt=True)) == 1 # ndarray assert len(model(source=im, save=True, save_txt=True)) == 1 # ndarray
@ -64,6 +66,18 @@ def test_predict_img():
np.zeros((320, 640, 3))] # numpy np.zeros((320, 640, 3))] # numpy
assert len(model(batch)) == len(batch) # multiple sources in a batch assert len(model(batch)) == len(batch) # multiple sources in a batch
# Test tensor inference
im = cv2.imread(str(SOURCE)) # OpenCV
t = cv2.resize(im, (32, 32))
t = torch.from_numpy(t.transpose((2, 0, 1)))
t = torch.stack([t, t, t, t])
results = model(t)
assert len(results) == t.shape[0]
results = seg_model(t)
assert len(results) == t.shape[0]
results = cls_model(t)
assert len(results) == t.shape[0]
def test_predict_grey_and_4ch(): def test_predict_grey_and_4ch():
model = YOLO(MODEL) model = YOLO(MODEL)
@ -199,3 +213,6 @@ def test_result():
res = model(SOURCE) res = model(SOURCE)
res[0].plot() res[0].plot()
print(res[0].path) print(res[0].path)
test_predict_img()

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.49' __version__ = '8.0.50'
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks from ultralytics.yolo.utils.checks import check_yolo as checks

@ -0,0 +1,55 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
[-1, 3, C3, [768]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 11
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [768, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P5
[-1, 3, C3, [768, False]], # 15
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 19
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 20], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 16], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
[-1, 1, Conv, [768, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P6
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
]

@ -0,0 +1,55 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.67 # model depth multiple
width_multiple: 0.75 # layer channel multiple
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
[-1, 3, C3, [768]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 11
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [768, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P5
[-1, 3, C3, [768, False]], # 15
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 19
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 20], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 16], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
[-1, 1, Conv, [768, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P6
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
]

@ -0,0 +1,55 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
[-1, 3, C3, [768]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 11
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [768, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P5
[-1, 3, C3, [768, False]], # 15
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 19
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 20], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 16], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
[-1, 1, Conv, [768, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P6
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
]

@ -0,0 +1,55 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
[-1, 3, C3, [768]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 11
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [768, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P5
[-1, 3, C3, [768, False]], # 15
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 19
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 20], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 16], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
[-1, 1, Conv, [768, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P6
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
]

@ -5,7 +5,6 @@ nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple width_multiple: 0.50 # layer channel multiple
# YOLOv5 v6.0 backbone # YOLOv5 v6.0 backbone
backbone: backbone:
# [from, number, module, args] # [from, number, module, args]

@ -0,0 +1,55 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 1.33 # model depth multiple
width_multiple: 1.25 # layer channel multiple
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
[-1, 3, C3, [768]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 11
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [768, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P5
[-1, 3, C3, [768, False]], # 15
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 19
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 23 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 20], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 26 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 16], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [768, False]], # 29 (P5/32-large)
[-1, 1, Conv, [768, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P6
[-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
[[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
]

@ -27,6 +27,10 @@ def check_class_names(names):
if isinstance(names, dict): if isinstance(names, dict):
if not all(isinstance(k, int) for k in names.keys()): # convert string keys to int, i.e. '0' to 0 if not all(isinstance(k, int) for k in names.keys()): # convert string keys to int, i.e. '0' to 0
names = {int(k): v for k, v in names.items()} names = {int(k): v for k, v in names.items()}
n = len(names)
if max(names.keys()) >= n:
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764' if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()} names = {k: map[v] for k, v in names.items()}
@ -35,12 +39,14 @@ def check_class_names(names):
class AutoBackend(nn.Module): class AutoBackend(nn.Module):
def _apply_default_class_names(self, data): def __init__(self,
with contextlib.suppress(Exception): weights='yolov8n.pt',
return yaml_load(check_yaml(data))['names'] device=torch.device('cpu'),
return {i: f'class{i}' for i in range(999)} # return default if above errors dnn=False,
data=None,
def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): fp16=False,
fuse=True,
verbose=True):
""" """
MultiBackend class for python inference on various platforms using Ultralytics YOLO. MultiBackend class for python inference on various platforms using Ultralytics YOLO.
@ -51,6 +57,7 @@ class AutoBackend(nn.Module):
data (str), (Path): Additional data.yaml file for class names, optional data (str), (Path): Additional data.yaml file for class names, optional
fp16 (bool): If True, use half precision. Default: False fp16 (bool): If True, use half precision. Default: False
fuse (bool): Whether to fuse the model or not. Default: True fuse (bool): Whether to fuse the model or not. Default: True
verbose (bool): Whether to run in verbose mode or not. Default: True
Supported formats and their naming conventions: Supported formats and their naming conventions:
| Format | Suffix | | Format | Suffix |
@ -83,7 +90,7 @@ class AutoBackend(nn.Module):
# NOTE: special case: in-memory pytorch model # NOTE: special case: in-memory pytorch model
if nn_module: if nn_module:
model = weights.to(device) model = weights.to(device)
model = model.fuse() if fuse else model model = model.fuse(verbose=verbose) if fuse else model
names = model.module.names if hasattr(model, 'module') else model.names # get class names names = model.module.names if hasattr(model, 'module') else model.names # get class names
stride = max(int(model.stride.max()), 32) # model stride stride = max(int(model.stride.max()), 32) # model stride
model.half() if fp16 else model.float() model.half() if fp16 else model.float()
@ -410,6 +417,12 @@ class AutoBackend(nn.Module):
for _ in range(2 if self.jit else 1): # for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup self.forward(im) # warmup
@staticmethod
def _apply_default_class_names(data):
with contextlib.suppress(Exception):
return yaml_load(check_yaml(data))['names']
return {i: f'class{i}' for i in range(999)} # return default if above errors
@staticmethod @staticmethod
def _model_type(p='path/to/model.pt'): def _model_type(p='path/to/model.pt'):
""" """

@ -8,9 +8,7 @@ import thop
import torch import torch
import torch.nn as nn import torch.nn as nn
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, from ultralytics.nn.modules import * # noqa: F403
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, emojis, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_yaml from ultralytics.yolo.utils.checks import check_requirements, check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
@ -87,7 +85,7 @@ class BaseModel(nn.Module):
if c: if c:
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
def fuse(self): def fuse(self, verbose=True):
""" """
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
computation efficiency. computation efficiency.
@ -105,7 +103,7 @@ class BaseModel(nn.Module):
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
delattr(m, 'bn') # remove batchnorm delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward m.forward = m.forward_fuse # update forward
self.info() self.info(verbose=verbose)
return self return self
@ -130,7 +128,7 @@ class BaseModel(nn.Module):
verbose (bool): if True, prints out the model information. Defaults to False verbose (bool): if True, prints out the model information. Defaults to False
imgsz (int): the size of the image that the model will be trained on. Defaults to 640 imgsz (int): the size of the image that the model will be trained on. Defaults to 640
""" """
model_info(self, verbose, imgsz) model_info(self, verbose=verbose, imgsz=imgsz)
def _apply(self, fn): def _apply(self, fn):
""" """
@ -437,7 +435,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
ch = [ch] ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
for j, a in enumerate(args): for j, a in enumerate(args):
# TODO: re-implement with eval() removal if possible # TODO: re-implement with eval() removal if possible
# args[j] = (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else a # args[j] = (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else a

@ -61,8 +61,10 @@ CFG_BOOL_KEYS = ('save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', '
'v5loader') 'v5loader')
# Define valid tasks and modes # Define valid tasks and modes
TASKS = 'detect', 'segment', 'classify'
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify'
TASK2DATA = {'detect': 'coco128.yaml', 'segment': 'coco128-seg.yaml', 'classify': 'imagenet100'}
TASK2MODEL = {'detect': 'yolov8n.pt', 'segment': 'yolov8n-seg.pt', 'classify': 'yolov8n-cls.pt'}
def cfg2dict(cfg): def cfg2dict(cfg):
@ -274,8 +276,11 @@ def entrypoint(debug=''):
# Task # Task
task = overrides.pop('task', None) task = overrides.pop('task', None)
if task and task not in TASKS: if task:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") if task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
if 'model' not in overrides:
overrides['model'] = TASK2MODEL[task]
# Model # Model
model = overrides.pop('model', DEFAULT_CFG.model) model = overrides.pop('model', DEFAULT_CFG.model)
@ -287,9 +292,10 @@ def entrypoint(debug=''):
model = YOLO(model, task=task) model = YOLO(model, task=task)
# Task Update # Task Update
if task and task != model.task: if task != model.task:
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " if task:
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.") LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
task = model.task task = model.task
# Mode # Mode
@ -299,8 +305,7 @@ def entrypoint(debug=''):
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'): elif mode in ('train', 'val'):
if 'data' not in overrides: if 'data' not in overrides:
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100') overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
overrides['data'] = task2data.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export': elif mode == 'export':
if 'format' not in overrides: if 'format' not in overrides:
@ -322,4 +327,4 @@ def copy_default_cfg():
if __name__ == '__main__': if __name__ == '__main__':
# entrypoint(debug='yolo predict model=yolov8n.pt') # entrypoint(debug='yolo predict model=yolov8n.pt')
entrypoint(debug='') entrypoint(debug='yolo train model=yolov8n-seg.pt')

@ -6,7 +6,7 @@ mode: train # YOLO mode, i.e. train, val, predict, export
# Train settings ------------------------------------------------------------------------------------------------------- # Train settings -------------------------------------------------------------------------------------------------------
model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: # path to data file, i.e. i.e. coco128.yaml data: # path to data file, i.e. coco128.yaml
epochs: 100 # number of epochs to train for epochs: 100 # number of epochs to train for
patience: 50 # epochs to wait for no observable improvement for early stopping of training patience: 50 # epochs to wait for no observable improvement for early stopping of training
batch: 16 # number of images per batch (-1 for AutoBatch) batch: 16 # number of images per batch (-1 for AutoBatch)

@ -35,7 +35,8 @@ class BaseDataset(Dataset):
batch_size=None, batch_size=None,
stride=32, stride=32,
pad=0.5, pad=0.5,
single_cls=False): single_cls=False,
classes=None):
super().__init__() super().__init__()
self.img_path = img_path self.img_path = img_path
self.imgsz = imgsz self.imgsz = imgsz
@ -45,8 +46,7 @@ class BaseDataset(Dataset):
self.im_files = self.get_img_files(self.img_path) self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels() self.labels = self.get_labels()
if self.single_cls: self.update_labels(include_class=classes) # single_cls and include_class
self.update_labels(include_class=[])
self.ni = len(self.labels) self.ni = len(self.labels)
@ -96,7 +96,7 @@ class BaseDataset(Dataset):
"""include_class, filter labels to include only these classes (optional)""" """include_class, filter labels to include only these classes (optional)"""
include_class_array = np.array(include_class).reshape(1, -1) include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)): for i in range(len(self.labels)):
if include_class: if include_class is not None:
cls = self.labels[i]['cls'] cls = self.labels[i]['cls']
bboxes = self.labels[i]['bboxes'] bboxes = self.labels[i]['bboxes']
segments = self.labels[i]['segments'] segments = self.labels[i]['segments']
@ -104,7 +104,7 @@ class BaseDataset(Dataset):
self.labels[i]['cls'] = cls[j] self.labels[i]['cls'] = cls[j]
self.labels[i]['bboxes'] = bboxes[j] self.labels[i]['bboxes'] = bboxes[j]
if segments: if segments:
self.labels[i]['segments'] = segments[j] self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
if self.single_cls: if self.single_cls:
self.labels[i]['cls'][:, 0] = 0 self.labels[i]['cls'][:, 0] = 0

@ -10,7 +10,7 @@ from PIL import Image
from torch.utils.data import DataLoader, dataloader, distributed from torch.utils.data import DataLoader, dataloader, distributed
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
LoadStreams, SourceTypes, autocast_list) LoadStreams, LoadTensor, SourceTypes, autocast_list)
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file from ultralytics.yolo.utils.checks import check_file
@ -82,7 +82,8 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
prefix=colorstr(f'{mode}: '), prefix=colorstr(f'{mode}: '),
use_segments=cfg.task == 'segment', use_segments=cfg.task == 'segment',
use_keypoints=cfg.task == 'keypoint', use_keypoints=cfg.task == 'keypoint',
names=names) names=names,
classes=cfg.classes)
batch = min(batch, len(dataset)) batch = min(batch, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices nd = torch.cuda.device_count() # number of CUDA devices
@ -133,7 +134,7 @@ def build_classification_dataloader(path,
def check_source(source): def check_source(source):
webcam, screenshot, from_img, in_memory = False, False, False, False webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb camera if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source) source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
@ -149,22 +150,25 @@ def check_source(source):
from_img = True from_img = True
elif isinstance(source, (Image.Image, np.ndarray)): elif isinstance(source, (Image.Image, np.ndarray)):
from_img = True from_img = True
elif isinstance(source, torch.Tensor):
tensor = True
else: else:
raise TypeError('Unsupported image type. See docs for supported types https://docs.ultralytics.com/predict') raise TypeError('Unsupported image type. See docs for supported types https://docs.ultralytics.com/predict')
return source, webcam, screenshot, from_img, in_memory return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True): def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
""" """
TODO: docs TODO: docs
""" """
# source source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source, webcam, screenshot, from_img, in_memory = check_source(source) source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img)
# Dataloader # Dataloader
if in_memory: if tensor:
dataset = LoadTensor(source)
elif in_memory:
dataset = source dataset = source
elif webcam: elif webcam:
dataset = LoadStreams(source, dataset = LoadStreams(source,

@ -26,6 +26,7 @@ class SourceTypes:
webcam: bool = False webcam: bool = False
screenshot: bool = False screenshot: bool = False
from_img: bool = False from_img: bool = False
tensor: bool = False
class LoadStreams: class LoadStreams:
@ -329,6 +330,23 @@ class LoadPilAndNumpy:
return self return self
class LoadTensor:
def __init__(self, imgs) -> None:
self.im0 = imgs
self.bs = imgs.shape[0]
def __iter__(self):
self.count = 0
return self
def __next__(self):
if self.count == 1:
raise StopIteration
self.count += 1
return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, ''
def autocast_list(source): def autocast_list(source):
""" """
Merges a list of source of different types into a list of numpy arrays or PIL images Merges a list of source of different types into a list of numpy arrays or PIL images

@ -539,7 +539,7 @@ class LoadImagesAndLabels(Dataset):
j = (label[:, 0:1] == include_class_array).any(1) j = (label[:, 0:1] == include_class_array).any(1)
self.labels[i] = label[j] self.labels[i] = label[j]
if segment: if segment:
self.segments[i] = segment[j] self.segments[i] = [segment[si] for si, idx in enumerate(j) if idx]
if single_cls: # single-class training, merge all classes into 0 if single_cls: # single-class training, merge all classes into 0
self.labels[i][:, 0] = 0 self.labels[i][:, 0] = 0

@ -57,12 +57,14 @@ class YOLODataset(BaseDataset):
single_cls=False, single_cls=False,
use_segments=False, use_segments=False,
use_keypoints=False, use_keypoints=False,
names=None): names=None,
classes=None):
self.use_segments = use_segments self.use_segments = use_segments
self.use_keypoints = use_keypoints self.use_keypoints = use_keypoints
self.names = names self.names = names
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls) super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls,
classes)
def cache_labels(self, path=Path('./labels.cache')): def cache_labels(self, path=Path('./labels.cache')):
"""Cache dataset labels, check images and read shapes. """Cache dataset labels, check images and read shapes.

@ -16,6 +16,7 @@ import numpy as np
from PIL import ExifTags, Image, ImageOps from PIL import ExifTags, Image, ImageOps
from tqdm import tqdm from tqdm import tqdm
from ultralytics.nn.autobackend import check_class_names
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
@ -211,8 +212,7 @@ def check_det_dataset(dataset, autodownload=True):
raise SyntaxError( raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n" emojis(f"{dataset} '{k}:' key missing ❌.\n"
f"'train', 'val' and 'names' are required in data.yaml files.")) f"'train', 'val' and 'names' are required in data.yaml files."))
if isinstance(data['names'], (list, tuple)): # old array format data['names'] = check_class_names(data['names'])
data['names'] = dict(enumerate(data['names'])) # convert to dict
data['nc'] = len(data['names']) data['nc'] = len(data['names'])
# Resolve paths # Resolve paths

@ -574,7 +574,7 @@ class Exporter:
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model')) saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
if self.args.int8: if self.args.int8:
f = saved_model / (self.file.stem + 'yolov8n_integer_quant.tflite') # fp32 in/out f = saved_model / (self.file.stem + '_integer_quant.tflite') # fp32 in/out
elif self.args.half: elif self.args.half:
f = saved_model / (self.file.stem + '_float16.tflite') f = saved_model / (self.file.stem + '_float16.tflite')
else: else:
@ -863,18 +863,6 @@ def export(cfg=DEFAULT_CFG):
cfg.model = cfg.model or 'yolov8n.yaml' cfg.model = cfg.model or 'yolov8n.yaml'
cfg.format = cfg.format or 'torchscript' cfg.format = cfg.format or 'torchscript'
# exporter = Exporter(cfg)
#
# model = None
# if isinstance(cfg.model, (str, Path)):
# if Path(cfg.model).suffix == '.yaml':
# model = DetectionModel(cfg.model)
# elif Path(cfg.model).suffix == '.pt':
# model = attempt_load_weights(cfg.model, fuse=True)
# else:
# TypeError(f'Unsupported model type {cfg.model}')
# exporter(model=model)
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) model = YOLO(cfg.model)
model.export(**vars(cfg)) model.export(**vars(cfg))

@ -203,6 +203,8 @@ class YOLO:
if source is None: if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \
('predict' in sys.argv or 'mode=predict' in sys.argv)
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides['conf'] = 0.25 overrides['conf'] = 0.25
@ -213,10 +215,9 @@ class YOLO:
if not self.predictor: if not self.predictor:
self.task = overrides.get('task') or self.task self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides) self.predictor = TASK_MAP[self.task][3](overrides=overrides)
self.predictor.setup_model(model=self.model) self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides) self.predictor.args = get_cfg(self.predictor.args, overrides)
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, **kwargs): def track(self, source=None, stream=False, **kwargs):

@ -183,6 +183,8 @@ class BasePredictor:
'preprocess': self.dt[0].dt * 1E3 / n, 'preprocess': self.dt[0].dt * 1E3 / n,
'inference': self.dt[1].dt * 1E3 / n, 'inference': self.dt[1].dt * 1E3 / n,
'postprocess': self.dt[2].dt * 1E3 / n} 'postprocess': self.dt[2].dt * 1E3 / n}
if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor
continue
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \ p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
else (path, im0s.copy()) else (path, im0s.copy())
p = Path(p) p = Path(p)
@ -218,11 +220,16 @@ class BasePredictor:
self.run_callbacks('on_predict_end') self.run_callbacks('on_predict_end')
def setup_model(self, model): def setup_model(self, model, verbose=True):
device = select_device(self.args.device) device = select_device(self.args.device, verbose=verbose)
model = model or self.args.model model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half) self.model = AutoBackend(model,
device=device,
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
verbose=verbose)
self.device = device self.device = device
self.model.eval() self.model.eval()

@ -25,8 +25,8 @@ from tqdm import tqdm
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__,
colorstr, emojis, yaml_save) callbacks, colorstr, emojis, yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
@ -111,8 +111,6 @@ class BaseTrainer:
print_args(vars(self.args)) print_args(vars(self.args))
# Device # Device
self.amp = self.device.type != 'cpu'
self.scaler = amp.GradScaler(enabled=self.amp)
if self.device.type == 'cpu': if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
@ -126,7 +124,7 @@ class BaseTrainer:
if 'yaml_file' in self.data: if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
except Exception as e: except Exception as e:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data) self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None self.ema = None
@ -204,6 +202,8 @@ class BaseTrainer:
ckpt = self.setup_model() ckpt = self.setup_model()
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.set_model_attributes() self.set_model_attributes()
self.amp = check_amp(self.model)
self.scaler = amp.GradScaler(enabled=self.amp)
if world_size > 1: if world_size > 1:
self.model = DDP(self.model, device_ids=[rank]) self.model = DDP(self.model, device_ids=[rank])
# Check imgsz # Check imgsz
@ -597,3 +597,31 @@ class BaseTrainer:
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias') f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
return optimizer return optimizer
def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
# All close FP32 vs AMP results
a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference
with torch.cuda.amp.autocast(True):
b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.1) # close to 10% absolute tolerance
f = ROOT / 'assets/bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
prefix = colorstr('AMP: ')
try:
from ultralytics import YOLO
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
assert amp_allclose(YOLO('yolov8n.pt'), im)
LOGGER.info(f'{prefix}checks passed ✅')
return True
except AssertionError:
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
return False

@ -236,9 +236,10 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
def check_yolov5u_filename(file: str, verbose: bool = True): def check_yolov5u_filename(file: str, verbose: bool = True):
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames # Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
if 'yolov3' in file or 'yolov5' in file and 'u' not in file: if ('yolov3' in file or 'yolov5' in file) and 'u' not in file:
original_file = file original_file = file
file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r'(.*yolov5([nsmlx])6)\.', '\\1u.', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose: if file != original_file and verbose:
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "

@ -162,11 +162,13 @@ def fuse_deconv_and_bn(deconv, bn):
return fuseddconv return fuseddconv
def model_info(model, verbose=False, imgsz=640): def model_info(model, detailed=False, verbose=True, imgsz=640):
# Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] # Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
if not verbose:
return
n_p = get_num_params(model) n_p = get_num_params(model)
n_g = get_num_gradients(model) # number gradients n_g = get_num_gradients(model) # number gradients
if verbose: if detailed:
LOGGER.info( LOGGER.info(
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
for i, (name, p) in enumerate(model.named_parameters()): for i, (name, p) in enumerate(model.named_parameters()):

@ -14,14 +14,13 @@ class ClassificationPredictor(BasePredictor):
return Annotator(img, example=str(self.model.names), pil=True) return Annotator(img, example=str(self.model.names), pil=True)
def preprocess(self, img): def preprocess(self, img):
img = (img if isinstance(img, torch.Tensor) else torch.Tensor(img)).to(self.model.device) img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
return img
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_imgs):
results = [] results = []
for i, pred in enumerate(preds): for i, pred in enumerate(preds):
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path, _, _, _, _ = self.batch path, _, _, _, _ = self.batch
img_path = path[i] if isinstance(path, list) else path img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))

@ -14,12 +14,12 @@ class DetectionPredictor(BasePredictor):
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names)) return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
def preprocess(self, img): def preprocess(self, img):
img = torch.from_numpy(img).to(self.model.device) img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0 img /= 255 # 0 - 255 to 0.0 - 1.0
return img return img
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_imgs):
preds = ops.non_max_suppression(preds, preds = ops.non_max_suppression(preds,
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,
@ -29,7 +29,7 @@ class DetectionPredictor(BasePredictor):
results = [] results = []
for i, pred in enumerate(preds): for i, pred in enumerate(preds):
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
shape = orig_img.shape shape = orig_img.shape
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
path, _, _, _, _ = self.batch path, _, _, _, _ = self.batch

@ -10,7 +10,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
class SegmentationPredictor(DetectionPredictor): class SegmentationPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_imgs):
# TODO: filter by classes # TODO: filter by classes
p = ops.non_max_suppression(preds[0], p = ops.non_max_suppression(preds[0],
self.args.conf, self.args.conf,
@ -22,7 +22,7 @@ class SegmentationPredictor(DetectionPredictor):
results = [] results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p): for i, pred in enumerate(p):
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
shape = orig_img.shape shape = orig_img.shape
path, _, _, _, _ = self.batch path, _, _, _, _ = self.batch
img_path = path[i] if isinstance(path, list) else path img_path = path[i] if isinstance(path, list) else path

Loading…
Cancel
Save