ultralytics 8.0.62
HUB Syntax updates and fixes (#1795)
Co-authored-by: Danny Kim <imbird0312@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: MagicCodess <32194768+MagicCodess@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Amjad Alsharafi <26300843+Amjad50@users.noreply.github.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = '8.0.61'
|
||||
__version__ = '8.0.62'
|
||||
|
||||
from ultralytics.hub import start
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
|
@ -9,8 +9,8 @@ from types import SimpleNamespace
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, get_settings, yaml_load,
|
||||
yaml_print)
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
|
||||
get_settings, yaml_load, yaml_print)
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
@ -71,7 +71,7 @@ CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic'
|
||||
'line_thickness', 'workspace', 'nbs', 'save_period')
|
||||
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
||||
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt',
|
||||
'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms',
|
||||
'save_conf', 'save_crop', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms',
|
||||
'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader')
|
||||
|
||||
|
||||
@ -140,6 +140,22 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def _handle_deprecation(custom):
|
||||
"""
|
||||
Hardcoded function to handle deprecated config keys
|
||||
"""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
if key == 'hide_labels':
|
||||
deprecation_warn(key, 'show_labels')
|
||||
custom['show_labels'] = custom.pop('hide_labels') == 'False'
|
||||
if key == 'hide_conf':
|
||||
deprecation_warn(key, 'show_conf')
|
||||
custom['show_conf'] = custom.pop('hide_conf') == 'False'
|
||||
|
||||
return custom
|
||||
|
||||
|
||||
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||
"""
|
||||
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
||||
@ -149,6 +165,7 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||
- custom (Dict): a dictionary of custom configuration options
|
||||
- base (Dict): a dictionary of base configuration options
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
mismatched = [x for x in custom if x not in base]
|
||||
if mismatched:
|
||||
|
@ -55,8 +55,8 @@ show: False # show results if possible
|
||||
save_txt: False # save results as .txt file
|
||||
save_conf: False # save results with confidence scores
|
||||
save_crop: False # save cropped images with results
|
||||
hide_labels: False # hide labels
|
||||
hide_conf: False # hide confidence scores
|
||||
show_labels: True # show object labels in plots
|
||||
show_conf: True # show object confidence scores in plots
|
||||
vid_stride: 1 # video frame-rate stride
|
||||
line_thickness: 3 # bounding box thickness (pixels)
|
||||
visualize: False # visualize model features
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@ -77,7 +78,7 @@ class YOLO:
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.callbacks = deepcopy(callbacks.default_callbacks)
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
@ -91,7 +92,7 @@ class YOLO:
|
||||
model = str(model).strip() # strip spaces
|
||||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if model.startswith('https://hub.ultralytics.com/models/'):
|
||||
if self.is_hub_model(model):
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
self.session = HUBTrainingSession(model)
|
||||
model = self.session.model_file
|
||||
@ -112,6 +113,13 @@ class YOLO:
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
return any((
|
||||
model.startswith('https://hub.ultralytics.com/models/'),
|
||||
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
||||
(len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID
|
||||
|
||||
def _new(self, cfg: str, task=None, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
@ -220,8 +228,7 @@ class YOLO:
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \
|
||||
('predict' in sys.argv or 'mode=predict' in sys.argv)
|
||||
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
@ -231,7 +238,7 @@ class YOLO:
|
||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
@ -380,19 +387,17 @@ class YOLO:
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
callbacks.default_callbacks[event].append(func)
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
@staticmethod
|
||||
def _reset_callbacks():
|
||||
def _reset_callbacks(self):
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
@ -75,7 +75,7 @@ class BasePredictor:
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
@ -104,7 +104,7 @@ class BasePredictor:
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.batch = None
|
||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||
self.callbacks = defaultdict(list, _callbacks) if _callbacks else defaultdict(list, callbacks.default_callbacks)
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, img):
|
||||
|
@ -12,7 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, ops
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
|
||||
|
||||
@ -65,7 +65,7 @@ class Results(SimpleClass):
|
||||
self.boxes = Boxes(boxes, self.orig_shape)
|
||||
if masks is not None:
|
||||
self.masks = Masks(masks, self.orig_shape)
|
||||
if boxes is not None:
|
||||
if probs is not None:
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
@ -100,46 +100,72 @@ class Results(SimpleClass):
|
||||
def keys(self):
|
||||
return [k for k in self._keys if getattr(self, k) is not None]
|
||||
|
||||
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
def plot(
|
||||
self,
|
||||
conf=True,
|
||||
line_width=None,
|
||||
font_size=None,
|
||||
font='Arial.ttf',
|
||||
pil=False,
|
||||
example='abc',
|
||||
img=None,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
masks=True,
|
||||
probs=True,
|
||||
**kwargs # deprecated args TODO: remove support in 8.2
|
||||
):
|
||||
"""
|
||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||
|
||||
Args:
|
||||
show_conf (bool): Whether to show the detection confidence score.
|
||||
conf (bool): Whether to plot the detection confidence score.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
example (str): An example string to display. Useful for indicating the expected format of the output.
|
||||
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
||||
labels (bool): Whether to plot the label of bounding boxes.
|
||||
boxes (bool): Whether to plot the bounding boxes.
|
||||
masks (bool): Whether to plot the masks.
|
||||
probs (bool): Whether to plot classification probability
|
||||
|
||||
Returns:
|
||||
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
|
||||
"""
|
||||
annotator = Annotator(deepcopy(self.orig_img), line_width, font_size, font, pil, example)
|
||||
boxes = self.boxes
|
||||
masks = self.masks
|
||||
probs = self.probs
|
||||
# Deprecation warn TODO: remove in 8.2
|
||||
if 'show_conf' in kwargs:
|
||||
deprecation_warn('show_conf', 'conf')
|
||||
conf = kwargs['show_conf']
|
||||
assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False'
|
||||
|
||||
annotator = Annotator(deepcopy(self.orig_img if img is None else img), line_width, font_size, font, pil,
|
||||
example)
|
||||
pred_boxes, show_boxes = self.boxes, boxes
|
||||
pred_masks, show_masks = self.masks, masks
|
||||
pred_probs, show_probs = self.probs, probs
|
||||
names = self.names
|
||||
hide_labels, hide_conf = False, not show_conf
|
||||
if boxes is not None:
|
||||
for d in reversed(boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
||||
if pred_boxes and show_boxes:
|
||||
for d in reversed(pred_boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||
name = ('' if id is None else f'id:{id} ') + names[c]
|
||||
label = None if hide_labels else (name if hide_conf else f'{name} {conf:.2f}')
|
||||
label = (name if not conf else f'{name} {conf:.2f}') if labels else None
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
if masks is not None:
|
||||
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
|
||||
if pred_masks and show_masks:
|
||||
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=pred_masks.data.device).permute(2, 0,
|
||||
1).flip(0)
|
||||
if TORCHVISION_0_10:
|
||||
im = F.resize(im.contiguous(), masks.data.shape[1:], antialias=True) / 255
|
||||
im = F.resize(im.contiguous(), pred_masks.data.shape[1:], antialias=True) / 255
|
||||
else:
|
||||
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
|
||||
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
|
||||
im = F.resize(im.contiguous(), pred_masks.data.shape[1:]) / 255
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=im)
|
||||
|
||||
if probs is not None:
|
||||
if pred_probs is not None and show_probs:
|
||||
n5 = min(len(names), 5)
|
||||
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
text = f"{', '.join(f'{names[j] if names else j} {probs[j]:.2f}' for j in top5i)}, "
|
||||
top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||
text = f"{', '.join(f'{names[j] if names else j} {pred_probs[j]:.2f}' for j in top5i)}, "
|
||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
return np.asarray(annotator.im) if annotator.pil else annotator.im
|
||||
|
@ -624,7 +624,8 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'):
|
||||
|
||||
# Check that settings keys and types match defaults
|
||||
correct = \
|
||||
settings.keys() == defaults.keys() \
|
||||
settings \
|
||||
and settings.keys() == defaults.keys() \
|
||||
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
||||
and check_version(settings['settings_version'], version)
|
||||
if not correct:
|
||||
@ -646,6 +647,14 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
|
||||
yaml_save(file, SETTINGS)
|
||||
|
||||
|
||||
def deprecation_warn(arg, new_arg, version=None):
|
||||
if not version:
|
||||
version = float(__version__[0:3]) + 0.2 # deprecate after 2nd major release
|
||||
LOGGER.warning(
|
||||
f'WARNING: `{arg}` is deprecated and will be removed in upcoming major release {version}. Use `{new_arg}` instead'
|
||||
)
|
||||
|
||||
|
||||
# Run below code on yolo/utils init ------------------------------------------------------------------------------------
|
||||
|
||||
# Check first-install steps
|
||||
|
@ -70,7 +70,7 @@ class DetectionPredictor(BasePredictor):
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
if self.args.save or self.args.show: # Add bbox to image
|
||||
name = ('' if id is None else f'id:{id} ') + self.model.names[c]
|
||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
||||
label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None
|
||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
if self.args.save_crop:
|
||||
save_one_box(d.xyxy,
|
||||
|
@ -84,7 +84,7 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
if self.args.save or self.args.show: # Add bbox to image
|
||||
name = ('' if id is None else f'id:{id} ') + self.model.names[c]
|
||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
||||
label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None
|
||||
if self.args.boxes:
|
||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
if self.args.save_crop:
|
||||
|
Reference in New Issue
Block a user