Codeflow for training with BSDS300 dataset

iSigned-off-by: Ranjan Debnath <m22aie245@gmail.com>
pull/141/head
Ranjan Debnath 2024-04-22 23:35:27 +05:30
parent cf36476ea3
commit 273bd48f21
17 changed files with 774 additions and 14 deletions

View File

@ -39,6 +39,10 @@ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
img = to_y_channel(img)
img2 = to_y_channel(img2)
if isinstance(img, torch.Tensor):
img = img.numpy()
if isinstance(img2, torch.Tensor):
img2 = img2.numpy()
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
@ -119,6 +123,10 @@ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
img = to_y_channel(img)
img2 = to_y_channel(img2)
if isinstance(img, torch.Tensor):
img = img.numpy()
if isinstance(img2, torch.Tensor):
img2 = img2.numpy()
img = img.astype(np.float64)
img2 = img2.astype(np.float64)

View File

@ -14,7 +14,8 @@ from basicsr.train import parse_options
from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
make_exp_dirs)
from basicsr.utils.options import dict2str
import gc
torch.cuda.empty_cache()
def main():
# parse options, set distributed setting, set ramdom seed
@ -68,3 +69,4 @@ def main():
if __name__ == '__main__':
main()
gc.collect()

View File

@ -23,6 +23,9 @@ from basicsr.utils import (MessageLogger, check_resume, get_env_info,
set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse
import gc
torch.cuda.empty_cache()
def parse_options(is_train=True):
@ -303,3 +306,4 @@ if __name__ == '__main__':
import os
os.environ['GRPC_POLL_STRATEGY']='epoll1'
main()
gc.collect()

View File

@ -1,10 +1,11 @@
import sys
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet")
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
from .diffjpeg import DiffJPEG
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt, scandir_SIDD
from .options import yaml_load
__all__ = [
@ -35,14 +36,12 @@ __all__ = [
'mkdir_and_rename',
'make_exp_dirs',
'scandir',
'scandir_SIDD',
'check_resume',
'sizeof_fmt',
# diffjpeg
'DiffJPEG',
# img_process_util
'USMSharp',
'usm_sharp',
# options
'yaml_load',
'padding'
]

View File

@ -0,0 +1,208 @@
import numpy as np
import torch
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
out_img = np.matmul(
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
conversion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace conversion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
def rgb2ycbcr_pt(img, y_only=False):
"""Convert RGB images to YCbCr images (PyTorch version).
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
Args:
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
"""
if y_only:
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
else:
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
out_img = out_img / 255.
return out_img

View File

@ -21,6 +21,7 @@ def prepare_keys(folder_path, suffix='png'):
list[str]: Key list.
"""
print('Reading image path list ...')
# import ipdb; ipdb.set_trace()
img_path_list = sorted(
list(scandir(folder_path, suffix=suffix, recursive=False)))
keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
@ -131,3 +132,17 @@ def create_lmdb_for_SIDD():
img_path_list, keys = prepare_keys(folder_path, 'png')
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
'''
def create_lmdb_for_BSDS300(root_path=''):
if root_path != '':
folder_path = f'{root_path}/input_crops'
lmdb_path = f'{root_path}/input_crops.lmdb'
img_path_list, keys = prepare_keys(folder_path, 'jpg')
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
folder_path = f'{root_path}/gt_crops'
lmdb_path = f'{root_path}/gt_crops.lmdb'
img_path_list, keys = prepare_keys(folder_path, 'jpg')
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)

View File

@ -145,3 +145,45 @@ def sizeof_fmt(size, suffix='B'):
return f'{size:3.1f} {unit}{suffix}'
size /= 1024.0
return f'{size:3.1f} Y{suffix}'
def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
keywords (str | tuple(str), optional): File keywords that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
if (keywords is not None) and not isinstance(keywords, (str, tuple)):
raise TypeError('"keywords" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, keywords, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if keywords is None:
yield return_path
elif return_path.find(keywords) > 0:
yield return_path
else:
if recursive:
yield from _scandir(
entry.path, keywords=keywords, recursive=recursive)
else:
continue
return _scandir(dir_path, keywords=keywords, recursive=recursive)

View File

@ -0,0 +1,88 @@
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class, suffix)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj, suffix)
def get(self, name, suffix='basicsr'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')

View File

@ -1,6 +1,6 @@
# GENERATED VERSION FILE
# TIME: Sat Apr 20 12:26:03 2024
__version__ = '1.2.0+2b4af71'
__gitsha__ = '2b4af71'
# TIME: Sat Apr 20 19:01:02 2024
__version__ = '1.2.0+cf36476'
__gitsha__ = 'cf36476'
short_version = '1.2.0'
version_info = (1, 2, 0)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 167 KiB

91
docs/BSDS300.md 100644
View File

@ -0,0 +1,91 @@
# reproduce the SIDD dataset results
### 1. Data Preparation
##### Download the train set and place it in ```./datasets/SIDD/Data```:
* [google drive](https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1EnBVjrfFBiXIRPBgjFrifg?pwd=sl6h),
* ```python scripts/data_preparation/sidd.py``` to crop the train image pairs to 512x512 patches and make the data into lmdb format.
##### Download the evaluation data (in lmdb format) and place it in ```./datasets/SIDD/val/```:
* [google drive](https://drive.google.com/file/d/1gZx_K2vmiHalRNOb1aj93KuUQ2guOlLp/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1I9N5fDa4SNP0nuHEy6k-rw?pwd=59d7),
* it should be like ```./datasets/SIDD/val/input_crops.lmdb``` and ```./datasets/SIDD/val/gt_crops.lmdb```
### 2. Training
* NAFNet-SIDD-width32:
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width32.yml --launcher pytorch
```
* NAFNet-SIDD-width64:
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width64.yml --launcher pytorch
```
* Baseline-SIDD-width32:
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width32.yml --launcher pytorch
```
* Baseline-SIDD-width64:
```
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width64.yml --launcher pytorch
```
* 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.
### 3. Evaluation
##### Download the pretrain model in ```./experiments/pretrained_models/```
* **NAFNet-SIDD-width32**: [google drive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw?pwd=um97)
* **NAFNet-SIDD-width64**: [google drive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ?pwd=dton)
* **Baseline-SIDD-width32**: [google drive](https://drive.google.com/file/d/1NhqVcqkDcYvYgF_P4BOOfo9tuTcKDuhW/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1wkskmCRKhXq6dGa6Ns8D0A?pwd=0rin)
* **Baseline-SIDD-width64**: [google drive](https://drive.google.com/file/d/1wQ1HHHPhSp70_ledMBZhDhIGjZQs16wO/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1ivruGfSRGfWq5AEB8qc7YQ?pwd=t9w8)
##### Testing on SIDD dataset
* NAFNet-SIDD-width32:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width32.yml --launcher pytorch
```
* NAFNet-SIDD-width64:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width64.yml --launcher pytorch
```
* Baseline-SIDD-width32:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width32.yml --launcher pytorch
```
* Baseline-SIDD-width64:
```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width64.yml --launcher pytorch
```
* Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.

View File

@ -0,0 +1,60 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNet-BSDS-width64-test
model_type: ImageRestorationModel
scale: 1
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 10
# dataset and data loader settings
datasets:
val:
name: BSDS_val
type: PairedImageDataset
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb
io_backend:
type: lmdb
# network structures
network_g:
type: NAFNet
width: 64
enc_blk_nums: [2, 2, 4, 8]
middle_blk_num: 12
dec_blk_nums: [2, 2, 2, 2]
# path
path:
pretrain_network_g: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/experiments/NAFNet-BSDS-width64/models/net_g_latest.pth
strict_load_g: true
resume_state: ~
# validation settings
val:
save_img: true
grids: false
use_image: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_ssim
crop_border: 0
test_y_channel: false
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -0,0 +1,109 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
# general settings
name: NAFNet-BSDS-width64
model_type: ImageRestorationModel
scale: 1
num_gpu: 1
manual_seed: 10
datasets:
train:
name: BSDS
type: PairedImageDataset
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/gt_crops.lmdb
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/input_crops.lmdb
filename_tmpl: '{}'
io_backend:
type: lmdb
gt_size: 256
use_flip: false
use_rot: false
# data loader
use_shuffle: true
num_worker_per_gpu: 1
batch_size_per_gpu: 1
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: BSDS_val
type: PairedImageDataset
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb
io_backend:
type: lmdb
network_g:
type: NAFNet
width: 64
enc_blk_nums: [2, 2, 4, 8]
middle_blk_num: 12
dec_blk_nums: [2, 2, 2, 2]
# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~
# training settings
train:
optim_g:
type: AdamW
lr: !!float 1e-3
weight_decay: 0.
betas: [0.9, 0.9]
scheduler:
type: TrueCosineAnnealingLR
T_max: 400000
eta_min: !!float 1e-7
# total_iter: 200000
total_iter: 123076
warmup_iter: -1 # no warm up
# losses
pixel_opt:
type: PSNRLoss
loss_weight: 1
reduction: mean
# validation settings
val:
val_freq: !!float 2e4
save_img: false
use_image: false
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_ssim
crop_border: 0
test_y_channel: false
# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 5e3
use_tb_logger: true
wandb:
project: ~
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29500

View File

@ -28,8 +28,8 @@ datasets:
# data loader
use_shuffle: true
num_worker_per_gpu: 8
batch_size_per_gpu: 8
num_worker_per_gpu: 1
batch_size_per_gpu: 1
dataset_enlarge_ratio: 1
prefetch_mode: ~

View File

@ -0,0 +1,133 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
import cv2
import numpy as np
import os
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
from basicsr.utils import scandir_SIDD
from basicsr.utils.create_lmdb import create_lmdb_for_BSDS300
import sys
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/")
def main():
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3
opt['input_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/Data'
opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops'
opt['crop_size'] = 512
opt['crop_size'] = 30
opt['step'] = 384
opt['thresh_size'] = 0
opt['keywords'] = '_NOISY'
extract_subimages(opt)
opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops'
opt['keywords'] = '_GT'
extract_subimages(opt)
create_lmdb_for_BSDS300('/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val')
def extract_subimages(opt):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print(f'mkdir {save_folder} ...')
else:
print(f'Folder {save_folder} already exists. Exit.')
# sys.exit(1)
img_list = list(scandir_SIDD(input_folder, keywords=opt['keywords'], recursive=True, full_path=True))
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(
worker, args=(path, opt), callback=lambda arg: pbar.update(1))
pool.close()
pool.join()
pbar.close()
print('All processes done.')
def worker(path, opt):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower
than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size = opt['crop_size']
step = opt['step']
thresh_size = opt['thresh_size']
img_name, extension = osp.splitext(osp.basename(path))
img_name = img_name.replace(opt['keywords'], '')
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
resized_img = cv2.resize(img, (512, 512))
index = 0
cv2.imwrite(
osp.join(opt['save_folder'],
f'{img_name}_s{index:03d}{extension}'), resized_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
process_info = f'Processing {img_name} ...'
return process_info
if img.ndim == 2:
h, w = img.shape
elif img.ndim == 3:
h, w, c = img.shape
else:
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')
h_space = np.arange(0, h - crop_size + 1, step)
if h - (h_space[-1] + crop_size) > thresh_size:
h_space = np.append(h_space, h - crop_size)
w_space = np.arange(0, w - crop_size + 1, step)
if w - (w_space[-1] + crop_size) > thresh_size:
w_space = np.append(w_space, w - crop_size)
index = 0
for x in h_space:
for y in w_space:
index += 1
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
cropped_img = np.ascontiguousarray(cropped_img)
cv2.imwrite(
osp.join(opt['save_folder'],
f'{img_name}_s{index:03d}{extension}'), cropped_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
process_info = f'Processing {img_name} ...'
return process_info
if __name__ == '__main__':
main()
# ... make sidd to lmdb

View File

@ -14,7 +14,8 @@ from tqdm import tqdm
from basicsr.utils import scandir_SIDD
from basicsr.utils.create_lmdb import create_lmdb_for_SIDD
import sys
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/")
def main():
opt = {}