ultralytics 8.0.49
task, exports and metadata updates (#1197)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Paul Guerrie <97041392+paulguerrie@users.noreply.github.com>
This commit is contained in:
@ -215,7 +215,7 @@ class Exporter:
|
||||
self.model = model
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
|
||||
description = f'Ultralytics {self.pretty_name} model ' + f'trained on {Path(self.args.data).name}' \
|
||||
if self.args.data else '(untrained)'
|
||||
self.metadata = {
|
||||
@ -225,6 +225,8 @@ class Exporter:
|
||||
'version': __version__,
|
||||
'stride': int(max(model.stride)),
|
||||
'task': model.task,
|
||||
'batch': self.args.batch,
|
||||
'imgsz': self.imgsz,
|
||||
'names': model.names} # model metadata
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
@ -283,8 +285,7 @@ class Exporter:
|
||||
f = self.file.with_suffix('.torchscript')
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
d = {'shape': self.im.shape, 'stride': int(max(self.model.stride)), 'names': self.model.names}
|
||||
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
||||
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f'{prefix} optimizing for mobile...')
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
@ -429,16 +430,18 @@ class Exporter:
|
||||
classifier_config=classifier_config)
|
||||
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
|
||||
if bits < 32:
|
||||
if 'kmeans' in mode:
|
||||
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
|
||||
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
||||
if self.args.nms and self.model.task == 'detect':
|
||||
ct_model = self._pipeline_coreml(ct_model)
|
||||
|
||||
m = self.metadata # metadata dict
|
||||
ct_model.short_description = m['description']
|
||||
ct_model.author = m['author']
|
||||
ct_model.license = m['license']
|
||||
ct_model.version = m['version']
|
||||
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items() if k in ('stride', 'task', 'names')})
|
||||
ct_model.short_description = m.pop('description')
|
||||
ct_model.author = m.pop('author')
|
||||
ct_model.license = m.pop('license')
|
||||
ct_model.version = m.pop('version')
|
||||
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
||||
ct_model.save(str(f))
|
||||
return f, ct_model
|
||||
|
||||
|
@ -8,8 +8,8 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
is_git_dir, is_pip_package, yaml_load)
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, ONLINE, RANK, ROOT,
|
||||
callbacks, is_git_dir, is_pip_package, yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
@ -157,7 +157,7 @@ class YOLO:
|
||||
"""
|
||||
Inform user of ultralytics package update availability
|
||||
"""
|
||||
if is_pip_package():
|
||||
if ONLINE and is_pip_package():
|
||||
check_pip_update()
|
||||
|
||||
def reset(self):
|
||||
|
@ -5,6 +5,7 @@ Ultralytics Results, Boxes and Masks classes for handling inference results
|
||||
Usage: See https://docs.ultralytics.com/predict/
|
||||
"""
|
||||
|
||||
import pprint
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
|
||||
@ -96,10 +97,11 @@ class Results:
|
||||
return len(getattr(self, k))
|
||||
|
||||
def __str__(self):
|
||||
return ''.join(getattr(self, k).__str__() for k in self._keys)
|
||||
attr = {k: v for k, v in vars(self).items() if not isinstance(v, type(self))}
|
||||
return pprint.pformat(attr, indent=2, width=120, depth=10, compact=True)
|
||||
|
||||
def __repr__(self):
|
||||
return ''.join(getattr(self, k).__repr__() for k in self._keys)
|
||||
return self.__str__()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
@ -261,7 +263,7 @@ class Boxes:
|
||||
return self.boxes.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.boxes)}\n' +
|
||||
return (f'Ultralytics YOLO {self.__class__.__name__}\n' + f'type: {type(self.boxes)}\n' +
|
||||
f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@ -337,7 +339,7 @@ class Masks:
|
||||
return self.masks.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return (f'Ultralytics YOLO {self.__class__} masks\n' + f'type: {type(self.masks)}\n' +
|
||||
return (f'Ultralytics YOLO {self.__class__.__name__}\n' + f'type: {type(self.masks)}\n' +
|
||||
f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -102,7 +102,7 @@ class BaseValidator:
|
||||
model = model.half() if self.args.half else model.float()
|
||||
self.model = model
|
||||
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
self.args.plots = trainer.epoch == trainer.epochs - 1 # always plot final epoch
|
||||
self.args.plots = trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
||||
model.eval()
|
||||
else:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
Reference in New Issue
Block a user