diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 7334e98..f7e6b75 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -54,7 +54,7 @@ class BaseValidator: self.batch_i = batch_i # pre-process with dt[0]: - batch = self.preprocess_batch(batch) + batch = self.preprocess(batch) # inference with dt[1]: @@ -69,7 +69,7 @@ class BaseValidator: # pre-process predictions with dt[3]: - preds = self.preprocess_preds(preds) + preds = self.postprocess(preds) self.update_metrics(preds, batch) @@ -89,10 +89,10 @@ class BaseValidator: return stats - def preprocess_batch(self, batch): + def preprocess(self, batch): return batch - def preprocess_preds(self, preds): + def postprocess(self, preds): return preds def init_metrics(self): diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index f24ae7f..c0420e8 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -8,7 +8,7 @@ class ClassificationValidator(BaseValidator): def init_metrics(self, model): self.correct = torch.tensor([]) - def preprocess_batch(self, batch): + def preprocess(self, batch): batch["img"] = batch["img"].to(self.device) batch["cls"] = batch["cls"].to(self.device) return batch diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 40aa7ab..0078e76 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -28,7 +28,7 @@ class SegmentationValidator(BaseValidator): self.class_map = None self.targets = None - def preprocess_batch(self, batch): + def preprocess(self, batch): batch["img"] = batch["img"].to(self.device, non_blocking=True) batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225 batch["bboxes"] = batch["bboxes"].to(self.device) @@ -66,7 +66,7 @@ class SegmentationValidator(BaseValidator): return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", "R", "mAP50", "mAP50-95)") - def preprocess_preds(self, preds): + def postprocess(self, preds): p = ops.non_max_suppression(preds[0], self.args.conf_thres, self.args.iou_thres,