ultralytics 8.0.49 task, exports and metadata updates (#1197)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
Co-authored-by: Paul Guerrie <97041392+paulguerrie@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-03-01 21:16:09 -08:00
committed by GitHub
parent 74e4c94806
commit 3861e6c82a
20 changed files with 111 additions and 101 deletions

View File

@ -273,7 +273,7 @@ def entrypoint(debug=''):
return
# Task
task = overrides.get('task')
task = overrides.pop('task', None)
if task and task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
@ -289,9 +289,8 @@ def entrypoint(debug=''):
# Task Update
if task and task != model.task:
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
f'This may produce errors.')
task = task or model.task
overrides['task'] = task
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
task = model.task
# Mode
if mode in {'predict', 'track'} and 'source' not in overrides:

View File

@ -54,7 +54,7 @@ class _RepeatSampler:
yield from iter(self.sampler)
def seed_worker(worker_id):
def seed_worker(worker_id): # noqa
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
@ -134,7 +134,7 @@ def build_classification_dataloader(path,
def check_source(source):
webcam, screenshot, from_img, in_memory = False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb carame
if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
@ -147,11 +147,10 @@ def check_source(source):
elif isinstance(source, (list, tuple)):
source = autocast_list(source) # convert all list elements to PIL or np arrays
from_img = True
elif isinstance(source, ((Image.Image, np.ndarray))):
elif isinstance(source, (Image.Image, np.ndarray)):
from_img = True
else:
raise Exception(
'Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict')
raise TypeError('Unsupported image type. See docs for supported types https://docs.ultralytics.com/predict')
return source, webcam, screenshot, from_img, in_memory

View File

@ -215,7 +215,7 @@ class Exporter:
self.model = model
self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
description = f'Ultralytics {self.pretty_name} model ' + f'trained on {Path(self.args.data).name}' \
if self.args.data else '(untrained)'
self.metadata = {
@ -225,6 +225,8 @@ class Exporter:
'version': __version__,
'stride': int(max(model.stride)),
'task': model.task,
'batch': self.args.batch,
'imgsz': self.imgsz,
'names': model.names} # model metadata
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
@ -283,8 +285,7 @@ class Exporter:
f = self.file.with_suffix('.torchscript')
ts = torch.jit.trace(self.model, self.im, strict=False)
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
LOGGER.info(f'{prefix} optimizing for mobile...')
from torch.utils.mobile_optimizer import optimize_for_mobile
@ -429,16 +430,18 @@ class Exporter:
classifier_config=classifier_config)
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
if bits < 32:
if 'kmeans' in mode:
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
if self.args.nms and self.model.task == 'detect':
ct_model = self._pipeline_coreml(ct_model)
m = self.metadata # metadata dict
ct_model.short_description = m['description']
ct_model.author = m['author']
ct_model.license = m['license']
ct_model.version = m['version']
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items() if k in ('stride', 'task', 'names')})
ct_model.short_description = m.pop('description')
ct_model.author = m.pop('author')
ct_model.license = m.pop('license')
ct_model.version = m.pop('version')
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
ct_model.save(str(f))
return f, ct_model

View File

@ -8,8 +8,8 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, is_pip_package, yaml_load)
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, ONLINE, RANK, ROOT,
callbacks, is_git_dir, is_pip_package, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@ -157,7 +157,7 @@ class YOLO:
"""
Inform user of ultralytics package update availability
"""
if is_pip_package():
if ONLINE and is_pip_package():
check_pip_update()
def reset(self):

View File

@ -5,6 +5,7 @@ Ultralytics Results, Boxes and Masks classes for handling inference results
Usage: See https://docs.ultralytics.com/predict/
"""
import pprint
from copy import deepcopy
from functools import lru_cache
@ -96,10 +97,11 @@ class Results:
return len(getattr(self, k))
def __str__(self):
return ''.join(getattr(self, k).__str__() for k in self._keys)
attr = {k: v for k, v in vars(self).items() if not isinstance(v, type(self))}
return pprint.pformat(attr, indent=2, width=120, depth=10, compact=True)
def __repr__(self):
return ''.join(getattr(self, k).__repr__() for k in self._keys)
return self.__str__()
def __getattr__(self, attr):
name = self.__class__.__name__
@ -261,7 +263,7 @@ class Boxes:
return self.boxes.__str__()
def __repr__(self):
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.boxes)}\n' +
return (f'Ultralytics YOLO {self.__class__.__name__}\n' + f'type: {type(self.boxes)}\n' +
f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
def __getitem__(self, idx):
@ -337,7 +339,7 @@ class Masks:
return self.masks.__str__()
def __repr__(self):
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.masks)}\n' +
return (f'Ultralytics YOLO {self.__class__.__name__}\n' + f'type: {type(self.masks)}\n' +
f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
def __getitem__(self, idx):

View File

@ -102,7 +102,7 @@ class BaseValidator:
model = model.half() if self.args.half else model.float()
self.model = model
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
self.args.plots = trainer.epoch == trainer.epochs - 1 # always plot final epoch
self.args.plots = trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
model.eval()
else:
callbacks.add_integration_callbacks(self)

View File

@ -45,9 +45,10 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
y = []
t0 = time.time()
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
emoji = '' # indicates export failure
emoji, filename = '', None # export defaults
try:
assert i != 11, 'paddle exports coming soon'
if model.task == 'classify':
assert i != 11, 'paddle cls exports coming soon'
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU'
@ -86,7 +87,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
if hard_fail:
assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}'
LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
y.append([name, emoji, None, None, None]) # mAP, t_inference
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
# Print results
check_yolo(device=device) # print system info

View File

@ -70,14 +70,14 @@ def file_date(path=__file__):
def file_size(path):
# Return file/dir size (MB)
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
else:
return 0.0
if isinstance(path, (str, Path)):
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
return 0.0
def url2file(url):

View File

@ -77,11 +77,18 @@ class SegLoss(Loss):
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
'as an example.\nSee https://docs.ultralytics.com/tasks/segmentation/ for help.') from e
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)