Revert loss head PR (#2873)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-05-28 19:45:41 +05:30
committed by GitHub
parent 6391c60089
commit 527a97759b
8 changed files with 335 additions and 327 deletions

View File

@ -41,6 +41,7 @@ class ClassificationTrainer(BaseTrainer):
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
@ -102,6 +103,12 @@ class ClassificationTrainer(BaseTrainer):
self.loss_names = ['loss']
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
def criterion(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
loss_items = loss.detach()
return loss, loss_items
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor