Simplify augmentations (#93)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -30,8 +30,8 @@ class DetectionTrainer(BaseTrainer):
|
||||
def set_model_attributes(self):
|
||||
nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
|
||||
self.args.box *= 3 / nl # scale to layers
|
||||
self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
self.args.obj *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||
self.model.args = self.args # attach hyperparameters to model
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
@ -85,14 +85,11 @@ class Loss:
|
||||
device = next(model.parameters()).device # get model device
|
||||
h = model.args # hyperparameters
|
||||
|
||||
# Define criteria
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device), reduction='none')
|
||||
|
||||
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||
self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets
|
||||
|
||||
m = model.model[-1] # Detect() module
|
||||
self.BCEcls = BCEcls
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction='none')
|
||||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
@ -156,7 +153,7 @@ class Loss:
|
||||
|
||||
# cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.BCEcls(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# bbox loss
|
||||
if fg_mask.sum():
|
||||
|
Reference in New Issue
Block a user