Add EMA and model checkpointing (#49)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2022-11-19 23:37:26 +05:30
committed by GitHub
parent 27d6545117
commit 4291b9c31c
6 changed files with 55 additions and 21 deletions

View File

@ -159,11 +159,11 @@ class SegmentationTrainer(BaseTrainer):
return tcls, tbox, indices, anch, tidxs, xywhn
if self.model.training:
if len(preds) == 2: # eval
p, proto, = preds
else:
p, proto, train_out = preds
p = train_out
else: # len(3) train
_, proto, p = preds
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
masks = batch["masks"]
targets, masks = targets.to(self.device), masks.to(self.device).float()

View File

@ -1,5 +1,4 @@
import os
from pathlib import Path
import numpy as np
import torch