diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 4d0f9f5..e1382cd 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -95,6 +95,7 @@ class BaseValidator: if trainer is passed (trainer gets priority). """ self.training = trainer is not None + augment = self.args.augment and (not self.training) if self.training: self.device = trainer.device self.data = trainer.data @@ -159,7 +160,7 @@ class BaseValidator: # Inference with dt[1]: - preds = model(batch['img'], augment=self.args.augment) + preds = model(batch['img'], augment=augment) # Loss with dt[2]: