Start export implementation (#110)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-12-29 14:17:14 +01:00
committed by GitHub
parent c1b38428bc
commit 92dad1c1b5
32 changed files with 827 additions and 222 deletions

View File

@ -19,7 +19,7 @@ from ultralytics.yolo.utils.ops import xywh2xyxy
class AutoBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript

View File

@ -6,12 +6,12 @@ import thop
import torch
import torch.nn as nn
import torchvision
import yaml
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import LOGGER, colorstr
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts,
make_divisible, model_info, scale_img, time_sync)
@ -78,14 +78,9 @@ class BaseModel(nn.Module):
class DetectionModel(BaseModel):
# YOLOv5 detection model
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
super().__init__()
if isinstance(cfg, dict):
self.yaml = cfg # model dict
else: # is *.yaml
self.yaml_file = Path(cfg).name
with open(cfg, encoding='ascii', errors='ignore') as f:
self.yaml = yaml.safe_load(f) # model dict
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(cfg) # cfg dict
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
@ -163,7 +158,7 @@ class DetectionModel(BaseModel):
class SegmentationModel(DetectionModel):
# YOLOv5 segmentation model
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, verbose=True):
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
super().__init__(cfg, ch, nc, verbose)