ultralytics 8.0.47 Docker and reformat updates (#1153)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-25 22:49:19 -08:00
committed by GitHub
parent d4be4cb24b
commit a58f766f94
41 changed files with 224 additions and 201 deletions

View File

@ -5,12 +5,5 @@ from .build import build_classification_dataloader, build_dataloader, load_infer
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .dataset_wrappers import MixAndRectDataset
__all__ = [
'BaseDataset',
'ClassificationDataset',
'MixAndRectDataset',
'SemanticDataset',
'YOLODataset',
'build_classification_dataloader',
'build_dataloader',
'load_inference_source',]
__all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset',
'build_classification_dataloader', 'build_dataloader', 'load_inference_source')

View File

@ -564,7 +564,7 @@ class Albumentations:
A.CLAHE(p=0.01),
A.RandomBrightnessContrast(p=0.0),
A.RandomGamma(p=0.0),
A.ImageCompression(quality_lower=75, p=0.0),] # transforms
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
@ -671,14 +671,14 @@ def v8_transforms(dataset, imgsz, hyp):
shear=hyp.shear,
perspective=hyp.perspective,
pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
),])
)])
return Compose([
pre_transform,
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction='vertical', p=hyp.flipud),
RandomFlip(direction='horizontal', p=hyp.fliplr),]) # transforms
RandomFlip(direction='horizontal', p=hyp.fliplr)]) # transforms
# Classification augmentations -----------------------------------------------------------------------------------------
@ -719,8 +719,8 @@ def classify_albumentations(
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if jitter > 0:
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, saturation, 0 hue
T += [A.ColorJitter(*color_jitter, 0)]
jitter = float(jitter)
T += [A.ColorJitter(jitter, jitter, jitter, 0)] # brightness, contrast, saturation, 0 hue
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor

View File

@ -24,20 +24,18 @@ class BaseDataset(Dataset):
label_path (str): label path, this can also be an ann_file or other custom label path.
"""
def __init__(
self,
img_path,
imgsz=640,
cache=False,
augment=True,
hyp=None,
prefix='',
rect=False,
batch_size=None,
stride=32,
pad=0.5,
single_cls=False,
):
def __init__(self,
img_path,
imgsz=640,
cache=False,
augment=True,
hyp=None,
prefix='',
rect=False,
batch_size=None,
stride=32,
pad=0.5,
single_cls=False):
super().__init__()
self.img_path = img_path
self.imgsz = imgsz

View File

@ -335,8 +335,8 @@ def classify_albumentations(
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if jitter > 0:
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
T += [A.ColorJitter(*color_jitter, 0)]
jitter = float(jitter)
T += [A.ColorJitter(jitter, jitter, jitter, 0)] # brightness, contrast, satuaration, 0 hue
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor

View File

@ -4,13 +4,16 @@ from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import numpy as np
import torch
import torchvision
from tqdm import tqdm
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
from .augment import *
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
from .utils import HELP_URL, LOCAL_RANK, LOGGER, get_hash, img2label_paths, verify_image_label
class YOLODataset(BaseDataset):