ultralytics 8.0.50 AMP check and YOLOv5u YAMLs (#1263)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Troy <wudashuo@vip.qq.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-03-06 11:39:26 +01:00
committed by GitHub
parent 3861e6c82a
commit f0d8e4718b
29 changed files with 440 additions and 83 deletions

View File

@ -35,7 +35,8 @@ class BaseDataset(Dataset):
batch_size=None,
stride=32,
pad=0.5,
single_cls=False):
single_cls=False,
classes=None):
super().__init__()
self.img_path = img_path
self.imgsz = imgsz
@ -45,8 +46,7 @@ class BaseDataset(Dataset):
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
if self.single_cls:
self.update_labels(include_class=[])
self.update_labels(include_class=classes) # single_cls and include_class
self.ni = len(self.labels)
@ -96,7 +96,7 @@ class BaseDataset(Dataset):
"""include_class, filter labels to include only these classes (optional)"""
include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)):
if include_class:
if include_class is not None:
cls = self.labels[i]['cls']
bboxes = self.labels[i]['bboxes']
segments = self.labels[i]['segments']
@ -104,7 +104,7 @@ class BaseDataset(Dataset):
self.labels[i]['cls'] = cls[j]
self.labels[i]['bboxes'] = bboxes[j]
if segments:
self.labels[i]['segments'] = segments[j]
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
if self.single_cls:
self.labels[i]['cls'][:, 0] = 0

View File

@ -10,7 +10,7 @@ from PIL import Image
from torch.utils.data import DataLoader, dataloader, distributed
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
LoadStreams, SourceTypes, autocast_list)
LoadStreams, LoadTensor, SourceTypes, autocast_list)
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file
@ -82,7 +82,8 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
prefix=colorstr(f'{mode}: '),
use_segments=cfg.task == 'segment',
use_keypoints=cfg.task == 'keypoint',
names=names)
names=names,
classes=cfg.classes)
batch = min(batch, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
@ -133,7 +134,7 @@ def build_classification_dataloader(path,
def check_source(source):
webcam, screenshot, from_img, in_memory = False, False, False, False
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
@ -149,22 +150,25 @@ def check_source(source):
from_img = True
elif isinstance(source, (Image.Image, np.ndarray)):
from_img = True
elif isinstance(source, torch.Tensor):
tensor = True
else:
raise TypeError('Unsupported image type. See docs for supported types https://docs.ultralytics.com/predict')
return source, webcam, screenshot, from_img, in_memory
return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
"""
TODO: docs
"""
# source
source, webcam, screenshot, from_img, in_memory = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img)
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
# Dataloader
if in_memory:
if tensor:
dataset = LoadTensor(source)
elif in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source,

View File

@ -26,6 +26,7 @@ class SourceTypes:
webcam: bool = False
screenshot: bool = False
from_img: bool = False
tensor: bool = False
class LoadStreams:
@ -329,6 +330,23 @@ class LoadPilAndNumpy:
return self
class LoadTensor:
def __init__(self, imgs) -> None:
self.im0 = imgs
self.bs = imgs.shape[0]
def __iter__(self):
self.count = 0
return self
def __next__(self):
if self.count == 1:
raise StopIteration
self.count += 1
return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, ''
def autocast_list(source):
"""
Merges a list of source of different types into a list of numpy arrays or PIL images

View File

@ -539,7 +539,7 @@ class LoadImagesAndLabels(Dataset):
j = (label[:, 0:1] == include_class_array).any(1)
self.labels[i] = label[j]
if segment:
self.segments[i] = segment[j]
self.segments[i] = [segment[si] for si, idx in enumerate(j) if idx]
if single_cls: # single-class training, merge all classes into 0
self.labels[i][:, 0] = 0

View File

@ -57,12 +57,14 @@ class YOLODataset(BaseDataset):
single_cls=False,
use_segments=False,
use_keypoints=False,
names=None):
names=None,
classes=None):
self.use_segments = use_segments
self.use_keypoints = use_keypoints
self.names = names
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls,
classes)
def cache_labels(self, path=Path('./labels.cache')):
"""Cache dataset labels, check images and read shapes.

View File

@ -16,6 +16,7 @@ import numpy as np
from PIL import ExifTags, Image, ImageOps
from tqdm import tqdm
from ultralytics.nn.autobackend import check_class_names
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
@ -211,8 +212,7 @@ def check_det_dataset(dataset, autodownload=True):
raise SyntaxError(
emojis(f"{dataset} '{k}:' key missing ❌.\n"
f"'train', 'val' and 'names' are required in data.yaml files."))
if isinstance(data['names'], (list, tuple)): # old array format
data['names'] = dict(enumerate(data['names'])) # convert to dict
data['names'] = check_class_names(data['names'])
data['nc'] = len(data['names'])
# Resolve paths