ultralytics 8.0.94
HUBDatasetStats() Segment and Pose support (#2450)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: JF Chen <k-2feng@hotmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
@ -24,7 +24,7 @@ TASK2MODEL = {
|
||||
'detect': 'yolov8n.pt',
|
||||
'segment': 'yolov8n-seg.pt',
|
||||
'classify': 'yolov8n-cls.pt',
|
||||
'pose': 'yolov8n-pose.yaml'}
|
||||
'pose': 'yolov8n-pose.pt'}
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
f"""
|
||||
|
@ -15,7 +15,7 @@ import psutil
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from ..utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from .utils import HELP_URL, IMG_FORMATS
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ class BaseDataset(Dataset):
|
||||
imgsz=640,
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=None,
|
||||
hyp=DEFAULT_CFG,
|
||||
prefix='',
|
||||
rect=False,
|
||||
batch_size=None,
|
||||
|
@ -71,7 +71,7 @@ def seed_worker(worker_id): # noqa
|
||||
|
||||
def build_yolo_dataset(cfg, img_path, batch, data_info, mode='train', rect=False, stride=32):
|
||||
"""Build YOLO Dataset"""
|
||||
dataset = YOLODataset(
|
||||
return YOLODataset(
|
||||
img_path=img_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
@ -87,7 +87,6 @@ def build_yolo_dataset(cfg, img_path, batch, data_info, mode='train', rect=False
|
||||
use_keypoints=cfg.task == 'pose',
|
||||
classes=cfg.classes,
|
||||
data=data_info)
|
||||
return dataset
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||
|
@ -209,7 +209,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
album_transform: Albumentations transforms, used if installed
|
||||
"""
|
||||
|
||||
def __init__(self, root, augment, imgsz, cache=False):
|
||||
def __init__(self, root, augment=False, imgsz=224, cache=False):
|
||||
"""Initialize YOLO object with root, image size, augmentations, and cache settings"""
|
||||
super().__init__(root=root)
|
||||
self.torch_transforms = classify_transforms(imgsz)
|
||||
|
@ -310,17 +310,19 @@ class HUBDatasetStats():
|
||||
|
||||
Arguments
|
||||
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
||||
task: Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.
|
||||
autodownload: Attempt to download dataset if not found locally
|
||||
|
||||
Usage
|
||||
from ultralytics.yolo.data.utils import HUBDatasetStats
|
||||
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco6.zip') # usage 2
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
|
||||
stats.get_json(save=False)
|
||||
stats.process_images()
|
||||
"""
|
||||
|
||||
def __init__(self, path='coco128.yaml', autodownload=False):
|
||||
def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
|
||||
"""Initialize class."""
|
||||
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||
try:
|
||||
@ -336,6 +338,7 @@ class HUBDatasetStats():
|
||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
||||
self.data = data
|
||||
self.task = task # detect, segment, pose, classify
|
||||
|
||||
@staticmethod
|
||||
def _find_yaml(dir):
|
||||
@ -352,11 +355,10 @@ class HUBDatasetStats():
|
||||
"""Unzip data.zip."""
|
||||
if not str(path).endswith('.zip'): # path is data.yaml
|
||||
return False, None, path
|
||||
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
||||
unzip_file(path, path=path.parent)
|
||||
dir = path.with_suffix('') # dataset directory == zip name
|
||||
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
||||
unzip_dir = unzip_file(path, path=path.parent)
|
||||
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
|
||||
f'path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f):
|
||||
"""Saves a compressed image for HUB previews."""
|
||||
@ -364,20 +366,33 @@ class HUBDatasetStats():
|
||||
|
||||
def get_json(self, save=False, verbose=False):
|
||||
"""Return dataset JSON for Ultralytics HUB."""
|
||||
# from ultralytics.yolo.data import YOLODataset
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
|
||||
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
|
||||
|
||||
def _round(labels):
|
||||
"""Update labels to integer class and 6 decimal place floats."""
|
||||
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
|
||||
"""Update labels to integer class and 4 decimal place floats."""
|
||||
if self.task == 'detect':
|
||||
coordinates = labels['bboxes']
|
||||
elif self.task == 'segment':
|
||||
coordinates = [x.flatten() for x in labels['segments']]
|
||||
elif self.task == 'pose':
|
||||
n = labels['keypoints'].shape[0]
|
||||
coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
|
||||
else:
|
||||
raise ValueError('Undefined dataset task.')
|
||||
zipped = zip(labels['cls'], coordinates)
|
||||
return [[int(c), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
self.stats[split] = None # i.e. no test set
|
||||
continue
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
|
||||
dataset = YOLODataset(img_path=self.data[split],
|
||||
data=self.data,
|
||||
use_segments=self.task == 'segment',
|
||||
use_keypoints=self.task == 'pose')
|
||||
x = np.array([
|
||||
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
|
||||
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
|
||||
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
|
||||
self.stats[split] = {
|
||||
'instance_stats': {
|
||||
@ -388,7 +403,7 @@ class HUBDatasetStats():
|
||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
||||
'per_class': (x > 0).sum(0).tolist()},
|
||||
'labels': [{
|
||||
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||
Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||
|
||||
# Save, print and return
|
||||
if save:
|
||||
@ -402,13 +417,12 @@ class HUBDatasetStats():
|
||||
|
||||
def process_images(self):
|
||||
"""Compress images for Ultralytics HUB."""
|
||||
# from ultralytics.yolo.data import YOLODataset
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
|
||||
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
continue
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
|
||||
pass
|
||||
|
@ -37,26 +37,39 @@ def is_url(url, check=True):
|
||||
|
||||
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
|
||||
"""
|
||||
Unzip a *.zip file to path/, excluding files containing strings in exclude list
|
||||
Replaces: ZipFile(file).extractall(path=path)
|
||||
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
|
||||
|
||||
If the zipfile does not contain a single top-level directory, the function will create a new
|
||||
directory with the same name as the zipfile (without the extension) to extract its contents.
|
||||
If a path is not provided, the function will use the parent directory of the zipfile as the default path.
|
||||
|
||||
Args:
|
||||
file (str): The path to the zipfile to be extracted.
|
||||
path (str, optional): The path to extract the zipfile to. Defaults to None.
|
||||
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
|
||||
|
||||
Raises:
|
||||
BadZipFile: If the provided file does not exist or is not a valid zipfile.
|
||||
|
||||
Returns:
|
||||
(Path): The path to the directory where the zipfile was extracted.
|
||||
"""
|
||||
if not (Path(file).exists() and is_zipfile(file)):
|
||||
raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
|
||||
if path is None:
|
||||
path = Path(file).parent # default path
|
||||
|
||||
with ZipFile(file) as zipObj:
|
||||
for i, f in enumerate(zipObj.namelist()): # list all archived filenames in the zip
|
||||
# If zip does not expand into a directory create a new directory to expand into
|
||||
if i == 0:
|
||||
info = zipObj.getinfo(f)
|
||||
if info.file_size > 0 or not info.filename.endswith('/'): # element is a file and not a directory
|
||||
path = Path(path) / Path(file).stem # define new unzip directory
|
||||
unzip_dir = path
|
||||
else:
|
||||
unzip_dir = f
|
||||
if all(x not in f for x in exclude):
|
||||
zipObj.extract(f, path=path)
|
||||
return unzip_dir # return unzip dir
|
||||
file_list = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
|
||||
top_level_dirs = {Path(f).parts[0] for f in file_list}
|
||||
|
||||
if len(top_level_dirs) > 1 or not file_list[0].endswith('/'):
|
||||
path = Path(path) / Path(file).stem # define new unzip directory
|
||||
|
||||
for f in file_list:
|
||||
zipObj.extract(f, path=path)
|
||||
|
||||
return path # return unzip dir
|
||||
|
||||
|
||||
def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
|
||||
|
@ -318,7 +318,7 @@ class ConfusionMatrix:
|
||||
nc, nn = self.nc, len(names) # number of classes, names
|
||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||
ticklabels = (names + ['background']) if labels else 'auto'
|
||||
ticklabels = (list(names) + ['background']) if labels else 'auto'
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(array,
|
||||
@ -332,10 +332,11 @@ class ConfusionMatrix:
|
||||
vmin=0.0,
|
||||
xticklabels=ticklabels,
|
||||
yticklabels=ticklabels).set_facecolor((1, 1, 1))
|
||||
title = 'Confusion Matrix' + ' Normalized' * normalize
|
||||
ax.set_xlabel('True')
|
||||
ax.set_ylabel('Predicted')
|
||||
ax.set_title('Confusion Matrix')
|
||||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
||||
ax.set_title(title)
|
||||
fig.savefig(Path(save_dir) / f'{title.lower().replace(" ", "_")}.png', dpi=250)
|
||||
plt.close(fig)
|
||||
|
||||
def print(self):
|
||||
|
@ -38,5 +38,5 @@ default_space = {
|
||||
task_metric_map = {
|
||||
'detect': 'metrics/mAP50-95(B)',
|
||||
'segment': 'metrics/mAP50-95(M)',
|
||||
'classify': 'top1_acc',
|
||||
'pose': None}
|
||||
'classify': 'metrics/accuracy_top1',
|
||||
'pose': 'metrics/mAP50-95(P)'}
|
||||
|
@ -72,9 +72,8 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def build_dataset(self, img_path, mode='train'):
|
||||
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
|
||||
return dataset
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
|
@ -46,7 +46,8 @@ class ClassificationValidator(BaseValidator):
|
||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
|
@ -32,7 +32,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode='train'):
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""TODO: manage splits differently."""
|
||||
# Calculate stride - check if model is initialized
|
||||
if self.args.v5loader:
|
||||
@ -62,8 +62,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
workers = self.args.workers if mode == 'train' else self.args.workers * 2
|
||||
dataloader = build_dataloader(dataset, batch_size, workers, shuffle, rank)
|
||||
return dataloader
|
||||
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||
|
@ -144,7 +144,8 @@ class DetectionValidator(BaseValidator):
|
||||
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
|
||||
|
||||
def _process_batch(self, detections, labels):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user