ultralytics 8.0.100
add Mosaic9() augmentation (#2605)
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: BIGBOSS-FOX <47949596+BIGBOSS-FOX@users.noreply.github.com> Co-authored-by: xbkaishui <xxkaishui@gmail.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.99'
|
||||
__version__ = '8.0.100'
|
||||
|
||||
from ultralytics.hub import start
|
||||
from ultralytics.vit.rtdetr import RTDETR
|
||||
|
@ -177,13 +177,13 @@ class C2f(nn.Module):
|
||||
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass of a YOLOv5 CSPDarknet backbone layer."""
|
||||
"""Forward pass through C2f layer."""
|
||||
y = list(self.cv1(x).chunk(2, 1))
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
def forward_split(self, x):
|
||||
"""Applies spatial attention to module's input."""
|
||||
"""Forward pass using split() instead of chunk()."""
|
||||
y = list(self.cv1(x).split((self.c, self.c), 1))
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
@ -126,7 +126,7 @@ class BaseModel(nn.Module):
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
|
||||
def info(self, verbose=True, imgsz=640):
|
||||
def info(self, detailed=False, verbose=True, imgsz=640):
|
||||
"""
|
||||
Prints model information
|
||||
|
||||
@ -134,7 +134,7 @@ class BaseModel(nn.Module):
|
||||
verbose (bool): if True, prints out the model information. Defaults to False
|
||||
imgsz (int): the size of the image that the model will be trained on. Defaults to 640
|
||||
"""
|
||||
model_info(self, verbose=verbose, imgsz=imgsz)
|
||||
return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
|
||||
|
||||
def _apply(self, fn):
|
||||
"""
|
||||
|
@ -181,7 +181,7 @@ class BYTETracker:
|
||||
def update(self, results, img=None):
|
||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
activated_stracks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
removed_stracks = []
|
||||
@ -230,7 +230,7 @@ class BYTETracker:
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
activated_stracks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
@ -246,7 +246,7 @@ class BYTETracker:
|
||||
det = detections_second[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
activated_stracks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
@ -262,7 +262,7 @@ class BYTETracker:
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
activated_stracks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
@ -273,7 +273,7 @@ class BYTETracker:
|
||||
if track.score < self.args.new_track_thresh:
|
||||
continue
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
activated_stracks.append(track)
|
||||
# Step 5: Update state
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
@ -281,7 +281,7 @@ class BYTETracker:
|
||||
removed_stracks.append(track)
|
||||
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from ultralytics.nn.tasks import DetectionModel, attempt_load_one_weight, yaml_model_load
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.torch_utils import model_info
|
||||
|
||||
@ -47,7 +47,7 @@ class RTDETR:
|
||||
self.task = self.model.args['task']
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source, stream=False, **kwargs):
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
@ -61,6 +61,9 @@ class RTDETR:
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
overrides = dict(conf=0.25, task='detect', mode='predict')
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
if not self.predictor:
|
||||
|
@ -114,7 +114,11 @@ sam_model_map = {
|
||||
|
||||
def build_sam(ckpt='sam_b.pt'):
|
||||
"""Build a SAM model specified by ckpt."""
|
||||
model_builder = sam_model_map.get(ckpt)
|
||||
model_builder = None
|
||||
for k in sam_model_map.keys():
|
||||
if ckpt.endswith(k):
|
||||
model_builder = sam_model_map.get(k)
|
||||
|
||||
if not model_builder:
|
||||
raise FileNotFoundError(f'{ckpt} is not a supported sam model. Available models are: \n {sam_model_map.keys()}')
|
||||
|
||||
|
@ -9,7 +9,7 @@ from .predict import Predictor
|
||||
class SAM:
|
||||
|
||||
def __init__(self, model='sam_b.pt') -> None:
|
||||
if model and not (model.endswith('.pt') or model.endswith('.pth')):
|
||||
if model and not model.endswith('.pt') and not model.endswith('.pth'):
|
||||
# Should raise AssertionError instead?
|
||||
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
|
||||
self.model = build_sam(model)
|
||||
|
@ -115,30 +115,42 @@ class BaseMixTransform:
|
||||
|
||||
|
||||
class Mosaic(BaseMixTransform):
|
||||
"""Mosaic augmentation.
|
||||
Args:
|
||||
imgsz (Sequence[int]): Image size after mosaic pipeline of single
|
||||
image. The shape order should be (height, width).
|
||||
Default to (640, 640).
|
||||
"""
|
||||
Mosaic augmentation.
|
||||
|
||||
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
|
||||
The augmentation is applied to a dataset with a given probability.
|
||||
|
||||
Attributes:
|
||||
dataset: The dataset on which the mosaic augmentation is applied.
|
||||
imgsz (int, optional): Image size (height and width) after mosaic pipeline of a single image. Default to 640.
|
||||
p (float, optional): Probability of applying the mosaic augmentation. Must be in the range 0-1. Default to 1.0.
|
||||
n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, n=9):
|
||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||
assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
|
||||
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
|
||||
assert n in (4, 9), 'grid must be equal to 4 or 9.'
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
self.border = border
|
||||
self.border = [-imgsz // 2, -imgsz // 2] if n == 4 else [-imgsz, -imgsz]
|
||||
self.n = n
|
||||
|
||||
def get_indexes(self):
|
||||
"""Return a list of 3 random indexes from the dataset."""
|
||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
|
||||
"""Return a list of random indexes from the dataset."""
|
||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
"""Apply mixup transformation to the input image and labels."""
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
|
||||
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
|
||||
return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
|
||||
|
||||
def _mosaic4(self, labels):
|
||||
"""Create a 2x2 image mosaic."""
|
||||
mosaic_labels = []
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
|
||||
assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
|
||||
s = self.imgsz
|
||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||
for i in range(4):
|
||||
@ -172,7 +184,54 @@ class Mosaic(BaseMixTransform):
|
||||
final_labels['img'] = img4
|
||||
return final_labels
|
||||
|
||||
def _update_labels(self, labels, padw, padh):
|
||||
def _mosaic9(self, labels):
|
||||
"""Create a 3x3 image mosaic."""
|
||||
mosaic_labels = []
|
||||
s = self.imgsz
|
||||
hp, wp = -1, -1 # height, width previous
|
||||
for i in range(9):
|
||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
||||
# Load image
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
|
||||
# Place img in img9
|
||||
if i == 0: # center
|
||||
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
||||
h0, w0 = h, w
|
||||
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
|
||||
elif i == 1: # top
|
||||
c = s, s - h, s + w, s
|
||||
elif i == 2: # top right
|
||||
c = s + wp, s - h, s + wp + w, s
|
||||
elif i == 3: # right
|
||||
c = s + w0, s, s + w0 + w, s + h
|
||||
elif i == 4: # bottom right
|
||||
c = s + w0, s + hp, s + w0 + w, s + hp + h
|
||||
elif i == 5: # bottom
|
||||
c = s + w0 - w, s + h0, s + w0, s + h0 + h
|
||||
elif i == 6: # bottom left
|
||||
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
|
||||
elif i == 7: # left
|
||||
c = s - w, s + h0 - h, s, s + h0
|
||||
elif i == 8: # top left
|
||||
c = s - w, s + h0 - hp - h, s, s + h0 - hp
|
||||
|
||||
padw, padh = c[:2]
|
||||
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
||||
|
||||
# Image
|
||||
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
||||
hp, wp = h, w # height, width previous for next iteration
|
||||
|
||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
final_labels['img'] = img9
|
||||
return final_labels
|
||||
|
||||
@staticmethod
|
||||
def _update_labels(labels, padw, padh):
|
||||
"""Update labels."""
|
||||
nh, nw = labels['img'].shape[:2]
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
@ -195,8 +254,9 @@ class Mosaic(BaseMixTransform):
|
||||
'resized_shape': (self.imgsz * 2, self.imgsz * 2),
|
||||
'cls': np.concatenate(cls, 0),
|
||||
'instances': Instances.concatenate(instances, axis=0),
|
||||
'mosaic_border': self.border}
|
||||
final_labels['instances'].clip(self.imgsz * 2, self.imgsz * 2)
|
||||
'mosaic_border': self.border} # final_labels
|
||||
clip_size = self.imgsz * (2 if self.n == 4 else 3)
|
||||
final_labels['instances'].clip(clip_size, clip_size)
|
||||
return final_labels
|
||||
|
||||
|
||||
@ -695,7 +755,7 @@ class Format:
|
||||
def v8_transforms(dataset, imgsz, hyp):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose([
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
|
@ -201,15 +201,16 @@ class YOLO:
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def info(self, verbose=True):
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.info(verbose=verbose)
|
||||
return self.model.info(detailed=detailed, verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
||||
|
@ -190,17 +190,6 @@ class BaseTrainer:
|
||||
else:
|
||||
self._do_train(world_size)
|
||||
|
||||
def _pre_caching_dataset(self):
|
||||
"""
|
||||
Caching dataset before training to avoid NCCL timeout.
|
||||
Must be done before DDP initialization.
|
||||
See https://github.com/ultralytics/ultralytics/pull/2549 for details.
|
||||
"""
|
||||
if RANK in (-1, 0):
|
||||
LOGGER.info('Pre-caching dataset to avoid NCCL timeout')
|
||||
self.get_dataloader(self.trainset, batch_size=1, rank=RANK, mode='train')
|
||||
self.get_dataloader(self.testset, batch_size=1, rank=-1, mode='val')
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
@ -274,7 +263,6 @@ class BaseTrainer:
|
||||
def _do_train(self, world_size=1):
|
||||
"""Train completed, evaluate and plot if specified by arguments."""
|
||||
if world_size > 1:
|
||||
self._pre_caching_dataset()
|
||||
self._setup_ddp(world_size)
|
||||
|
||||
self._setup_train(world_size)
|
||||
|
@ -233,7 +233,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
n += 1
|
||||
|
||||
if s and install and AUTOINSTALL: # check environment variable
|
||||
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||
try:
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
|
@ -34,7 +34,7 @@ class BboxLoss(nn.Module):
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
||||
weight = target_scores.sum(-1)[fg_mask]
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
|
@ -189,6 +189,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
|
||||
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
||||
# filter invalid bboxes
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
||||
|
@ -165,14 +165,15 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = ' (fused)' if model.is_fused() else ''
|
||||
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||||
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
|
||||
LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||
return n_p, flops
|
||||
|
||||
|
||||
def get_num_params(model):
|
||||
|
Reference in New Issue
Block a user