diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index ab59658..7dfedf9 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -126,6 +126,7 @@ class YOLO: """ overrides = self.overrides.copy() overrides.update(kwargs) + overrides["mode"] = "predict" predictor = self.PredictorClass(overrides=overrides) # check size type @@ -151,6 +152,7 @@ class YOLO: overrides = self.overrides.copy() overrides.update(kwargs) + overrides["mode"] = "val" args = get_config(config=DEFAULT_CONFIG, overrides=overrides) args.data = data or args.data args.task = self.task diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index a1aefc0..bcd939f 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -46,9 +46,9 @@ class BasePredictor: def __init__(self, config=DEFAULT_CONFIG, overrides={}): self.args = get_config(config, overrides) - project = overrides.get("project") or self.args.task - name = overrides.get("name") or self.args.mode - self.save_dir = increment_path(Path("runs") / project / name, exist_ok=self.args.exist_ok) + project = self.args.project or f"runs/{self.args.task}" + name = self.args.name or f"{self.args.mode}" + self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) self.done_setup = False diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index d0524ea..c437e83 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -24,7 +24,7 @@ import ultralytics.yolo.utils as utils import ultralytics.yolo.utils.callbacks as callbacks from ultralytics import __version__ from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml -from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr +from ultralytics.yolo.utils import LOGGER, RANK, ROOT, TQDM_BAR_FORMAT, colorstr from ultralytics.yolo.utils.checks import check_file, print_args from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command @@ -32,7 +32,6 @@ from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_ya from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" -RANK = int(os.getenv('RANK', -1)) class BaseTrainer: @@ -48,9 +47,9 @@ class BaseTrainer: self.callbacks = defaultdict(list) # dirs - project = overrides.get("project") or self.args.task - name = overrides.get("name") or self.args.mode - self.save_dir = increment_path(Path("runs") / project / name, exist_ok=self.args.exist_ok) + project = self.args.project or f"runs/{self.args.task}" + name = self.args.name or f"{self.args.mode}" + self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK == -1 else True) self.wdir = self.save_dir / 'weights' # weights dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index f8fcd42..ecbd273 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -8,7 +8,7 @@ from tqdm import tqdm from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG -from ultralytics.yolo.utils import LOGGER, TQDM_BAR_FORMAT +from ultralytics.yolo.utils import LOGGER, RANK, TQDM_BAR_FORMAT from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.torch_utils import check_imgsz, de_parallel, select_device, smart_inference_mode @@ -32,9 +32,10 @@ class BaseValidator: self.speed = None self.jdict = None - project = self.args.project if self.args.project != "runs/train" else self.args.task - name = self.args.name if self.args.name != "exp" else self.args.mode - self.save_dir = increment_path(Path("runs") / project / name, exist_ok=self.args.exist_ok) + project = self.args.project or f"runs/{self.args.task}" + name = self.args.name or f"{self.args.mode}" + self.save_dir = save_dir or increment_path(Path(project) / name, + exist_ok=self.args.exist_ok if RANK == -1 else True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) @smart_inference_mode() diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml index b3f97bc..3865870 100644 --- a/ultralytics/yolo/utils/configs/default.yaml +++ b/ultralytics/yolo/utils/configs/default.yaml @@ -15,8 +15,8 @@ nosave: False cache: False # True/ram, disk or False device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu workers: 8 -project: 'runs/train' -name: 'exp' +project: null +name: null exist_ok: False pretrained: False optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 95666c3..7e95ba1 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -45,7 +45,7 @@ class ClassificationTrainer(BaseTrainer): return batch def get_validator(self): - return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console) + return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console) def criterion(self, preds, batch): loss = torch.nn.functional.cross_entropy(preds, batch["cls"]) diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index a8ce7f3..7c95a77 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -171,7 +171,7 @@ class DetectionValidator(BaseValidator): pad=0.5, rect=self.args.rect, workers=self.args.workers, - prefix=colorstr(f'{val}: '), + prefix=colorstr(f'{self.args.mode}: '), shuffle=False, seed=self.args.seed)[0] if self.args.v5loader else \ build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]