193 lines
6.5 KiB
Python
193 lines
6.5 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock
|
||
|
from basicsr.models.archs.arch_util import MySequential
|
||
|
from basicsr.models.archs.local_arch import Local_Base
|
||
|
|
||
|
class GenerateRelations(nn.Module):
|
||
|
def __init__(self, c):
|
||
|
super().__init__()
|
||
|
self.scale = c ** -0.5
|
||
|
|
||
|
self.norm_l = LayerNorm2d(c)
|
||
|
self.norm_r = LayerNorm2d(c)
|
||
|
|
||
|
self.l_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
||
|
self.r_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
||
|
|
||
|
def forward(self, lfeats, rfeats):
|
||
|
B, C, H, W = lfeats.shape
|
||
|
|
||
|
lfeats = lfeats.view(B, C, H, W)
|
||
|
rfeats = rfeats.view(B, C, H, W)
|
||
|
|
||
|
lfeats, rfeats = self.l_proj(self.norm_l(lfeats)), self.r_proj(self.norm_r(rfeats))
|
||
|
|
||
|
x = lfeats.permute(0, 2, 3, 1) #B H W c
|
||
|
y = rfeats.permute(0, 2, 1, 3) #B H c W
|
||
|
|
||
|
z = torch.matmul(x, y) #B H W W
|
||
|
|
||
|
return self.scale * z
|
||
|
|
||
|
class FusionModule(nn.Module):
|
||
|
def __init__(self, c):
|
||
|
super().__init__()
|
||
|
self.relation_generator = GenerateRelations(c)
|
||
|
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
||
|
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
||
|
|
||
|
self.l_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
||
|
self.r_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
||
|
|
||
|
def forward(self, lfeats, rfeats):
|
||
|
B, C, H, W = lfeats.shape
|
||
|
|
||
|
relations = self.relation_generator(lfeats, rfeats) # B, H, W, W
|
||
|
|
||
|
lfeats_projected = self.l_proj(lfeats.view(B, C, H, W)).permute(0, 2, 3, 1) # B, H, W, c
|
||
|
rfeats_projected = self.r_proj(rfeats.view(B, C, H, W)).permute(0, 2, 3, 1) # B, H, W, c
|
||
|
|
||
|
lresidual = torch.matmul(torch.softmax(relations, dim=-1), rfeats_projected) #B, H, W, c
|
||
|
rresidual = torch.matmul(torch.softmax(relations.permute(0, 1, 3, 2), dim=-1), lfeats_projected) #B, H, W, c
|
||
|
|
||
|
lresidual = lresidual.permute(0, 3, 1, 2).view(B, C, H, W) * self.beta
|
||
|
rresidual = rresidual.permute(0, 3, 1, 2).view(B, C, H, W) * self.gamma
|
||
|
return lfeats + lresidual, rfeats + rresidual
|
||
|
|
||
|
class DropPath(nn.Module):
|
||
|
def __init__(self, drop_rate, module):
|
||
|
super().__init__()
|
||
|
self.drop_rate = drop_rate
|
||
|
self.module = module
|
||
|
|
||
|
def forward(self, *feats):
|
||
|
if self.training and np.random.rand() < self.drop_rate:
|
||
|
return feats
|
||
|
|
||
|
new_feats = self.module(*feats)
|
||
|
factor = 1. / (1 - self.drop_rate) if self.training else 1.
|
||
|
|
||
|
if self.training and factor != 1.:
|
||
|
new_feats = tuple([x+factor*(new_x-x) for x, new_x in zip(feats, new_feats)])
|
||
|
return new_feats
|
||
|
|
||
|
class NAFBlockSR(nn.Module):
|
||
|
def __init__(self, c, fusion=False, drop_out_rate=0.):
|
||
|
super().__init__()
|
||
|
self.blk = NAFBlock(c, drop_out_rate=drop_out_rate)
|
||
|
self.fusion = FusionModule(c) if fusion else None
|
||
|
|
||
|
def forward(self, *feats):
|
||
|
feats = tuple([self.blk(x) for x in feats])
|
||
|
if self.fusion:
|
||
|
feats = self.fusion(*feats)
|
||
|
return feats
|
||
|
|
||
|
|
||
|
class NAFNetSR(nn.Module):
|
||
|
def __init__(self, img_channel=3, width=16, num_blks=1, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=True, up_scale=4):
|
||
|
super().__init__()
|
||
|
self.dual = dual
|
||
|
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
||
|
bias=True)
|
||
|
self.body = MySequential(
|
||
|
*[DropPath(
|
||
|
drop_path_rate,
|
||
|
NAFBlockSR(
|
||
|
width,
|
||
|
fusion=(fusion_from <= i and i <= fusion_to),
|
||
|
drop_out_rate=drop_out_rate
|
||
|
)) for i in range(num_blks)]
|
||
|
)
|
||
|
|
||
|
self.up = nn.Sequential(
|
||
|
nn.Conv2d(in_channels=width, out_channels=img_channel * up_scale**2, kernel_size=3, padding=1, stride=1, groups=1, bias=True),
|
||
|
nn.PixelShuffle(up_scale)
|
||
|
)
|
||
|
self.up_scale = up_scale
|
||
|
|
||
|
def forward(self, inp):
|
||
|
inp_hr = F.interpolate(inp, scale_factor=self.up_scale, mode='bilinear')
|
||
|
if self.dual:
|
||
|
inp = inp.chunk(2, dim=1)
|
||
|
else:
|
||
|
inp = (inp, )
|
||
|
feats = [self.intro(x) for x in inp]
|
||
|
feats = self.body(*feats)
|
||
|
out = torch.cat([self.up(x) for x in feats], dim=1)
|
||
|
out = out + inp_hr
|
||
|
return out
|
||
|
|
||
|
|
||
|
class NAFNetSRLocal(Local_Base, NAFNetSR):
|
||
|
def __init__(self, *args, train_size=(1, 6, 64, 64), fast_imp=False, **kwargs):
|
||
|
Local_Base.__init__(self)
|
||
|
NAFNetSR.__init__(self, *args, **kwargs)
|
||
|
|
||
|
N, C, H, W = train_size
|
||
|
base_size = (int(H * 1.5), int(W * 1.5))
|
||
|
|
||
|
self.eval()
|
||
|
with torch.no_grad():
|
||
|
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
img_channel = 3
|
||
|
num_blks = 64
|
||
|
width = 96
|
||
|
# num_blks = 32
|
||
|
# width = 64
|
||
|
# num_blks = 16
|
||
|
# width = 48
|
||
|
dual=True
|
||
|
# fusion_from = 0
|
||
|
# fusion_to = num_blks
|
||
|
fusion_from = 0
|
||
|
fusion_to = 1000
|
||
|
droppath=0.1
|
||
|
train_size = (1, 6, 30, 90)
|
||
|
|
||
|
net = NAFNetSRLocal(up_scale=2,train_size=train_size, fast_imp=True, img_channel=img_channel, width=width, num_blks=num_blks, dual=dual,
|
||
|
fusion_from=fusion_from,
|
||
|
fusion_to=fusion_to, drop_path_rate=droppath)
|
||
|
# net = NAFNetSR(img_channel=img_channel, width=width, num_blks=num_blks, dual=dual,
|
||
|
# fusion_from=fusion_from,
|
||
|
# fusion_to=fusion_to, drop_path_rate=droppath)
|
||
|
|
||
|
c = 6 if dual else 3
|
||
|
|
||
|
a = torch.randn((2, c, 24, 23))
|
||
|
|
||
|
b = net(a)
|
||
|
|
||
|
print(b.shape)
|
||
|
|
||
|
# inp_shape = (6, 128, 128)
|
||
|
|
||
|
inp_shape = (c, 64, 64)
|
||
|
|
||
|
# inp_shape = (6, 256, 96)
|
||
|
|
||
|
from ptflops import get_model_complexity_info
|
||
|
FLOPS = 0
|
||
|
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
||
|
|
||
|
# params = float(params[:-4])
|
||
|
print(params)
|
||
|
macs = float(macs[:-4]) + FLOPS / 10 ** 9
|
||
|
|
||
|
print('mac', macs, params, 'fusion from .. to ', fusion_from, fusion_to)
|
||
|
|
||
|
# from basicsr.models.archs.arch_util import measure_inference_speed
|
||
|
# net = net.cuda()
|
||
|
# data = torch.randn((1, 6, 128, 128)).cuda()
|
||
|
# measure_inference_speed(net, (data,))
|
||
|
|
||
|
|
||
|
|
||
|
|