Rename img_size to imgsz (#86)

This commit is contained in:
Glenn Jocher
2022-12-24 00:39:09 +01:00
committed by GitHub
parent ae2443c210
commit 6432afc5f9
25 changed files with 98 additions and 98 deletions

View File

@ -114,15 +114,15 @@ class BaseMixTransform:
class Mosaic(BaseMixTransform):
"""Mosaic augmentation.
Args:
img_size (Sequence[int]): Image size after mosaic pipeline of single
imgsz (Sequence[int]): Image size after mosaic pipeline of single
image. The shape order should be (height, width).
Default to (640, 640).
"""
def __init__(self, img_size=640, p=1.0, border=(0, 0)):
def __init__(self, imgsz=640, p=1.0, border=(0, 0)):
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
super().__init__(pre_transform=None, p=p)
self.img_size = img_size
self.imgsz = imgsz
self.border = border
def get_indexes(self, dataset):
@ -132,7 +132,7 @@ class Mosaic(BaseMixTransform):
mosaic_labels = []
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
s = self.img_size
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
mix_labels = labels["mix_labels"]
for i in range(4):
@ -184,12 +184,12 @@ class Mosaic(BaseMixTransform):
instances.append(labels["instances"])
final_labels = {
"ori_shape": mosaic_labels[0]["ori_shape"],
"resized_shape": (self.img_size * 2, self.img_size * 2),
"resized_shape": (self.imgsz * 2, self.imgsz * 2),
"im_file": mosaic_labels[0]["im_file"],
"cls": np.concatenate(cls, 0)}
final_labels["instances"] = Instances.concatenate(instances, axis=0)
final_labels["instances"].clip(self.img_size * 2, self.img_size * 2)
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
return final_labels
@ -658,9 +658,9 @@ class Format:
return masks, instances, cls
def mosaic_transforms(img_size, hyp):
def mosaic_transforms(imgsz, hyp):
pre_transform = Compose([
Mosaic(img_size=img_size, p=hyp.mosaic, border=[-img_size // 2, -img_size // 2]),
Mosaic(imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
@ -668,7 +668,7 @@ def mosaic_transforms(img_size, hyp):
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[-img_size // 2, -img_size // 2],
border=[-imgsz // 2, -imgsz // 2],
),])
return Compose([
pre_transform,
@ -682,9 +682,9 @@ def mosaic_transforms(img_size, hyp):
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
def affine_transforms(img_size, hyp):
def affine_transforms(imgsz, hyp):
return Compose([
LetterBox(new_shape=(img_size, img_size)),
LetterBox(new_shape=(imgsz, imgsz)),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,

View File

@ -24,7 +24,7 @@ class BaseDataset(Dataset):
def __init__(
self,
img_path,
img_size=640,
imgsz=640,
label_path=None,
cache=False,
augment=True,
@ -38,7 +38,7 @@ class BaseDataset(Dataset):
):
super().__init__()
self.img_path = img_path
self.img_size = img_size
self.imgsz = imgsz
self.label_path = label_path
self.augment = augment
self.prefix = prefix
@ -118,7 +118,7 @@ class BaseDataset(Dataset):
im = cv2.imread(f) # BGR
assert im is not None, f"Image Not Found {f}"
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
@ -168,7 +168,7 @@ class BaseDataset(Dataset):
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(int) * self.stride
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
self.batch = bi # batch index of image
def __getitem__(self, index):

View File

@ -62,7 +62,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
dataset = YOLODataset(
img_path=img_path,
label_path=label_path,
img_size=cfg.img_size,
imgsz=cfg.imgsz,
batch_size=batch_size,
augment=True if mode == "train" else False, # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function

View File

@ -18,10 +18,10 @@ from ultralytics.yolo.utils.checks import check_requirements
class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream'
self.img_size = img_size
self.imgsz = imgsz
self.stride = stride
self.vid_stride = vid_stride # video frame-rate stride
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
@ -55,7 +55,7 @@ class LoadStreams:
LOGGER.info('') # newline
# check for common shapes
s = np.stack([LetterBox(img_size, auto, stride=stride)(image=x).shape for x in self.imgs])
s = np.stack([LetterBox(imgsz, auto, stride=stride)(image=x).shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
@ -92,7 +92,7 @@ class LoadStreams:
if self.transforms:
im = np.stack([self.transforms(x) for x in im0]) # transforms
else:
im = np.stack([LetterBox(self.img_size, self.auto, stride=self.stride)(image=x) for x in im0])
im = np.stack([LetterBox(self.imgsz, self.auto, stride=self.stride)(image=x) for x in im0])
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous
@ -104,7 +104,7 @@ class LoadStreams:
class LoadScreenshots:
# YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None):
# source = [screen_number left top width height] (pixels)
check_requirements('mss')
import mss
@ -117,7 +117,7 @@ class LoadScreenshots:
left, top, width, height = (int(x) for x in params)
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
self.img_size = img_size
self.imgsz = imgsz
self.stride = stride
self.transforms = transforms
self.auto = auto
@ -144,7 +144,7 @@ class LoadScreenshots:
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
self.frame += 1
@ -153,7 +153,7 @@ class LoadScreenshots:
class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit()
files = []
@ -172,7 +172,7 @@ class LoadImages:
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.img_size = img_size
self.imgsz = imgsz
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
@ -226,7 +226,7 @@ class LoadImages:
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous

View File

@ -24,7 +24,7 @@ class YOLODataset(BaseDataset):
def __init__(
self,
img_path,
img_size=640,
imgsz=640,
label_path=None,
cache=False,
augment=True,
@ -41,7 +41,7 @@ class YOLODataset(BaseDataset):
self.use_segments = use_segments
self.use_keypoints = use_keypoints
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
super().__init__(img_path, imgsz, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
single_cls)
def cache_labels(self, path=Path("./labels.cache")):
@ -128,11 +128,11 @@ class YOLODataset(BaseDataset):
# mosaic = False
if self.augment:
if mosaic:
transforms = mosaic_transforms(self.img_size, hyp)
transforms = mosaic_transforms(self.imgsz, hyp)
else:
transforms = affine_transforms(self.img_size, hyp)
transforms = affine_transforms(self.imgsz, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))])
transforms.append(
Format(bbox_format="xywh",
normalize=True,

View File

@ -14,7 +14,7 @@ class MixAndRectDataset:
def __init__(self, dataset):
self.dataset = dataset
self.img_size = dataset.img_size
self.imgsz = dataset.imgsz
def __len__(self):
return len(self.dataset)

View File

@ -128,50 +128,50 @@ def verify_image_label(args):
return [None, None, None, None, None, nm, nf, ne, nc, msg]
def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
"""
Args:
img_size (tuple): The image size.
imgsz (tuple): The image size.
polygons (np.ndarray): [N, M], N is the number of polygons,
M is the number of points(Be divided by 2).
"""
mask = np.zeros(img_size, dtype=np.uint8)
mask = np.zeros(imgsz, dtype=np.uint8)
polygons = np.asarray(polygons)
polygons = polygons.astype(np.int32)
shape = polygons.shape
polygons = polygons.reshape(shape[0], -1, 2)
cv2.fillPoly(mask, polygons, color=color)
nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
# NOTE: fillPoly firstly then resize is trying the keep the same way
# of loss calculation when mask-ratio=1.
mask = cv2.resize(mask, (nw, nh))
return mask
def polygons2masks(img_size, polygons, color, downsample_ratio=1):
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
"""
Args:
img_size (tuple): The image size.
imgsz (tuple): The image size.
polygons (list[np.ndarray]): each polygon is [N, M],
N is the number of polygons,
M is the number of points(Be divided by 2).
"""
masks = []
for si in range(len(polygons)):
mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
masks.append(mask)
return np.array(masks)
def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
"""Return a (640, 640) overlap mask."""
masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
dtype=np.int32 if len(segments) > 255 else np.uint8)
areas = []
ms = []
for si in range(len(segments)):
mask = polygon2mask(
img_size,
imgsz,
[segments[si].reshape(-1)],
downsample_ratio=downsample_ratio,
color=1,