Update save_dir on new predict args (#3215)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 431fad6834
commit d69a1e8046
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -250,6 +250,8 @@ class YOLO:
self.predictor.setup_model(model=self.model, verbose=is_cli) self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides) self.predictor.args = get_cfg(self.predictor.args, overrides)
if 'project' in overrides or 'name' in overrides:
self.predictor.save_dir = self.predictor.get_save_dir()
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, persist=False, **kwargs): def track(self, source=None, stream=False, persist=False, **kwargs):

@ -65,14 +65,13 @@ class BasePredictor:
Attributes: Attributes:
args (SimpleNamespace): Configuration for the predictor. args (SimpleNamespace): Configuration for the predictor.
save_dir (Path): Directory to save results. save_dir (Path): Directory to save results.
done_setup (bool): Whether the predictor has finished setup. done_warmup (bool): Whether the predictor has finished setup.
model (nn.Module): Model used for prediction. model (nn.Module): Model used for prediction.
data (dict): Data configuration. data (dict): Data configuration.
device (torch.device): Device used for prediction. device (torch.device): Device used for prediction.
dataset (Dataset): Dataset used for prediction. dataset (Dataset): Dataset used for prediction.
vid_path (str): Path to video file. vid_path (str): Path to video file.
vid_writer (cv2.VideoWriter): Video writer for saving video output. vid_writer (cv2.VideoWriter): Video writer for saving video output.
annotator (Annotator): Annotator used for prediction.
data_path (str): Path to data. data_path (str): Path to data.
""" """
@ -85,9 +84,7 @@ class BasePredictor:
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task self.save_dir = self.get_save_dir()
name = self.args.name or f'{self.args.mode}'
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
if self.args.conf is None: if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25 self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False self.done_warmup = False
@ -108,6 +105,11 @@ class BasePredictor:
self.callbacks = _callbacks or callbacks.get_default_callbacks() self.callbacks = _callbacks or callbacks.get_default_callbacks()
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
def get_save_dir(self):
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f'{self.args.mode}'
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
def preprocess(self, im): def preprocess(self, im):
"""Prepares input image before inference. """Prepares input image before inference.

Loading…
Cancel
Save