ultralytics 8.0.80
single-line docstring fixes (#2060)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -21,11 +21,11 @@ from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
def check_class_names(names):
|
||||
# Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.
|
||||
"""Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts."""
|
||||
if isinstance(names, list): # names is a list
|
||||
names = dict(enumerate(names)) # convert to dict
|
||||
if isinstance(names, dict):
|
||||
# convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
|
||||
# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
|
||||
names = {int(k): str(v) for k, v in names.items()}
|
||||
n = len(names)
|
||||
if max(names.keys()) >= n:
|
||||
@ -229,7 +229,7 @@ class AutoBackend(nn.Module):
|
||||
interpreter.allocate_tensors() # allocate
|
||||
input_details = interpreter.get_input_details() # inputs
|
||||
output_details = interpreter.get_output_details() # outputs
|
||||
# load metadata
|
||||
# Load metadata
|
||||
with contextlib.suppress(zipfile.BadZipFile):
|
||||
with zipfile.ZipFile(w, 'r') as model:
|
||||
meta_file = model.namelist()[0]
|
||||
|
@ -24,7 +24,7 @@ from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
|
||||
|
||||
|
||||
class AutoShape(nn.Module):
|
||||
# YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
||||
"""YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS."""
|
||||
conf = 0.25 # NMS confidence threshold
|
||||
iou = 0.45 # NMS IoU threshold
|
||||
agnostic = False # NMS class-agnostic
|
||||
@ -47,7 +47,7 @@ class AutoShape(nn.Module):
|
||||
m.export = True # do not output loss values
|
||||
|
||||
def _apply(self, fn):
|
||||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
||||
"""Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers."""
|
||||
self = super()._apply(fn)
|
||||
if self.pt:
|
||||
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
||||
@ -59,7 +59,7 @@ class AutoShape(nn.Module):
|
||||
|
||||
@smart_inference_mode()
|
||||
def forward(self, ims, size=640, augment=False, profile=False):
|
||||
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
|
||||
"""Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:."""
|
||||
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
||||
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
||||
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
||||
@ -202,7 +202,7 @@ class Detections:
|
||||
return self.ims
|
||||
|
||||
def pandas(self):
|
||||
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
||||
"""Return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])."""
|
||||
import pandas
|
||||
new = copy(self) # return copy
|
||||
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
||||
@ -213,7 +213,7 @@ class Detections:
|
||||
return new
|
||||
|
||||
def tolist(self):
|
||||
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
||||
"""Return a list of Detections objects, i.e. 'for result in results.tolist():'."""
|
||||
r = range(self.n) # iterable
|
||||
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
||||
# for d in x:
|
||||
|
@ -12,7 +12,7 @@ from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
|
||||
|
||||
|
||||
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
||||
# Pad to 'same' shape outputs
|
||||
"""Pad to 'same' shape outputs."""
|
||||
if d > 1:
|
||||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
||||
if p is None:
|
||||
@ -21,7 +21,7 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
|
||||
|
||||
|
||||
class Conv(nn.Module):
|
||||
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
||||
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
|
||||
default_act = nn.SiLU() # default activation
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
||||
@ -38,19 +38,21 @@ class Conv(nn.Module):
|
||||
|
||||
|
||||
class DWConv(Conv):
|
||||
# Depth-wise convolution
|
||||
"""Depth-wise convolution."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
|
||||
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
||||
|
||||
|
||||
class DWConvTranspose2d(nn.ConvTranspose2d):
|
||||
# Depth-wise transpose convolution
|
||||
"""Depth-wise transpose convolution."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
|
||||
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
|
||||
|
||||
|
||||
class ConvTranspose(nn.Module):
|
||||
# Convolution transpose 2d layer
|
||||
"""Convolution transpose 2d layer."""
|
||||
default_act = nn.SiLU() # default activation
|
||||
|
||||
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
||||
@ -67,8 +69,11 @@ class ConvTranspose(nn.Module):
|
||||
|
||||
|
||||
class DFL(nn.Module):
|
||||
# Integral module of Distribution Focal Loss (DFL)
|
||||
# Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
||||
"""
|
||||
Integral module of Distribution Focal Loss (DFL).
|
||||
Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
||||
"""
|
||||
|
||||
def __init__(self, c1=16):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
||||
@ -83,7 +88,8 @@ class DFL(nn.Module):
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
|
||||
"""Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
|
||||
|
||||
def __init__(self, c, num_heads):
|
||||
super().__init__()
|
||||
self.q = nn.Linear(c, c, bias=False)
|
||||
@ -100,7 +106,8 @@ class TransformerLayer(nn.Module):
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
# Vision Transformer https://arxiv.org/abs/2010.11929
|
||||
"""Vision Transformer https://arxiv.org/abs/2010.11929."""
|
||||
|
||||
def __init__(self, c1, c2, num_heads, num_layers):
|
||||
super().__init__()
|
||||
self.conv = None
|
||||
@ -119,7 +126,8 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Standard bottleneck
|
||||
"""Standard bottleneck."""
|
||||
|
||||
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
@ -132,7 +140,8 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
class BottleneckCSP(nn.Module):
|
||||
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
||||
"""CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
@ -151,7 +160,8 @@ class BottleneckCSP(nn.Module):
|
||||
|
||||
|
||||
class C3(nn.Module):
|
||||
# CSP Bottleneck with 3 convolutions
|
||||
"""CSP Bottleneck with 3 convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
@ -165,7 +175,8 @@ class C3(nn.Module):
|
||||
|
||||
|
||||
class C2(nn.Module):
|
||||
# CSP Bottleneck with 2 convolutions
|
||||
"""CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super().__init__()
|
||||
self.c = int(c2 * e) # hidden channels
|
||||
@ -180,7 +191,8 @@ class C2(nn.Module):
|
||||
|
||||
|
||||
class C2f(nn.Module):
|
||||
# CSP Bottleneck with 2 convolutions
|
||||
"""CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super().__init__()
|
||||
self.c = int(c2 * e) # hidden channels
|
||||
@ -200,7 +212,8 @@ class C2f(nn.Module):
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
# Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
|
||||
"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
|
||||
|
||||
def __init__(self, channels: int) -> None:
|
||||
super().__init__()
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
@ -212,7 +225,8 @@ class ChannelAttention(nn.Module):
|
||||
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
# Spatial-attention module
|
||||
"""Spatial-attention module."""
|
||||
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
||||
@ -225,7 +239,8 @@ class SpatialAttention(nn.Module):
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
# Convolutional Block Attention Module
|
||||
"""Convolutional Block Attention Module."""
|
||||
|
||||
def __init__(self, c1, kernel_size=7): # ch_in, kernels
|
||||
super().__init__()
|
||||
self.channel_attention = ChannelAttention(c1)
|
||||
@ -236,7 +251,8 @@ class CBAM(nn.Module):
|
||||
|
||||
|
||||
class C1(nn.Module):
|
||||
# CSP Bottleneck with 1 convolution
|
||||
"""CSP Bottleneck with 1 convolution."""
|
||||
|
||||
def __init__(self, c1, c2, n=1): # ch_in, ch_out, number
|
||||
super().__init__()
|
||||
self.cv1 = Conv(c1, c2, 1, 1)
|
||||
@ -248,7 +264,8 @@ class C1(nn.Module):
|
||||
|
||||
|
||||
class C3x(C3):
|
||||
# C3 module with cross-convolutions
|
||||
"""C3 module with cross-convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
self.c_ = int(c2 * e)
|
||||
@ -256,7 +273,8 @@ class C3x(C3):
|
||||
|
||||
|
||||
class C3TR(C3):
|
||||
# C3 module with TransformerBlock()
|
||||
"""C3 module with TransformerBlock()."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
c_ = int(c2 * e)
|
||||
@ -264,7 +282,8 @@ class C3TR(C3):
|
||||
|
||||
|
||||
class C3Ghost(C3):
|
||||
# C3 module with GhostBottleneck()
|
||||
"""C3 module with GhostBottleneck()."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
@ -272,7 +291,8 @@ class C3Ghost(C3):
|
||||
|
||||
|
||||
class SPP(nn.Module):
|
||||
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
|
||||
"""Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
|
||||
|
||||
def __init__(self, c1, c2, k=(5, 9, 13)):
|
||||
super().__init__()
|
||||
c_ = c1 // 2 # hidden channels
|
||||
@ -286,7 +306,8 @@ class SPP(nn.Module):
|
||||
|
||||
|
||||
class SPPF(nn.Module):
|
||||
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
|
||||
"""Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
|
||||
|
||||
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
||||
super().__init__()
|
||||
c_ = c1 // 2 # hidden channels
|
||||
@ -302,7 +323,8 @@ class SPPF(nn.Module):
|
||||
|
||||
|
||||
class Focus(nn.Module):
|
||||
# Focus wh information into c-space
|
||||
"""Focus wh information into c-space."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
super().__init__()
|
||||
self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
|
||||
@ -314,7 +336,8 @@ class Focus(nn.Module):
|
||||
|
||||
|
||||
class GhostConv(nn.Module):
|
||||
# Ghost Convolution https://github.com/huawei-noah/ghostnet
|
||||
"""Ghost Convolution https://github.com/huawei-noah/ghostnet."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
|
||||
super().__init__()
|
||||
c_ = c2 // 2 # hidden channels
|
||||
@ -327,7 +350,8 @@ class GhostConv(nn.Module):
|
||||
|
||||
|
||||
class GhostBottleneck(nn.Module):
|
||||
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
|
||||
"""Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
|
||||
|
||||
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
|
||||
super().__init__()
|
||||
c_ = c2 // 2
|
||||
@ -343,7 +367,8 @@ class GhostBottleneck(nn.Module):
|
||||
|
||||
|
||||
class Concat(nn.Module):
|
||||
# Concatenate a list of tensors along dimension
|
||||
"""Concatenate a list of tensors along dimension."""
|
||||
|
||||
def __init__(self, dimension=1):
|
||||
super().__init__()
|
||||
self.d = dimension
|
||||
@ -353,7 +378,8 @@ class Concat(nn.Module):
|
||||
|
||||
|
||||
class Proto(nn.Module):
|
||||
# YOLOv8 mask Proto module for segmentation models
|
||||
"""YOLOv8 mask Proto module for segmentation models."""
|
||||
|
||||
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
||||
super().__init__()
|
||||
self.cv1 = Conv(c1, c_, k=3)
|
||||
@ -366,7 +392,8 @@ class Proto(nn.Module):
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
# Ensemble of models
|
||||
"""Ensemble of models."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -382,7 +409,7 @@ class Ensemble(nn.ModuleList):
|
||||
|
||||
|
||||
class Detect(nn.Module):
|
||||
# YOLOv8 Detect head for detection models
|
||||
"""YOLOv8 Detect head for detection models."""
|
||||
dynamic = False # force grid reconstruction
|
||||
export = False # export mode
|
||||
shape = None
|
||||
@ -423,7 +450,7 @@ class Detect(nn.Module):
|
||||
return y if self.export else (y, x)
|
||||
|
||||
def bias_init(self):
|
||||
# Initialize Detect() biases, WARNING: requires stride availability
|
||||
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
||||
m = self # self.model[-1] # Detect() module
|
||||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
||||
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
||||
@ -433,7 +460,8 @@ class Detect(nn.Module):
|
||||
|
||||
|
||||
class Segment(Detect):
|
||||
# YOLOv8 Segment head for segmentation models
|
||||
"""YOLOv8 Segment head for segmentation models."""
|
||||
|
||||
def __init__(self, nc=80, nm=32, npr=256, ch=()):
|
||||
super().__init__(nc, ch)
|
||||
self.nm = nm # number of masks
|
||||
@ -456,7 +484,8 @@ class Segment(Detect):
|
||||
|
||||
|
||||
class Pose(Detect):
|
||||
# YOLOv8 Pose head for keypoints models
|
||||
"""YOLOv8 Pose head for keypoints models."""
|
||||
|
||||
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
|
||||
super().__init__(nc, ch)
|
||||
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
||||
@ -486,7 +515,8 @@ class Pose(Detect):
|
||||
|
||||
|
||||
class Classify(nn.Module):
|
||||
# YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
||||
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
|
||||
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
super().__init__()
|
||||
c_ = 1280 # efficientnet_b0 size
|
||||
|
@ -167,7 +167,8 @@ class BaseModel(nn.Module):
|
||||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
# YOLOv8 detection model
|
||||
"""YOLOv8 detection model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
||||
super().__init__()
|
||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||
@ -218,7 +219,7 @@ class DetectionModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def _descale_pred(p, flips, scale, img_size, dim=1):
|
||||
# de-scale predictions following augmented inference (inverse operation)
|
||||
"""De-scale predictions following augmented inference (inverse operation)."""
|
||||
p[:, :4] /= scale # de-scale
|
||||
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
|
||||
if flips == 2:
|
||||
@ -228,7 +229,7 @@ class DetectionModel(BaseModel):
|
||||
return torch.cat((x, y, wh, cls), dim)
|
||||
|
||||
def _clip_augmented(self, y):
|
||||
# Clip YOLOv5 augmented inference tails
|
||||
"""Clip YOLOv5 augmented inference tails."""
|
||||
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
||||
g = sum(4 ** x for x in range(nl)) # grid points
|
||||
e = 1 # exclude layer count
|
||||
@ -240,7 +241,8 @@ class DetectionModel(BaseModel):
|
||||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv8 segmentation model
|
||||
"""YOLOv8 segmentation model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
@ -249,7 +251,8 @@ class SegmentationModel(DetectionModel):
|
||||
|
||||
|
||||
class PoseModel(DetectionModel):
|
||||
# YOLOv8 pose model
|
||||
"""YOLOv8 pose model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||
if not isinstance(cfg, dict):
|
||||
cfg = yaml_model_load(cfg) # load model YAML
|
||||
@ -260,7 +263,8 @@ class PoseModel(DetectionModel):
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
# YOLOv8 classification model
|
||||
"""YOLOv8 classification model."""
|
||||
|
||||
def __init__(self,
|
||||
cfg=None,
|
||||
model=None,
|
||||
@ -272,7 +276,7 @@ class ClassificationModel(BaseModel):
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
||||
"""Create a YOLOv5 classification model from a YOLOv5 detection model."""
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
if isinstance(model, AutoBackend):
|
||||
model = model.model # unwrap DetectMultiBackend
|
||||
@ -304,7 +308,7 @@ class ClassificationModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def reshape_outputs(model, nc):
|
||||
# Update a TorchVision classification model to class count 'n' if required
|
||||
"""Update a TorchVision classification model to class count 'n' if required."""
|
||||
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
||||
if isinstance(m, Classify): # YOLO Classify() head
|
||||
if m.linear.out_features != nc:
|
||||
@ -363,7 +367,7 @@ def torch_safe_load(weight):
|
||||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
||||
|
||||
ensemble = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
@ -403,7 +407,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
|
||||
|
||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
# Loads a single model weights
|
||||
"""Loads a single model weights."""
|
||||
ckpt, weight = torch_safe_load(weight) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
@ -546,7 +550,7 @@ def guess_model_task(model):
|
||||
"""
|
||||
|
||||
def cfg2task(cfg):
|
||||
# Guess from YAML dictionary
|
||||
"""Guess from YAML dictionary."""
|
||||
m = cfg['head'][-1][-2].lower() # output module name
|
||||
if m in ('classify', 'classifier', 'cls', 'fc'):
|
||||
return 'classify'
|
||||
|
Reference in New Issue
Block a user