ultralytics 8.0.100 add Mosaic9() augmentation (#2605)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: BIGBOSS-FOX <47949596+BIGBOSS-FOX@users.noreply.github.com>
Co-authored-by: xbkaishui <xxkaishui@gmail.com>
This commit is contained in:
Glenn Jocher
2023-05-14 20:43:35 +02:00
committed by GitHub
parent db1c5885d5
commit dce4efce48
23 changed files with 351 additions and 64 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.99'
__version__ = '8.0.100'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR

View File

@ -177,13 +177,13 @@ class C2f(nn.Module):
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
def forward(self, x):
"""Forward pass of a YOLOv5 CSPDarknet backbone layer."""
"""Forward pass through C2f layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
def forward_split(self, x):
"""Applies spatial attention to module's input."""
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))

View File

@ -126,7 +126,7 @@ class BaseModel(nn.Module):
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
def info(self, verbose=True, imgsz=640):
def info(self, detailed=False, verbose=True, imgsz=640):
"""
Prints model information
@ -134,7 +134,7 @@ class BaseModel(nn.Module):
verbose (bool): if True, prints out the model information. Defaults to False
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
"""
model_info(self, verbose=verbose, imgsz=imgsz)
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
def _apply(self, fn):
"""

View File

@ -181,7 +181,7 @@ class BYTETracker:
def update(self, results, img=None):
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
self.frame_id += 1
activated_starcks = []
activated_stracks = []
refind_stracks = []
lost_stracks = []
removed_stracks = []
@ -230,7 +230,7 @@ class BYTETracker:
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
activated_stracks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
@ -246,7 +246,7 @@ class BYTETracker:
det = detections_second[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_id)
activated_starcks.append(track)
activated_stracks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
@ -262,7 +262,7 @@ class BYTETracker:
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_starcks.append(unconfirmed[itracked])
activated_stracks.append(unconfirmed[itracked])
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed()
@ -273,7 +273,7 @@ class BYTETracker:
if track.score < self.args.new_track_thresh:
continue
track.activate(self.kalman_filter, self.frame_id)
activated_starcks.append(track)
activated_stracks.append(track)
# Step 5: Update state
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost:
@ -281,7 +281,7 @@ class BYTETracker:
removed_stracks.append(track)
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)

View File

@ -8,7 +8,7 @@ from pathlib import Path
from ultralytics.nn.tasks import DetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.torch_utils import model_info
@ -47,7 +47,7 @@ class RTDETR:
self.task = self.model.args['task']
@smart_inference_mode()
def predict(self, source, stream=False, **kwargs):
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
@ -61,6 +61,9 @@ class RTDETR:
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = dict(conf=0.25, task='detect', mode='predict')
overrides.update(kwargs) # prefer kwargs
if not self.predictor:

View File

@ -114,7 +114,11 @@ sam_model_map = {
def build_sam(ckpt='sam_b.pt'):
"""Build a SAM model specified by ckpt."""
model_builder = sam_model_map.get(ckpt)
model_builder = None
for k in sam_model_map.keys():
if ckpt.endswith(k):
model_builder = sam_model_map.get(k)
if not model_builder:
raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')

View File

@ -9,7 +9,7 @@ from .predict import Predictor
class SAM:
def __init__(self, model='sam_b.pt') -> None:
if model and not (model.endswith('.pt') or model.endswith('.pth')):
if model and not model.endswith('.pt') and not model.endswith('.pth'):
# Should raise AssertionError instead?
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
self.model = build_sam(model)

View File

@ -115,30 +115,42 @@ class BaseMixTransform:
class Mosaic(BaseMixTransform):
"""Mosaic augmentation.
Args:
imgsz (Sequence[int]): Image size after mosaic pipeline of single
image. The shape order should be (height, width).
Default to (640, 640).
"""
Mosaic augmentation.
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
The augmentation is applied to a dataset with a given probability.
Attributes:
dataset: The dataset on which the mosaic augmentation is applied.
imgsz (int, optional): Image size (height and width) after mosaic pipeline of a single image. Default to 640.
p (float, optional): Probability of applying the mosaic augmentation. Must be in the range 0-1. Default to 1.0.
n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
"""
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
def __init__(self, dataset, imgsz=640, p=1.0, n=9):
"""Initializes the object with a dataset, image size, probability, and border."""
assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
assert n in (4, 9), 'grid must be equal to 4 or 9.'
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
self.border = border
self.border = [-imgsz // 2, -imgsz // 2] if n == 4 else [-imgsz, -imgsz]
self.n = n
def get_indexes(self):
"""Return a list of 3 random indexes from the dataset."""
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
"""Return a list of random indexes from the dataset."""
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
def _mix_transform(self, labels):
"""Apply mixup transformation to the input image and labels."""
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
def _mosaic4(self, labels):
"""Create a 2x2 image mosaic."""
mosaic_labels = []
assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
for i in range(4):
@ -172,7 +184,54 @@ class Mosaic(BaseMixTransform):
final_labels['img'] = img4
return final_labels
def _update_labels(self, labels, padw, padh):
def _mosaic9(self, labels):
"""Create a 3x3 image mosaic."""
mosaic_labels = []
s = self.imgsz
hp, wp = -1, -1 # height, width previous
for i in range(9):
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
# Load image
img = labels_patch['img']
h, w = labels_patch.pop('resized_shape')
# Place img in img9
if i == 0: # center
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
h0, w0 = h, w
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
elif i == 1: # top
c = s, s - h, s + w, s
elif i == 2: # top right
c = s + wp, s - h, s + wp + w, s
elif i == 3: # right
c = s + w0, s, s + w0 + w, s + h
elif i == 4: # bottom right
c = s + w0, s + hp, s + w0 + w, s + hp + h
elif i == 5: # bottom
c = s + w0 - w, s + h0, s + w0, s + h0 + h
elif i == 6: # bottom left
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
elif i == 7: # left
c = s - w, s + h0 - h, s, s + h0
elif i == 8: # top left
c = s - w, s + h0 - hp - h, s, s + h0 - hp
padw, padh = c[:2]
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
# Image
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
hp, wp = h, w # height, width previous for next iteration
labels_patch = self._update_labels(labels_patch, padw, padh)
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
final_labels['img'] = img9
return final_labels
@staticmethod
def _update_labels(labels, padw, padh):
"""Update labels."""
nh, nw = labels['img'].shape[:2]
labels['instances'].convert_bbox(format='xyxy')
@ -195,8 +254,9 @@ class Mosaic(BaseMixTransform):
'resized_shape': (self.imgsz * 2, self.imgsz * 2),
'cls': np.concatenate(cls, 0),
'instances': Instances.concatenate(instances, axis=0),
'mosaic_border': self.border}
final_labels['instances'].clip(self.imgsz * 2, self.imgsz * 2)
'mosaic_border': self.border} # final_labels
clip_size = self.imgsz * (2 if self.n == 4 else 3)
final_labels['instances'].clip(clip_size, clip_size)
return final_labels
@ -695,7 +755,7 @@ class Format:
def v8_transforms(dataset, imgsz, hyp):
"""Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose([
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,

View File

@ -201,15 +201,16 @@ class YOLO:
self.model.load(weights)
return self
def info(self, verbose=True):
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
self._check_is_pytorch_model()
self.model.info(verbose=verbose)
return self.model.info(detailed=detailed, verbose=verbose)
def fuse(self):
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""

View File

@ -190,17 +190,6 @@ class BaseTrainer:
else:
self._do_train(world_size)
def _pre_caching_dataset(self):
"""
Caching dataset before training to avoid NCCL timeout.
Must be done before DDP initialization.
See https://github.com/ultralytics/ultralytics/pull/2549 for details.
"""
if RANK in (-1, 0):
LOGGER.info('Pre-caching dataset to avoid NCCL timeout')
self.get_dataloader(self.trainset, batch_size=1, rank=RANK, mode='train')
self.get_dataloader(self.testset, batch_size=1, rank=-1, mode='val')
def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
@ -274,7 +263,6 @@ class BaseTrainer:
def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
if world_size > 1:
self._pre_caching_dataset()
self._setup_ddp(world_size)
self._setup_train(world_size)

View File

@ -233,7 +233,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
n += 1
if s and install and AUTOINSTALL: # check environment variable
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
try:
assert is_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())

View File

@ -34,7 +34,7 @@ class BboxLoss(nn.Module):
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
weight = target_scores.sum(-1)[fg_mask]
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

View File

@ -189,6 +189,7 @@ class TaskAlignedAssigner(nn.Module):
for k in range(self.topk):
# Expand topk_idxs for each value of k and add 1 at the specified positions
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
# filter invalid bboxes
count_tensor.masked_fill_(count_tensor > 1, 0)

View File

@ -165,14 +165,15 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
for i, (name, p) in enumerate(model.named_parameters()):
name = name.replace('module_list.', '')
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
flops = get_flops(model, imgsz)
fused = ' (fused)' if model.is_fused() else ''
fs = f', {flops:.1f} GFLOPs' if flops else ''
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
return n_p, flops
def get_num_params(model):