add NAFSSR

pull/8/head
chuxiaojie 2022-04-19 15:50:14 +08:00
parent 386ca20359
commit f2809712e3
25 changed files with 1771 additions and 125 deletions

7
.gitignore vendored
View File

@ -1,2 +1,9 @@
.DS_Store
.idea/*
experiments
logs/
*results*
*__pycache__*
*.sh
datasets
basicsr.egg-info

View File

@ -183,6 +183,9 @@ class PairedImageSRLRDataset(data.Dataset):
class PairedStereoImageDataset(data.Dataset):
'''
Paired dataset for stereo SR (Flickr1024, KITTI, Middlebury)
'''
def __init__(self, opt):
super(PairedStereoImageDataset, self).__init__()
self.opt = opt
@ -255,14 +258,6 @@ class PairedStereoImageDataset(data.Dataset):
gt_size = int(self.opt['gt_size'])
gt_size_h, gt_size_w = gt_size, gt_size
if 'flip_LR' in self.opt and self.opt['flip_LR']:
if np.random.rand() < 0.5:
img_gt = img_gt[:, :, [3, 4, 5, 0, 1, 2]]
img_lq = img_lq[:, :, [3, 4, 5, 0, 1, 2]]
# img_gt, img_lq
if 'flip_RGB' in self.opt and self.opt['flip_RGB']:
idx = [
[0, 1, 2, 3, 4, 5],
@ -276,42 +271,6 @@ class PairedStereoImageDataset(data.Dataset):
img_gt = img_gt[:, :, idx]
img_lq = img_lq[:, :, idx]
if 'inverse_RGB' in self.opt and self.opt['inverse_RGB']:
for i in range(3):
if np.random.rand() < 0.5:
img_gt[:, :, i] = 1 - img_gt[:, :, i]
img_gt[:, :, i+3] = 1 - img_gt[:, :, i+3]
img_lq[:, :, i] = 1 - img_lq[:, :, i]
img_lq[:, :, i+3] = 1 - img_lq[:, :, i+3]
if 'naive_inverse_RGB' in self.opt and self.opt['naive_inverse_RGB']:
# for i in range(3):
if np.random.rand() < 0.5:
img_gt = 1 - img_gt
img_lq = 1 - img_lq
# img_gt[:, :, i] = 1 - img_gt[:, :, i]
# img_gt[:, :, i+3] = 1 - img_gt[:, :, i+3]
# img_lq[:, :, i] = 1 - img_lq[:, :, i]
# img_lq[:, :, i+3] = 1 - img_lq[:, :, i+3]
if 'random_offset' in self.opt and self.opt['random_offset'] > 0:
# if np.random.rand() < 0.9:
S = int(self.opt['random_offset'])
offsets = int(np.random.rand() * (S+1)) #1~S
s2, s4 = 0, 0
if np.random.rand() < 0.5:
s2 = offsets
else:
s4 = offsets
_, w, _ = img_lq.shape
img_lq = np.concatenate([img_lq[:, s2:w-s4, :3], img_lq[:, s4:w-s2, 3:]], axis=-1)
img_gt = np.concatenate(
[img_gt[:, 4 * s2:4*w-4 * s4, :3], img_gt[:, 4 * s4:4*w-4 * s2, 3:]], axis=-1)
# random crop
img_gt, img_lq = img_gt.copy(), img_lq.copy()
img_gt, img_lq = paired_random_crop_hw(img_gt, img_lq, gt_size_h, gt_size_w, scale,

View File

@ -22,56 +22,42 @@ from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock
from basicsr.models.archs.arch_util import MySequential
from basicsr.models.archs.local_arch import Local_Base
class GenerateRelations(nn.Module):
class SCAM(nn.Module):
'''
Stereo Cross Attention Module (SCAM)
'''
def __init__(self, c):
super().__init__()
self.scale = c ** -0.5
self.norm_l = LayerNorm2d(c)
self.norm_r = LayerNorm2d(c)
self.l_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
def forward(self, lfeats, rfeats):
B, C, H, W = lfeats.shape
lfeats = lfeats.view(B, C, H, W)
rfeats = rfeats.view(B, C, H, W)
lfeats, rfeats = self.l_proj(self.norm_l(lfeats)), self.r_proj(self.norm_r(rfeats))
x = lfeats.permute(0, 2, 3, 1) #B H W c
y = rfeats.permute(0, 2, 1, 3) #B H c W
z = torch.matmul(x, y) #B H W W
return self.scale * z
class FusionModule(nn.Module):
def __init__(self, c):
super().__init__()
self.relation_generator = GenerateRelations(c)
self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.l_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
def forward(self, lfeats, rfeats):
B, C, H, W = lfeats.shape
def forward(self, x_l, x_r):
Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c
Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) # B, H, c, W (transposed)
relations = self.relation_generator(lfeats, rfeats) # B, H, W, W
V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) # B, H, W, c
V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) # B, H, W, c
lfeats_projected = self.l_proj(lfeats.view(B, C, H, W)).permute(0, 2, 3, 1) # B, H, W, c
rfeats_projected = self.r_proj(rfeats.view(B, C, H, W)).permute(0, 2, 3, 1) # B, H, W, c
# (B, H, W, c) x (B, H, c, W) -> (B, H, W, W)
attention = torch.matmul(Q_l, Q_r_T) * self.scale
lresidual = torch.matmul(torch.softmax(relations, dim=-1), rfeats_projected) #B, H, W, c
rresidual = torch.matmul(torch.softmax(relations.permute(0, 1, 3, 2), dim=-1), lfeats_projected) #B, H, W, c
F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) #B, H, W, c
F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) #B, H, W, c
lresidual = lresidual.permute(0, 3, 1, 2).view(B, C, H, W) * self.beta
rresidual = rresidual.permute(0, 3, 1, 2).view(B, C, H, W) * self.gamma
return lfeats + lresidual, rfeats + rresidual
# scale
F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta
F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma
return x_l + F_r2l, x_r + F_l2r
class DropPath(nn.Module):
def __init__(self, drop_rate, module):
@ -91,10 +77,13 @@ class DropPath(nn.Module):
return new_feats
class NAFBlockSR(nn.Module):
def __init__(self, c, fusion=False, drop_out_rate=0.):
'''
NAFBlock for Super-Resolution
'''
def __init__(self, c, fusion=False, drop_out_rate=0.):
super().__init__()
self.blk = NAFBlock(c, drop_out_rate=drop_out_rate)
self.fusion = FusionModule(c) if fusion else None
self.fusion = SCAM(c) if fusion else None
def forward(self, *feats):
feats = tuple([self.blk(x) for x in feats])
@ -102,11 +91,13 @@ class NAFBlockSR(nn.Module):
feats = self.fusion(*feats)
return feats
class NAFNetSR(nn.Module):
def __init__(self, img_channel=3, width=16, num_blks=1, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=True, up_scale=4):
'''
NAFNet for Super-Resolution
'''
def __init__(self, up_scale=4, width=48, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=False):
super().__init__()
self.dual = dual
self.dual = dual # dual input for stereo SR (left view, right view)
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
bias=True)
self.body = MySequential(
@ -137,11 +128,10 @@ class NAFNetSR(nn.Module):
out = out + inp_hr
return out
class NAFNetSRLocal(Local_Base, NAFNetSR):
def __init__(self, *args, train_size=(1, 6, 64, 64), fast_imp=False, **kwargs):
class NAFSSR(Local_Base, NAFNetSR):
def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs):
Local_Base.__init__(self)
NAFNetSR.__init__(self, *args, **kwargs)
NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs)
N, C, H, W = train_size
base_size = (int(H * 1.5), int(W * 1.5))
@ -151,41 +141,14 @@ class NAFNetSRLocal(Local_Base, NAFNetSR):
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
if __name__ == '__main__':
img_channel = 3
num_blks = 64
width = 96
# num_blks = 32
# width = 64
# num_blks = 16
# width = 48
dual=True
# fusion_from = 0
# fusion_to = num_blks
fusion_from = 0
fusion_to = 1000
num_blks = 128
width = 128
droppath=0.1
train_size = (1, 6, 30, 90)
net = NAFNetSRLocal(up_scale=2,train_size=train_size, fast_imp=True, img_channel=img_channel, width=width, num_blks=num_blks, dual=dual,
fusion_from=fusion_from,
fusion_to=fusion_to, drop_path_rate=droppath)
# net = NAFNetSR(img_channel=img_channel, width=width, num_blks=num_blks, dual=dual,
# fusion_from=fusion_from,
# fusion_to=fusion_to, drop_path_rate=droppath)
net = NAFSSR(up_scale=2,train_size=train_size, fast_imp=True, width=width, num_blks=num_blks, drop_path_rate=droppath)
c = 6 if dual else 3
a = torch.randn((2, c, 24, 23))
b = net(a)
print(b.shape)
# inp_shape = (6, 128, 128)
inp_shape = (c, 64, 64)
# inp_shape = (6, 256, 96)
inp_shape = (6, 64, 64)
from ptflops import get_model_complexity_info
FLOPS = 0
@ -195,7 +158,7 @@ if __name__ == '__main__':
print(params)
macs = float(macs[:-4]) + FLOPS / 10 ** 9
print('mac', macs, params, 'fusion from .. to ', fusion_from, fusion_to)
print('mac', macs, params)
# from basicsr.models.archs.arch_util import measure_inference_speed
# net = net.cuda()

View File

@ -300,7 +300,7 @@ class ImageRestorationModel(BaseModel):
R_img = sr_img[:, :, 3:]
# visual_dir = osp.join('visual_results', dataset_name, self.opt['name'])
visual_dir = osp.join('visual_results', self.opt['name'], dataset_name)
visual_dir = osp.join(self.opt['path']['visualization'], dataset_name)
imwrite(L_img, osp.join(visual_dir, f'{img_name}_L.png'))
imwrite(R_img, osp.join(visual_dir, f'{img_name}_R.png'))

View File

@ -1,5 +1,5 @@
# GENERATED VERSION FILE
# TIME: Fri Apr 1 17:46:10 2022
__version__ = '1.2.0+e41bf19'
# TIME: Mon Apr 18 21:35:20 2022
__version__ = '1.2.0+386ca20'
short_version = '1.2.0'
version_info = (1, 2, 0)

139
docs/StereoSR.md 100644
View File

@ -0,0 +1,139 @@
## NAFSSR: Stereo Image Super-Resolution Using NAFNet
The official pytorch implementation of the paper **[NAFSSR: Stereo Image Super-Resolution Using NAFNet]()**
#### Xiaojie Chu\*, Liangyu Chen\*, Wenqing Yu
>This paper proposes a simple baseline named NAFSSR for stereo image super-resolution. We use a stack of NAFNet's Block (NAFBlock) for intra-view feature extraction and combine it with Stereo Cross Attention Modules (SCAM) for cross-view feature interaction.
<img src=../figures/NAFSSR_arch.jpg>
>NAFSSR outperforms the state-of-the-art methods on the KITTI 2012, KITTI 2015, Middlebury, and Flickr1024 datasets. With NAFSSR, we won **1st place** in the [NTIRE 2022 Stereo Image Super-resolution Challenge](https://codalab.lisn.upsaclay.fr/competitions/1598).
<p align="center">
<img src=../figures/NAFSSR_params.jpg width=70%>
</p>
# Reproduce the Stereo SR Results
## 1. Data Preparation
Follow previous works, our models are trained with Flickr1024 and Middlebury datasets, which is exactly the same as <a href="https://github.com/YingqianWang/iPASSR">iPASSR</a>. Please visit their homepage and follow their instructions to download and prepare the datasets.
#### Download and prepare the train set and place it in ```./datasets/StereoSR```
#### Download and prepare the evaluation data and place it in ```./datasets/StereoSR/test```
The structure of `datasets` directory should be like
```
datasets
├── StereoSR
│ ├── patches_x2
│ │ ├── 000001
│ │ ├── 000002
│ │ ├── ...
│ │ ├── 298142
│ │ └── 298143
│ ├── patches_x4
│ │ ├── 000001
│ │ ├── 000002
│ │ ├── ...
│ │ ├── 049019
│ │ └── 049020
| ├── test
│ | ├── Flickr1024
│ │ │ ├── hr
│ │ │ ├── lr_x2
│ │ │ └── lr_x4
│ | ├── KITTI2012
│ │ │ ├── hr
│ │ │ ├── lr_x2
│ │ │ └── lr_x4
│ | ├── KITTI2015
│ │ │ ├── hr
│ │ │ ├── lr_x2
│ │ │ └── lr_x4
│ │ └── Middlebury
│ │ ├── hr
│ │ ├── lr_x2
│ │ └── lr_x4
```
## 2. Evaluation
#### Download the pretrain model in ```./experiments/pretrained_models/```
| name | scale |#Params|PSNR|SSIM| pretrained models | configs |
|:----:|:----:|:----:|:----:|:----:|:----:|-----:|
|NAFSSR-T|x4|0.46M|23.69|0.7384|[gdrive](https://drive.google.com/file/d/1owfYG1KTXFMl4wHpUZefWAcVlBpLohe5/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1yC5XzJcL5peC1YuW3MkFMA?pwd=5j1u)|[train](../options/test/NAFSSR/NAFSSR-T_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-T_4x.yml)|
|NAFSSR-S|x4|1.56M|23.88|0.7468|[gdrive](https://drive.google.com/file/d/1RpfS2lemsgetIQwBwkZpZwLBJfOTDCU5/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1XvwM5KVhNsKAxWbxU85SFA?pwd=n5au)|[train](../options/test/NAFSSR/NAFSSR-S_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-S_4x.yml)|
|NAFSSR-B|x4|6.80M|24.07|0.7551|[gdrive](https://drive.google.com/file/d/1Su0OTp66_NsXUbqTAIi1msvsp0G5WVxp/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/18tVlH-QIVtvDC1LM2oPatw?pwd=3up5)|[train](../options/test/NAFSSR/NAFSSR-B_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-B_4x.yml)|
|NAFSSR-L|x4|23.83M|24.17|0.7589|[gdrive](https://drive.google.com/file/d/1TIdQhPtBrZb2wrBdAp9l8NHINLeExOwb/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1P8ioEuI1gwydA2Avr3nUvw?pwd=qs7a)|[train](../options/test/NAFSSR/NAFSSR-L_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-L_4x.yml)|
|NAFSSR-T|x2|0.46M|28.94|0.9128|[gdrive](https://drive.google.com/file/d/1sBivtt5KaFMjMhBwyajYy1uemIEFQBvW/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1RDW923v0e0G_eYvTF8I7Dg?pwd=utgs)|[train](../options/test/NAFSSR/NAFSSR-T_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-T_2x.yml)|
|NAFSSR-S|x2|1.56M|29.19|0.9160|[gdrive](https://drive.google.com/file/d/1caVrp3fFSpwiU8RPGXe-zDD1tU110yxA/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1qJmQzv-YV1V9raim57pTIQ?pwd=t2or)|[train](../options/test/NAFSSR/NAFSSR-S_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-S_2x.yml)|
|NAFSSR-B|x2|6.80M|29.54|0.7551|[gdrive](https://drive.google.com/file/d/1gOfDTfyCaff_xNm86u8sN_3MtAZvnQEk/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1IkadqW1uWx9xM5K2ETJ9-A?pwd=pv1f)|[train](../options/test/NAFSSR/NAFSSR-B_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-B_2x.yml)|
|NAFSSR-L|x2|23.79M|29.68|0.9221|[gdrive](https://drive.google.com/file/d/1SZ6bQVYTVS_AXedBEr-_mBCC-qGYHLmf/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1GS6YQSSECH8hAKhvzw6GyQ?pwd=2v3v)|[train](../options/test/NAFSSR/NAFSSR-L_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-L_2x.yml)|
*PSNR/SSIM are evaluate on Flickr1024 test set.*
### Testing on KITTI2012, KITTI2015, Middlebury, Flickr1024 datasets
* NAFSSR-T for 4x SR:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-T_x4.yml --launcher pytorch
```
* NAFSSR-S for 4x SR:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-S_x4.yml --launcher pytorch
```
* NAFSSR-B for 4x SR:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-B_x4.yml --launcher pytorch
```
* NAFSSR-L for 4x SR:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-L_x4.yml --launcher pytorch
```
* NAFSSR-L for 2x SR:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-L_x2.yml --launcher pytorch
```
* Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.
## 3. Training
* NAFNet-B for 4x SR:
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/NAFSSR/NAFSSR-B_x4.yml --launcher pytorch
```
* NAFNet-S for 4x SR:
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/NAFSSR/NAFSSR-S_x4.yml --launcher pytorch
```
* 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 403 KiB

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-B_2x
model_type: ImageRestorationModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 96
num_blks: 64
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-B_2x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-B_4x
model_type: ImageRestorationModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 96
num_blks: 64
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-B_4x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-L_2x
model_type: ImageRestorationModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 128
num_blks: 128
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-L_2x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-L_4x
model_type: ImageRestorationModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 128
num_blks: 128
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-L_4x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-S_2x
model_type: ImageRestorationModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 64
num_blks: 32
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-S_2x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-S_4x
model_type: ImageRestorationModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 64
num_blks: 32
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-S_4x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-T_2x
model_type: ImageRestorationModel
scale: 2
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 48
num_blks: 16
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-T_2x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,85 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFSSR-T_4x
model_type: ImageRestorationModel
scale: 4
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
test0:
name: KITTI2012
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2012/hr
dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4
io_backend:
type: disk
test1:
name: KITTI2015
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/KITTI2015/hr
dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4
io_backend:
type: disk
test2:
name: Middlebury
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Middlebury/hr
dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4
io_backend:
type: disk
test3:
name: Flickr1024
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 48
num_blks: 16
# path
path:
pretrain_network_g: experiments/pretrained_models/NAFSSR-T_4x.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# psnr_left: # metric name, can be arbitrary
# type: calculate_psnr_left
# crop_border: 0
# test_y_channel: false
# ssim_left:
# type: calculate_skimage_ssim_left
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-B_x2
model_type: ImageRestorationModel
scale: 2
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x2/
dataroot_lq: datasets/StereoSR/patches_x2/
io_backend:
type: disk
gt_size_h: 60
gt_size_w: 180
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 96
num_blks: 64
drop_path_rate: 0.2
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 100000
eta_min: !!float 1e-7
total_iter: 100000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-B_x4
model_type: ImageRestorationModel
scale: 4
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x4/
dataroot_lq: datasets/StereoSR/patches_x4/
io_backend:
type: disk
gt_size_h: 120
gt_size_w: 360
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 96
num_blks: 64
drop_path_rate: 0.2
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 100000
eta_min: !!float 1e-7
total_iter: 100000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-L_x2
model_type: ImageRestorationModel
scale: 2
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x2/
dataroot_lq: datasets/StereoSR/patches_x2/
io_backend:
type: disk
gt_size_h: 60
gt_size_w: 180
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 128
num_blks: 128
drop_path_rate: 0.3
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 100000
eta_min: !!float 1e-7
total_iter: 100000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-L_x4
model_type: ImageRestorationModel
scale: 4
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x4/
dataroot_lq: datasets/StereoSR/patches_x4/
io_backend:
type: disk
gt_size_h: 120
gt_size_w: 360
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 128
num_blks: 128
drop_path_rate: 0.3
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 100000
eta_min: !!float 1e-7
total_iter: 100000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-S_x2
model_type: ImageRestorationModel
scale: 2
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x2/
dataroot_lq: datasets/StereoSR/patches_x2/
io_backend:
type: disk
gt_size_h: 60
gt_size_w: 180
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 64
num_blks: 32
drop_path_rate: 0.1
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 100000
eta_min: !!float 1e-7
total_iter: 100000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-S_x4
model_type: ImageRestorationModel
scale: 4
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x4/
dataroot_lq: datasets/StereoSR/patches_x4/
io_backend:
type: disk
gt_size_h: 120
gt_size_w: 360
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 64
num_blks: 32
drop_path_rate: 0.1
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 100000
eta_min: !!float 1e-7
total_iter: 100000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-T_x2
model_type: ImageRestorationModel
scale: 2
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x2/
dataroot_lq: datasets/StereoSR/patches_x2/
io_backend:
type: disk
gt_size_h: 60
gt_size_w: 180
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 2
width: 48
num_blks: 16
drop_path_rate: 0.
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 400000
eta_min: !!float 1e-7
total_iter: 400000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,112 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNetSR-T_x4
model_type: ImageRestorationModel
scale: 4
num_gpu: 8
manual_seed: 10
datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/patches_x4/
dataroot_lq: datasets/StereoSR/patches_x4/
io_backend:
type: disk
gt_size_h: 120
gt_size_w: 360
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true
# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 4
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Flickr1024-sr-test
type: PairedStereoImageDataset
dataroot_gt: datasets/StereoSR/test/Flickr1024/hr
dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4
io_backend:
type: disk
# network structures
network_g:
type: NAFSSR
up_scale: 4
width: 48
num_blks: 16
drop_path_rate: 0.
train_size: [1, 6, 30, 90]
drop_out_rate: 0.
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 400000
eta_min: !!float 1e-7
total_iter: 400000
warmup_iter: -1 # no warm up
mixup: false
# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1
max_minibatch: 1
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -57,6 +57,8 @@ python setup.py develop --no_cuda_ext
|NAFNet-SIDD-width32|SIDD|39.9672|0.9599|[gdrive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw) (提取码: um97)|[train](./options/train/SIDD/NAFNet-width32.yml) \| [test](./options/test/SIDD/NAFNet-width32.yml)|
|NAFNet-SIDD-width64|SIDD|40.3045|0.9614|[gdrive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ) (提取码: dton)|[train](./options/train/SIDD/NAFNet-width64.yml) \| [test](./options/test/SIDD/NAFNet-width64.yml)|
|NAFNet-REDS-width64|REDS|29.0903|0.8671|[gdrive](https://drive.google.com/file/d/14D4V4raNYIOhETfcuuLI3bGLB-OYIv6X/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1vg89ccbpIxg3mK9IONBfGg) (提取码: 9fas)|[train](./options/train/REDS/NAFNet-width64.yml) \| [test](./options/test/REDS/NAFNet-width64.yml)|
|NAFSSR-L_4x|Flickr1024|24.17|0.7589|[gdrive](https://drive.google.com/file/d/1TIdQhPtBrZb2wrBdAp9l8NHINLeExOwb/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ) (提取码: dton)|[train](../options/test/NAFSSR/NAFSSR-L_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-L_4x.yml)|
|NAFSSR-L_2x|Flickr1024|29.68|0.9221|[gdrive](https://drive.google.com/file/d/1SZ6bQVYTVS_AXedBEr-_mBCC-qGYHLmf/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ) (提取码: dton)|[train](./options/test/NAFSSR/NAFSSR-L_2x.yml) \| [test](./options/test/NAFSSR/NAFSSR-L_2x.yml)|
### Image Restoration Tasks
@ -65,7 +67,7 @@ python setup.py develop --no_cuda_ext
| Image Deblurring | GoPro | [link](./docs/GoPro.md) | [gdrive](https://drive.google.com/file/d/1S8u4TqQP6eHI81F9yoVR0be-DLh4cNgb/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1yNYQhznChafsbcfHO44aHQ) (提取码: 96ii) |
| Image Denoising | SIDD | [link](./docs/SIDD.md) | [gdrive](https://drive.google.com/file/d/1rbBYD64bfvbHOrN3HByNg0vz6gHQq7Np/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1wIubY6SeXRfZHpp6bAojqQ) (提取码: hu4t) |
| Image Deblurring with JPEG artifacts | REDS | [link](./docs/REDS.md) | [gdrive](https://drive.google.com/file/d/1FwHWYPXdPtUkPqckpz-WBitpVyPuXFRi/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/17T30w5xAtBQQ2P3wawLiVA) (提取码: put5) |
| Stereo Image Super-Resolution | Flickr1024+Middlebury | [link](./docs/StereoSR.md) | [gdrive](https://drive.google.com/drive/folders/1lTKe2TU7F-KcU-oaF8jqgoUwIMb6RW0w?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1kov6ivrSFy1FuToCATbyrA?pwd=q263 ) |
### Citations