Add EMA and model checkpointing (#49)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -159,11 +159,11 @@ class SegmentationTrainer(BaseTrainer):
|
||||
|
||||
return tcls, tbox, indices, anch, tidxs, xywhn
|
||||
|
||||
if self.model.training:
|
||||
if len(preds) == 2: # eval
|
||||
p, proto, = preds
|
||||
else:
|
||||
p, proto, train_out = preds
|
||||
p = train_out
|
||||
else: # len(3) train
|
||||
_, proto, p = preds
|
||||
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
masks = batch["masks"]
|
||||
targets, masks = targets.to(self.device), masks.to(self.device).float()
|
||||
|
@ -1,5 +1,4 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
Reference in New Issue
Block a user