ultralytics 8.0.50 AMP check and YOLOv5u YAMLs (#1263)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Troy <wudashuo@vip.qq.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-03-06 11:39:26 +01:00
committed by GitHub
parent 3861e6c82a
commit f0d8e4718b
29 changed files with 440 additions and 83 deletions

View File

@ -61,8 +61,10 @@ CFG_BOOL_KEYS = ('save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', '
'v5loader')
# Define valid tasks and modes
TASKS = 'detect', 'segment', 'classify'
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify'
TASK2DATA = {'detect': 'coco128.yaml', 'segment': 'coco128-seg.yaml', 'classify': 'imagenet100'}
TASK2MODEL = {'detect': 'yolov8n.pt', 'segment': 'yolov8n-seg.pt', 'classify': 'yolov8n-cls.pt'}
def cfg2dict(cfg):
@ -274,8 +276,11 @@ def entrypoint(debug=''):
# Task
task = overrides.pop('task', None)
if task and task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
if task:
if task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
if 'model' not in overrides:
overrides['model'] = TASK2MODEL[task]
# Model
model = overrides.pop('model', DEFAULT_CFG.model)
@ -287,9 +292,10 @@ def entrypoint(debug=''):
model = YOLO(model, task=task)
# Task Update
if task and task != model.task:
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
if task != model.task:
if task:
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
task = model.task
# Mode
@ -299,8 +305,7 @@ def entrypoint(debug=''):
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
overrides['data'] = task2data.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
@ -322,4 +327,4 @@ def copy_default_cfg():
if __name__ == '__main__':
# entrypoint(debug='yolo predict model=yolov8n.pt')
entrypoint(debug='')
entrypoint(debug='yolo train model=yolov8n-seg.pt')

View File

@ -6,7 +6,7 @@ mode: train # YOLO mode, i.e. train, val, predict, export
# Train settings -------------------------------------------------------------------------------------------------------
model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: # path to data file, i.e. i.e. coco128.yaml
data: # path to data file, i.e. coco128.yaml
epochs: 100 # number of epochs to train for
patience: 50 # epochs to wait for no observable improvement for early stopping of training
batch: 16 # number of images per batch (-1 for AutoBatch)

View File

@ -35,7 +35,8 @@ class BaseDataset(Dataset):
batch_size=None,
stride=32,
pad=0.5,
single_cls=False):
single_cls=False,
classes=None):
super().__init__()
self.img_path = img_path
self.imgsz = imgsz
@ -45,8 +46,7 @@ class BaseDataset(Dataset):
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
if self.single_cls:
self.update_labels(include_class=[])
self.update_labels(include_class=classes) # single_cls and include_class
self.ni = len(self.labels)
@ -96,7 +96,7 @@ class BaseDataset(Dataset):
"""include_class, filter labels to include only these classes (optional)"""
include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)):
if include_class:
if include_class is not None:
cls = self.labels[i]['cls']
bboxes = self.labels[i]['bboxes']
segments = self.labels[i]['segments']
@ -104,7 +104,7 @@ class BaseDataset(Dataset):
self.labels[i]['cls'] = cls[j]
self.labels[i]['bboxes'] = bboxes[j]
if segments:
self.labels[i]['segments'] = segments[j]
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
if self.single_cls:
self.labels[i]['cls'][:, 0] = 0

View File

@ -10,7 +10,7 @@ from PIL import Image
from torch.utils.data import DataLoader, dataloader, distributed
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
LoadStreams, SourceTypes, autocast_list)
LoadStreams, LoadTensor, SourceTypes, autocast_list)
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file
@ -82,7 +82,8 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
prefix=colorstr(f'{mode}: '),
use_segments=cfg.task == 'segment',
use_keypoints=cfg.task == 'keypoint',
names=names)
names=names,
classes=cfg.classes)
batch = min(batch, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
@ -133,7 +134,7 @@ def build_classification_dataloader(path,
def check_source(source):
webcam, screenshot, from_img, in_memory = False, False, False, False
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
@ -149,22 +150,25 @@ def check_source(source):
from_img = True
elif isinstance(source, (Image.Image, np.ndarray)):
from_img = True
elif isinstance(source, torch.Tensor):
tensor = True
else:
raise TypeError('Unsupported image type. See docs for supported types https://docs.ultralytics.com/predict')
return source, webcam, screenshot, from_img, in_memory
return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
"""
TODO: docs
"""
# source
source, webcam, screenshot, from_img, in_memory = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img)
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
# Dataloader
if in_memory:
if tensor:
dataset = LoadTensor(source)
elif in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source,

View File

@ -26,6 +26,7 @@ class SourceTypes:
webcam: bool = False
screenshot: bool = False
from_img: bool = False
tensor: bool = False
class LoadStreams:
@ -329,6 +330,23 @@ class LoadPilAndNumpy:
return self
class LoadTensor:
def __init__(self, imgs) -> None:
self.im0 = imgs
self.bs = imgs.shape[0]
def __iter__(self):
self.count = 0
return self
def __next__(self):
if self.count == 1:
raise StopIteration
self.count += 1
return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, ''
def autocast_list(source):
"""
Merges a list of source of different types into a list of numpy arrays or PIL images

View File

@ -539,7 +539,7 @@ class LoadImagesAndLabels(Dataset):
j = (label[:, 0:1] == include_class_array).any(1)
self.labels[i] = label[j]
if segment:
self.segments[i] = segment[j]
self.segments[i] = [segment[si] for si, idx in enumerate(j) if idx]
if single_cls: # single-class training, merge all classes into 0
self.labels[i][:, 0] = 0

View File

@ -57,12 +57,14 @@ class YOLODataset(BaseDataset):
single_cls=False,
use_segments=False,
use_keypoints=False,
names=None):
names=None,
classes=None):
self.use_segments = use_segments
self.use_keypoints = use_keypoints
self.names = names
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls,
classes)
def cache_labels(self, path=Path('./labels.cache')):
"""Cache dataset labels, check images and read shapes.

View File

@ -16,6 +16,7 @@ import numpy as np
from PIL import ExifTags, Image, ImageOps
from tqdm import tqdm
from ultralytics.nn.autobackend import check_class_names
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
@ -211,8 +212,7 @@ def check_det_dataset(dataset, autodownload=True):
raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n"
f"'train', 'val' and 'names' are required in data.yaml files."))
if isinstance(data['names'], (list, tuple)): # old array format
data['names'] = dict(enumerate(data['names'])) # convert to dict
data['names'] = check_class_names(data['names'])
data['nc'] = len(data['names'])
# Resolve paths

View File

@ -574,7 +574,7 @@ class Exporter:
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
if self.args.int8:
f = saved_model / (self.file.stem + 'yolov8n_integer_quant.tflite') # fp32 in/out
f = saved_model / (self.file.stem + '_integer_quant.tflite') # fp32 in/out
elif self.args.half:
f = saved_model / (self.file.stem + '_float16.tflite')
else:
@ -863,18 +863,6 @@ def export(cfg=DEFAULT_CFG):
cfg.model = cfg.model or 'yolov8n.yaml'
cfg.format = cfg.format or 'torchscript'
# exporter = Exporter(cfg)
#
# model = None
# if isinstance(cfg.model, (str, Path)):
# if Path(cfg.model).suffix == '.yaml':
# model = DetectionModel(cfg.model)
# elif Path(cfg.model).suffix == '.pt':
# model = attempt_load_weights(cfg.model, fuse=True)
# else:
# TypeError(f'Unsupported model type {cfg.model}')
# exporter(model=model)
from ultralytics import YOLO
model = YOLO(cfg.model)
model.export(**vars(cfg))

View File

@ -203,6 +203,8 @@ class YOLO:
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \
('predict' in sys.argv or 'mode=predict' in sys.argv)
overrides = self.overrides.copy()
overrides['conf'] = 0.25
@ -213,10 +215,9 @@ class YOLO:
if not self.predictor:
self.task = overrides.get('task') or self.task
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
self.predictor.setup_model(model=self.model)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, **kwargs):

View File

@ -183,6 +183,8 @@ class BasePredictor:
'preprocess': self.dt[0].dt * 1E3 / n,
'inference': self.dt[1].dt * 1E3 / n,
'postprocess': self.dt[2].dt * 1E3 / n}
if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor
continue
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
else (path, im0s.copy())
p = Path(p)
@ -218,11 +220,16 @@ class BasePredictor:
self.run_callbacks('on_predict_end')
def setup_model(self, model):
device = select_device(self.args.device)
def setup_model(self, model, verbose=True):
device = select_device(self.args.device, verbose=verbose)
model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
self.model = AutoBackend(model,
device=device,
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
verbose=verbose)
self.device = device
self.model.eval()

View File

@ -25,8 +25,8 @@ from tqdm import tqdm
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
colorstr, emojis, yaml_save)
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__,
callbacks, colorstr, emojis, yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
@ -111,8 +111,6 @@ class BaseTrainer:
print_args(vars(self.args))
# Device
self.amp = self.device.type != 'cpu'
self.scaler = amp.GradScaler(enabled=self.amp)
if self.device.type == 'cpu':
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
@ -126,7 +124,7 @@ class BaseTrainer:
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None
@ -204,6 +202,8 @@ class BaseTrainer:
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
self.amp = check_amp(self.model)
self.scaler = amp.GradScaler(enabled=self.amp)
if world_size > 1:
self.model = DDP(self.model, device_ids=[rank])
# Check imgsz
@ -597,3 +597,31 @@ class BaseTrainer:
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
return optimizer
def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
# All close FP32 vs AMP results
a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference
with torch.cuda.amp.autocast(True):
b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.1) # close to 10% absolute tolerance
f = ROOT / 'assets/bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
prefix = colorstr('AMP: ')
try:
from ultralytics import YOLO
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
assert amp_allclose(YOLO('yolov8n.pt'), im)
LOGGER.info(f'{prefix}checks passed ✅')
return True
except AssertionError:
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
return False

View File

@ -236,9 +236,10 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
def check_yolov5u_filename(file: str, verbose: bool = True):
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
if ('yolov3' in file or 'yolov5' in file) and 'u' not in file:
original_file = file
file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r'(.*yolov5([nsmlx])6)\.', '\\1u.', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose:
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "

View File

@ -162,11 +162,13 @@ def fuse_deconv_and_bn(deconv, bn):
return fuseddconv
def model_info(model, verbose=False, imgsz=640):
def model_info(model, detailed=False, verbose=True, imgsz=640):
# Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
if not verbose:
return
n_p = get_num_params(model)
n_g = get_num_gradients(model) # number gradients
if verbose:
if detailed:
LOGGER.info(
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
for i, (name, p) in enumerate(model.named_parameters()):

View File

@ -14,14 +14,13 @@ class ClassificationPredictor(BasePredictor):
return Annotator(img, example=str(self.model.names), pil=True)
def preprocess(self, img):
img = (img if isinstance(img, torch.Tensor) else torch.Tensor(img)).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
return img
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_imgs):
results = []
for i, pred in enumerate(preds):
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path, _, _, _, _ = self.batch
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))

View File

@ -14,12 +14,12 @@ class DetectionPredictor(BasePredictor):
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
def preprocess(self, img):
img = torch.from_numpy(img).to(self.model.device)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_imgs):
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
@ -29,7 +29,7 @@ class DetectionPredictor(BasePredictor):
results = []
for i, pred in enumerate(preds):
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
shape = orig_img.shape
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
path, _, _, _, _ = self.batch

View File

@ -10,7 +10,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
class SegmentationPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_imgs):
# TODO: filter by classes
p = ops.non_max_suppression(preds[0],
self.args.conf,
@ -22,7 +22,7 @@ class SegmentationPredictor(DetectionPredictor):
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
shape = orig_img.shape
path, _, _, _, _ = self.batch
img_path = path[i] if isinstance(path, list) else path