new check_dataset functions (#43)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-11-12 19:32:41 +05:30
committed by GitHub
parent d143ac666f
commit 1f3aad86c1
13 changed files with 336 additions and 62 deletions

View File

@ -24,7 +24,9 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.loggers as loggers
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT
from ultralytics.yolo.utils.checks import check_file, check_yaml
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
@ -55,9 +57,14 @@ class BaseTrainer:
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
# Model and Dataloaders.
self.trainset, self.testset = self.get_dataset(self.args.data)
self.data = self.args.data
if self.data.endswith(".yaml"):
self.data = check_dataset_yaml(self.data)
else:
self.data = check_dataset(self.data)
self.trainset, self.testset = self.get_dataset(self.data)
if self.args.cfg is not None:
self.model = self.load_cfg(self.args.cfg)
self.model = self.load_cfg(check_file(self.args.cfg))
if self.args.model is not None:
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
@ -250,10 +257,9 @@ class BaseTrainer:
def get_dataset(self, data):
"""
Download the dataset if needed and verify it.
Returns train and val split datasets
Get train, val path from data dict if it exists. Returns None if data format is not recognized
"""
pass
return data["train"], data["val"]
def get_model(self, model, pretrained):
"""