diff --git a/requirements.txt b/requirements.txt index eb7cb41..b8a0555 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ seaborn>=0.11.0 # Export -------------------------------------- # coremltools>=5.2 # CoreML export -# onnx>=1.9.0 # ONNX export +# onnx>=1.12.0 # ONNX export # onnx-simplifier>=0.4.1 # ONNX simplifier # nvidia-pyindex # TensorRT export # nvidia-tensorrt # TensorRT export diff --git a/ultralytics/tests/data/dataloader/yolodetection.py b/ultralytics/tests/data/dataloader/yolodetection.py index db6da14..515633f 100644 --- a/ultralytics/tests/data/dataloader/yolodetection.py +++ b/ultralytics/tests/data/dataloader/yolodetection.py @@ -55,7 +55,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): ) -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def test(cfg): cfg.task = "detect" cfg.mode = "train" diff --git a/ultralytics/tests/data/dataloader/yolosegment.py b/ultralytics/tests/data/dataloader/yolosegment.py index cd38ba2..f8cca68 100644 --- a/ultralytics/tests/data/dataloader/yolosegment.py +++ b/ultralytics/tests/data/dataloader/yolosegment.py @@ -54,7 +54,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): ) -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def test(cfg): cfg.task = "segment" cfg.mode = "train" diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index b1463ce..2ebaa77 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -82,7 +82,7 @@ class BaseMixTransform: indexes = [indexes] # get images information will be used for Mosaic or MixUp - mix_labels = [deepcopy(dataset.get_label_info(index)) for index in indexes] + mix_labels = [dataset.get_label_info(index) for index in indexes] if self.pre_transform is not None: for i, data in enumerate(mix_labels): @@ -134,9 +134,8 @@ class Mosaic(BaseMixTransform): assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment." s = self.imgsz yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y - mix_labels = labels["mix_labels"] for i in range(4): - labels_patch = deepcopy(labels) if i == 0 else deepcopy(mix_labels[i - 1]) + labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy() # Load image img = labels_patch["img"] h, w = labels_patch["resized_shape"] @@ -186,9 +185,8 @@ class Mosaic(BaseMixTransform): "ori_shape": mosaic_labels[0]["ori_shape"], "resized_shape": (self.imgsz * 2, self.imgsz * 2), "im_file": mosaic_labels[0]["im_file"], - "cls": np.concatenate(cls, 0)} - - final_labels["instances"] = Instances.concatenate(instances, axis=0) + "cls": np.concatenate(cls, 0), + "instances": Instances.concatenate(instances, axis=0)} final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2) return final_labels @@ -345,7 +343,6 @@ class RandomPerspective: Affine images and targets. Args: - img(ndarray): image. labels(Dict): a dict of `bboxes`, `segments`, `keypoints`. """ img = labels["img"] @@ -387,7 +384,7 @@ class RandomPerspective: return labels def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) - # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio + # Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio w1, h1 = box1[2] - box1[0], box1[3] - box1[1] w2, h2 = box2[2] - box2[0], box2[3] - box2[1] ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio @@ -609,6 +606,7 @@ class Format: self.batch_idx = batch_idx # keep the batch indexes def __call__(self, labels): + labels.pop("dataset", None) img = labels["img"] h, w = img.shape[:2] cls = labels.pop("cls") @@ -672,10 +670,7 @@ def mosaic_transforms(imgsz, hyp): ),]) return Compose([ pre_transform, - MixUp( - pre_transform=pre_transform, - p=hyp.mixup, - ), + MixUp(pre_transform=pre_transform, p=hyp.mixup), Albumentations(p=1.0), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomFlip(direction="vertical", p=hyp.flipud), diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index 33352e9..e00fac4 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -1,4 +1,5 @@ import glob +import math import os from multiprocessing.pool import ThreadPool from pathlib import Path @@ -121,7 +122,7 @@ class BaseDataset(Dataset): r = self.imgsz / max(h0, w0) # ratio if r != 1: # if sizes are not equal interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA - im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp) + im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp) return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized @@ -179,10 +180,7 @@ class BaseDataset(Dataset): def get_label_info(self, index): label = self.labels[index].copy() - img, (h0, w0), (h, w) = self.load_image(index) - label["img"] = img - label["ori_shape"] = (h0, w0) - label["resized_shape"] = (h, w) + label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) if self.rect: label["rect_shape"] = self.batch_shapes[self.batch[index]] label = self.update_labels_info(label) diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py index 669876a..1fa0cb2 100644 --- a/ultralytics/yolo/data/build.py +++ b/ultralytics/yolo/data/build.py @@ -64,7 +64,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank label_path=label_path, imgsz=cfg.imgsz, batch_size=batch_size, - augment=True if mode == "train" else False, # augmentation + augment=mode == "train", # augmentation hyp=cfg, # TODO: probably add a get_hyps_from_cfg function rect=cfg.rect if mode == "train" else True, # rectangular batches cache=None if cfg.noval else cfg.get("cache", None), @@ -73,31 +73,25 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank pad=0.0 if mode == "train" else 0.5, prefix=colorstr(f"{mode}: "), use_segments=cfg.task == "segment", - use_keypoints=cfg.task == "keypoint", - ) + use_keypoints=cfg.task == "keypoint") batch_size = min(batch_size, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices workers = cfg.workers if mode == "train" else cfg.workers * 2 nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) - loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates + loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates generator = torch.Generator() generator.manual_seed(6148914691236517205 + RANK) - return ( - loader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle and sampler is None, - num_workers=nw, - sampler=sampler, - pin_memory=PIN_MEMORY, - collate_fn=getattr(dataset, "collate_fn", None), - worker_init_fn=seed_worker, - generator=generator, - ), - dataset, - ) + return loader(dataset=dataset, + batch_size=batch_size, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + worker_init_fn=seed_worker, + generator=generator), dataset # build classification diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index 2ad939d..8b47343 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -124,13 +124,9 @@ class YOLODataset(BaseDataset): # TODO: use hyp config to set all these augmentations def build_transforms(self, hyp=None): - mosaic = self.augment and not self.rect - # mosaic = False if self.augment: - if mosaic: - transforms = mosaic_transforms(self.imgsz, hyp) - else: - transforms = affine_transforms(self.imgsz, hyp) + mosaic = self.augment and not self.rect + transforms = mosaic_transforms(self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp) else: transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))]) transforms.append( @@ -143,7 +139,7 @@ class YOLODataset(BaseDataset): def update_labels_info(self, label): """custom your label format here""" - # NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label + # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label # we can make it also support classification and semantic segmentation by add or remove some dict keys there. bboxes = label.pop("bboxes") segments = label.pop("segments") @@ -206,7 +202,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] else: sample = self.torch_transforms(im) - return OrderedDict(img=sample, cls=j) + return {'img': sample, 'cls': j} def __len__(self) -> int: return len(self.samples) diff --git a/ultralytics/yolo/data/datasets/coco.yaml b/ultralytics/yolo/data/datasets/coco.yaml new file mode 100644 index 0000000..57aa9b9 --- /dev/null +++ b/ultralytics/yolo/data/datasets/coco.yaml @@ -0,0 +1,113 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +# COCO 2017 dataset http://cocodataset.org by Microsoft +# Example usage: python train.py --data coco.yaml +# parent +# ├── yolov5 +# └── datasets +# └── coco ← downloads here (20.1 GB) + + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco # dataset root dir +train: train2017.txt # train images (relative to 'path') 118287 images +val: val2017.txt # val images (relative to 'path') 5000 images +test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + + +# Download script/URL (optional) +download: | + from utils.general import download, Path + # Download labels + segments = True # segment or box labels + dir = Path(yaml['path']) # dataset root dir + url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/' + urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels + download(urls, dir=dir.parent) + # Download data + urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images + 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images + 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) + download(urls, dir=dir / 'images', threads=3) \ No newline at end of file diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index bfefc59..3e894e6 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -29,16 +29,14 @@ import platform from pathlib import Path import cv2 -import torch from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS, check_dataset, check_dataset_yaml -from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr, ops +from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, ops from ultralytics.yolo.utils.checks import check_file, check_imshow from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.modeling.autobackend import AutoBackend -from ultralytics.yolo.utils.plotting import Annotator from ultralytics.yolo.utils.torch_utils import check_imgsz, select_device, smart_inference_mode DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" @@ -125,11 +123,7 @@ class BasePredictor: @smart_inference_mode() def __call__(self, source=None, model=None): - if not self.done_setup: - model = self.setup(source, model) - else: - model = self.model - + model = self.model if self.done_setup else self.setup(source, model) self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) for batch in self.dataset: path, im, im0s, vid_cap, s = batch diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index ac9770a..c1e5352 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -60,7 +60,8 @@ class BaseTrainer: # device self.device = utils.torch_utils.select_device(self.args.device, self.batch_size) - self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') + self.amp = self.device.type != 'cpu' + self.scaler = amp.GradScaler(enabled=self.amp) # Model and Dataloaders. self.model = self.args.model @@ -175,6 +176,10 @@ class BaseTrainer: nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations last_opt_step = -1 self.trigger_callbacks("on_train_start") + self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n" + f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' + f"Logging results to {colorstr('bold', self.save_dir)}\n" + f"Starting training for {self.epochs} epochs...") for epoch in range(self.start_epoch, self.epochs): self.epoch = epoch self.trigger_callbacks("on_train_epoch_start") @@ -189,8 +194,6 @@ class BaseTrainer: self.optimizer.zero_grad() for i, batch in pbar: self.trigger_callbacks("on_train_batch_start") - # forward - batch = self.preprocess_batch(batch) # warmup ni = i + nb * epoch @@ -204,17 +207,20 @@ class BaseTrainer: if 'momentum' in x: x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) - preds = self.model(batch["img"]) - self.loss, self.loss_items = self.criterion(preds, batch) - if rank != -1: - self.loss *= world_size - self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ - else self.loss_items - - # backward + # Forward + with torch.cuda.amp.autocast(self.amp): + batch = self.preprocess_batch(batch) + preds = self.model(batch["img"]) + self.loss, self.loss_items = self.criterion(preds, batch) + if rank != -1: + self.loss *= world_size + self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ + else self.loss_items + + # Backward self.scaler.scale(self.loss).backward() - # optimize + # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html if ni - last_opt_step >= self.accumulate: self.optimizer_step() last_opt_step = ni @@ -237,7 +243,7 @@ class BaseTrainer: self.scheduler.step() self.trigger_callbacks("on_train_epoch_end") - if rank in [-1, 0]: + if rank in {-1, 0}: # validation self.trigger_callbacks('on_val_start') self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) @@ -245,7 +251,7 @@ class BaseTrainer: if not self.args.noval or final_epoch: self.metrics, self.fitness = self.validate() self.trigger_callbacks('on_val_end') - log_vals = self.label_loss_items(self.tloss) | self.metrics | lr + log_vals = {**self.label_loss_items(self.tloss), **self.metrics, **lr} self.save_metrics(metrics=log_vals) # save model @@ -259,12 +265,13 @@ class BaseTrainer: # TODO: termination condition - if rank in [-1, 0]: + if rank in {-1, 0}: # do the last evaluation with best.pt + self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in ' + f'{(time.time() - self.train_time_start) / 3600:.3f} hours.') self.final_eval() if self.args.plots: self.plot_metrics() - self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") self.log(f"Results saved to {colorstr('bold', self.save_dir)}") self.trigger_callbacks('on_train_end') dist.destroy_process_group() if world_size > 1 else None diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 7be8442..b150614 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -1,4 +1,3 @@ -import logging from pathlib import Path import torch @@ -9,10 +8,9 @@ from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.utils import LOGGER, TQDM_BAR_FORMAT from ultralytics.yolo.utils.files import increment_path -from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling.autobackend import AutoBackend from ultralytics.yolo.utils.ops import Profile -from ultralytics.yolo.utils.torch_utils import check_imgsz, de_parallel, select_device +from ultralytics.yolo.utils.torch_utils import check_imgsz, de_parallel, select_device, smart_inference_mode class BaseValidator: @@ -32,8 +30,9 @@ class BaseValidator: self.training = True self.speed = None self.save_dir = save_dir if save_dir is not None else \ - increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) + increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) + @smart_inference_mode() def __call__(self, trainer=None, model=None): """ Supports validation of a pre-trained model if passed or a model being trained @@ -76,35 +75,34 @@ class BaseValidator: dt = Profile(), Profile(), Profile(), Profile() n_batches = len(self.dataloader) desc = self.get_desc() - # NOTE: keeping this `not self.training` in tqdm will eliminate pbar after finishing segmantation evaluation during training, - # so I removed it, not sure if this will affect classification task cause I saw we use this arg in yolov5/classify/val.py. + # NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training, + # which may affect classification task since this arg is in yolov5/classify/val.py. # bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT) bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT) self.init_metrics(de_parallel(model)) - with torch.no_grad(): - for batch_i, batch in enumerate(bar): - self.batch_i = batch_i - # pre-process - with dt[0]: - batch = self.preprocess(batch) - - # inference - with dt[1]: - preds = model(batch["img"]) - - # loss - with dt[2]: - if self.training: - self.loss += trainer.criterion(preds, batch)[1] - - # pre-process predictions - with dt[3]: - preds = self.postprocess(preds) - - self.update_metrics(preds, batch) - if self.args.plots and batch_i < 3: - self.plot_val_samples(batch, batch_i) - self.plot_predictions(batch, preds, batch_i) + for batch_i, batch in enumerate(bar): + self.batch_i = batch_i + # pre-process + with dt[0]: + batch = self.preprocess(batch) + + # inference + with dt[1]: + preds = model(batch["img"]) + + # loss + with dt[2]: + if self.training: + self.loss += trainer.criterion(preds, batch)[1] + + # pre-process predictions + with dt[3]: + preds = self.postprocess(preds) + + self.update_metrics(preds, batch) + if self.args.plots and batch_i < 3: + self.plot_val_samples(batch, batch_i) + self.plot_predictions(batch, preds, batch_i) stats = self.get_stats() self.check_stats(stats) @@ -113,22 +111,21 @@ class BaseValidator: # calculate speed only once when training if not self.training or trainer.epoch == 0: - t = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image - self.speed = t + self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image - if not self.training: # print only at inference - self.logger.info( - 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' % t) + if not self.training: # print only at inference + self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' % + self.speed) if self.training: model.float() # TODO: implement save json - return stats | trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val") \ - if self.training else stats + return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} \ + if self.training else stats def get_dataloader(self, dataset_path, batch_size): - raise Exception("get_dataloder function not implemented for this validator") + raise NotImplementedError("get_dataloader function not implemented for this validator") def preprocess(self, batch): return batch diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index f8f21f4..7d234b4 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -17,7 +17,7 @@ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiproces AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode -TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format +TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format LOGGING_NAME = 'yolov5' diff --git a/ultralytics/yolo/utils/callbacks/clearml.py b/ultralytics/yolo/utils/callbacks/clearml.py index 8d4bfe8..946ca0b 100644 --- a/ultralytics/yolo/utils/callbacks/clearml.py +++ b/ultralytics/yolo/utils/callbacks/clearml.py @@ -23,9 +23,9 @@ def on_train_start(trainer): def on_val_end(trainer): if trainer.epoch == 0: model_info = { - "Inference speed (ms/img)": round(trainer.validator.speed[1], 1), + "Parameters": get_num_params(trainer.model), "GFLOPs": round(get_flops(trainer.model), 1), - "Parameters": get_num_params(trainer.model)} + "Inference speed (ms/img)": round(trainer.validator.speed[1], 1)} Task.current_task().connect(model_info, name='Model') diff --git a/ultralytics/yolo/utils/callbacks/tb.py b/ultralytics/yolo/utils/callbacks/tb.py index 5fe4d28..b442424 100644 --- a/ultralytics/yolo/utils/callbacks/tb.py +++ b/ultralytics/yolo/utils/callbacks/tb.py @@ -11,8 +11,6 @@ def _log_scalars(scalars, step=0): def on_train_start(trainer): global writer writer = SummaryWriter(str(trainer.save_dir)) - trainer.console.info(f"Logging results to {trainer.save_dir}\n" - f"Starting training for {trainer.args.epochs} epochs...") def on_batch_end(trainer): diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml index 5ef0de6..e40cb83 100644 --- a/ultralytics/yolo/utils/configs/default.yaml +++ b/ultralytics/yolo/utils/configs/default.yaml @@ -27,12 +27,13 @@ local_rank: -1 single_cls: False # train multi-class data as single-class image_weights: False # use weighted image selection for training rect: False # support rectangular training -cos_lr: False # Use cosine LR scheduler +cos_lr: False # use cosine LR scheduler +close_mosaic: 10 # disable mosaic for final 10 epochs # Segmentation overlap_mask: True # masks overlap mask_ratio: 4 # mask downsample ratio # Classification -dropout: False # use dropout +dropout: False # use dropout resume: False diff --git a/ultralytics/yolo/utils/modeling/__init__.py b/ultralytics/yolo/utils/modeling/__init__.py index 48a5917..9cedfa2 100644 --- a/ultralytics/yolo/utils/modeling/__init__.py +++ b/ultralytics/yolo/utils/modeling/__init__.py @@ -45,7 +45,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True): def parse_model(d, ch): # model_dict, input_channels(3) # Parse a YOLOv5 model.yaml dictionary - LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}") + LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<50}{'arguments':<30}") nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation') if act: Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU() @@ -87,7 +87,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) t = str(m)[8:-2].replace('__main__.', '') # module type m.np = sum(x.numel() for x in m_.parameters()) # number params m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type - LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{m.np:10.0f} {t:<40}{str(args):<30}') # print + LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{m.np:10.0f} {t:<50}{str(args):<30}') # print save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist layers.append(m_) if i == 0: diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index eb289f6..75f2975 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -63,9 +63,9 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): gain = ratio_pad[0][0] pad = ratio_pad[1] - boxes[:, [0, 2]] -= pad[0] # x padding - boxes[:, [1, 3]] -= pad[1] # y padding - boxes[:, :4] /= gain + boxes[..., [0, 2]] -= pad[0] # x padding + boxes[..., [1, 3]] -= pad[1] # y padding + boxes[..., :4] /= gain clip_boxes(boxes, img0_shape) return boxes @@ -73,13 +73,13 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): def clip_boxes(boxes, shape): # Clip boxes (xyxy) to image shape (height, width) if isinstance(boxes, torch.Tensor): # faster individually - boxes[:, 0].clamp_(0, shape[1]) # x1 - boxes[:, 1].clamp_(0, shape[0]) # y1 - boxes[:, 2].clamp_(0, shape[1]) # x2 - boxes[:, 3].clamp_(0, shape[0]) # y2 + boxes[..., 0].clamp_(0, shape[1]) # x1 + boxes[..., 1].clamp_(0, shape[0]) # y1 + boxes[..., 2].clamp_(0, shape[1]) # x2 + boxes[..., 3].clamp_(0, shape[0]) # y2 else: # np.array (faster grouped) - boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 - boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 def make_divisible(x, divisor): @@ -106,6 +106,9 @@ def non_max_suppression( list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ + # Checks + assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' + assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out) prediction = prediction[0] # select only inference output @@ -118,10 +121,6 @@ def non_max_suppression( mi = 4 + nc # mask start index xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates - # Checks - assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' - assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' - # Settings # min_wh = 2 # (pixels) minimum box width and height max_wh = 7680 # (pixels) maximum box width and height @@ -172,17 +171,13 @@ def non_max_suppression( n = x.shape[0] # number of boxes if not n: # no boxes continue - elif n > max_nms: # excess boxes - x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence - else: - x = x[x[:, 4].argsort(descending=True)] # sort by confidence + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS - if i.shape[0] > max_det: # limit detections - i = i[:max_det] + i = i[:max_det] # limit detections if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix @@ -244,20 +239,50 @@ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None): def xyxy2xywh(x): # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center - y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center - y[:, 2] = x[:, 2] - x[:, 0] # width - y[:, 3] = x[:, 3] - x[:, 1] # height + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height return y def xywh2xyxy(x): # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x - y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y - y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x - y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x + y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x + y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y + y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x + y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right + if clip: + clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center + y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center + y[..., 2] = (x[..., 2] - x[..., 0]) / w # width + y[..., 3] = (x[..., 3] - x[..., 1]) / h # height + return y + + +def xyn2xy(x, w=640, h=640, padw=0, padh=0): + # Convert normalized segments into pixel segments, shape (n,2) + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = w * x[..., 0] + padw # top left x + y[..., 1] = h * x[..., 1] + padh # top left y return y diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 08f9031..a61c669 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -29,7 +29,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) @contextmanager def torch_distributed_zero_first(local_rank: int): # Decorator to make all processes in distributed training wait for each local_master to do something - if local_rank not in [-1, 0]: + if local_rank not in {-1, 0}: dist.barrier(device_ids=[local_rank]) yield if local_rank == 0: diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index a869a7a..9f07b5a 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -52,7 +52,7 @@ class ClassificationPredictor(BasePredictor): return log_string -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg): cfg.model = cfg.model or "squeezenet1_0" sz = cfg.imgsz diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 4f436f4..2b98de6 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -59,7 +59,7 @@ class ClassificationTrainer(BaseTrainer): pass -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): cfg.model = cfg.model or "resnet18" cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist") diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index 006459e..9620e7f 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -35,7 +35,7 @@ class ClassificationValidator(BaseValidator): return ["top1", "top5"] -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg): cfg.data = cfg.data or "imagenette160" cfg.model = cfg.model or "resnet18" diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index 01c3df3..d537e64 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -81,7 +81,7 @@ class DetectionPredictor(BasePredictor): return log_string -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg): cfg.model = cfg.model or "n.pt" sz = cfg.imgsz diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index b361eab..dc31332 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -53,7 +53,9 @@ class DetectionTrainer(BaseTrainer): args=self.args) def criterion(self, preds, batch): - return Loss(self.model)(preds, batch) + if not hasattr(self, 'compute_loss'): + self.compute_loss = Loss(de_parallel(self.model)) + return self.compute_loss(preds, batch) def label_loss_items(self, loss_items=None, prefix="train"): # We should just use named tensors here in future @@ -61,8 +63,8 @@ class DetectionTrainer(BaseTrainer): return dict(zip(keys, loss_items)) if loss_items is not None else keys def progress_string(self): - return ('\n' + '%11s' * 6) % \ - ('Epoch', 'GPU_mem', *self.loss_names, 'Size') + return ('\n' + '%11s' * 7) % \ + ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') def plot_training_samples(self, batch, ni): images = batch["img"] @@ -79,7 +81,7 @@ class DetectionTrainer(BaseTrainer): # Criterion class for computing training losses class Loss: - def __init__(self, model): + def __init__(self, model): # model must be de-paralleled device = next(model.parameters()).device # get model device h = model.args # hyperparameters @@ -90,7 +92,7 @@ class Loss: # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets - m = de_parallel(model).model[-1] # Detect() module + m = model.model[-1] # Detect() module self.BCEcls = BCEcls self.hyp = h self.stride = m.stride # model strides @@ -169,12 +171,12 @@ class Loss: return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): cfg.model = cfg.model or "models/yolov8n.yaml" cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") - cfg.imgsz = 160 - cfg.epochs = 5 + # cfg.imgsz = 160 + # cfg.epochs = 5 trainer = DetectionTrainer(cfg) trainer.train() diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 63bfe2f..d2a32b4 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -119,9 +119,9 @@ class DetectionValidator(BaseValidator): if len(stats) and stats[0].any(): self.metrics.process(*stats) self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class - metrics = {"fitness": self.metrics.fitness()} - metrics |= zip(self.metric_keys, self.metrics.mean_results()) - return metrics + fitness = {"fitness": self.metrics.fitness()} + metrics = dict(zip(self.metric_keys, self.metrics.mean_results())) + return {**metrics, **fitness} def print_results(self): pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metric_keys) # print format @@ -198,7 +198,7 @@ class DetectionValidator(BaseValidator): names=self.names) # pred -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg): cfg.data = cfg.data or "coco128.yaml" validator = DetectionValidator(args=cfg) diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index aaa4437..4403635 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -99,7 +99,7 @@ class SegmentationPredictor(DetectionPredictor): return log_string -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg): cfg.model = cfg.model or "n.pt" sz = cfg.imgsz diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 5dd0f59..5e7f00b 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -214,8 +214,8 @@ class SegmentationTrainer(DetectionTrainer): return dict(zip(keys, loss_items)) if loss_items is not None else keys def progress_string(self): - return ('\n' + '%11s' * 7) % \ - ('Epoch', 'GPU_mem', *self.loss_names, 'Size') + return ('\n' + '%11s' * 8) % \ + ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') def plot_training_samples(self, batch, ni): images = batch["img"] @@ -230,7 +230,7 @@ class SegmentationTrainer(DetectionTrainer): plot_results(file=self.csv, segment=True) # save results.png -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): cfg.model = cfg.model or "models/yolov8n-seg.yaml" cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index a8537f8..df36a34 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -211,7 +211,7 @@ class SegmentationValidator(DetectionValidator): self.plot_masks.clear() -@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) +@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg): cfg.data = cfg.data or "coco128-seg.yaml" validator = SegmentationValidator(args=cfg)