ultralytics 8.0.58 new SimpleClass, fixes and updates (#1636)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-03-26 22:16:38 +02:00
committed by GitHub
parent ef03e6732a
commit ec10002a4a
30 changed files with 351 additions and 314 deletions

View File

@ -22,11 +22,15 @@ import yaml
from ultralytics import __version__
# Constants
# PyTorch Multi-GPU DDP Constants
RANK = int(os.getenv('RANK', -1))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
# Other Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO
DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
@ -92,25 +96,59 @@ HELP_MSG = \
"""
# Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long')
torch.set_printoptions(linewidth=320, precision=4, profile='default')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
class SimpleClass:
"""
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
access methods for easier debugging and usage.
"""
def __str__(self):
"""Return a human-readable string representation of the object."""
attr = []
for a in dir(self):
v = getattr(self, a)
if not callable(v) and not a.startswith('__'):
if isinstance(v, SimpleClass):
# Display only the module and class name for subclasses
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
else:
s = f'{a}: {repr(v)}'
attr.append(s)
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
def __repr__(self):
"""Return a machine-readable string representation of the object."""
return self.__str__()
def __getattr__(self, attr):
"""Custom attribute access error message with helpful information."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
class IterableSimpleNamespace(SimpleNamespace):
"""
Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
enables usage with dict() and for loops.
"""
def __iter__(self):
"""Return an iterator of key-value pairs from the namespace's attributes."""
return iter(vars(self).items())
def __str__(self):
"""Return a human-readable string representation of the object."""
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
def __getattr__(self, attr):
"""Custom attribute access error message with helpful information."""
name = self.__class__.__name__
raise AttributeError(f"""
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
@ -120,6 +158,7 @@ class IterableSimpleNamespace(SimpleNamespace):
""")
def get(self, key, default=None):
"""Return the value of the specified key if it exists; otherwise, return the default value."""
return getattr(self, key, default)

View File

@ -8,7 +8,7 @@ try:
assert clearml.__version__ # verify package is not directory
assert not TESTS_RUNNING # do not log pytest
except (ImportError, AssertionError):
except (ImportError, AssertionError, AttributeError):
clearml = None

View File

@ -7,7 +7,7 @@ try:
assert not TESTS_RUNNING # do not log pytest
assert comet_ml.__version__ # verify package is not directory
except (ImportError, AssertionError):
except (ImportError, AssertionError, AttributeError):
comet_ml = None

View File

@ -239,7 +239,7 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
if isinstance(suffix, str):
suffix = (suffix, )
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix
s = Path(f).suffix.lower().strip() # file suffix
if len(s):
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
@ -261,7 +261,7 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
def check_file(file, suffix='', download=True, hard=True):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to string
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
return file

View File

@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from ultralytics.yolo.utils import LOGGER, TryExcept
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept
# boxes
@ -425,7 +425,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
class Metric:
class Metric(SimpleClass):
"""
Class for computing evaluation metrics for YOLOv8 model.
@ -461,10 +461,6 @@ class Metric:
self.ap_class_index = [] # (nc, )
self.nc = 0
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def ap50(self):
"""AP@0.5 of all classes.
@ -550,7 +546,7 @@ class Metric:
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
class DetMetrics:
class DetMetrics(SimpleClass):
"""
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
(mAP) of an object detection model.
@ -585,10 +581,6 @@ class DetMetrics:
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp, conf, pred_cls, target_cls):
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
names=self.names)[2:]
@ -622,7 +614,7 @@ class DetMetrics:
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class SegmentMetrics:
class SegmentMetrics(SimpleClass):
"""
Calculates and aggregates detection and segmentation metrics over a given set of classes.
@ -657,10 +649,6 @@ class SegmentMetrics:
self.seg = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
"""
Processes the detection and segmentation metrics over the given set of predictions.
@ -724,7 +712,7 @@ class SegmentMetrics:
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class ClassifyMetrics:
class ClassifyMetrics(SimpleClass):
"""
Class for computing classification metrics including top-1 and top-5 accuracy.
@ -747,10 +735,6 @@ class ClassifyMetrics:
self.top5 = 0
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, targets, pred):
# target classes and predicted classes
pred, targets = torch.cat(pred), torch.cat(targets)

View File

@ -295,7 +295,7 @@ def plot_images(images,
for j, box in enumerate(boxes.T.tolist()):
c = classes[j]
color = colors(c)
c = names[c] if names else c
c = names.get(c, c) if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
annotator.box_label(box, label, color=color)

View File

@ -8,6 +8,7 @@ import time
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
from typing import Union
import numpy as np
import thop
@ -15,15 +16,10 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
from ultralytics.yolo.utils.checks import check_version
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
TORCH_1_12 = check_version(torch.__version__, '1.12.0')
@ -49,17 +45,6 @@ def smart_inference_mode():
return decorate
def DDP_model(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
if TORCH_1_11:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
def select_device(device='', batch=0, newline=False, verbose=True):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
@ -141,6 +126,7 @@ def fuse_conv_and_bn(conv, bn):
def fuse_deconv_and_bn(deconv, bn):
# Fuse ConvTranspose2d() and BatchNorm2d() layers
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
@ -186,14 +172,17 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
def get_num_params(model):
# Return the total number of parameters in a YOLO model
return sum(x.numel() for x in model.parameters())
def get_num_gradients(model):
# Return the total number of parameters with gradients in a YOLO model
return sum(x.numel() for x in model.parameters() if x.requires_grad)
def get_flops(model, imgsz=640):
# Return a YOLO model's FLOPs
try:
model = de_parallel(model)
p = next(model.parameters())
@ -208,6 +197,7 @@ def get_flops(model, imgsz=640):
def initialize_weights(model):
# Initialize model weights to random values
for m in model.modules():
t = type(m)
if t is nn.Conv2d:
@ -239,7 +229,7 @@ def make_divisible(x, divisor):
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
# Copy attributes from 'b' to 'a', options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
@ -322,7 +312,7 @@ class ModelEMA:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f='best.pt', s=''):
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.