ultralytics 8.0.133 add torchvision compatibility check (#3703)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-07-13 01:40:50 +02:00
committed by GitHub
parent 0821ccb618
commit c55a98ab8e
17 changed files with 140 additions and 68 deletions

View File

@ -211,6 +211,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
"""
prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
file = None
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
@ -255,6 +256,34 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
return True
def check_torchvision():
"""
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
Torchvision versions.
"""
import torchvision
# Compatibility table
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
# Extract only the major and minor versions
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
'For a full compatibility table see https://github.com/pytorch/vision#installation')
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
"""Check file(s) for acceptable suffix."""
if file and suffix:
@ -402,7 +431,7 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
try:
assert (Path(path) / '.git').is_dir()
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]

View File

@ -91,7 +91,7 @@ def get_latest_run(search_dir='.'):
def make_dirs(dir='new_dir/'):
# Create folders
"""Create directories."""
dir = Path(dir)
if dir.exists():
shutil.rmtree(dir) # delete dir

View File

@ -55,12 +55,17 @@ class Profile(contextlib.ContextDecorator):
return time.time()
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
def coco80_to_coco91_class(): #
"""
Converts 80-index (val2014) to 91-index (paper).
For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
Example:
a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
"""
return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,

View File

@ -34,7 +34,7 @@ _torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
try:
import dill as pickle
except ImportError:

View File

@ -21,7 +21,8 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
class Colors:
# Ultralytics color palette https://ultralytics.com/
"""Ultralytics color palette https://ultralytics.com/."""
def __init__(self):
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
@ -48,7 +49,8 @@ colors = Colors() # create instance for 'from utils.plots import colors'
class Annotator:
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
"""YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations."""
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
@ -204,7 +206,14 @@ class Annotator:
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color
txt_color = (255, 255, 255)
self.draw.text(xy, text, fill=txt_color, font=self.font)
if '\n' in text:
lines = text.split('\n')
_, h = self.font.getsize(text)
for line in lines:
self.draw.text(xy, line, fill=txt_color, font=self.font)
xy[1] += h
else:
self.draw.text(xy, text, fill=txt_color, font=self.font)
else:
if box_style:
tf = max(self.lw - 1, 1) # font thickness
@ -310,7 +319,7 @@ def plot_images(images,
fname='images.jpg',
names=None,
on_plot=None):
# Plot image grid with labels
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor):

View File

@ -232,7 +232,7 @@ def get_flops(model, imgsz=640):
def get_flops_with_torch_profiler(model, imgsz=640):
# Compute model FLOPs (thop alternative)
"""Compute model FLOPs (thop alternative)."""
model = de_parallel(model)
p = next(model.parameters())
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride