General trainer cleanup (#147)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-07 19:25:48 +05:30
committed by GitHub
parent f8a13c49a0
commit 0e5a7ae623
8 changed files with 196 additions and 60 deletions

View File

@ -55,10 +55,11 @@ class DetectionTrainer(BaseTrainer):
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
self.model.names = self.data["names"]
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
def get_model(self, cfg=None, weights=None, verbose=True):
model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
if weights:
model.load(weights, verbose)
model.load(model)
return model
def get_validator(self):