ultralytics 8.0.136
refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -16,10 +16,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
from ultralytics.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
|
||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
||||
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
def check_class_names(names):
|
||||
@ -34,7 +34,7 @@ def check_class_names(names):
|
||||
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
|
||||
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
|
||||
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
|
||||
map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names
|
||||
map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
|
||||
names = {k: map[v] for k, v in names.items()}
|
||||
return names
|
||||
|
||||
@ -210,7 +210,7 @@ class AutoBackend(nn.Module):
|
||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||
import tensorflow as tf
|
||||
|
||||
from ultralytics.yolo.engine.exporter import gd_outputs
|
||||
from ultralytics.engine.exporter import gd_outputs
|
||||
|
||||
def wrap_frozen_graph(gd, inputs, outputs):
|
||||
"""Wrap frozen graphs for deployment."""
|
||||
@ -284,7 +284,7 @@ class AutoBackend(nn.Module):
|
||||
"""
|
||||
raise NotImplementedError('Triton Inference Server is not currently supported.')
|
||||
else:
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||
'See https://docs.ultralytics.com/modes/predict for help.'
|
||||
f'\n\n{export_formats()}')
|
||||
@ -476,7 +476,7 @@ class AutoBackend(nn.Module):
|
||||
"""
|
||||
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
sf = list(export_formats().Suffix) # export suffixes
|
||||
if not is_url(p, check=False) and not isinstance(p, str):
|
||||
check_suffix(p, sf) # checks
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import constant_, xavier_uniform_
|
||||
|
||||
from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
|
||||
from ultralytics.utils.tal import dist2bbox, make_anchors
|
||||
|
||||
from .block import DFL, Proto
|
||||
from .conv import Conv
|
||||
@ -219,7 +219,7 @@ class RTDETRDecoder(nn.Module):
|
||||
self._reset_parameters()
|
||||
|
||||
def forward(self, x, batch=None):
|
||||
from ultralytics.vit.utils.ops import get_cdn_group
|
||||
from ultralytics.models.utils.ops import get_cdn_group
|
||||
|
||||
# input projection and embedding
|
||||
feats, shapes = self._get_encoder_input(x)
|
||||
|
@ -11,12 +11,12 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
|
||||
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
|
||||
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
|
||||
RTDETRDecoder, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.yolo.utils.plotting import feature_visualization
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.utils.plotting import feature_visualization
|
||||
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
|
||||
make_divisible, model_info, scale_img, time_sync)
|
||||
|
||||
try:
|
||||
import thop
|
||||
@ -412,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
|
||||
|
||||
def init_criterion(self):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
|
||||
from ultralytics.models.utils.loss import RTDETRDetectionLoss
|
||||
|
||||
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
||||
|
||||
@ -498,6 +498,45 @@ class Ensemble(nn.ModuleList):
|
||||
# Functions ------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_modules(modules=None):
|
||||
"""
|
||||
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
||||
|
||||
This function can be used to change the module paths during runtime. It's useful when refactoring code,
|
||||
where you've moved a module from one location to another, but you still want to support the old import
|
||||
paths for backwards compatibility.
|
||||
|
||||
Args:
|
||||
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
||||
|
||||
Example:
|
||||
with temporary_modules({'old.module.path': 'new.module.path'}):
|
||||
import old.module.path # this will now import new.module.path
|
||||
|
||||
Note:
|
||||
The changes are only in effect inside the context manager and are undone once the context manager exits.
|
||||
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
||||
applications or libraries. Use this function with caution.
|
||||
"""
|
||||
if not modules:
|
||||
modules = {}
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
# Set modules in sys.modules under their old name
|
||||
for old, new in modules.items():
|
||||
sys.modules[old] = importlib.import_module(new)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Remove the temporary module paths
|
||||
for old in modules:
|
||||
if old in sys.modules:
|
||||
del sys.modules[old]
|
||||
|
||||
|
||||
def torch_safe_load(weight):
|
||||
"""
|
||||
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
|
||||
@ -510,12 +549,17 @@ def torch_safe_load(weight):
|
||||
Returns:
|
||||
(dict): The loaded PyTorch model.
|
||||
"""
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
check_suffix(file=weight, suffix='.pt')
|
||||
file = attempt_download_asset(weight) # search online if missing locally
|
||||
try:
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
with temporary_modules({
|
||||
'ultralytics.yolo.utils': 'ultralytics.utils',
|
||||
'ultralytics.yolo.v8': 'ultralytics.models.yolo',
|
||||
'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
|
||||
except ModuleNotFoundError as e: # e.name is missing module name
|
||||
if e.name == 'models':
|
||||
raise TypeError(
|
||||
|
Reference in New Issue
Block a user