ultralytics 8.0.47 Docker and reformat updates (#1153)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-25 22:49:19 -08:00
committed by GitHub
parent d4be4cb24b
commit a58f766f94
41 changed files with 224 additions and 201 deletions

View File

@ -50,7 +50,6 @@ TensorFlow.js:
import json
import os
import platform
import re
import subprocess
import time
import warnings
@ -90,9 +89,9 @@ def export_formats():
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
['TensorFlow GraphDef', 'pb', '.pb', True, True],
['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
['TensorFlow.js', 'tfjs', '_web_model', False, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
['TensorFlow.js', 'tfjs', '_web_model', True, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True], ]
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
@ -100,6 +99,15 @@ EXPORT_FORMATS_LIST = list(export_formats()['Argument'][1:])
EXPORT_FORMATS_TABLE = str(export_formats())
def gd_outputs(gd):
# TensorFlow GraphDef model output node names
name_list, input_list = [], []
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
name_list.append(node.name)
input_list.extend(node.input)
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
def try_export(inner_func):
# YOLOv8 export decorator, i..e @try_export
inner_args = get_default_args(inner_func)
@ -164,10 +172,10 @@ class Exporter:
# Checks
model.names = check_class_names(model.names)
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
if model.task == 'classify':
self.args.nms = self.args.agnostic_nms = False
if self.args.optimize:
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
if edgetpu and not LINUX:
raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
# Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
@ -208,7 +216,7 @@ class Exporter:
self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
description = f'Ultralytics {self.pretty_name} model' + f'trained on {Path(self.args.data).name}' \
description = f'Ultralytics {self.pretty_name} model ' + f'trained on {Path(self.args.data).name}' \
if self.args.data else '(untrained)'
self.metadata = {
'description': description,
@ -239,8 +247,7 @@ class Exporter:
'Please consider contributing to the effort if you have TF expertise. Thank you!')
nms = False
self.args.int8 |= edgetpu
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
agnostic_nms=self.args.agnostic_nms or tfjs)
f[5], s_model = self._export_saved_model()
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model)
if tflite:
@ -386,7 +393,7 @@ class Exporter:
check_requirements('coremltools>=6.0')
import coremltools as ct # noqa
class iOSModel(torch.nn.Module):
class iOSDetectModel(torch.nn.Module):
# Wrap an Ultralytics YOLO model for iOS export
def __init__(self, model, im):
super().__init__()
@ -405,29 +412,36 @@ class Exporter:
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = self.file.with_suffix('.mlmodel')
bias = [0.0, 0.0, 0.0]
scale = 1 / 255
classifier_config = None
if self.model.task == 'classify':
bias = [-x for x in IMAGENET_MEAN]
scale = 1 / 255 / (sum(IMAGENET_STD) / 3)
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
else:
bias = [0.0, 0.0, 0.0]
scale = 1 / 255
classifier_config = None
model = iOSModel(self.model, self.im).eval() if self.args.nms else self.model
ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model
model = self.model
elif self.model.task == 'detect':
model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model
elif self.model.task == 'segment':
# TODO CoreML Segmentation model pipelining
model = self.model
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
ct_model = ct.convert(ts,
inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
classifier_config=classifier_config)
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
if bits < 32:
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
if self.args.nms:
if self.args.nms and self.model.task == 'detect':
ct_model = self._pipeline_coreml(ct_model)
ct_model.short_description = self.metadata['description']
ct_model.author = self.metadata['author']
ct_model.license = self.metadata['license']
ct_model.version = self.metadata['version']
m = self.metadata # metadata dict
ct_model.short_description = m['description']
ct_model.author = m['author']
ct_model.license = m['license']
ct_model.version = m['version']
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items() if k in ('stride', 'task', 'names')})
ct_model.save(str(f))
return f, ct_model
@ -497,14 +511,7 @@ class Exporter:
return f, None
@try_export
def _export_saved_model(self,
nms=False,
agnostic_nms=False,
topk_per_class=100,
topk_all=100,
iou_thres=0.45,
conf_thres=0.25,
prefix=colorstr('TensorFlow SavedModel:')):
def _export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv8 TensorFlow SavedModel export
try:
@ -562,6 +569,9 @@ class Exporter:
@try_export
def _export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
# YOLOv8 TensorFlow Lite export
import tensorflow as tf # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
if self.args.int8:
f = saved_model / (self.file.stem + 'yolov8n_integer_quant.tflite') # fp32 in/out
@ -572,9 +582,6 @@ class Exporter:
return str(f), None # noqa
# OLD VERSION BELOW ---------------------------------------------------------------
import tensorflow as tf # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
batch_size, ch, *imgsz = list(self.im.shape) # BCHW
f = str(self.file).replace(self.file.suffix, '-fp16.tflite')
@ -619,7 +626,9 @@ class Exporter:
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
for c in (
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
# 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', # errors
'wget --no-check-certificate -q -O - https://packages.cloud.google.com/apt/doc/apt-key.gpg | '
'sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' # no comma
'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update',
@ -639,30 +648,36 @@ class Exporter:
def _export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
# YOLOv8 TensorFlow.js export
check_requirements('tensorflowjs')
import tensorflow as tf
import tensorflowjs as tfjs # noqa
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
f_pb = self.file.with_suffix('.pb') # *.pb path
f_json = Path(f) / 'model.json' # *.json path
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
gd = tf.Graph().as_graph_def() # TF GraphDef
with open(f_pb, 'rb') as file:
gd.ParseFromString(file.read())
outputs = ','.join(gd_outputs(gd))
LOGGER.info(f'\n{prefix} output node names: {outputs}')
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} {f_pb} {f}'
subprocess.run(cmd.split(), check=True)
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
subst = re.sub(
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity_1": {"name": "Identity_1"}, '
r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}',
f_json.read_text(),
)
j.write(subst)
# f_json = Path(f) / 'model.json' # *.json path
# with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
# subst = re.sub(
# r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
# r'"Identity.?.?": {"name": "Identity.?.?"}}}',
# r'{"outputs": {"Identity": {"name": "Identity"}, '
# r'"Identity_1": {"name": "Identity_1"}, '
# r'"Identity_2": {"name": "Identity_2"}, '
# r'"Identity_3": {"name": "Identity_3"}}}',
# f_json.read_text(),
# )
# j.write(subst)
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
return f, None
@ -680,7 +695,7 @@ class Exporter:
model_meta.license = self.metadata['license']
# Label file
tmp_file = file.parent / 'temp_meta.txt'
tmp_file = Path(file).parent / 'temp_meta.txt'
with open(tmp_file, 'w') as f:
f.write(str(self.metadata))
@ -718,7 +733,7 @@ class Exporter:
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(file)
populator = _metadata.MetadataPopulator.with_model_file(str(file))
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()

View File

@ -2,7 +2,6 @@
import sys
from pathlib import Path
from typing import List
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
@ -68,7 +67,7 @@ class YOLO:
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt') -> None:
def __init__(self, model='yolov8n.pt', task=None) -> None:
"""
Initializes the YOLO model.
@ -91,9 +90,9 @@ class YOLO:
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
if suffix == '.yaml':
self._new(model)
self._new(model, task)
else:
self._load(model)
self._load(model, task)
def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, **kwargs)
@ -102,17 +101,18 @@ class YOLO:
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def _new(self, cfg: str, verbose=True):
def _new(self, cfg: str, task=None, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
task (str) or (None): model task
verbose (bool): display model info on load
"""
self.cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
self.task = guess_model_task(cfg_dict)
self.task = task or guess_model_task(cfg_dict)
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg
@ -121,12 +121,13 @@ class YOLO:
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
self.model.task = self.task
def _load(self, weights: str, task=''):
def _load(self, weights: str, task=None):
"""
Initializes a new model and infers the task type from the model head.
Args:
weights (str): model checkpoint to be loaded
task (str) or (None): model task
"""
suffix = Path(weights).suffix
if suffix == '.pt':
@ -137,7 +138,7 @@ class YOLO:
else:
weights = check_file(weights)
self.model, self.ckpt = weights, None
self.task = guess_model_task(weights)
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
self.overrides['model'] = weights

View File

@ -32,7 +32,6 @@ from collections import defaultdict
from pathlib import Path
import cv2
import torch
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg

View File

@ -242,7 +242,7 @@ class BaseTrainer:
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
self.ema = ModelEMA(self.model)
if self.args.plots:
if self.args.plots and not self.args.v5loader:
self.plot_training_labels()
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move