diff --git a/.gitignore b/.gitignore index 8d979a0..032d2dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,9 @@ .DS_Store .idea/* +experiments +logs/ +*results* +*__pycache__* +*.sh +datasets +basicsr.egg-info \ No newline at end of file diff --git a/basicsr/data/paired_image_SR_LR_dataset.py b/basicsr/data/paired_image_SR_LR_dataset.py index 110fa82..fe34495 100644 --- a/basicsr/data/paired_image_SR_LR_dataset.py +++ b/basicsr/data/paired_image_SR_LR_dataset.py @@ -183,6 +183,9 @@ class PairedImageSRLRDataset(data.Dataset): class PairedStereoImageDataset(data.Dataset): + ''' + Paired dataset for stereo SR (Flickr1024, KITTI, Middlebury) + ''' def __init__(self, opt): super(PairedStereoImageDataset, self).__init__() self.opt = opt @@ -255,14 +258,6 @@ class PairedStereoImageDataset(data.Dataset): gt_size = int(self.opt['gt_size']) gt_size_h, gt_size_w = gt_size, gt_size - - if 'flip_LR' in self.opt and self.opt['flip_LR']: - if np.random.rand() < 0.5: - img_gt = img_gt[:, :, [3, 4, 5, 0, 1, 2]] - img_lq = img_lq[:, :, [3, 4, 5, 0, 1, 2]] - - # img_gt, img_lq - if 'flip_RGB' in self.opt and self.opt['flip_RGB']: idx = [ [0, 1, 2, 3, 4, 5], @@ -276,42 +271,6 @@ class PairedStereoImageDataset(data.Dataset): img_gt = img_gt[:, :, idx] img_lq = img_lq[:, :, idx] - if 'inverse_RGB' in self.opt and self.opt['inverse_RGB']: - for i in range(3): - if np.random.rand() < 0.5: - img_gt[:, :, i] = 1 - img_gt[:, :, i] - img_gt[:, :, i+3] = 1 - img_gt[:, :, i+3] - img_lq[:, :, i] = 1 - img_lq[:, :, i] - img_lq[:, :, i+3] = 1 - img_lq[:, :, i+3] - - if 'naive_inverse_RGB' in self.opt and self.opt['naive_inverse_RGB']: - # for i in range(3): - if np.random.rand() < 0.5: - img_gt = 1 - img_gt - img_lq = 1 - img_lq - # img_gt[:, :, i] = 1 - img_gt[:, :, i] - # img_gt[:, :, i+3] = 1 - img_gt[:, :, i+3] - # img_lq[:, :, i] = 1 - img_lq[:, :, i] - # img_lq[:, :, i+3] = 1 - img_lq[:, :, i+3] - - if 'random_offset' in self.opt and self.opt['random_offset'] > 0: - # if np.random.rand() < 0.9: - S = int(self.opt['random_offset']) - - offsets = int(np.random.rand() * (S+1)) #1~S - s2, s4 = 0, 0 - - if np.random.rand() < 0.5: - s2 = offsets - else: - s4 = offsets - - _, w, _ = img_lq.shape - - img_lq = np.concatenate([img_lq[:, s2:w-s4, :3], img_lq[:, s4:w-s2, 3:]], axis=-1) - img_gt = np.concatenate( - [img_gt[:, 4 * s2:4*w-4 * s4, :3], img_gt[:, 4 * s4:4*w-4 * s2, 3:]], axis=-1) - # random crop img_gt, img_lq = img_gt.copy(), img_lq.copy() img_gt, img_lq = paired_random_crop_hw(img_gt, img_lq, gt_size_h, gt_size_w, scale, diff --git a/basicsr/models/archs/NAFNetSR_arch.py b/basicsr/models/archs/NAFSSR_arch.py similarity index 54% rename from basicsr/models/archs/NAFNetSR_arch.py rename to basicsr/models/archs/NAFSSR_arch.py index acf54d3..d189b0e 100644 --- a/basicsr/models/archs/NAFNetSR_arch.py +++ b/basicsr/models/archs/NAFSSR_arch.py @@ -22,56 +22,42 @@ from basicsr.models.archs.NAFNet_arch import LayerNorm2d, NAFBlock from basicsr.models.archs.arch_util import MySequential from basicsr.models.archs.local_arch import Local_Base -class GenerateRelations(nn.Module): +class SCAM(nn.Module): + ''' + Stereo Cross Attention Module (SCAM) + ''' def __init__(self, c): super().__init__() self.scale = c ** -0.5 self.norm_l = LayerNorm2d(c) self.norm_r = LayerNorm2d(c) - - self.l_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) - self.r_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) - - def forward(self, lfeats, rfeats): - B, C, H, W = lfeats.shape - - lfeats = lfeats.view(B, C, H, W) - rfeats = rfeats.view(B, C, H, W) - - lfeats, rfeats = self.l_proj(self.norm_l(lfeats)), self.r_proj(self.norm_r(rfeats)) - - x = lfeats.permute(0, 2, 3, 1) #B H W c - y = rfeats.permute(0, 2, 1, 3) #B H c W - - z = torch.matmul(x, y) #B H W W - - return self.scale * z - -class FusionModule(nn.Module): - def __init__(self, c): - super().__init__() - self.relation_generator = GenerateRelations(c) + self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) - self.l_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) - self.r_proj = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) + self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0) - def forward(self, lfeats, rfeats): - B, C, H, W = lfeats.shape + def forward(self, x_l, x_r): + Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c + Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) # B, H, c, W (transposed) - relations = self.relation_generator(lfeats, rfeats) # B, H, W, W + V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) # B, H, W, c + V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) # B, H, W, c - lfeats_projected = self.l_proj(lfeats.view(B, C, H, W)).permute(0, 2, 3, 1) # B, H, W, c - rfeats_projected = self.r_proj(rfeats.view(B, C, H, W)).permute(0, 2, 3, 1) # B, H, W, c + # (B, H, W, c) x (B, H, c, W) -> (B, H, W, W) + attention = torch.matmul(Q_l, Q_r_T) * self.scale - lresidual = torch.matmul(torch.softmax(relations, dim=-1), rfeats_projected) #B, H, W, c - rresidual = torch.matmul(torch.softmax(relations.permute(0, 1, 3, 2), dim=-1), lfeats_projected) #B, H, W, c + F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) #B, H, W, c + F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) #B, H, W, c - lresidual = lresidual.permute(0, 3, 1, 2).view(B, C, H, W) * self.beta - rresidual = rresidual.permute(0, 3, 1, 2).view(B, C, H, W) * self.gamma - return lfeats + lresidual, rfeats + rresidual + # scale + F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta + F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma + return x_l + F_r2l, x_r + F_l2r class DropPath(nn.Module): def __init__(self, drop_rate, module): @@ -91,10 +77,13 @@ class DropPath(nn.Module): return new_feats class NAFBlockSR(nn.Module): - def __init__(self, c, fusion=False, drop_out_rate=0.): + ''' + NAFBlock for Super-Resolution + ''' + def __init__(self, c, fusion=False, drop_out_rate=0.): super().__init__() self.blk = NAFBlock(c, drop_out_rate=drop_out_rate) - self.fusion = FusionModule(c) if fusion else None + self.fusion = SCAM(c) if fusion else None def forward(self, *feats): feats = tuple([self.blk(x) for x in feats]) @@ -102,11 +91,13 @@ class NAFBlockSR(nn.Module): feats = self.fusion(*feats) return feats - class NAFNetSR(nn.Module): - def __init__(self, img_channel=3, width=16, num_blks=1, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=True, up_scale=4): + ''' + NAFNet for Super-Resolution + ''' + def __init__(self, up_scale=4, width=48, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0., fusion_from=-1, fusion_to=-1, dual=False): super().__init__() - self.dual = dual + self.dual = dual # dual input for stereo SR (left view, right view) self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, bias=True) self.body = MySequential( @@ -137,11 +128,10 @@ class NAFNetSR(nn.Module): out = out + inp_hr return out - -class NAFNetSRLocal(Local_Base, NAFNetSR): - def __init__(self, *args, train_size=(1, 6, 64, 64), fast_imp=False, **kwargs): +class NAFSSR(Local_Base, NAFNetSR): + def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs): Local_Base.__init__(self) - NAFNetSR.__init__(self, *args, **kwargs) + NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs) N, C, H, W = train_size base_size = (int(H * 1.5), int(W * 1.5)) @@ -151,41 +141,14 @@ class NAFNetSRLocal(Local_Base, NAFNetSR): self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) if __name__ == '__main__': - img_channel = 3 - num_blks = 64 - width = 96 - # num_blks = 32 - # width = 64 - # num_blks = 16 - # width = 48 - dual=True - # fusion_from = 0 - # fusion_to = num_blks - fusion_from = 0 - fusion_to = 1000 + num_blks = 128 + width = 128 droppath=0.1 train_size = (1, 6, 30, 90) - net = NAFNetSRLocal(up_scale=2,train_size=train_size, fast_imp=True, img_channel=img_channel, width=width, num_blks=num_blks, dual=dual, - fusion_from=fusion_from, - fusion_to=fusion_to, drop_path_rate=droppath) - # net = NAFNetSR(img_channel=img_channel, width=width, num_blks=num_blks, dual=dual, - # fusion_from=fusion_from, - # fusion_to=fusion_to, drop_path_rate=droppath) + net = NAFSSR(up_scale=2,train_size=train_size, fast_imp=True, width=width, num_blks=num_blks, drop_path_rate=droppath) - c = 6 if dual else 3 - - a = torch.randn((2, c, 24, 23)) - - b = net(a) - - print(b.shape) - - # inp_shape = (6, 128, 128) - - inp_shape = (c, 64, 64) - - # inp_shape = (6, 256, 96) + inp_shape = (6, 64, 64) from ptflops import get_model_complexity_info FLOPS = 0 @@ -195,7 +158,7 @@ if __name__ == '__main__': print(params) macs = float(macs[:-4]) + FLOPS / 10 ** 9 - print('mac', macs, params, 'fusion from .. to ', fusion_from, fusion_to) + print('mac', macs, params) # from basicsr.models.archs.arch_util import measure_inference_speed # net = net.cuda() diff --git a/basicsr/models/image_restoration_model.py b/basicsr/models/image_restoration_model.py index 0c5c0ee..1eec564 100644 --- a/basicsr/models/image_restoration_model.py +++ b/basicsr/models/image_restoration_model.py @@ -300,7 +300,7 @@ class ImageRestorationModel(BaseModel): R_img = sr_img[:, :, 3:] # visual_dir = osp.join('visual_results', dataset_name, self.opt['name']) - visual_dir = osp.join('visual_results', self.opt['name'], dataset_name) + visual_dir = osp.join(self.opt['path']['visualization'], dataset_name) imwrite(L_img, osp.join(visual_dir, f'{img_name}_L.png')) imwrite(R_img, osp.join(visual_dir, f'{img_name}_R.png')) diff --git a/basicsr/version.py b/basicsr/version.py index f67ca82..2c4fdd3 100644 --- a/basicsr/version.py +++ b/basicsr/version.py @@ -1,5 +1,5 @@ # GENERATED VERSION FILE -# TIME: Fri Apr 1 17:46:10 2022 -__version__ = '1.2.0+e41bf19' +# TIME: Mon Apr 18 21:35:20 2022 +__version__ = '1.2.0+386ca20' short_version = '1.2.0' version_info = (1, 2, 0) diff --git a/docs/StereoSR.md b/docs/StereoSR.md new file mode 100644 index 0000000..352afb0 --- /dev/null +++ b/docs/StereoSR.md @@ -0,0 +1,139 @@ +## NAFSSR: Stereo Image Super-Resolution Using NAFNet + +The official pytorch implementation of the paper **[NAFSSR: Stereo Image Super-Resolution Using NAFNet]()** + +#### Xiaojie Chu\*, Liangyu Chen\*, Wenqing Yu + +>This paper proposes a simple baseline named NAFSSR for stereo image super-resolution. We use a stack of NAFNet's Block (NAFBlock) for intra-view feature extraction and combine it with Stereo Cross Attention Modules (SCAM) for cross-view feature interaction. + + + +>NAFSSR outperforms the state-of-the-art methods on the KITTI 2012, KITTI 2015, Middlebury, and Flickr1024 datasets. With NAFSSR, we won **1st place** in the [NTIRE 2022 Stereo Image Super-resolution Challenge](https://codalab.lisn.upsaclay.fr/competitions/1598). + +

+ +

+ + + + +# Reproduce the Stereo SR Results + +## 1. Data Preparation +Follow previous works, our models are trained with Flickr1024 and Middlebury datasets, which is exactly the same as iPASSR. Please visit their homepage and follow their instructions to download and prepare the datasets. + +#### Download and prepare the train set and place it in ```./datasets/StereoSR``` + +#### Download and prepare the evaluation data and place it in ```./datasets/StereoSR/test``` + +The structure of `datasets` directory should be like +``` + datasets + ├── StereoSR + │ ├── patches_x2 + │ │ ├── 000001 + │ │ ├── 000002 + │ │ ├── ... + │ │ ├── 298142 + │ │ └── 298143 + │ ├── patches_x4 + │ │ ├── 000001 + │ │ ├── 000002 + │ │ ├── ... + │ │ ├── 049019 + │ │ └── 049020 + | ├── test + │ | ├── Flickr1024 + │ │ │ ├── hr + │ │ │ ├── lr_x2 + │ │ │ └── lr_x4 + │ | ├── KITTI2012 + │ │ │ ├── hr + │ │ │ ├── lr_x2 + │ │ │ └── lr_x4 + │ | ├── KITTI2015 + │ │ │ ├── hr + │ │ │ ├── lr_x2 + │ │ │ └── lr_x4 + │ │ └── Middlebury + │ │ ├── hr + │ │ ├── lr_x2 + │ │ └── lr_x4 +``` + +## 2. Evaluation + + +#### Download the pretrain model in ```./experiments/pretrained_models/``` + +| name | scale |#Params|PSNR|SSIM| pretrained models | configs | +|:----:|:----:|:----:|:----:|:----:|:----:|-----:| +|NAFSSR-T|x4|0.46M|23.69|0.7384|[gdrive](https://drive.google.com/file/d/1owfYG1KTXFMl4wHpUZefWAcVlBpLohe5/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1yC5XzJcL5peC1YuW3MkFMA?pwd=5j1u)|[train](../options/test/NAFSSR/NAFSSR-T_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-T_4x.yml)| +|NAFSSR-S|x4|1.56M|23.88|0.7468|[gdrive](https://drive.google.com/file/d/1RpfS2lemsgetIQwBwkZpZwLBJfOTDCU5/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1XvwM5KVhNsKAxWbxU85SFA?pwd=n5au)|[train](../options/test/NAFSSR/NAFSSR-S_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-S_4x.yml)| +|NAFSSR-B|x4|6.80M|24.07|0.7551|[gdrive](https://drive.google.com/file/d/1Su0OTp66_NsXUbqTAIi1msvsp0G5WVxp/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/18tVlH-QIVtvDC1LM2oPatw?pwd=3up5)|[train](../options/test/NAFSSR/NAFSSR-B_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-B_4x.yml)| +|NAFSSR-L|x4|23.83M|24.17|0.7589|[gdrive](https://drive.google.com/file/d/1TIdQhPtBrZb2wrBdAp9l8NHINLeExOwb/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1P8ioEuI1gwydA2Avr3nUvw?pwd=qs7a)|[train](../options/test/NAFSSR/NAFSSR-L_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-L_4x.yml)| +|NAFSSR-T|x2|0.46M|28.94|0.9128|[gdrive](https://drive.google.com/file/d/1sBivtt5KaFMjMhBwyajYy1uemIEFQBvW/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1RDW923v0e0G_eYvTF8I7Dg?pwd=utgs)|[train](../options/test/NAFSSR/NAFSSR-T_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-T_2x.yml)| +|NAFSSR-S|x2|1.56M|29.19|0.9160|[gdrive](https://drive.google.com/file/d/1caVrp3fFSpwiU8RPGXe-zDD1tU110yxA/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1qJmQzv-YV1V9raim57pTIQ?pwd=t2or)|[train](../options/test/NAFSSR/NAFSSR-S_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-S_2x.yml)| +|NAFSSR-B|x2|6.80M|29.54|0.7551|[gdrive](https://drive.google.com/file/d/1gOfDTfyCaff_xNm86u8sN_3MtAZvnQEk/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1IkadqW1uWx9xM5K2ETJ9-A?pwd=pv1f)|[train](../options/test/NAFSSR/NAFSSR-B_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-B_2x.yml)| +|NAFSSR-L|x2|23.79M|29.68|0.9221|[gdrive](https://drive.google.com/file/d/1SZ6bQVYTVS_AXedBEr-_mBCC-qGYHLmf/view?usp=sharing) \| [baidu](https://pan.baidu.com/s/1GS6YQSSECH8hAKhvzw6GyQ?pwd=2v3v)|[train](../options/test/NAFSSR/NAFSSR-L_2x.yml) \| [test](../options/test/NAFSSR/NAFSSR-L_2x.yml)| + +*PSNR/SSIM are evaluate on Flickr1024 test set.* + + +### Testing on KITTI2012, KITTI2015, Middlebury, Flickr1024 datasets + + * NAFSSR-T for 4x SR: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-T_x4.yml --launcher pytorch +``` + + * NAFSSR-S for 4x SR: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-S_x4.yml --launcher pytorch +``` + + * NAFSSR-B for 4x SR: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-B_x4.yml --launcher pytorch +``` + + * NAFSSR-L for 4x SR: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-L_x4.yml --launcher pytorch +``` + + * NAFSSR-L for 2x SR: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/NAFSSR/NAFSSR-L_x2.yml --launcher pytorch +``` + +* Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. + + + + +## 3. Training + +* NAFNet-B for 4x SR: + + ``` + python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/NAFSSR/NAFSSR-B_x4.yml --launcher pytorch + ``` + +* NAFNet-S for 4x SR: + + ``` + python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/NAFSSR/NAFSSR-S_x4.yml --launcher pytorch + ``` + +* 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. + + + + + diff --git a/figures/NAFSSR_arch.jpg b/figures/NAFSSR_arch.jpg new file mode 100644 index 0000000..0ad34ad Binary files /dev/null and b/figures/NAFSSR_arch.jpg differ diff --git a/figures/NAFSSR_params.jpg b/figures/NAFSSR_params.jpg new file mode 100644 index 0000000..452e2a7 Binary files /dev/null and b/figures/NAFSSR_params.jpg differ diff --git a/options/test/NAFSSR/NAFSSR-B_2x.yml b/options/test/NAFSSR/NAFSSR-B_2x.yml new file mode 100644 index 0000000..30bfb70 --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-B_2x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-B_2x +model_type: ImageRestorationModel +scale: 2 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 96 + num_blks: 64 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-B_2x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/test/NAFSSR/NAFSSR-B_4x.yml b/options/test/NAFSSR/NAFSSR-B_4x.yml new file mode 100644 index 0000000..2f18e43 --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-B_4x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-B_4x +model_type: ImageRestorationModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 96 + num_blks: 64 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-B_4x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/test/NAFSSR/NAFSSR-L_2x.yml b/options/test/NAFSSR/NAFSSR-L_2x.yml new file mode 100644 index 0000000..6f94ca1 --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-L_2x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-L_2x +model_type: ImageRestorationModel +scale: 2 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 128 + num_blks: 128 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-L_2x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/test/NAFSSR/NAFSSR-L_4x.yml b/options/test/NAFSSR/NAFSSR-L_4x.yml new file mode 100644 index 0000000..79d6b27 --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-L_4x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-L_4x +model_type: ImageRestorationModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 128 + num_blks: 128 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-L_4x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/test/NAFSSR/NAFSSR-S_2x.yml b/options/test/NAFSSR/NAFSSR-S_2x.yml new file mode 100644 index 0000000..30bd6ed --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-S_2x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-S_2x +model_type: ImageRestorationModel +scale: 2 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 64 + num_blks: 32 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-S_2x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/test/NAFSSR/NAFSSR-S_4x.yml b/options/test/NAFSSR/NAFSSR-S_4x.yml new file mode 100644 index 0000000..2b70a41 --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-S_4x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-S_4x +model_type: ImageRestorationModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 64 + num_blks: 32 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-S_4x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/test/NAFSSR/NAFSSR-T_2x.yml b/options/test/NAFSSR/NAFSSR-T_2x.yml new file mode 100644 index 0000000..77b548a --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-T_2x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-T_2x +model_type: ImageRestorationModel +scale: 2 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x2 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x2 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x2 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 48 + num_blks: 16 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-T_2x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/test/NAFSSR/NAFSSR-T_4x.yml b/options/test/NAFSSR/NAFSSR-T_4x.yml new file mode 100644 index 0000000..bba812a --- /dev/null +++ b/options/test/NAFSSR/NAFSSR-T_4x.yml @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFSSR-T_4x +model_type: ImageRestorationModel +scale: 4 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + test0: + name: KITTI2012 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2012/hr + dataroot_lq: datasets/StereoSR/test/KITTI2012/lr_x4 + io_backend: + type: disk + + test1: + name: KITTI2015 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/KITTI2015/hr + dataroot_lq: datasets/StereoSR/test/KITTI2015/lr_x4 + io_backend: + type: disk + + test2: + name: Middlebury + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Middlebury/hr + dataroot_lq: datasets/StereoSR/test/Middlebury/lr_x4 + io_backend: + type: disk + + test3: + name: Flickr1024 + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 48 + num_blks: 16 + + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFSSR-T_4x.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + # psnr_left: # metric name, can be arbitrary + # type: calculate_psnr_left + # crop_border: 0 + # test_y_channel: false + # ssim_left: + # type: calculate_skimage_ssim_left + + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-B_x2.yml b/options/train/NAFSSR/NAFSSR-B_x2.yml new file mode 100644 index 0000000..b4f0c7f --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-B_x2.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-B_x2 +model_type: ImageRestorationModel +scale: 2 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x2/ + dataroot_lq: datasets/StereoSR/patches_x2/ + io_backend: + type: disk + + gt_size_h: 60 + gt_size_w: 180 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 96 + num_blks: 64 + drop_path_rate: 0.2 + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 100000 + eta_min: !!float 1e-7 + + total_iter: 100000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-B_x4.yml b/options/train/NAFSSR/NAFSSR-B_x4.yml new file mode 100644 index 0000000..e975323 --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-B_x4.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-B_x4 +model_type: ImageRestorationModel +scale: 4 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x4/ + dataroot_lq: datasets/StereoSR/patches_x4/ + io_backend: + type: disk + + gt_size_h: 120 + gt_size_w: 360 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 96 + num_blks: 64 + drop_path_rate: 0.2 + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 100000 + eta_min: !!float 1e-7 + + total_iter: 100000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-L_x2.yml b/options/train/NAFSSR/NAFSSR-L_x2.yml new file mode 100644 index 0000000..2f417f9 --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-L_x2.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-L_x2 +model_type: ImageRestorationModel +scale: 2 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x2/ + dataroot_lq: datasets/StereoSR/patches_x2/ + io_backend: + type: disk + + gt_size_h: 60 + gt_size_w: 180 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 128 + num_blks: 128 + drop_path_rate: 0.3 + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 100000 + eta_min: !!float 1e-7 + + total_iter: 100000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-L_x4.yml b/options/train/NAFSSR/NAFSSR-L_x4.yml new file mode 100644 index 0000000..f5b5a46 --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-L_x4.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-L_x4 +model_type: ImageRestorationModel +scale: 4 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x4/ + dataroot_lq: datasets/StereoSR/patches_x4/ + io_backend: + type: disk + + gt_size_h: 120 + gt_size_w: 360 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 128 + num_blks: 128 + drop_path_rate: 0.3 + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 100000 + eta_min: !!float 1e-7 + + total_iter: 100000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-S_x2.yml b/options/train/NAFSSR/NAFSSR-S_x2.yml new file mode 100644 index 0000000..37a8365 --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-S_x2.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-S_x2 +model_type: ImageRestorationModel +scale: 2 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x2/ + dataroot_lq: datasets/StereoSR/patches_x2/ + io_backend: + type: disk + + gt_size_h: 60 + gt_size_w: 180 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 64 + num_blks: 32 + drop_path_rate: 0.1 + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 100000 + eta_min: !!float 1e-7 + + total_iter: 100000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-S_x4.yml b/options/train/NAFSSR/NAFSSR-S_x4.yml new file mode 100644 index 0000000..ff1a6cc --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-S_x4.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-S_x4 +model_type: ImageRestorationModel +scale: 4 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x4/ + dataroot_lq: datasets/StereoSR/patches_x4/ + io_backend: + type: disk + + gt_size_h: 120 + gt_size_w: 360 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 64 + num_blks: 32 + drop_path_rate: 0.1 + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 100000 + eta_min: !!float 1e-7 + + total_iter: 100000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-T_x2.yml b/options/train/NAFSSR/NAFSSR-T_x2.yml new file mode 100644 index 0000000..7d97d08 --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-T_x2.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-T_x2 +model_type: ImageRestorationModel +scale: 2 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x2/ + dataroot_lq: datasets/StereoSR/patches_x2/ + io_backend: + type: disk + + gt_size_h: 60 + gt_size_w: 180 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x2 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 2 + width: 48 + num_blks: 16 + drop_path_rate: 0. + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 400000 + eta_min: !!float 1e-7 + + total_iter: 400000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/NAFSSR/NAFSSR-T_x4.yml b/options/train/NAFSSR/NAFSSR-T_x4.yml new file mode 100644 index 0000000..76f68dc --- /dev/null +++ b/options/train/NAFSSR/NAFSSR-T_x4.yml @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNetSR-T_x4 +model_type: ImageRestorationModel +scale: 4 +num_gpu: 8 +manual_seed: 10 + +datasets: + train: + name: Flickr1024-sr-train + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/patches_x4/ + dataroot_lq: datasets/StereoSR/patches_x4/ + io_backend: + type: disk + + gt_size_h: 120 + gt_size_w: 360 + use_hflip: true + use_vflip: true + use_rot: false + flip_RGB: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 4 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: Flickr1024-sr-test + type: PairedStereoImageDataset + dataroot_gt: datasets/StereoSR/test/Flickr1024/hr + dataroot_lq: datasets/StereoSR/test/Flickr1024/lr_x4 + io_backend: + type: disk + +# network structures +network_g: + type: NAFSSR + up_scale: 4 + width: 48 + num_blks: 16 + drop_path_rate: 0. + train_size: [1, 6, 30, 90] + drop_out_rate: 0. + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 3e-3 + weight_decay: !!float 0 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 400000 + eta_min: !!float 1e-7 + + total_iter: 400000 + warmup_iter: -1 # no warm up + mixup: false + + # losses + pixel_opt: + type: MSELoss + loss_weight: 1. + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + trans_num: 1 + + max_minibatch: 1 + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_skimage_ssim + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 1e4 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/readme.md b/readme.md index 3b2d97c..f9d1d61 100644 --- a/readme.md +++ b/readme.md @@ -57,6 +57,8 @@ python setup.py develop --no_cuda_ext |NAFNet-SIDD-width32|SIDD|39.9672|0.9599|[gdrive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw) (提取码: um97)|[train](./options/train/SIDD/NAFNet-width32.yml) \| [test](./options/test/SIDD/NAFNet-width32.yml)| |NAFNet-SIDD-width64|SIDD|40.3045|0.9614|[gdrive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ) (提取码: dton)|[train](./options/train/SIDD/NAFNet-width64.yml) \| [test](./options/test/SIDD/NAFNet-width64.yml)| |NAFNet-REDS-width64|REDS|29.0903|0.8671|[gdrive](https://drive.google.com/file/d/14D4V4raNYIOhETfcuuLI3bGLB-OYIv6X/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1vg89ccbpIxg3mK9IONBfGg) (提取码: 9fas)|[train](./options/train/REDS/NAFNet-width64.yml) \| [test](./options/test/REDS/NAFNet-width64.yml)| +|NAFSSR-L_4x|Flickr1024|24.17|0.7589|[gdrive](https://drive.google.com/file/d/1TIdQhPtBrZb2wrBdAp9l8NHINLeExOwb/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ) (提取码: dton)|[train](../options/test/NAFSSR/NAFSSR-L_4x.yml) \| [test](../options/test/NAFSSR/NAFSSR-L_4x.yml)| +|NAFSSR-L_2x|Flickr1024|29.68|0.9221|[gdrive](https://drive.google.com/file/d/1SZ6bQVYTVS_AXedBEr-_mBCC-qGYHLmf/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ) (提取码: dton)|[train](./options/test/NAFSSR/NAFSSR-L_2x.yml) \| [test](./options/test/NAFSSR/NAFSSR-L_2x.yml)| ### Image Restoration Tasks @@ -65,7 +67,7 @@ python setup.py develop --no_cuda_ext | Image Deblurring | GoPro | [link](./docs/GoPro.md) | [gdrive](https://drive.google.com/file/d/1S8u4TqQP6eHI81F9yoVR0be-DLh4cNgb/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1yNYQhznChafsbcfHO44aHQ) (提取码: 96ii) | | Image Denoising | SIDD | [link](./docs/SIDD.md) | [gdrive](https://drive.google.com/file/d/1rbBYD64bfvbHOrN3HByNg0vz6gHQq7Np/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1wIubY6SeXRfZHpp6bAojqQ) (提取码: hu4t) | | Image Deblurring with JPEG artifacts | REDS | [link](./docs/REDS.md) | [gdrive](https://drive.google.com/file/d/1FwHWYPXdPtUkPqckpz-WBitpVyPuXFRi/view?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/17T30w5xAtBQQ2P3wawLiVA) (提取码: put5) | - +| Stereo Image Super-Resolution | Flickr1024+Middlebury | [link](./docs/StereoSR.md) | [gdrive](https://drive.google.com/drive/folders/1lTKe2TU7F-KcU-oaF8jqgoUwIMb6RW0w?usp=sharing) \| [百度网盘](https://pan.baidu.com/s/1kov6ivrSFy1FuToCATbyrA?pwd=q263 ) | ### Citations