ultralytics 8.0.70
minor fixes and improvements (#1892)
Co-authored-by: feicccccccc <49809204+feicccccccc@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
@ -99,7 +99,7 @@ class BasePredictor:
|
||||
self.device = None
|
||||
self.dataset = None
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.annotator = None
|
||||
self.plotted_img = None
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.batch = None
|
||||
@ -109,9 +109,6 @@ class BasePredictor:
|
||||
def preprocess(self, img):
|
||||
pass
|
||||
|
||||
def get_annotator(self, img):
|
||||
raise NotImplementedError('get_annotator function needs to be implemented')
|
||||
|
||||
def write_results(self, results, batch, print_string):
|
||||
raise NotImplementedError('print_results function needs to be implemented')
|
||||
|
||||
@ -208,10 +205,10 @@ class BasePredictor:
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s += self.write_results(i, self.results, (p, im, im0))
|
||||
|
||||
if self.args.show:
|
||||
if self.args.show and self.plotted_img is not None:
|
||||
self.show(p)
|
||||
|
||||
if self.args.save:
|
||||
if self.args.save and self.plotted_img is not None:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
self.run_callbacks('on_predict_batch_end')
|
||||
yield from self.results
|
||||
@ -251,7 +248,7 @@ class BasePredictor:
|
||||
self.model.eval()
|
||||
|
||||
def show(self, p):
|
||||
im0 = self.annotator.result()
|
||||
im0 = self.plotted_img
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
@ -260,7 +257,7 @@ class BasePredictor:
|
||||
cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
im0 = self.annotator.result()
|
||||
im0 = self.plotted_img
|
||||
# save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
cv2.imwrite(save_path, im0)
|
||||
|
@ -10,11 +10,10 @@ from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
|
||||
|
||||
|
||||
class BaseTensor(SimpleClass):
|
||||
@ -160,6 +159,7 @@ class Results(SimpleClass):
|
||||
pil=False,
|
||||
example='abc',
|
||||
img=None,
|
||||
img_gpu=None,
|
||||
kpt_line=True,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
@ -178,6 +178,7 @@ class Results(SimpleClass):
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
example (str): An example string to display. Useful for indicating the expected format of the output.
|
||||
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
||||
img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
|
||||
kpt_line (bool): Whether to draw lines connecting keypoints.
|
||||
labels (bool): Whether to plot the label of bounding boxes.
|
||||
boxes (bool): Whether to plot the bounding boxes.
|
||||
@ -185,7 +186,7 @@ class Results(SimpleClass):
|
||||
probs (bool): Whether to plot classification probability
|
||||
|
||||
Returns:
|
||||
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
|
||||
(numpy.ndarray): A numpy array of the annotated image.
|
||||
"""
|
||||
# Deprecation warn TODO: remove in 8.2
|
||||
if 'show_conf' in kwargs:
|
||||
@ -200,6 +201,13 @@ class Results(SimpleClass):
|
||||
pred_probs, show_probs = self.probs, probs
|
||||
names = self.names
|
||||
keypoints = self.keypoints
|
||||
if pred_masks and show_masks:
|
||||
if img_gpu is None:
|
||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.im)
|
||||
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.masks.device).permute(
|
||||
2, 0, 1).flip(0).contiguous() / 255
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu)
|
||||
|
||||
if pred_boxes and show_boxes:
|
||||
for d in reversed(pred_boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||
@ -207,15 +215,6 @@ class Results(SimpleClass):
|
||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
if pred_masks and show_masks:
|
||||
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=pred_masks.data.device).permute(2, 0,
|
||||
1).flip(0)
|
||||
if TORCHVISION_0_10:
|
||||
im = F.resize(im.contiguous(), pred_masks.data.shape[1:], antialias=True) / 255
|
||||
else:
|
||||
im = F.resize(im.contiguous(), pred_masks.data.shape[1:]) / 255
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=im)
|
||||
|
||||
if pred_probs is not None and show_probs:
|
||||
n5 = min(len(names), 5)
|
||||
top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
@ -226,7 +225,7 @@ class Results(SimpleClass):
|
||||
for k in reversed(keypoints):
|
||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||
|
||||
return np.asarray(annotator.im) if annotator.pil else annotator.im
|
||||
return annotator.result()
|
||||
|
||||
|
||||
class Boxes(BaseTensor):
|
||||
|
Reference in New Issue
Block a user