ultralytics 8.0.100
add Mosaic9() augmentation (#2605)
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: BIGBOSS-FOX <47949596+BIGBOSS-FOX@users.noreply.github.com> Co-authored-by: xbkaishui <xxkaishui@gmail.com>
This commit is contained in:
@ -115,30 +115,42 @@ class BaseMixTransform:
|
||||
|
||||
|
||||
class Mosaic(BaseMixTransform):
|
||||
"""Mosaic augmentation.
|
||||
Args:
|
||||
imgsz (Sequence[int]): Image size after mosaic pipeline of single
|
||||
image. The shape order should be (height, width).
|
||||
Default to (640, 640).
|
||||
"""
|
||||
Mosaic augmentation.
|
||||
|
||||
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
|
||||
The augmentation is applied to a dataset with a given probability.
|
||||
|
||||
Attributes:
|
||||
dataset: The dataset on which the mosaic augmentation is applied.
|
||||
imgsz (int, optional): Image size (height and width) after mosaic pipeline of a single image. Default to 640.
|
||||
p (float, optional): Probability of applying the mosaic augmentation. Must be in the range 0-1. Default to 1.0.
|
||||
n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, n=9):
|
||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||
assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
|
||||
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
|
||||
assert n in (4, 9), 'grid must be equal to 4 or 9.'
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
self.border = border
|
||||
self.border = [-imgsz // 2, -imgsz // 2] if n == 4 else [-imgsz, -imgsz]
|
||||
self.n = n
|
||||
|
||||
def get_indexes(self):
|
||||
"""Return a list of 3 random indexes from the dataset."""
|
||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
|
||||
"""Return a list of random indexes from the dataset."""
|
||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
"""Apply mixup transformation to the input image and labels."""
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
|
||||
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
|
||||
return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
|
||||
|
||||
def _mosaic4(self, labels):
|
||||
"""Create a 2x2 image mosaic."""
|
||||
mosaic_labels = []
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
|
||||
assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
|
||||
s = self.imgsz
|
||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||
for i in range(4):
|
||||
@ -172,7 +184,54 @@ class Mosaic(BaseMixTransform):
|
||||
final_labels['img'] = img4
|
||||
return final_labels
|
||||
|
||||
def _update_labels(self, labels, padw, padh):
|
||||
def _mosaic9(self, labels):
|
||||
"""Create a 3x3 image mosaic."""
|
||||
mosaic_labels = []
|
||||
s = self.imgsz
|
||||
hp, wp = -1, -1 # height, width previous
|
||||
for i in range(9):
|
||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
||||
# Load image
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
|
||||
# Place img in img9
|
||||
if i == 0: # center
|
||||
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
||||
h0, w0 = h, w
|
||||
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
|
||||
elif i == 1: # top
|
||||
c = s, s - h, s + w, s
|
||||
elif i == 2: # top right
|
||||
c = s + wp, s - h, s + wp + w, s
|
||||
elif i == 3: # right
|
||||
c = s + w0, s, s + w0 + w, s + h
|
||||
elif i == 4: # bottom right
|
||||
c = s + w0, s + hp, s + w0 + w, s + hp + h
|
||||
elif i == 5: # bottom
|
||||
c = s + w0 - w, s + h0, s + w0, s + h0 + h
|
||||
elif i == 6: # bottom left
|
||||
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
|
||||
elif i == 7: # left
|
||||
c = s - w, s + h0 - h, s, s + h0
|
||||
elif i == 8: # top left
|
||||
c = s - w, s + h0 - hp - h, s, s + h0 - hp
|
||||
|
||||
padw, padh = c[:2]
|
||||
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
||||
|
||||
# Image
|
||||
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
||||
hp, wp = h, w # height, width previous for next iteration
|
||||
|
||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
final_labels['img'] = img9
|
||||
return final_labels
|
||||
|
||||
@staticmethod
|
||||
def _update_labels(labels, padw, padh):
|
||||
"""Update labels."""
|
||||
nh, nw = labels['img'].shape[:2]
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
@ -195,8 +254,9 @@ class Mosaic(BaseMixTransform):
|
||||
'resized_shape': (self.imgsz * 2, self.imgsz * 2),
|
||||
'cls': np.concatenate(cls, 0),
|
||||
'instances': Instances.concatenate(instances, axis=0),
|
||||
'mosaic_border': self.border}
|
||||
final_labels['instances'].clip(self.imgsz * 2, self.imgsz * 2)
|
||||
'mosaic_border': self.border} # final_labels
|
||||
clip_size = self.imgsz * (2 if self.n == 4 else 3)
|
||||
final_labels['instances'].clip(clip_size, clip_size)
|
||||
return final_labels
|
||||
|
||||
|
||||
@ -695,7 +755,7 @@ class Format:
|
||||
def v8_transforms(dataset, imgsz, hyp):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose([
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
|
@ -201,15 +201,16 @@ class YOLO:
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def info(self, verbose=True):
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.info(verbose=verbose)
|
||||
return self.model.info(detailed=detailed, verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
||||
|
@ -190,17 +190,6 @@ class BaseTrainer:
|
||||
else:
|
||||
self._do_train(world_size)
|
||||
|
||||
def _pre_caching_dataset(self):
|
||||
"""
|
||||
Caching dataset before training to avoid NCCL timeout.
|
||||
Must be done before DDP initialization.
|
||||
See https://github.com/ultralytics/ultralytics/pull/2549 for details.
|
||||
"""
|
||||
if RANK in (-1, 0):
|
||||
LOGGER.info('Pre-caching dataset to avoid NCCL timeout')
|
||||
self.get_dataloader(self.trainset, batch_size=1, rank=RANK, mode='train')
|
||||
self.get_dataloader(self.testset, batch_size=1, rank=-1, mode='val')
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
@ -274,7 +263,6 @@ class BaseTrainer:
|
||||
def _do_train(self, world_size=1):
|
||||
"""Train completed, evaluate and plot if specified by arguments."""
|
||||
if world_size > 1:
|
||||
self._pre_caching_dataset()
|
||||
self._setup_ddp(world_size)
|
||||
|
||||
self._setup_train(world_size)
|
||||
|
@ -233,7 +233,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
n += 1
|
||||
|
||||
if s and install and AUTOINSTALL: # check environment variable
|
||||
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||
try:
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
|
@ -34,7 +34,7 @@ class BboxLoss(nn.Module):
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
||||
weight = target_scores.sum(-1)[fg_mask]
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
|
@ -189,6 +189,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
|
||||
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
||||
# filter invalid bboxes
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
||||
|
@ -165,14 +165,15 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = ' (fused)' if model.is_fused() else ''
|
||||
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||||
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
|
||||
LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||
return n_p, flops
|
||||
|
||||
|
||||
def get_num_params(model):
|
||||
|
Reference in New Issue
Block a user