Return processed outputs from predictor (#161)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
@ -121,11 +121,12 @@ class YOLO:
|
||||
overrides["conf"] = 0.25
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "predict"
|
||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
||||
predictor = self.PredictorClass(overrides=overrides)
|
||||
|
||||
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
|
||||
predictor.setup(model=self.model, source=source)
|
||||
predictor()
|
||||
return predictor()
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
|
@ -76,7 +76,8 @@ class BasePredictor:
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
if self.args.save:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_setup = False
|
||||
@ -149,7 +150,9 @@ class BasePredictor:
|
||||
def __call__(self, source=None, model=None):
|
||||
self.run_callbacks("on_predict_start")
|
||||
model = self.model if self.done_setup else self.setup(source, model)
|
||||
model.eval()
|
||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
self.all_outputs = []
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
@ -194,6 +197,7 @@ class BasePredictor:
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks("on_predict_end")
|
||||
return self.all_outputs
|
||||
|
||||
def show(self, p):
|
||||
im0 = self.annotator.result()
|
||||
|
Reference in New Issue
Block a user