update segment training (#57)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -9,6 +9,9 @@ from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
# TODO: why treat clf models as unique. We should have clf yamls?
|
||||
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||
@ -18,7 +21,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
ClassificationModel.reshape_outputs(model, data["nc"])
|
||||
return model
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.img_size,
|
||||
batch_size=batch_size,
|
||||
|
@ -23,3 +23,7 @@ class ClassificationValidator(BaseValidator):
|
||||
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
top1, top5 = acc.mean(0).tolist()
|
||||
return {"top1": top1, "top5": top5, "fitness": top5}
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return ["top1", "top5"]
|
||||
|
Reference in New Issue
Block a user