ultralytics 8.0.14 Hydra removal fixes and cleanup (#542)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kamlesh Kumar <patelkamleshpatel364@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-21 21:22:40 +01:00
committed by GitHub
parent cc3be0e223
commit d9a0fba251
30 changed files with 339 additions and 301 deletions

View File

@ -18,7 +18,7 @@ class ClassificationPredictor(BasePredictor):
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
return img
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
results = []
for i, pred in enumerate(preds):
shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape

View File

@ -19,12 +19,13 @@ class DetectionPredictor(BasePredictor):
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det)
max_det=self.args.max_det,
classes=self.args.classes)
results = []
for i, pred in enumerate(preds):

View File

@ -10,14 +10,15 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
class SegmentationPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
# TODO: filter by classes
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nm=32)
nm=32,
classes=self.args.classes)
results = []
proto = preds[1][-1]
for i, pred in enumerate(p):