Model enhancement (#75)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-12-14 14:33:31 +05:30
committed by GitHub
parent d85b44f259
commit eb5adf4e0b
3 changed files with 176 additions and 31 deletions

View File

@ -37,15 +37,23 @@ class AutoBackend(nn.Module):
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
fp16 &= pt or jit or onnx or engine # FP16
fp16 &= pt or jit or onnx or engine or nn_module # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if not (pt or triton):
if not (pt or triton or nn_module):
w = attempt_download(w) # download if not local
if pt: # PyTorch
# NOTE: special case: in-memory pytorch model
if nn_module:
model = weights.to(device)
model = model.fuse() if fuse else model
names = model.module.names if hasattr(model, 'module') else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif pt: # PyTorch
model = attempt_load_weights(weights if isinstance(weights, list) else w,
device=device,
inplace=True,
@ -215,7 +223,7 @@ class AutoBackend(nn.Module):
if self.nhwc:
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pt: # PyTorch
if self.pt or self.nn_module: # PyTorch
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
elif self.jit: # TorchScript
y = self.model(im)
@ -294,7 +302,7 @@ class AutoBackend(nn.Module):
def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1): #
@ -306,7 +314,7 @@ class AutoBackend(nn.Module):
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from ultralytics.yolo.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False):
if not is_url(p, check=False) and not isinstance(p, str):
check_suffix(p, sf) # checks
url = urlparse(p) # if url may be Triton inference server
types = [s in Path(p).name for s in sf]