Fix model re-fuse() in inference loops (#466)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-18 20:32:36 +01:00
committed by GitHub
parent cc3c774bde
commit a86218b767
22 changed files with 135 additions and 66 deletions

View File

@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
@ -43,6 +43,7 @@ class YOLO:
self.TrainerClass = None # trainer class
self.ValidatorClass = None # validator class
self.PredictorClass = None # predictor class
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
self.task = None # task type
@ -131,11 +132,12 @@ class YOLO:
overrides.update(kwargs)
overrides["mode"] = "predict"
overrides["save"] = kwargs.get("save", False) # not save files by default
predictor = self.PredictorClass(overrides=overrides)
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
predictor.setup(model=self.model, source=source)
return predictor(stream=stream, verbose=verbose)
if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup
self.predictor.args = get_config(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream, verbose=verbose)
@smart_inference_mode()
def val(self, data=None, **kwargs):
@ -170,6 +172,7 @@ class YOLO:
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args.task = self.task
print(args)
exporter = Exporter(overrides=args)
exporter(model=self.model)
@ -224,10 +227,14 @@ class YOLO:
def _reset_ckpt_args(args):
args.pop("project", None)
args.pop("name", None)
args.pop("exist_ok", None)
args.pop("resume", None)
args.pop("batch", None)
args.pop("epochs", None)
args.pop("cache", None)
args.pop("save_json", None)
args.pop("half", None)
args.pop("v5loader", None)
# set device to '' to prevent from auto DDP usage
args["device"] = ''

View File

@ -76,15 +76,15 @@ class BasePredictor:
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
if self.args.save:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_setup = False
self.done_warmup = False
# Usable if setup is done
self.model = None
self.data = self.args.data # data_dict
self.bs = None
self.imgsz = None
self.device = None
self.dataset = None
self.vid_path, self.vid_writer = None, None
@ -105,11 +105,13 @@ class BasePredictor:
def postprocess(self, preds, img, orig_img):
return preds
def setup(self, source=None, model=None):
def setup_source(self, source=None):
if not self.model:
raise Exception("setup model before setting up source!")
# source
source, webcam, screenshot, from_img = self.check_source(source)
# model
stride, pt = self.setup_model(model)
stride, pt = self.model.stride, self.model.pt
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
# Dataloader
@ -143,14 +145,12 @@ class BasePredictor:
transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride)
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz)) # warmup
self.webcam = webcam
self.screenshot = screenshot
self.from_img = from_img
self.imgsz = imgsz
self.done_setup = True
return model
self.bs = bs
@smart_inference_mode()
def __call__(self, source=None, model=None, verbose=False, stream=False):
@ -167,8 +167,20 @@ class BasePredictor:
def stream_inference(self, source=None, model=None, verbose=False):
self.run_callbacks("on_predict_start")
if not self.done_setup:
self.setup(source, model)
# setup model
if not self.model:
self.setup_model(model)
# setup source. Run every time predict is called
self.setup_source(source)
# check if save_dir/ label file exists
if self.args.save:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# warmup model
if not self.done_warmup:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
self.done_warmup = True
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
@ -223,11 +235,9 @@ class BasePredictor:
device = select_device(self.args.device)
model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
self.model = model
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
self.device = device
self.model.eval()
return model.stride, model.pt
def check_source(self, source):
source = source if source is not None else self.args.source

View File

@ -85,11 +85,11 @@ class Results:
def __repr__(self):
s = f'Ultralytics YOLO {self.__class__} instance\n' # string
if self.boxes:
if self.boxes is not None:
s = s + self.boxes.__repr__() + '\n'
if self.masks:
if self.masks is not None:
s = s + self.masks.__repr__() + '\n'
if self.probs:
if self.probs is not None:
s = s + self.probs.__repr__()
s += f'original size: {self.orig_shape}\n'

View File

@ -205,7 +205,7 @@ class BaseTrainer:
self.model = DDP(self.model, device_ids=[rank])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs * 2)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs)
# Batch size
if self.batch_size == -1:
if RANK == -1: # single-GPU only, estimate best batch size