update segment training (#57)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -578,8 +578,8 @@ class Albumentations:
|
||||
# TODO: add supports of segments and keypoints
|
||||
if self.transform and random.random() < self.p:
|
||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||
labels["img"] = new["image"]
|
||||
labels["cls"] = np.array(new["class_labels"])
|
||||
labels["img"] = new["image"]
|
||||
labels["cls"] = np.array(new["class_labels"])
|
||||
labels["instances"].update(bboxes=bboxes)
|
||||
return labels
|
||||
|
||||
@ -635,7 +635,7 @@ class Format:
|
||||
def _format_img(self, img):
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
|
||||
img = torch.from_numpy(img)
|
||||
return img
|
||||
|
||||
|
@ -151,7 +151,7 @@ class BaseDataset(Dataset):
|
||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
s = np.array([x["shape"] for x in self.labels]) # hw
|
||||
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||
irect = ar.argsort()
|
||||
self.im_files = [self.im_files[i] for i in irect]
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, dataloader, distributed
|
||||
|
||||
from ..utils import LOGGER
|
||||
from ..utils import LOGGER, colorstr
|
||||
from ..utils.torch_utils import torch_distributed_zero_first
|
||||
from .dataset import ClassificationDataset, YOLODataset
|
||||
from .utils import PIN_MEMORY, RANK
|
||||
@ -52,53 +52,36 @@ def seed_worker(worker_id):
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
# TODO: we can inject most args from a config file
|
||||
def build_dataloader(
|
||||
img_path,
|
||||
img_size, #
|
||||
batch_size, #
|
||||
single_cls=False, #
|
||||
hyp=None, #
|
||||
augment=False,
|
||||
cache=False, #
|
||||
image_weights=False, #
|
||||
stride=32,
|
||||
label_path=None,
|
||||
pad=0.0,
|
||||
rect=False,
|
||||
rank=-1,
|
||||
workers=8,
|
||||
prefix="",
|
||||
shuffle=False,
|
||||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
):
|
||||
if rect and shuffle:
|
||||
def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
|
||||
assert mode in ["train", "val"]
|
||||
shuffle = mode == "train"
|
||||
if cfg.rect and shuffle:
|
||||
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = YOLODataset(
|
||||
img_path=img_path,
|
||||
img_size=img_size,
|
||||
batch_size=batch_size,
|
||||
label_path=label_path,
|
||||
augment=augment, # augmentation
|
||||
hyp=hyp,
|
||||
rect=rect, # rectangular batches
|
||||
cache=cache,
|
||||
single_cls=single_cls,
|
||||
img_size=cfg.img_size,
|
||||
batch_size=batch_size,
|
||||
augment=True if mode == "train" else False, # augmentation
|
||||
hyp=cfg.get("augment_hyp", None),
|
||||
rect=cfg.rect if mode == "train" else True, # rectangular batches
|
||||
cache=None if cfg.noval else cfg.get("cache", None),
|
||||
single_cls=cfg.get("single_cls", False),
|
||||
stride=int(stride),
|
||||
pad=pad,
|
||||
prefix=prefix,
|
||||
use_segments=use_segments,
|
||||
use_keypoints=use_keypoints,
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
use_segments=cfg.task == "segment",
|
||||
use_keypoints=cfg.task == "keypoint",
|
||||
)
|
||||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
workers = cfg.workers if mode == "train" else cfg.workers * 2
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return (
|
||||
@ -118,6 +101,7 @@ def build_dataloader(
|
||||
|
||||
|
||||
# build classification
|
||||
# TODO: using cfg like `build_dataloader`
|
||||
def build_classification_dataloader(path,
|
||||
imgsz=224,
|
||||
batch_size=16,
|
||||
|
Reference in New Issue
Block a user