add random degradation image dataset, remain psf degradation addition

pull/105/head
huangcaohui 2023-05-07 22:58:45 +08:00
parent a8d01a7e3b
commit ccffddb550
4 changed files with 166 additions and 3 deletions

View File

@ -287,6 +287,33 @@ def paths_from_lmdb(folder):
return paths
def paths_from_meta_info_file(folder, meta_info_file, filename_tmpl): # note: add new
"""Generate paths from folder.
Args:
folder (str): Folder path.
meta_info_file (str): Path to the meta information file.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
with open(meta_info_file, 'r') as fin:
input_names = [line.split(' ')[0] for line in fin]
paths = []
for input in input_names:
basename, ext = osp.splitext(osp.basename(input))
name = f'{filename_tmpl.format(basename)}{ext}'
path = osp.join(folder, name)
paths.append(path)
return paths
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
"""Generate Gaussian kernel used in `duf_downsample`.

View File

@ -89,7 +89,7 @@ class PairedImageDataset(data.Dataset):
img_bytes = self.file_client.get(gt_path, 'gt')
try:
img_gt = imfrombytes(img_bytes, float32=True)
except:
except Exception:
raise Exception("gt path {} not working".format(gt_path))
lq_path = self.paths[index]['lq_path']
@ -97,7 +97,7 @@ class PairedImageDataset(data.Dataset):
img_bytes = self.file_client.get(lq_path, 'lq')
try:
img_lq = imfrombytes(img_bytes, float32=True)
except:
except Exception:
raise Exception("lq path {} not working".format(lq_path))

View File

@ -0,0 +1,136 @@
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import (paths_from_folder,
paths_from_meta_info_file,
paths_from_lmdb)
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[3]))
from psf import PSF
class RandomDegradationImageDataset(data.Dataset):
"""Single image dataset for image restoration. Using random degradation for
HQ image to obtain LQ images.
Read HQ (High Quality, e.g. HR (High Resolution), blurry, noisy, etc) only.
There are three modes:
1. 'lmdb': Use lmdb files.
If opt['io_backend'] == lmdb.
2. 'meta_info_file': Use meta information file to generate paths.
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. 'folder': Scan folders to generate paths.
The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq, but not using.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_flip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h
and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(RandomDegradationImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder = opt['dataroot_gt']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.gt_folder]
self.io_backend_opt['client_keys'] = ['gt']
self.paths = paths_from_lmdb(self.gt_folder)
elif 'meta_info_file' in self.opt and self.opt[
'meta_info_file'] is not None:
self.paths = paths_from_meta_info_file(
self.gt_folder, self.opt['meta_info_file'], self.filename_tmpl)
else:
self.paths = paths_from_folder(self.gt_folder)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(
self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]
# print('gt path,', gt_path)
img_bytes = self.file_client.get(gt_path, 'gt')
try:
img_gt = imfrombytes(img_bytes, float32=True)
except Exception:
raise Exception("gt path {} not working".format(gt_path))
# lq_path = self.paths[index]['lq_path']
# # print(', lq path', lq_path)
# img_bytes = self.file_client.get(lq_path, 'lq')
# try:
# img_lq = imfrombytes(img_bytes, float32=True)
# except:
# raise Exception("lq path {} not working".format(lq_path))
img_lq = None
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# padding
img_gt, img_lq = padding(img_gt, img_lq, gt_size)
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale,
gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'],
self.opt['use_rot'])
# TODO: color space transform
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {
'lq': img_lq,
'gt': img_gt,
'gt_path': gt_path
}
def __len__(self):
return len(self.paths)

View File

@ -386,7 +386,7 @@ class ImageRestorationModel(BaseModel):
# self.dist_validation(*args, **kwargs)
def nondist_validation(self, dataloader, current_iter, tb_logger,
save_img, rgb2bgr, use_image):
save_img, rgb2bgr, use_image): # note: add new here
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics: