General console printout updates (#48)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-11-19 16:08:16 +01:00
committed by GitHub
parent 8530e3fae0
commit 27d6545117
12 changed files with 81 additions and 105 deletions

View File

@ -1,7 +1,3 @@
import subprocess
import time
from pathlib import Path
import hydra
import torch
import torch.nn as nn
@ -10,7 +6,6 @@ import torch.nn.functional as F
from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.anchors import check_anchors
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
@ -24,7 +19,7 @@ class SegmentationTrainer(BaseTrainer):
# TODO: manage splits differently
# calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
loader = build_dataloader(
return build_dataloader(
img_path=dataset_path,
img_size=self.args.img_size,
batch_size=batch_size,
@ -38,18 +33,16 @@ class SegmentationTrainer(BaseTrainer):
shuffle=self.args.shuffle,
use_segments=True,
)[0]
return loader
def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
return batch
def load_model(self, model_cfg, weights, data):
model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml,
model = SegmentationModel(model_cfg or weights["model"].yaml,
ch=3,
nc=data["nc"],
anchors=self.args.get("anchors"))
check_anchors(model, self.args.anchor_t, self.args.img_size)
if weights:
model.load(weights)
return model