Update common.py
parent
0fc870c8bc
commit
7573b62a50
180
models/common.py
180
models/common.py
|
@ -446,4 +446,182 @@ class HarDBlock2(nn.Module):
|
|||
outs_.append(xin)
|
||||
|
||||
out = torch.cat(outs_, 1)
|
||||
return out
|
||||
return out
|
||||
|
||||
class ConvSig(nn.Module):
|
||||
# Standard convolution
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
super(ConvSig, self).__init__()
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
||||
self.act = nn.Sigmoid() if act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
def fuseforward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
|
||||
class ConvSqu(nn.Module):
|
||||
# Standard convolution
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
super(ConvSqu, self).__init__()
|
||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
||||
self.act = Mish() if act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
def fuseforward(self, x):
|
||||
return self.act(self.conv(x))
|
||||
|
||||
'''
|
||||
class SE(nn.Module):
|
||||
# Squeeze-and-excitation block in https://arxiv.org/abs/1709.01507
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=8, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super(SE, self).__init__()
|
||||
c_ = int(c2) # hidden channels
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.cs = ConvSqu(c1, c1//g, 1, 1)
|
||||
self.cvsig = ConvSig(c1//g, c1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return x = x * self.cvsig(self.cs(self.avg_pool(x))).expand_as(x)
|
||||
|
||||
class SAM(nn.Module):
|
||||
# SAM block in yolov4
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
||||
super(SAM, self).__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cvsig = ConvSig(c1, c1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return x = x * self.cvsig(x)
|
||||
|
||||
class DNL(nn.Module):
|
||||
# Disentangled Non-Local block in https://arxiv.org/abs/2006.06668
|
||||
def __init__(self, c1, c2, k=3, s=1):
|
||||
super(DNL, self).__init__()
|
||||
c_ = int(c1) # hidden channels
|
||||
|
||||
#
|
||||
self.conv_query = nn.Conv2d(c1, c_, kernel_size=1)
|
||||
self.conv_key = nn.Conv2d(c1, c_, kernel_size=1)
|
||||
|
||||
self.conv_value = nn.Conv2d(c1, c1, kernel_size=1, bias=False)
|
||||
self.conv_out = None
|
||||
|
||||
self.scale = math.sqrt(c_)
|
||||
self.temperature = 0.05
|
||||
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
self.gamma = nn.Parameter(torch.zeros(1))
|
||||
|
||||
self.conv_mask = nn.Conv2d(c1, 1, kernel_size=1)
|
||||
|
||||
self.cv = Conv(c1, c2, k, s)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# [N, C, T, H, W]
|
||||
residual = x
|
||||
|
||||
# [N, C, T, H', W']
|
||||
input_x = x
|
||||
|
||||
# [N, C', T, H, W]
|
||||
query = self.conv_query(x)
|
||||
|
||||
# [N, C', T, H', W']
|
||||
key = self.conv_key(input_x)
|
||||
value = self.conv_value(input_x)
|
||||
|
||||
# [N, C', H x W]
|
||||
query = query.view(query.size(0), query.size(1), -1)
|
||||
|
||||
# [N, C', H' x W']
|
||||
key = key.view(key.size(0), key.size(1), -1)
|
||||
value = value.view(value.size(0), value.size(1), -1)
|
||||
|
||||
# channel whitening
|
||||
key_mean = key.mean(2).unsqueeze(2)
|
||||
query_mean = query.mean(2).unsqueeze(2)
|
||||
key -= key_mean
|
||||
query -= query_mean
|
||||
|
||||
# [N, T x H x W, T x H' x W']
|
||||
sim_map = torch.bmm(query.transpose(1, 2), key)
|
||||
sim_map = sim_map/self.scale
|
||||
sim_map = sim_map/self.temperature
|
||||
sim_map = self.softmax(sim_map)
|
||||
|
||||
# [N, T x H x W, C']
|
||||
out_sim = torch.bmm(sim_map, value.transpose(1, 2))
|
||||
|
||||
# [N, C', T x H x W]
|
||||
out_sim = out_sim.transpose(1, 2)
|
||||
|
||||
# [N, C', T, H, W]
|
||||
out_sim = out_sim.view(out_sim.size(0), out_sim.size(1), *x.size()[2:])
|
||||
out_sim = self.gamma * out_sim
|
||||
|
||||
# [N, 1, H', W']
|
||||
mask = self.conv_mask(input_x)
|
||||
# [N, 1, H'x W']
|
||||
mask = mask.view(mask.size(0), mask.size(1), -1)
|
||||
mask = self.softmax(mask)
|
||||
# [N, C, 1, 1]
|
||||
out_gc = torch.bmm(value, mask.permute(0,2,1)).unsqueeze(-1)
|
||||
out_sim = out_sim+out_gc
|
||||
|
||||
return self.cv(out_sim + residual)
|
||||
|
||||
|
||||
class GC(nn.Module):
|
||||
# global context block in https://arxiv.org/abs/1904.11492
|
||||
def __init__(self, c1, c2, k=3, s=1):
|
||||
super(GC, self).__init__()
|
||||
c_ = int(c1) # hidden channels
|
||||
|
||||
#
|
||||
self.channel_add_conv = nn.Sequential(
|
||||
nn.Conv2d(c1, c_, kernel_size=1),
|
||||
nn.LayerNorm([c_, 1, 1]),
|
||||
nn.ReLU(inplace=True), # yapf: disable
|
||||
nn.Conv2d(c_, c1, kernel_size=1))
|
||||
|
||||
self.conv_mask = nn.Conv2d(c_, 1, kernel_size=1)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
self.cv = Conv(c1, c2, k, s)
|
||||
|
||||
|
||||
def spatial_pool(self, x):
|
||||
|
||||
batch, channel, height, width = x.size()
|
||||
|
||||
input_x = x
|
||||
# [N, C, H * W]
|
||||
input_x = input_x.view(batch, channel, height * width)
|
||||
# [N, 1, C, H * W]
|
||||
input_x = input_x.unsqueeze(1)
|
||||
# [N, 1, H, W]
|
||||
context_mask = self.conv_mask(x)
|
||||
# [N, 1, H * W]
|
||||
context_mask = context_mask.view(batch, 1, height * width)
|
||||
# [N, 1, H * W]
|
||||
context_mask = self.softmax(context_mask)
|
||||
# [N, 1, H * W, 1]
|
||||
context_mask = context_mask.unsqueeze(-1)
|
||||
# [N, 1, C, 1]
|
||||
context = torch.matmul(input_x, context_mask)
|
||||
# [N, C, 1, 1]
|
||||
context = context.view(batch, channel, 1, 1)
|
||||
|
||||
return context
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return self.cv(x + self.channel_add_conv(self.spatial_pool(x)))
|
||||
'''
|
||||
|
|
Loading…
Reference in New Issue