new check_dataset functions (#43)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -24,7 +24,9 @@ from tqdm import tqdm
|
||||
|
||||
import ultralytics.yolo.utils as utils
|
||||
import ultralytics.yolo.utils.loggers as loggers
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT
|
||||
from ultralytics.yolo.utils.checks import check_file, check_yaml
|
||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
|
||||
@ -55,9 +57,14 @@ class BaseTrainer:
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders.
|
||||
self.trainset, self.testset = self.get_dataset(self.args.data)
|
||||
self.data = self.args.data
|
||||
if self.data.endswith(".yaml"):
|
||||
self.data = check_dataset_yaml(self.data)
|
||||
else:
|
||||
self.data = check_dataset(self.data)
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
if self.args.cfg is not None:
|
||||
self.model = self.load_cfg(self.args.cfg)
|
||||
self.model = self.load_cfg(check_file(self.args.cfg))
|
||||
if self.args.model is not None:
|
||||
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
|
||||
|
||||
@ -250,10 +257,9 @@ class BaseTrainer:
|
||||
|
||||
def get_dataset(self, data):
|
||||
"""
|
||||
Download the dataset if needed and verify it.
|
||||
Returns train and val split datasets
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized
|
||||
"""
|
||||
pass
|
||||
return data["train"], data["val"]
|
||||
|
||||
def get_model(self, model, pretrained):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user