General console printout updates (#48)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-11-19 16:08:16 +01:00
committed by GitHub
parent 8530e3fae0
commit 27d6545117
12 changed files with 81 additions and 105 deletions

View File

@ -665,7 +665,7 @@ def mosaic_transforms(img_size, hyp):
perspective=hyp.perspective,
border=[-img_size // 2, -img_size // 2],
),])
transforms = Compose([
return Compose([
pre_transform,
MixUp(
pre_transform=pre_transform,
@ -674,13 +674,11 @@ def mosaic_transforms(img_size, hyp):
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),])
return transforms
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
def affine_transforms(img_size, hyp):
# rect, randomperspective, albumentation, hsv, flipud, fliplr
transforms = Compose([
return Compose([
LetterBox(new_shape=(img_size, img_size)),
RandomPerspective(
degrees=hyp.degrees,
@ -693,11 +691,10 @@ def affine_transforms(img_size, hyp):
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),])
return transforms
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
# Classification augmentations -------------------------------------------------------------------------------------------
# Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224):
# Transforms to apply if albumentations not installed
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"

View File

@ -9,8 +9,8 @@ import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm
from ..utils import NUM_THREADS
from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
class BaseDataset(Dataset):
@ -18,7 +18,7 @@ class BaseDataset(Dataset):
Args:
img_path (str): image path.
pipeline (dict): a dict of image transforms.
label_path (str): label path, this can also be a ann_file or other custom label path.
label_path (str): label path, this can also be an ann_file or other custom label path.
"""
def __init__(
@ -131,7 +131,7 @@ class BaseDataset(Dataset):
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if self.cache == "disk":
gb += self.npy_files[i].stat().st_size

View File

@ -6,10 +6,10 @@ from typing import OrderedDict
import torchvision
from tqdm import tqdm
from ..utils import NUM_THREADS
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
from .augment import *
from .base import BaseDataset
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
class YOLODataset(BaseDataset):
@ -40,7 +40,7 @@ class YOLODataset(BaseDataset):
):
self.use_segments = use_segments
self.use_keypoints = use_keypoints
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
single_cls)
@ -48,14 +48,14 @@ class YOLODataset(BaseDataset):
# Cache dataset labels, check images and read shapes
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
with Pool(NUM_THREADS) as pool:
pbar = tqdm(
pool.imap(verify_image_label,
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
desc=desc,
total=len(self.im_files),
bar_format=BAR_FORMAT,
bar_format=TQDM_BAR_FORMAT,
)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
@ -76,7 +76,7 @@ class YOLODataset(BaseDataset):
))
if msg:
msgs.append(msg)
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
@ -109,8 +109,8 @@ class YOLODataset(BaseDataset):
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"

View File

@ -22,7 +22,6 @@ from ..utils.ops import segments2boxes
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" # tqdm bar format
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders