Return processed outputs from predictor (#161)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Ayush Chaurasia
2023-01-10 00:10:44 +05:30
committed by GitHub
parent cb4801888e
commit 6e5638c128
7 changed files with 23 additions and 10 deletions

View File

@ -121,11 +121,12 @@ class YOLO:
overrides["conf"] = 0.25
overrides.update(kwargs)
overrides["mode"] = "predict"
overrides["save"] = kwargs.get("save", False) # not save files by default
predictor = self.PredictorClass(overrides=overrides)
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
predictor.setup(model=self.model, source=source)
predictor()
return predictor()
@smart_inference_mode()
def val(self, data=None, **kwargs):

View File

@ -76,7 +76,8 @@ class BasePredictor:
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.save:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_setup = False
@ -149,7 +150,9 @@ class BasePredictor:
def __call__(self, source=None, model=None):
self.run_callbacks("on_predict_start")
model = self.model if self.done_setup else self.setup(source, model)
model.eval()
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
self.all_outputs = []
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
path, im, im0s, vid_cap, s = batch
@ -194,6 +197,7 @@ class BasePredictor:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end")
return self.all_outputs
def show(self, p):
im0 = self.annotator.result()