`ultralytics 8.0.133` add `torchvision` compatibility check (#3703)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 0821ccb618
commit c55a98ab8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,6 +43,11 @@ keywords: YOLO, Ultralytics, Utils, Checks, image sizing, version updates, font
### ::: ultralytics.yolo.utils.checks.check_requirements ### ::: ultralytics.yolo.utils.checks.check_requirements
<br><br> <br><br>
## check_torchvision
---
### ::: ultralytics.yolo.utils.checks.check_torchvision
<br><br>
## check_suffix ## check_suffix
--- ---
### ::: ultralytics.yolo.utils.checks.check_suffix ### ::: ultralytics.yolo.utils.checks.check_suffix

@ -66,7 +66,7 @@
"import ultralytics\n", "import ultralytics\n",
"ultralytics.checks()" "ultralytics.checks()"
], ],
"execution_count": 1, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -102,7 +102,7 @@
"# Run inference on an image with YOLOv8n\n", "# Run inference on an image with YOLOv8n\n",
"!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'" "!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'"
], ],
"execution_count": 2, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -169,7 +169,7 @@
"# Validate YOLOv8n on COCO128 val\n", "# Validate YOLOv8n on COCO128 val\n",
"!yolo val model=yolov8n.pt data=coco128.yaml" "!yolo val model=yolov8n.pt data=coco128.yaml"
], ],
"execution_count": 3, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -293,7 +293,7 @@
"# Train YOLOv8n on COCO128 for 3 epochs\n", "# Train YOLOv8n on COCO128 for 3 epochs\n",
"!yolo train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640" "!yolo train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640"
], ],
"execution_count": 4, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -454,8 +454,8 @@
"- 💡 ProTip: Export to [TensorRT](https://developer.nvidia.com/tensorrt) for up to 5x GPU speedup.\n", "- 💡 ProTip: Export to [TensorRT](https://developer.nvidia.com/tensorrt) for up to 5x GPU speedup.\n",
"\n", "\n",
"\n", "\n",
"| Format | `format=` | Model |\n", "| Format | `format` Argument | Model |\n",
"|----------------------------------------------------------------------------|--------------------|---------------------------|\n", "|----------------------------------------------------------------------------|-------------------|---------------------------|\n",
"| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` |\n", "| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` |\n",
"| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` |\n", "| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` |\n",
"| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` |\n", "| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` |\n",
@ -468,7 +468,7 @@
"| [TensorFlow Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` |\n", "| [TensorFlow Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` |\n",
"| [TensorFlow.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` |\n", "| [TensorFlow.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` |\n",
"| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` |\n", "| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` |\n",
"\n" "| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` |\n"
], ],
"metadata": { "metadata": {
"id": "nPZZeNrLCQG6" "id": "nPZZeNrLCQG6"
@ -486,7 +486,7 @@
"id": "CYIjW4igCjqD", "id": "CYIjW4igCjqD",
"outputId": "fc41bf7a-0ea2-41a6-9ec5-dd0455af43bc" "outputId": "fc41bf7a-0ea2-41a6-9ec5-dd0455af43bc"
}, },
"execution_count": 5, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
@ -533,7 +533,7 @@
"results = model.train(data='coco128.yaml', epochs=3) # train the model\n", "results = model.train(data='coco128.yaml', epochs=3) # train the model\n",
"results = model.val() # evaluate model performance on the validation set\n", "results = model.val() # evaluate model performance on the validation set\n",
"results = model('https://ultralytics.com/images/bus.jpg') # predict on an image\n", "results = model('https://ultralytics.com/images/bus.jpg') # predict on an image\n",
"success = model.export(format='onnx') # export the model to ONNX format" "results = model.export(format='onnx') # export the model to ONNX format"
], ],
"metadata": { "metadata": {
"id": "bpF9-vS_DAaf" "id": "bpF9-vS_DAaf"
@ -677,9 +677,8 @@
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Git clone and run tests on updates branch\n", "# Git clone and run tests on updates branch\n",
"!git clone https://github.com/ultralytics/ultralytics -b updates\n", "!git clone https://github.com/ultralytics/ultralytics -b main\n",
"%pip install -qe ultralytics\n", "%pip install -qe ultralytics"
"!pytest ultralytics/tests"
], ],
"metadata": { "metadata": {
"id": "uRKlwxSJdhd1" "id": "uRKlwxSJdhd1"
@ -687,6 +686,18 @@
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
}, },
{
"cell_type": "code",
"source": [
"# Run tests (Git clone only)\n",
"!pytest ultralytics/tests"
],
"metadata": {
"id": "GtPlh7mcCGZX"
},
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.132' __version__ = '8.0.133'
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR from ultralytics.vit.rtdetr import RTDETR

@ -25,11 +25,11 @@ class HUBTrainingSession:
model_id (str): Identifier for the YOLOv5 model being trained. model_id (str): Identifier for the YOLOv5 model being trained.
model_url (str): URL for the model in Ultralytics HUB. model_url (str): URL for the model in Ultralytics HUB.
api_url (str): API URL for the model in Ultralytics HUB. api_url (str): API URL for the model in Ultralytics HUB.
auth_header (Dict): Authentication header for the Ultralytics HUB API requests. auth_header (dict): Authentication header for the Ultralytics HUB API requests.
rate_limits (Dict): Rate limits for different API calls (in seconds). rate_limits (dict): Rate limits for different API calls (in seconds).
timers (Dict): Timers for rate limiting. timers (dict): Timers for rate limiting.
metrics_queue (Dict): Queue for the model's metrics. metrics_queue (dict): Queue for the model's metrics.
model (Dict): Model data fetched from Ultralytics HUB. model (dict): Model data fetched from Ultralytics HUB.
alive (bool): Indicates if the heartbeat loop is active. alive (bool): Indicates if the heartbeat loop is active.
""" """

@ -601,7 +601,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
# Parse a YOLO model.yaml dictionary into a PyTorch model """Parse a YOLO model.yaml dictionary into a PyTorch model."""
import ast import ast
# Args # Args

@ -171,8 +171,8 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program. If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
Args: Args:
custom (Dict): a dictionary of custom configuration options custom (dict): a dictionary of custom configuration options
base (Dict): a dictionary of base configuration options base (dict): a dictionary of base configuration options
""" """
custom = _handle_deprecation(custom) custom = _handle_deprecation(custom)
base, custom = (set(x.keys()) for x in (base, custom)) base, custom = (set(x.keys()) for x in (base, custom))

@ -642,7 +642,8 @@ class CopyPaste:
class Albumentations: class Albumentations:
# YOLOv8 Albumentations class (optional, only used if package is installed) """YOLOv8 Albumentations class (optional, only used if package is installed)"""
def __init__(self, p=1.0): def __init__(self, p=1.0):
"""Initialize the transform object for YOLO bbox formatted params.""" """Initialize the transform object for YOLO bbox formatted params."""
self.p = p self.p = p
@ -819,7 +820,7 @@ def classify_albumentations(
std=(1.0, 1.0, 1.0), # IMAGENET_STD std=(1.0, 1.0, 1.0), # IMAGENET_STD
auto_aug=False, auto_aug=False,
): ):
# YOLOv8 classification Albumentations (optional, only used if package is installed) """YOLOv8 classification Albumentations (optional, only used if package is installed)."""
prefix = colorstr('albumentations: ') prefix = colorstr('albumentations: ')
try: try:
import albumentations as A import albumentations as A
@ -851,7 +852,8 @@ def classify_albumentations(
class ClassifyLetterBox: class ClassifyLetterBox:
# YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) """YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
def __init__(self, size=(640, 640), auto=False, stride=32): def __init__(self, size=(640, 640), auto=False, stride=32):
"""Resizes image and crops it to center with max dimensions 'h' and 'w'.""" """Resizes image and crops it to center with max dimensions 'h' and 'w'."""
super().__init__() super().__init__()
@ -871,7 +873,8 @@ class ClassifyLetterBox:
class CenterCrop: class CenterCrop:
# YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()]) """YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
def __init__(self, size=640): def __init__(self, size=640):
"""Converts an image from numpy array to PyTorch tensor.""" """Converts an image from numpy array to PyTorch tensor."""
super().__init__() super().__init__()
@ -885,7 +888,8 @@ class CenterCrop:
class ToTensor: class ToTensor:
# YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) """YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
def __init__(self, half=False): def __init__(self, half=False):
"""Initialize YOLOv8 ToTensor object with optional half-precision support.""" """Initialize YOLOv8 ToTensor object with optional half-precision support."""
super().__init__() super().__init__()

@ -63,7 +63,7 @@ class _RepeatSampler:
def seed_worker(worker_id): # noqa def seed_worker(worker_id): # noqa
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
worker_seed = torch.initial_seed() % 2 ** 32 worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed) np.random.seed(worker_seed)
random.seed(worker_seed) random.seed(worker_seed)

@ -29,7 +29,8 @@ class SourceTypes:
class LoadStreams: class LoadStreams:
# YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` """YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1): def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
"""Initialize instance variables and check for consistent input stream shapes.""" """Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference torch.backends.cudnn.benchmark = True # faster for fixed-size inference
@ -116,7 +117,8 @@ class LoadStreams:
class LoadScreenshots: class LoadScreenshots:
# YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen` """YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
def __init__(self, source, imgsz=640): def __init__(self, source, imgsz=640):
"""source = [screen_number left top width height] (pixels).""" """source = [screen_number left top width height] (pixels)."""
check_requirements('mss') check_requirements('mss')
@ -158,7 +160,8 @@ class LoadScreenshots:
class LoadImages: class LoadImages:
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4` """YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
def __init__(self, path, imgsz=640, vid_stride=1): def __init__(self, path, imgsz=640, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found.""" """Initialize the Dataloader and raise FileNotFoundError if file not found."""
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line

@ -278,7 +278,7 @@ def check_cls_dataset(dataset: str, split=''):
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''. split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns: Returns:
dict: A dictionary containing the following keys: (dict): A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset. - 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset. - 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset. - 'test' (Path): The directory path containing the test set of the dataset.

@ -213,16 +213,18 @@ class Results(SimpleClass):
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3' assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
names = self.names names = self.names
annotator = Annotator(deepcopy(self.orig_img if img is None else img), pred_boxes, show_boxes = self.boxes, boxes
pred_masks, show_masks = self.masks, masks
pred_probs, show_probs = self.probs, probs
annotator = Annotator(
deepcopy(self.orig_img if img is None else img),
line_width, line_width,
font_size, font_size,
font, font,
pil, pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
example=names) example=names)
pred_boxes, show_boxes = self.boxes, boxes
pred_masks, show_masks = self.masks, masks # Plot Segment results
pred_probs, show_probs = self.probs, probs
keypoints = self.keypoints
if pred_masks and show_masks: if pred_masks and show_masks:
if img_gpu is None: if img_gpu is None:
img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
@ -231,6 +233,7 @@ class Results(SimpleClass):
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu) annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
# Plot Detect results
if pred_boxes and show_boxes: if pred_boxes and show_boxes:
for d in reversed(pred_boxes): for d in reversed(pred_boxes):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
@ -238,12 +241,15 @@ class Results(SimpleClass):
label = (f'{name} {conf:.2f}' if conf else name) if labels else None label = (f'{name} {conf:.2f}' if conf else name) if labels else None
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
# Plot Classify results
if pred_probs is not None and show_probs: if pred_probs is not None and show_probs:
text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, " text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors x = round(self.orig_shape[0] * 0.03)
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
if keypoints is not None: # Plot Pose results
for k in reversed(keypoints.data): if self.keypoints is not None:
for k in reversed(self.keypoints.data):
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line) annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
return annotator.result() return annotator.result()

@ -211,6 +211,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
""" """
prefix = colorstr('red', 'bold', 'requirements:') prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version check_python() # check python version
check_torchvision() # check torch-torchvision compatibility
file = None file = None
if isinstance(requirements, Path): # requirements.txt file if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve() file = requirements.resolve()
@ -255,6 +256,34 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
return True return True
def check_torchvision():
"""
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
Torchvision versions.
"""
import torchvision
# Compatibility table
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
# Extract only the major and minor versions
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
'For a full compatibility table see https://github.com/pytorch/vision#installation')
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''): def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
"""Check file(s) for acceptable suffix.""" """Check file(s) for acceptable suffix."""
if file and suffix: if file and suffix:
@ -402,7 +431,7 @@ def check_amp(model):
def git_describe(path=ROOT): # path must be a directory def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
try: try:
assert (Path(path) / '.git').is_dir() assert (Path(path) / '.git').is_dir()
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]

@ -91,7 +91,7 @@ def get_latest_run(search_dir='.'):
def make_dirs(dir='new_dir/'): def make_dirs(dir='new_dir/'):
# Create folders """Create directories."""
dir = Path(dir) dir = Path(dir)
if dir.exists(): if dir.exists():
shutil.rmtree(dir) # delete dir shutil.rmtree(dir) # delete dir

@ -55,12 +55,17 @@ class Profile(contextlib.ContextDecorator):
return time.time() return time.time()
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) def coco80_to_coco91_class(): #
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ """
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') Converts 80-index (val2014) to 91-index (paper).
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet Example:
a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
"""
return [ return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,

@ -34,7 +34,7 @@ _torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs): def torch_save(*args, **kwargs):
# Use dill (if exists) to serialize the lambda functions where pickle does not do this """Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
try: try:
import dill as pickle import dill as pickle
except ImportError: except ImportError:

@ -21,7 +21,8 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
class Colors: class Colors:
# Ultralytics color palette https://ultralytics.com/ """Ultralytics color palette https://ultralytics.com/."""
def __init__(self): def __init__(self):
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB', hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
@ -48,7 +49,8 @@ colors = Colors() # create instance for 'from utils.plots import colors'
class Annotator: class Annotator:
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations """YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations."""
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.""" """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.' assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
@ -204,6 +206,13 @@ class Annotator:
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color) self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
# Using `txt_color` for background and draw fg with white color # Using `txt_color` for background and draw fg with white color
txt_color = (255, 255, 255) txt_color = (255, 255, 255)
if '\n' in text:
lines = text.split('\n')
_, h = self.font.getsize(text)
for line in lines:
self.draw.text(xy, line, fill=txt_color, font=self.font)
xy[1] += h
else:
self.draw.text(xy, text, fill=txt_color, font=self.font) self.draw.text(xy, text, fill=txt_color, font=self.font)
else: else:
if box_style: if box_style:
@ -310,7 +319,7 @@ def plot_images(images,
fname='images.jpg', fname='images.jpg',
names=None, names=None,
on_plot=None): on_plot=None):
# Plot image grid with labels """Plot image grid with labels."""
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy() images = images.cpu().float().numpy()
if isinstance(cls, torch.Tensor): if isinstance(cls, torch.Tensor):

@ -232,7 +232,7 @@ def get_flops(model, imgsz=640):
def get_flops_with_torch_profiler(model, imgsz=640): def get_flops_with_torch_profiler(model, imgsz=640):
# Compute model FLOPs (thop alternative) """Compute model FLOPs (thop alternative)."""
model = de_parallel(model) model = de_parallel(model)
p = next(model.parameters()) p = next(model.parameters())
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride

Loading…
Cancel
Save