General console printout updates (#48)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -665,7 +665,7 @@ def mosaic_transforms(img_size, hyp):
|
||||
perspective=hyp.perspective,
|
||||
border=[-img_size // 2, -img_size // 2],
|
||||
),])
|
||||
transforms = Compose([
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(
|
||||
pre_transform=pre_transform,
|
||||
@ -674,13 +674,11 @@ def mosaic_transforms(img_size, hyp):
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
||||
return transforms
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
||||
|
||||
|
||||
def affine_transforms(img_size, hyp):
|
||||
# rect, randomperspective, albumentation, hsv, flipud, fliplr
|
||||
transforms = Compose([
|
||||
return Compose([
|
||||
LetterBox(new_shape=(img_size, img_size)),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
@ -693,11 +691,10 @@ def affine_transforms(img_size, hyp):
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
||||
return transforms
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
||||
|
||||
|
||||
# Classification augmentations -------------------------------------------------------------------------------------------
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
def classify_transforms(size=224):
|
||||
# Transforms to apply if albumentations not installed
|
||||
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
||||
|
@ -9,8 +9,8 @@ import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import NUM_THREADS
|
||||
from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK
|
||||
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
@ -18,7 +18,7 @@ class BaseDataset(Dataset):
|
||||
Args:
|
||||
img_path (str): image path.
|
||||
pipeline (dict): a dict of image transforms.
|
||||
label_path (str): label path, this can also be a ann_file or other custom label path.
|
||||
label_path (str): label path, this can also be an ann_file or other custom label path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -131,7 +131,7 @@ class BaseDataset(Dataset):
|
||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
|
||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if self.cache == "disk":
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
|
@ -6,10 +6,10 @@ from typing import OrderedDict
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import NUM_THREADS
|
||||
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from .augment import *
|
||||
from .base import BaseDataset
|
||||
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
@ -40,7 +40,7 @@ class YOLODataset(BaseDataset):
|
||||
):
|
||||
self.use_segments = use_segments
|
||||
self.use_keypoints = use_keypoints
|
||||
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
|
||||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
|
||||
single_cls)
|
||||
|
||||
@ -48,14 +48,14 @@ class YOLODataset(BaseDataset):
|
||||
# Cache dataset labels, check images and read shapes
|
||||
x = {"labels": []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
with Pool(NUM_THREADS) as pool:
|
||||
pbar = tqdm(
|
||||
pool.imap(verify_image_label,
|
||||
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
||||
desc=desc,
|
||||
total=len(self.im_files),
|
||||
bar_format=BAR_FORMAT,
|
||||
bar_format=TQDM_BAR_FORMAT,
|
||||
)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
@ -76,7 +76,7 @@ class YOLODataset(BaseDataset):
|
||||
))
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
|
||||
pbar.close()
|
||||
if msgs:
|
||||
@ -109,8 +109,8 @@ class YOLODataset(BaseDataset):
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
|
||||
|
@ -22,7 +22,6 @@ from ..utils.ops import segments2boxes
|
||||
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
|
||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
||||
BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" # tqdm bar format
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
|
Reference in New Issue
Block a user