ultralytics 8.0.14 Hydra removal fixes and cleanup (#542)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kamlesh Kumar <patelkamleshpatel364@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-21 21:22:40 +01:00
committed by GitHub
parent cc3be0e223
commit d9a0fba251
30 changed files with 339 additions and 301 deletions

View File

@ -28,7 +28,7 @@ CLI_HELP_MSG = \
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export]
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
For a full list of available ARGS see https://docs.ultralytics.com/config.
For a full list of available ARGS see https://docs.ultralytics.com/cfg.
Train a detection model for 10 epochs with an initial learning_rate of 0.01
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
@ -48,7 +48,7 @@ CLI_HELP_MSG = \
yolo checks
yolo version
yolo settings
yolo copy-config
yolo copy-cfg
Docs: https://docs.ultralytics.com/cli
Community: https://community.ultralytics.com
@ -56,6 +56,15 @@ CLI_HELP_MSG = \
"""
class UltralyticsCFG(SimpleNamespace):
"""
UltralyticsCFG iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
"""
def __iter__(self):
return iter(vars(self).items())
def cfg2dict(cfg):
"""
Convert a configuration object to a dictionary.
@ -75,30 +84,30 @@ def cfg2dict(cfg):
return cfg
def get_config(config: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
"""
Load and merge configuration data from a file or dictionary.
Args:
config (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
cfg (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
Returns:
(SimpleNamespace): Training arguments namespace.
"""
config = cfg2dict(config)
cfg = cfg2dict(cfg)
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
check_config_mismatch(config, overrides)
config = {**config, **overrides} # merge config and overrides dicts (prefer overrides)
check_cfg_mismatch(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# Return instance
return SimpleNamespace(**config)
return UltralyticsCFG(**cfg)
def check_config_mismatch(base: Dict, custom: Dict):
def check_cfg_mismatch(base: Dict, custom: Dict):
"""
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
@ -127,8 +136,8 @@ def entrypoint(debug=False):
- running special modes like 'checks'
- passing overrides to the package's configuration
It uses the package's default config and initializes it using the passed overrides.
Then it calls the CLI function with the composed config
It uses the package's default cfg and initializes it using the passed overrides.
Then it calls the CLI function with the composed cfg
"""
if debug:
args = ['train', 'predict', 'model=yolov8n.pt'] # for testing
@ -149,7 +158,7 @@ def entrypoint(debug=False):
'checks': checks.check_yolo,
'version': lambda: LOGGER.info(__version__),
'settings': print_settings,
'copy-config': copy_default_config}
'copy-cfg': copy_default_config}
overrides = {} # basic overrides, i.e. imgsz=320
defaults = yaml_load(DEFAULT_CFG_PATH)
@ -190,7 +199,7 @@ def entrypoint(debug=False):
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
f"\n{CLI_HELP_MSG}")
cfg = get_config(defaults, overrides) # create CFG instance
cfg = get_cfg(defaults, overrides) # create CFG instance
# Mapping from task to module
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
@ -214,7 +223,7 @@ def copy_default_config():
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
f"Usage for running YOLO with this new custom cfg:\nyolo cfg={new_file} args...")
if __name__ == '__main__':

View File

@ -1,20 +1,20 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
task: "detect" # inference task, i.e. detect, segment, classify
mode: "train" # YOLO mode, i.e. train, val, predict, export
# Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
data: null # i.e. coco128.yaml. Path to data file
model: null # path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: null # path to data file, i.e. i.e. coco128.yaml
epochs: 100 # number of epochs to train for
patience: 50 # epochs to wait for no observable improvement for early stopping of training
batch: 16 # number of images per batch
imgsz: 640 # size of input images
save: True # save checkpoints
batch: 16 # number of images per batch (-1 for AutoBatch)
imgsz: 640 # size of input images as integer or w,h
save: True # save train checkpoints and predict results
cache: False # True/ram, disk or False. Use cache for data loading
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
workers: 8 # number of worker threads for data loading
device: null # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
workers: 8 # number of worker threads for data loading (per RANK if DDP)
project: null # project name
name: null # experiment name
exist_ok: False # whether to overwrite existing experiment
@ -30,10 +30,10 @@ cos_lr: False # use cosine learning rate scheduler
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
resume: False # resume training from last checkpoint
# Segmentation
overlap_mask: True # masks should overlap during training
mask_ratio: 4 # mask downsample ratio
overlap_mask: True # masks should overlap during training (segment train only)
mask_ratio: 4 # mask downsample ratio (segment train only)
# Classification
dropout: 0.0 # use dropout regularization
dropout: 0.0 # use dropout regularization (classify train only)
# Val/Test settings ----------------------------------------------------------------------------------------------------
val: True # validate/test during training
@ -44,7 +44,7 @@ iou: 0.7 # intersection over union (IoU) threshold for NMS
max_det: 300 # maximum number of detections per image
half: False # use half precision (FP16)
dnn: False # use OpenCV DNN for ONNX inference
plots: True # show plots during training
plots: True # save plots during train/val
# Prediction settings --------------------------------------------------------------------------------------------------
source: null # source directory for images or videos
@ -56,10 +56,11 @@ hide_labels: False # hide labels
hide_conf: False # hide confidence scores
vid_stride: 1 # video frame-rate stride
line_thickness: 3 # bounding box thickness (pixels)
visualize: False # visualize results
augment: False # apply data augmentation to images
visualize: False # visualize model features
augment: False # apply image augmentation to prediction sources
agnostic_nms: False # class-agnostic NMS
retina_masks: False # use retina masks for object detection
retina_masks: False # use high-resolution segmentation masks
classes: null # filter results by class, i.e. class=0, or class=[0,2,3]
# Export settings ------------------------------------------------------------------------------------------------------
format: torchscript # format to export to
@ -73,8 +74,8 @@ workspace: 4 # TensorRT: workspace size (GB)
nms: False # CoreML: add NMS
# Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
lr0: 0.01 # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
warmup_epochs: 3.0 # warmup epochs (fractions ok)
@ -84,7 +85,7 @@ box: 7.5 # box loss gain
cls: 0.5 # cls loss gain (scale with pixels)
dfl: 1.5 # dfl loss gain
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
label_smoothing: 0.0
label_smoothing: 0.0 # label smoothing (fraction)
nbs: 64 # nominal batch size
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)

View File

@ -615,7 +615,7 @@ class LoadImagesAndLabels(Dataset):
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:

View File

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool
from multiprocessing.pool import ThreadPool
from pathlib import Path
import torchvision
@ -51,7 +51,7 @@ class YOLODataset(BaseDataset):
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image_label,
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
repeat(self.use_keypoints)))

View File

@ -67,7 +67,7 @@ import torch
import ultralytics
from ultralytics.nn.modules import Detect, Segment
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import check_dataset
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
@ -134,7 +134,7 @@ class Exporter:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_config(config, overrides)
self.args = get_cfg(config, overrides)
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self)

View File

@ -4,7 +4,7 @@ from pathlib import Path
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
@ -136,7 +136,7 @@ class YOLO:
self.predictor = self.PredictorClass(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_config(self.predictor.args, overrides)
self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream, verbose=verbose)
@smart_inference_mode()
@ -151,7 +151,7 @@ class YOLO:
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides["mode"] = "val"
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
args.data = data or args.data
args.task = self.task
@ -169,7 +169,7 @@ class YOLO:
overrides = self.overrides.copy()
overrides.update(kwargs)
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
args.task = self.task
print(args)
@ -201,7 +201,7 @@ class YOLO:
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
# update model and configs after training
# update model and cfg after training
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args

View File

@ -33,7 +33,7 @@ from pathlib import Path
import cv2
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
@ -70,7 +70,7 @@ class BasePredictor:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_config(config, overrides)
self.args = get_cfg(config, overrides)
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
@ -84,6 +84,7 @@ class BasePredictor:
self.bs = None
self.imgsz = None
self.device = None
self.classes = self.args.classes
self.dataset = None
self.vid_path, self.vid_writer = None, None
self.annotator = None
@ -100,7 +101,7 @@ class BasePredictor:
def write_results(self, results, batch, print_string):
raise NotImplementedError("print_results function needs to be implemented")
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
return preds
def setup_source(self, source=None):
@ -195,7 +196,7 @@ class BasePredictor:
# postprocess
with self.dt[2]:
results = self.postprocess(preds, im, im0s)
results = self.postprocess(preds, im, im0s, self.classes)
for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p = Path(p)

View File

@ -21,6 +21,8 @@ class Results:
masks (Masks, optional): A Masks object containing the detection masks.
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
orig_shape (tuple, optional): Original image size.
data (torch.Tensor): The raw masks tensor
"""
def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None:
@ -81,19 +83,20 @@ class Results:
return len(getattr(self, item))
def __str__(self):
return self.__repr__()
str_out = ""
for item in self.comp:
if getattr(self, item) is None:
continue
str_out = str_out + getattr(self, item).__str__()
return str_out
def __repr__(self):
s = f'Ultralytics YOLO {self.__class__} instance\n' # string
if self.boxes is not None:
s = s + self.boxes.__repr__() + '\n'
if self.masks is not None:
s = s + self.masks.__repr__() + '\n'
if self.probs is not None:
s = s + self.probs.__repr__()
s += f'original size: {self.orig_shape}\n'
return s
str_out = ""
for item in self.comp:
if getattr(self, item) is None:
continue
str_out = str_out + getattr(self, item).__repr__()
return str_out
def __getattr__(self, attr):
name = self.__class__.__name__
@ -129,6 +132,7 @@ class Boxes:
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
data (torch.Tensor): The raw bboxes tensor
"""
def __init__(self, boxes, orig_shape) -> None:
@ -198,15 +202,19 @@ class Boxes:
def shape(self):
return self.boxes.shape
@property
def data(self):
return self.boxes
def __len__(self): # override len(results)
return len(self.boxes)
def __str__(self):
return self.__repr__()
return self.boxes.__str__()
def __repr__(self):
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" +
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}")
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}")
def __getitem__(self, idx):
boxes = self.boxes[idx]
@ -257,12 +265,16 @@ class Masks:
def segments(self):
return [
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=True)
for x in reversed(ops.masks2segments(self.masks))]
for x in ops.masks2segments(self.masks)]
@property
def shape(self):
return self.masks.shape
@property
def data(self):
return self.masks
def cpu(self):
masks = self.masks.cpu()
return Masks(masks, self.orig_shape)
@ -283,11 +295,11 @@ class Masks:
return len(self.masks)
def __str__(self):
return self.__repr__()
return self.masks.__str__()
def __repr__(self):
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" +
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}")
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}\n + {self.masks.__repr__()}")
def __getitem__(self, idx):
masks = self.masks[idx]

View File

@ -23,7 +23,7 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils
from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
yaml_save)
@ -79,7 +79,7 @@ class BaseTrainer:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_config(config, overrides)
self.args = get_cfg(config, overrides)
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
self.check_resume()
self.console = LOGGER
@ -509,7 +509,7 @@ class BaseTrainer:
assert args_yaml.is_file(), \
FileNotFoundError('Resume checkpoint f{last} not found. '
'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt')
args = get_config(args_yaml) # replace
args = get_cfg(args_yaml) # replace
args.model, resume = str(last), True # reinstate
self.args = args
self.resume = resume

View File

@ -8,7 +8,7 @@ import torch
from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
from ultralytics.yolo.utils.checks import check_imgsz
@ -52,7 +52,7 @@ class BaseValidator:
self.dataloader = dataloader
self.pbar = pbar
self.logger = logger or LOGGER
self.args = args or get_config(DEFAULT_CFG_PATH)
self.args = args or get_cfg(DEFAULT_CFG_PATH)
self.model = None
self.data = None
self.device = None

View File

@ -23,7 +23,7 @@ import yaml
# Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO
DEFAULT_CFG_PATH = ROOT / "yolo/configs/default.yaml"
DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode

View File

@ -26,7 +26,7 @@ def on_pretrain_routine_start(trainer):
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={'pytorch': False})
task.connect(dict(trainer.args), name='General')
task.connect(vars(trainer.args), name='General')
def on_train_epoch_end(trainer):

View File

@ -11,7 +11,7 @@ except (ModuleNotFoundError, ImportError):
def on_pretrain_routine_start(trainer):
experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8")
experiment.log_parameters(dict(trainer.args))
experiment.log_parameters(vars(trainer.args))
def on_train_epoch_end(trainer):

View File

@ -137,9 +137,10 @@ def model_info(model, verbose=False, imgsz=640):
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
flops = get_flops(model, imgsz)
fused = ' (fused)' if model.is_fused() else ''
fs = f', {flops:.1f} GFLOPs' if flops else ''
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
LOGGER.info(f"{m} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
LOGGER.info(f"{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
def get_num_params(model):

View File

@ -18,7 +18,7 @@ class ClassificationPredictor(BasePredictor):
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
return img
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
results = []
for i, pred in enumerate(preds):
shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape

View File

@ -19,12 +19,13 @@ class DetectionPredictor(BasePredictor):
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det)
max_det=self.args.max_det,
classes=self.args.classes)
results = []
for i, pred in enumerate(preds):

View File

@ -10,14 +10,15 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
class SegmentationPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_img):
def postprocess(self, preds, img, orig_img, classes=None):
# TODO: filter by classes
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nm=32)
nm=32,
classes=self.args.classes)
results = []
proto = preds[1][-1]
for i, pred in enumerate(p):