pull/13255/head
Yothin Saengsakoon 2024-08-06 09:57:18 +07:00
parent f4962acb8e
commit 893229ad5d
3 changed files with 102 additions and 3 deletions

View File

@ -279,6 +279,51 @@ class C3Ghost(C3):
c_ = int(c2 * e) # hidden channels
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
class CAM(nn.Module):
def __init__(self, channels, r):
super(CAM, self).__init__()
self.channels = channels
self.r = r
self.linear = nn.Sequential(
nn.Linear(in_features=self.channels, out_features=self.channels//self.r, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=self.channels//self.r, out_features=self.channels, bias=True))
def forward(self, x):
max = nn.functional.adaptive_max_pool2d(x, output_size=1)
avg = nn.functional.adaptive_avg_pool2d(x, output_size=1)
b, c, _, _ = x.size()
linear_max = self.linear(max.view(b,c)).view(b, c, 1, 1)
linear_avg = self.linear(avg.view(b,c)).view(b, c, 1, 1)
output = linear_max + linear_avg
output = nn.functional.sigmoid(output) * x
return output
class SAM(nn.Module):
def __init__(self, bias=False):
super(SAM, self).__init__()
self.bias = bias
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, bias=self.bias)
def forward(self, x):
max = torch.max(x,1)[0].unsqueeze(1)
avg = torch.mean(x,1).unsqueeze(1)
concat = torch.cat((max,avg), dim=1)
output = self.conv(concat)
output = nn.functional.sigmoid(output) * x
return output
class CBAM(nn.Module):
def __init__(self, channels, r):
super(CBAM, self).__init__()
self.channels = channels
self.r = r
self.sam = SAM(bias=False)
self.cam = CAM(channels=self.channels, r=self.r)
def forward(self, x):
output = self.cam(x)
output = self.sam(output)
return output + x
class SPP(nn.Module):
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
@ -1066,9 +1111,6 @@ class Classify(nn.Module):
def __init__(
self, c1, c2, k=1, s=1, p=None, g=1, dropout_p=0.0
): # ch_in, ch_out, kernel, stride, padding, groups, dropout probability
"""Initializes YOLOv5 classification head with convolution, pooling, and dropout layers for input to output
channel transformation.
"""
super().__init__()
c_ = 1280 # efficientnet_b0 size
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)

View File

@ -48,6 +48,7 @@ from models.common import (
GhostBottleneck,
GhostConv,
Proto,
CBAM,
)
from models.experimental import MixConv2d
from utils.autoanchor import check_anchor_order
@ -415,7 +416,9 @@ def parse_model(d, ch):
nn.ConvTranspose2d,
DWConvTranspose2d,
C3x,
CBAM
}:
"""c1 = number previous chanel ,c2 = number output chanel"""
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, ch_mul)
@ -439,6 +442,8 @@ def parse_model(d, ch):
c2 = ch[f] * args[0] ** 2
elif m is Expand:
c2 = ch[f] // args[0] ** 2
elif m is CBAM:
c2 = c1
else:
c2 = ch[f]

View File

@ -0,0 +1,52 @@
# Ultralytics YOLOv5 🚀, AGPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10, 13, 16, 30, 33, 23] # P3/8
- [30, 61, 62, 45, 59, 119] # P4/16
- [116, 90, 156, 198, 373, 326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head: [
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, "nearest"]],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 3, CBAM, [16]], # 18 (CBAM)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 21 (P4/16-medium)
[-1, 3, CBAM, [16]], # 22 (CBAM)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 25 (P5/32-large)
[-1, 3, CBAM, [16]], # 26 (CBAM)
[[18, 22, 26], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]