Integration of v8 segmentation (#107)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2022-12-28 23:01:38 +08:00
committed by GitHub
parent 384f0ef1c6
commit 8406b49b49
16 changed files with 422 additions and 224 deletions

View File

@ -9,11 +9,10 @@ from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils import colorstr
from ultralytics.yolo.utils.loss import BboxLoss
from ultralytics.yolo.utils.metrics import smooth_BCE
from ultralytics.yolo.utils.ops import xywh2xyxy
from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel, strip_optimizer
from ultralytics.yolo.utils.torch_utils import de_parallel
# BaseTrainer python usage
@ -78,7 +77,8 @@ class DetectionTrainer(BaseTrainer):
return dict(zip(keys, loss_items)) if loss_items is not None else keys
def progress_string(self):
return ('\n' + '%11s' * 7) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def plot_training_samples(self, batch, ni):
plot_images(images=batch["img"],
@ -100,15 +100,13 @@ class Loss:
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.nl = m.nl # number of layers
self.no = m.no
self.reg_max = m.reg_max
self.device = device
self.use_dfl = m.reg_max > 1
@ -141,12 +139,15 @@ class Loss:
def __call__(self, preds, batch):
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats, pred_distri, pred_scores = preds if len(preds) == 3 else preds[1]
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
batch_size, grid_size = pred_scores.shape[:2]
batch_size = pred_scores.shape[0]
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
@ -159,7 +160,7 @@ class Loss:
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
target_labels, target_bboxes, target_scores, fg_mask = self.assigner(
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)