diff --git a/docs/modes/predict.md b/docs/modes/predict.md index 30f8743..6db2a9f 100644 --- a/docs/modes/predict.md +++ b/docs/modes/predict.md @@ -216,19 +216,19 @@ masks, classification logits, etc.) found in the results object res_plotted = res[0].plot() cv2.imshow("result", res_plotted) ``` -| Argument | Description | -| ----------- | ------------- | -| `conf (bool)` | Whether to plot the detection confidence score. | -| `line_width (float, optional)` | The line width of the bounding boxes. If None, it is scaled to the image size. | -| `font_size (float, optional)` | The font size of the text. If None, it is scaled to the image size. | -| `font (str)` | The font to use for the text. | -| `pil (bool)` | Whether to return the image as a PIL Image. | -| `example (str)` | An example string to display. Useful for indicating the expected format of the output. | -| `img (numpy.ndarray)` | Plot to another image. if not, plot to original image. | -| `labels (bool)` | Whether to plot the label of bounding boxes. | -| `boxes (bool)` | Whether to plot the bounding boxes. | -| `masks (bool)` | Whether to plot the masks. | -| `probs (bool)` | Whether to plot classification probability. | +| Argument | Description | +|--------------------------------|----------------------------------------------------------------------------------------| +| `conf (bool)` | Whether to plot the detection confidence score. | +| `line_width (float, optional)` | The line width of the bounding boxes. If None, it is scaled to the image size. | +| `font_size (float, optional)` | The font size of the text. If None, it is scaled to the image size. | +| `font (str)` | The font to use for the text. | +| `pil (bool)` | Whether to use PIL for image plotting. | +| `example (str)` | An example string to display. Useful for indicating the expected format of the output. | +| `img (numpy.ndarray)` | Plot to another image. if not, plot to original image. | +| `labels (bool)` | Whether to plot the label of bounding boxes. | +| `boxes (bool)` | Whether to plot the bounding boxes. | +| `masks (bool)` | Whether to plot the masks. | +| `probs (bool)` | Whether to plot classification probability. | ## Streaming Source `for`-loop diff --git a/docs/tasks/pose.md b/docs/tasks/pose.md index 0807b25..fb56208 100644 --- a/docs/tasks/pose.md +++ b/docs/tasks/pose.md @@ -56,19 +56,19 @@ Train a YOLOv8-pose model on the COCO128-pose dataset. model = YOLO('yolov8n-pose.yaml').load('yolov8n-pose.pt') # build from YAML and transfer weights # Train the model - model.train(data='coco128-pose.yaml', epochs=100, imgsz=640) + model.train(data='coco8-pose.yaml', epochs=100, imgsz=640) ``` === "CLI" ```bash # Build a new model from YAML and start training from scratch - yolo pose train data=coco128-pose.yaml model=yolov8n-pose.yaml epochs=100 imgsz=640 + yolo pose train data=coco8-pose.yaml model=yolov8n-pose.yaml epochs=100 imgsz=640 # Start training from a pretrained *.pt model - yolo pose train data=coco128-pose.yaml model=yolov8n-pose.pt epochs=100 imgsz=640 + yolo pose train data=coco8-pose.yaml model=yolov8n-pose.pt epochs=100 imgsz=640 # Build a new model from YAML, transfer pretrained weights to it and start training - yolo pose train data=coco128-pose.yaml model=yolov8n-pose.yaml pretrained=yolov8n-pose.pt epochs=100 imgsz=640 + yolo pose train data=coco8-pose.yaml model=yolov8n-pose.yaml pretrained=yolov8n-pose.pt epochs=100 imgsz=640 ``` ## Val diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 08a183e..f8689ac 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.69' +__version__ = '8.0.70' from ultralytics.hub import start from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py index a907ad7..83203dc 100644 --- a/ultralytics/hub/__init__.py +++ b/ultralytics/hub/__init__.py @@ -18,7 +18,7 @@ def login(api_key=''): from ultralytics import hub hub.login('API_KEY') """ - Auth(api_key) + Auth(api_key, verbose=True) def logout(): @@ -82,7 +82,7 @@ def export_model(model_id='', format='torchscript'): def get_export(model_id='', format='torchscript'): # Get an exported model dictionary with download URL - assert format in export_fmts_hub, f"Unsupported export format '{format}', valid formats are {export_fmts_hub}" + assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" r = requests.post('https://api.ultralytics.com/get-export', json={ 'apiKey': Auth().api_key, diff --git a/ultralytics/hub/auth.py b/ultralytics/hub/auth.py index 10ee4df..9a7acbe 100644 --- a/ultralytics/hub/auth.py +++ b/ultralytics/hub/auth.py @@ -11,7 +11,7 @@ API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys' class Auth: id_token = api_key = model_key = False - def __init__(self, api_key='', verbose=True): + def __init__(self, api_key='', verbose=False): """ Initialize the Auth class with an optional API key. diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index b90858e..0f93f50 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -58,7 +58,7 @@ class HUBTrainingSession: raise ValueError(f'Invalid HUBTrainingSession input: {url}') # Authorize - auth = Auth(key, verbose=False) + auth = Auth(key) self.agent_id = None # identifies which instance is communicating with server self.model_id = model_id self.model_url = f'https://hub.ultralytics.com/models/{model_id}' diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py index 7f3f8a9..91424cc 100644 --- a/ultralytics/yolo/cfg/__init__.py +++ b/ultralytics/yolo/cfg/__init__.py @@ -19,7 +19,7 @@ TASK2DATA = { 'detect': 'coco128.yaml', 'segment': 'coco128-seg.yaml', 'classify': 'imagenet100', - 'pose': 'coco128-pose.yaml'} + 'pose': 'coco8-pose.yaml'} TASK2MODEL = { 'detect': 'yolov8n.pt', 'segment': 'yolov8n-seg.pt', diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index d7ee374..f480027 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -99,7 +99,7 @@ class BasePredictor: self.device = None self.dataset = None self.vid_path, self.vid_writer = None, None - self.annotator = None + self.plotted_img = None self.data_path = None self.source_type = None self.batch = None @@ -109,9 +109,6 @@ class BasePredictor: def preprocess(self, img): pass - def get_annotator(self, img): - raise NotImplementedError('get_annotator function needs to be implemented') - def write_results(self, results, batch, print_string): raise NotImplementedError('print_results function needs to be implemented') @@ -208,10 +205,10 @@ class BasePredictor: if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: s += self.write_results(i, self.results, (p, im, im0)) - if self.args.show: + if self.args.show and self.plotted_img is not None: self.show(p) - if self.args.save: + if self.args.save and self.plotted_img is not None: self.save_preds(vid_cap, i, str(self.save_dir / p.name)) self.run_callbacks('on_predict_batch_end') yield from self.results @@ -251,7 +248,7 @@ class BasePredictor: self.model.eval() def show(self, p): - im0 = self.annotator.result() + im0 = self.plotted_img if platform.system() == 'Linux' and p not in self.windows: self.windows.append(p) cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) @@ -260,7 +257,7 @@ class BasePredictor: cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond def save_preds(self, vid_cap, idx, save_path): - im0 = self.annotator.result() + im0 = self.plotted_img # save imgs if self.dataset.mode == 'image': cv2.imwrite(save_path, im0) diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index 8fbcb4b..746859c 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -10,11 +10,10 @@ from functools import lru_cache import numpy as np import torch -import torchvision.transforms.functional as F +from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops from ultralytics.yolo.utils.plotting import Annotator, colors -from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10 class BaseTensor(SimpleClass): @@ -160,6 +159,7 @@ class Results(SimpleClass): pil=False, example='abc', img=None, + img_gpu=None, kpt_line=True, labels=True, boxes=True, @@ -178,6 +178,7 @@ class Results(SimpleClass): pil (bool): Whether to return the image as a PIL Image. example (str): An example string to display. Useful for indicating the expected format of the output. img (numpy.ndarray): Plot to another image. if not, plot to original image. + img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting. kpt_line (bool): Whether to draw lines connecting keypoints. labels (bool): Whether to plot the label of bounding boxes. boxes (bool): Whether to plot the bounding boxes. @@ -185,7 +186,7 @@ class Results(SimpleClass): probs (bool): Whether to plot classification probability Returns: - (None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned. + (numpy.ndarray): A numpy array of the annotated image. """ # Deprecation warn TODO: remove in 8.2 if 'show_conf' in kwargs: @@ -200,6 +201,13 @@ class Results(SimpleClass): pred_probs, show_probs = self.probs, probs names = self.names keypoints = self.keypoints + if pred_masks and show_masks: + if img_gpu is None: + img = LetterBox(pred_masks.shape[1:])(image=annotator.im) + img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.masks.device).permute( + 2, 0, 1).flip(0).contiguous() / 255 + annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu) + if pred_boxes and show_boxes: for d in reversed(pred_boxes): c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) @@ -207,15 +215,6 @@ class Results(SimpleClass): label = (f'{name} {conf:.2f}' if conf else name) if labels else None annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) - if pred_masks and show_masks: - im = torch.as_tensor(annotator.im, dtype=torch.float16, device=pred_masks.data.device).permute(2, 0, - 1).flip(0) - if TORCHVISION_0_10: - im = F.resize(im.contiguous(), pred_masks.data.shape[1:], antialias=True) / 255 - else: - im = F.resize(im.contiguous(), pred_masks.data.shape[1:]) / 255 - annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=im) - if pred_probs is not None and show_probs: n5 = min(len(names), 5) top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices @@ -226,7 +225,7 @@ class Results(SimpleClass): for k in reversed(keypoints): annotator.kpts(k, self.orig_shape, kpt_line=kpt_line) - return np.asarray(annotator.im) if annotator.pil else annotator.im + return annotator.result() class Boxes(BaseTensor): diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index c879d8d..26b6fe8 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -300,13 +300,12 @@ def clip_coords(coords, shape): coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y -def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): +def scale_image(masks, im0_shape, ratio_pad=None): """ Takes a mask, and resizes it to the original image size Args: - im1_shape (tuple): model input shape, [h, w] - masks (torch.Tensor): [h, w, num] + masks (torch.Tensor): resized and padded masks/images, [h, w, num]/[h, w, 3]. im0_shape (tuple): the original image shape ratio_pad (tuple): the ratio of the padding to the original image. @@ -314,10 +313,14 @@ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): masks (torch.Tensor): The masks that are being returned. """ # Rescale coordinates (xyxy) from im1_shape to im0_shape + im1_shape = masks.shape + if im1_shape[:2] == im0_shape[:2]: + return masks if ratio_pad is None: # calculate from im0_shape gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding else: + gain = ratio_pad[0][0] pad = ratio_pad[1] top, left = int(pad[1]), int(pad[0]) # y, x bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) @@ -329,9 +332,9 @@ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): # masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0] # masks = masks.permute(1, 2, 0).contiguous() masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) - if len(masks.shape) == 2: masks = masks[:, :, None] + return masks diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py index b139f97..6338822 100644 --- a/ultralytics/yolo/utils/plotting.py +++ b/ultralytics/yolo/utils/plotting.py @@ -138,7 +138,7 @@ class Annotator: im_gpu = im_gpu * inv_alph_masks[-1] + mcs im_mask = (im_gpu * 255) im_mask_np = im_mask.byte().cpu().numpy() - self.im[:] = im_mask_np if retina_masks else scale_image(im_gpu.shape, im_mask_np, self.im.shape) + self.im[:] = im_mask_np if retina_masks else scale_image(im_mask_np, self.im.shape) if self.pil: # convert im back to PIL and update draw self.fromarray(self.im) @@ -165,11 +165,11 @@ class Annotator: conf = k[2] if conf < 0.5: continue - cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1) + cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA) if kpt_line: ndim = kpts.shape[-1] - for sk_id, sk in enumerate(self.skeleton): + for i, sk in enumerate(self.skeleton): pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1])) pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1])) if ndim == 3: @@ -181,7 +181,7 @@ class Annotator: continue if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0: continue - cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[sk_id]], thickness=2) + cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA) if self.pil: # convert im back to PIL and update draw self.fromarray(self.im) diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index 790fcee..e869dfd 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -44,7 +44,6 @@ class ClassificationPredictor(BasePredictor): # save_path = str(self.save_dir / p.name) # im.jpg self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') log_string += '%gx%g ' % im.shape[2:] # print string - self.annotator = self.get_annotator(im0) result = results[idx] if len(result) == 0: @@ -56,10 +55,10 @@ class ClassificationPredictor(BasePredictor): log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, " # write - text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i) if self.args.save or self.args.show: # Add bbox to image - self.annotator.text((32, 32), text, txt_color=(255, 255, 255)) + self.plotted_img = result.plot() if self.args.save_txt: # Write to file + text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i) with open(f'{self.txt_path}.txt', 'a') as f: f.write(text + '\n') diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index a54b6e7..4a9621a 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -5,14 +5,11 @@ import torch from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops -from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box +from ultralytics.yolo.utils.plotting import save_one_box class DetectionPredictor(BasePredictor): - def get_annotator(self, img): - return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names)) - def preprocess(self, img): img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 @@ -52,15 +49,18 @@ class DetectionPredictor(BasePredictor): self.data_path = p self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') log_string += '%gx%g ' % im.shape[2:] # print string - self.annotator = self.get_annotator(im0) - det = results[idx].boxes # TODO: make boxes inherit from tensors - if len(det) == 0: + result = results[idx] # TODO: make boxes inherit from tensors + if len(result) == 0: return f'{log_string}(no detections), ' + det = result.boxes for c in det.cls.unique(): n = (det.cls == c).sum() # detections per class log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " + if self.args.save or self.args.show: # Add bbox to image + self.plotted_img = result.plot(line_width=self.args.line_thickness) + # write for d in reversed(det): c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) @@ -68,10 +68,6 @@ class DetectionPredictor(BasePredictor): line = (c, *d.xywhn.view(-1)) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) with open(f'{self.txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save or self.args.show: # Add bbox to image - name = ('' if id is None else f'id:{id} ') + self.model.names[c] - label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None - self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: save_one_box(d.xyxy, imc, diff --git a/ultralytics/yolo/v8/pose/predict.py b/ultralytics/yolo/v8/pose/predict.py index c121f80..06daa96 100644 --- a/ultralytics/yolo/v8/pose/predict.py +++ b/ultralytics/yolo/v8/pose/predict.py @@ -2,7 +2,7 @@ from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops -from ultralytics.yolo.utils.plotting import colors, save_one_box +from ultralytics.yolo.utils.plotting import save_one_box from ultralytics.yolo.v8.detect.predict import DetectionPredictor @@ -49,33 +49,27 @@ class PosePredictor(DetectionPredictor): self.data_path = p self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') log_string += '%gx%g ' % im.shape[2:] # print string - self.annotator = self.get_annotator(im0) - det = results[idx].boxes # TODO: make boxes inherit from tensors - if len(det) == 0: + result = results[idx] # TODO: make boxes inherit from tensors + if len(result) == 0: return f'{log_string}(no detections), ' + det = result.boxes for c in det.cls.unique(): n = (det.cls == c).sum() # detections per class log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " - kpts = reversed(results[idx].keypoints) - for k in kpts: - self.annotator.kpts(k, shape=results[idx].orig_shape) + if self.args.save or self.args.show: # Add bbox to image + self.plotted_img = result.plot(line_width=self.args.line_thickness, boxes=self.args.boxes) # write for j, d in enumerate(reversed(det)): c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) if self.args.save_txt: # Write to file - kpt = (kpts[j][:, :2] / d.orig_shape[[1, 0]]).reshape(-1).tolist() + kpt = (result[j].keypoints[:, :2] / d.orig_shape[[1, 0]]).reshape(-1).tolist() box = d.xywhn.view(-1).tolist() line = (c, *box, *kpt) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) with open(f'{self.txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save or self.args.show: # Add bbox to image - name = ('' if id is None else f'id:{id} ') + self.model.names[c] - label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None - if self.args.boxes: - self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: save_one_box(d.xyxy, imc, diff --git a/ultralytics/yolo/v8/pose/val.py b/ultralytics/yolo/v8/pose/val.py index 1834e48..4a2e2c6 100644 --- a/ultralytics/yolo/v8/pose/val.py +++ b/ultralytics/yolo/v8/pose/val.py @@ -198,7 +198,7 @@ class PoseValidator(DetectionValidator): def val(cfg=DEFAULT_CFG, use_python=False): model = cfg.model or 'yolov8n-pose.pt' - data = cfg.data or 'coco128-pose.yaml' + data = cfg.data or 'coco8-pose.yaml' args = dict(model=model, data=data) if use_python: diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 66b35f7..969b978 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -4,7 +4,7 @@ import torch from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops -from ultralytics.yolo.utils.plotting import colors, save_one_box +from ultralytics.yolo.utils.plotting import save_one_box from ultralytics.yolo.v8.detect.predict import DetectionPredictor @@ -56,7 +56,6 @@ class SegmentationPredictor(DetectionPredictor): self.data_path = p self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') log_string += '%gx%g ' % im.shape[2:] # print string - self.annotator = self.get_annotator(im0) result = results[idx] if len(result) == 0: @@ -72,7 +71,7 @@ class SegmentationPredictor(DetectionPredictor): if self.args.save or self.args.show: im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute( 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx] - self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu) + self.plotted_img = result.plot(line_width=self.args.line_thickness, im_gpu=im_gpu, boxes=self.args.boxes) # Write results for j, d in enumerate(reversed(det)): @@ -82,11 +81,6 @@ class SegmentationPredictor(DetectionPredictor): line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) with open(f'{self.txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save or self.args.show: # Add bbox to image - name = ('' if id is None else f'id:{id} ') + self.model.names[c] - label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None - if self.args.boxes: - self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: save_one_box(d.xyxy, imc, diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 2beefcf..c3736ef 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -111,8 +111,7 @@ class SegmentationValidator(DetectionValidator): # Save if self.args.save_json: - pred_masks = ops.scale_image(batch['img'][si].shape[1:], - pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), + pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, ratio_pad=batch['ratio_pad'][si]) self.pred_to_json(predn, batch['im_file'][si], pred_masks)