Codeflow for training with BSDS300 dataset
iSigned-off-by: Ranjan Debnath <m22aie245@gmail.com>pull/141/head
parent
cf36476ea3
commit
273bd48f21
|
@ -39,6 +39,10 @@ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
|
|||
img = to_y_channel(img)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.numpy()
|
||||
if isinstance(img2, torch.Tensor):
|
||||
img2 = img2.numpy()
|
||||
img = img.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
|
@ -119,6 +123,10 @@ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
|
|||
img = to_y_channel(img)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.numpy()
|
||||
if isinstance(img2, torch.Tensor):
|
||||
img2 = img2.numpy()
|
||||
img = img.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@ from basicsr.train import parse_options
|
|||
from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
|
||||
make_exp_dirs)
|
||||
from basicsr.utils.options import dict2str
|
||||
|
||||
import gc
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def main():
|
||||
# parse options, set distributed setting, set ramdom seed
|
||||
|
@ -68,3 +69,4 @@ def main():
|
|||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
gc.collect()
|
|
@ -23,6 +23,9 @@ from basicsr.utils import (MessageLogger, check_resume, get_env_info,
|
|||
set_random_seed)
|
||||
from basicsr.utils.dist_util import get_dist_info, init_dist
|
||||
from basicsr.utils.options import dict2str, parse
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def parse_options(is_train=True):
|
||||
|
@ -303,3 +306,4 @@ if __name__ == '__main__':
|
|||
import os
|
||||
os.environ['GRPC_POLL_STRATEGY']='epoll1'
|
||||
main()
|
||||
gc.collect()
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import sys
|
||||
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet")
|
||||
|
||||
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
|
||||
from .diffjpeg import DiffJPEG
|
||||
from .file_client import FileClient
|
||||
from .img_process_util import USMSharp, usm_sharp
|
||||
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding
|
||||
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt, scandir_SIDD
|
||||
from .options import yaml_load
|
||||
|
||||
__all__ = [
|
||||
|
@ -35,14 +36,12 @@ __all__ = [
|
|||
'mkdir_and_rename',
|
||||
'make_exp_dirs',
|
||||
'scandir',
|
||||
'scandir_SIDD',
|
||||
'check_resume',
|
||||
'sizeof_fmt',
|
||||
# diffjpeg
|
||||
'DiffJPEG',
|
||||
# img_process_util
|
||||
'USMSharp',
|
||||
'usm_sharp',
|
||||
# options
|
||||
'yaml_load',
|
||||
'padding'
|
||||
]
|
||||
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=False):
|
||||
"""Convert a RGB image to YCbCr image.
|
||||
|
||||
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def bgr2ycbcr(img, y_only=False):
|
||||
"""Convert a BGR image to YCbCr image.
|
||||
|
||||
The bgr version of rgb2ycbcr.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
"""Convert a YCbCr image to RGB image.
|
||||
|
||||
This function produces the same results as Matlab's ycbcr2rgb function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted RGB image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2bgr(img):
|
||||
"""Convert a YCbCr image to BGR image.
|
||||
|
||||
The bgr version of ycbcr2rgb.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted BGR image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
||||
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def _convert_input_type_range(img):
|
||||
"""Convert the type and range of the input image.
|
||||
|
||||
It converts the input image to np.float32 type and range of [0, 1].
|
||||
It is mainly used for pre-processing the input image in colorspace
|
||||
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with type of np.float32 and range of
|
||||
[0, 1].
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = img.astype(np.float32)
|
||||
if img_type == np.float32:
|
||||
pass
|
||||
elif img_type == np.uint8:
|
||||
img /= 255.
|
||||
else:
|
||||
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
||||
return img
|
||||
|
||||
|
||||
def _convert_output_type_range(img, dst_type):
|
||||
"""Convert the type and range of the image according to dst_type.
|
||||
|
||||
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||
images will be converted to np.uint8 type with range [0, 255]. If
|
||||
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||
range [0, 1].
|
||||
It is mainly used for post-processing images in colorspace conversion
|
||||
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The image to be converted with np.float32 type and
|
||||
range [0, 255].
|
||||
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||
converts the image to np.uint8 type with range [0, 255]. If
|
||||
dst_type is np.float32, it converts the image to np.float32 type
|
||||
with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with desired type and range.
|
||||
"""
|
||||
if dst_type not in (np.uint8, np.float32):
|
||||
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
||||
if dst_type == np.uint8:
|
||||
img = img.round()
|
||||
else:
|
||||
img /= 255.
|
||||
return img.astype(dst_type)
|
||||
|
||||
|
||||
def rgb2ycbcr_pt(img, y_only=False):
|
||||
"""Convert RGB images to YCbCr images (PyTorch version).
|
||||
|
||||
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
Args:
|
||||
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
|
||||
"""
|
||||
if y_only:
|
||||
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
|
||||
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
||||
else:
|
||||
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
|
||||
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
|
||||
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
||||
|
||||
out_img = out_img / 255.
|
||||
return out_img
|
|
@ -21,6 +21,7 @@ def prepare_keys(folder_path, suffix='png'):
|
|||
list[str]: Key list.
|
||||
"""
|
||||
print('Reading image path list ...')
|
||||
# import ipdb; ipdb.set_trace()
|
||||
img_path_list = sorted(
|
||||
list(scandir(folder_path, suffix=suffix, recursive=False)))
|
||||
keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
|
||||
|
@ -131,3 +132,17 @@ def create_lmdb_for_SIDD():
|
|||
img_path_list, keys = prepare_keys(folder_path, 'png')
|
||||
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
||||
'''
|
||||
|
||||
def create_lmdb_for_BSDS300(root_path=''):
|
||||
if root_path != '':
|
||||
folder_path = f'{root_path}/input_crops'
|
||||
lmdb_path = f'{root_path}/input_crops.lmdb'
|
||||
|
||||
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
||||
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
||||
|
||||
folder_path = f'{root_path}/gt_crops'
|
||||
lmdb_path = f'{root_path}/gt_crops.lmdb'
|
||||
|
||||
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
||||
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
|
@ -145,3 +145,45 @@ def sizeof_fmt(size, suffix='B'):
|
|||
return f'{size:3.1f} {unit}{suffix}'
|
||||
size /= 1024.0
|
||||
return f'{size:3.1f} Y{suffix}'
|
||||
|
||||
def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
keywords (str | tuple(str), optional): File keywords that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative pathes.
|
||||
"""
|
||||
|
||||
if (keywords is not None) and not isinstance(keywords, (str, tuple)):
|
||||
raise TypeError('"keywords" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, keywords, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if keywords is None:
|
||||
yield return_path
|
||||
elif return_path.find(keywords) > 0:
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(
|
||||
entry.path, keywords=keywords, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, keywords=keywords, recursive=recursive)
|
|
@ -0,0 +1,88 @@
|
|||
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
||||
|
||||
|
||||
class Registry():
|
||||
"""
|
||||
The registry that provides name -> object mapping, to support third-party
|
||||
users' custom modules.
|
||||
|
||||
To create a registry (e.g. a backbone registry):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
|
||||
To register an object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone():
|
||||
...
|
||||
|
||||
Or:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Args:
|
||||
name (str): the name of this registry
|
||||
"""
|
||||
self._name = name
|
||||
self._obj_map = {}
|
||||
|
||||
def _do_register(self, name, obj, suffix=None):
|
||||
if isinstance(suffix, str):
|
||||
name = name + '_' + suffix
|
||||
|
||||
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
|
||||
f"in '{self._name}' registry!")
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj=None, suffix=None):
|
||||
"""
|
||||
Register the given object under the the name `obj.__name__`.
|
||||
Can be used as either a decorator or not.
|
||||
See docstring of this class for usage.
|
||||
"""
|
||||
if obj is None:
|
||||
# used as a decorator
|
||||
def deco(func_or_class):
|
||||
name = func_or_class.__name__
|
||||
self._do_register(name, func_or_class, suffix)
|
||||
return func_or_class
|
||||
|
||||
return deco
|
||||
|
||||
# used as a function call
|
||||
name = obj.__name__
|
||||
self._do_register(name, obj, suffix)
|
||||
|
||||
def get(self, name, suffix='basicsr'):
|
||||
ret = self._obj_map.get(name)
|
||||
if ret is None:
|
||||
ret = self._obj_map.get(name + '_' + suffix)
|
||||
print(f'Name {name} is not found, use name: {name}_{suffix}!')
|
||||
if ret is None:
|
||||
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
|
||||
return ret
|
||||
|
||||
def __contains__(self, name):
|
||||
return name in self._obj_map
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._obj_map.items())
|
||||
|
||||
def keys(self):
|
||||
return self._obj_map.keys()
|
||||
|
||||
|
||||
DATASET_REGISTRY = Registry('dataset')
|
||||
ARCH_REGISTRY = Registry('arch')
|
||||
MODEL_REGISTRY = Registry('model')
|
||||
LOSS_REGISTRY = Registry('loss')
|
||||
METRIC_REGISTRY = Registry('metric')
|
|
@ -1,6 +1,6 @@
|
|||
# GENERATED VERSION FILE
|
||||
# TIME: Sat Apr 20 12:26:03 2024
|
||||
__version__ = '1.2.0+2b4af71'
|
||||
__gitsha__ = '2b4af71'
|
||||
# TIME: Sat Apr 20 19:01:02 2024
|
||||
__version__ = '1.2.0+cf36476'
|
||||
__gitsha__ = 'cf36476'
|
||||
short_version = '1.2.0'
|
||||
version_info = (1, 2, 0)
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 61 KiB |
BIN
demo/noisy.png
BIN
demo/noisy.png
Binary file not shown.
Before Width: | Height: | Size: 167 KiB |
|
@ -0,0 +1,91 @@
|
|||
# reproduce the SIDD dataset results
|
||||
|
||||
|
||||
|
||||
### 1. Data Preparation
|
||||
|
||||
##### Download the train set and place it in ```./datasets/SIDD/Data```:
|
||||
|
||||
* [google drive](https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1EnBVjrfFBiXIRPBgjFrifg?pwd=sl6h),
|
||||
* ```python scripts/data_preparation/sidd.py``` to crop the train image pairs to 512x512 patches and make the data into lmdb format.
|
||||
|
||||
##### Download the evaluation data (in lmdb format) and place it in ```./datasets/SIDD/val/```:
|
||||
|
||||
* [google drive](https://drive.google.com/file/d/1gZx_K2vmiHalRNOb1aj93KuUQ2guOlLp/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1I9N5fDa4SNP0nuHEy6k-rw?pwd=59d7),
|
||||
* it should be like ```./datasets/SIDD/val/input_crops.lmdb``` and ```./datasets/SIDD/val/gt_crops.lmdb```
|
||||
|
||||
|
||||
|
||||
### 2. Training
|
||||
|
||||
* NAFNet-SIDD-width32:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width32.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* NAFNet-SIDD-width64:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width64.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* Baseline-SIDD-width32:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width32.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* Baseline-SIDD-width64:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width64.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.
|
||||
|
||||
|
||||
|
||||
|
||||
### 3. Evaluation
|
||||
|
||||
|
||||
##### Download the pretrain model in ```./experiments/pretrained_models/```
|
||||
|
||||
* **NAFNet-SIDD-width32**: [google drive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw?pwd=um97)
|
||||
|
||||
* **NAFNet-SIDD-width64**: [google drive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ?pwd=dton)
|
||||
|
||||
* **Baseline-SIDD-width32**: [google drive](https://drive.google.com/file/d/1NhqVcqkDcYvYgF_P4BOOfo9tuTcKDuhW/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1wkskmCRKhXq6dGa6Ns8D0A?pwd=0rin)
|
||||
|
||||
* **Baseline-SIDD-width64**: [google drive](https://drive.google.com/file/d/1wQ1HHHPhSp70_ledMBZhDhIGjZQs16wO/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1ivruGfSRGfWq5AEB8qc7YQ?pwd=t9w8)
|
||||
|
||||
|
||||
##### Testing on SIDD dataset
|
||||
|
||||
* NAFNet-SIDD-width32:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width32.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* NAFNet-SIDD-width64:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width64.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* Baseline-SIDD-width32:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width32.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* Baseline-SIDD-width64:
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width64.yml --launcher pytorch
|
||||
```
|
||||
|
||||
* Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
# general settings
|
||||
name: NAFNet-BSDS-width64-test
|
||||
model_type: ImageRestorationModel
|
||||
scale: 1
|
||||
num_gpu: 1 # set num_gpu: 0 for cpu mode
|
||||
manual_seed: 10
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
|
||||
val:
|
||||
name: BSDS_val
|
||||
type: PairedImageDataset
|
||||
|
||||
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb
|
||||
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb
|
||||
|
||||
io_backend:
|
||||
type: lmdb
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: NAFNet
|
||||
width: 64
|
||||
enc_blk_nums: [2, 2, 4, 8]
|
||||
middle_blk_num: 12
|
||||
dec_blk_nums: [2, 2, 2, 2]
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/experiments/NAFNet-BSDS-width64/models/net_g_latest.pth
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
save_img: true
|
||||
grids: false
|
||||
use_image: false
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
ssim:
|
||||
type: calculate_ssim
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
|
@ -0,0 +1,109 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
# general settings
|
||||
name: NAFNet-BSDS-width64
|
||||
model_type: ImageRestorationModel
|
||||
scale: 1
|
||||
num_gpu: 1
|
||||
manual_seed: 10
|
||||
|
||||
datasets:
|
||||
train:
|
||||
name: BSDS
|
||||
type: PairedImageDataset
|
||||
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/gt_crops.lmdb
|
||||
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/input_crops.lmdb
|
||||
|
||||
filename_tmpl: '{}'
|
||||
io_backend:
|
||||
type: lmdb
|
||||
|
||||
gt_size: 256
|
||||
use_flip: false
|
||||
use_rot: false
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 1
|
||||
batch_size_per_gpu: 1
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
val:
|
||||
name: BSDS_val
|
||||
type: PairedImageDataset
|
||||
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb
|
||||
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb
|
||||
io_backend:
|
||||
type: lmdb
|
||||
|
||||
|
||||
network_g:
|
||||
type: NAFNet
|
||||
width: 64
|
||||
enc_blk_nums: [2, 2, 4, 8]
|
||||
middle_blk_num: 12
|
||||
dec_blk_nums: [2, 2, 2, 2]
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
optim_g:
|
||||
type: AdamW
|
||||
lr: !!float 1e-3
|
||||
weight_decay: 0.
|
||||
betas: [0.9, 0.9]
|
||||
|
||||
scheduler:
|
||||
type: TrueCosineAnnealingLR
|
||||
T_max: 400000
|
||||
eta_min: !!float 1e-7
|
||||
|
||||
# total_iter: 200000
|
||||
total_iter: 123076
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: PSNRLoss
|
||||
loss_weight: 1
|
||||
reduction: mean
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 2e4
|
||||
save_img: false
|
||||
use_image: false
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
ssim:
|
||||
type: calculate_ssim
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 200
|
||||
save_checkpoint_freq: !!float 5e3
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
dist_params:
|
||||
backend: nccl
|
||||
port: 29500
|
|
@ -28,8 +28,8 @@ datasets:
|
|||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 8
|
||||
batch_size_per_gpu: 8
|
||||
num_worker_per_gpu: 1
|
||||
batch_size_per_gpu: 1
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
from os import path as osp
|
||||
from tqdm import tqdm
|
||||
|
||||
from basicsr.utils import scandir_SIDD
|
||||
from basicsr.utils.create_lmdb import create_lmdb_for_BSDS300
|
||||
import sys
|
||||
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/")
|
||||
|
||||
def main():
|
||||
opt = {}
|
||||
opt['n_thread'] = 20
|
||||
opt['compression_level'] = 3
|
||||
|
||||
opt['input_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/Data'
|
||||
opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops'
|
||||
opt['crop_size'] = 512
|
||||
opt['crop_size'] = 30
|
||||
opt['step'] = 384
|
||||
opt['thresh_size'] = 0
|
||||
opt['keywords'] = '_NOISY'
|
||||
extract_subimages(opt)
|
||||
|
||||
opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops'
|
||||
opt['keywords'] = '_GT'
|
||||
extract_subimages(opt)
|
||||
|
||||
create_lmdb_for_BSDS300('/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val')
|
||||
|
||||
|
||||
def extract_subimages(opt):
|
||||
"""Crop images to subimages.
|
||||
Args:
|
||||
opt (dict): Configuration dict. It contains:
|
||||
input_folder (str): Path to the input folder.
|
||||
save_folder (str): Path to save folder.
|
||||
n_thread (int): Thread number.
|
||||
"""
|
||||
input_folder = opt['input_folder']
|
||||
save_folder = opt['save_folder']
|
||||
if not osp.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
print(f'mkdir {save_folder} ...')
|
||||
else:
|
||||
print(f'Folder {save_folder} already exists. Exit.')
|
||||
# sys.exit(1)
|
||||
|
||||
img_list = list(scandir_SIDD(input_folder, keywords=opt['keywords'], recursive=True, full_path=True))
|
||||
|
||||
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
||||
pool = Pool(opt['n_thread'])
|
||||
for path in img_list:
|
||||
pool.apply_async(
|
||||
worker, args=(path, opt), callback=lambda arg: pbar.update(1))
|
||||
pool.close()
|
||||
pool.join()
|
||||
pbar.close()
|
||||
print('All processes done.')
|
||||
|
||||
|
||||
def worker(path, opt):
|
||||
"""Worker for each process.
|
||||
Args:
|
||||
path (str): Image path.
|
||||
opt (dict): Configuration dict. It contains:
|
||||
crop_size (int): Crop size.
|
||||
step (int): Step for overlapped sliding window.
|
||||
thresh_size (int): Threshold size. Patches whose size is lower
|
||||
than thresh_size will be dropped.
|
||||
save_folder (str): Path to save folder.
|
||||
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
||||
Returns:
|
||||
process_info (str): Process information displayed in progress bar.
|
||||
"""
|
||||
crop_size = opt['crop_size']
|
||||
step = opt['step']
|
||||
thresh_size = opt['thresh_size']
|
||||
img_name, extension = osp.splitext(osp.basename(path))
|
||||
|
||||
img_name = img_name.replace(opt['keywords'], '')
|
||||
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
resized_img = cv2.resize(img, (512, 512))
|
||||
|
||||
index = 0
|
||||
cv2.imwrite(
|
||||
osp.join(opt['save_folder'],
|
||||
f'{img_name}_s{index:03d}{extension}'), resized_img,
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
||||
process_info = f'Processing {img_name} ...'
|
||||
return process_info
|
||||
|
||||
if img.ndim == 2:
|
||||
h, w = img.shape
|
||||
elif img.ndim == 3:
|
||||
h, w, c = img.shape
|
||||
else:
|
||||
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')
|
||||
|
||||
h_space = np.arange(0, h - crop_size + 1, step)
|
||||
if h - (h_space[-1] + crop_size) > thresh_size:
|
||||
h_space = np.append(h_space, h - crop_size)
|
||||
w_space = np.arange(0, w - crop_size + 1, step)
|
||||
if w - (w_space[-1] + crop_size) > thresh_size:
|
||||
w_space = np.append(w_space, w - crop_size)
|
||||
|
||||
index = 0
|
||||
for x in h_space:
|
||||
for y in w_space:
|
||||
index += 1
|
||||
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
|
||||
cropped_img = np.ascontiguousarray(cropped_img)
|
||||
cv2.imwrite(
|
||||
osp.join(opt['save_folder'],
|
||||
f'{img_name}_s{index:03d}{extension}'), cropped_img,
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
||||
process_info = f'Processing {img_name} ...'
|
||||
return process_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# ... make sidd to lmdb
|
|
@ -14,7 +14,8 @@ from tqdm import tqdm
|
|||
|
||||
from basicsr.utils import scandir_SIDD
|
||||
from basicsr.utils.create_lmdb import create_lmdb_for_SIDD
|
||||
|
||||
import sys
|
||||
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/")
|
||||
|
||||
def main():
|
||||
opt = {}
|
||||
|
|
Loading…
Reference in New Issue