General console printout updates (#48)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 8530e3fae0
commit 27d6545117
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -665,7 +665,7 @@ def mosaic_transforms(img_size, hyp):
perspective=hyp.perspective, perspective=hyp.perspective,
border=[-img_size // 2, -img_size // 2], border=[-img_size // 2, -img_size // 2],
),]) ),])
transforms = Compose([ return Compose([
pre_transform, pre_transform,
MixUp( MixUp(
pre_transform=pre_transform, pre_transform=pre_transform,
@ -674,13 +674,11 @@ def mosaic_transforms(img_size, hyp):
Albumentations(p=1.0), Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),]) RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
return transforms
def affine_transforms(img_size, hyp): def affine_transforms(img_size, hyp):
# rect, randomperspective, albumentation, hsv, flipud, fliplr return Compose([
transforms = Compose([
LetterBox(new_shape=(img_size, img_size)), LetterBox(new_shape=(img_size, img_size)),
RandomPerspective( RandomPerspective(
degrees=hyp.degrees, degrees=hyp.degrees,
@ -693,11 +691,10 @@ def affine_transforms(img_size, hyp):
Albumentations(p=1.0), Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),]) RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
return transforms
# Classification augmentations ------------------------------------------------------------------------------------------- # Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224): def classify_transforms(size=224):
# Transforms to apply if albumentations not installed # Transforms to apply if albumentations not installed
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)" assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"

@ -9,8 +9,8 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from ..utils import NUM_THREADS from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
class BaseDataset(Dataset): class BaseDataset(Dataset):
@ -18,7 +18,7 @@ class BaseDataset(Dataset):
Args: Args:
img_path (str): image path. img_path (str): image path.
pipeline (dict): a dict of image transforms. pipeline (dict): a dict of image transforms.
label_path (str): label path, this can also be a ann_file or other custom label path. label_path (str): label path, this can also be an ann_file or other custom label path.
""" """
def __init__( def __init__(
@ -131,7 +131,7 @@ class BaseDataset(Dataset):
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni)) results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0) pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar: for i, x in pbar:
if self.cache == "disk": if self.cache == "disk":
gb += self.npy_files[i].stat().st_size gb += self.npy_files[i].stat().st_size

@ -6,10 +6,10 @@ from typing import OrderedDict
import torchvision import torchvision
from tqdm import tqdm from tqdm import tqdm
from ..utils import NUM_THREADS from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
from .augment import * from .augment import *
from .base import BaseDataset from .base import BaseDataset
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
class YOLODataset(BaseDataset): class YOLODataset(BaseDataset):
@ -40,7 +40,7 @@ class YOLODataset(BaseDataset):
): ):
self.use_segments = use_segments self.use_segments = use_segments
self.use_keypoints = use_keypoints self.use_keypoints = use_keypoints
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose." assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad, super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
single_cls) single_cls)
@ -48,14 +48,14 @@ class YOLODataset(BaseDataset):
# Cache dataset labels, check images and read shapes # Cache dataset labels, check images and read shapes
x = {"labels": []} x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..." desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
with Pool(NUM_THREADS) as pool: with Pool(NUM_THREADS) as pool:
pbar = tqdm( pbar = tqdm(
pool.imap(verify_image_label, pool.imap(verify_image_label,
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))), zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
desc=desc, desc=desc,
total=len(self.im_files), total=len(self.im_files),
bar_format=BAR_FORMAT, bar_format=TQDM_BAR_FORMAT,
) )
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f nm += nm_f
@ -76,7 +76,7 @@ class YOLODataset(BaseDataset):
)) ))
if msg: if msg:
msgs.append(msg) msgs.append(msg)
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt" pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close() pbar.close()
if msgs: if msgs:
@ -109,8 +109,8 @@ class YOLODataset(BaseDataset):
# Display cache # Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}: if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt" d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
if cache["msgs"]: if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings LOGGER.info("\n".join(cache["msgs"])) # display warnings
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}" assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"

@ -22,7 +22,6 @@ from ..utils.ops import segments2boxes
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data" HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" # tqdm bar format
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders

@ -25,8 +25,8 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.loggers as loggers import ultralytics.yolo.utils.loggers as loggers
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import check_file, check_yaml from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling import get_model
@ -41,19 +41,17 @@ class BaseTrainer:
self.validator = None self.validator = None
self.model = None self.model = None
self.callbacks = defaultdict(list) self.callbacks = defaultdict(list)
self.console.info(f"Training config: \n args: \n {self.args}") # to debug
# Directories
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
self.wdir = self.save_dir / 'weights' self.wdir = self.save_dir / 'weights' # weights dir
self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
print_args(dict(self.args))
# Save run settings # Save run settings
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
# device # device
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size) self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
self.console.info(f"running on device {self.device}")
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders. # Model and Dataloaders.
@ -64,7 +62,7 @@ class BaseTrainer:
self.data = check_dataset(self.data) self.data = check_dataset(self.data)
self.trainset, self.testset = self.get_dataset(self.data) self.trainset, self.testset = self.get_dataset(self.data)
if self.args.model: if self.args.model:
self.model = self.get_model(self.args.model, self.data) self.model = self.get_model(self.args.model)
# epoch level metrics # epoch level metrics
self.metrics = {} # handle metrics returned by validator self.metrics = {} # handle metrics returned by validator
@ -115,7 +113,7 @@ class BaseTrainer:
if world_size > 1: if world_size > 1:
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
else: else:
self._do_train(-1, 1) self._do_train()
def _setup_ddp(self, rank, world_size): def _setup_ddp(self, rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_ADDR'] = 'localhost'
@ -147,7 +145,7 @@ class BaseTrainer:
print("created testloader :", rank) print("created testloader :", rank)
self.console.info(self.progress_string()) self.console.info(self.progress_string())
def _do_train(self, rank, world_size): def _do_train(self, rank=-1, world_size=1):
if world_size > 1: if world_size > 1:
self._setup_ddp(rank, world_size) self._setup_ddp(rank, world_size)
else: else:
@ -165,9 +163,7 @@ class BaseTrainer:
self.model.train() self.model.train()
pbar = enumerate(self.train_loader) pbar = enumerate(self.train_loader)
if rank in {-1, 0}: if rank in {-1, 0}:
pbar = tqdm(enumerate(self.train_loader), pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
total=len(self.train_loader),
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
tloss = None tloss = None
for i, batch in pbar: for i, batch in pbar:
# img, label (classification)/ img, targets, paths, _, masks(detection) # img, label (classification)/ img, targets, paths, _, masks(detection)
@ -249,18 +245,14 @@ class BaseTrainer:
""" """
return data["train"], data["val"] return data["train"], data["val"]
def get_model(self, model: str, data: Dict): def get_model(self, model: Union[str, Path]):
""" """
load/create/download model for any task load/create/download model for any task
""" """
pretrained = False pretrained = not str(model).endswith(".yaml")
if not str(model).endswith(".yaml"): return self.load_model(model_cfg=None if pretrained else model,
pretrained = True weights=get_model(model) if pretrained else None,
weights = get_model(model) # rename this to something less confusing? data=self.data) # model
model = self.load_model(model_cfg=model if not pretrained else None,
weights=weights if pretrained else None,
data=self.data)
return model
def load_model(self, model_cfg, weights, data): def load_model(self, model_cfg, weights, data):
raise NotImplementedError("This task trainer doesn't support loading cfg files") raise NotImplementedError("This task trainer doesn't support loading cfg files")

@ -5,6 +5,7 @@ from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
@ -49,7 +50,7 @@ class BaseValidator:
loss = 0 loss = 0
n_batches = len(self.dataloader) n_batches = len(self.dataloader)
desc = self.get_desc() desc = self.get_desc()
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
self.init_metrics(de_parallel(model)) self.init_metrics(de_parallel(model))
with torch.no_grad(): with torch.no_grad():
for batch_i, batch in enumerate(bar): for batch_i, batch in enumerate(bar):

@ -14,6 +14,7 @@ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiproces
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
LOGGING_NAME = 'yolov5' LOGGING_NAME = 'yolov5'

@ -1,9 +1,10 @@
import glob import glob
import inspect
import platform import platform
import sys
import urllib import urllib
from pathlib import Path from pathlib import Path
from subprocess import check_output from subprocess import check_output
from typing import Optional
import pkg_resources as pkg import pkg_resources as pkg
import torch import torch
@ -128,3 +129,27 @@ def check_file(file, suffix=''):
def check_yaml(file, suffix=('.yaml', '.yml')): def check_yaml(file, suffix=('.yaml', '.yml')):
# Search/download YAML file (if necessary) and return path, checking suffix # Search/download YAML file (if necessary) and return path, checking suffix
return check_file(file, suffix) return check_file(file, suffix)
def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
try:
assert (Path(path) / '.git').is_dir()
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
except Exception:
return ''
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
except ValueError:
file = Path(file).stem
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))

@ -1,5 +1,6 @@
import contextlib import contextlib
import os import os
from datetime import datetime
from pathlib import Path from pathlib import Path
from zipfile import ZipFile from zipfile import ZipFile
@ -61,3 +62,15 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
for f in zipObj.namelist(): # list all archived filenames in the zip for f in zipObj.namelist(): # list all archived filenames in the zip
if all(x not in f for x in exclude): if all(x not in f for x in exclude):
zipObj.extract(f, path=path) zipObj.extract(f, path=path)
def file_age(path=__file__):
# Return days since last file update
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
# Return human-readable file modification date, i.e. '2021-3-26'
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'

@ -1,5 +1,6 @@
import math import math
import os import os
import platform
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
@ -12,7 +13,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.checks import git_describe
from .checks import check_version from .checks import check_version
@ -44,8 +47,8 @@ def DDP_model(model):
def select_device(device='', batch_size=0, newline=True): def select_device(device='', batch_size=0, newline=True):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3' # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
# s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} ' ver = git_describe() or ultralytics.__version__ # git commit or pip package version
s = f'YOLOv5 🚀 torch-{torch.__version__} ' s = f'Ultralytics YOLO 🚀 {ver} Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0' device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
cpu = device == 'cpu' cpu = device == 'cpu'
mps = device == 'mps' # Apple Metal Performance Shaders (MPS) mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
@ -75,7 +78,7 @@ def select_device(device='', batch_size=0, newline=True):
if not newline: if not newline:
s = s.rstrip() s = s.rstrip()
print(s) LOGGER.info(s)
return torch.device(arg) return torch.device(arg)

@ -1,7 +1,3 @@
import subprocess
import time
from pathlib import Path
import hydra import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -10,7 +6,6 @@ import torch.nn.functional as F
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.anchors import check_anchors
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
@ -24,7 +19,7 @@ class SegmentationTrainer(BaseTrainer):
# TODO: manage splits differently # TODO: manage splits differently
# calculate stride - check if model is initialized # calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
loader = build_dataloader( return build_dataloader(
img_path=dataset_path, img_path=dataset_path,
img_size=self.args.img_size, img_size=self.args.img_size,
batch_size=batch_size, batch_size=batch_size,
@ -38,18 +33,16 @@ class SegmentationTrainer(BaseTrainer):
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
use_segments=True, use_segments=True,
)[0] )[0]
return loader
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
return batch return batch
def load_model(self, model_cfg, weights, data): def load_model(self, model_cfg, weights, data):
model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml, model = SegmentationModel(model_cfg or weights["model"].yaml,
ch=3, ch=3,
nc=data["nc"], nc=data["nc"],
anchors=self.args.get("anchors")) anchors=self.args.get("anchors"))
check_anchors(model, self.args.anchor_t, self.args.img_size)
if weights: if weights:
model.load(weights) model.load(weights)
return model return model

@ -1,48 +0,0 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
]
Loading…
Cancel
Save