Auto-format by https://ultralytics.com/actions
parent
90e5c5845e
commit
9ec3e9ca2f
110
models/lsknet.py
110
models/lsknet.py
|
@ -1,19 +1,20 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
# from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
||||
# trunc_normal_init)
|
||||
# from ..builder import ROTATED_BACKBONES
|
||||
# from mmcv.runner import BaseModule
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
import math
|
||||
from functools import partial
|
||||
import warnings
|
||||
from timm.models.layers import DropPath, to_2tuple
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
# from mmcv.cnn import build_norm_layer
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
@ -38,10 +39,10 @@ class LSKblock(nn.Module):
|
|||
super().__init__()
|
||||
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
||||
self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
|
||||
self.conv1 = nn.Conv2d(dim, dim//2, 1)
|
||||
self.conv2 = nn.Conv2d(dim, dim//2, 1)
|
||||
self.conv1 = nn.Conv2d(dim, dim // 2, 1)
|
||||
self.conv2 = nn.Conv2d(dim, dim // 2, 1)
|
||||
self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
|
||||
self.conv = nn.Conv2d(dim//2, dim, 1)
|
||||
self.conv = nn.Conv2d(dim // 2, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
attn1 = self.conv0(x)
|
||||
|
@ -55,12 +56,11 @@ class LSKblock(nn.Module):
|
|||
max_attn, _ = torch.max(attn, dim=1, keepdim=True)
|
||||
agg = torch.cat([avg_attn, max_attn], dim=1)
|
||||
sig = self.conv_squeeze(agg).sigmoid()
|
||||
attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1)
|
||||
attn = attn1 * sig[:, 0, :, :].unsqueeze(1) + attn2 * sig[:, 1, :, :].unsqueeze(1)
|
||||
attn = self.conv(attn)
|
||||
return x * attn
|
||||
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, d_model):
|
||||
super().__init__()
|
||||
|
@ -71,17 +71,17 @@ class Attention(nn.Module):
|
|||
self.proj_2 = nn.Conv2d(d_model, d_model, 1)
|
||||
|
||||
def forward(self, x):
|
||||
shorcut = x.clone()
|
||||
shortcut = x.clone()
|
||||
x = self.proj_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.spatial_gating_unit(x)
|
||||
x = self.proj_2(x)
|
||||
x = x + shorcut
|
||||
x = x + shortcut
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU, norm_cfg=None):
|
||||
def __init__(self, dim, mlp_ratio=4.0, drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_cfg=None):
|
||||
super().__init__()
|
||||
# if norm_cfg:
|
||||
# self.norm1 = torch.nn.SyncBatchNorm(norm_cfg, dim)[1]
|
||||
|
@ -93,15 +93,13 @@ class Block(nn.Module):
|
|||
self.norm2 = nn.BatchNorm2d(dim)
|
||||
|
||||
self.attn = Attention(dim)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
# self.drop_path = nn.Identity()
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
layer_scale_init_value = 1e-2
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
|
||||
|
@ -110,21 +108,20 @@ class Block(nn.Module):
|
|||
|
||||
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=None):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)
|
||||
)
|
||||
# if norm_cfg:
|
||||
# self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
|
||||
# else:
|
||||
# self.norm = nn.BatchNorm2d(embed_dim)
|
||||
self.norm = nn.BatchNorm2d(embed_dim)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
|
@ -133,21 +130,29 @@ class OverlapPatchEmbed(nn.Module):
|
|||
|
||||
|
||||
class LSKNet(nn.Module):
|
||||
def __init__(self, in_chans=3, embed_dims=[64, 128, 256, 512],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
depths=[3, 4, 6, 3],
|
||||
num_stages=4,
|
||||
norm_cfg=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
mlp_ratios=[8, 8, 4, 4],
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
depths=[3, 4, 6, 3],
|
||||
num_stages=4,
|
||||
norm_cfg=None,
|
||||
):
|
||||
super().__init__()
|
||||
###
|
||||
embed_dims=[16,]
|
||||
depths=[4,]
|
||||
drop_rate=0.1
|
||||
drop_path_rate=0.1
|
||||
num_stages=1
|
||||
embed_dims = [
|
||||
16,
|
||||
]
|
||||
depths = [
|
||||
4,
|
||||
]
|
||||
drop_rate = 0.1
|
||||
drop_path_rate = 0.1
|
||||
num_stages = 1
|
||||
|
||||
self.depths = depths
|
||||
self.num_stages = num_stages
|
||||
|
@ -156,14 +161,26 @@ class LSKNet(nn.Module):
|
|||
cur = 0
|
||||
|
||||
for i in range(num_stages):
|
||||
patch_embed = OverlapPatchEmbed(patch_size=3 if i == 0 else 3,
|
||||
stride=2 if i == 0 else 2,
|
||||
in_chans=in_chans if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i], norm_cfg=norm_cfg)
|
||||
patch_embed = OverlapPatchEmbed(
|
||||
patch_size=3 if i == 0 else 3,
|
||||
stride=2 if i == 0 else 2,
|
||||
in_chans=in_chans if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i],
|
||||
norm_cfg=norm_cfg,
|
||||
)
|
||||
|
||||
block = nn.ModuleList([Block(
|
||||
dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j],norm_cfg=norm_cfg)
|
||||
for j in range(depths[i])])
|
||||
block = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=embed_dims[i],
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
drop=drop_rate,
|
||||
drop_path=dpr[cur + j],
|
||||
norm_cfg=norm_cfg,
|
||||
)
|
||||
for j in range(depths[i])
|
||||
]
|
||||
)
|
||||
norm = norm_layer(embed_dims[i])
|
||||
cur += depths[i]
|
||||
|
||||
|
@ -204,7 +221,7 @@ class DWConv(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# model = LSKNet(embed_dims=[32, 64, 160, 256], depths=[3, 3, 5, 2], drop_rate=0.1, drop_path_rate=0.1)
|
||||
# model = LSKNet(embed_dims=[64, 128, 256, 512], depths=[2, 2, 4, 2], drop_rate=0.1, drop_path_rate=0.1)
|
||||
# model = LSKNet(embed_dims=[64,], depths=[4,], drop_rate=0.1, drop_path_rate=0.1, num_stages=1)
|
||||
|
@ -212,6 +229,3 @@ if __name__ == '__main__':
|
|||
inputs = torch.randn((1, 3, 640, 640))
|
||||
for i in model(inputs):
|
||||
print(i.size())
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -49,9 +49,8 @@ from models.common import (
|
|||
GhostConv,
|
||||
Proto,
|
||||
)
|
||||
from models.lsknet import LSKNet
|
||||
|
||||
from models.experimental import MixConv2d
|
||||
from models.lsknet import LSKNet
|
||||
from utils.autoanchor import check_anchor_order
|
||||
from utils.general import LOGGER, check_version, check_yaml, colorstr, make_divisible, print_args
|
||||
from utils.plots import feature_visualization
|
||||
|
|
Loading…
Reference in New Issue