Improve tests coverage and speed (#4340)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import constant_, xavier_uniform_
|
||||
|
||||
from ultralytics.utils.tal import dist2bbox, make_anchors
|
||||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors
|
||||
|
||||
from .block import DFL, Proto
|
||||
from .conv import Conv
|
||||
@ -267,9 +267,9 @@ class RTDETRDecoder(nn.Module):
|
||||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
||||
anchors = []
|
||||
for i, (h, w) in enumerate(shapes):
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device),
|
||||
torch.arange(end=w, dtype=dtype, device=device),
|
||||
indexing='ij')
|
||||
sy = torch.arange(end=h, dtype=dtype, device=device)
|
||||
sx = torch.arange(end=w, dtype=dtype, device=device)
|
||||
grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
||||
|
||||
valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
|
||||
|
@ -22,6 +22,10 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
|
||||
super().__init__()
|
||||
from ...utils.torch_utils import TORCH_1_9
|
||||
if not TORCH_1_9:
|
||||
raise ModuleNotFoundError(
|
||||
'TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True).')
|
||||
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
|
||||
# Implementation of Feedforward model
|
||||
self.fc1 = nn.Linear(c1, cm)
|
||||
|
Reference in New Issue
Block a user