General trainer cleanup (#147)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-07 19:25:48 +05:30
committed by GitHub
parent f8a13c49a0
commit 0e5a7ae623
8 changed files with 196 additions and 60 deletions

View File

@ -18,10 +18,15 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer):
def load_model(self, model_cfg=None, weights=None, verbose=True):
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
overrides["task"] = "segment"
super().__init__(config, overrides)
def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
if weights:
model.load(weights, verbose)
model.load(weights)
return model
def get_validator(self):

View File

@ -19,6 +19,7 @@ class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args)
self.args.task = "segment"
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots)
def preprocess(self, batch):