ultralytics 8.0.58 new SimpleClass, fixes and updates (#1636)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-03-26 22:16:38 +02:00
committed by GitHub
parent ef03e6732a
commit ec10002a4a
30 changed files with 351 additions and 314 deletions

View File

@ -12,8 +12,8 @@ import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT
from .utils import HELP_URL, IMG_FORMATS
class BaseDataset(Dataset):

View File

@ -14,10 +14,10 @@ from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImage
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file
from ..utils import LOGGER, colorstr
from ..utils import LOGGER, RANK, colorstr
from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset
from .utils import PIN_MEMORY, RANK
from .utils import PIN_MEMORY
class InfiniteDataLoader(dataloader.DataLoader):

View File

@ -335,6 +335,7 @@ class LoadTensor:
def __init__(self, imgs) -> None:
self.im0 = imgs
self.bs = imgs.shape[0]
self.mode = 'image'
def __iter__(self):
self.count = 0
@ -346,6 +347,9 @@ class LoadTensor:
self.count += 1
return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, ''
def __len__(self):
return self.bs
def autocast_list(source):
"""

View File

@ -10,10 +10,10 @@ import torch
import torchvision
from tqdm import tqdm
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOCAL_RANK, LOGGER, get_hash, img2label_paths, verify_image_label
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
class YOLODataset(BaseDataset):

View File

@ -25,8 +25,6 @@ from ultralytics.yolo.utils.ops import segments2boxes
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation

View File

@ -2,6 +2,7 @@
import sys
from pathlib import Path
from typing import Union
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
@ -67,7 +68,7 @@ class YOLO:
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt', task=None, session=None) -> None:
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None, session=None) -> None:
"""
Initializes the YOLO model.
@ -87,6 +88,7 @@ class YOLO:
self.session = session # HUB session
# Load or create new YOLO model
model = str(model).strip() # strip spaces
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt

View File

@ -42,6 +42,18 @@ from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
STREAM_WARNING = """
WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
causing potential out-of-memory errors for large sources or long-running streams/videos.
Usage:
results = model(source=..., stream=True) # generator of Results objects
for r in results:
boxes = r.boxes # Boxes object for bbox outputs
masks = r.masks # Masks object for segment masks outputs
probs = r.probs # Class probabilities for classification outputs
"""
class BasePredictor:
"""
@ -108,6 +120,7 @@ class BasePredictor:
return preds
def __call__(self, source=None, model=None, stream=False):
self.stream = stream
if stream:
return self.stream_inference(source, model)
else:
@ -132,6 +145,10 @@ class BasePredictor:
stride=self.model.stride,
auto=self.model.pt)
self.source_type = self.dataset.source_type
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
len(self.dataset) > 1000 or # images
any(getattr(self.dataset, 'video_flag', [False]))): # videos
LOGGER.warning(STREAM_WARNING)
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()

View File

@ -5,7 +5,6 @@ Ultralytics Results, Boxes and Masks classes for handling inference results
Usage: See https://docs.ultralytics.com/modes/predict/
"""
import pprint
from copy import deepcopy
from functools import lru_cache
@ -13,11 +12,11 @@ import numpy as np
import torch
import torchvision.transforms.functional as F
from ultralytics.yolo.utils import LOGGER, ops
from ultralytics.yolo.utils import LOGGER, SimpleClass, ops
from ultralytics.yolo.utils.plotting import Annotator, colors
class Results:
class Results(SimpleClass):
"""
A class for storing and manipulating inference results.
@ -96,17 +95,6 @@ class Results:
for k in self.keys:
return len(getattr(self, k))
def __str__(self):
attr = {k: v for k, v in vars(self).items() if not isinstance(v, type(self))}
return pprint.pformat(attr, indent=2, width=120, depth=10, compact=True)
def __repr__(self):
return self.__str__()
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def keys(self):
return [k for k in self._keys if getattr(self, k) is not None]
@ -153,7 +141,7 @@ class Results:
return np.asarray(annotator.im) if annotator.pil else annotator.im
class Boxes:
class Boxes(SimpleClass):
"""
A class for storing and manipulating detection boxes.
@ -242,15 +230,6 @@ class Boxes:
def pandas(self):
LOGGER.info('results.pandas() method not yet implemented')
'''
new = copy(self) # return copy
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
return new
'''
@property
def shape(self):
@ -263,25 +242,11 @@ class Boxes:
def __len__(self): # override len(results)
return len(self.boxes)
def __str__(self):
return self.boxes.__str__()
def __repr__(self):
return (f'{self.__class__.__module__}.{self.__class__.__name__}\n'
f'type: {self.boxes.__class__.__module__}.{self.boxes.__class__.__name__}\n'
f'shape: {self.boxes.shape}\n'
f'dtype: {self.boxes.dtype}\n'
f'{self.boxes.__repr__()}')
def __getitem__(self, idx):
return Boxes(self.boxes[idx], self.orig_shape)
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
class Masks:
class Masks(SimpleClass):
"""
A class for storing and manipulating detection masks.
@ -301,11 +266,6 @@ class Masks:
numpy(): Returns a copy of the masks tensor as a numpy array.
cuda(): Returns a copy of the masks tensor on GPU memory.
to(): Returns a copy of the masks tensor with the specified device and dtype.
__len__(): Returns the number of masks in the tensor.
__str__(): Returns a string representation of the masks tensor.
__repr__(): Returns a detailed string representation of the masks tensor.
__getitem__(): Returns a new Masks object with the masks at the specified index.
__getattr__(): Raises an AttributeError with a list of valid attributes and properties.
"""
def __init__(self, masks, orig_shape) -> None:
@ -342,19 +302,5 @@ class Masks:
def __len__(self): # override len(results)
return len(self.masks)
def __str__(self):
return self.masks.__str__()
def __repr__(self):
return (f'{self.__class__.__module__}.{self.__class__.__name__}\n'
f'type: {self.masks.__class__.__module__}.{self.masks.__class__.__name__}\n'
f'shape: {self.masks.shape}\n'
f'dtype: {self.masks.dtype}\n'
f'{self.masks.__repr__()}')
def __getitem__(self, idx):
return Masks(self.masks[idx], self.orig_shape)
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

View File

@ -174,7 +174,12 @@ class BaseTrainer:
# Run subprocess if DDP training, else train normally
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
# Argument checks
if self.args.rect:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
self.args.rect = False
# Command
cmd, file = generate_ddp_command(world_size, self)
try:
LOGGER.info(f'Running DDP command {cmd}')
subprocess.run(cmd, check=True)
@ -183,17 +188,15 @@ class BaseTrainer:
finally:
ddp_cleanup(self, str(file))
else:
self._do_train(RANK, world_size)
self._do_train(world_size)
def _setup_ddp(self, rank, world_size):
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
def _setup_ddp(self, world_size):
torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK)
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=RANK, world_size=world_size)
def _setup_train(self, rank, world_size):
def _setup_train(self, world_size):
"""
Builds dataloaders and optimizer on correct rank process.
"""
@ -213,7 +216,7 @@ class BaseTrainer:
self.amp = bool(self.amp) # as boolean
self.scaler = amp.GradScaler(enabled=self.amp)
if world_size > 1:
self.model = DDP(self.model, device_ids=[rank])
self.model = DDP(self.model, device_ids=[RANK])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
@ -243,8 +246,8 @@ class BaseTrainer:
# dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
if rank in (-1, 0):
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
@ -256,11 +259,11 @@ class BaseTrainer:
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks('on_pretrain_routine_end')
def _do_train(self, rank=-1, world_size=1):
def _do_train(self, world_size=1):
if world_size > 1:
self._setup_ddp(rank, world_size)
self._setup_ddp(world_size)
self._setup_train(rank, world_size)
self._setup_train(world_size)
self.epoch_time = None
self.epoch_time_start = time.time()
@ -280,7 +283,7 @@ class BaseTrainer:
self.epoch = epoch
self.run_callbacks('on_train_epoch_start')
self.model.train()
if rank != -1:
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
@ -291,7 +294,7 @@ class BaseTrainer:
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
if rank in (-1, 0):
if RANK in (-1, 0):
LOGGER.info(self.progress_string())
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
self.tloss = None
@ -315,7 +318,7 @@ class BaseTrainer:
batch = self.preprocess_batch(batch)
preds = self.model(batch['img'])
self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1:
if RANK != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
@ -332,7 +335,7 @@ class BaseTrainer:
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if rank in (-1, 0):
if RANK in (-1, 0):
pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
@ -347,7 +350,7 @@ class BaseTrainer:
self.scheduler.step()
self.run_callbacks('on_train_epoch_end')
if rank in (-1, 0):
if RANK in (-1, 0):
# Validation
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
@ -377,7 +380,7 @@ class BaseTrainer:
if self.stop:
break # must break all DDP ranks
if rank in (-1, 0):
if RANK in (-1, 0):
# Do final val with best.pt
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
@ -408,7 +411,8 @@ class BaseTrainer:
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
del ckpt
def get_dataset(self, data):
@staticmethod
def get_dataset(data):
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""

View File

@ -22,11 +22,15 @@ import yaml
from ultralytics import __version__
# Constants
# PyTorch Multi-GPU DDP Constants
RANK = int(os.getenv('RANK', -1))
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
# Other Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO
DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
@ -92,25 +96,59 @@ HELP_MSG = \
"""
# Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long')
torch.set_printoptions(linewidth=320, precision=4, profile='default')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
class SimpleClass:
"""
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
access methods for easier debugging and usage.
"""
def __str__(self):
"""Return a human-readable string representation of the object."""
attr = []
for a in dir(self):
v = getattr(self, a)
if not callable(v) and not a.startswith('__'):
if isinstance(v, SimpleClass):
# Display only the module and class name for subclasses
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
else:
s = f'{a}: {repr(v)}'
attr.append(s)
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
def __repr__(self):
"""Return a machine-readable string representation of the object."""
return self.__str__()
def __getattr__(self, attr):
"""Custom attribute access error message with helpful information."""
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
class IterableSimpleNamespace(SimpleNamespace):
"""
Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
enables usage with dict() and for loops.
"""
def __iter__(self):
"""Return an iterator of key-value pairs from the namespace's attributes."""
return iter(vars(self).items())
def __str__(self):
"""Return a human-readable string representation of the object."""
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
def __getattr__(self, attr):
"""Custom attribute access error message with helpful information."""
name = self.__class__.__name__
raise AttributeError(f"""
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
@ -120,6 +158,7 @@ class IterableSimpleNamespace(SimpleNamespace):
""")
def get(self, key, default=None):
"""Return the value of the specified key if it exists; otherwise, return the default value."""
return getattr(self, key, default)

View File

@ -8,7 +8,7 @@ try:
assert clearml.__version__ # verify package is not directory
assert not TESTS_RUNNING # do not log pytest
except (ImportError, AssertionError):
except (ImportError, AssertionError, AttributeError):
clearml = None

View File

@ -7,7 +7,7 @@ try:
assert not TESTS_RUNNING # do not log pytest
assert comet_ml.__version__ # verify package is not directory
except (ImportError, AssertionError):
except (ImportError, AssertionError, AttributeError):
comet_ml = None

View File

@ -239,7 +239,7 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
if isinstance(suffix, str):
suffix = (suffix, )
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix
s = Path(f).suffix.lower().strip() # file suffix
if len(s):
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
@ -261,7 +261,7 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
def check_file(file, suffix='', download=True, hard=True):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to string
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
return file

View File

@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from ultralytics.yolo.utils import LOGGER, TryExcept
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept
# boxes
@ -425,7 +425,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
class Metric:
class Metric(SimpleClass):
"""
Class for computing evaluation metrics for YOLOv8 model.
@ -461,10 +461,6 @@ class Metric:
self.ap_class_index = [] # (nc, )
self.nc = 0
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def ap50(self):
"""AP@0.5 of all classes.
@ -550,7 +546,7 @@ class Metric:
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
class DetMetrics:
class DetMetrics(SimpleClass):
"""
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
(mAP) of an object detection model.
@ -585,10 +581,6 @@ class DetMetrics:
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp, conf, pred_cls, target_cls):
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
names=self.names)[2:]
@ -622,7 +614,7 @@ class DetMetrics:
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class SegmentMetrics:
class SegmentMetrics(SimpleClass):
"""
Calculates and aggregates detection and segmentation metrics over a given set of classes.
@ -657,10 +649,6 @@ class SegmentMetrics:
self.seg = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
"""
Processes the detection and segmentation metrics over the given set of predictions.
@ -724,7 +712,7 @@ class SegmentMetrics:
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class ClassifyMetrics:
class ClassifyMetrics(SimpleClass):
"""
Class for computing classification metrics including top-1 and top-5 accuracy.
@ -747,10 +735,6 @@ class ClassifyMetrics:
self.top5 = 0
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, targets, pred):
# target classes and predicted classes
pred, targets = torch.cat(pred), torch.cat(targets)

View File

@ -295,7 +295,7 @@ def plot_images(images,
for j, box in enumerate(boxes.T.tolist()):
c = classes[j]
color = colors(c)
c = names[c] if names else c
c = names.get(c, c) if names else c
if labels or conf[j] > 0.25: # 0.25 conf thresh
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
annotator.box_label(box, label, color=color)

View File

@ -8,6 +8,7 @@ import time
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
from typing import Union
import numpy as np
import thop
@ -15,15 +16,10 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
from ultralytics.yolo.utils.checks import check_version
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
TORCH_1_12 = check_version(torch.__version__, '1.12.0')
@ -49,17 +45,6 @@ def smart_inference_mode():
return decorate
def DDP_model(model):
# Model DDP creation with checks
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
if TORCH_1_11:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
else:
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
def select_device(device='', batch=0, newline=False, verbose=True):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
@ -141,6 +126,7 @@ def fuse_conv_and_bn(conv, bn):
def fuse_deconv_and_bn(deconv, bn):
# Fuse ConvTranspose2d() and BatchNorm2d() layers
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
@ -186,14 +172,17 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
def get_num_params(model):
# Return the total number of parameters in a YOLO model
return sum(x.numel() for x in model.parameters())
def get_num_gradients(model):
# Return the total number of parameters with gradients in a YOLO model
return sum(x.numel() for x in model.parameters() if x.requires_grad)
def get_flops(model, imgsz=640):
# Return a YOLO model's FLOPs
try:
model = de_parallel(model)
p = next(model.parameters())
@ -208,6 +197,7 @@ def get_flops(model, imgsz=640):
def initialize_weights(model):
# Initialize model weights to random values
for m in model.modules():
t = type(m)
if t is nn.Conv2d:
@ -239,7 +229,7 @@ def make_divisible(x, divisor):
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
# Copy attributes from 'b' to 'a', options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
@ -322,7 +312,7 @@ class ModelEMA:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f='best.pt', s=''):
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.

View File

@ -126,11 +126,11 @@ class SegLoss(Loss):
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box / batch_size # seg gain