[rename] - preprocess-batch -> preprocess, preprocess_preds -> postprocess (#42)

single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 4c68b9dcf6
commit d143ac666f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,7 +54,7 @@ class BaseValidator:
self.batch_i = batch_i self.batch_i = batch_i
# pre-process # pre-process
with dt[0]: with dt[0]:
batch = self.preprocess_batch(batch) batch = self.preprocess(batch)
# inference # inference
with dt[1]: with dt[1]:
@ -69,7 +69,7 @@ class BaseValidator:
# pre-process predictions # pre-process predictions
with dt[3]: with dt[3]:
preds = self.preprocess_preds(preds) preds = self.postprocess(preds)
self.update_metrics(preds, batch) self.update_metrics(preds, batch)
@ -89,10 +89,10 @@ class BaseValidator:
return stats return stats
def preprocess_batch(self, batch): def preprocess(self, batch):
return batch return batch
def preprocess_preds(self, preds): def postprocess(self, preds):
return preds return preds
def init_metrics(self): def init_metrics(self):

@ -8,7 +8,7 @@ class ClassificationValidator(BaseValidator):
def init_metrics(self, model): def init_metrics(self, model):
self.correct = torch.tensor([]) self.correct = torch.tensor([])
def preprocess_batch(self, batch): def preprocess(self, batch):
batch["img"] = batch["img"].to(self.device) batch["img"] = batch["img"].to(self.device)
batch["cls"] = batch["cls"].to(self.device) batch["cls"] = batch["cls"].to(self.device)
return batch return batch

@ -28,7 +28,7 @@ class SegmentationValidator(BaseValidator):
self.class_map = None self.class_map = None
self.targets = None self.targets = None
def preprocess_batch(self, batch): def preprocess(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True) batch["img"] = batch["img"].to(self.device, non_blocking=True)
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225 batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225
batch["bboxes"] = batch["bboxes"].to(self.device) batch["bboxes"] = batch["bboxes"].to(self.device)
@ -66,7 +66,7 @@ class SegmentationValidator(BaseValidator):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
"R", "mAP50", "mAP50-95)") "R", "mAP50", "mAP50-95)")
def preprocess_preds(self, preds): def postprocess(self, preds):
p = ops.non_max_suppression(preds[0], p = ops.non_max_suppression(preds[0],
self.args.conf_thres, self.args.conf_thres,
self.args.iou_thres, self.args.iou_thres,

Loading…
Cancel
Save