ultralytics 8.0.47
Docker and reformat updates (#1153)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -136,7 +136,7 @@ class AutoBackend(nn.Module):
|
||||
batch_dim = get_batch(network)
|
||||
if batch_dim.is_static:
|
||||
batch_size = batch_dim.get_length()
|
||||
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for Intel NCS2
|
||||
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for NCS2
|
||||
elif engine: # TensorRT
|
||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||
@ -176,6 +176,8 @@ class AutoBackend(nn.Module):
|
||||
LOGGER.info(f'Loading {w} for CoreML inference...')
|
||||
import coremltools as ct
|
||||
model = ct.models.MLModel(w)
|
||||
names, stride, task = (model.user_defined_metadata.get(k) for k in ('names', 'stride', 'task'))
|
||||
names, stride = eval(names), int(stride)
|
||||
elif saved_model: # TF SavedModel
|
||||
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
||||
import tensorflow as tf
|
||||
@ -185,18 +187,13 @@ class AutoBackend(nn.Module):
|
||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||
import tensorflow as tf
|
||||
|
||||
from ultralytics.yolo.engine.exporter import gd_outputs
|
||||
|
||||
def wrap_frozen_graph(gd, inputs, outputs):
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
||||
ge = x.graph.as_graph_element
|
||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||
|
||||
def gd_outputs(gd):
|
||||
name_list, input_list = [], []
|
||||
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
||||
name_list.append(node.name)
|
||||
input_list.extend(node.input)
|
||||
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
||||
|
||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||
with open(w, 'rb') as f:
|
||||
gd.ParseFromString(f.read())
|
||||
@ -319,10 +316,17 @@ class AutoBackend(nn.Module):
|
||||
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
||||
elif self.coreml: # CoreML
|
||||
im = im.cpu().numpy()
|
||||
im = Image.fromarray((im[0] * 255).astype('uint8'))
|
||||
im = im[0].cpu().numpy()
|
||||
if self.task == 'classify':
|
||||
from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD
|
||||
|
||||
# im_pil = Image.fromarray(((im / 6 + 0.5) * 255).astype('uint8'))
|
||||
for i in range(3):
|
||||
im[..., i] *= IMAGENET_STD[i]
|
||||
im[..., i] += IMAGENET_MEAN[i]
|
||||
im_pil = Image.fromarray((im * 255).astype('uint8'))
|
||||
# im = im.resize((192, 320), Image.ANTIALIAS)
|
||||
y = self.model.predict({'image': im}) # coordinates are xywh normalized
|
||||
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
||||
if 'confidence' in y:
|
||||
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
||||
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
||||
|
@ -11,7 +11,7 @@ import torch.nn as nn
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
GhostBottleneck, GhostConv, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
||||
@ -76,7 +76,7 @@ class BaseModel(nn.Module):
|
||||
None
|
||||
"""
|
||||
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
||||
o = thop.profile(m, inputs=(x.clone() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
o = thop.profile(m, inputs=[x.clone() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
t = time_sync()
|
||||
for _ in range(10):
|
||||
m(x.clone() if c else x)
|
||||
@ -339,14 +339,20 @@ def torch_safe_load(weight):
|
||||
file = attempt_download_asset(weight) # search online if missing locally
|
||||
try:
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == 'omegaconf': # e.name is missing module name
|
||||
LOGGER.warning(f'WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements.'
|
||||
f'\nAutoInstall will run now for {e.name} but this feature will be removed in the future.'
|
||||
f'\nRecommend fixes are to train a new model using updated ultralytics package or to '
|
||||
f'download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0')
|
||||
if e.name != 'models':
|
||||
check_requirements(e.name) # install missing module
|
||||
except ModuleNotFoundError as e: # e.name is missing module name
|
||||
if e.name == 'models':
|
||||
raise TypeError(
|
||||
emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
|
||||
f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
|
||||
f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
|
||||
LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
||||
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
|
||||
check_requirements(e.name) # install missing module
|
||||
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
|
||||
|
||||
@ -437,22 +443,21 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
||||
|
||||
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||
if m in {
|
||||
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
||||
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
|
||||
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
||||
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(c2 * gw, 8)
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in {BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x}:
|
||||
if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x):
|
||||
args.insert(2, n) # number of repeats
|
||||
n = 1
|
||||
elif m is nn.BatchNorm2d:
|
||||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in {Detect, Segment}:
|
||||
elif m in (Detect, Segment):
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(args[2] * gw, 8)
|
||||
@ -490,11 +495,11 @@ def guess_model_task(model):
|
||||
def cfg2task(cfg):
|
||||
# Guess from YAML dictionary
|
||||
m = cfg['head'][-1][-2].lower() # output module name
|
||||
if m in ['classify', 'classifier', 'cls', 'fc']:
|
||||
if m in ('classify', 'classifier', 'cls', 'fc'):
|
||||
return 'classify'
|
||||
if m in ['detect']:
|
||||
if m == 'detect':
|
||||
return 'detect'
|
||||
if m in ['segment']:
|
||||
if m == 'segment':
|
||||
return 'segment'
|
||||
|
||||
# Guess from model cfg
|
||||
|
Reference in New Issue
Block a user