Update common.py

pull/182/head
Kin-Yiu, Wong 2021-01-19 15:00:28 +08:00 committed by GitHub
parent 0fc870c8bc
commit 7573b62a50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 179 additions and 1 deletions

View File

@ -446,4 +446,182 @@ class HarDBlock2(nn.Module):
outs_.append(xin)
out = torch.cat(outs_, 1)
return out
return out
class ConvSig(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(ConvSig, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.act = nn.Sigmoid() if act else nn.Identity()
def forward(self, x):
return self.act(self.conv(x))
def fuseforward(self, x):
return self.act(self.conv(x))
class ConvSqu(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(ConvSqu, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.act = Mish() if act else nn.Identity()
def forward(self, x):
return self.act(self.conv(x))
def fuseforward(self, x):
return self.act(self.conv(x))
'''
class SE(nn.Module):
# Squeeze-and-excitation block in https://arxiv.org/abs/1709.01507
def __init__(self, c1, c2, n=1, shortcut=True, g=8, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(SE, self).__init__()
c_ = int(c2) # hidden channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.cs = ConvSqu(c1, c1//g, 1, 1)
self.cvsig = ConvSig(c1//g, c1, 1, 1)
def forward(self, x):
return x = x * self.cvsig(self.cs(self.avg_pool(x))).expand_as(x)
class SAM(nn.Module):
# SAM block in yolov4
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(SAM, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cvsig = ConvSig(c1, c1, 1, 1)
def forward(self, x):
return x = x * self.cvsig(x)
class DNL(nn.Module):
# Disentangled Non-Local block in https://arxiv.org/abs/2006.06668
def __init__(self, c1, c2, k=3, s=1):
super(DNL, self).__init__()
c_ = int(c1) # hidden channels
#
self.conv_query = nn.Conv2d(c1, c_, kernel_size=1)
self.conv_key = nn.Conv2d(c1, c_, kernel_size=1)
self.conv_value = nn.Conv2d(c1, c1, kernel_size=1, bias=False)
self.conv_out = None
self.scale = math.sqrt(c_)
self.temperature = 0.05
self.softmax = nn.Softmax(dim=2)
self.gamma = nn.Parameter(torch.zeros(1))
self.conv_mask = nn.Conv2d(c1, 1, kernel_size=1)
self.cv = Conv(c1, c2, k, s)
def forward(self, x):
# [N, C, T, H, W]
residual = x
# [N, C, T, H', W']
input_x = x
# [N, C', T, H, W]
query = self.conv_query(x)
# [N, C', T, H', W']
key = self.conv_key(input_x)
value = self.conv_value(input_x)
# [N, C', H x W]
query = query.view(query.size(0), query.size(1), -1)
# [N, C', H' x W']
key = key.view(key.size(0), key.size(1), -1)
value = value.view(value.size(0), value.size(1), -1)
# channel whitening
key_mean = key.mean(2).unsqueeze(2)
query_mean = query.mean(2).unsqueeze(2)
key -= key_mean
query -= query_mean
# [N, T x H x W, T x H' x W']
sim_map = torch.bmm(query.transpose(1, 2), key)
sim_map = sim_map/self.scale
sim_map = sim_map/self.temperature
sim_map = self.softmax(sim_map)
# [N, T x H x W, C']
out_sim = torch.bmm(sim_map, value.transpose(1, 2))
# [N, C', T x H x W]
out_sim = out_sim.transpose(1, 2)
# [N, C', T, H, W]
out_sim = out_sim.view(out_sim.size(0), out_sim.size(1), *x.size()[2:])
out_sim = self.gamma * out_sim
# [N, 1, H', W']
mask = self.conv_mask(input_x)
# [N, 1, H'x W']
mask = mask.view(mask.size(0), mask.size(1), -1)
mask = self.softmax(mask)
# [N, C, 1, 1]
out_gc = torch.bmm(value, mask.permute(0,2,1)).unsqueeze(-1)
out_sim = out_sim+out_gc
return self.cv(out_sim + residual)
class GC(nn.Module):
# global context block in https://arxiv.org/abs/1904.11492
def __init__(self, c1, c2, k=3, s=1):
super(GC, self).__init__()
c_ = int(c1) # hidden channels
#
self.channel_add_conv = nn.Sequential(
nn.Conv2d(c1, c_, kernel_size=1),
nn.LayerNorm([c_, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(c_, c1, kernel_size=1))
self.conv_mask = nn.Conv2d(c_, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
self.cv = Conv(c1, c2, k, s)
def spatial_pool(self, x):
batch, channel, height, width = x.size()
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
return context
def forward(self, x):
return self.cv(x + self.channel_add_conv(self.spatial_pool(x)))
'''