diff --git a/models/common.py b/models/common.py index 049dfc0b9..f7c3cf4e4 100644 --- a/models/common.py +++ b/models/common.py @@ -279,6 +279,51 @@ class C3Ghost(C3): c_ = int(c2 * e) # hidden channels self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) +class CAM(nn.Module): + def __init__(self, channels, r): + super(CAM, self).__init__() + self.channels = channels + self.r = r + self.linear = nn.Sequential( + nn.Linear(in_features=self.channels, out_features=self.channels//self.r, bias=True), + nn.ReLU(inplace=True), + nn.Linear(in_features=self.channels//self.r, out_features=self.channels, bias=True)) + + def forward(self, x): + max = nn.functional.adaptive_max_pool2d(x, output_size=1) + avg = nn.functional.adaptive_avg_pool2d(x, output_size=1) + b, c, _, _ = x.size() + linear_max = self.linear(max.view(b,c)).view(b, c, 1, 1) + linear_avg = self.linear(avg.view(b,c)).view(b, c, 1, 1) + output = linear_max + linear_avg + output = nn.functional.sigmoid(output) * x + return output + +class SAM(nn.Module): + def __init__(self, bias=False): + super(SAM, self).__init__() + self.bias = bias + self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, bias=self.bias) + + def forward(self, x): + max = torch.max(x,1)[0].unsqueeze(1) + avg = torch.mean(x,1).unsqueeze(1) + concat = torch.cat((max,avg), dim=1) + output = self.conv(concat) + output = nn.functional.sigmoid(output) * x + return output +class CBAM(nn.Module): + def __init__(self, channels, r): + super(CBAM, self).__init__() + self.channels = channels + self.r = r + self.sam = SAM(bias=False) + self.cam = CAM(channels=self.channels, r=self.r) + + def forward(self, x): + output = self.cam(x) + output = self.sam(output) + return output + x class SPP(nn.Module): # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729 @@ -1066,9 +1111,6 @@ class Classify(nn.Module): def __init__( self, c1, c2, k=1, s=1, p=None, g=1, dropout_p=0.0 ): # ch_in, ch_out, kernel, stride, padding, groups, dropout probability - """Initializes YOLOv5 classification head with convolution, pooling, and dropout layers for input to output - channel transformation. - """ super().__init__() c_ = 1280 # efficientnet_b0 size self.conv = Conv(c1, c_, k, s, autopad(k, p), g) diff --git a/models/yolo.py b/models/yolo.py index d89c5da01..d99847b8e 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -48,6 +48,7 @@ from models.common import ( GhostBottleneck, GhostConv, Proto, + CBAM, ) from models.experimental import MixConv2d from utils.autoanchor import check_anchor_order @@ -415,7 +416,9 @@ def parse_model(d, ch): nn.ConvTranspose2d, DWConvTranspose2d, C3x, + CBAM }: + """c1 = number previous chanel ,c2 = number output chanel""" c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, ch_mul) @@ -439,6 +442,8 @@ def parse_model(d, ch): c2 = ch[f] * args[0] ** 2 elif m is Expand: c2 = ch[f] // args[0] ** 2 + elif m is CBAM: + c2 = c1 else: c2 = ch[f] diff --git a/models/yolov5s_CBAM.yaml b/models/yolov5s_CBAM.yaml new file mode 100644 index 000000000..111c74911 --- /dev/null +++ b/models/yolov5s_CBAM.yaml @@ -0,0 +1,52 @@ +# Ultralytics YOLOv5 🚀, AGPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple +anchors: + - [10, 13, 16, 30, 33, 23] # P3/8 + - [30, 61, 62, 45, 59, 119] # P4/16 + - [116, 90, 156, 198, 373, 326] # P5/32 + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [ + [-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 9 + ] + +# YOLOv5 v6.0 head +head: [ + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, "nearest"]], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, "nearest"]], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 17 (P3/8-small) + [-1, 3, CBAM, [16]], # 18 (CBAM) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 21 (P4/16-medium) + [-1, 3, CBAM, [16]], # 22 (CBAM) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 25 (P5/32-large) + [-1, 3, CBAM, [16]], # 26 (CBAM) + + [[18, 22, 26], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ]