Codeflow for training with BSDS300 dataset
iSigned-off-by: Ranjan Debnath <m22aie245@gmail.com>pull/141/head
parent
cf36476ea3
commit
273bd48f21
|
@ -39,6 +39,10 @@ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
|
||||||
img = to_y_channel(img)
|
img = to_y_channel(img)
|
||||||
img2 = to_y_channel(img2)
|
img2 = to_y_channel(img2)
|
||||||
|
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
img = img.numpy()
|
||||||
|
if isinstance(img2, torch.Tensor):
|
||||||
|
img2 = img2.numpy()
|
||||||
img = img.astype(np.float64)
|
img = img.astype(np.float64)
|
||||||
img2 = img2.astype(np.float64)
|
img2 = img2.astype(np.float64)
|
||||||
|
|
||||||
|
@ -119,6 +123,10 @@ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
|
||||||
img = to_y_channel(img)
|
img = to_y_channel(img)
|
||||||
img2 = to_y_channel(img2)
|
img2 = to_y_channel(img2)
|
||||||
|
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
img = img.numpy()
|
||||||
|
if isinstance(img2, torch.Tensor):
|
||||||
|
img2 = img2.numpy()
|
||||||
img = img.astype(np.float64)
|
img = img.astype(np.float64)
|
||||||
img2 = img2.astype(np.float64)
|
img2 = img2.astype(np.float64)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,8 @@ from basicsr.train import parse_options
|
||||||
from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
|
from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
|
||||||
make_exp_dirs)
|
make_exp_dirs)
|
||||||
from basicsr.utils.options import dict2str
|
from basicsr.utils.options import dict2str
|
||||||
|
import gc
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# parse options, set distributed setting, set ramdom seed
|
# parse options, set distributed setting, set ramdom seed
|
||||||
|
@ -68,3 +69,4 @@ def main():
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
gc.collect()
|
|
@ -23,6 +23,9 @@ from basicsr.utils import (MessageLogger, check_resume, get_env_info,
|
||||||
set_random_seed)
|
set_random_seed)
|
||||||
from basicsr.utils.dist_util import get_dist_info, init_dist
|
from basicsr.utils.dist_util import get_dist_info, init_dist
|
||||||
from basicsr.utils.options import dict2str, parse
|
from basicsr.utils.options import dict2str, parse
|
||||||
|
import gc
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def parse_options(is_train=True):
|
def parse_options(is_train=True):
|
||||||
|
@ -303,3 +306,4 @@ if __name__ == '__main__':
|
||||||
import os
|
import os
|
||||||
os.environ['GRPC_POLL_STRATEGY']='epoll1'
|
os.environ['GRPC_POLL_STRATEGY']='epoll1'
|
||||||
main()
|
main()
|
||||||
|
gc.collect()
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
import sys
|
||||||
|
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet")
|
||||||
|
|
||||||
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
|
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
|
||||||
from .diffjpeg import DiffJPEG
|
|
||||||
from .file_client import FileClient
|
from .file_client import FileClient
|
||||||
from .img_process_util import USMSharp, usm_sharp
|
|
||||||
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding
|
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding
|
||||||
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
||||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt, scandir_SIDD
|
||||||
from .options import yaml_load
|
from .options import yaml_load
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -35,14 +36,12 @@ __all__ = [
|
||||||
'mkdir_and_rename',
|
'mkdir_and_rename',
|
||||||
'make_exp_dirs',
|
'make_exp_dirs',
|
||||||
'scandir',
|
'scandir',
|
||||||
|
'scandir_SIDD',
|
||||||
'check_resume',
|
'check_resume',
|
||||||
'sizeof_fmt',
|
'sizeof_fmt',
|
||||||
# diffjpeg
|
|
||||||
'DiffJPEG',
|
|
||||||
# img_process_util
|
# img_process_util
|
||||||
'USMSharp',
|
|
||||||
'usm_sharp',
|
|
||||||
# options
|
# options
|
||||||
'yaml_load',
|
'yaml_load',
|
||||||
'padding'
|
'padding'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def rgb2ycbcr(img, y_only=False):
|
||||||
|
"""Convert a RGB image to YCbCr image.
|
||||||
|
|
||||||
|
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
||||||
|
It implements the ITU-R BT.601 conversion for standard-definition
|
||||||
|
television. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||||
|
|
||||||
|
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
||||||
|
In OpenCV, it implements a JPEG conversion. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): The input image. It accepts:
|
||||||
|
1. np.uint8 type with range [0, 255];
|
||||||
|
2. np.float32 type with range [0, 1].
|
||||||
|
y_only (bool): Whether to only return Y channel. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The converted YCbCr image. The output image has the same type
|
||||||
|
and range as input image.
|
||||||
|
"""
|
||||||
|
img_type = img.dtype
|
||||||
|
img = _convert_input_type_range(img)
|
||||||
|
if y_only:
|
||||||
|
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
||||||
|
else:
|
||||||
|
out_img = np.matmul(
|
||||||
|
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
||||||
|
out_img = _convert_output_type_range(out_img, img_type)
|
||||||
|
return out_img
|
||||||
|
|
||||||
|
|
||||||
|
def bgr2ycbcr(img, y_only=False):
|
||||||
|
"""Convert a BGR image to YCbCr image.
|
||||||
|
|
||||||
|
The bgr version of rgb2ycbcr.
|
||||||
|
It implements the ITU-R BT.601 conversion for standard-definition
|
||||||
|
television. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||||
|
|
||||||
|
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||||
|
In OpenCV, it implements a JPEG conversion. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): The input image. It accepts:
|
||||||
|
1. np.uint8 type with range [0, 255];
|
||||||
|
2. np.float32 type with range [0, 1].
|
||||||
|
y_only (bool): Whether to only return Y channel. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The converted YCbCr image. The output image has the same type
|
||||||
|
and range as input image.
|
||||||
|
"""
|
||||||
|
img_type = img.dtype
|
||||||
|
img = _convert_input_type_range(img)
|
||||||
|
if y_only:
|
||||||
|
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||||
|
else:
|
||||||
|
out_img = np.matmul(
|
||||||
|
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||||
|
out_img = _convert_output_type_range(out_img, img_type)
|
||||||
|
return out_img
|
||||||
|
|
||||||
|
|
||||||
|
def ycbcr2rgb(img):
|
||||||
|
"""Convert a YCbCr image to RGB image.
|
||||||
|
|
||||||
|
This function produces the same results as Matlab's ycbcr2rgb function.
|
||||||
|
It implements the ITU-R BT.601 conversion for standard-definition
|
||||||
|
television. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||||
|
|
||||||
|
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
||||||
|
In OpenCV, it implements a JPEG conversion. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): The input image. It accepts:
|
||||||
|
1. np.uint8 type with range [0, 255];
|
||||||
|
2. np.float32 type with range [0, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The converted RGB image. The output image has the same type
|
||||||
|
and range as input image.
|
||||||
|
"""
|
||||||
|
img_type = img.dtype
|
||||||
|
img = _convert_input_type_range(img) * 255
|
||||||
|
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||||
|
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
||||||
|
out_img = _convert_output_type_range(out_img, img_type)
|
||||||
|
return out_img
|
||||||
|
|
||||||
|
|
||||||
|
def ycbcr2bgr(img):
|
||||||
|
"""Convert a YCbCr image to BGR image.
|
||||||
|
|
||||||
|
The bgr version of ycbcr2rgb.
|
||||||
|
It implements the ITU-R BT.601 conversion for standard-definition
|
||||||
|
television. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||||
|
|
||||||
|
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
||||||
|
In OpenCV, it implements a JPEG conversion. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): The input image. It accepts:
|
||||||
|
1. np.uint8 type with range [0, 255];
|
||||||
|
2. np.float32 type with range [0, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ndarray: The converted BGR image. The output image has the same type
|
||||||
|
and range as input image.
|
||||||
|
"""
|
||||||
|
img_type = img.dtype
|
||||||
|
img = _convert_input_type_range(img) * 255
|
||||||
|
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
||||||
|
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
||||||
|
out_img = _convert_output_type_range(out_img, img_type)
|
||||||
|
return out_img
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_input_type_range(img):
|
||||||
|
"""Convert the type and range of the input image.
|
||||||
|
|
||||||
|
It converts the input image to np.float32 type and range of [0, 1].
|
||||||
|
It is mainly used for pre-processing the input image in colorspace
|
||||||
|
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): The input image. It accepts:
|
||||||
|
1. np.uint8 type with range [0, 255];
|
||||||
|
2. np.float32 type with range [0, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(ndarray): The converted image with type of np.float32 and range of
|
||||||
|
[0, 1].
|
||||||
|
"""
|
||||||
|
img_type = img.dtype
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
if img_type == np.float32:
|
||||||
|
pass
|
||||||
|
elif img_type == np.uint8:
|
||||||
|
img /= 255.
|
||||||
|
else:
|
||||||
|
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_output_type_range(img, dst_type):
|
||||||
|
"""Convert the type and range of the image according to dst_type.
|
||||||
|
|
||||||
|
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||||
|
images will be converted to np.uint8 type with range [0, 255]. If
|
||||||
|
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||||
|
range [0, 1].
|
||||||
|
It is mainly used for post-processing images in colorspace conversion
|
||||||
|
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (ndarray): The image to be converted with np.float32 type and
|
||||||
|
range [0, 255].
|
||||||
|
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||||
|
converts the image to np.uint8 type with range [0, 255]. If
|
||||||
|
dst_type is np.float32, it converts the image to np.float32 type
|
||||||
|
with range [0, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(ndarray): The converted image with desired type and range.
|
||||||
|
"""
|
||||||
|
if dst_type not in (np.uint8, np.float32):
|
||||||
|
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
||||||
|
if dst_type == np.uint8:
|
||||||
|
img = img.round()
|
||||||
|
else:
|
||||||
|
img /= 255.
|
||||||
|
return img.astype(dst_type)
|
||||||
|
|
||||||
|
|
||||||
|
def rgb2ycbcr_pt(img, y_only=False):
|
||||||
|
"""Convert RGB images to YCbCr images (PyTorch version).
|
||||||
|
|
||||||
|
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
|
||||||
|
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
|
||||||
|
y_only (bool): Whether to only return Y channel. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
|
||||||
|
"""
|
||||||
|
if y_only:
|
||||||
|
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
|
||||||
|
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
||||||
|
else:
|
||||||
|
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
|
||||||
|
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
|
||||||
|
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
||||||
|
|
||||||
|
out_img = out_img / 255.
|
||||||
|
return out_img
|
|
@ -21,6 +21,7 @@ def prepare_keys(folder_path, suffix='png'):
|
||||||
list[str]: Key list.
|
list[str]: Key list.
|
||||||
"""
|
"""
|
||||||
print('Reading image path list ...')
|
print('Reading image path list ...')
|
||||||
|
# import ipdb; ipdb.set_trace()
|
||||||
img_path_list = sorted(
|
img_path_list = sorted(
|
||||||
list(scandir(folder_path, suffix=suffix, recursive=False)))
|
list(scandir(folder_path, suffix=suffix, recursive=False)))
|
||||||
keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
|
keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
|
||||||
|
@ -131,3 +132,17 @@ def create_lmdb_for_SIDD():
|
||||||
img_path_list, keys = prepare_keys(folder_path, 'png')
|
img_path_list, keys = prepare_keys(folder_path, 'png')
|
||||||
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
def create_lmdb_for_BSDS300(root_path=''):
|
||||||
|
if root_path != '':
|
||||||
|
folder_path = f'{root_path}/input_crops'
|
||||||
|
lmdb_path = f'{root_path}/input_crops.lmdb'
|
||||||
|
|
||||||
|
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
||||||
|
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
||||||
|
|
||||||
|
folder_path = f'{root_path}/gt_crops'
|
||||||
|
lmdb_path = f'{root_path}/gt_crops.lmdb'
|
||||||
|
|
||||||
|
img_path_list, keys = prepare_keys(folder_path, 'jpg')
|
||||||
|
make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
|
|
@ -145,3 +145,45 @@ def sizeof_fmt(size, suffix='B'):
|
||||||
return f'{size:3.1f} {unit}{suffix}'
|
return f'{size:3.1f} {unit}{suffix}'
|
||||||
size /= 1024.0
|
size /= 1024.0
|
||||||
return f'{size:3.1f} Y{suffix}'
|
return f'{size:3.1f} Y{suffix}'
|
||||||
|
|
||||||
|
def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
|
||||||
|
"""Scan a directory to find the interested files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir_path (str): Path of the directory.
|
||||||
|
keywords (str | tuple(str), optional): File keywords that we are
|
||||||
|
interested in. Default: None.
|
||||||
|
recursive (bool, optional): If set to True, recursively scan the
|
||||||
|
directory. Default: False.
|
||||||
|
full_path (bool, optional): If set to True, include the dir_path.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator for all the interested files with relative pathes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (keywords is not None) and not isinstance(keywords, (str, tuple)):
|
||||||
|
raise TypeError('"keywords" must be a string or tuple of strings')
|
||||||
|
|
||||||
|
root = dir_path
|
||||||
|
|
||||||
|
def _scandir(dir_path, keywords, recursive):
|
||||||
|
for entry in os.scandir(dir_path):
|
||||||
|
if not entry.name.startswith('.') and entry.is_file():
|
||||||
|
if full_path:
|
||||||
|
return_path = entry.path
|
||||||
|
else:
|
||||||
|
return_path = osp.relpath(entry.path, root)
|
||||||
|
|
||||||
|
if keywords is None:
|
||||||
|
yield return_path
|
||||||
|
elif return_path.find(keywords) > 0:
|
||||||
|
yield return_path
|
||||||
|
else:
|
||||||
|
if recursive:
|
||||||
|
yield from _scandir(
|
||||||
|
entry.path, keywords=keywords, recursive=recursive)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return _scandir(dir_path, keywords=keywords, recursive=recursive)
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
class Registry():
|
||||||
|
"""
|
||||||
|
The registry that provides name -> object mapping, to support third-party
|
||||||
|
users' custom modules.
|
||||||
|
|
||||||
|
To create a registry (e.g. a backbone registry):
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||||
|
|
||||||
|
To register an object:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@BACKBONE_REGISTRY.register()
|
||||||
|
class MyBackbone():
|
||||||
|
...
|
||||||
|
|
||||||
|
Or:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
BACKBONE_REGISTRY.register(MyBackbone)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
name (str): the name of this registry
|
||||||
|
"""
|
||||||
|
self._name = name
|
||||||
|
self._obj_map = {}
|
||||||
|
|
||||||
|
def _do_register(self, name, obj, suffix=None):
|
||||||
|
if isinstance(suffix, str):
|
||||||
|
name = name + '_' + suffix
|
||||||
|
|
||||||
|
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
|
||||||
|
f"in '{self._name}' registry!")
|
||||||
|
self._obj_map[name] = obj
|
||||||
|
|
||||||
|
def register(self, obj=None, suffix=None):
|
||||||
|
"""
|
||||||
|
Register the given object under the the name `obj.__name__`.
|
||||||
|
Can be used as either a decorator or not.
|
||||||
|
See docstring of this class for usage.
|
||||||
|
"""
|
||||||
|
if obj is None:
|
||||||
|
# used as a decorator
|
||||||
|
def deco(func_or_class):
|
||||||
|
name = func_or_class.__name__
|
||||||
|
self._do_register(name, func_or_class, suffix)
|
||||||
|
return func_or_class
|
||||||
|
|
||||||
|
return deco
|
||||||
|
|
||||||
|
# used as a function call
|
||||||
|
name = obj.__name__
|
||||||
|
self._do_register(name, obj, suffix)
|
||||||
|
|
||||||
|
def get(self, name, suffix='basicsr'):
|
||||||
|
ret = self._obj_map.get(name)
|
||||||
|
if ret is None:
|
||||||
|
ret = self._obj_map.get(name + '_' + suffix)
|
||||||
|
print(f'Name {name} is not found, use name: {name}_{suffix}!')
|
||||||
|
if ret is None:
|
||||||
|
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __contains__(self, name):
|
||||||
|
return name in self._obj_map
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._obj_map.items())
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self._obj_map.keys()
|
||||||
|
|
||||||
|
|
||||||
|
DATASET_REGISTRY = Registry('dataset')
|
||||||
|
ARCH_REGISTRY = Registry('arch')
|
||||||
|
MODEL_REGISTRY = Registry('model')
|
||||||
|
LOSS_REGISTRY = Registry('loss')
|
||||||
|
METRIC_REGISTRY = Registry('metric')
|
|
@ -1,6 +1,6 @@
|
||||||
# GENERATED VERSION FILE
|
# GENERATED VERSION FILE
|
||||||
# TIME: Sat Apr 20 12:26:03 2024
|
# TIME: Sat Apr 20 19:01:02 2024
|
||||||
__version__ = '1.2.0+2b4af71'
|
__version__ = '1.2.0+cf36476'
|
||||||
__gitsha__ = '2b4af71'
|
__gitsha__ = 'cf36476'
|
||||||
short_version = '1.2.0'
|
short_version = '1.2.0'
|
||||||
version_info = (1, 2, 0)
|
version_info = (1, 2, 0)
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 61 KiB |
BIN
demo/noisy.png
BIN
demo/noisy.png
Binary file not shown.
Before Width: | Height: | Size: 167 KiB |
|
@ -0,0 +1,91 @@
|
||||||
|
# reproduce the SIDD dataset results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 1. Data Preparation
|
||||||
|
|
||||||
|
##### Download the train set and place it in ```./datasets/SIDD/Data```:
|
||||||
|
|
||||||
|
* [google drive](https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1EnBVjrfFBiXIRPBgjFrifg?pwd=sl6h),
|
||||||
|
* ```python scripts/data_preparation/sidd.py``` to crop the train image pairs to 512x512 patches and make the data into lmdb format.
|
||||||
|
|
||||||
|
##### Download the evaluation data (in lmdb format) and place it in ```./datasets/SIDD/val/```:
|
||||||
|
|
||||||
|
* [google drive](https://drive.google.com/file/d/1gZx_K2vmiHalRNOb1aj93KuUQ2guOlLp/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1I9N5fDa4SNP0nuHEy6k-rw?pwd=59d7),
|
||||||
|
* it should be like ```./datasets/SIDD/val/input_crops.lmdb``` and ```./datasets/SIDD/val/gt_crops.lmdb```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Training
|
||||||
|
|
||||||
|
* NAFNet-SIDD-width32:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width32.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* NAFNet-SIDD-width64:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width64.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* Baseline-SIDD-width32:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width32.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* Baseline-SIDD-width64:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width64.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Evaluation
|
||||||
|
|
||||||
|
|
||||||
|
##### Download the pretrain model in ```./experiments/pretrained_models/```
|
||||||
|
|
||||||
|
* **NAFNet-SIDD-width32**: [google drive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw?pwd=um97)
|
||||||
|
|
||||||
|
* **NAFNet-SIDD-width64**: [google drive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ?pwd=dton)
|
||||||
|
|
||||||
|
* **Baseline-SIDD-width32**: [google drive](https://drive.google.com/file/d/1NhqVcqkDcYvYgF_P4BOOfo9tuTcKDuhW/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1wkskmCRKhXq6dGa6Ns8D0A?pwd=0rin)
|
||||||
|
|
||||||
|
* **Baseline-SIDD-width64**: [google drive](https://drive.google.com/file/d/1wQ1HHHPhSp70_ledMBZhDhIGjZQs16wO/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1ivruGfSRGfWq5AEB8qc7YQ?pwd=t9w8)
|
||||||
|
|
||||||
|
|
||||||
|
##### Testing on SIDD dataset
|
||||||
|
|
||||||
|
* NAFNet-SIDD-width32:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width32.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* NAFNet-SIDD-width64:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width64.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* Baseline-SIDD-width32:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width32.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* Baseline-SIDD-width64:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width64.yml --launcher pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
* Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation.
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||||
|
# Copyright 2018-2020 BasicSR Authors
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# general settings
|
||||||
|
name: NAFNet-BSDS-width64-test
|
||||||
|
model_type: ImageRestorationModel
|
||||||
|
scale: 1
|
||||||
|
num_gpu: 1 # set num_gpu: 0 for cpu mode
|
||||||
|
manual_seed: 10
|
||||||
|
|
||||||
|
# dataset and data loader settings
|
||||||
|
datasets:
|
||||||
|
|
||||||
|
val:
|
||||||
|
name: BSDS_val
|
||||||
|
type: PairedImageDataset
|
||||||
|
|
||||||
|
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb
|
||||||
|
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb
|
||||||
|
|
||||||
|
io_backend:
|
||||||
|
type: lmdb
|
||||||
|
|
||||||
|
# network structures
|
||||||
|
network_g:
|
||||||
|
type: NAFNet
|
||||||
|
width: 64
|
||||||
|
enc_blk_nums: [2, 2, 4, 8]
|
||||||
|
middle_blk_num: 12
|
||||||
|
dec_blk_nums: [2, 2, 2, 2]
|
||||||
|
|
||||||
|
# path
|
||||||
|
path:
|
||||||
|
pretrain_network_g: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/experiments/NAFNet-BSDS-width64/models/net_g_latest.pth
|
||||||
|
strict_load_g: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
# validation settings
|
||||||
|
val:
|
||||||
|
save_img: true
|
||||||
|
grids: false
|
||||||
|
use_image: false
|
||||||
|
|
||||||
|
metrics:
|
||||||
|
psnr: # metric name, can be arbitrary
|
||||||
|
type: calculate_psnr
|
||||||
|
crop_border: 0
|
||||||
|
test_y_channel: false
|
||||||
|
ssim:
|
||||||
|
type: calculate_ssim
|
||||||
|
crop_border: 0
|
||||||
|
test_y_channel: false
|
||||||
|
|
||||||
|
# dist training settings
|
||||||
|
dist_params:
|
||||||
|
backend: nccl
|
||||||
|
port: 29500
|
|
@ -0,0 +1,109 @@
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||||
|
# Copyright 2018-2020 BasicSR Authors
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# general settings
|
||||||
|
name: NAFNet-BSDS-width64
|
||||||
|
model_type: ImageRestorationModel
|
||||||
|
scale: 1
|
||||||
|
num_gpu: 1
|
||||||
|
manual_seed: 10
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
name: BSDS
|
||||||
|
type: PairedImageDataset
|
||||||
|
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/gt_crops.lmdb
|
||||||
|
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/input_crops.lmdb
|
||||||
|
|
||||||
|
filename_tmpl: '{}'
|
||||||
|
io_backend:
|
||||||
|
type: lmdb
|
||||||
|
|
||||||
|
gt_size: 256
|
||||||
|
use_flip: false
|
||||||
|
use_rot: false
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
use_shuffle: true
|
||||||
|
num_worker_per_gpu: 1
|
||||||
|
batch_size_per_gpu: 1
|
||||||
|
dataset_enlarge_ratio: 1
|
||||||
|
prefetch_mode: ~
|
||||||
|
|
||||||
|
val:
|
||||||
|
name: BSDS_val
|
||||||
|
type: PairedImageDataset
|
||||||
|
dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb
|
||||||
|
dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb
|
||||||
|
io_backend:
|
||||||
|
type: lmdb
|
||||||
|
|
||||||
|
|
||||||
|
network_g:
|
||||||
|
type: NAFNet
|
||||||
|
width: 64
|
||||||
|
enc_blk_nums: [2, 2, 4, 8]
|
||||||
|
middle_blk_num: 12
|
||||||
|
dec_blk_nums: [2, 2, 2, 2]
|
||||||
|
|
||||||
|
# path
|
||||||
|
path:
|
||||||
|
pretrain_network_g: ~
|
||||||
|
strict_load_g: true
|
||||||
|
resume_state: ~
|
||||||
|
|
||||||
|
# training settings
|
||||||
|
train:
|
||||||
|
optim_g:
|
||||||
|
type: AdamW
|
||||||
|
lr: !!float 1e-3
|
||||||
|
weight_decay: 0.
|
||||||
|
betas: [0.9, 0.9]
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
type: TrueCosineAnnealingLR
|
||||||
|
T_max: 400000
|
||||||
|
eta_min: !!float 1e-7
|
||||||
|
|
||||||
|
# total_iter: 200000
|
||||||
|
total_iter: 123076
|
||||||
|
warmup_iter: -1 # no warm up
|
||||||
|
|
||||||
|
# losses
|
||||||
|
pixel_opt:
|
||||||
|
type: PSNRLoss
|
||||||
|
loss_weight: 1
|
||||||
|
reduction: mean
|
||||||
|
|
||||||
|
# validation settings
|
||||||
|
val:
|
||||||
|
val_freq: !!float 2e4
|
||||||
|
save_img: false
|
||||||
|
use_image: false
|
||||||
|
|
||||||
|
metrics:
|
||||||
|
psnr: # metric name, can be arbitrary
|
||||||
|
type: calculate_psnr
|
||||||
|
crop_border: 0
|
||||||
|
test_y_channel: false
|
||||||
|
ssim:
|
||||||
|
type: calculate_ssim
|
||||||
|
crop_border: 0
|
||||||
|
test_y_channel: false
|
||||||
|
|
||||||
|
# logging settings
|
||||||
|
logger:
|
||||||
|
print_freq: 200
|
||||||
|
save_checkpoint_freq: !!float 5e3
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb:
|
||||||
|
project: ~
|
||||||
|
resume_id: ~
|
||||||
|
|
||||||
|
# dist training settings
|
||||||
|
dist_params:
|
||||||
|
backend: nccl
|
||||||
|
port: 29500
|
|
@ -28,8 +28,8 @@ datasets:
|
||||||
|
|
||||||
# data loader
|
# data loader
|
||||||
use_shuffle: true
|
use_shuffle: true
|
||||||
num_worker_per_gpu: 8
|
num_worker_per_gpu: 1
|
||||||
batch_size_per_gpu: 8
|
batch_size_per_gpu: 1
|
||||||
dataset_enlarge_ratio: 1
|
dataset_enlarge_ratio: 1
|
||||||
prefetch_mode: ~
|
prefetch_mode: ~
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||||
|
# Copyright 2018-2020 BasicSR Authors
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from os import path as osp
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from basicsr.utils import scandir_SIDD
|
||||||
|
from basicsr.utils.create_lmdb import create_lmdb_for_BSDS300
|
||||||
|
import sys
|
||||||
|
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
opt = {}
|
||||||
|
opt['n_thread'] = 20
|
||||||
|
opt['compression_level'] = 3
|
||||||
|
|
||||||
|
opt['input_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/Data'
|
||||||
|
opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops'
|
||||||
|
opt['crop_size'] = 512
|
||||||
|
opt['crop_size'] = 30
|
||||||
|
opt['step'] = 384
|
||||||
|
opt['thresh_size'] = 0
|
||||||
|
opt['keywords'] = '_NOISY'
|
||||||
|
extract_subimages(opt)
|
||||||
|
|
||||||
|
opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops'
|
||||||
|
opt['keywords'] = '_GT'
|
||||||
|
extract_subimages(opt)
|
||||||
|
|
||||||
|
create_lmdb_for_BSDS300('/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val')
|
||||||
|
|
||||||
|
|
||||||
|
def extract_subimages(opt):
|
||||||
|
"""Crop images to subimages.
|
||||||
|
Args:
|
||||||
|
opt (dict): Configuration dict. It contains:
|
||||||
|
input_folder (str): Path to the input folder.
|
||||||
|
save_folder (str): Path to save folder.
|
||||||
|
n_thread (int): Thread number.
|
||||||
|
"""
|
||||||
|
input_folder = opt['input_folder']
|
||||||
|
save_folder = opt['save_folder']
|
||||||
|
if not osp.exists(save_folder):
|
||||||
|
os.makedirs(save_folder)
|
||||||
|
print(f'mkdir {save_folder} ...')
|
||||||
|
else:
|
||||||
|
print(f'Folder {save_folder} already exists. Exit.')
|
||||||
|
# sys.exit(1)
|
||||||
|
|
||||||
|
img_list = list(scandir_SIDD(input_folder, keywords=opt['keywords'], recursive=True, full_path=True))
|
||||||
|
|
||||||
|
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
||||||
|
pool = Pool(opt['n_thread'])
|
||||||
|
for path in img_list:
|
||||||
|
pool.apply_async(
|
||||||
|
worker, args=(path, opt), callback=lambda arg: pbar.update(1))
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
pbar.close()
|
||||||
|
print('All processes done.')
|
||||||
|
|
||||||
|
|
||||||
|
def worker(path, opt):
|
||||||
|
"""Worker for each process.
|
||||||
|
Args:
|
||||||
|
path (str): Image path.
|
||||||
|
opt (dict): Configuration dict. It contains:
|
||||||
|
crop_size (int): Crop size.
|
||||||
|
step (int): Step for overlapped sliding window.
|
||||||
|
thresh_size (int): Threshold size. Patches whose size is lower
|
||||||
|
than thresh_size will be dropped.
|
||||||
|
save_folder (str): Path to save folder.
|
||||||
|
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
||||||
|
Returns:
|
||||||
|
process_info (str): Process information displayed in progress bar.
|
||||||
|
"""
|
||||||
|
crop_size = opt['crop_size']
|
||||||
|
step = opt['step']
|
||||||
|
thresh_size = opt['thresh_size']
|
||||||
|
img_name, extension = osp.splitext(osp.basename(path))
|
||||||
|
|
||||||
|
img_name = img_name.replace(opt['keywords'], '')
|
||||||
|
|
||||||
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
resized_img = cv2.resize(img, (512, 512))
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
cv2.imwrite(
|
||||||
|
osp.join(opt['save_folder'],
|
||||||
|
f'{img_name}_s{index:03d}{extension}'), resized_img,
|
||||||
|
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
||||||
|
process_info = f'Processing {img_name} ...'
|
||||||
|
return process_info
|
||||||
|
|
||||||
|
if img.ndim == 2:
|
||||||
|
h, w = img.shape
|
||||||
|
elif img.ndim == 3:
|
||||||
|
h, w, c = img.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')
|
||||||
|
|
||||||
|
h_space = np.arange(0, h - crop_size + 1, step)
|
||||||
|
if h - (h_space[-1] + crop_size) > thresh_size:
|
||||||
|
h_space = np.append(h_space, h - crop_size)
|
||||||
|
w_space = np.arange(0, w - crop_size + 1, step)
|
||||||
|
if w - (w_space[-1] + crop_size) > thresh_size:
|
||||||
|
w_space = np.append(w_space, w - crop_size)
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
for x in h_space:
|
||||||
|
for y in w_space:
|
||||||
|
index += 1
|
||||||
|
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
|
||||||
|
cropped_img = np.ascontiguousarray(cropped_img)
|
||||||
|
cv2.imwrite(
|
||||||
|
osp.join(opt['save_folder'],
|
||||||
|
f'{img_name}_s{index:03d}{extension}'), cropped_img,
|
||||||
|
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
||||||
|
process_info = f'Processing {img_name} ...'
|
||||||
|
return process_info
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
# ... make sidd to lmdb
|
|
@ -14,7 +14,8 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from basicsr.utils import scandir_SIDD
|
from basicsr.utils import scandir_SIDD
|
||||||
from basicsr.utils.create_lmdb import create_lmdb_for_SIDD
|
from basicsr.utils.create_lmdb import create_lmdb_for_SIDD
|
||||||
|
import sys
|
||||||
|
sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
opt = {}
|
opt = {}
|
||||||
|
|
Loading…
Reference in New Issue