From d69a1e8046d2dcc9d89befdcd3bf2d8da0f67627 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 16 Jun 2023 19:59:46 +0200 Subject: [PATCH] Update save_dir on new predict args (#3215) --- ultralytics/yolo/engine/model.py | 2 ++ ultralytics/yolo/engine/predictor.py | 12 +++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index de3dc8d..6d3168b 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -250,6 +250,8 @@ class YOLO: self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) + if 'project' in overrides or 'name' in overrides: + self.predictor.save_dir = self.predictor.get_save_dir() return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) def track(self, source=None, stream=False, persist=False, **kwargs): diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 32fea52..5a37878 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -65,14 +65,13 @@ class BasePredictor: Attributes: args (SimpleNamespace): Configuration for the predictor. save_dir (Path): Directory to save results. - done_setup (bool): Whether the predictor has finished setup. + done_warmup (bool): Whether the predictor has finished setup. model (nn.Module): Model used for prediction. data (dict): Data configuration. device (torch.device): Device used for prediction. dataset (Dataset): Dataset used for prediction. vid_path (str): Path to video file. vid_writer (cv2.VideoWriter): Video writer for saving video output. - annotator (Annotator): Annotator used for prediction. data_path (str): Path to data. """ @@ -85,9 +84,7 @@ class BasePredictor: overrides (dict, optional): Configuration overrides. Defaults to None. """ self.args = get_cfg(cfg, overrides) - project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task - name = self.args.name or f'{self.args.mode}' - self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) + self.save_dir = self.get_save_dir() if self.args.conf is None: self.args.conf = 0.25 # default conf=0.25 self.done_warmup = False @@ -108,6 +105,11 @@ class BasePredictor: self.callbacks = _callbacks or callbacks.get_default_callbacks() callbacks.add_integration_callbacks(self) + def get_save_dir(self): + project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task + name = self.args.name or f'{self.args.mode}' + return increment_path(Path(project) / name, exist_ok=self.args.exist_ok) + def preprocess(self, im): """Prepares input image before inference.