diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index ff50b52..a4bc790 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -157,7 +157,7 @@ class BaseValidator: self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch) model.eval() - model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup + model.warmup(imgsz=(1 if pt else self.args.batch, 1, imgsz, imgsz)) # warmup dt = Profile(), Profile(), Profile(), Profile() n_batches = len(self.dataloader)