ultralytics 8.0.54
TFLite export improvements and fixes (#1447)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -188,7 +188,7 @@ class Exporter:
|
||||
m.dynamic = self.args.dynamic
|
||||
m.export = True
|
||||
m.format = self.args.format
|
||||
elif isinstance(m, C2f) and not edgetpu:
|
||||
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
|
||||
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
||||
m.forward = m.forward_split
|
||||
|
||||
|
@ -8,8 +8,8 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, ONLINE, RANK, ROOT,
|
||||
callbacks, is_git_dir, is_pip_package, yaml_load)
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
is_git_dir, yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
@ -153,16 +153,10 @@ class YOLO:
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
def _check_pip_update(self):
|
||||
@smart_inference_mode()
|
||||
def reset_weights(self):
|
||||
"""
|
||||
Inform user of ultralytics package update availability
|
||||
"""
|
||||
if ONLINE and is_pip_package():
|
||||
check_pip_update_available()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the model modules.
|
||||
Resets the model modules parameters to randomly initialized values, losing all training information.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
for m in self.model.modules():
|
||||
@ -170,6 +164,18 @@ class YOLO:
|
||||
m.reset_parameters()
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True
|
||||
return self
|
||||
|
||||
@smart_inference_mode()
|
||||
def load(self, weights='yolov8n.pt'):
|
||||
"""
|
||||
Transfers parameters with matching names and shapes from 'weights' to model.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if isinstance(weights, (str, Path)):
|
||||
weights, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def info(self, verbose=False):
|
||||
"""
|
||||
@ -299,7 +305,7 @@ class YOLO:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self._check_pip_update()
|
||||
check_pip_update_available()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get('cfg'):
|
||||
|
@ -48,7 +48,7 @@ class Results:
|
||||
self.probs = probs if probs is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self._keys = [k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None]
|
||||
self._keys = ('boxes', 'masks', 'probs')
|
||||
|
||||
def pandas(self):
|
||||
pass
|
||||
@ -56,7 +56,7 @@ class Results:
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for k in self._keys:
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k)[idx])
|
||||
return r
|
||||
|
||||
@ -70,30 +70,30 @@ class Results:
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for k in self._keys:
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cpu())
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for k in self._keys:
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).numpy())
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for k in self._keys:
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cuda())
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for k in self._keys:
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
for k in self._keys:
|
||||
for k in self.keys:
|
||||
return len(getattr(self, k))
|
||||
|
||||
def __str__(self):
|
||||
@ -107,6 +107,10 @@ class Results:
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return [k for k in self._keys if getattr(self, k) is not None]
|
||||
|
||||
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""
|
||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||
|
Reference in New Issue
Block a user