|
|
|
@ -76,15 +76,15 @@ class BasePredictor:
|
|
|
|
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
|
|
|
|
name = self.args.name or f"{self.args.mode}"
|
|
|
|
|
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
|
|
|
|
if self.args.save:
|
|
|
|
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
if self.args.conf is None:
|
|
|
|
|
self.args.conf = 0.25 # default conf=0.25
|
|
|
|
|
self.done_setup = False
|
|
|
|
|
self.done_warmup = False
|
|
|
|
|
|
|
|
|
|
# Usable if setup is done
|
|
|
|
|
self.model = None
|
|
|
|
|
self.data = self.args.data # data_dict
|
|
|
|
|
self.bs = None
|
|
|
|
|
self.imgsz = None
|
|
|
|
|
self.device = None
|
|
|
|
|
self.dataset = None
|
|
|
|
|
self.vid_path, self.vid_writer = None, None
|
|
|
|
@ -105,11 +105,13 @@ class BasePredictor:
|
|
|
|
|
def postprocess(self, preds, img, orig_img):
|
|
|
|
|
return preds
|
|
|
|
|
|
|
|
|
|
def setup(self, source=None, model=None):
|
|
|
|
|
def setup_source(self, source=None):
|
|
|
|
|
if not self.model:
|
|
|
|
|
raise Exception("setup model before setting up source!")
|
|
|
|
|
# source
|
|
|
|
|
source, webcam, screenshot, from_img = self.check_source(source)
|
|
|
|
|
# model
|
|
|
|
|
stride, pt = self.setup_model(model)
|
|
|
|
|
stride, pt = self.model.stride, self.model.pt
|
|
|
|
|
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
|
|
|
|
|
|
|
|
|
|
# Dataloader
|
|
|
|
@ -143,14 +145,12 @@ class BasePredictor:
|
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None),
|
|
|
|
|
vid_stride=self.args.vid_stride)
|
|
|
|
|
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
|
|
|
|
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz)) # warmup
|
|
|
|
|
|
|
|
|
|
self.webcam = webcam
|
|
|
|
|
self.screenshot = screenshot
|
|
|
|
|
self.from_img = from_img
|
|
|
|
|
self.imgsz = imgsz
|
|
|
|
|
self.done_setup = True
|
|
|
|
|
return model
|
|
|
|
|
self.bs = bs
|
|
|
|
|
|
|
|
|
|
@smart_inference_mode()
|
|
|
|
|
def __call__(self, source=None, model=None, verbose=False, stream=False):
|
|
|
|
@ -167,8 +167,20 @@ class BasePredictor:
|
|
|
|
|
|
|
|
|
|
def stream_inference(self, source=None, model=None, verbose=False):
|
|
|
|
|
self.run_callbacks("on_predict_start")
|
|
|
|
|
if not self.done_setup:
|
|
|
|
|
self.setup(source, model)
|
|
|
|
|
|
|
|
|
|
# setup model
|
|
|
|
|
if not self.model:
|
|
|
|
|
self.setup_model(model)
|
|
|
|
|
# setup source. Run every time predict is called
|
|
|
|
|
self.setup_source(source)
|
|
|
|
|
# check if save_dir/ label file exists
|
|
|
|
|
if self.args.save:
|
|
|
|
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
# warmup model
|
|
|
|
|
if not self.done_warmup:
|
|
|
|
|
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
|
|
|
|
|
self.done_warmup = True
|
|
|
|
|
|
|
|
|
|
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
|
|
|
|
for batch in self.dataset:
|
|
|
|
|
self.run_callbacks("on_predict_batch_start")
|
|
|
|
@ -223,11 +235,9 @@ class BasePredictor:
|
|
|
|
|
device = select_device(self.args.device)
|
|
|
|
|
model = model or self.args.model
|
|
|
|
|
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
|
|
|
|
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
|
|
|
|
|
self.model = model
|
|
|
|
|
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
|
|
|
|
|
self.device = device
|
|
|
|
|
self.model.eval()
|
|
|
|
|
return model.stride, model.pt
|
|
|
|
|
|
|
|
|
|
def check_source(self, source):
|
|
|
|
|
source = source if source is not None else self.args.source
|
|
|
|
|