From 47f1cb3ef489862ec17c56133d729f546698fece Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Thu, 17 Nov 2022 06:44:02 -0600 Subject: [PATCH] Fix some cuda training issues of segmentation (#46) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/engine/trainer.py | 11 +++++++---- ultralytics/yolo/engine/validator.py | 14 +++++++++----- ultralytics/yolo/v8/classify/val.py | 5 +++-- ultralytics/yolo/v8/segment/train.py | 4 ++-- ultralytics/yolo/v8/segment/val.py | 25 +++++++++++++++++-------- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index ec82738..a99826f 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -142,7 +142,7 @@ class BaseTrainer: self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) if rank in {0, -1}: print(" Creating testloader rank :", rank) - self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank) + self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1) self.validator = self.get_validator() print("created testloader :", rank) self.console.info(self.progress_string()) @@ -150,6 +150,8 @@ class BaseTrainer: def _do_train(self, rank, world_size): if world_size > 1: self._setup_ddp(rank, world_size) + else: + self.model = self.model.to(self.device) # callback hook. before_train self._setup_train(rank) @@ -192,8 +194,8 @@ class BaseTrainer: losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0) if rank in {-1, 0}: pbar.set_description( - (" {} " + "{:.3f} " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses, - batch["img"].shape[-1])) + (" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem, + *losses, batch["img"].shape[-1])) if rank in [-1, 0]: # validation @@ -286,7 +288,8 @@ class BaseTrainer: "fitness" metric. """ self.metrics = self.validator(self) - self.fitness = self.metrics.get("fitness") or (-self.loss) # use loss as fitness measure if not found + self.fitness = self.metrics.get("fitness", + -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found if not self.best_fitness or self.best_fitness < self.fitness: self.best_fitness = self.fitness diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index f7e6b75..b5ff0ab 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -6,7 +6,7 @@ from tqdm import tqdm from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.utils.ops import Profile -from ultralytics.yolo.utils.torch_utils import select_device +from ultralytics.yolo.utils.torch_utils import de_parallel, select_device class BaseValidator: @@ -36,7 +36,9 @@ class BaseValidator: if training: model = trainer.model self.args.half &= self.device.type != 'cpu' - model = model.half() if self.args.half else model + # NOTE: half() inference in evaluation will make training stuck, + # so I comment it out for now, I think we can reuse half mode after we add EMA. + # model = model.half() if self.args.half else model else: # TODO: handle this when detectMultiBackend is supported # model = DetectMultiBacked(model) pass @@ -48,8 +50,8 @@ class BaseValidator: n_batches = len(self.dataloader) desc = self.get_desc() bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') - self.init_metrics(model) - with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'): + self.init_metrics(de_parallel(model)) + with torch.no_grad(): for batch_i, batch in enumerate(bar): self.batch_i = batch_i # pre-process @@ -58,7 +60,7 @@ class BaseValidator: # inference with dt[1]: - preds = model(batch["img"]) + preds = model(batch["img"].float()) # TODO: remember to add native augmentation support when implementing model, like: # preds, train_out = model(im, augment=augment) @@ -85,6 +87,8 @@ class BaseValidator: self.logger.info( 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t) + if self.training: + model.float() # TODO: implement save json return stats diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index c0420e8..9fcfc6e 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -6,10 +6,11 @@ from ultralytics.yolo.engine.validator import BaseValidator class ClassificationValidator(BaseValidator): def init_metrics(self, model): - self.correct = torch.tensor([]) + self.correct = torch.tensor([], device=next(model.parameters()).device) def preprocess(self, batch): - batch["img"] = batch["img"].to(self.device) + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() batch["cls"] = batch["cls"].to(self.device) return batch diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 10c3522..b83bcbf 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -23,7 +23,7 @@ class SegmentationTrainer(BaseTrainer): def get_dataloader(self, dataset_path, batch_size, rank=0): # TODO: manage splits differently # calculate stride - check if model is initialized - gs = max(int(self.model.stride.max() if self.model else 0), 32) + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) loader = build_dataloader( img_path=dataset_path, img_size=self.args.img_size, @@ -220,7 +220,7 @@ class SegmentationTrainer(BaseTrainer): mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)) for bi in b.unique(): j = b == bi # matching index - if True: + if self.args.overlap_mask: mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0) else: mask_gti = masks[tidxs[i]][j] diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index c09d730..f4a526f 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -30,11 +30,13 @@ class SegmentationValidator(BaseValidator): def preprocess(self, batch): batch["img"] = batch["img"].to(self.device, non_blocking=True) - batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225 - batch["bboxes"] = batch["bboxes"].to(self.device) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 batch["masks"] = batch["masks"].to(self.device).float() self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + self.targets = self.targets.to(self.device) + height, width = batch["img"].shape[2:] + self.targets[:, 2:] *= torch.tensor((width, height, width, height), device=self.device) # to pixels self.lb = [self.targets[self.targets[:, 0] == i, 1:] for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling @@ -75,7 +77,7 @@ class SegmentationValidator(BaseValidator): agnostic=self.args.single_cls, max_det=self.args.max_det, nm=self.nm) - return (p, preds[0], preds[2]) + return (p, preds[1], preds[2]) def update_metrics(self, preds, batch): # Metrics @@ -83,7 +85,7 @@ class SegmentationValidator(BaseValidator): for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): labels = self.targets[self.targets[:, 0] == si, 1:] nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions - shape = Path(batch["im_file"][si]) + shape = batch["shape"][si] # path = batch["shape"][si][0] correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init @@ -106,22 +108,29 @@ class SegmentationValidator(BaseValidator): if self.args.single_cls: pred[:, 5] = 0 predn = pred.clone() - ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, batch["shape"][si][1]) # native-space pred + ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape) # native-space pred # Evaluate if nl: tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes - ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, batch["shapes"][si][1]) # native-space labels + ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels correct_bboxes = self._process_batch(predn, labelsn, self.iouv) - correct_masks = self._process_batch(predn, labelsn, self.iouv, pred_masks, gt_masks, masks=True) + # TODO: maybe remove these `self.` arguments as they already are member variable + correct_masks = self._process_batch(predn, + labelsn, + self.iouv, + pred_masks, + gt_masks, + overlap=self.args.overlap_mask, + masks=True) if self.args.plots: self.confusion_matrix.process_batch(predn, labelsn) self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls) pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) - if self.plots and self.batch_i < 3: + if self.args.plots and self.batch_i < 3: plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot # TODO: Save/log