|
|
|
@ -11,8 +11,8 @@ import torch.nn as nn
|
|
|
|
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
|
|
|
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
|
|
|
|
GhostBottleneck, GhostConv, Segment)
|
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, emojis, yaml_load
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
|
|
|
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
|
|
|
|
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
|
|
|
|
|
|
|
|
@ -151,15 +151,19 @@ class BaseModel(nn.Module):
|
|
|
|
|
m.strides = fn(m.strides)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def load(self, weights):
|
|
|
|
|
"""
|
|
|
|
|
This function loads the weights of the model from a file
|
|
|
|
|
def load(self, weights, verbose=True):
|
|
|
|
|
"""Load the weights into the model.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
weights (str): The weights to load into the model.
|
|
|
|
|
weights (dict) or (torch.nn.Module): The pre-trained weights to be loaded.
|
|
|
|
|
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
|
|
|
|
"""
|
|
|
|
|
# Force all tasks to implement this function
|
|
|
|
|
raise NotImplementedError('This function needs to be implemented by derived classes!')
|
|
|
|
|
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
|
|
|
|
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
|
|
|
|
self.load_state_dict(csd, strict=False) # load
|
|
|
|
|
if verbose:
|
|
|
|
|
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetectionModel(BaseModel):
|
|
|
|
@ -234,13 +238,6 @@ class DetectionModel(BaseModel):
|
|
|
|
|
y[-1] = y[-1][..., i:] # small
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
def load(self, weights, verbose=True):
|
|
|
|
|
csd = weights.float().state_dict() # checkpoint state_dict as FP32
|
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
|
|
|
|
self.load_state_dict(csd, strict=False) # load
|
|
|
|
|
if verbose and RANK == -1:
|
|
|
|
|
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SegmentationModel(DetectionModel):
|
|
|
|
|
# YOLOv8 segmentation model
|
|
|
|
@ -293,12 +290,6 @@ class ClassificationModel(BaseModel):
|
|
|
|
|
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
|
|
|
|
self.info()
|
|
|
|
|
|
|
|
|
|
def load(self, weights):
|
|
|
|
|
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
|
|
|
|
csd = model.float().state_dict()
|
|
|
|
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
|
|
|
|
self.load_state_dict(csd, strict=False) # load
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def reshape_outputs(model, nc):
|
|
|
|
|
# Update a TorchVision classification model to class count 'n' if required
|
|
|
|
@ -338,6 +329,7 @@ def torch_safe_load(weight):
|
|
|
|
|
"""
|
|
|
|
|
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
|
|
|
|
|
|
|
|
|
check_suffix(file=weight, suffix='.pt')
|
|
|
|
|
file = attempt_download_asset(weight) # search online if missing locally
|
|
|
|
|
try:
|
|
|
|
|
return torch.load(file, map_location='cpu'), file # load
|
|
|
|
|