171 lines
5.9 KiB
Python
171 lines
5.9 KiB
Python
# ------------------------------------------------------------------------
|
|
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
|
# ------------------------------------------------------------------------
|
|
|
|
'''
|
|
NAFSSR: Stereo Image Super-Resolution Using NAFNet
|
|
|
|
@InProceedings{Chu2022NAFSSR,
|
|
author = {Xiaojie Chu and Liangyu Chen and Wenqing Yu},
|
|
title = {NAFSSR: Stereo Image Super-Resolution Using NAFNet},
|
|
booktitle = {CVPRW},
|
|
year = {2022},
|
|
}
|
|
'''
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock
|
|
from basicsr.models.archs.arch_util import MySequential
|
|
from basicsr.models.archs.local_arch import Local_Base
|
|
|
|
class SCAM(nn.Module):
|
|
'''
|
|
Stereo Cross Attention Module (SCAM)
|
|
'''
|
|
def __init__(self, c):
|
|
super().__init__()
|
|
self.scale = c ** -0.5
|
|
|
|
self.norm_l = LayerNorm2d(c)
|
|
self.norm_r = LayerNorm2d(c)
|
|
self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
|
self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
|
|
|
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
|
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
|
|
|
self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
|
self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, x_l, x_r):
|
|
Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c
|
|
Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) # B, H, c, W (transposed)
|
|
|
|
V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) # B, H, W, c
|
|
V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) # B, H, W, c
|
|
|
|
# (B, H, W, c) x (B, H, c, W) -> (B, H, W, W)
|
|
attention = torch.matmul(Q_l, Q_r_T) * self.scale
|
|
|
|
F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) #B, H, W, c
|
|
F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) #B, H, W, c
|
|
|
|
# scale
|
|
F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta
|
|
F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma
|
|
return x_l + F_r2l, x_r + F_l2r
|
|
|
|
class DropPath(nn.Module):
|
|
def __init__(self, drop_rate, module):
|
|
super().__init__()
|
|
self.drop_rate = drop_rate
|
|
self.module = module
|
|
|
|
def forward(self, *feats):
|
|
if self.training and np.random.rand() < self.drop_rate:
|
|
return feats
|
|
|
|
new_feats = self.module(*feats)
|
|
factor = 1. / (1 - self.drop_rate) if self.training else 1.
|
|
|
|
if self.training and factor != 1.:
|
|
new_feats = tuple([x+factor*(new_x-x) for x, new_x in zip(feats, new_feats)])
|
|
return new_feats
|
|
|
|
class NAFBlockSR(nn.Module):
|
|
'''
|
|
NAFBlock for Super-Resolution
|
|
'''
|
|
def __init__(self, c, fusion=False, drop_out_rate=0.):
|
|
super().__init__()
|
|
self.blk = NAFBlock(c, drop_out_rate=drop_out_rate)
|
|
self.fusion = SCAM(c) if fusion else None
|
|
|
|
def forward(self, *feats):
|
|
feats = tuple([self.blk(x) for x in feats])
|
|
if self.fusion:
|
|
feats = self.fusion(*feats)
|
|
return feats
|
|
|
|
class NAFNetSR(nn.Module):
|
|
'''
|
|
NAFNet for Super-Resolution
|
|
'''
|
|
def __init__(self, up_scale=4, width=48, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=False):
|
|
super().__init__()
|
|
self.dual = dual # dual input for stereo SR (left view, right view)
|
|
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
|
bias=True)
|
|
self.body = MySequential(
|
|
*[DropPath(
|
|
drop_path_rate,
|
|
NAFBlockSR(
|
|
width,
|
|
fusion=(fusion_from <= i and i <= fusion_to),
|
|
drop_out_rate=drop_out_rate
|
|
)) for i in range(num_blks)]
|
|
)
|
|
|
|
self.up = nn.Sequential(
|
|
nn.Conv2d(in_channels=width, out_channels=img_channel * up_scale**2, kernel_size=3, padding=1, stride=1, groups=1, bias=True),
|
|
nn.PixelShuffle(up_scale)
|
|
)
|
|
self.up_scale = up_scale
|
|
|
|
def forward(self, inp):
|
|
inp_hr = F.interpolate(inp, scale_factor=self.up_scale, mode='bilinear')
|
|
if self.dual:
|
|
inp = inp.chunk(2, dim=1)
|
|
else:
|
|
inp = (inp, )
|
|
feats = [self.intro(x) for x in inp]
|
|
feats = self.body(*feats)
|
|
out = torch.cat([self.up(x) for x in feats], dim=1)
|
|
out = out + inp_hr
|
|
return out
|
|
|
|
class NAFSSR(Local_Base, NAFNetSR):
|
|
def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs):
|
|
Local_Base.__init__(self)
|
|
NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs)
|
|
|
|
N, C, H, W = train_size
|
|
base_size = (int(H * 1.5), int(W * 1.5))
|
|
|
|
self.eval()
|
|
with torch.no_grad():
|
|
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
|
|
|
|
if __name__ == '__main__':
|
|
num_blks = 128
|
|
width = 128
|
|
droppath=0.1
|
|
train_size = (1, 6, 30, 90)
|
|
|
|
net = NAFSSR(up_scale=2,train_size=train_size, fast_imp=True, width=width, num_blks=num_blks, drop_path_rate=droppath)
|
|
|
|
inp_shape = (6, 64, 64)
|
|
|
|
from ptflops import get_model_complexity_info
|
|
FLOPS = 0
|
|
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
|
|
|
|
# params = float(params[:-4])
|
|
print(params)
|
|
macs = float(macs[:-4]) + FLOPS / 10 ** 9
|
|
|
|
print('mac', macs, params)
|
|
|
|
# from basicsr.models.archs.arch_util import measure_inference_speed
|
|
# net = net.cuda()
|
|
# data = torch.randn((1, 6, 128, 128)).cuda()
|
|
# measure_inference_speed(net, (data,))
|
|
|
|
|
|
|
|
|