update segment training (#57)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Laughing
2022-11-29 05:30:08 -06:00
committed by GitHub
parent d0b0fe2592
commit 3a241e4cea
14 changed files with 460 additions and 144 deletions

View File

@ -578,8 +578,8 @@ class Albumentations:
# TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p:
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
labels["instances"].update(bboxes=bboxes)
return labels
@ -635,7 +635,7 @@ class Format:
def _format_img(self, img):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
img = torch.from_numpy(img)
return img

View File

@ -151,7 +151,7 @@ class BaseDataset(Dataset):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
s = np.array([x["shape"] for x in self.labels]) # hw
s = np.array([x.pop("shape") for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from torch.utils.data import DataLoader, dataloader, distributed
from ..utils import LOGGER
from ..utils import LOGGER, colorstr
from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset
from .utils import PIN_MEMORY, RANK
@ -52,53 +52,36 @@ def seed_worker(worker_id):
random.seed(worker_seed)
# TODO: we can inject most args from a config file
def build_dataloader(
img_path,
img_size, #
batch_size, #
single_cls=False, #
hyp=None, #
augment=False,
cache=False, #
image_weights=False, #
stride=32,
label_path=None,
pad=0.0,
rect=False,
rank=-1,
workers=8,
prefix="",
shuffle=False,
use_segments=False,
use_keypoints=False,
):
if rect and shuffle:
def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
assert mode in ["train", "val"]
shuffle = mode == "train"
if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = YOLODataset(
img_path=img_path,
img_size=img_size,
batch_size=batch_size,
label_path=label_path,
augment=augment, # augmentation
hyp=hyp,
rect=rect, # rectangular batches
cache=cache,
single_cls=single_cls,
img_size=cfg.img_size,
batch_size=batch_size,
augment=True if mode == "train" else False, # augmentation
hyp=cfg.get("augment_hyp", None),
rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=None if cfg.noval else cfg.get("cache", None),
single_cls=cfg.get("single_cls", False),
stride=int(stride),
pad=pad,
prefix=prefix,
use_segments=use_segments,
use_keypoints=use_keypoints,
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
use_segments=cfg.task == "segment",
use_keypoints=cfg.task == "keypoint",
)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return (
@ -118,6 +101,7 @@ def build_dataloader(
# build classification
# TODO: using cfg like `build_dataloader`
def build_classification_dataloader(path,
imgsz=224,
batch_size=16,