Update .pre-commit-config.yaml
(#1026)
This commit is contained in:
@ -127,11 +127,11 @@ class AutoBackend(nn.Module):
|
||||
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
||||
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
||||
if network.get_parameters()[0].get_layout().empty:
|
||||
network.get_parameters()[0].set_layout(Layout("NCHW"))
|
||||
network.get_parameters()[0].set_layout(Layout('NCHW'))
|
||||
batch_dim = get_batch(network)
|
||||
if batch_dim.is_static:
|
||||
batch_size = batch_dim.get_length()
|
||||
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
||||
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for Intel NCS2
|
||||
elif engine: # TensorRT
|
||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||
@ -184,7 +184,7 @@ class AutoBackend(nn.Module):
|
||||
import tensorflow as tf
|
||||
|
||||
def wrap_frozen_graph(gd, inputs, outputs):
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
||||
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
|
||||
ge = x.graph.as_graph_element
|
||||
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
||||
|
||||
@ -198,7 +198,7 @@ class AutoBackend(nn.Module):
|
||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||
with open(w, 'rb') as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||
@ -220,9 +220,9 @@ class AutoBackend(nn.Module):
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
# load metadata
|
||||
with contextlib.suppress(zipfile.BadZipFile):
|
||||
with zipfile.ZipFile(w, "r") as model:
|
||||
with zipfile.ZipFile(w, 'r') as model:
|
||||
meta_file = model.namelist()[0]
|
||||
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
||||
meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
||||
stride, names = int(meta['stride']), meta['names']
|
||||
elif tfjs: # TF.js
|
||||
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
|
||||
@ -251,8 +251,8 @@ class AutoBackend(nn.Module):
|
||||
else:
|
||||
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
|
||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||
"See https://docs.ultralytics.com/tasks/detection/#export for help."
|
||||
f"\n\n{EXPORT_FORMATS_TABLE}")
|
||||
'See https://docs.ultralytics.com/tasks/detection/#export for help.'
|
||||
f'\n\n{EXPORT_FORMATS_TABLE}')
|
||||
|
||||
# Load external metadata YAML
|
||||
if xml or saved_model or paddle:
|
||||
@ -410,5 +410,5 @@ class AutoBackend(nn.Module):
|
||||
url = urlparse(p) # if url may be Triton inference server
|
||||
types = [s in Path(p).name for s in sf]
|
||||
types[8] &= not types[9] # tflite &= not edgetpu
|
||||
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
|
||||
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
||||
return types + [triton]
|
||||
|
@ -99,7 +99,7 @@ class AutoShape(nn.Module):
|
||||
shape1.append([y * g for y in s])
|
||||
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
||||
x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad
|
||||
x = [LetterBox(shape1, auto=False)(image=im)['img'] for im in ims] # pad
|
||||
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||
|
||||
|
@ -160,7 +160,7 @@ class BaseModel(nn.Module):
|
||||
weights (str): The weights to load into the model.
|
||||
"""
|
||||
# Force all tasks to implement this function
|
||||
raise NotImplementedError("This function needs to be implemented by derived classes!")
|
||||
raise NotImplementedError('This function needs to be implemented by derived classes!')
|
||||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
@ -249,7 +249,7 @@ class SegmentationModel(DetectionModel):
|
||||
super().__init__(cfg, ch, nc, verbose)
|
||||
|
||||
def _forward_augment(self, x):
|
||||
raise NotImplementedError("WARNING ⚠️ SegmentationModel has not supported augment inference yet!")
|
||||
raise NotImplementedError('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
@ -292,7 +292,7 @@ class ClassificationModel(BaseModel):
|
||||
self.info()
|
||||
|
||||
def load(self, weights):
|
||||
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
csd = model.float().state_dict()
|
||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||
self.load_state_dict(csd, strict=False) # load
|
||||
@ -341,10 +341,10 @@ def torch_safe_load(weight):
|
||||
return torch.load(file, map_location='cpu') # load
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == 'omegaconf': # e.name is missing module name
|
||||
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
|
||||
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
|
||||
f"\nRecommend fixes are to train a new model using updated ultralytics package or to "
|
||||
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
||||
LOGGER.warning(f'WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements.'
|
||||
f'\nAutoInstall will run now for {e.name} but this feature will be removed in the future.'
|
||||
f'\nRecommend fixes are to train a new model using updated ultralytics package or to '
|
||||
f'download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0')
|
||||
if e.name != 'models':
|
||||
check_requirements(e.name) # install missing module
|
||||
return torch.load(file, map_location='cpu') # load
|
||||
@ -489,13 +489,13 @@ def guess_model_task(model):
|
||||
|
||||
def cfg2task(cfg):
|
||||
# Guess from YAML dictionary
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
if m in ["classify", "classifier", "cls", "fc"]:
|
||||
return "classify"
|
||||
if m in ["detect"]:
|
||||
return "detect"
|
||||
if m in ["segment"]:
|
||||
return "segment"
|
||||
m = cfg['head'][-1][-2].lower() # output module name
|
||||
if m in ['classify', 'classifier', 'cls', 'fc']:
|
||||
return 'classify'
|
||||
if m in ['detect']:
|
||||
return 'detect'
|
||||
if m in ['segment']:
|
||||
return 'segment'
|
||||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
@ -513,22 +513,22 @@ def guess_model_task(model):
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, Detect):
|
||||
return "detect"
|
||||
return 'detect'
|
||||
elif isinstance(m, Segment):
|
||||
return "segment"
|
||||
return 'segment'
|
||||
elif isinstance(m, Classify):
|
||||
return "classify"
|
||||
return 'classify'
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
model = Path(model).stem
|
||||
if '-seg' in model:
|
||||
return "segment"
|
||||
return 'segment'
|
||||
elif '-cls' in model:
|
||||
return "classify"
|
||||
return 'classify'
|
||||
else:
|
||||
return "detect"
|
||||
return 'detect'
|
||||
|
||||
# Unable to determine task from model
|
||||
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
|
||||
raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
|
||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
|
Reference in New Issue
Block a user