ultralytics 8.0.65
YOLOv8 Pose models (#1347)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mert Can Demir <validatedev@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Fabian Greavu <fabiangreavu@gmail.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Eric Pedley <ericpedley@gmail.com> Co-authored-by: JustasBart <40023722+JustasBart@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: Sergio Sanchez <sergio.ssm.97@gmail.com> Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com> Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com> Co-authored-by: Farid Inawan <frdteknikelektro@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Alexander Duda <Alexander.Duda@me.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Snyk bot <snyk-bot@snyk.io> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
@ -209,8 +209,8 @@ class Exporter:
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
|
||||
description = f'Ultralytics {self.pretty_name} model ' + f'trained on {Path(self.args.data).name}' \
|
||||
if self.args.data else '(untrained)'
|
||||
trained_on = f'trained on {Path(self.args.data).name}' if self.args.data else '(untrained)'
|
||||
description = f'Ultralytics {self.pretty_name} model {trained_on}'
|
||||
self.metadata = {
|
||||
'description': description,
|
||||
'author': 'Ultralytics',
|
||||
@ -221,6 +221,8 @@ class Exporter:
|
||||
'batch': self.args.batch,
|
||||
'imgsz': self.imgsz,
|
||||
'names': model.names} # model metadata
|
||||
if model.task == 'pose':
|
||||
self.metadata['kpt_shape'] = model.kpt_shape
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
|
||||
@ -295,7 +297,8 @@ class Exporter:
|
||||
check_requirements(requirements)
|
||||
import onnx # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
|
||||
opset_version = self.args.opset or get_latest_opset()
|
||||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
|
||||
f = str(self.file.with_suffix('.onnx'))
|
||||
|
||||
output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
|
||||
@ -313,7 +316,7 @@ class Exporter:
|
||||
self.im.cpu() if dynamic else self.im,
|
||||
f,
|
||||
verbose=False,
|
||||
opset_version=self.args.opset or get_latest_opset(),
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||
input_names=['images'],
|
||||
output_names=output_names,
|
||||
@ -377,7 +380,6 @@ class Exporter:
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def _export_coreml(self, prefix=colorstr('CoreML:')):
|
||||
# YOLOv8 CoreML export
|
||||
check_requirements('coremltools>=6.0')
|
||||
@ -410,8 +412,8 @@ class Exporter:
|
||||
model = self.model
|
||||
elif self.model.task == 'detect':
|
||||
model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model
|
||||
elif self.model.task == 'segment':
|
||||
# TODO CoreML Segmentation model pipelining
|
||||
else:
|
||||
# TODO CoreML Segment and Pose model pipelining
|
||||
model = self.model
|
||||
|
||||
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
||||
|
@ -5,8 +5,8 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
||||
guess_model_task, nn, yaml_model_load)
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
|
||||
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
@ -25,7 +25,8 @@ TASK_MAP = {
|
||||
yolo.v8.detect.DetectionPredictor],
|
||||
'segment': [
|
||||
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
|
||||
yolo.v8.segment.SegmentationPredictor]}
|
||||
yolo.v8.segment.SegmentationPredictor],
|
||||
'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]}
|
||||
|
||||
|
||||
class YOLO:
|
||||
@ -195,7 +196,7 @@ class YOLO:
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def info(self, verbose=False):
|
||||
def info(self, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
|
@ -246,6 +246,7 @@ class BasePredictor:
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
fuse=True,
|
||||
verbose=verbose)
|
||||
self.device = device
|
||||
self.model.eval()
|
||||
|
@ -17,6 +17,53 @@ from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
|
||||
|
||||
|
||||
class BaseTensor(SimpleClass):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
tensor (torch.Tensor): A tensor.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the tensor on CPU memory.
|
||||
numpy(): Returns a copy of the tensor as a numpy array.
|
||||
cuda(): Returns a copy of the tensor on GPU memory.
|
||||
to(): Returns a copy of the tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor, orig_shape) -> None:
|
||||
super().__init__()
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
self.tensor = tensor
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.data.shape
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.tensor
|
||||
|
||||
def cpu(self):
|
||||
return self.__class__(self.data.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
return self.__class__(self.data.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
return self.__class__(self.data.cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self.__class__(self.data.to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.__class__(self.data[idx], self.orig_shape)
|
||||
|
||||
|
||||
class Results(SimpleClass):
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
@ -40,22 +87,23 @@ class Results(SimpleClass):
|
||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.keypoints = keypoints if keypoints is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self._keys = ('boxes', 'masks', 'probs')
|
||||
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
||||
|
||||
def pandas(self):
|
||||
pass
|
||||
# TODO masks.pandas + boxes.pandas + cls.pandas
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k)[idx])
|
||||
return r
|
||||
@ -69,25 +117,25 @@ class Results(SimpleClass):
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cpu())
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).numpy())
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cuda())
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
||||
return r
|
||||
@ -96,6 +144,9 @@ class Results(SimpleClass):
|
||||
for k in self.keys:
|
||||
return len(getattr(self, k))
|
||||
|
||||
def new(self):
|
||||
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return [k for k in self._keys if getattr(self, k) is not None]
|
||||
@ -109,6 +160,7 @@ class Results(SimpleClass):
|
||||
pil=False,
|
||||
example='abc',
|
||||
img=None,
|
||||
kpt_line=True,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
masks=True,
|
||||
@ -126,6 +178,7 @@ class Results(SimpleClass):
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
example (str): An example string to display. Useful for indicating the expected format of the output.
|
||||
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
||||
kpt_line (bool): Whether to draw lines connecting keypoints.
|
||||
labels (bool): Whether to plot the label of bounding boxes.
|
||||
boxes (bool): Whether to plot the bounding boxes.
|
||||
masks (bool): Whether to plot the masks.
|
||||
@ -146,11 +199,12 @@ class Results(SimpleClass):
|
||||
pred_masks, show_masks = self.masks, masks
|
||||
pred_probs, show_probs = self.probs, probs
|
||||
names = self.names
|
||||
keypoints = self.keypoints
|
||||
if pred_boxes and show_boxes:
|
||||
for d in reversed(pred_boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||
name = ('' if id is None else f'id:{id} ') + names[c]
|
||||
label = (name if not conf else f'{name} {conf:.2f}') if labels else None
|
||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
if pred_masks and show_masks:
|
||||
@ -168,10 +222,14 @@ class Results(SimpleClass):
|
||||
text = f"{', '.join(f'{names[j] if names else j} {pred_probs[j]:.2f}' for j in top5i)}, "
|
||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
if keypoints is not None:
|
||||
for k in reversed(keypoints):
|
||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||
|
||||
return np.asarray(annotator.im) if annotator.pil else annotator.im
|
||||
|
||||
|
||||
class Boxes(SimpleClass):
|
||||
class Boxes(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection boxes.
|
||||
|
||||
@ -246,37 +304,15 @@ class Boxes(SimpleClass):
|
||||
def xywhn(self):
|
||||
return self.xywh / self.orig_shape[[1, 0, 1, 0]]
|
||||
|
||||
def cpu(self):
|
||||
return Boxes(self.boxes.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
return Boxes(self.boxes.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
return Boxes(self.boxes.cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return Boxes(self.boxes.to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def pandas(self):
|
||||
LOGGER.info('results.pandas() method not yet implemented')
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.boxes.shape
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.boxes
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
return len(self.boxes)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return Boxes(self.boxes[idx], self.orig_shape)
|
||||
|
||||
|
||||
class Masks(SimpleClass):
|
||||
class Masks(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection masks.
|
||||
|
||||
@ -316,7 +352,7 @@ class Masks(SimpleClass):
|
||||
def xyn(self):
|
||||
# Segments (normalized)
|
||||
return [
|
||||
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=True)
|
||||
ops.scale_coords(self.masks.shape[1:], x, self.orig_shape, normalize=True)
|
||||
for x in ops.masks2segments(self.masks)]
|
||||
|
||||
@property
|
||||
@ -324,31 +360,9 @@ class Masks(SimpleClass):
|
||||
def xy(self):
|
||||
# Segments (pixels)
|
||||
return [
|
||||
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=False)
|
||||
ops.scale_coords(self.masks.shape[1:], x, self.orig_shape, normalize=False)
|
||||
for x in ops.masks2segments(self.masks)]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.masks.shape
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.masks
|
||||
|
||||
def cpu(self):
|
||||
return Masks(self.masks.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
return Masks(self.masks.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
return Masks(self.masks.cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return Masks(self.masks.to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
return len(self.masks)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return Masks(self.masks[idx], self.orig_shape)
|
||||
|
Reference in New Issue
Block a user