ultralytics 8.0.133 add torchvision compatibility check (#3703)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-07-13 01:40:50 +02:00
committed by GitHub
parent 0821ccb618
commit c55a98ab8e
17 changed files with 140 additions and 68 deletions

View File

@ -171,8 +171,8 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
Args:
custom (Dict): a dictionary of custom configuration options
base (Dict): a dictionary of base configuration options
custom (dict): a dictionary of custom configuration options
base (dict): a dictionary of base configuration options
"""
custom = _handle_deprecation(custom)
base, custom = (set(x.keys()) for x in (base, custom))

View File

@ -642,7 +642,8 @@ class CopyPaste:
class Albumentations:
# YOLOv8 Albumentations class (optional, only used if package is installed)
"""YOLOv8 Albumentations class (optional, only used if package is installed)"""
def __init__(self, p=1.0):
"""Initialize the transform object for YOLO bbox formatted params."""
self.p = p
@ -819,7 +820,7 @@ def classify_albumentations(
std=(1.0, 1.0, 1.0), # IMAGENET_STD
auto_aug=False,
):
# YOLOv8 classification Albumentations (optional, only used if package is installed)
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
prefix = colorstr('albumentations: ')
try:
import albumentations as A
@ -851,7 +852,8 @@ def classify_albumentations(
class ClassifyLetterBox:
# YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
"""YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
def __init__(self, size=(640, 640), auto=False, stride=32):
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
super().__init__()
@ -871,7 +873,8 @@ class ClassifyLetterBox:
class CenterCrop:
# YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
"""YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
def __init__(self, size=640):
"""Converts an image from numpy array to PyTorch tensor."""
super().__init__()
@ -885,7 +888,8 @@ class CenterCrop:
class ToTensor:
# YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
"""YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
def __init__(self, half=False):
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
super().__init__()

View File

@ -63,7 +63,7 @@ class _RepeatSampler:
def seed_worker(worker_id): # noqa
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)

View File

@ -29,7 +29,8 @@ class SourceTypes:
class LoadStreams:
# YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
@ -116,7 +117,8 @@ class LoadStreams:
class LoadScreenshots:
# YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`
"""YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
def __init__(self, source, imgsz=640):
"""source = [screen_number left top width height] (pixels)."""
check_requirements('mss')
@ -158,7 +160,8 @@ class LoadScreenshots:
class LoadImages:
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
"""YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
def __init__(self, path, imgsz=640, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line

View File

@ -278,7 +278,7 @@ def check_cls_dataset(dataset: str, split=''):
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns:
dict: A dictionary containing the following keys:
(dict): A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset.

View File

@ -213,16 +213,18 @@ class Results(SimpleClass):
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
names = self.names
annotator = Annotator(deepcopy(self.orig_img if img is None else img),
line_width,
font_size,
font,
pil,
example=names)
pred_boxes, show_boxes = self.boxes, boxes
pred_masks, show_masks = self.masks, masks
pred_probs, show_probs = self.probs, probs
keypoints = self.keypoints
annotator = Annotator(
deepcopy(self.orig_img if img is None else img),
line_width,
font_size,
font,
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
example=names)
# Plot Segment results
if pred_masks and show_masks:
if img_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
@ -231,6 +233,7 @@ class Results(SimpleClass):
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
# Plot Detect results
if pred_boxes and show_boxes:
for d in reversed(pred_boxes):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
@ -238,12 +241,15 @@ class Results(SimpleClass):
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
# Plot Classify results
if pred_probs is not None and show_probs:
text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
x = round(self.orig_shape[0] * 0.03)
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
if keypoints is not None:
for k in reversed(keypoints.data):
# Plot Pose results
if self.keypoints is not None:
for k in reversed(self.keypoints.data):
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
return annotator.result()

View File

@ -211,6 +211,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
"""
prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
file = None
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
@ -255,6 +256,34 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
return True
def check_torchvision():
"""
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
Torchvision versions.
"""
import torchvision
# Compatibility table
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
# Extract only the major and minor versions
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
'For a full compatibility table see https://github.com/pytorch/vision#installation')
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
"""Check file(s) for acceptable suffix."""
if file and suffix:
@ -402,7 +431,7 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
try:
assert (Path(path) / '.git').is_dir()
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]

View File

@ -91,7 +91,7 @@ def get_latest_run(search_dir='.'):
def make_dirs(dir='new_dir/'):
# Create folders
"""Create directories."""
dir = Path(dir)
if dir.exists():
shutil.rmtree(dir) # delete dir

View File

@ -55,12 +55,17 @@ class Profile(contextlib.ContextDecorator):
return time.time()
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
def coco80_to_coco91_class(): #
"""
Converts 80-index (val2014) to 91-index (paper).
For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
Example:
a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
"""
return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,

View File

@ -34,7 +34,7 @@ _torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
try:
import dill as pickle
except ImportError:

View File

@ -21,7 +21,8 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
class Colors:
# Ultralytics color palette https://ultralytics.com/
"""Ultralytics color palette https://ultralytics.com/."""
def __init__(self):
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
@ -48,7 +49,8 @@ colors = Colors() # create instance for 'from utils.plots import colors'
class Annotator:
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
"""YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations."""
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
@ -204,7 +206,14 @@ class Annotator:
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color
txt_color = (255, 255, 255)
self.draw.text(xy, text, fill=txt_color, font=self.font)
if '\n' in text:
lines = text.split('\n')
_, h = self.font.getsize(text)
for line in lines:
self.draw.text(xy, line, fill=txt_color, font=self.font)
xy[1] += h
else:
self.draw.text(xy, text, fill=txt_color, font=self.font)
else:
if box_style:
tf = max(self.lw - 1, 1) # font thickness
@ -310,7 +319,7 @@ def plot_images(images,
fname='images.jpg',
names=None,
on_plot=None):
# Plot image grid with labels
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor):

View File

@ -232,7 +232,7 @@ def get_flops(model, imgsz=640):
def get_flops_with_torch_profiler(model, imgsz=640):
# Compute model FLOPs (thop alternative)
"""Compute model FLOPs (thop alternative)."""
model = de_parallel(model)
p = next(model.parameters())
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride