Check PyTorch model status for all YOLO
methods (#945)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = "8.0.35"
|
||||
__version__ = "8.0.36"
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||
|
@ -5,12 +5,12 @@ import requests
|
||||
from ultralytics.hub.auth import Auth
|
||||
from ultralytics.hub.session import HubTrainingSession
|
||||
from ultralytics.hub.utils import split_key
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
|
||||
|
||||
# Define all export formats
|
||||
EXPORT_FORMATS = list(export_formats()['Argument'][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||
|
||||
|
||||
def start(key=""):
|
||||
@ -69,7 +69,7 @@ def reset_model(key=""):
|
||||
|
||||
def export_model(key="", format="torchscript"):
|
||||
# Export a model to all formats
|
||||
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
|
||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||
api_key, model_id = split_key(key)
|
||||
r = requests.post("https://api.ultralytics.com/export",
|
||||
json={
|
||||
@ -82,7 +82,7 @@ def export_model(key="", format="torchscript"):
|
||||
|
||||
def get_export(key="", format="torchscript"):
|
||||
# Get an exported model dictionary with download URL
|
||||
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
|
||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||
api_key, model_id = split_key(key)
|
||||
r = requests.post("https://api.ultralytics.com/get-export",
|
||||
json={
|
||||
|
@ -193,7 +193,7 @@ class AutoBackend(nn.Module):
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
except ImportError:
|
||||
import tensorflow as tf
|
||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
||||
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
||||
delegate = {
|
||||
@ -232,8 +232,10 @@ class AutoBackend(nn.Module):
|
||||
nhwc = model.runtime.startswith("tensorflow")
|
||||
'''
|
||||
else:
|
||||
raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see "
|
||||
f"https://docs.ultralytics.com/reference/nn/")
|
||||
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
|
||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||
"See https://docs.ultralytics.com/tasks/detection/#export for help."
|
||||
f"\n\n{EXPORT_FORMATS_TABLE}")
|
||||
|
||||
# class names
|
||||
if 'names' not in locals(): # names missing
|
||||
|
@ -356,7 +356,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
model.args = args # attach args to model
|
||||
model.pt_path = weights # attach *.pt file path to model
|
||||
model.task = guess_model_task(model)
|
||||
if not hasattr(model, 'stride'):
|
||||
|
@ -12,8 +12,8 @@ from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_P
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print)
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
"""
|
||||
YOLOv8 'yolo' CLI commands use the following syntax:
|
||||
f"""
|
||||
Arguments received: {str(['yolo'] + sys.argv[1:])}. Note that Ultralytics 'yolo' commands use the following syntax:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
@ -64,9 +64,7 @@ CFG_BOOL_KEYS = {
|
||||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary.
|
||||
|
||||
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
|
||||
Inputs:
|
||||
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||
@ -143,8 +141,9 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||
if mismatched:
|
||||
string = ''
|
||||
for x in mismatched:
|
||||
matches = get_close_matches(x, base)
|
||||
match_str = f"Similar arguments are {matches}." if matches else ''
|
||||
matches = get_close_matches(x, base) # key list
|
||||
matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT[k] is not None else k for k in matches] # k=v
|
||||
match_str = f"Similar arguments are i.e. {matches}." if matches else ''
|
||||
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
||||
raise SyntaxError(string + CLI_HELP_MSG) from e
|
||||
|
||||
@ -265,7 +264,7 @@ def entrypoint(debug=''):
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||
elif mode not in modes:
|
||||
if mode != 'checks':
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.")
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
|
||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||
checks.check_yolo()
|
||||
return
|
||||
|
@ -682,7 +682,8 @@ def v8_transforms(dataset, imgsz, hyp):
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
def classify_transforms(size=224):
|
||||
# Transforms to apply if albumentations not installed
|
||||
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
||||
if not isinstance(size, int):
|
||||
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
|
||||
|
@ -48,7 +48,6 @@ TensorFlow.js:
|
||||
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
|
||||
$ npm start
|
||||
"""
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
@ -74,7 +73,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, __version__, callbacks,
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode, get_latest_opset
|
||||
|
||||
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||
|
||||
@ -97,6 +96,10 @@ def export_formats():
|
||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
EXPORT_FORMATS_LIST = list(export_formats()['Argument'][1:])
|
||||
EXPORT_FORMATS_TABLE = str(export_formats())
|
||||
|
||||
|
||||
def try_export(inner_func):
|
||||
# YOLOv8 export decorator, i..e @try_export
|
||||
inner_args = get_default_args(inner_func)
|
||||
@ -244,7 +247,7 @@ class Exporter:
|
||||
agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self._export_edgetpu()
|
||||
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
|
||||
self._add_tflite_metadata(f[8] or f[7])
|
||||
if tfjs:
|
||||
f[9], _ = self._export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
@ -253,11 +256,11 @@ class Exporter:
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
if any(f):
|
||||
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
|
||||
f = str(Path(f[-1]))
|
||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f"\nPredict: yolo task={model.task} mode=predict model={f[-1]} {s}"
|
||||
f"\nValidate: yolo task={model.task} mode=val model={f[-1]} {s}"
|
||||
f"\nPredict: yolo task={model.task} mode=predict model={f}"
|
||||
f"\nValidate: yolo task={model.task} mode=val model={f}"
|
||||
f"\nVisualize: https://netron.app")
|
||||
|
||||
self.run_callbacks("on_export_end")
|
||||
@ -304,7 +307,7 @@ class Exporter:
|
||||
self.im.cpu() if dynamic else self.im,
|
||||
f,
|
||||
verbose=False,
|
||||
opset_version=self.args.opset,
|
||||
opset_version=self.args.opset or get_latest_opset(),
|
||||
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||
input_names=['images'],
|
||||
output_names=output_names,
|
||||
@ -507,6 +510,10 @@ class Exporter:
|
||||
# Export to TF SavedModel
|
||||
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
|
||||
|
||||
# Add TFLite metadata
|
||||
for tflite_file in Path(f).rglob('*.tflite'):
|
||||
self._add_tflite_metadata(tflite_file)
|
||||
|
||||
# Load saved_model
|
||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||
|
||||
@ -661,44 +668,47 @@ class Exporter:
|
||||
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
|
||||
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
||||
r'{"outputs": {"Identity": {"name": "Identity"}, '
|
||||
r'"Identity_1": {"name": "Identity_1"}, '
|
||||
r'"Identity_2": {"name": "Identity_2"}, '
|
||||
r'"Identity_3": {"name": "Identity_3"}}}', f_json.read_text())
|
||||
r'"Identity_3": {"name": "Identity_3"}}}',
|
||||
f_json.read_text(),
|
||||
)
|
||||
j.write(subst)
|
||||
return f, None
|
||||
|
||||
def _add_tflite_metadata(self, file, num_outputs):
|
||||
def _add_tflite_metadata(self, file):
|
||||
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
|
||||
with contextlib.suppress(ImportError):
|
||||
# check_requirements('tflite_support')
|
||||
from tflite_support import flatbuffers # noqa
|
||||
from tflite_support import metadata as _metadata # noqa
|
||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||
check_requirements('tflite_support')
|
||||
|
||||
tmp_file = Path('/tmp/meta.txt')
|
||||
with open(tmp_file, 'w') as meta_f:
|
||||
meta_f.write(str(self.metadata))
|
||||
from tflite_support import flatbuffers # noqa
|
||||
from tflite_support import metadata as _metadata # noqa
|
||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
label_file = _metadata_fb.AssociatedFileT()
|
||||
label_file.name = tmp_file.name
|
||||
model_meta.associatedFiles = [label_file]
|
||||
tmp_file = Path('/tmp/meta.txt')
|
||||
with open(tmp_file, 'w') as meta_f:
|
||||
meta_f.write(str(self.metadata))
|
||||
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
|
||||
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
label_file = _metadata_fb.AssociatedFileT()
|
||||
label_file.name = tmp_file.name
|
||||
model_meta.associatedFiles = [label_file]
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf = b.Output()
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
|
||||
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * len(self.output_shape)
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_file(file)
|
||||
populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_associated_files([str(tmp_file)])
|
||||
populator.populate()
|
||||
tmp_file.unlink()
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf = b.Output()
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_file(file)
|
||||
populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_associated_files([str(tmp_file)])
|
||||
populator.populate()
|
||||
tmp_file.unlink()
|
||||
|
||||
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
||||
# YOLOv8 CoreML pipeline
|
||||
|
@ -6,11 +6,11 @@ from typing import List
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
||||
guess_model_task)
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
@ -55,19 +55,16 @@ class YOLO:
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.metrics_data = None
|
||||
|
||||
# Load or create new YOLO model
|
||||
suffix = Path(model).suffix
|
||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
try:
|
||||
if suffix == '.yaml':
|
||||
self._new(model)
|
||||
else:
|
||||
self._load(model)
|
||||
except Exception as e:
|
||||
raise NotImplementedError(f"Unable to load model='{model}'. "
|
||||
f"As an example try model='yolov8n.pt' or model='yolov8n.yaml'") from e
|
||||
if suffix == '.yaml':
|
||||
self._new(model)
|
||||
else:
|
||||
self._load(model)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
return self.predict(source, stream, **kwargs)
|
||||
@ -100,15 +97,27 @@ class YOLO:
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
else:
|
||||
check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
Raises TypeError is model is not a PyTorch model
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"model='{self.model}' must be a PyTorch model, but is a different type. PyTorch models "
|
||||
f"can be used to train, val, predict and export, i.e. "
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the model modules.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
@ -122,9 +131,11 @@ class YOLO:
|
||||
Args:
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.info(verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
@ -176,6 +187,8 @@ class YOLO:
|
||||
|
||||
validator = self.ValidatorClass(args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics_data = validator.metrics
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
@ -186,7 +199,7 @@ class YOLO:
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
@ -196,7 +209,7 @@ class YOLO:
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
exporter = Exporter(overrides=args)
|
||||
exporter(model=self.model)
|
||||
return exporter(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
@ -205,6 +218,7 @@ class YOLO:
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get("cfg"):
|
||||
@ -226,6 +240,7 @@ class YOLO:
|
||||
if RANK in {0, -1}:
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
self.metrics_data = self.trainer.validator.metrics
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
@ -234,15 +249,14 @@ class YOLO:
|
||||
Args:
|
||||
device (str): device
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.to(device)
|
||||
|
||||
def _assign_ops_from_task(self):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||
# warning: eval is unsafe. Use with caution
|
||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
||||
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
||||
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
@ -250,7 +264,7 @@ class YOLO:
|
||||
"""
|
||||
Returns class names of the loaded model.
|
||||
"""
|
||||
return self.model.names
|
||||
return self.model.names if hasattr(self.model, 'names') else None
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
@ -259,6 +273,16 @@ class YOLO:
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""
|
||||
Returns metrics if computed
|
||||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info("No metrics data found! Run training or validation operation first.")
|
||||
|
||||
return self.metrics_data
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
"""
|
||||
@ -269,5 +293,5 @@ class YOLO:
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
|
||||
args.pop(arg, None)
|
||||
|
@ -35,6 +35,7 @@ import torch
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data import load_inference_source
|
||||
from ultralytics.yolo.data.augment import classify_transforms
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
@ -121,8 +122,12 @@ class BasePredictor:
|
||||
|
||||
def setup_source(self, source):
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
if self.args.task == 'classify':
|
||||
transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0]))
|
||||
else: # predict, segment
|
||||
transforms = None
|
||||
self.dataset = load_inference_source(source=source,
|
||||
transforms=getattr(self.model.model, 'transforms', None),
|
||||
transforms=transforms,
|
||||
imgsz=self.imgsz,
|
||||
vid_stride=self.args.vid_stride,
|
||||
stride=self.model.stride,
|
||||
|
@ -217,19 +217,18 @@ class BaseTrainer:
|
||||
|
||||
# Optimizer
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
self.optimizer = self.build_optimizer(model=self.model,
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
momentum=self.args.momentum,
|
||||
decay=self.args.weight_decay)
|
||||
decay=weight_decay)
|
||||
# Scheduler
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
|
||||
# dataloaders
|
||||
@ -242,6 +241,7 @@ class BaseTrainer:
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks("on_pretrain_routine_end")
|
||||
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
@ -555,6 +555,12 @@ class BaseTrainer:
|
||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
self.console.info("Closing dataloader mosaic")
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
|
@ -234,17 +234,17 @@ def check_yolov5u_filename(file: str):
|
||||
return file
|
||||
|
||||
|
||||
def check_file(file, suffix=''):
|
||||
def check_file(file, suffix='', download=True):
|
||||
# Search/download file (if necessary) and return path
|
||||
check_suffix(file, suffix) # optional
|
||||
file = str(file) # convert to string
|
||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||
if not file or ('://' not in file and Path(file).is_file()): # exists ('://' check required in Windows Python<3.10)
|
||||
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
|
||||
return file
|
||||
elif file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
|
||||
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
|
||||
url = file # warning: Pathlib turns :// -> :/
|
||||
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
if Path(file).is_file():
|
||||
if Path(file).exists():
|
||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||
else:
|
||||
downloads.safe_download(url=url, file=file, unzip=False)
|
||||
|
@ -44,11 +44,17 @@ def generate_ddp_file(trainer):
|
||||
|
||||
def generate_ddp_command(world_size, trainer):
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
file = generate_ddp_file(trainer) if sys.argv[0].endswith('yolo') else os.path.abspath(sys.argv[0])
|
||||
|
||||
# Get file and args (do not use sys.argv due to security vulnerability)
|
||||
exclude_args = ['save_dir']
|
||||
args = [f"{k}={v}" for k, v in vars(trainer.args).items() if k not in exclude_args]
|
||||
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
||||
|
||||
# Build command
|
||||
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
||||
cmd = [
|
||||
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
|
||||
f"{find_free_network_port()}", file] + sys.argv[1:]
|
||||
f"{find_free_network_port()}", file] + args
|
||||
return cmd, file
|
||||
|
||||
|
||||
|
@ -242,6 +242,11 @@ def copy_attr(a, b, include=(), exclude=()):
|
||||
setattr(a, k, v)
|
||||
|
||||
|
||||
def get_latest_opset():
|
||||
# Return max supported ONNX opset by this version of torch
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) # opset
|
||||
|
||||
|
||||
def intersect_dicts(da, db, exclude=()):
|
||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
||||
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||
|
Reference in New Issue
Block a user