ultralytics 8.0.35
TensorRT, ONNX and OpenVINO predict and val (#929)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Pedley <ericpedley@gmail.com>
This commit is contained in:
@ -161,8 +161,6 @@ class Exporter:
|
||||
|
||||
# Checks
|
||||
model.names = check_class_names(model.names)
|
||||
# if self.args.batch == model.args['batch_size']: # user has not modified training batch_size
|
||||
self.args.batch = 1
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if model.task == 'classify':
|
||||
self.args.nms = self.args.agnostic_nms = False
|
||||
|
@ -57,15 +57,17 @@ class YOLO:
|
||||
self.overrides = {} # overrides for trainer object
|
||||
|
||||
# Load or create new YOLO model
|
||||
load_methods = {'.pt': self._load, '.yaml': self._new}
|
||||
suffix = Path(model).suffix
|
||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
if suffix in load_methods:
|
||||
{'.pt': self._load, '.yaml': self._new}[suffix](model)
|
||||
else:
|
||||
raise NotImplementedError(f"'{suffix}' models not supported. Try a *.pt and *.yaml model, "
|
||||
"i.e. model='yolov8n.pt' or model='yolov8n.yaml'")
|
||||
try:
|
||||
if suffix == '.yaml':
|
||||
self._new(model)
|
||||
else:
|
||||
self._load(model)
|
||||
except Exception as e:
|
||||
raise NotImplementedError(f"Unable to load model='{model}'. "
|
||||
f"As an example try model='yolov8n.pt' or model='yolov8n.yaml'") from e
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
return self.predict(source, stream, **kwargs)
|
||||
@ -78,13 +80,11 @@ class YOLO:
|
||||
cfg (str): model configuration file
|
||||
verbose (bool): display model info on load
|
||||
"""
|
||||
cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
||||
self.cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
||||
self.task = guess_model_task(cfg_dict)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||
self._assign_ops_from_task(self.task)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
||||
self.cfg = cfg
|
||||
|
||||
def _load(self, weights: str):
|
||||
"""
|
||||
@ -93,13 +93,17 @@ class YOLO:
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
"""
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
else:
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||
self._assign_ops_from_task(self.task)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@ -166,7 +170,7 @@ class YOLO:
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
|
||||
@ -189,7 +193,8 @@ class YOLO:
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
exporter = Exporter(overrides=args)
|
||||
exporter(model=self.model)
|
||||
|
||||
@ -231,8 +236,8 @@ class YOLO:
|
||||
"""
|
||||
self.model.to(device)
|
||||
|
||||
def _assign_ops_from_task(self, task):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
||||
def _assign_ops_from_task(self):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||
# warning: eval is unsafe. Use with caution
|
||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
||||
|
@ -146,7 +146,7 @@ class BasePredictor:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
# warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
|
||||
@ -218,7 +218,7 @@ class BasePredictor:
|
||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
||||
cv2.imshow(str(p), im0)
|
||||
cv2.waitKey(1) # 1 millisecond
|
||||
cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
im0 = self.annotator.result()
|
||||
|
@ -95,7 +95,7 @@ class BaseValidator:
|
||||
assert model is not None, "Either trainer or model is needed for validation"
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
|
||||
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
|
||||
self.model = model
|
||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||
|
Reference in New Issue
Block a user