Add initial model interface (#30)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -8,7 +8,8 @@ from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils.anchors import check_anchor_order
|
||||
from ultralytics.yolo.utils.modeling import parse_model
|
||||
from ultralytics.yolo.utils.modeling.modules import *
|
||||
from ultralytics.yolo.utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, time_sync
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info,
|
||||
scale_img, time_sync)
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
@ -67,6 +68,10 @@ class BaseModel(nn.Module):
|
||||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||
return self
|
||||
|
||||
def load(self, weights):
|
||||
# Force all tasks implement this function
|
||||
raise NotImplementedError("This function needs to be implemented by derived classes!")
|
||||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
# YOLO detection model
|
||||
@ -166,6 +171,12 @@ class DetectionModel(BaseModel):
|
||||
b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
||||
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
||||
|
||||
def load(self, weights):
|
||||
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
||||
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
||||
self.load_state_dict(csd, strict=False) # load
|
||||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv5 segmentation model
|
||||
@ -197,3 +208,9 @@ class ClassificationModel(BaseModel):
|
||||
def _from_yaml(self, cfg):
|
||||
# Create a YOLOv5 classification model from a *.yaml file
|
||||
self.model = None
|
||||
|
||||
def load(self, weights):
|
||||
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
||||
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
||||
self.load_state_dict(csd, strict=False) # load
|
||||
|
Reference in New Issue
Block a user