ultralytics 8.0.58 new SimpleClass, fixes and updates (#1636)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-03-26 22:16:38 +02:00
committed by GitHub
parent ef03e6732a
commit ec10002a4a
30 changed files with 351 additions and 314 deletions

View File

@ -2,6 +2,7 @@
import sys
from pathlib import Path
from typing import Union
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
@ -67,7 +68,7 @@ class YOLO:
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt', task=None, session=None) -> None:
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None, session=None) -> None:
"""
Initializes the YOLO model.
@ -87,6 +88,7 @@ class YOLO:
self.session = session # HUB session
# Load or create new YOLO model
model = str(model).strip() # strip spaces
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt

View File

@ -42,6 +42,18 @@ from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
STREAM_WARNING = """
WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
causing potential out-of-memory errors for large sources or long-running streams/videos.
Usage:
results = model(source=..., stream=True) # generator of Results objects
for r in results:
boxes = r.boxes # Boxes object for bbox outputs
masks = r.masks # Masks object for segment masks outputs
probs = r.probs # Class probabilities for classification outputs
"""
class BasePredictor:
"""
@ -108,6 +120,7 @@ class BasePredictor:
return preds
def __call__(self, source=None, model=None, stream=False):
self.stream = stream
if stream:
return self.stream_inference(source, model)
else:
@ -132,6 +145,10 @@ class BasePredictor:
stride=self.model.stride,
auto=self.model.pt)
self.source_type = self.dataset.source_type
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
len(self.dataset) > 1000 or # images
any(getattr(self.dataset, 'video_flag', [False]))): # videos
LOGGER.warning(STREAM_WARNING)
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
@smart_inference_mode()

View File

@ -5,7 +5,6 @@ Ultralytics Results, Boxes and Masks classes for handling inference results
Usage: See https://docs.ultralytics.com/modes/predict/
"""
import pprint
from copy import deepcopy
from functools import lru_cache
@ -13,11 +12,11 @@ import numpy as np
import torch
import torchvision.transforms.functional as F
from ultralytics.yolo.utils import LOGGER, ops
from ultralytics.yolo.utils import LOGGER, SimpleClass, ops
from ultralytics.yolo.utils.plotting import Annotator, colors
class Results:
class Results(SimpleClass):
"""
A class for storing and manipulating inference results.
@ -96,17 +95,6 @@ class Results:
for k in self.keys:
return len(getattr(self, k))
def __str__(self):
attr = {k: v for k, v in vars(self).items() if not isinstance(v, type(self))}
return pprint.pformat(attr, indent=2, width=120, depth=10, compact=True)
def __repr__(self):
return self.__str__()
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def keys(self):
return [k for k in self._keys if getattr(self, k) is not None]
@ -153,7 +141,7 @@ class Results:
return np.asarray(annotator.im) if annotator.pil else annotator.im
class Boxes:
class Boxes(SimpleClass):
"""
A class for storing and manipulating detection boxes.
@ -242,15 +230,6 @@ class Boxes:
def pandas(self):
LOGGER.info('results.pandas() method not yet implemented')
'''
new = copy(self) # return copy
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
return new
'''
@property
def shape(self):
@ -263,25 +242,11 @@ class Boxes:
def __len__(self): # override len(results)
return len(self.boxes)
def __str__(self):
return self.boxes.__str__()
def __repr__(self):
return (f'{self.__class__.__module__}.{self.__class__.__name__}\n'
f'type: {self.boxes.__class__.__module__}.{self.boxes.__class__.__name__}\n'
f'shape: {self.boxes.shape}\n'
f'dtype: {self.boxes.dtype}\n'
f'{self.boxes.__repr__()}')
def __getitem__(self, idx):
return Boxes(self.boxes[idx], self.orig_shape)
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
class Masks:
class Masks(SimpleClass):
"""
A class for storing and manipulating detection masks.
@ -301,11 +266,6 @@ class Masks:
numpy(): Returns a copy of the masks tensor as a numpy array.
cuda(): Returns a copy of the masks tensor on GPU memory.
to(): Returns a copy of the masks tensor with the specified device and dtype.
__len__(): Returns the number of masks in the tensor.
__str__(): Returns a string representation of the masks tensor.
__repr__(): Returns a detailed string representation of the masks tensor.
__getitem__(): Returns a new Masks object with the masks at the specified index.
__getattr__(): Raises an AttributeError with a list of valid attributes and properties.
"""
def __init__(self, masks, orig_shape) -> None:
@ -342,19 +302,5 @@ class Masks:
def __len__(self): # override len(results)
return len(self.masks)
def __str__(self):
return self.masks.__str__()
def __repr__(self):
return (f'{self.__class__.__module__}.{self.__class__.__name__}\n'
f'type: {self.masks.__class__.__module__}.{self.masks.__class__.__name__}\n'
f'shape: {self.masks.shape}\n'
f'dtype: {self.masks.dtype}\n'
f'{self.masks.__repr__()}')
def __getitem__(self, idx):
return Masks(self.masks[idx], self.orig_shape)
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

View File

@ -174,7 +174,12 @@ class BaseTrainer:
# Run subprocess if DDP training, else train normally
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
# Argument checks
if self.args.rect:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
self.args.rect = False
# Command
cmd, file = generate_ddp_command(world_size, self)
try:
LOGGER.info(f'Running DDP command {cmd}')
subprocess.run(cmd, check=True)
@ -183,17 +188,15 @@ class BaseTrainer:
finally:
ddp_cleanup(self, str(file))
else:
self._do_train(RANK, world_size)
self._do_train(world_size)
def _setup_ddp(self, rank, world_size):
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '9020'
torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank)
LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
def _setup_ddp(self, world_size):
torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK)
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=RANK, world_size=world_size)
def _setup_train(self, rank, world_size):
def _setup_train(self, world_size):
"""
Builds dataloaders and optimizer on correct rank process.
"""
@ -213,7 +216,7 @@ class BaseTrainer:
self.amp = bool(self.amp) # as boolean
self.scaler = amp.GradScaler(enabled=self.amp)
if world_size > 1:
self.model = DDP(self.model, device_ids=[rank])
self.model = DDP(self.model, device_ids=[RANK])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
@ -243,8 +246,8 @@ class BaseTrainer:
# dataloaders
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode='train')
if rank in (-1, 0):
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
if RANK in (-1, 0):
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
@ -256,11 +259,11 @@ class BaseTrainer:
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks('on_pretrain_routine_end')
def _do_train(self, rank=-1, world_size=1):
def _do_train(self, world_size=1):
if world_size > 1:
self._setup_ddp(rank, world_size)
self._setup_ddp(world_size)
self._setup_train(rank, world_size)
self._setup_train(world_size)
self.epoch_time = None
self.epoch_time_start = time.time()
@ -280,7 +283,7 @@ class BaseTrainer:
self.epoch = epoch
self.run_callbacks('on_train_epoch_start')
self.model.train()
if rank != -1:
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
@ -291,7 +294,7 @@ class BaseTrainer:
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
if rank in (-1, 0):
if RANK in (-1, 0):
LOGGER.info(self.progress_string())
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
self.tloss = None
@ -315,7 +318,7 @@ class BaseTrainer:
batch = self.preprocess_batch(batch)
preds = self.model(batch['img'])
self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1:
if RANK != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
@ -332,7 +335,7 @@ class BaseTrainer:
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if rank in (-1, 0):
if RANK in (-1, 0):
pbar.set_description(
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
@ -347,7 +350,7 @@ class BaseTrainer:
self.scheduler.step()
self.run_callbacks('on_train_epoch_end')
if rank in (-1, 0):
if RANK in (-1, 0):
# Validation
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
@ -377,7 +380,7 @@ class BaseTrainer:
if self.stop:
break # must break all DDP ranks
if rank in (-1, 0):
if RANK in (-1, 0):
# Do final val with best.pt
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
@ -408,7 +411,8 @@ class BaseTrainer:
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
del ckpt
def get_dataset(self, data):
@staticmethod
def get_dataset(data):
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""