diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index 59a607a..89183e8 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -13,6 +13,7 @@ on: jobs: CLA: + if: github.repository == 'ultralytics/ultralytics' runs-on: ubuntu-latest steps: - name: "CLA Assistant" diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 817530c..b441beb 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -7,7 +7,7 @@ import requests from ultralytics import __version__ from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request -from ultralytics.yolo.utils import LOGGER, is_colab, threaded +from ultralytics.yolo.utils import is_colab, threaded AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 661a6fc..6f68da2 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -32,21 +32,21 @@ class AutoBackend(nn.Module): fp16 (bool): If True, use half precision. Default: False fuse (bool): Whether to fuse the model or not. Default: True - Supported formats and their usage: - Platform | Weights Format - -----------------------|------------------ - PyTorch | *.pt - TorchScript | *.torchscript - ONNX Runtime | *.onnx - ONNX OpenCV DNN | *.onnx --dnn - OpenVINO | *.xml - CoreML | *.mlmodel - TensorRT | *.engine - TensorFlow SavedModel | *_saved_model - TensorFlow GraphDef | *.pb - TensorFlow Lite | *.tflite - TensorFlow Edge TPU | *_edgetpu.tflite - PaddlePaddle | *_paddle_model + Supported formats and their naming conventions: + | Format | Suffix | + |-----------------------|------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx --dnn | + | OpenVINO | *.xml | + | CoreML | *.mlmodel | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model | """ super().__init__() w = str(weights[0] if isinstance(weights, list) else weights) @@ -357,7 +357,7 @@ class AutoBackend(nn.Module): This function takes a path to a model file and returns the model type Args: - p: path to the model file. Defaults to path/to/model.pt + p: path to the model file. Defaults to path/to/model.pt """ # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] @@ -374,12 +374,11 @@ class AutoBackend(nn.Module): @staticmethod def _load_metadata(f=Path('path/to/meta.yaml')): """ - > Loads the metadata from a yaml file + Loads the metadata from a yaml file Args: - f: The path to the metadata file. + f: The path to the metadata file. """ - from ultralytics.yolo.utils.files import yaml_load # Load metadata from meta.yaml if it exists if f.exists(): diff --git a/ultralytics/nn/modules.py b/ultralytics/nn/modules.py index 9cba6ff..7847593 100644 --- a/ultralytics/nn/modules.py +++ b/ultralytics/nn/modules.py @@ -5,28 +5,11 @@ Common modules import math import warnings -from copy import copy -from pathlib import Path -import cv2 -import numpy as np -import pandas as pd -import requests import torch import torch.nn as nn -from PIL import Image, ImageOps -from torch.cuda import amp - -from ultralytics.nn.autobackend import AutoBackend -from ultralytics.yolo.data.augment import LetterBox -from ultralytics.yolo.utils import LOGGER, colorstr -from ultralytics.yolo.utils.files import increment_path -from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh -from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box -from ultralytics.yolo.utils.tal import dist2bbox, make_anchors -from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode -# from utils.plots import feature_visualization TODO +from ultralytics.yolo.utils.tal import dist2bbox, make_anchors def autopad(k, p=None, d=1): # kernel, padding, dilation @@ -365,216 +348,6 @@ class Concat(nn.Module): return torch.cat(x, self.d) -class AutoShape(nn.Module): - # YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS - conf = 0.25 # NMS confidence threshold - iou = 0.45 # NMS IoU threshold - agnostic = False # NMS class-agnostic - multi_label = False # NMS multiple labels per box - classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs - max_det = 1000 # maximum number of detections per image - amp = False # Automatic Mixed Precision (AMP) inference - - def __init__(self, model, verbose=True): - super().__init__() - if verbose: - LOGGER.info('Adding AutoShape... ') - copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes - self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance - self.pt = not self.dmb or model.pt # PyTorch model - self.model = model.eval() - if self.pt: - m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect() - m.inplace = False # Detect.inplace=False for safe multithread inference - m.export = True # do not output loss values - - def _apply(self, fn): - # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers - self = super()._apply(fn) - if self.pt: - m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect() - m.stride = fn(m.stride) - m.grid = list(map(fn, m.grid)) - if isinstance(m.anchor_grid, list): - m.anchor_grid = list(map(fn, m.anchor_grid)) - return self - - @smart_inference_mode() - def forward(self, ims, size=640, augment=False, profile=False): - # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are: - # file: ims = 'data/images/zidane.jpg' # str or PosixPath - # URI: = 'https://ultralytics.com/images/zidane.jpg' - # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3) - # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3) - # numpy: = np.zeros((640,1280,3)) # HWC - # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values) - # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images - - dt = (Profile(), Profile(), Profile()) - with dt[0]: - if isinstance(size, int): # expand - size = (size, size) - p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param - autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference - if isinstance(ims, torch.Tensor): # torch - with amp.autocast(autocast): - return self.model(ims.to(p.device).type_as(p), augment=augment) # inference - - # Pre-process - n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images - shape0, shape1, files = [], [], [] # image and inference shapes, filenames - for i, im in enumerate(ims): - f = f'image{i}' # filename - if isinstance(im, (str, Path)): # filename or uri - im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im - im = np.asarray(ImageOps.exif_transpose(im)) - elif isinstance(im, Image.Image): # PIL Image - im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f - files.append(Path(f).with_suffix('.jpg').name) - if im.shape[0] < 5: # image in CHW - im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) - im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input - s = im.shape[:2] # HWC - shape0.append(s) # image shape - g = max(size) / max(s) # gain - shape1.append([y * g for y in s]) - ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update - shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape - x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad - x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW - x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 - - with amp.autocast(autocast): - # Inference - with dt[1]: - y = self.model(x, augment=augment) # forward - - # Post-process - with dt[2]: - y = non_max_suppression(y if self.dmb else y[0], - self.conf, - self.iou, - self.classes, - self.agnostic, - self.multi_label, - max_det=self.max_det) # NMS - for i in range(n): - scale_boxes(shape1, y[i][:, :4], shape0[i]) - - return Detections(ims, y, files, dt, self.names, x.shape) - - -class Detections: - # YOLOv8 detections class for inference results - def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None): - super().__init__() - d = pred[0].device # device - gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations - self.ims = ims # list of images as numpy arrays - self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) - self.names = names # class names - self.files = files # image filenames - self.times = times # profiling times - self.xyxy = pred # xyxy pixels - self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels - self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized - self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized - self.n = len(self.pred) # number of images (batch size) - self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms) - self.s = tuple(shape) # inference BCHW shape - - def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')): - s, crops = '', [] - for i, (im, pred) in enumerate(zip(self.ims, self.pred)): - s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string - if pred.shape[0]: - for c in pred[:, -1].unique(): - n = (pred[:, -1] == c).sum() # detections per class - s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string - s = s.rstrip(', ') - if show or save or render or crop: - annotator = Annotator(im, example=str(self.names)) - for *box, conf, cls in reversed(pred): # xyxy, confidence, class - label = f'{self.names[int(cls)]} {conf:.2f}' - if crop: - file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None - crops.append({ - 'box': box, - 'conf': conf, - 'cls': cls, - 'label': label, - 'im': save_one_box(box, im, file=file, save=save)}) - else: # all others - annotator.box_label(box, label if labels else '', color=colors(cls)) - im = annotator.im - else: - s += '(no detections)' - - im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np - if show: - im.show(self.files[i]) # show - if save: - f = self.files[i] - im.save(save_dir / f) # save - if i == self.n - 1: - LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}") - if render: - self.ims[i] = np.asarray(im) - if pprint: - s = s.lstrip('\n') - return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t - if crop: - if save: - LOGGER.info(f'Saved results to {save_dir}\n') - return crops - - def show(self, labels=True): - self._run(show=True, labels=labels) # show results - - def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False): - save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir - self._run(save=True, labels=labels, save_dir=save_dir) # save results - - def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False): - save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None - return self._run(crop=True, save=save, save_dir=save_dir) # crop results - - def render(self, labels=True): - self._run(render=True, labels=labels) # render results - return self.ims - - def pandas(self): - # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0]) - new = copy(self) # return copy - ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns - cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns - for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]): - a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update - setattr(new, k, [pd.DataFrame(x, columns=c) for x in a]) - return new - - def tolist(self): - # return a list of Detections objects, i.e. 'for result in results.tolist():' - r = range(self.n) # iterable - x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r] - # for d in x: - # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: - # setattr(d, k, getattr(d, k)[0]) # pop out of list - return x - - def print(self): - LOGGER.info(self.__str__()) - - def __len__(self): # override len(results) - return self.n - - def __str__(self): # override print(results) - return self._run(pprint=True) # print results - - def __repr__(self): - return f'YOLOv8 {self.__class__} instance\n' + self.__str__() - - class Proto(nn.Module): # YOLOv8 mask Proto module for segmentation models def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks diff --git a/ultralytics/nn/results.py b/ultralytics/nn/results.py new file mode 100644 index 0000000..30c6110 --- /dev/null +++ b/ultralytics/nn/results.py @@ -0,0 +1,237 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license +""" +Common modules +""" + +from copy import copy +from pathlib import Path + +import cv2 +import numpy as np +import pandas as pd +import requests +import torch +import torch.nn as nn +from PIL import Image, ImageOps +from torch.cuda import amp + +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.yolo.data.augment import LetterBox +from ultralytics.yolo.utils import LOGGER, colorstr +from ultralytics.yolo.utils.files import increment_path +from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh +from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box +from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode + + +class AutoShape(nn.Module): + # YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + agnostic = False # NMS class-agnostic + multi_label = False # NMS multiple labels per box + classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs + max_det = 1000 # maximum number of detections per image + amp = False # Automatic Mixed Precision (AMP) inference + + def __init__(self, model, verbose=True): + super().__init__() + if verbose: + LOGGER.info('Adding AutoShape... ') + copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes + self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance + self.pt = not self.dmb or model.pt # PyTorch model + self.model = model.eval() + if self.pt: + m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect() + m.inplace = False # Detect.inplace=False for safe multithread inference + m.export = True # do not output loss values + + def _apply(self, fn): + # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers + self = super()._apply(fn) + if self.pt: + m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect() + m.stride = fn(m.stride) + m.grid = list(map(fn, m.grid)) + if isinstance(m.anchor_grid, list): + m.anchor_grid = list(map(fn, m.anchor_grid)) + return self + + @smart_inference_mode() + def forward(self, ims, size=640, augment=False, profile=False): + # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are: + # file: ims = 'data/images/zidane.jpg' # str or PosixPath + # URI: = 'https://ultralytics.com/images/zidane.jpg' + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3) + # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3) + # numpy: = np.zeros((640,1280,3)) # HWC + # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values) + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + dt = (Profile(), Profile(), Profile()) + with dt[0]: + if isinstance(size, int): # expand + size = (size, size) + p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param + autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference + if isinstance(ims, torch.Tensor): # torch + with amp.autocast(autocast): + return self.model(ims.to(p.device).type_as(p), augment=augment) # inference + + # Pre-process + n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images + shape0, shape1, files = [], [], [] # image and inference shapes, filenames + for i, im in enumerate(ims): + f = f'image{i}' # filename + if isinstance(im, (str, Path)): # filename or uri + im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im + im = np.asarray(ImageOps.exif_transpose(im)) + elif isinstance(im, Image.Image): # PIL Image + im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f + files.append(Path(f).with_suffix('.jpg').name) + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input + s = im.shape[:2] # HWC + shape0.append(s) # image shape + g = max(size) / max(s) # gain + shape1.append([y * g for y in s]) + ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update + shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape + x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad + x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 + + with amp.autocast(autocast): + # Inference + with dt[1]: + y = self.model(x, augment=augment) # forward + + # Post-process + with dt[2]: + y = non_max_suppression(y if self.dmb else y[0], + self.conf, + self.iou, + self.classes, + self.agnostic, + self.multi_label, + max_det=self.max_det) # NMS + for i in range(n): + scale_boxes(shape1, y[i][:, :4], shape0[i]) + + return Detections(ims, y, files, dt, self.names, x.shape) + + +class Detections: + # YOLOv8 detections class for inference results + def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None): + super().__init__() + d = pred[0].device # device + gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations + self.ims = ims # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.files = files # image filenames + self.times = times # profiling times + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + self.n = len(self.pred) # number of images (batch size) + self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms) + self.s = tuple(shape) # inference BCHW shape + + def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')): + s, crops = '', [] + for i, (im, pred) in enumerate(zip(self.ims, self.pred)): + s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string + if pred.shape[0]: + for c in pred[:, -1].unique(): + n = (pred[:, -1] == c).sum() # detections per class + s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string + s = s.rstrip(', ') + if show or save or render or crop: + annotator = Annotator(im, example=str(self.names)) + for *box, conf, cls in reversed(pred): # xyxy, confidence, class + label = f'{self.names[int(cls)]} {conf:.2f}' + if crop: + file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None + crops.append({ + 'box': box, + 'conf': conf, + 'cls': cls, + 'label': label, + 'im': save_one_box(box, im, file=file, save=save)}) + else: # all others + annotator.box_label(box, label if labels else '', color=colors(cls)) + im = annotator.im + else: + s += '(no detections)' + + im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np + if show: + im.show(self.files[i]) # show + if save: + f = self.files[i] + im.save(save_dir / f) # save + if i == self.n - 1: + LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}") + if render: + self.ims[i] = np.asarray(im) + if pprint: + s = s.lstrip('\n') + return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t + if crop: + if save: + LOGGER.info(f'Saved results to {save_dir}\n') + return crops + + def show(self, labels=True): + self._run(show=True, labels=labels) # show results + + def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False): + save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir + self._run(save=True, labels=labels, save_dir=save_dir) # save results + + def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False): + save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None + return self._run(crop=True, save=save, save_dir=save_dir) # crop results + + def render(self, labels=True): + self._run(render=True, labels=labels) # render results + return self.ims + + def pandas(self): + # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0]) + new = copy(self) # return copy + ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns + cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns + for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]): + a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update + setattr(new, k, [pd.DataFrame(x, columns=c) for x in a]) + return new + + def tolist(self): + # return a list of Detections objects, i.e. 'for result in results.tolist():' + r = range(self.n) # iterable + x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r] + # for d in x: + # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: + # setattr(d, k, getattr(d, k)[0]) # pop out of list + return x + + def print(self): + LOGGER.info(self.__str__()) + + def __len__(self): # override len(results) + return self.n + + def __str__(self): # override print(results) + return self._run(pprint=True) # print results + + def __repr__(self): + return f'YOLOv8 {self.__class__} instance\n' + self.__str__() + + +print('works') diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index b676294..c3ec574 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -57,7 +57,7 @@ class BaseModel(nn.Module): x = m(x) # run y.append(x if m.i in self.save else None) # save output if visualize: - pass + LOGGER.info('visualize feature not yet supported') # TODO: feature_visualization(x, m.type, m.i, save_dir=visualize) return x @@ -106,8 +106,8 @@ class BaseModel(nn.Module): Prints model information Args: - verbose (bool): if True, prints out the model information. Defaults to False - imgsz (int): the size of the image that the model will be trained on. Defaults to 640 + verbose (bool): if True, prints out the model information. Defaults to False + imgsz (int): the size of the image that the model will be trained on. Defaults to 640 """ model_info(self, verbose, imgsz) @@ -117,10 +117,10 @@ class BaseModel(nn.Module): parameters or registered buffers Args: - fn: the function to apply to the model + fn: the function to apply to the model Returns: - A model that is a Detect() object. + A model that is a Detect() object. """ self = super()._apply(fn) m = self.model[-1] # Detect() @@ -135,7 +135,7 @@ class BaseModel(nn.Module): This function loads the weights of the model from a file Args: - weights (str): The weights to load into the model. + weights (str): The weights to load into the model. """ # Force all tasks to implement this function raise NotImplementedError("This function needs to be implemented by derived classes!") diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index f55394a..18a5313 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -32,7 +32,7 @@ class YOLO: def __init__(self, model='yolov8n.yaml', type="v8") -> None: """ - > Initializes the YOLO object. + Initializes the YOLO object. Args: model (str, Path): model to load or create @@ -59,7 +59,7 @@ class YOLO: def _new(self, cfg: str, verbose=True): """ - > Initializes a new model and infers the task type from the model definitions. + Initializes a new model and infers the task type from the model definitions. Args: cfg (str): model configuration file @@ -75,7 +75,7 @@ class YOLO: def _load(self, weights: str): """ - > Initializes a new model and infers the task type from the model head. + Initializes a new model and infers the task type from the model head. Args: weights (str): model checkpoint to be loaded @@ -90,7 +90,7 @@ class YOLO: def reset(self): """ - > Resets the model modules. + Resets the model modules. """ for m in self.model.modules(): if hasattr(m, 'reset_parameters'): @@ -100,7 +100,7 @@ class YOLO: def info(self, verbose=False): """ - > Logs model info. + Logs model info. Args: verbose (bool): Controls verbosity. @@ -133,7 +133,7 @@ class YOLO: @smart_inference_mode() def val(self, data=None, **kwargs): """ - > Validate a model on a given dataset . + Validate a model on a given dataset . Args: data (str): The dataset to validate on. Accepts all formats accepted by yolo @@ -152,7 +152,7 @@ class YOLO: @smart_inference_mode() def export(self, **kwargs): """ - > Export model. + Export model. Args: **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs @@ -168,7 +168,7 @@ class YOLO: def train(self, **kwargs): """ - > Trains the model on a given dataset. + Trains the model on a given dataset. Args: **kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section. @@ -197,7 +197,7 @@ class YOLO: def to(self, device): """ - > Sends the model to the given device. + Sends the model to the given device. Args: device (str): device diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 58495c1..1ad0021 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -89,7 +89,7 @@ class BasePredictor: self.vid_path, self.vid_writer = None, None self.annotator = None self.data_path = None - self.output = dict() + self.output = {} self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks callbacks.add_integration_callbacks(self) @@ -216,7 +216,7 @@ class BasePredictor: self.run_callbacks("on_predict_end") def predict_cli(self, source=None, model=None, return_outputs=False): - # as __call__ is a genertor now so have to treat it like a genertor + # as __call__ is a generator now so have to treat it like a generator for _ in (self.__call__(source, model, return_outputs)): pass diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index d107291..5c6f262 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -40,7 +40,7 @@ class BaseTrainer: """ BaseTrainer - > A base class for creating trainers. + A base class for creating trainers. Attributes: args (OmegaConf): Configuration for the trainer. @@ -75,7 +75,7 @@ class BaseTrainer: def __init__(self, config=DEFAULT_CONFIG, overrides=None): """ - > Initializes the BaseTrainer class. + Initializes the BaseTrainer class. Args: config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. @@ -149,13 +149,13 @@ class BaseTrainer: def add_callback(self, event: str, callback): """ - > Appends the given callback. + Appends the given callback. """ self.callbacks[event].append(callback) def set_callback(self, event: str, callback): """ - > Overrides the existing callbacks with the given callback. + Overrides the existing callbacks with the given callback. """ self.callbacks[event] = [callback] @@ -194,7 +194,7 @@ class BaseTrainer: def _setup_train(self, rank, world_size): """ - > Builds dataloaders and optimizer on correct rank process. + Builds dataloaders and optimizer on correct rank process. """ # model self.run_callbacks("on_pretrain_routine_start") @@ -383,13 +383,13 @@ class BaseTrainer: def get_dataset(self, data): """ - > Get train, val path from data dict if it exists. Returns None if data format is not recognized. + Get train, val path from data dict if it exists. Returns None if data format is not recognized. """ return data["train"], data.get("val") or data.get("test") def setup_model(self): """ - > load/create/download model for any task. + load/create/download model for any task. """ if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed return @@ -415,13 +415,13 @@ class BaseTrainer: def preprocess_batch(self, batch): """ - > Allows custom preprocessing model inputs and ground truths depending on task type. + Allows custom preprocessing model inputs and ground truths depending on task type. """ return batch def validate(self): """ - > Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key. + Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key. """ metrics = self.validator(self) fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found @@ -431,7 +431,7 @@ class BaseTrainer: def log(self, text, rank=-1): """ - > Logs the given text to given ranks process if provided, otherwise logs to all ranks. + Logs the given text to given ranks process if provided, otherwise logs to all ranks. Args" text (str): text to log @@ -449,13 +449,13 @@ class BaseTrainer: def get_dataloader(self, dataset_path, batch_size=16, rank=0): """ - > Returns dataloader derived from torch.data.Dataloader. + Returns dataloader derived from torch.data.Dataloader. """ raise NotImplementedError("get_dataloader function not implemented in trainer") def criterion(self, preds, batch): """ - > Returns loss and individual loss items as Tensor. + Returns loss and individual loss items as Tensor. """ raise NotImplementedError("criterion function not implemented in trainer") @@ -543,7 +543,7 @@ class BaseTrainer: @staticmethod def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): """ - > Builds an optimizer with the specified parameters and parameter groups. + Builds an optimizer with the specified parameters and parameter groups. Args: model (nn.Module): model to optimize diff --git a/ultralytics/yolo/utils/callbacks/comet.py b/ultralytics/yolo/utils/callbacks/comet.py index 7133cbb..0f6d4f2 100644 --- a/ultralytics/yolo/utils/callbacks/comet.py +++ b/ultralytics/yolo/utils/callbacks/comet.py @@ -10,7 +10,7 @@ except (ModuleNotFoundError, ImportError): def on_pretrain_routine_start(trainer): - experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8",) + experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8") experiment.log_parameters(dict(trainer.args)) diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py index f2bfc53..263845c 100644 --- a/ultralytics/yolo/utils/downloads.py +++ b/ultralytics/yolo/utils/downloads.py @@ -12,7 +12,7 @@ from zipfile import ZipFile import requests import torch -from ultralytics.yolo.utils import LOGGER +from ultralytics.yolo.utils import LOGGER, SETTINGS def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): @@ -59,7 +59,11 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'): return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets file = Path(str(file).strip().replace("'", '')) - if not file.exists(): + if file.exists(): + return str(file) + elif (SETTINGS['weights_dir'] / file).exists(): + return str(SETTINGS['weights_dir'] / file) + else: # URL specified name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc. if str(file).startswith(('http:/', 'https:/')): # download @@ -94,7 +98,7 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'): min_bytes=1E5, error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}') - return str(file) + return str(file) def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3): diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index e5a88f4..f4f8cfe 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -58,10 +58,9 @@ class ClassificationPredictor(BasePredictor): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg): - cfg.model = cfg.model or "squeezenet1_0" + cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" - predictor = ClassificationPredictor(cfg) predictor.predict_cli() diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py index 237ae8e..aca9703 100644 --- a/ultralytics/yolo/v8/classify/train.py +++ b/ultralytics/yolo/v8/classify/train.py @@ -136,7 +136,7 @@ class ClassificationTrainer(BaseTrainer): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18" + cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist") cfg.lr0 = 0.1 cfg.weight_decay = 5e-5 @@ -151,10 +151,4 @@ def train(cfg): if __name__ == "__main__": - """ - yolo task=classify mode=train model=yolov8n-cls.pt data=mnist160 epochs=10 imgsz=32 - yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32 - yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg - yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript - """ train() diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py index a9895d4..f33ecdf 100644 --- a/ultralytics/yolo/v8/classify/val.py +++ b/ultralytics/yolo/v8/classify/val.py @@ -48,8 +48,8 @@ class ClassificationValidator(BaseValidator): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg): + cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.data = cfg.data or "imagenette160" - cfg.model = cfg.model or "resnet18" validator = ClassificationValidator(args=cfg) validator(model=cfg.model) diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index 817bca9..ea817a9 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -197,7 +197,7 @@ class Loss: @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.model = cfg.model or "yolov8n.yaml" + cfg.model = cfg.model or "yolov8n.pt" cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") cfg.device = cfg.device if cfg.device is not None else '' # trainer = DetectionTrainer(cfg) @@ -208,11 +208,4 @@ def train(cfg): if __name__ == "__main__": - """ - CLI usage: - python ultralytics/yolo/v8/detect/train.py model=yolov8n.yaml data=coco128 epochs=100 imgsz=640 - - TODO: - yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100 - """ train() diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 7eac253..6ec867d 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -234,6 +234,7 @@ class DetectionValidator(BaseValidator): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg): + cfg.model = cfg.model or "yolov8n.pt" cfg.data = cfg.data or "coco128.yaml" validator = DetectionValidator(args=cfg) validator(model=cfg.model) diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index be41f16..845f3dc 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -143,7 +143,7 @@ class SegLoss(Loss): @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg): - cfg.model = cfg.model or "yolov8n-seg.yaml" + cfg.model = cfg.model or "yolov8n-seg.pt" cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") cfg.device = cfg.device if cfg.device is not None else '' # trainer = SegmentationTrainer(cfg) @@ -154,11 +154,4 @@ def train(cfg): if __name__ == "__main__": - """ - CLI usage: - python ultralytics/yolo/v8/segment/train.py model=yolov8n-seg.yaml data=coco128-segments epochs=100 imgsz=640 - - TODO: - Direct cli support, i.e, yolov8 classify_train args.epochs 10 - """ train() diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 32b8a9e..762f540 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -114,8 +114,9 @@ class SegmentationValidator(DetectionValidator): masks=True) if self.args.plots: self.confusion_matrix.process_batch(predn, labelsn) - self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, - 5], cls.squeeze(-1))) # conf, pcls, tcls + + # Append correct_masks, correct_boxes, pconf, pcls, tcls + self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) if self.args.plots and self.batch_i < 3: