diff --git a/models/common.py b/models/common.py index b1e060b..401de2f 100644 --- a/models/common.py +++ b/models/common.py @@ -446,4 +446,182 @@ class HarDBlock2(nn.Module): outs_.append(xin) out = torch.cat(outs_, 1) - return out \ No newline at end of file + return out + +class ConvSig(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super(ConvSig, self).__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.act = nn.Sigmoid() if act else nn.Identity() + + def forward(self, x): + return self.act(self.conv(x)) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class ConvSqu(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super(ConvSqu, self).__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.act = Mish() if act else nn.Identity() + + def forward(self, x): + return self.act(self.conv(x)) + + def fuseforward(self, x): + return self.act(self.conv(x)) + +''' +class SE(nn.Module): + # Squeeze-and-excitation block in https://arxiv.org/abs/1709.01507 + def __init__(self, c1, c2, n=1, shortcut=True, g=8, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(SE, self).__init__() + c_ = int(c2) # hidden channels + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.cs = ConvSqu(c1, c1//g, 1, 1) + self.cvsig = ConvSig(c1//g, c1, 1, 1) + + def forward(self, x): + return x = x * self.cvsig(self.cs(self.avg_pool(x))).expand_as(x) + +class SAM(nn.Module): + # SAM block in yolov4 + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(SAM, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cvsig = ConvSig(c1, c1, 1, 1) + + def forward(self, x): + return x = x * self.cvsig(x) + +class DNL(nn.Module): + # Disentangled Non-Local block in https://arxiv.org/abs/2006.06668 + def __init__(self, c1, c2, k=3, s=1): + super(DNL, self).__init__() + c_ = int(c1) # hidden channels + + # + self.conv_query = nn.Conv2d(c1, c_, kernel_size=1) + self.conv_key = nn.Conv2d(c1, c_, kernel_size=1) + + self.conv_value = nn.Conv2d(c1, c1, kernel_size=1, bias=False) + self.conv_out = None + + self.scale = math.sqrt(c_) + self.temperature = 0.05 + + self.softmax = nn.Softmax(dim=2) + + self.gamma = nn.Parameter(torch.zeros(1)) + + self.conv_mask = nn.Conv2d(c1, 1, kernel_size=1) + + self.cv = Conv(c1, c2, k, s) + + def forward(self, x): + + # [N, C, T, H, W] + residual = x + + # [N, C, T, H', W'] + input_x = x + + # [N, C', T, H, W] + query = self.conv_query(x) + + # [N, C', T, H', W'] + key = self.conv_key(input_x) + value = self.conv_value(input_x) + + # [N, C', H x W] + query = query.view(query.size(0), query.size(1), -1) + + # [N, C', H' x W'] + key = key.view(key.size(0), key.size(1), -1) + value = value.view(value.size(0), value.size(1), -1) + + # channel whitening + key_mean = key.mean(2).unsqueeze(2) + query_mean = query.mean(2).unsqueeze(2) + key -= key_mean + query -= query_mean + + # [N, T x H x W, T x H' x W'] + sim_map = torch.bmm(query.transpose(1, 2), key) + sim_map = sim_map/self.scale + sim_map = sim_map/self.temperature + sim_map = self.softmax(sim_map) + + # [N, T x H x W, C'] + out_sim = torch.bmm(sim_map, value.transpose(1, 2)) + + # [N, C', T x H x W] + out_sim = out_sim.transpose(1, 2) + + # [N, C', T, H, W] + out_sim = out_sim.view(out_sim.size(0), out_sim.size(1), *x.size()[2:]) + out_sim = self.gamma * out_sim + + # [N, 1, H', W'] + mask = self.conv_mask(input_x) + # [N, 1, H'x W'] + mask = mask.view(mask.size(0), mask.size(1), -1) + mask = self.softmax(mask) + # [N, C, 1, 1] + out_gc = torch.bmm(value, mask.permute(0,2,1)).unsqueeze(-1) + out_sim = out_sim+out_gc + + return self.cv(out_sim + residual) + + +class GC(nn.Module): + # global context block in https://arxiv.org/abs/1904.11492 + def __init__(self, c1, c2, k=3, s=1): + super(GC, self).__init__() + c_ = int(c1) # hidden channels + + # + self.channel_add_conv = nn.Sequential( + nn.Conv2d(c1, c_, kernel_size=1), + nn.LayerNorm([c_, 1, 1]), + nn.ReLU(inplace=True), # yapf: disable + nn.Conv2d(c_, c1, kernel_size=1)) + + self.conv_mask = nn.Conv2d(c_, 1, kernel_size=1) + self.softmax = nn.Softmax(dim=2) + + self.cv = Conv(c1, c2, k, s) + + + def spatial_pool(self, x): + + batch, channel, height, width = x.size() + + input_x = x + # [N, C, H * W] + input_x = input_x.view(batch, channel, height * width) + # [N, 1, C, H * W] + input_x = input_x.unsqueeze(1) + # [N, 1, H, W] + context_mask = self.conv_mask(x) + # [N, 1, H * W] + context_mask = context_mask.view(batch, 1, height * width) + # [N, 1, H * W] + context_mask = self.softmax(context_mask) + # [N, 1, H * W, 1] + context_mask = context_mask.unsqueeze(-1) + # [N, 1, C, 1] + context = torch.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = context.view(batch, channel, 1, 1) + + return context + + def forward(self, x): + + return self.cv(x + self.channel_add_conv(self.spatial_pool(x))) +'''