Add initial model interface (#30)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-10-26 01:21:15 +05:30
committed by GitHub
parent 7b560f7861
commit 1054819a59
12 changed files with 220 additions and 109 deletions

View File

@ -8,7 +8,8 @@ from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.anchors import check_anchor_order
from ultralytics.yolo.utils.modeling import parse_model
from ultralytics.yolo.utils.modeling.modules import *
from ultralytics.yolo.utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, time_sync
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info,
scale_img, time_sync)
class BaseModel(nn.Module):
@ -67,6 +68,10 @@ class BaseModel(nn.Module):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self
def load(self, weights):
# Force all tasks implement this function
raise NotImplementedError("This function needs to be implemented by derived classes!")
class DetectionModel(BaseModel):
# YOLO detection model
@ -166,6 +171,12 @@ class DetectionModel(BaseModel):
b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def load(self, weights):
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load
class SegmentationModel(DetectionModel):
# YOLOv5 segmentation model
@ -197,3 +208,9 @@ class ClassificationModel(BaseModel):
def _from_yaml(self, cfg):
# Create a YOLOv5 classification model from a *.yaml file
self.model = None
def load(self, weights):
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
self.load_state_dict(csd, strict=False) # load