Add CoreML iOS updates (#121)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent fec13ec773
commit c9f3e469cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,7 @@ pandas>=1.1.4
seaborn>=0.11.0
# Export --------------------------------------
# coremltools>=5.2 # CoreML export
# coremltools>=6.0 # CoreML export
# onnx>=1.12.0 # ONNX export
# onnx-simplifier>=0.4.1 # ONNX simplifier
# nvidia-pyindex # TensorRT export

@ -89,7 +89,7 @@ class DetectionModel(BaseModel):
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
self.inplace = self.yaml.get('inplace', True)
# Build strides

@ -73,6 +73,7 @@ dynamic: False # ONNX/TF/TensorRT: dynamic axes
simplify: False # ONNX: simplify model
opset: 17 # ONNX: opset version
workspace: 4 # TensorRT: workspace size (GB)
nms: False # CoreML: add NMS
# Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)

@ -64,6 +64,7 @@ import numpy as np
import pandas as pd
import torch
import ultralytics
from ultralytics.nn.modules import Detect, Segment
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
from ultralytics.yolo.configs import get_config
@ -73,7 +74,7 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, get_default
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size, increment_path, yaml_save
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode
MACOS = platform.system() == 'Darwin' # macOS environment
@ -119,7 +120,7 @@ class Exporter:
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
self.args = get_config(config, overrides)
project = self.args.project or f"runs/{self.args.task}"
name = self.args.name or f"{self.args.mode}"
name = self.args.name or "exp" # hardcode mode as export doesn't require it
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.imgsz = self.args.imgsz
@ -136,22 +137,20 @@ class Exporter:
# Load PyTorch model
self.device = select_device(self.args.device)
if self.args.half:
assert self.device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
if self.device.type == 'cpu' or not coreml:
LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
self.args.half = False
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
# Checks
if isinstance(self.imgsz, int):
self.imgsz = [self.imgsz]
self.imgsz *= 2 if len(self.imgsz) == 1 else 1 # expand
self.imgsz = check_imgsz(self.imgsz, stride=model.stride, min_dim=2) # check image size
if self.args.optimize:
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
# Input
self.args.batch_size = 1 # TODO: resolve this issue, default 16 not fit for export
gs = int(max(model.stride)) # grid size (max stride)
imgsz = [check_imgsz(x, gs) for x in self.imgsz] # verify img_size are gs-multiples
im = torch.zeros(self.args.batch_size, 3, *imgsz).to(self.device) # image size(1,3,320,192) BCHW iDetection
file = Path(Path(model.yaml['yaml_file']).name)
im = torch.zeros(self.args.batch_size, 3, *self.imgsz).to(self.device)
file = Path(getattr(model, 'yaml_file', None) or Path(model.yaml['yaml_file']).name)
# Update model
model = deepcopy(model)
@ -182,7 +181,9 @@ class Exporter:
self.im = im
self.model = model
self.file = file
self.output_shape = tuple(y.shape)
self.metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
# Exports
f = [''] * len(fmts) # exported filenames
@ -202,7 +203,7 @@ class Exporter:
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
agnostic_nms=self.args.agnostic_nms or tfjs)
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model,)
f[6], _ = self._export_pb(s_model)
if tflite or edgetpu:
f[7], _ = self._export_tflite(s_model,
int8=self.args.int8 or edgetpu,
@ -220,11 +221,8 @@ class Exporter:
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
cls, det, seg = (isinstance(model, x)
for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type
det &= not seg # segmentation models inherit from SegmentationModel(DetectionModel)
task = guess_task_from_head(model.yaml["head"][-1][-2])
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
task = 'detect' if det else 'segment' if seg else 'classify' if cls else ''
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}"
@ -337,13 +335,30 @@ class Exporter:
@try_export
def _export_coreml(self, prefix=colorstr('CoreML:')):
# YOLOv5 CoreML export
check_requirements('coremltools')
check_requirements('coremltools>=6.0')
import coremltools as ct # noqa
class iOSModel(torch.nn.Module):
# Wrap an Ultralytics YOLO model for iOS export
def __init__(self, model, im):
super().__init__()
b, c, h, w = im.shape # batch, channel, height, width
self.model = model
self.nc = len(model.names) # number of classes
if w == h:
self.normalize = 1.0 / w # scalar
else:
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
def forward(self, x):
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = self.file.with_suffix('.mlmodel')
ts = torch.jit.trace(self.model, self.im, strict=False) # TorchScript model
model = iOSModel(self.model, self.im) if self.args.nms else self.model
ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])])
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
if bits < 32:
@ -351,6 +366,9 @@ class Exporter:
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
else:
LOGGER.info(f'{prefix} quantization only supported on macOS, skipping...')
if self.args.nms:
ct_model = self._pipeline_coreml(ct_model)
ct_model.save(str(f))
return f, ct_model
@ -525,8 +543,10 @@ class Exporter:
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
for c in (
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' # no comma
'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update',
'sudo apt-get install edgetpu-compiler'):
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
@ -597,6 +617,127 @@ class Exporter:
populator.populate()
tmp_file.unlink()
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
# YOLOv5 CoreML pipeline
import coremltools as ct # noqa
LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
batch_size, ch, h, w = list(self.im.shape) # BCHW
# Output shapes
spec = model.get_spec()
out0, out1 = iter(spec.description.output)
if MACOS:
from PIL import Image
img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
# img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
out = model.predict({'image': img})
out0_shape = out[out0.name].shape
out1_shape = out[out1.name].shape
else: # linux and windows can not run model.predict(), get sizes from pytorch output y
out0_shape = self.output_shape[1], self.output_shape[2] - 5 # (3780, 80)
out1_shape = self.output_shape[1], 4 # (3780, 4)
# Checks
names = self.metadata['names']
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
na, nc = out0_shape
# na, nc = out0.type.multiArrayType.shape # number anchors, classes
assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
# Define output shapes (missing)
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
# spec.neuralNetwork.preprocessing[0].featureName = '0'
# Flexible input shapes
# from coremltools.models.neural_network import flexible_shape_utils
# s = [] # shapes
# s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
# s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
# flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
# r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
# r.add_height_range((192, 640))
# r.add_width_range((192, 640))
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
# Print
print(spec.description)
# Model from spec
model = ct.models.MLModel(spec)
# 3. Create NMS protobuf
nms_spec = ct.proto.Model_pb2.Model()
nms_spec.specificationVersion = 5
for i in range(2):
decoder_output = model._spec.description.output[i].SerializeToString()
nms_spec.description.input.add()
nms_spec.description.input[i].ParseFromString(decoder_output)
nms_spec.description.output.add()
nms_spec.description.output[i].ParseFromString(decoder_output)
nms_spec.description.output[0].name = 'confidence'
nms_spec.description.output[1].name = 'coordinates'
output_sizes = [nc, 4]
for i in range(2):
ma_type = nms_spec.description.output[i].type.multiArrayType
ma_type.shapeRange.sizeRanges.add()
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
ma_type.shapeRange.sizeRanges[0].upperBound = -1
ma_type.shapeRange.sizeRanges.add()
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
del ma_type.shape[:]
nms = nms_spec.nonMaximumSuppression
nms.confidenceInputFeatureName = out0.name # 1x507x80
nms.coordinatesInputFeatureName = out1.name # 1x507x4
nms.confidenceOutputFeatureName = 'confidence'
nms.coordinatesOutputFeatureName = 'coordinates'
nms.iouThresholdInputFeatureName = 'iouThreshold'
nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
nms.iouThreshold = 0.45
nms.confidenceThreshold = 0.25
nms.pickTop.perClass = True
nms.stringClassLabels.vector.extend(names.values())
nms_model = ct.models.MLModel(nms_spec)
# 4. Pipeline models together
pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
('iouThreshold', ct.models.datatypes.Double()),
('confidenceThreshold', ct.models.datatypes.Double())],
output_features=['confidence', 'coordinates'])
pipeline.add_model(model)
pipeline.add_model(nms_model)
# Correct datatypes
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
# Update metadata
pipeline.spec.specificationVersion = 5
pipeline.spec.description.metadata.versionString = f'Ultralytics YOLOv{ultralytics.__version__}'
pipeline.spec.description.metadata.shortDescription = f'Ultralytics {self.pretty_name} CoreML model'
pipeline.spec.description.metadata.author = 'Ultralytics (https://ultralytics.com)'
pipeline.spec.description.metadata.license = 'GPL-3.0 license (https://ultralytics.com/license)'
pipeline.spec.description.metadata.userDefined.update({
'IoU threshold': str(nms.iouThreshold),
'Confidence threshold': str(nms.confidenceThreshold)})
# Save the model
model = ct.models.MLModel(pipeline.spec)
model.input_description['image'] = 'Input image'
model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
model.input_description['confidenceThreshold'] = \
f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
LOGGER.info(f'{prefix} pipeline success')
return model
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def export(cfg):

@ -1,5 +1,3 @@
from pathlib import Path
import torch
from ultralytics import yolo # noqa required for python usage
@ -7,9 +5,9 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
# map head: [model, trainer, validator, predictor]
MODEL_MAP = {
@ -63,7 +61,7 @@ class YOLO:
cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(cfg) # model dict
obj = cls(init_key=cls.__init_key)
obj.task = obj._guess_task_from_head(cfg_dict["head"][-1][-2])
obj.task = guess_task_from_head(cfg_dict["head"][-1][-2])
obj.ModelClass, obj.TrainerClass, obj.ValidatorClass, obj.PredictorClass = obj._guess_ops_from_task(obj.task)
obj.model = obj.ModelClass(cfg_dict, verbose=verbose) # initialize
obj.cfg = cfg
@ -132,13 +130,7 @@ class YOLO:
overrides["mode"] = "predict"
predictor = self.PredictorClass(overrides=overrides)
# check size type
sz = predictor.args.imgsz
if type(sz) != int: # received listConfig
predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
else:
predictor.args.imgsz = [sz, sz]
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
predictor.setup(model=self.model, source=source)
predictor()
@ -179,7 +171,7 @@ class YOLO:
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args.task = self.task
exporter = Exporter(overrides=overrides)
exporter = Exporter(overrides=args)
exporter(model=self.model)
def train(self, **kwargs):
@ -230,21 +222,6 @@ class YOLO:
self.trainer.train()
@staticmethod
def _guess_task_from_head(head):
task = None
if head.lower() in ["classify", "classifier", "cls", "fc"]:
task = "classify"
if head.lower() in ["detect"]:
task = "detect"
if head.lower() in ["segment"]:
task = "segment"
if not task:
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
return task
def to(self, device):
self.model.to(device)

@ -35,9 +35,9 @@ from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imshow
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import check_imgsz, select_device, smart_inference_mode
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
class BasePredictor:
@ -90,7 +90,7 @@ class BasePredictor:
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
stride, pt = model.stride, model.pt
imgsz = check_imgsz(self.args.imgsz, s=stride) # check image size
imgsz = check_imgsz(self.args.imgsz, stride=stride) # check image size
# Dataloader
bs = 1 # batch_size

@ -14,7 +14,7 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from omegaconf import OmegaConf
from omegaconf import OmegaConf # noqa
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler

@ -2,15 +2,16 @@ import json
from pathlib import Path
import torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf # noqa
from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import check_imgsz, de_parallel, select_device, smart_inference_mode
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart_inference_mode
class BaseValidator:
@ -60,7 +61,7 @@ class BaseValidator:
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
self.model = model
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_imgsz(self.args.imgsz, s=stride)
imgsz = check_imgsz(self.args.imgsz, stride=stride)
if engine:
self.args.batch_size = model.batch_size
else:

@ -22,16 +22,26 @@ def is_ascii(s=''):
return len(s.encode().decode('ascii', 'ignore')) == len(s)
def check_imgsz(imgsz, s=32, floor=0):
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
# Verify image size is a multiple of stride s in each dimension
if isinstance(imgsz, int): # integer i.e. img_size=640
new_size = max(make_divisible(imgsz, int(s)), floor)
else: # list i.e. img_size=[640, 480]
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
if isinstance(imgsz, int): # integer i.e. imgsz=640
sz = max(make_divisible(imgsz, stride), floor)
else: # list i.e. imgsz=[640, 480]
imgsz = list(imgsz) # convert to list if tuple
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
sz = [max(make_divisible(x, stride), floor) for x in imgsz]
if sz != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
# Check dims
if min_dim == 2:
if isinstance(imgsz, int):
sz = [sz, sz]
elif len(sz) == 1:
sz = [sz[0], sz[0]]
return sz
def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):

@ -185,18 +185,6 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
def check_imgsz(imgsz, s=32, floor=0):
# Verify image size is a multiple of stride s in each dimension
if isinstance(imgsz, int): # integer i.e. imgsz=640
new_size = max(make_divisible(imgsz, int(s)), floor)
else: # list i.e. imgsz=[640, 480]
imgsz = list(imgsz) # convert to list if tuple
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size
def make_divisible(x, divisor):
# Returns nearest x divisible by divisor
if isinstance(divisor, torch.Tensor):
@ -293,3 +281,18 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
def guess_task_from_head(head):
task = None
if head.lower() in ["classify", "classifier", "cls", "fc"]:
task = "classify"
if head.lower() in ["detect"]:
task = "detect"
if head.lower() in ["segment"]:
task = "segment"
if not task:
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
return task

@ -3,6 +3,7 @@ import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.utils import DEFAULT_CONFIG
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.plotting import Annotator
@ -54,11 +55,7 @@ class ClassificationPredictor(BasePredictor):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
cfg.model = cfg.model or "squeezenet1_0"
sz = cfg.imgsz
if type(sz) != int: # received listConfig
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else:
cfg.imgsz = [sz, sz]
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
predictor = ClassificationPredictor(cfg)
predictor()

@ -3,6 +3,7 @@ import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.utils import DEFAULT_CONFIG, ops
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
@ -83,11 +84,7 @@ class DetectionPredictor(BasePredictor):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
cfg.model = cfg.model or "n.pt"
sz = cfg.imgsz
if type(sz) != int: # received listConfig
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else:
cfg.imgsz = [sz, sz]
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
predictor = DetectionPredictor(cfg)
predictor()

@ -2,6 +2,7 @@ import hydra
import torch
from ultralytics.yolo.utils import DEFAULT_CONFIG, ops
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.plotting import colors, save_one_box
from ..detect.predict import DetectionPredictor
@ -96,11 +97,7 @@ class SegmentationPredictor(DetectionPredictor):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
cfg.model = cfg.model or "n.pt"
sz = cfg.imgsz
if type(sz) != int: # received listConfig
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else:
cfg.imgsz = [sz, sz]
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
predictor = SegmentationPredictor(cfg)
predictor()

Loading…
Cancel
Save