ultralytics 8.0.70 minor fixes and improvements (#1892)

Co-authored-by: feicccccccc <49809204+feicccccccc@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Glenn Jocher
2023-04-08 00:27:33 +02:00
committed by GitHub
parent c2cd3fd20e
commit c38b17a0d8
17 changed files with 71 additions and 90 deletions

View File

@ -99,7 +99,7 @@ class BasePredictor:
self.device = None
self.dataset = None
self.vid_path, self.vid_writer = None, None
self.annotator = None
self.plotted_img = None
self.data_path = None
self.source_type = None
self.batch = None
@ -109,9 +109,6 @@ class BasePredictor:
def preprocess(self, img):
pass
def get_annotator(self, img):
raise NotImplementedError('get_annotator function needs to be implemented')
def write_results(self, results, batch, print_string):
raise NotImplementedError('print_results function needs to be implemented')
@ -208,10 +205,10 @@ class BasePredictor:
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s += self.write_results(i, self.results, (p, im, im0))
if self.args.show:
if self.args.show and self.plotted_img is not None:
self.show(p)
if self.args.save:
if self.args.save and self.plotted_img is not None:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks('on_predict_batch_end')
yield from self.results
@ -251,7 +248,7 @@ class BasePredictor:
self.model.eval()
def show(self, p):
im0 = self.annotator.result()
im0 = self.plotted_img
if platform.system() == 'Linux' and p not in self.windows:
self.windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
@ -260,7 +257,7 @@ class BasePredictor:
cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond
def save_preds(self, vid_cap, idx, save_path):
im0 = self.annotator.result()
im0 = self.plotted_img
# save imgs
if self.dataset.mode == 'image':
cv2.imwrite(save_path, im0)

View File

@ -10,11 +10,10 @@ from functools import lru_cache
import numpy as np
import torch
import torchvision.transforms.functional as F
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
from ultralytics.yolo.utils.plotting import Annotator, colors
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
class BaseTensor(SimpleClass):
@ -160,6 +159,7 @@ class Results(SimpleClass):
pil=False,
example='abc',
img=None,
img_gpu=None,
kpt_line=True,
labels=True,
boxes=True,
@ -178,6 +178,7 @@ class Results(SimpleClass):
pil (bool): Whether to return the image as a PIL Image.
example (str): An example string to display. Useful for indicating the expected format of the output.
img (numpy.ndarray): Plot to another image. if not, plot to original image.
img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
kpt_line (bool): Whether to draw lines connecting keypoints.
labels (bool): Whether to plot the label of bounding boxes.
boxes (bool): Whether to plot the bounding boxes.
@ -185,7 +186,7 @@ class Results(SimpleClass):
probs (bool): Whether to plot classification probability
Returns:
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
(numpy.ndarray): A numpy array of the annotated image.
"""
# Deprecation warn TODO: remove in 8.2
if 'show_conf' in kwargs:
@ -200,6 +201,13 @@ class Results(SimpleClass):
pred_probs, show_probs = self.probs, probs
names = self.names
keypoints = self.keypoints
if pred_masks and show_masks:
if img_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.im)
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.masks.device).permute(
2, 0, 1).flip(0).contiguous() / 255
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu)
if pred_boxes and show_boxes:
for d in reversed(pred_boxes):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
@ -207,15 +215,6 @@ class Results(SimpleClass):
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
if pred_masks and show_masks:
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=pred_masks.data.device).permute(2, 0,
1).flip(0)
if TORCHVISION_0_10:
im = F.resize(im.contiguous(), pred_masks.data.shape[1:], antialias=True) / 255
else:
im = F.resize(im.contiguous(), pred_masks.data.shape[1:]) / 255
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=im)
if pred_probs is not None and show_probs:
n5 = min(len(names), 5)
top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
@ -226,7 +225,7 @@ class Results(SimpleClass):
for k in reversed(keypoints):
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
return np.asarray(annotator.im) if annotator.pil else annotator.im
return annotator.result()
class Boxes(BaseTensor):