ultralytics 8.0.133
add torchvision
compatibility check (#3703)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -211,6 +211,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
"""
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
file = None
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
@ -255,6 +256,34 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
return True
|
||||
|
||||
|
||||
def check_torchvision():
|
||||
"""
|
||||
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
|
||||
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
Torchvision versions.
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
|
||||
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
'For a full compatibility table see https://github.com/pytorch/vision#installation')
|
||||
|
||||
|
||||
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||
"""Check file(s) for acceptable suffix."""
|
||||
if file and suffix:
|
||||
@ -402,7 +431,7 @@ def check_amp(model):
|
||||
|
||||
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||
try:
|
||||
assert (Path(path) / '.git').is_dir()
|
||||
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||
|
@ -91,7 +91,7 @@ def get_latest_run(search_dir='.'):
|
||||
|
||||
|
||||
def make_dirs(dir='new_dir/'):
|
||||
# Create folders
|
||||
"""Create directories."""
|
||||
dir = Path(dir)
|
||||
if dir.exists():
|
||||
shutil.rmtree(dir) # delete dir
|
||||
|
@ -55,12 +55,17 @@ class Profile(contextlib.ContextDecorator):
|
||||
return time.time()
|
||||
|
||||
|
||||
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
||||
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
||||
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
||||
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
||||
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
||||
def coco80_to_coco91_class(): #
|
||||
"""
|
||||
Converts 80-index (val2014) to 91-index (paper).
|
||||
For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
|
||||
|
||||
Example:
|
||||
a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||
b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
||||
x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
||||
x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
||||
"""
|
||||
return [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
||||
|
@ -34,7 +34,7 @@ _torch_save = torch.save # copy to avoid recursion errors
|
||||
|
||||
|
||||
def torch_save(*args, **kwargs):
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
|
@ -21,7 +21,8 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class Colors:
|
||||
# Ultralytics color palette https://ultralytics.com/
|
||||
"""Ultralytics color palette https://ultralytics.com/."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||
@ -48,7 +49,8 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
||||
|
||||
|
||||
class Annotator:
|
||||
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
||||
"""YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations."""
|
||||
|
||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
||||
@ -204,7 +206,14 @@ class Annotator:
|
||||
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
|
||||
# Using `txt_color` for background and draw fg with white color
|
||||
txt_color = (255, 255, 255)
|
||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||
if '\n' in text:
|
||||
lines = text.split('\n')
|
||||
_, h = self.font.getsize(text)
|
||||
for line in lines:
|
||||
self.draw.text(xy, line, fill=txt_color, font=self.font)
|
||||
xy[1] += h
|
||||
else:
|
||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||
else:
|
||||
if box_style:
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
@ -310,7 +319,7 @@ def plot_images(images,
|
||||
fname='images.jpg',
|
||||
names=None,
|
||||
on_plot=None):
|
||||
# Plot image grid with labels
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
if isinstance(cls, torch.Tensor):
|
||||
|
@ -232,7 +232,7 @@ def get_flops(model, imgsz=640):
|
||||
|
||||
|
||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
# Compute model FLOPs (thop alternative)
|
||||
"""Compute model FLOPs (thop alternative)."""
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||||
|
Reference in New Issue
Block a user