Rename `img_size` to `imgsz` (#86)

single_channel
Glenn Jocher 2 years ago committed by GitHub
parent ae2443c210
commit 6432afc5f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -90,16 +90,16 @@ jobs:
- name: Test detection
shell: bash # for Windows compatibility
run: |
yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=1 img_size=64
yolo task=detect mode=val model=runs/exp/weights/last.pt img_size=64
yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=1 imgsz=64
yolo task=detect mode=val model=runs/exp/weights/last.pt imgsz=64
- name: Test segmentation
shell: bash # for Windows compatibility
# TODO: redo val test without hardcoded weights
run: |
yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
yolo task=segment mode=val model=runs/exp2/weights/last.pt data=coco128-seg.yaml img_size=64
yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64
yolo task=segment mode=val model=runs/exp2/weights/last.pt data=coco128-seg.yaml imgsz=64
- name: Test classification
shell: bash # for Windows compatibility
run: |
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32
yolo task=classify mode=val model=runs/exp3/weights/last.pt data=mnist160

@ -21,7 +21,7 @@ Default training settings and hyperparameters for medium-augmentation COCO train
| epochs | 100 | Number of epochs to train |
| workers | 8 | Number of cpu workers used per process. Scales automatically with DDP |
| batch_size | 16 | Batch size of the dataloader |
| img_size | 640 | Image size of data in dataloader |
| imgsz | 640 | Image size of data in dataloader |
| optimizer | SGD | Optimizer used. Supported optimizer are: `Adam`, `SGD`, `RMSProp` |
| single_cls | False | Train on multi-class data as single-class |
| image_weights | False | Use weighted image selection for training |

@ -70,7 +70,7 @@ with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
def test(augment, rect):
dataloader, _ = build_dataloader(
img_path="/d/dataset/COCO/images/val2017",
img_size=640,
imgsz=640,
label_path=None,
cache=False,
hyp=hyp,

@ -36,13 +36,13 @@ def test_visualize_preds():
def test_val():
model = YOLO()
model.load("balloon-segment.pt")
model.val(data="coco128-seg.yaml", img_size=32)
model.val(data="coco128-seg.yaml", imgsz=32)
def test_model_resume():
model = YOLO()
model.new("yolov5n-seg.yaml")
model.train(epochs=1, img_size=32, data="coco128-seg.yaml")
model.train(epochs=1, imgsz=32, data="coco128-seg.yaml")
try:
model.resume(task="segment")
except AssertionError:
@ -52,9 +52,9 @@ def test_model_resume():
def test_model_train_pretrained():
model = YOLO()
model.load("balloon-detect.pt")
model.train(data="coco128.yaml", epochs=1, img_size=32)
model.train(data="coco128.yaml", epochs=1, imgsz=32)
model.new("yolov5n.yaml")
model.train(data="coco128.yaml", epochs=1, img_size=32)
model.train(data="coco128.yaml", epochs=1, imgsz=32)
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
model(img)

@ -114,15 +114,15 @@ class BaseMixTransform:
class Mosaic(BaseMixTransform):
"""Mosaic augmentation.
Args:
img_size (Sequence[int]): Image size after mosaic pipeline of single
imgsz (Sequence[int]): Image size after mosaic pipeline of single
image. The shape order should be (height, width).
Default to (640, 640).
"""
def __init__(self, img_size=640, p=1.0, border=(0, 0)):
def __init__(self, imgsz=640, p=1.0, border=(0, 0)):
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
super().__init__(pre_transform=None, p=p)
self.img_size = img_size
self.imgsz = imgsz
self.border = border
def get_indexes(self, dataset):
@ -132,7 +132,7 @@ class Mosaic(BaseMixTransform):
mosaic_labels = []
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
s = self.img_size
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
mix_labels = labels["mix_labels"]
for i in range(4):
@ -184,12 +184,12 @@ class Mosaic(BaseMixTransform):
instances.append(labels["instances"])
final_labels = {
"ori_shape": mosaic_labels[0]["ori_shape"],
"resized_shape": (self.img_size * 2, self.img_size * 2),
"resized_shape": (self.imgsz * 2, self.imgsz * 2),
"im_file": mosaic_labels[0]["im_file"],
"cls": np.concatenate(cls, 0)}
final_labels["instances"] = Instances.concatenate(instances, axis=0)
final_labels["instances"].clip(self.img_size * 2, self.img_size * 2)
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
return final_labels
@ -658,9 +658,9 @@ class Format:
return masks, instances, cls
def mosaic_transforms(img_size, hyp):
def mosaic_transforms(imgsz, hyp):
pre_transform = Compose([
Mosaic(img_size=img_size, p=hyp.mosaic, border=[-img_size // 2, -img_size // 2]),
Mosaic(imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
@ -668,7 +668,7 @@ def mosaic_transforms(img_size, hyp):
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[-img_size // 2, -img_size // 2],
border=[-imgsz // 2, -imgsz // 2],
),])
return Compose([
pre_transform,
@ -682,9 +682,9 @@ def mosaic_transforms(img_size, hyp):
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
def affine_transforms(img_size, hyp):
def affine_transforms(imgsz, hyp):
return Compose([
LetterBox(new_shape=(img_size, img_size)),
LetterBox(new_shape=(imgsz, imgsz)),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,

@ -24,7 +24,7 @@ class BaseDataset(Dataset):
def __init__(
self,
img_path,
img_size=640,
imgsz=640,
label_path=None,
cache=False,
augment=True,
@ -38,7 +38,7 @@ class BaseDataset(Dataset):
):
super().__init__()
self.img_path = img_path
self.img_size = img_size
self.imgsz = imgsz
self.label_path = label_path
self.augment = augment
self.prefix = prefix
@ -118,7 +118,7 @@ class BaseDataset(Dataset):
im = cv2.imread(f) # BGR
assert im is not None, f"Image Not Found {f}"
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
@ -168,7 +168,7 @@ class BaseDataset(Dataset):
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(int) * self.stride
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
self.batch = bi # batch index of image
def __getitem__(self, index):

@ -62,7 +62,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
dataset = YOLODataset(
img_path=img_path,
label_path=label_path,
img_size=cfg.img_size,
imgsz=cfg.imgsz,
batch_size=batch_size,
augment=True if mode == "train" else False, # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function

@ -18,10 +18,10 @@ from ultralytics.yolo.utils.checks import check_requirements
class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream'
self.img_size = img_size
self.imgsz = imgsz
self.stride = stride
self.vid_stride = vid_stride # video frame-rate stride
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
@ -55,7 +55,7 @@ class LoadStreams:
LOGGER.info('') # newline
# check for common shapes
s = np.stack([LetterBox(img_size, auto, stride=stride)(image=x).shape for x in self.imgs])
s = np.stack([LetterBox(imgsz, auto, stride=stride)(image=x).shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
@ -92,7 +92,7 @@ class LoadStreams:
if self.transforms:
im = np.stack([self.transforms(x) for x in im0]) # transforms
else:
im = np.stack([LetterBox(self.img_size, self.auto, stride=self.stride)(image=x) for x in im0])
im = np.stack([LetterBox(self.imgsz, self.auto, stride=self.stride)(image=x) for x in im0])
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous
@ -104,7 +104,7 @@ class LoadStreams:
class LoadScreenshots:
# YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None):
# source = [screen_number left top width height] (pixels)
check_requirements('mss')
import mss
@ -117,7 +117,7 @@ class LoadScreenshots:
left, top, width, height = (int(x) for x in params)
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
self.img_size = img_size
self.imgsz = imgsz
self.stride = stride
self.transforms = transforms
self.auto = auto
@ -144,7 +144,7 @@ class LoadScreenshots:
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
self.frame += 1
@ -153,7 +153,7 @@ class LoadScreenshots:
class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit()
files = []
@ -172,7 +172,7 @@ class LoadImages:
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.img_size = img_size
self.imgsz = imgsz
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
@ -226,7 +226,7 @@ class LoadImages:
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
im = LetterBox(self.imgsz, self.auto, stride=self.stride)(image=im0)
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous

@ -24,7 +24,7 @@ class YOLODataset(BaseDataset):
def __init__(
self,
img_path,
img_size=640,
imgsz=640,
label_path=None,
cache=False,
augment=True,
@ -41,7 +41,7 @@ class YOLODataset(BaseDataset):
self.use_segments = use_segments
self.use_keypoints = use_keypoints
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
super().__init__(img_path, imgsz, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
single_cls)
def cache_labels(self, path=Path("./labels.cache")):
@ -128,11 +128,11 @@ class YOLODataset(BaseDataset):
# mosaic = False
if self.augment:
if mosaic:
transforms = mosaic_transforms(self.img_size, hyp)
transforms = mosaic_transforms(self.imgsz, hyp)
else:
transforms = affine_transforms(self.img_size, hyp)
transforms = affine_transforms(self.imgsz, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))])
transforms.append(
Format(bbox_format="xywh",
normalize=True,

@ -14,7 +14,7 @@ class MixAndRectDataset:
def __init__(self, dataset):
self.dataset = dataset
self.img_size = dataset.img_size
self.imgsz = dataset.imgsz
def __len__(self):
return len(self.dataset)

@ -128,50 +128,50 @@ def verify_image_label(args):
return [None, None, None, None, None, nm, nf, ne, nc, msg]
def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
"""
Args:
img_size (tuple): The image size.
imgsz (tuple): The image size.
polygons (np.ndarray): [N, M], N is the number of polygons,
M is the number of points(Be divided by 2).
"""
mask = np.zeros(img_size, dtype=np.uint8)
mask = np.zeros(imgsz, dtype=np.uint8)
polygons = np.asarray(polygons)
polygons = polygons.astype(np.int32)
shape = polygons.shape
polygons = polygons.reshape(shape[0], -1, 2)
cv2.fillPoly(mask, polygons, color=color)
nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
# NOTE: fillPoly firstly then resize is trying the keep the same way
# of loss calculation when mask-ratio=1.
mask = cv2.resize(mask, (nw, nh))
return mask
def polygons2masks(img_size, polygons, color, downsample_ratio=1):
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
"""
Args:
img_size (tuple): The image size.
imgsz (tuple): The image size.
polygons (list[np.ndarray]): each polygon is [N, M],
N is the number of polygons,
M is the number of points(Be divided by 2).
"""
masks = []
for si in range(len(polygons)):
mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
masks.append(mask)
return np.array(masks)
def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
"""Return a (640, 640) overlap mask."""
masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
dtype=np.int32 if len(segments) > 255 else np.uint8)
areas = []
ms = []
for si in range(len(segments)):
mask = polygon2mask(
img_size,
imgsz,
[segments[si].reshape(-1)],
downsample_ratio=downsample_ratio,
color=1,

@ -111,11 +111,11 @@ class YOLO:
predictor = self.PredictorClass(overrides=kwargs)
# check size type
sz = predictor.args.img_size
sz = predictor.args.imgsz
if type(sz) != int: # recieved listConfig
predictor.args.img_size = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
else:
predictor.args.img_size = [sz, sz]
predictor.args.imgsz = [sz, sz]
predictor.setup(model=self.model, source=source)
predictor()

@ -39,7 +39,7 @@ from ultralytics.yolo.utils.configs import get_config
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
from ultralytics.yolo.utils.plotting import Annotator
from ultralytics.yolo.utils.torch_utils import check_img_size, select_device, smart_inference_mode
from ultralytics.yolo.utils.torch_utils import check_imgsz, select_device, smart_inference_mode
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -99,18 +99,18 @@ class BasePredictor:
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half) # NOTE: not passing data
stride, pt = model.stride, model.pt
imgsz = check_img_size(self.args.img_size, s=stride) # check image size
imgsz = check_imgsz(self.args.imgsz, s=stride) # check image size
# Dataloader
bs = 1 # batch_size
if webcam:
self.view_img = check_imshow(warn=True)
self.dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
self.dataset = LoadStreams(source, imgsz=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
bs = len(self.dataset)
elif screenshot:
self.dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
self.dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=pt)
else:
self.dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
self.dataset = LoadImages(source, imgsz=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup

@ -12,7 +12,7 @@ from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import check_img_size, de_parallel, select_device
from ultralytics.yolo.utils.torch_utils import check_imgsz, de_parallel, select_device
class BaseValidator:
@ -55,7 +55,7 @@ class BaseValidator:
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
self.model = model
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
imgsz = check_img_size(self.args.img_size, s=stride)
imgsz = check_imgsz(self.args.imgsz, s=stride)
if engine:
self.args.batch_size = model.batch_size
else:

@ -51,7 +51,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
else:
LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
na = m.anchors.numel() // 2 # number of anchors
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
anchors = kmean_anchors(dataset, n=na, imgsz=imgsz, thr=thr, gen=1000, verbose=False)
new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
@ -64,13 +64,13 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
LOGGER.info(s)
def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
def kmean_anchors(dataset='./data/coco128.yaml', n=9, imgsz=640, thr=4.0, gen=1000, verbose=True):
""" Creates kmeans-evolved anchors from training dataset
Arguments:
dataset: path to data.yaml, or a loaded dataset
n: number of anchors
img_size: image size used for training
imgsz: image size used for training
thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
gen: generations to evolve anchors using genetic algorithm
verbose: print all results
@ -101,7 +101,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
f'{PREFIX}n={n}, imgsz={imgsz}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
f'past_thr={x[x > thr].mean():.3f}-mean: '
for x in k:
s += '%i,%i, ' % (round(x[0]), round(x[1]))
@ -116,7 +116,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
dataset = BaseDataset(data_dict['train'], augment=True, rect=True)
# Get label wh
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
# Filter
@ -135,7 +135,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
except Exception:
LOGGER.warning(f'{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init')
k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
k = np.sort(npr.rand(n * 2)).reshape(n, 2) * imgsz # random init
wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
k = print_results(k, verbose=False)

@ -10,7 +10,7 @@ model: null # i.e. yolov5s.pt, yolo.yaml
data: null # i.e. coco128.yaml
epochs: 300
batch_size: 16
img_size: 640
imgsz: 640
nosave: False
cache: False # True/ram, disk or False
device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu

@ -51,8 +51,8 @@ class BaseModel(nn.Module):
self.info()
return self
def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)
def info(self, verbose=False, imgsz=640): # print model information
model_info(self, verbose, imgsz)
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
@ -117,7 +117,7 @@ class DetectionModel(BaseModel):
return self._forward_once(x, profile, visualize) # single-scale inference, train
def _forward_augment(self, x):
img_size = x.shape[-2:] # height, width
imgsz = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
@ -125,25 +125,25 @@ class DetectionModel(BaseModel):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self._forward_once(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi = self._descale_pred(yi, fi, si, img_size)
yi = self._descale_pred(yi, fi, si, imgsz)
y.append(yi)
y = self._clip_augmented(y) # clip augmented tails
return torch.cat(y, 1), None # augmented inference, train
def _descale_pred(self, p, flips, scale, img_size):
def _descale_pred(self, p, flips, scale, imgsz):
# de-scale predictions following augmented inference (inverse operation)
if self.inplace:
p[..., :4] /= scale # de-scale
if flips == 2:
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
p[..., 1] = imgsz[0] - p[..., 1] # de-flip ud
elif flips == 3:
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
p[..., 0] = imgsz[1] - p[..., 0] # de-flip lr
else:
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
if flips == 2:
y = img_size[0] - y # de-flip ud
y = imgsz[0] - y # de-flip ud
elif flips == 3:
x = img_size[1] - x # de-flip lr
x = imgsz[1] - x # de-flip lr
p = torch.cat((x, y, wh, p[..., 4:]), -1)
return p

@ -124,7 +124,7 @@ def fuse_conv_and_bn(conv, bn):
def model_info(model, verbose=False, imgsz=640):
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
# Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
n_p = get_num_params(model)
n_g = get_num_gradients(model) # number gradients
if verbose:
@ -185,11 +185,11 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
def check_img_size(imgsz, s=32, floor=0):
def check_imgsz(imgsz, s=32, floor=0):
# Verify image size is a multiple of stride s in each dimension
if isinstance(imgsz, int): # integer i.e. img_size=640
if isinstance(imgsz, int): # integer i.e. imgsz=640
new_size = max(make_divisible(imgsz, int(s)), floor)
else: # list i.e. img_size=[640, 480]
else: # list i.e. imgsz=[640, 480]
imgsz = list(imgsz) # convert to list if tuple
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
if new_size != imgsz:

@ -55,11 +55,11 @@ class ClassificationPredictor(BasePredictor):
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def predict(cfg):
cfg.model = cfg.model or "squeezenet1_0"
sz = cfg.img_size
sz = cfg.imgsz
if type(sz) != int: # recieved listConfig
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else:
cfg.img_size = [sz, sz]
cfg.imgsz = [sz, sz]
predictor = ClassificationPredictor(cfg)
predictor()

@ -36,7 +36,7 @@ class ClassificationTrainer(BaseTrainer):
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
return build_classification_dataloader(path=dataset_path,
imgsz=self.args.img_size,
imgsz=self.args.imgsz,
batch_size=batch_size,
rank=rank)
@ -70,7 +70,7 @@ def train(cfg):
if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/classify/train.py model=resnet18 data=imagenette160 epochs=1 img_size=224
python ultralytics/yolo/v8/classify/train.py model=resnet18 data=imagenette160 epochs=1 imgsz=224
TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10

@ -28,7 +28,7 @@ class ClassificationValidator(BaseValidator):
return {"top1": top1, "top5": top5, "fitness": top5}
def get_dataloader(self, dataset_path, batch_size):
return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, batch_size=batch_size)
return build_classification_dataloader(path=dataset_path, imgsz=self.args.imgsz, batch_size=batch_size)
@property
def metric_keys(self):

@ -84,11 +84,11 @@ class DetectionPredictor(BasePredictor):
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def predict(cfg):
cfg.model = cfg.model or "n.pt"
sz = cfg.img_size
sz = cfg.imgsz
if type(sz) != int: # recieved listConfig
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else:
cfg.img_size = [sz, sz]
cfg.imgsz = [sz, sz]
predictor = DetectionPredictor(cfg)
predictor()

@ -28,7 +28,7 @@ class DetectionTrainer(BaseTrainer):
nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
self.args.box *= 3 / nl # scale to layers
self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
self.args.obj *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
@ -223,7 +223,7 @@ def train(cfg):
if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/detect/train.py model=yolov5n.yaml data=coco128 epochs=100 img_size=640
python ultralytics/yolo/v8/detect/train.py model=yolov5n.yaml data=coco128 epochs=100 imgsz=640
TODO:
yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=100

@ -102,11 +102,11 @@ class SegmentationPredictor(DetectionPredictor):
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def predict(cfg):
cfg.model = cfg.model or "n.pt"
sz = cfg.img_size
sz = cfg.imgsz
if type(sz) != int: # recieved listConfig
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
cfg.imgsz = [sz[0], sz[0]] if len(cfg.imgsz) == 1 else [sz[0], sz[1]] # expand
else:
cfg.img_size = [sz, sz]
cfg.imgsz = [sz, sz]
predictor = SegmentationPredictor(cfg)
predictor()

@ -243,7 +243,7 @@ def train(cfg):
if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640
python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-segments epochs=100 imgsz=640
TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10

Loading…
Cancel
Save