Move loss to task heads (#2825)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
@ -41,7 +41,6 @@ class ClassificationTrainer(BaseTrainer):
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
@ -103,12 +102,6 @@ class ClassificationTrainer(BaseTrainer):
|
||||
self.loss_names = ['loss']
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
|
Reference in New Issue
Block a user