ultralytics 8.0.35 TensorRT, ONNX and OpenVINO predict and val (#929)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Pedley <ericpedley@gmail.com>
This commit is contained in:
Glenn Jocher
2023-02-11 21:31:49 +04:00
committed by GitHub
parent d32b339373
commit 977fd8f0b8
15 changed files with 88 additions and 69 deletions

View File

@ -57,15 +57,17 @@ class YOLO:
self.overrides = {} # overrides for trainer object
# Load or create new YOLO model
load_methods = {'.pt': self._load, '.yaml': self._new}
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
if suffix in load_methods:
{'.pt': self._load, '.yaml': self._new}[suffix](model)
else:
raise NotImplementedError(f"'{suffix}' models not supported. Try a *.pt and *.yaml model, "
"i.e. model='yolov8n.pt' or model='yolov8n.yaml'")
try:
if suffix == '.yaml':
self._new(model)
else:
self._load(model)
except Exception as e:
raise NotImplementedError(f"Unable to load model='{model}'. "
f"As an example try model='yolov8n.pt' or model='yolov8n.yaml'") from e
def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, **kwargs)
@ -78,13 +80,11 @@ class YOLO:
cfg (str): model configuration file
verbose (bool): display model info on load
"""
cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
self.cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
self.task = guess_model_task(cfg_dict)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._assign_ops_from_task(self.task)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
self.cfg = cfg
def _load(self, weights: str):
"""
@ -93,13 +93,17 @@ class YOLO:
Args:
weights (str): model checkpoint to be loaded
"""
self.model, self.ckpt = attempt_load_one_weight(weights)
suffix = Path(weights).suffix
if suffix == '.pt':
self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args["task"]
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
else:
self.model, self.ckpt = weights, None
self.task = guess_model_task(weights)
self.ckpt_path = weights
self.task = self.model.args["task"]
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
self._assign_ops_from_task(self.task)
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
def reset(self):
"""
@ -166,7 +170,7 @@ class YOLO:
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
@ -189,7 +193,8 @@ class YOLO:
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
exporter = Exporter(overrides=args)
exporter(model=self.model)
@ -231,8 +236,8 @@ class YOLO:
"""
self.model.to(device)
def _assign_ops_from_task(self, task):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
def _assign_ops_from_task(self):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
# warning: eval is unsafe. Use with caution
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))