Fix model re-fuse() in inference loops (#466)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent cc3c774bde
commit a86218b767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.8" __version__ = "8.0.9"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -63,7 +63,8 @@ class BaseModel(nn.Module):
def _profile_one_layer(self, m, x, dt): def _profile_one_layer(self, m, x, dt):
""" """
Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to the provided list. Profile the computation time and FLOPs of a single layer of the model on a given input.
Appends the results to the provided list.
Args: Args:
m (nn.Module): The layer to be profiled. m (nn.Module): The layer to be profiled.
@ -74,10 +75,10 @@ class BaseModel(nn.Module):
None None
""" """
c = m == self.model[-1] # is final layer, copy input as inplace fix c = m == self.model[-1] # is final layer, copy input as inplace fix
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs o = thop.profile(m, inputs=(x.clone() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
t = time_sync() t = time_sync()
for _ in range(10): for _ in range(10):
m(x.copy() if c else x) m(x.clone() if c else x)
dt.append((time_sync() - t) * 100) dt.append((time_sync() - t) * 100)
if m == self.model[0]: if m == self.model[0]:
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
@ -87,20 +88,36 @@ class BaseModel(nn.Module):
def fuse(self): def fuse(self):
""" """
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency. Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
computation efficiency.
Returns: Returns:
(nn.Module): The fused model is returned. (nn.Module): The fused model is returned.
""" """
LOGGER.info('Fusing layers... ') if not self.is_fused():
for m in self.model.modules(): LOGGER.info('Fusing... ')
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'): for m in self.model.modules():
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
delattr(m, 'bn') # remove batchnorm m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
m.forward = m.forward_fuse # update forward delattr(m, 'bn') # remove batchnorm
self.info() m.forward = m.forward_fuse # update forward
self.info()
return self return self
def is_fused(self, thresh=10):
"""
Check if the model has less than a certain threshold of BatchNorm layers.
Args:
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
Returns:
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
"""
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
def info(self, verbose=False, imgsz=640): def info(self, verbose=False, imgsz=640):
""" """
Prints model information Prints model information

@ -1,7 +1,9 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import argparse import argparse
import re
import shutil import shutil
import sys
from pathlib import Path from pathlib import Path
from ultralytics import __version__, yolo from ultralytics import __version__, yolo
@ -17,7 +19,7 @@ CLI_HELP_MSG = \
pip install ultralytics pip install ultralytics
2. Train, Val, Predict and Export using 'yolo' commands of the form: 2. Train, Val, Predict and Export using 'yolo' commands:
yolo TASK MODE ARGS yolo TASK MODE ARGS
@ -97,9 +99,14 @@ def entrypoint():
It uses the package's default config and initializes it using the passed overrides. It uses the package's default config and initializes it using the passed overrides.
Then it calls the CLI function with the composed config Then it calls the CLI function with the composed config
""" """
if len(sys.argv) == 1: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
parser = argparse.ArgumentParser(description='YOLO parser') parser = argparse.ArgumentParser(description='YOLO parser')
parser.add_argument('args', type=str, nargs='+', help='YOLO args') parser.add_argument('args', type=str, nargs='+', help='YOLO args')
args = parser.parse_args().args args = parser.parse_args().args
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
tasks = 'detect', 'segment', 'classify' tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export' modes = 'train', 'val', 'predict', 'export'

@ -8,7 +8,7 @@ mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
data: null # i.e. coco128.yaml. Path to data file data: null # i.e. coco128.yaml. Path to data file
epochs: 100 # number of epochs to train for epochs: 100 # number of epochs to train for
patience: 50 # TODO: epochs to wait for no observable improvement for early stopping of training patience: 50 # epochs to wait for no observable improvement for early stopping of training
batch: 16 # number of images per batch batch: 16 # number of images per batch
imgsz: 640 # size of input images imgsz: 640 # size of input images
save: True # save checkpoints save: True # save checkpoints

@ -28,10 +28,9 @@ names:
# Download script/URL (optional) --------------------------------------------------------------------------------------- # Download script/URL (optional) ---------------------------------------------------------------------------------------
download: | download: |
import json import json
from tqdm import tqdm from tqdm import tqdm
from utils.general import download, Path from ultralytics.yolo.utils.downloads import download
from pathlib import Path
def argoverse2yolo(set): def argoverse2yolo(set):
labels = {} labels = {}

@ -32,8 +32,8 @@ names:
# Download script/URL (optional) --------------------------------------------------------------------------------------- # Download script/URL (optional) ---------------------------------------------------------------------------------------
download: | download: |
from utils.general import download, Path from ultralytics.yolo.utils.downloads import download
from pathlib import Path
# Download # Download
dir = Path(yaml['path']) # dataset root dir dir = Path(yaml['path']) # dataset root dir

@ -386,7 +386,12 @@ names:
download: | download: |
from tqdm import tqdm from tqdm import tqdm
from utils.general import Path, check_requirements, download, np, xyxy2xywhn from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.ops import xyxy2xywhn
import numpy as np
from pathlib import Path
check_requirements(('pycocotools>=2.0',)) check_requirements(('pycocotools>=2.0',))
from pycocotools.coco import COCO from pycocotools.coco import COCO

@ -21,9 +21,14 @@ names:
# Download script/URL (optional) --------------------------------------------------------------------------------------- # Download script/URL (optional) ---------------------------------------------------------------------------------------
download: | download: |
import shutil import shutil
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from utils.general import np, pd, Path, download, xyxy2xywh
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.ops import xyxy2xywh
# Download # Download
dir = Path(yaml['path']) # dataset root dir dir = Path(yaml['path']) # dataset root dir

@ -48,8 +48,8 @@ download: |
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from tqdm import tqdm from tqdm import tqdm
from utils.general import download, Path from ultralytics.yolo.utils.downloads import download
from pathlib import Path
def convert_label(path, lb_path, year, image_id): def convert_label(path, lb_path, year, image_id):
def convert_box(size, box): def convert_box(size, box):

@ -29,7 +29,10 @@ names:
# Download script/URL (optional) --------------------------------------------------------------------------------------- # Download script/URL (optional) ---------------------------------------------------------------------------------------
download: | download: |
from utils.general import download, os, Path import os
from pathlib import Path
from ultralytics.yolo.utils.downloads import download
def visdrone2yolo(dir): def visdrone2yolo(dir):
from PIL import Image from PIL import Image

@ -99,7 +99,9 @@ names:
# Download script/URL (optional) # Download script/URL (optional)
download: | download: |
from utils.general import download, Path from ultralytics.yoloutils.downloads import download
from pathlib import Path
# Download labels # Download labels
segments = True # segment or box labels segments = True # segment or box labels
dir = Path(yaml['path']) # dataset root dir dir = Path(yaml['path']) # dataset root dir

@ -87,8 +87,8 @@ download: |
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from utils.dataloaders import autosplit from ultralytics.yolo.data.dataloaders.v5loader import autosplit
from utils.general import download, xyxy2xywhn from ultralytics.yolo.utils.ops import xyxy2xywhn
def convert_labels(fname=Path('xView/xView_train.geojson')): def convert_labels(fname=Path('xView/xView_train.geojson')):

@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
# Map head to model, trainer, validator, and predictor classes # Map head to model, trainer, validator, and predictor classes
@ -43,6 +43,7 @@ class YOLO:
self.TrainerClass = None # trainer class self.TrainerClass = None # trainer class
self.ValidatorClass = None # validator class self.ValidatorClass = None # validator class
self.PredictorClass = None # predictor class self.PredictorClass = None # predictor class
self.predictor = None # reuse predictor
self.model = None # model object self.model = None # model object
self.trainer = None # trainer object self.trainer = None # trainer object
self.task = None # task type self.task = None # task type
@ -131,11 +132,12 @@ class YOLO:
overrides.update(kwargs) overrides.update(kwargs)
overrides["mode"] = "predict" overrides["mode"] = "predict"
overrides["save"] = kwargs.get("save", False) # not save files by default overrides["save"] = kwargs.get("save", False) # not save files by default
predictor = self.PredictorClass(overrides=overrides) if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size self.predictor.setup_model(model=self.model)
predictor.setup(model=self.model, source=source) else: # only update args if predictor is already setup
return predictor(stream=stream, verbose=verbose) self.predictor.args = get_config(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream, verbose=verbose)
@smart_inference_mode() @smart_inference_mode()
def val(self, data=None, **kwargs): def val(self, data=None, **kwargs):
@ -170,6 +172,7 @@ class YOLO:
args = get_config(config=DEFAULT_CONFIG, overrides=overrides) args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args.task = self.task args.task = self.task
print(args)
exporter = Exporter(overrides=args) exporter = Exporter(overrides=args)
exporter(model=self.model) exporter(model=self.model)
@ -224,10 +227,14 @@ class YOLO:
def _reset_ckpt_args(args): def _reset_ckpt_args(args):
args.pop("project", None) args.pop("project", None)
args.pop("name", None) args.pop("name", None)
args.pop("exist_ok", None)
args.pop("resume", None)
args.pop("batch", None) args.pop("batch", None)
args.pop("epochs", None) args.pop("epochs", None)
args.pop("cache", None) args.pop("cache", None)
args.pop("save_json", None) args.pop("save_json", None)
args.pop("half", None)
args.pop("v5loader", None)
# set device to '' to prevent from auto DDP usage # set device to '' to prevent from auto DDP usage
args["device"] = '' args["device"] = ''

@ -76,15 +76,15 @@ class BasePredictor:
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}" name = self.args.name or f"{self.args.mode}"
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
if self.args.save:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None: if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25 self.args.conf = 0.25 # default conf=0.25
self.done_setup = False self.done_warmup = False
# Usable if setup is done # Usable if setup is done
self.model = None self.model = None
self.data = self.args.data # data_dict self.data = self.args.data # data_dict
self.bs = None
self.imgsz = None
self.device = None self.device = None
self.dataset = None self.dataset = None
self.vid_path, self.vid_writer = None, None self.vid_path, self.vid_writer = None, None
@ -105,11 +105,13 @@ class BasePredictor:
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_img):
return preds return preds
def setup(self, source=None, model=None): def setup_source(self, source=None):
if not self.model:
raise Exception("setup model before setting up source!")
# source # source
source, webcam, screenshot, from_img = self.check_source(source) source, webcam, screenshot, from_img = self.check_source(source)
# model # model
stride, pt = self.setup_model(model) stride, pt = self.model.stride, self.model.pt
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
# Dataloader # Dataloader
@ -143,14 +145,12 @@ class BasePredictor:
transforms=getattr(self.model.model, 'transforms', None), transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride) vid_stride=self.args.vid_stride)
self.vid_path, self.vid_writer = [None] * bs, [None] * bs self.vid_path, self.vid_writer = [None] * bs, [None] * bs
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz)) # warmup
self.webcam = webcam self.webcam = webcam
self.screenshot = screenshot self.screenshot = screenshot
self.from_img = from_img self.from_img = from_img
self.imgsz = imgsz self.imgsz = imgsz
self.done_setup = True self.bs = bs
return model
@smart_inference_mode() @smart_inference_mode()
def __call__(self, source=None, model=None, verbose=False, stream=False): def __call__(self, source=None, model=None, verbose=False, stream=False):
@ -167,8 +167,20 @@ class BasePredictor:
def stream_inference(self, source=None, model=None, verbose=False): def stream_inference(self, source=None, model=None, verbose=False):
self.run_callbacks("on_predict_start") self.run_callbacks("on_predict_start")
if not self.done_setup:
self.setup(source, model) # setup model
if not self.model:
self.setup_model(model)
# setup source. Run every time predict is called
self.setup_source(source)
# check if save_dir/ label file exists
if self.args.save:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# warmup model
if not self.done_warmup:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
self.done_warmup = True
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()) self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
for batch in self.dataset: for batch in self.dataset:
self.run_callbacks("on_predict_batch_start") self.run_callbacks("on_predict_batch_start")
@ -223,11 +235,9 @@ class BasePredictor:
device = select_device(self.args.device) device = select_device(self.args.device)
model = model or self.args.model model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half) self.model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
self.model = model
self.device = device self.device = device
self.model.eval() self.model.eval()
return model.stride, model.pt
def check_source(self, source): def check_source(self, source):
source = source if source is not None else self.args.source source = source if source is not None else self.args.source

@ -85,11 +85,11 @@ class Results:
def __repr__(self): def __repr__(self):
s = f'Ultralytics YOLO {self.__class__} instance\n' # string s = f'Ultralytics YOLO {self.__class__} instance\n' # string
if self.boxes: if self.boxes is not None:
s = s + self.boxes.__repr__() + '\n' s = s + self.boxes.__repr__() + '\n'
if self.masks: if self.masks is not None:
s = s + self.masks.__repr__() + '\n' s = s + self.masks.__repr__() + '\n'
if self.probs: if self.probs is not None:
s = s + self.probs.__repr__() s = s + self.probs.__repr__()
s += f'original size: {self.orig_shape}\n' s += f'original size: {self.orig_shape}\n'

@ -205,7 +205,7 @@ class BaseTrainer:
self.model = DDP(self.model, device_ids=[rank]) self.model = DDP(self.model, device_ids=[rank])
# Check imgsz # Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride) gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs * 2) self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs)
# Batch size # Batch size
if self.batch_size == -1: if self.batch_size == -1:
if RANK == -1: # single-GPU only, estimate best batch size if RANK == -1: # single-GPU only, estimate best batch size

@ -372,7 +372,14 @@ def set_sentry(dsn=None):
import sentry_sdk # noqa import sentry_sdk # noqa
import ultralytics import ultralytics
sentry_sdk.init(dsn=dsn, traces_sample_rate=1.0, release=ultralytics.__version__, debug=False) sentry_sdk.init(
dsn=dsn,
debug=False,
traces_sample_rate=1.0,
release=ultralytics.__version__,
send_default_pii=True,
environment='production', # 'dev' or 'production'
ignore_errors=[KeyboardInterrupt, torch.cuda.OutOfMemoryError])
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):

@ -5,7 +5,7 @@ import torch
from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory
from ultralytics.yolo.utils.plotting import Annotator from ultralytics.yolo.utils.plotting import Annotator
@ -67,7 +67,8 @@ class ClassificationPredictor(BasePredictor):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg): def predict(cfg):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg"
predictor = ClassificationPredictor(cfg) predictor = ClassificationPredictor(cfg)
predictor.predict_cli() predictor.predict_cli()

@ -140,10 +140,13 @@ class ClassificationTrainer(BaseTrainer):
def train(cfg): def train(cfg):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
cfg.lr0 = 0.1
cfg.weight_decay = 5e-5 # Reproduce ImageNet results
cfg.label_smoothing = 0.1 # cfg.lr0 = 0.1
cfg.warmup_epochs = 0.0 # cfg.weight_decay = 5e-5
# cfg.label_smoothing = 0.1
# cfg.warmup_epochs = 0.0
cfg.device = cfg.device if cfg.device is not None else '' cfg.device = cfg.device if cfg.device is not None else ''
# trainer = ClassificationTrainer(cfg) # trainer = ClassificationTrainer(cfg)
# trainer.train() # trainer.train()

@ -5,7 +5,7 @@ import torch
from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
@ -84,7 +84,8 @@ class DetectionPredictor(BasePredictor):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg): def predict(cfg):
cfg.model = cfg.model or "yolov8n.pt" cfg.model = cfg.model or "yolov8n.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg"
predictor = DetectionPredictor(cfg) predictor = DetectionPredictor(cfg)
predictor.predict_cli() predictor.predict_cli()

@ -4,7 +4,7 @@ import hydra
import torch import torch
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils.plotting import colors, save_one_box from ultralytics.yolo.utils.plotting import colors, save_one_box
from ultralytics.yolo.v8.detect.predict import DetectionPredictor from ultralytics.yolo.v8.detect.predict import DetectionPredictor
@ -101,8 +101,8 @@ class SegmentationPredictor(DetectionPredictor):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg): def predict(cfg):
cfg.model = cfg.model or "yolov8n-seg.pt" cfg.model = cfg.model or "yolov8n-seg.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg"
predictor = SegmentationPredictor(cfg) predictor = SegmentationPredictor(cfg)
predictor.predict_cli() predictor.predict_cli()

@ -45,6 +45,7 @@ class SegmentationValidator(DetectionValidator):
self.jdict = [] self.jdict = []
self.stats = [] self.stats = []
if self.args.save_json: if self.args.save_json:
check_requirements('pycocotools>=2.0.6')
self.process = ops.process_mask_upsample # more accurate self.process = ops.process_mask_upsample # more accurate
else: else:
self.process = ops.process_mask # faster self.process = ops.process_mask # faster
@ -189,8 +190,9 @@ class SegmentationValidator(DetectionValidator):
self.plot_masks.clear() self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks): def pred_to_json(self, predn, filename, pred_masks):
# Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} # Save one JSON result
from pycocotools.mask import encode # Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode # noqa
def single_encode(x): def single_encode(x):
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]

Loading…
Cancel
Save