ultralytics 8.0.54
TFLite export improvements and fixes (#1447)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -411,12 +411,12 @@ class Detect(nn.Module):
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format == 'edgetpu': # FlexSplitV ops issue
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
||||
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, :self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4:]
|
||||
else:
|
||||
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, x)
|
||||
|
@ -11,8 +11,8 @@ import torch.nn as nn
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
GhostBottleneck, GhostConv, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
||||
|
||||
@ -151,15 +151,19 @@ class BaseModel(nn.Module):
|
||||
m.strides = fn(m.strides)
|
||||
return self
|
||||
|
||||
def load(self, weights):
|
||||
"""
|
||||
This function loads the weights of the model from a file
|
||||
def load(self, weights, verbose=True):
|
||||
"""Load the weights into the model.
|
||||
|
||||
Args:
|
||||
weights (str): The weights to load into the model.
|
||||
weights (dict) or (torch.nn.Module): The pre-trained weights to be loaded.
|
||||
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
||||
"""
|
||||
# Force all tasks to implement this function
|
||||
raise NotImplementedError('This function needs to be implemented by derived classes!')
|
||||
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||
self.load_state_dict(csd, strict=False) # load
|
||||
if verbose:
|
||||
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
||||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
@ -234,13 +238,6 @@ class DetectionModel(BaseModel):
|
||||
y[-1] = y[-1][..., i:] # small
|
||||
return y
|
||||
|
||||
def load(self, weights, verbose=True):
|
||||
csd = weights.float().state_dict() # checkpoint state_dict as FP32
|
||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||
self.load_state_dict(csd, strict=False) # load
|
||||
if verbose and RANK == -1:
|
||||
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
||||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv8 segmentation model
|
||||
@ -293,12 +290,6 @@ class ClassificationModel(BaseModel):
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
self.info()
|
||||
|
||||
def load(self, weights):
|
||||
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
csd = model.float().state_dict()
|
||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||
self.load_state_dict(csd, strict=False) # load
|
||||
|
||||
@staticmethod
|
||||
def reshape_outputs(model, nc):
|
||||
# Update a TorchVision classification model to class count 'n' if required
|
||||
@ -338,6 +329,7 @@ def torch_safe_load(weight):
|
||||
"""
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
||||
|
||||
check_suffix(file=weight, suffix='.pt')
|
||||
file = attempt_download_asset(weight) # search online if missing locally
|
||||
try:
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
|
Reference in New Issue
Block a user