diff --git a/.eggs/README.txt b/.eggs/README.txt new file mode 100644 index 0000000..5d01668 --- /dev/null +++ b/.eggs/README.txt @@ -0,0 +1,6 @@ +This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins. + +This directory caches those eggs to prevent repeated downloads. + +However, it is safe to delete this directory. + diff --git a/.gitignore b/.gitignore index 032d2dd..a40fb55 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,6 @@ logs/ *__pycache__* *.sh datasets -basicsr.egg-info \ No newline at end of file +basicsr.egg-info +.eggs/* +build/* \ No newline at end of file diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py index 83fc13d..4fb044a 100644 --- a/basicsr/metrics/__init__.py +++ b/basicsr/metrics/__init__.py @@ -1,10 +1,20 @@ -# ------------------------------------------------------------------------ -# Copyright (c) 2022 megvii-model. All Rights Reserved. -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ -from .niqe import calculate_niqe -from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_left, calculate_psnr_left, calculate_skimage_ssim, calculate_skimage_ssim_left +from copy import deepcopy -__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe', 'calculate_ssim_left', 'calculate_psnr_left', 'calculate_skimage_ssim', 'calculate_skimage_ssim_left'] +from basicsr.utils.registry import METRIC_REGISTRY +from .niqe import calculate_niqe +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must contain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/basicsr/metrics/fid.py b/basicsr/metrics/fid.py index 102ba91..1b0ba6d 100644 --- a/basicsr/metrics/fid.py +++ b/basicsr/metrics/fid.py @@ -1,35 +1,22 @@ -# ------------------------------------------------------------------------ -# Copyright (c) 2022 megvii-model. All Rights Reserved. -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ import numpy as np import torch import torch.nn as nn from scipy import linalg from tqdm import tqdm -from basicsr.models.archs.inception import InceptionV3 +from basicsr.archs.inception import InceptionV3 -def load_patched_inception_v3(device='cuda', - resize_input=True, - normalize_input=False): +def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): # we may not resize the input, but in [rosinality/stylegan2-pytorch] it # does resize the input. - inception = InceptionV3([3], - resize_input=resize_input, - normalize_input=normalize_input) + inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) inception = nn.DataParallel(inception).eval().to(device) return inception @torch.no_grad() -def extract_inception_features(data_generator, - inception, - len_generator=None, - device='cuda'): +def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): """Extract inception features. Args: @@ -63,33 +50,27 @@ def extract_inception_features(data_generator, def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): """Numpy implementation of the Frechet Distance. - The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) - and X_2 ~ N(mu_2, C_2) is - d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). Stable version by Dougal J. Sutherland. Args: mu1 (np.array): The sample mean over activations. - sigma1 (np.array): The covariance matrix over activations for - generated samples. - mu2 (np.array): The sample mean over activations, precalculated on an - representative data set. - sigma2 (np.array): The covariance matrix over activations, - precalculated on an representative data set. + sigma1 (np.array): The covariance matrix over activations for generated samples. + mu2 (np.array): The sample mean over activations, precalculated on an representative data set. + sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. Returns: float: The Frechet Distance. """ assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' - assert sigma1.shape == sigma2.shape, ( - 'Two covariances have different dimensions') + assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) # Product might be almost singular if not np.isfinite(cov_sqrt).all(): - print('Product of cov matrices is singular. Adding {eps} to diagonal ' - 'of cov estimates') + print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') offset = np.eye(sigma1.shape[0]) * eps cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py index 9764769..2a27c70 100644 --- a/basicsr/metrics/metric_util.py +++ b/basicsr/metrics/metric_util.py @@ -1,12 +1,6 @@ -# ------------------------------------------------------------------------ -# Copyright (c) 2022 megvii-model. All Rights Reserved. -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ import numpy as np -from basicsr.utils.matlab_functions import bgr2ycbcr +from basicsr.utils import bgr2ycbcr def reorder_image(img, input_order='HWC'): @@ -27,9 +21,7 @@ def reorder_image(img, input_order='HWC'): """ if input_order not in ['HWC', 'CHW']: - raise ValueError( - f'Wrong input_order {input_order}. Supported input_orders are ' - "'HWC' and 'CHW'") + raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") if len(img.shape) == 2: img = img[..., None] if input_order == 'CHW': diff --git a/basicsr/metrics/niqe.py b/basicsr/metrics/niqe.py index 6bf2019..e3c1467 100644 --- a/basicsr/metrics/niqe.py +++ b/basicsr/metrics/niqe.py @@ -1,20 +1,17 @@ -# ------------------------------------------------------------------------ -# Copyright (c) 2022 megvii-model. All Rights Reserved. -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ import cv2 import math import numpy as np -from scipy.ndimage.filters import convolve +import os +from scipy.ndimage import convolve from scipy.special import gamma from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.matlab_functions import imresize +from basicsr.utils.registry import METRIC_REGISTRY def estimate_aggd_param(block): - """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters. + """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters. Args: block (ndarray): 2D Image block. @@ -26,15 +23,13 @@ def estimate_aggd_param(block): block = block.flatten() gam = np.arange(0.2, 10.001, 0.001) # len = 9801 gam_reciprocal = np.reciprocal(gam) - r_gam = np.square(gamma(gam_reciprocal * 2)) / ( - gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) + r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3)) left_std = np.sqrt(np.mean(block[block < 0]**2)) right_std = np.sqrt(np.mean(block[block > 0]**2)) gammahat = left_std / right_std rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2) - rhatnorm = (rhat * (gammahat**3 + 1) * - (gammahat + 1)) / ((gammahat**2 + 1)**2) + rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2) array_position = np.argmin((r_gam - rhatnorm)**2) alpha = gam[array_position] @@ -70,22 +65,18 @@ def compute_feature(block): return feat -def niqe(img, - mu_pris_param, - cov_pris_param, - gaussian_window, - block_size_h=96, - block_size_w=96): +def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96): """Calculate NIQE (Natural Image Quality Evaluator) metric. - Ref: Making a "Completely Blind" Image Quality Analyzer. + ``Paper: Making a "Completely Blind" Image Quality Analyzer`` + This implementation could produce almost the same results as the official MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip Note that we do not include block overlap height and width, since they are always 0 in the official implementation. - For good performance, it is advisable by the official implemtation to + For good performance, it is advisable by the official implementation to divide the distorted image in to the same size patched as used for the construction of multivariate Gaussian model. @@ -104,8 +95,7 @@ def niqe(img, block_size_w (int): Width of the blocks in to which image is divided. Default: 96 (the official recommended value). """ - assert img.ndim == 2, ( - 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).') + assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).') # crop image h, w = img.shape num_block_h = math.floor(h / block_size_h) @@ -115,10 +105,7 @@ def niqe(img, distparam = [] # dist param is actually the multiscale features for scale in (1, 2): # perform on two scales (1, 2) mu = convolve(img, gaussian_window, mode='nearest') - sigma = np.sqrt( - np.abs( - convolve(np.square(img), gaussian_window, mode='nearest') - - np.square(mu))) + sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu))) # normalize, as in Eq. 1 in the paper img_nomalized = (img - mu) / (sigma + 1) @@ -126,21 +113,14 @@ def niqe(img, for idx_w in range(num_block_w): for idx_h in range(num_block_h): # process ecah block - block = img_nomalized[idx_h * block_size_h // - scale:(idx_h + 1) * block_size_h // - scale, idx_w * block_size_w // - scale:(idx_w + 1) * block_size_w // - scale] + block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale, + idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale] feat.append(compute_feature(block)) distparam.append(np.array(feat)) - # TODO: matlab bicubic downsample with anti-aliasing - # for simplicity, now we use opencv instead, which will result in - # a slight difference. + if scale == 1: - h, w = img.shape - img = cv2.resize( - img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR) + img = imresize(img / 255., scale=0.5, antialiasing=True) img = img * 255. distparam = np.concatenate(distparam, axis=1) @@ -154,20 +134,25 @@ def niqe(img, # compute niqe quality, Eq. 10 in the paper invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2) quality = np.matmul( - np.matmul((mu_pris_param - mu_distparam), invcov_param), - np.transpose((mu_pris_param - mu_distparam))) - quality = np.sqrt(quality) + np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam))) + quality = np.sqrt(quality) + quality = float(np.squeeze(quality)) return quality -def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'): +@METRIC_REGISTRY.register() +def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs): """Calculate NIQE (Natural Image Quality Evaluator) metric. - Ref: Making a "Completely Blind" Image Quality Analyzer. + ``Paper: Making a "Completely Blind" Image Quality Analyzer`` + This implementation could produce almost the same results as the official MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip + > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296) + > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296) + We use the official params estimated from the pristine dataset. We use the recommended block size (96, 96) without overlaps. @@ -181,15 +166,15 @@ def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'): pixels are not involved in the metric calculation. input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'. Default: 'HWC'. - convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'. + convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'. Default: 'y'. Returns: float: NIQE result. """ - + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) # we use the official params estimated from the pristine dataset. - niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz') + niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz')) mu_pris_param = niqe_pris_params['mu_pris_param'] cov_pris_param = niqe_pris_params['cov_pris_param'] gaussian_window = niqe_pris_params['gaussian_window'] @@ -206,6 +191,9 @@ def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'): if crop_border != 0: img = img[crop_border:-crop_border, crop_border:-crop_border] + # round is necessary for being consistent with MATLAB's result + img = img.round() + niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window) return niqe_result diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py index 317a7b5..b1a95ae 100644 --- a/basicsr/metrics/psnr_ssim.py +++ b/basicsr/metrics/psnr_ssim.py @@ -1,263 +1,95 @@ -# ------------------------------------------------------------------------ -# Copyright (c) 2022 megvii-model. All Rights Reserved. -# ------------------------------------------------------------------------ -# modified from https://github.com/mayorx/matlab_ssim_pytorch_implementation/blob/main/calc_ssim.py -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ import cv2 import numpy as np +import torch +import torch.nn.functional as F from basicsr.metrics.metric_util import reorder_image, to_y_channel -from skimage.metrics import structural_similarity -import torch +from basicsr.utils.color_util import rgb2ycbcr_pt +from basicsr.utils.registry import METRIC_REGISTRY -def calculate_psnr(img1, - img2, - crop_border, - input_order='HWC', - test_y_channel=False): + +@METRIC_REGISTRY.register() +def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): """Calculate PSNR (Peak Signal-to-Noise Ratio). - Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio Args: - img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. - img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. - crop_border (int): Cropped pixels in each edge of an image. These - pixels are not involved in the PSNR calculation. - input_order (str): Whether the input order is 'HWC' or 'CHW'. - Default: 'HWC'. + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. test_y_channel (bool): Test on Y channel of YCbCr. Default: False. Returns: - float: psnr result. + float: PSNR result. """ - assert img1.shape == img2.shape, ( - f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') if input_order not in ['HWC', 'CHW']: - raise ValueError( - f'Wrong input_order {input_order}. Supported input_orders are ' - '"HWC" and "CHW"') - if type(img1) == torch.Tensor: - if len(img1.shape) == 4: - img1 = img1.squeeze(0) - img1 = img1.detach().cpu().numpy().transpose(1,2,0) - if type(img2) == torch.Tensor: - if len(img2.shape) == 4: - img2 = img2.squeeze(0) - img2 = img2.detach().cpu().numpy().transpose(1,2,0) - - img1 = reorder_image(img1, input_order=input_order) + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) if crop_border != 0: - img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] - - def _psnr(img1, img2): - if test_y_channel: - img1 = to_y_channel(img1) - img2 = to_y_channel(img2) - mse = np.mean((img1 - img2)**2) - if mse == 0: - return float('inf') - max_value = 1. if img1.max() <= 1 else 255. - return 20. * np.log10(max_value / np.sqrt(mse)) - - if img1.ndim == 3 and img1.shape[2] == 6: - l1, r1 = img1[:,:,:3], img1[:,:,3:] - l2, r2 = img2[:,:,:3], img2[:,:,3:] - return (_psnr(l1, l2) + _psnr(r1, r2))/2 - else: - return _psnr(img1, img2) + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) -def calculate_psnr_left(img1, - img2, - crop_border, - input_order='HWC', - test_y_channel=False): - assert input_order == 'HWC' - assert crop_border == 0 - - img1 = img1[:,64:,:3] - img2 = img2[:,64:,:3] - return calculate_psnr(img1=img1, img2=img2, crop_border=0, input_order=input_order, test_y_channel=test_y_channel) - -def _ssim(img1, img2, max_value): - """Calculate SSIM (structural similarity) for one channel images. - - It is called by func:`calculate_ssim`. - - Args: - img1 (ndarray): Images with range [0, 255] with order 'HWC'. - img2 (ndarray): Images with range [0, 255] with order 'HWC'. - - Returns: - float: ssim result. - """ - - C1 = (0.01 * max_value)**2 - C2 = (0.03 * max_value)**2 - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - kernel = cv2.getGaussianKernel(11, 1.5) - window = np.outer(kernel, kernel.transpose()) - - mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] - mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] - mu1_sq = mu1**2 - mu2_sq = mu2**2 - mu1_mu2 = mu1 * mu2 - sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq - sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq - sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * - (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) - return ssim_map.mean() - -def prepare_for_ssim(img, k): - import torch - with torch.no_grad(): - img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() - conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect') - conv.weight.requires_grad = False - conv.weight[:, :, :, :] = 1. / (k * k) - - img = conv(img) - - img = img.squeeze(0).squeeze(0) - img = img[0::k, 0::k] - return img.detach().cpu().numpy() - -def prepare_for_ssim_rgb(img, k): - import torch - with torch.no_grad(): - img = torch.from_numpy(img).float() #HxWx3 - - conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect') - conv.weight.requires_grad = False - conv.weight[:, :, :, :] = 1. / (k * k) - - new_img = [] - - for i in range(3): - new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k]) - - return torch.stack(new_img, dim=2).detach().cpu().numpy() - -def _3d_gaussian_calculator(img, conv3d): - out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) - return out - -def _generate_3d_gaussian_kernel(): - kernel = cv2.getGaussianKernel(11, 1.5) - window = np.outer(kernel, kernel.transpose()) - kernel_3 = cv2.getGaussianKernel(11, 1.5) - kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) - conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') - conv3d.weight.requires_grad = False - conv3d.weight[0, 0, :, :, :] = kernel - return conv3d - -def _ssim_3d(img1, img2, max_value): - assert len(img1.shape) == 3 and len(img2.shape) == 3 - """Calculate SSIM (structural similarity) for one channel images. - - It is called by func:`calculate_ssim`. - - Args: - img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. - img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. - - Returns: - float: ssim result. - """ - C1 = (0.01 * max_value) ** 2 - C2 = (0.03 * max_value) ** 2 - img1 = img1.astype(np.float64) + if isinstance(img, torch.Tensor): + img = img.numpy() + if isinstance(img2, torch.Tensor): + img2 = img2.numpy() + img = img.astype(np.float64) img2 = img2.astype(np.float64) - kernel = _generate_3d_gaussian_kernel().cuda() - - img1 = torch.tensor(img1).float().cuda() - img2 = torch.tensor(img2).float().cuda() + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) - mu1 = _3d_gaussian_calculator(img1, kernel) - mu2 = _3d_gaussian_calculator(img2, kernel) +@METRIC_REGISTRY.register() +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). - mu1_sq = mu1 ** 2 - mu2_sq = mu2 ** 2 - mu1_mu2 = mu1 * mu2 - sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq - sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq - sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * - (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) - return float(ssim_map.mean()) - -def _ssim_cly(img1, img2): - assert len(img1.shape) == 2 and len(img2.shape) == 2 - """Calculate SSIM (structural similarity) for one channel images. - - It is called by func:`calculate_ssim`. + Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio Args: - img1 (ndarray): Images with range [0, 255] with order 'HWC'. - img2 (ndarray): Images with range [0, 255] with order 'HWC'. + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. Returns: - float: ssim result. + float: PSNR result. """ - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') - kernel = cv2.getGaussianKernel(11, 1.5) - # print(kernel) - window = np.outer(kernel, kernel.transpose()) + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] - bt = cv2.BORDER_REPLICATE + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) - mu1 = cv2.filter2D(img1, -1, window, borderType=bt) - mu2 = cv2.filter2D(img2, -1, window,borderType=bt) + img = img.to(torch.float64) + img2 = img2.to(torch.float64) - mu1_sq = mu1**2 - mu2_sq = mu2**2 - mu1_mu2 = mu1 * mu2 - sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq - sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq - sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * - (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) - return ssim_map.mean() + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) -def calculate_ssim(img1, - img2, - crop_border, - input_order='HWC', - test_y_channel=False, - ssim3d=True): +@METRIC_REGISTRY.register() +def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs): """Calculate SSIM (structural similarity). - Ref: - Image quality assessment: From error visibility to structural similarity + ``Paper: Image quality assessment: From error visibility to structural similarity`` The results are the same as that of the official released MATLAB code in https://ece.uwaterloo.ca/~z70wang/research/ssim/. @@ -266,93 +98,142 @@ def calculate_ssim(img1, averaged. Args: - img1 (ndarray): Images with range [0, 255]. + img (ndarray): Images with range [0, 255]. img2 (ndarray): Images with range [0, 255]. - crop_border (int): Cropped pixels in each edge of an image. These - pixels are not involved in the SSIM calculation. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. test_y_channel (bool): Test on Y channel of YCbCr. Default: False. Returns: - float: ssim result. + float: SSIM result. """ - assert img1.shape == img2.shape, ( - f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') if input_order not in ['HWC', 'CHW']: - raise ValueError( - f'Wrong input_order {input_order}. Supported input_orders are ' - '"HWC" and "CHW"') - - if type(img1) == torch.Tensor: - if len(img1.shape) == 4: - img1 = img1.squeeze(0) - img1 = img1.detach().cpu().numpy().transpose(1,2,0) - if type(img2) == torch.Tensor: - if len(img2.shape) == 4: - img2 = img2.squeeze(0) - img2 = img2.detach().cpu().numpy().transpose(1,2,0) - - img1 = reorder_image(img1, input_order=input_order) + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') + img = reorder_image(img, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - if crop_border != 0: - img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] - def _cal_ssim(img1, img2): - if test_y_channel: - img1 = to_y_channel(img1) - img2 = to_y_channel(img2) - return _ssim_cly(img1[..., 0], img2[..., 0]) + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) - ssims = [] - # ssims_before = [] + if isinstance(img, torch.Tensor): + img = img.numpy() + if isinstance(img2, torch.Tensor): + img2 = img2.numpy() + img = img.astype(np.float64) + img2 = img2.astype(np.float64) - # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True) - # print('.._skimage', - # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)) - max_value = 1 if img1.max() <= 1 else 255 - with torch.no_grad(): - final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim(img1, img2, max_value) - ssims.append(final_ssim) + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() - # for i in range(img1.shape[2]): - # ssims_before.append(_ssim(img1, img2)) - # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before)) - # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False)) +@METRIC_REGISTRY.register() +def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs): + """Calculate SSIM (structural similarity) (PyTorch version). - return np.array(ssims).mean() + ``Paper: Image quality assessment: From error visibility to structural similarity`` - if img1.ndim == 3 and img1.shape[2] == 6: - l1, r1 = img1[:,:,:3], img1[:,:,3:] - l2, r2 = img2[:,:,:3], img2[:,:,3:] - return (_cal_ssim(l1, l2) + _cal_ssim(r1, r2))/2 - else: - return _cal_ssim(img1, img2) + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. -def calculate_ssim_left(img1, - img2, - crop_border, - input_order='HWC', - test_y_channel=False, - ssim3d=True): - assert input_order == 'HWC' - assert crop_border == 0 + For three-channel images, SSIM is calculated for each channel and then + averaged. - img1 = img1[:,64:,:3] - img2 = img2[:,64:,:3] - return calculate_ssim(img1=img1, img2=img2, crop_border=0, input_order=input_order, test_y_channel=test_y_channel, ssim3d=ssim3d) + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. -def calculate_skimage_ssim(img1, img2): - return structural_similarity(img1, img2, multichannel=True) + Returns: + float: SSIM result. + """ -def calculate_skimage_ssim_left(img1, img2): - img1 = img1[:,64:,:3] - img2 = img2[:,64:,:3] - return calculate_skimage_ssim(img1=img1, img2=img2) + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + ssim = _ssim_pth(img * 255., img2 * 255.) + return ssim + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) + return ssim_map.mean() + + +def _ssim_pth(img, img2): + """Calculate SSIM (structural similarity) (PyTorch version). + + It is called by func:`calculate_ssim_pt`. + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + + Returns: + float: SSIM result. + """ + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device) + + mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode + mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq + sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2 + + cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2) + ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map + return ssim_map.mean([1, 2, 3]) diff --git a/basicsr/models/ops/dcn/__init__.py b/basicsr/models/ops/dcn/__init__.py new file mode 100644 index 0000000..32e3592 --- /dev/null +++ b/basicsr/models/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/basicsr/models/ops/dcn/deform_conv.py b/basicsr/models/ops/dcn/deform_conv.py new file mode 100644 index 0000000..6268ca8 --- /dev/null +++ b/basicsr/models/ops/dcn/deform_conv.py @@ -0,0 +1,379 @@ +import math +import os +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) +else: + try: + from . import deform_conv_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/basicsr/models/ops/dcn/src/deform_conv_cuda.cpp b/basicsr/models/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 0000000..b465c49 --- /dev/null +++ b/basicsr/models/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/basicsr/models/ops/dcn/src/deform_conv_cuda_kernel.cu b/basicsr/models/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000..98752dc --- /dev/null +++ b/basicsr/models/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/basicsr/models/ops/dcn/src/deform_conv_ext.cpp b/basicsr/models/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 0000000..41c6df6 --- /dev/null +++ b/basicsr/models/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/basicsr/models/ops/fused_act/__init__.py b/basicsr/models/ops/fused_act/__init__.py new file mode 100644 index 0000000..241dc07 --- /dev/null +++ b/basicsr/models/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/basicsr/models/ops/fused_act/fused_act.py b/basicsr/models/ops/fused_act/fused_act.py new file mode 100644 index 0000000..88edc44 --- /dev/null +++ b/basicsr/models/ops/fused_act/fused_act.py @@ -0,0 +1,95 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import os +import torch +from torch import nn +from torch.autograd import Function + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) +else: + try: + from . import fused_act_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/basicsr/models/ops/fused_act/src/fused_bias_act.cpp b/basicsr/models/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 0000000..85ed0a7 --- /dev/null +++ b/basicsr/models/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/basicsr/models/ops/fused_act/src/fused_bias_act_kernel.cu b/basicsr/models/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 0000000..54c7ff5 --- /dev/null +++ b/basicsr/models/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/basicsr/models/ops/upfirdn2d/__init__.py b/basicsr/models/ops/upfirdn2d/__init__.py new file mode 100644 index 0000000..397e85b --- /dev/null +++ b/basicsr/models/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/basicsr/models/ops/upfirdn2d/src/upfirdn2d.cpp b/basicsr/models/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 0000000..43d0b67 --- /dev/null +++ b/basicsr/models/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/basicsr/models/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/basicsr/models/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 0000000..8870063 --- /dev/null +++ b/basicsr/models/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/basicsr/models/ops/upfirdn2d/upfirdn2d.py b/basicsr/models/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 0000000..d6122d5 --- /dev/null +++ b/basicsr/models/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,192 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import os +import torch +from torch.autograd import Function +from torch.nn import functional as F + +BASICSR_JIT = os.getenv('BASICSR_JIT') +if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) +else: + try: + from . import upfirdn2d_ext + except ImportError: + pass + # avoid annoying print output + # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' + # '1. compile with BASICSR_EXT=True. or\n ' + # '2. set BASICSR_JIT=True during running') + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + _, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/basicsr/test.py b/basicsr/test.py index bf04219..5666756 100644 --- a/basicsr/test.py +++ b/basicsr/test.py @@ -14,7 +14,8 @@ from basicsr.train import parse_options from basicsr.utils import (get_env_info, get_root_logger, get_time_str, make_exp_dirs) from basicsr.utils.options import dict2str - +import gc +torch.cuda.empty_cache() def main(): # parse options, set distributed setting, set ramdom seed @@ -68,3 +69,4 @@ def main(): if __name__ == '__main__': main() + gc.collect() \ No newline at end of file diff --git a/basicsr/train.py b/basicsr/train.py index 9cc8f2a..1f257ef 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -23,6 +23,9 @@ from basicsr.utils import (MessageLogger, check_resume, get_env_info, set_random_seed) from basicsr.utils.dist_util import get_dist_info, init_dist from basicsr.utils.options import dict2str, parse +import gc + +torch.cuda.empty_cache() def parse_options(is_train=True): @@ -303,3 +306,4 @@ if __name__ == '__main__': import os os.environ['GRPC_POLL_STRATEGY']='epoll1' main() + gc.collect() diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py index c2a5e1e..483b27e 100644 --- a/basicsr/utils/__init__.py +++ b/basicsr/utils/__init__.py @@ -1,18 +1,20 @@ -# ------------------------------------------------------------------------ -# Copyright (c) 2022 megvii-model. All Rights Reserved. -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ +import sys +sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet") + +from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb from .file_client import FileClient from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding -from .logger import (MessageLogger, get_env_info, get_root_logger, - init_tb_logger, init_wandb_logger) -from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, - scandir, scandir_SIDD, set_random_seed, sizeof_fmt) -from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) +from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt, scandir_SIDD +from .options import yaml_load __all__ = [ + # color_util.py + 'bgr2ycbcr', + 'rgb2ycbcr', + 'rgb2ycbcr_pt', + 'ycbcr2bgr', + 'ycbcr2rgb', # file_client.py 'FileClient', # img_util.py @@ -23,6 +25,7 @@ __all__ = [ 'crop_border', # logger.py 'MessageLogger', + 'AvgTimer', 'init_tb_logger', 'init_wandb_logger', 'get_root_logger', @@ -36,8 +39,9 @@ __all__ = [ 'scandir_SIDD', 'check_resume', 'sizeof_fmt', - 'padding', - 'create_lmdb_for_reds', - 'create_lmdb_for_gopro', - 'create_lmdb_for_rain13k', + # img_process_util + # options + 'yaml_load', + 'padding' ] + diff --git a/basicsr/utils/color_util.py b/basicsr/utils/color_util.py new file mode 100644 index 0000000..8b7676f --- /dev/null +++ b/basicsr/utils/color_util.py @@ -0,0 +1,208 @@ +import numpy as np +import torch + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + conversion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace conversion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) + + +def rgb2ycbcr_pt(img, y_only=False): + """Convert RGB images to YCbCr images (PyTorch version). + + It implements the ITU-R BT.601 conversion for standard-definition television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + Args: + img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. + """ + if y_only: + weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 + else: + weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) + bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias + + out_img = out_img / 255. + return out_img diff --git a/basicsr/utils/create_lmdb.py b/basicsr/utils/create_lmdb.py index 6194203..1dba6da 100644 --- a/basicsr/utils/create_lmdb.py +++ b/basicsr/utils/create_lmdb.py @@ -21,6 +21,7 @@ def prepare_keys(folder_path, suffix='png'): list[str]: Key list. """ print('Reading image path list ...') + # import ipdb; ipdb.set_trace() img_path_list = sorted( list(scandir(folder_path, suffix=suffix, recursive=False))) keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] @@ -131,3 +132,17 @@ def create_lmdb_for_SIDD(): img_path_list, keys = prepare_keys(folder_path, 'png') make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) ''' + +def create_lmdb_for_BSDS300(root_path=''): + if root_path != '': + folder_path = f'{root_path}/input_crops' + lmdb_path = f'{root_path}/input_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) + + folder_path = f'{root_path}/gt_crops' + lmdb_path = f'{root_path}/gt_crops.lmdb' + + img_path_list, keys = prepare_keys(folder_path, 'jpg') + make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) \ No newline at end of file diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py index 34ebc2e..01ce81b 100644 --- a/basicsr/utils/download_util.py +++ b/basicsr/utils/download_util.py @@ -5,8 +5,11 @@ # Copyright 2018-2020 BasicSR Authors # ------------------------------------------------------------------------ import math +import os import requests +from torch.hub import download_url_to_file, get_dir from tqdm import tqdm +from urllib.parse import urlparse from .misc import sizeof_fmt @@ -14,8 +17,7 @@ from .misc import sizeof_fmt def download_file_from_google_drive(file_id, save_path): """Download files from google drive. - Ref: - https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive Args: file_id (str): File id. @@ -33,11 +35,9 @@ def download_file_from_google_drive(file_id, save_path): response = session.get(URL, params=params, stream=True) # get file size - response_file_size = session.get( - URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) if 'Content-Range' in response_file_size.headers: - file_size = int( - response_file_size.headers['Content-Range'].split('/')[1]) + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) else: file_size = None @@ -51,10 +51,7 @@ def get_confirm_token(response): return None -def save_response_content(response, - destination, - file_size=None, - chunk_size=32768): +def save_response_content(response, destination, file_size=None, chunk_size=32768): if file_size is not None: pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') @@ -68,9 +65,40 @@ def save_response_content(response, downloaded_size += chunk_size if pbar is not None: pbar.update(1) - pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' - f'/ {readable_file_size}') + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') if chunk: # filter out keep-alive new chunks f.write(chunk) if pbar is not None: pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py index 9a0785f..c300287 100644 --- a/basicsr/utils/file_client.py +++ b/basicsr/utils/file_client.py @@ -42,13 +42,11 @@ class MemcachedBackend(BaseStorageBackend): try: import mc except ImportError: - raise ImportError( - 'Please install memcached to enable MemcachedBackend.') + raise ImportError('Please install memcached to enable MemcachedBackend.') self.server_list_cfg = server_list_cfg self.client_cfg = client_cfg - self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, - self.client_cfg) + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) # mc.pyvector servers as a point which points to a memory cache self._mc_buffer = mc.pyvector() @@ -99,13 +97,7 @@ class LmdbBackend(BaseStorageBackend): _client (list): A list of several lmdb envs. """ - def __init__(self, - db_paths, - client_keys='default', - readonly=True, - lock=False, - readahead=False, - **kwargs): + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): try: import lmdb except ImportError: @@ -118,32 +110,22 @@ class LmdbBackend(BaseStorageBackend): self.db_paths = [str(v) for v in db_paths] elif isinstance(db_paths, str): self.db_paths = [str(db_paths)] - assert len(client_keys) == len(self.db_paths), ( - 'client_keys and db_paths should have the same length, ' - f'but received {len(client_keys)} and {len(self.db_paths)}.') + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') self._client = {} - for client, path in zip(client_keys, self.db_paths): - self._client[client] = lmdb.open( - path, - readonly=readonly, - lock=lock, - readahead=readahead, - map_size=8*1024*10485760, - # max_readers=1, - **kwargs) + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) def get(self, filepath, client_key): """Get values according to the filepath from one lmdb named client_key. Args: filepath (str | obj:`Path`): Here, filepath is the lmdb key. - client_key (str): Used for distinguishing differnet lmdb envs. + client_key (str): Used for distinguishing different lmdb envs. """ filepath = str(filepath) - assert client_key in self._client, (f'client_key {client_key} is not ' - 'in lmdb clients.') + assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') client = self._client[client_key] with client.begin(write=False) as txn: value_buf = txn.get(filepath.encode('ascii')) @@ -174,9 +156,8 @@ class FileClient(object): def __init__(self, backend='disk', **kwargs): if backend not in self._backends: - raise ValueError( - f'Backend {backend} is not supported. Currently supported ones' - f' are {list(self._backends.keys())}') + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') self.backend = backend self.client = self._backends[backend](**kwargs) diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py index d03c1d4..ed80266 100644 --- a/basicsr/utils/flow_util.py +++ b/basicsr/utils/flow_util.py @@ -27,8 +27,7 @@ def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): assert concat_axis in [0, 1] cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) if cat_flow.ndim != 2: - raise IOError(f'{flow_path} is not a valid quantized flow file, ' - f'its dimension is {cat_flow.ndim}.') + raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') assert cat_flow.shape[concat_axis] % 2 == 0 dx, dy = np.split(cat_flow, 2, axis=concat_axis) flow = dequantize_flow(dx, dy, *args, **kwargs) @@ -40,8 +39,7 @@ def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): raise IOError(f'Invalid flow file: {flow_path}') else: if header != 'PIEH': - raise IOError(f'Invalid flow file: {flow_path}, ' - 'header does not contain PIEH') + raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') w = np.fromfile(f, np.int32, 1).squeeze() h = np.fromfile(f, np.int32, 1).squeeze() @@ -77,8 +75,8 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): assert concat_axis in [0, 1] dx, dy = quantize_flow(flow, *args, **kwargs) dxdy = np.concatenate((dx, dy), axis=concat_axis) - os.makedirs(filename, exist_ok=True) - cv2.imwrite(dxdy, filename) + os.makedirs(os.path.dirname(filename), exist_ok=True) + cv2.imwrite(filename, dxdy) def quantize_flow(flow, max_val=0.02, norm=True): @@ -103,9 +101,7 @@ def quantize_flow(flow, max_val=0.02, norm=True): dx = dx / w # avoid inplace operations dy = dy / h # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. - flow_comps = [ - quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] - ] + flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] return tuple(flow_comps) @@ -147,15 +143,12 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64): tuple: Quantized array. """ if not (isinstance(levels, int) and levels > 1): - raise ValueError( - f'levels must be a positive integer, but got {levels}') + raise ValueError(f'levels must be a positive integer, but got {levels}') if min_val >= max_val: - raise ValueError( - f'min_val ({min_val}) must be smaller than max_val ({max_val})') + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') arr = np.clip(arr, min_val, max_val) - min_val - quantized_arr = np.minimum( - np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) return quantized_arr @@ -174,13 +167,10 @@ def dequantize(arr, min_val, max_val, levels, dtype=np.float64): tuple: Dequantized array. """ if not (isinstance(levels, int) and levels > 1): - raise ValueError( - f'levels must be a positive integer, but got {levels}') + raise ValueError(f'levels must be a positive integer, but got {levels}') if min_val >= max_val: - raise ValueError( - f'min_val ({min_val}) must be smaller than max_val ({max_val})') + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') - dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - - min_val) / levels + min_val + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val return dequantized_arr diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py index 6374861..e8ca09f 100644 --- a/basicsr/utils/img_util.py +++ b/basicsr/utils/img_util.py @@ -27,6 +27,8 @@ def img2tensor(imgs, bgr2rgb=True, float32=True): def _totensor(img, bgr2rgb, float32): if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.from_numpy(img.transpose(2, 0, 1)) if float32: @@ -60,11 +62,8 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of shape (H x W). The channel order is BGR. """ - if not (torch.is_tensor(tensor) or - (isinstance(tensor, list) - and all(torch.is_tensor(t) for t in tensor))): - raise TypeError( - f'tensor or list of tensors expected, got {type(tensor)}') + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') if torch.is_tensor(tensor): tensor = [tensor] @@ -75,9 +74,7 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): n_dim = _tensor.dim() if n_dim == 4: - img_np = make_grid( - _tensor, nrow=int(math.sqrt(_tensor.size(0))), - normalize=False).numpy() + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() img_np = img_np.transpose(1, 2, 0) if rgb2bgr: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) @@ -86,14 +83,13 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): img_np = img_np.transpose(1, 2, 0) if img_np.shape[2] == 1: # gray image img_np = np.squeeze(img_np, axis=2) - elif img_np.shape[2] == 3: + else: if rgb2bgr: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) elif n_dim == 2: img_np = _tensor.numpy() else: - raise TypeError('Only support 4D, 3D or 2D tensor. ' - f'But received with dimension: {n_dim}') + raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') if out_type == np.uint8: # Unlike MATLAB, numpy.unit8() WILL NOT round by default. img_np = (img_np * 255.0).round() @@ -104,6 +100,23 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): return result +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + def imfrombytes(content, flag='color', float32=False): """Read an image from bytes. @@ -118,31 +131,12 @@ def imfrombytes(content, flag='color', float32=False): ndarray: Loaded image array. """ img_np = np.frombuffer(content, np.uint8) - imread_flags = { - 'color': cv2.IMREAD_COLOR, - 'grayscale': cv2.IMREAD_GRAYSCALE, - 'unchanged': cv2.IMREAD_UNCHANGED - } - if img_np is None: - raise Exception('None .. !!!') + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} img = cv2.imdecode(img_np, imread_flags[flag]) if float32: img = img.astype(np.float32) / 255. return img -def padding(img_lq, img_gt, gt_size): - h, w, _ = img_lq.shape - - h_pad = max(0, gt_size - h) - w_pad = max(0, gt_size - w) - - if h_pad == 0 and w_pad == 0: - return img_lq, img_gt - - img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) - img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) - # print('img_lq', img_lq.shape, img_gt.shape) - return img_lq, img_gt def imwrite(img, file_path, params=None, auto_mkdir=True): """Write image to file. @@ -160,7 +154,9 @@ def imwrite(img, file_path, params=None, auto_mkdir=True): if auto_mkdir: dir_name = os.path.abspath(os.path.dirname(file_path)) os.makedirs(dir_name, exist_ok=True) - return cv2.imwrite(file_path, img, params) + ok = cv2.imwrite(file_path, img, params) + if not ok: + raise IOError('Failed in writing images.') def crop_border(imgs, crop_border): @@ -177,10 +173,21 @@ def crop_border(imgs, crop_border): return imgs else: if isinstance(imgs, list): - return [ - v[crop_border:-crop_border, crop_border:-crop_border, ...] - for v in imgs - ] + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] else: - return imgs[crop_border:-crop_border, crop_border:-crop_border, - ...] + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] + + +def padding(img_lq, img_gt, gt_size): + h, w, _ = img_lq.shape + + h_pad = max(0, gt_size - h) + w_pad = max(0, gt_size - w) + + if h_pad == 0 and w_pad == 0: + return img_lq, img_gt + + img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + # print('img_lq', img_lq.shape, img_gt.shape) + return img_lq, img_gt diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py index 3cec7d0..e83f8c5 100644 --- a/basicsr/utils/lmdb_util.py +++ b/basicsr/utils/lmdb_util.py @@ -24,10 +24,13 @@ def make_lmdb_from_imgs(data_path, """Make lmdb from images. Contents of lmdb. The file structure is: - example.lmdb - ├── data.mdb - ├── lock.mdb - ├── meta_info.txt + + :: + + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt The data.mdb and lock.mdb are standard lmdb files and you can refer to https://lmdb.readthedocs.io/en/release/ for more details. @@ -64,11 +67,10 @@ def make_lmdb_from_imgs(data_path, estimated size from images. Default: None """ - assert len(img_path_list) == len(keys), ( - 'img_path_list and keys should have the same length, ' - f'but got {len(img_path_list)} and {len(keys)}') + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') print(f'Create lmdb for {data_path}, save to {lmdb_path}...') - print(f'Total images: {len(img_path_list)}') + print(f'Totoal images: {len(img_path_list)}') if not lmdb_path.endswith('.lmdb'): raise ValueError("lmdb_path must end with '.lmdb'.") if osp.exists(lmdb_path): @@ -90,10 +92,7 @@ def make_lmdb_from_imgs(data_path, pool = Pool(n_thread) for path, key in zip(img_path_list, keys): - pool.apply_async( - read_img_worker, - args=(osp.join(data_path, path), key, compress_level), - callback=callback) + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) pool.close() pool.join() pbar.close() @@ -102,10 +101,8 @@ def make_lmdb_from_imgs(data_path, # create lmdb environment if map_size is None: # obtain data size for one image - img = cv2.imread( - osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) - _, img_byte = cv2.imencode( - '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) data_size_per_img = img_byte.nbytes print('Data size per image is: ', data_size_per_img) data_size = data_size_per_img * len(img_path_list) @@ -125,8 +122,7 @@ def make_lmdb_from_imgs(data_path, img_byte = dataset[key] h, w, c = shapes[key] else: - _, img_byte, img_shape = read_img_worker( - osp.join(data_path, path), key, compress_level) + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) h, w, c = img_shape txn.put(key_byte, img_byte) @@ -162,8 +158,7 @@ def read_img_worker(path, key, compress_level): c = 1 else: h, w, c = img.shape - _, img_byte = cv2.imencode('.png', img, - [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) return (key, img_byte, (h, w, c)) @@ -178,11 +173,7 @@ class LmdbMaker(): compress_level (int): Compress level when encoding images. Default: 1. """ - def __init__(self, - lmdb_path, - map_size=1024**4, - batch=5000, - compress_level=1): + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): if not lmdb_path.endswith('.lmdb'): raise ValueError("lmdb_path must end with '.lmdb'.") if osp.exists(lmdb_path): diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py index 704e78c..d18fee0 100644 --- a/basicsr/utils/logger.py +++ b/basicsr/utils/logger.py @@ -10,6 +10,43 @@ import time from .dist_util import get_dist_info, master_only +initialized_logger = {} + + +class AvgTimer(): + + def __init__(self, window=200): + self.window = window # average window + self.current_time = 0 + self.total_time = 0 + self.count = 0 + self.avg_time = 0 + self.start() + + def start(self): + self.start_time = self.tic = time.time() + + def record(self): + self.count += 1 + self.toc = time.time() + self.current_time = self.toc - self.tic + self.total_time += self.current_time + # calculate average time + self.avg_time = self.total_time / self.count + + # reset + if self.count > self.window: + self.count = 0 + self.total_time = 0 + + self.tic = time.time() + + def get_current_time(self): + return self.current_time + + def get_avg_time(self): + return self.avg_time + class MessageLogger(): """Message logger for printing. @@ -34,6 +71,9 @@ class MessageLogger(): self.start_time = time.time() self.logger = get_root_logger() + def reset_start_time(self): + self.start_time = time.time() + @master_only def __call__(self, log_vars): """Format logging message. @@ -50,11 +90,9 @@ class MessageLogger(): # epoch, iter, learning rates epoch = log_vars.pop('epoch') current_iter = log_vars.pop('iter') - total_iter = log_vars.pop('total_iter') lrs = log_vars.pop('lrs') - message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' - f'iter:{current_iter:8,d}, lr:(') + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') for v in lrs: message += f'{v:.3e},' message += ')] ' @@ -76,17 +114,10 @@ class MessageLogger(): message += f'{k}: {v:.4e} ' # tensorboard logger if self.use_tb_logger and 'debug' not in self.exp_name: - normed_step = 10000 * (current_iter / total_iter) - normed_step = int(normed_step) - if k.startswith('l_'): - self.tb_logger.add_scalar(f'losses/{k}', v, normed_step) - elif k.startswith('m_'): - self.tb_logger.add_scalar(f'metrics/{k}', v, normed_step) + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) else: - assert 1 == 0 - # else: - # self.tb_logger.add_scalar(k, v, current_iter) + self.tb_logger.add_scalar(k, v, current_iter) self.logger.info(message) @@ -101,7 +132,7 @@ def init_tb_logger(log_dir): def init_wandb_logger(opt): """We now only use wandb to sync tensorboard log.""" import wandb - logger = logging.getLogger('basicsr') + logger = get_root_logger() project = opt['logger']['wandb']['project'] resume_id = opt['logger']['wandb'].get('resume_id') @@ -113,20 +144,12 @@ def init_wandb_logger(opt): wandb_id = wandb.util.generate_id() resume = 'never' - wandb.init( - id=wandb_id, - resume=resume, - name=opt['name'], - config=opt, - project=project, - sync_tensorboard=True) + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') -def get_root_logger(logger_name='basicsr', - log_level=logging.INFO, - log_file=None): +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): """Get the root logger. The logger will be initialized if it has not been initialized. By default a @@ -146,20 +169,25 @@ def get_root_logger(logger_name='basicsr', """ logger = logging.getLogger(logger_name) # if the logger has been initialized, just return it - if logger.hasHandlers(): + if logger_name in initialized_logger: return logger format_str = '%(asctime)s %(levelname)s: %(message)s' - logging.basicConfig(format=format_str, level=log_level) + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False rank, _ = get_dist_info() if rank != 0: logger.setLevel('ERROR') elif log_file is not None: + logger.setLevel(log_level) + # add file handler file_handler = logging.FileHandler(log_file, 'w') file_handler.setFormatter(logging.Formatter(format_str)) file_handler.setLevel(log_level) logger.addHandler(file_handler) - + initialized_logger[logger_name] = True return logger diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py index 7657cf4..e01e719 100644 --- a/basicsr/utils/matlab_functions.py +++ b/basicsr/utils/matlab_functions.py @@ -15,13 +15,11 @@ def cubic(x): absx2 = absx**2 absx3 = absx**3 return (1.5 * absx3 - 2.5 * absx2 + 1) * ( - (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + - 2) * (((absx > 1) * - (absx <= 2)).type_as(absx)) + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) -def calculate_weights_indices(in_length, out_length, scale, kernel, - kernel_width, antialiasing): +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): """Calculate weights and indices, used for imresize function. Args: @@ -56,8 +54,8 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace( - 0, p - 1, p).view(1, p).expand(out_length, p) + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -109,11 +107,18 @@ def imresize(img, scale, antialiasing=True): Returns: Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. """ + squeeze_flag = False if type(img).__module__ == np.__name__: # numpy type numpy_type = True + if img.ndim == 2: + img = img[:, :, None] + squeeze_flag = True img = torch.from_numpy(img.transpose(2, 0, 1)).float() else: numpy_type = False + if img.ndim == 2: + img = img.unsqueeze(0) + squeeze_flag = True in_c, in_h, in_w = img.size() out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) @@ -121,10 +126,10 @@ def imresize(img, scale, antialiasing=True): kernel = 'cubic' # get weights and indices - weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices( - in_h, out_h, scale, kernel, kernel_width, antialiasing) - weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices( - in_w, out_w, scale, kernel, kernel_width, antialiasing) + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) @@ -145,8 +150,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_h): idx = int(indices_h[i][0]) for j in range(in_c): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose( - 0, 1).mv(weights_h[i]) + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) # process W dimension # symmetric copying @@ -168,200 +172,13 @@ def imresize(img, scale, antialiasing=True): for i in range(out_w): idx = int(indices_w[i][0]) for j in range(in_c): - out_2[j, :, i] = out_1_aug[j, :, - idx:idx + kernel_width].mv(weights_w[i]) + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + if squeeze_flag: + out_2 = out_2.squeeze(0) if numpy_type: - out_2 = out_2.numpy().transpose(1, 2, 0) + out_2 = out_2.numpy() + if not squeeze_flag: + out_2 = out_2.transpose(1, 2, 0) + return out_2 - - -def rgb2ycbcr(img, y_only=False): - """Convert a RGB image to YCbCr image. - - This function produces the same results as Matlab's `rgb2ycbcr` function. - It implements the ITU-R BT.601 conversion for standard-definition - television. See more details in - https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. - - It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. - In OpenCV, it implements a JPEG conversion. See more details in - https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. - - Args: - img (ndarray): The input image. It accepts: - 1. np.uint8 type with range [0, 255]; - 2. np.float32 type with range [0, 1]. - y_only (bool): Whether to only return Y channel. Default: False. - - Returns: - ndarray: The converted YCbCr image. The output image has the same type - and range as input image. - """ - img_type = img.dtype - img = _convert_input_type_range(img) - if y_only: - out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 - else: - out_img = np.matmul( - img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) + [16, 128, 128] - out_img = _convert_output_type_range(out_img, img_type) - return out_img - - -def bgr2ycbcr(img, y_only=False): - """Convert a BGR image to YCbCr image. - - The bgr version of rgb2ycbcr. - It implements the ITU-R BT.601 conversion for standard-definition - television. See more details in - https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. - - It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. - In OpenCV, it implements a JPEG conversion. See more details in - https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. - - Args: - img (ndarray): The input image. It accepts: - 1. np.uint8 type with range [0, 255]; - 2. np.float32 type with range [0, 1]. - y_only (bool): Whether to only return Y channel. Default: False. - - Returns: - ndarray: The converted YCbCr image. The output image has the same type - and range as input image. - """ - img_type = img.dtype - img = _convert_input_type_range(img) - if y_only: - out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 - else: - out_img = np.matmul( - img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) + [16, 128, 128] - out_img = _convert_output_type_range(out_img, img_type) - return out_img - - -def ycbcr2rgb(img): - """Convert a YCbCr image to RGB image. - - This function produces the same results as Matlab's ycbcr2rgb function. - It implements the ITU-R BT.601 conversion for standard-definition - television. See more details in - https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. - - It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. - In OpenCV, it implements a JPEG conversion. See more details in - https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. - - Args: - img (ndarray): The input image. It accepts: - 1. np.uint8 type with range [0, 255]; - 2. np.float32 type with range [0, 1]. - - Returns: - ndarray: The converted RGB image. The output image has the same type - and range as input image. - """ - img_type = img.dtype - img = _convert_input_type_range(img) * 255 - out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], - [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [ - -222.921, 135.576, -276.836 - ] # noqa: E126 - out_img = _convert_output_type_range(out_img, img_type) - return out_img - - -def ycbcr2bgr(img): - """Convert a YCbCr image to BGR image. - - The bgr version of ycbcr2rgb. - It implements the ITU-R BT.601 conversion for standard-definition - television. See more details in - https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. - - It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. - In OpenCV, it implements a JPEG conversion. See more details in - https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. - - Args: - img (ndarray): The input image. It accepts: - 1. np.uint8 type with range [0, 255]; - 2. np.float32 type with range [0, 1]. - - Returns: - ndarray: The converted BGR image. The output image has the same type - and range as input image. - """ - img_type = img.dtype - img = _convert_input_type_range(img) * 255 - out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], - [0.00791071, -0.00153632, 0], - [0, -0.00318811, 0.00625893]]) * 255.0 + [ - -276.836, 135.576, -222.921 - ] # noqa: E126 - out_img = _convert_output_type_range(out_img, img_type) - return out_img - - -def _convert_input_type_range(img): - """Convert the type and range of the input image. - - It converts the input image to np.float32 type and range of [0, 1]. - It is mainly used for pre-processing the input image in colorspace - convertion functions such as rgb2ycbcr and ycbcr2rgb. - - Args: - img (ndarray): The input image. It accepts: - 1. np.uint8 type with range [0, 255]; - 2. np.float32 type with range [0, 1]. - - Returns: - (ndarray): The converted image with type of np.float32 and range of - [0, 1]. - """ - img_type = img.dtype - img = img.astype(np.float32) - if img_type == np.float32: - pass - elif img_type == np.uint8: - img /= 255. - else: - raise TypeError('The img type should be np.float32 or np.uint8, ' - f'but got {img_type}') - return img - - -def _convert_output_type_range(img, dst_type): - """Convert the type and range of the image according to dst_type. - - It converts the image to desired type and range. If `dst_type` is np.uint8, - images will be converted to np.uint8 type with range [0, 255]. If - `dst_type` is np.float32, it converts the image to np.float32 type with - range [0, 1]. - It is mainly used for post-processing images in colorspace convertion - functions such as rgb2ycbcr and ycbcr2rgb. - - Args: - img (ndarray): The image to be converted with np.float32 type and - range [0, 255]. - dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it - converts the image to np.uint8 type with range [0, 255]. If - dst_type is np.float32, it converts the image to np.float32 type - with range [0, 1]. - - Returns: - (ndarray): The converted image with desired type and range. - """ - if dst_type not in (np.uint8, np.float32): - raise TypeError('The dst_type should be np.float32 or np.uint8, ' - f'but got {dst_type}') - if dst_type == np.uint8: - img = img.round() - else: - img /= 255. - return img.astype(dst_type) diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py index 848b12b..32f55f9 100644 --- a/basicsr/utils/misc.py +++ b/basicsr/utils/misc.py @@ -12,7 +12,6 @@ import torch from os import path as osp from .dist_util import master_only -from .logger import get_root_logger def set_random_seed(seed): @@ -50,9 +49,9 @@ def make_exp_dirs(opt): else: mkdir_and_rename(path_opt.pop('results_root')) for key, path in path_opt.items(): - if ('strict_load' not in key) and ('pretrain_network' - not in key) and ('resume' - not in key): + if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): + continue + else: os.makedirs(path, exist_ok=True) @@ -69,7 +68,7 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False): Default: False. Returns: - A generator for all the interested files with relative pathes. + A generator for all the interested files with relative paths. """ if (suffix is not None) and not isinstance(suffix, (str, tuple)): @@ -91,13 +90,62 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False): yield return_path else: if recursive: - yield from _scandir( - entry.path, suffix=suffix, recursive=recursive) + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) else: continue return _scandir(dir_path, suffix=suffix, recursive=recursive) + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + print('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (network + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + print(f"Set {name} to {opt['path'][name]}") + + # change param_key to params in resume + param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] + for param_key in param_keys: + if opt['path'][param_key] == 'params_ema': + opt['path'][param_key] = 'params' + print(f'Set {param_key} to params') + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formatted file size. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' + def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): """Scan a directory to find the interested files. @@ -138,49 +186,4 @@ def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): else: continue - return _scandir(dir_path, keywords=keywords, recursive=recursive) - -def check_resume(opt, resume_iter): - """Check resume states and pretrain_network paths. - - Args: - opt (dict): Options. - resume_iter (int): Resume iteration. - """ - logger = get_root_logger() - if opt['path']['resume_state']: - # get all the networks - networks = [key for key in opt.keys() if key.startswith('network_')] - flag_pretrain = False - for network in networks: - if opt['path'].get(f'pretrain_{network}') is not None: - flag_pretrain = True - if flag_pretrain: - logger.warning( - 'pretrain_network path will be ignored during resuming.') - # set pretrained model paths - for network in networks: - name = f'pretrain_{network}' - basename = network.replace('network_', '') - if opt['path'].get('ignore_resume_networks') is None or ( - basename not in opt['path']['ignore_resume_networks']): - opt['path'][name] = osp.join( - opt['path']['models'], f'net_{basename}_{resume_iter}.pth') - logger.info(f"Set {name} to {opt['path'][name]}") - - -def sizeof_fmt(size, suffix='B'): - """Get human readable file size. - - Args: - size (int): File size. - suffix (str): Suffix. Default: 'B'. - - Return: - str: Formated file siz. - """ - for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: - if abs(size) < 1024.0: - return f'{size:3.1f} {unit}{suffix}' - size /= 1024.0 - return f'{size:3.1f} Y{suffix}' + return _scandir(dir_path, keywords=keywords, recursive=recursive) \ No newline at end of file diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 2f93a6c..3aab8ed 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -4,16 +4,23 @@ # Modified from BasicSR (https://github.com/xinntao/BasicSR) # Copyright 2018-2020 BasicSR Authors # ------------------------------------------------------------------------ +import argparse +import os +import random +import torch import yaml from collections import OrderedDict from os import path as osp +from basicsr.utils import set_random_seed +from basicsr.utils.dist_util import get_dist_info, init_dist, master_only + def ordered_yaml(): """Support OrderedDict for yaml. Returns: - yaml Loader and Dumper. + tuple: yaml Loader and Dumper. """ try: from yaml import CDumper as Dumper @@ -34,6 +41,189 @@ def ordered_yaml(): return Loader, Dumper +def yaml_load(f): + """Load yaml file or string. + + Args: + f (str): File path or a python string. + + Returns: + dict: Loaded dict. + """ + if os.path.isfile(f): + with open(f, 'r') as f: + return yaml.load(f, Loader=ordered_yaml()[0]) + else: + return yaml.load(f, Loader=ordered_yaml()[0]) + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg + + +def _postprocess_yml_value(value): + # None + if value == '~' or value.lower() == 'none': + return None + # bool + if value.lower() == 'true': + return True + elif value.lower() == 'false': + return False + # !!float number + if value.startswith('!!float'): + return float(value.replace('!!float', '')) + # number + if value.isdigit(): + return int(value) + elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: + return float(value) + # list + if value.startswith('['): + return eval(value) + # str + return value + + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--auto_resume', action='store_true') + parser.add_argument('--debug', action='store_true') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') + args = parser.parse_args() + + # parse yml to dict + opt = yaml_load(args.opt) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + # force to update yml options + if args.force_yml is not None: + for entry in args.force_yml: + # now do not support creating new keys + keys, value = entry.split('=') + keys, value = keys.strip(), value.strip() + value = _postprocess_yml_value(value) + eval_str = 'opt' + for key in keys.split(':'): + eval_str += f'["{key}"]' + eval_str += '=value' + # using exec function + exec(eval_str) + + opt['auto_resume'] = args.auto_resume + opt['is_train'] = is_train + + # debug setting + if args.debug and not opt['name'].startswith('debug'): + opt['name'] = 'debug_' + opt['name'] + + if opt['num_gpu'] == 'auto': + opt['num_gpu'] = torch.cuda.device_count() + + # datasets + for phase, dataset in opt['datasets'].items(): + # for multiple datasets, e.g., val_1, val_2; test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = opt['path'].get('experiments_root') + if experiments_root is None: + experiments_root = osp.join(root_path, 'experiments') + experiments_root = osp.join(experiments_root, opt['name']) + + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + # change some options for debug mode + if 'debug' in opt['name']: + if 'val' in opt: + opt['val']['val_freq'] = 8 + opt['logger']['print_freq'] = 1 + opt['logger']['save_checkpoint_freq'] = 8 + else: # test + results_root = opt['path'].get('results_root') + if results_root is None: + results_root = osp.join(root_path, 'results') + results_root = osp.join(results_root, opt['name']) + + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt, args + + +@master_only +def copy_opt_file(opt_file, experiments_root): + # copy the yml file to the experiment root + import sys + import time + from shutil import copyfile + cmd = ' '.join(sys.argv) + filename = osp.join(experiments_root, osp.basename(opt_file)) + copyfile(opt_file, filename) + + with open(filename, 'r+') as f: + lines = f.readlines() + lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') + f.seek(0) + f.writelines(lines) + + def parse(opt_path, is_train=True): """Parse option file. @@ -94,24 +284,3 @@ def parse(opt_path, is_train=True): opt['path']['visualization'] = osp.join(results_root, 'visualization') return opt - - -def dict2str(opt, indent_level=1): - """dict to string for printing options. - - Args: - opt (dict): Option dict. - indent_level (int): Indent level. Default: 1. - - Return: - (str): Option string for printing. - """ - msg = '\n' - for k, v in opt.items(): - if isinstance(v, dict): - msg += ' ' * (indent_level * 2) + k + ':[' - msg += dict2str(v, indent_level + 1) - msg += ' ' * (indent_level * 2) + ']\n' - else: - msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' - return msg diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py new file mode 100644 index 0000000..1745e94 --- /dev/null +++ b/basicsr/utils/registry.py @@ -0,0 +1,88 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj, suffix=None): + if isinstance(suffix, str): + name = name + '_' + suffix + + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None, suffix=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class, suffix) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj, suffix) + + def get(self, name, suffix='basicsr'): + ret = self._obj_map.get(name) + if ret is None: + ret = self._obj_map.get(name + '_' + suffix) + print(f'Name {name} is not found, use name: {name}_{suffix}!') + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/basicsr/version.py b/basicsr/version.py index 2c4fdd3..9fcc27a 100644 --- a/basicsr/version.py +++ b/basicsr/version.py @@ -1,5 +1,6 @@ # GENERATED VERSION FILE -# TIME: Mon Apr 18 21:35:20 2022 -__version__ = '1.2.0+386ca20' +# TIME: Sat Apr 20 19:01:02 2024 +__version__ = '1.2.0+cf36476' +__gitsha__ = 'cf36476' short_version = '1.2.0' version_info = (1, 2, 0) diff --git a/demo/denoise_img.png b/demo/denoise_img.png deleted file mode 100644 index 2f74439..0000000 Binary files a/demo/denoise_img.png and /dev/null differ diff --git a/demo/noisy.png b/demo/noisy.png deleted file mode 100644 index f709a31..0000000 Binary files a/demo/noisy.png and /dev/null differ diff --git a/docs/BSDS300.md b/docs/BSDS300.md new file mode 100644 index 0000000..084abad --- /dev/null +++ b/docs/BSDS300.md @@ -0,0 +1,91 @@ +# reproduce the SIDD dataset results + + + +### 1. Data Preparation + +##### Download the train set and place it in ```./datasets/SIDD/Data```: + +* [google drive](https://drive.google.com/file/d/1UHjWZzLPGweA9ZczmV8lFSRcIxqiOVJw/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1EnBVjrfFBiXIRPBgjFrifg?pwd=sl6h), +* ```python scripts/data_preparation/sidd.py``` to crop the train image pairs to 512x512 patches and make the data into lmdb format. + +##### Download the evaluation data (in lmdb format) and place it in ```./datasets/SIDD/val/```: + + * [google drive](https://drive.google.com/file/d/1gZx_K2vmiHalRNOb1aj93KuUQ2guOlLp/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1I9N5fDa4SNP0nuHEy6k-rw?pwd=59d7), + * it should be like ```./datasets/SIDD/val/input_crops.lmdb``` and ```./datasets/SIDD/val/gt_crops.lmdb``` + + + +### 2. Training + +* NAFNet-SIDD-width32: + + ``` + python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width32.yml --launcher pytorch + ``` + +* NAFNet-SIDD-width64: + + ``` + python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/options/train/BSDS300/NAFNet-width64.yml --launcher pytorch + ``` + +* Baseline-SIDD-width32: + + ``` + python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width32.yml --launcher pytorch + ``` + +* Baseline-SIDD-width64: + + ``` + python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/SIDD/Baseline-width64.yml --launcher pytorch + ``` + +* 8 gpus by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. + + + + +### 3. Evaluation + + +##### Download the pretrain model in ```./experiments/pretrained_models/``` + + * **NAFNet-SIDD-width32**: [google drive](https://drive.google.com/file/d/1lsByk21Xw-6aW7epCwOQxvm6HYCQZPHZ/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1Xses38SWl-7wuyuhaGNhaw?pwd=um97) + + * **NAFNet-SIDD-width64**: [google drive](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/198kYyVSrY_xZF0jGv9U0sQ?pwd=dton) + + * **Baseline-SIDD-width32**: [google drive](https://drive.google.com/file/d/1NhqVcqkDcYvYgF_P4BOOfo9tuTcKDuhW/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1wkskmCRKhXq6dGa6Ns8D0A?pwd=0rin) + + * **Baseline-SIDD-width64**: [google drive](https://drive.google.com/file/d/1wQ1HHHPhSp70_ledMBZhDhIGjZQs16wO/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/1ivruGfSRGfWq5AEB8qc7YQ?pwd=t9w8) + + +##### Testing on SIDD dataset + + * NAFNet-SIDD-width32: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width32.yml --launcher pytorch +``` + + * NAFNet-SIDD-width64: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/NAFNet-width64.yml --launcher pytorch +``` + + * Baseline-SIDD-width32: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width32.yml --launcher pytorch +``` + + * Baseline-SIDD-width64: + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 basicsr/test.py -opt ./options/test/SIDD/Baseline-width64.yml --launcher pytorch +``` + +* Test by a single gpu by default. Set ```--nproc_per_node``` to # of gpus for distributed validation. + diff --git a/options/test/BSDS300/NAFNet-width64.yml b/options/test/BSDS300/NAFNet-width64.yml new file mode 100644 index 0000000..6a938e9 --- /dev/null +++ b/options/test/BSDS300/NAFNet-width64.yml @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNet-BSDS-width64-test +model_type: ImageRestorationModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + + val: + name: BSDS_val + type: PairedImageDataset + + dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb + dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb + + io_backend: + type: lmdb + +# network structures +network_g: + type: NAFNet + width: 64 + enc_blk_nums: [2, 2, 4, 8] + middle_blk_num: 12 + dec_blk_nums: [2, 2, 2, 2] + +# path +path: + pretrain_network_g: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/experiments/NAFNet-BSDS-width64/models/net_g_latest.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + use_image: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_ssim + crop_border: 0 + test_y_channel: false + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/BSDS300/NAFNet-width64.yml b/options/train/BSDS300/NAFNet-width64.yml new file mode 100644 index 0000000..7e8e1de --- /dev/null +++ b/options/train/BSDS300/NAFNet-width64.yml @@ -0,0 +1,109 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNet-BSDS-width64 +model_type: ImageRestorationModel +scale: 1 +num_gpu: 1 +manual_seed: 10 + +datasets: + train: + name: BSDS + type: PairedImageDataset + dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/gt_crops.lmdb + dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/train/input_crops.lmdb + + filename_tmpl: '{}' + io_backend: + type: lmdb + + gt_size: 256 + use_flip: false + use_rot: false + + # data loader + use_shuffle: true + num_worker_per_gpu: 1 + batch_size_per_gpu: 1 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: BSDS_val + type: PairedImageDataset + dataroot_gt: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops.lmdb + dataroot_lq: /mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops.lmdb + io_backend: + type: lmdb + + +network_g: + type: NAFNet + width: 64 + enc_blk_nums: [2, 2, 4, 8] + middle_blk_num: 12 + dec_blk_nums: [2, 2, 2, 2] + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 1e-3 + weight_decay: 0. + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 400000 + eta_min: !!float 1e-7 + + # total_iter: 200000 + total_iter: 123076 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: PSNRLoss + loss_weight: 1 + reduction: mean + +# validation settings +val: + val_freq: !!float 2e4 + save_img: false + use_image: false + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_ssim + crop_border: 0 + test_y_channel: false + +# logging settings +logger: + print_freq: 200 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 diff --git a/options/train/SIDD/NAFNet-width64.yml b/options/train/SIDD/NAFNet-width64.yml index 21d0a85..61e08f4 100644 --- a/options/train/SIDD/NAFNet-width64.yml +++ b/options/train/SIDD/NAFNet-width64.yml @@ -28,8 +28,8 @@ datasets: # data loader use_shuffle: true - num_worker_per_gpu: 8 - batch_size_per_gpu: 8 + num_worker_per_gpu: 1 + batch_size_per_gpu: 1 dataset_enlarge_ratio: 1 prefetch_mode: ~ diff --git a/scripts/data_preparation/bsds_300.py b/scripts/data_preparation/bsds_300.py new file mode 100644 index 0000000..126c8bc --- /dev/null +++ b/scripts/data_preparation/bsds_300.py @@ -0,0 +1,133 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import cv2 +import numpy as np +import os +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + +from basicsr.utils import scandir_SIDD +from basicsr.utils.create_lmdb import create_lmdb_for_BSDS300 +import sys +sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/") + +def main(): + opt = {} + opt['n_thread'] = 20 + opt['compression_level'] = 3 + + opt['input_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/Data' + opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/input_crops' + opt['crop_size'] = 512 + opt['crop_size'] = 30 + opt['step'] = 384 + opt['thresh_size'] = 0 + opt['keywords'] = '_NOISY' + extract_subimages(opt) + + opt['save_folder'] = '/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val/gt_crops' + opt['keywords'] = '_GT' + extract_subimages(opt) + + create_lmdb_for_BSDS300('/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/datasets/BSDS300/processed/val') + + +def extract_subimages(opt): + """Crop images to subimages. + Args: + opt (dict): Configuration dict. It contains: + input_folder (str): Path to the input folder. + save_folder (str): Path to save folder. + n_thread (int): Thread number. + """ + input_folder = opt['input_folder'] + save_folder = opt['save_folder'] + if not osp.exists(save_folder): + os.makedirs(save_folder) + print(f'mkdir {save_folder} ...') + else: + print(f'Folder {save_folder} already exists. Exit.') + # sys.exit(1) + + img_list = list(scandir_SIDD(input_folder, keywords=opt['keywords'], recursive=True, full_path=True)) + + pbar = tqdm(total=len(img_list), unit='image', desc='Extract') + pool = Pool(opt['n_thread']) + for path in img_list: + pool.apply_async( + worker, args=(path, opt), callback=lambda arg: pbar.update(1)) + pool.close() + pool.join() + pbar.close() + print('All processes done.') + + +def worker(path, opt): + """Worker for each process. + Args: + path (str): Image path. + opt (dict): Configuration dict. It contains: + crop_size (int): Crop size. + step (int): Step for overlapped sliding window. + thresh_size (int): Threshold size. Patches whose size is lower + than thresh_size will be dropped. + save_folder (str): Path to save folder. + compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. + Returns: + process_info (str): Process information displayed in progress bar. + """ + crop_size = opt['crop_size'] + step = opt['step'] + thresh_size = opt['thresh_size'] + img_name, extension = osp.splitext(osp.basename(path)) + + img_name = img_name.replace(opt['keywords'], '') + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + resized_img = cv2.resize(img, (512, 512)) + + index = 0 + cv2.imwrite( + osp.join(opt['save_folder'], + f'{img_name}_s{index:03d}{extension}'), resized_img, + [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) + process_info = f'Processing {img_name} ...' + return process_info + + if img.ndim == 2: + h, w = img.shape + elif img.ndim == 3: + h, w, c = img.shape + else: + raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') + + h_space = np.arange(0, h - crop_size + 1, step) + if h - (h_space[-1] + crop_size) > thresh_size: + h_space = np.append(h_space, h - crop_size) + w_space = np.arange(0, w - crop_size + 1, step) + if w - (w_space[-1] + crop_size) > thresh_size: + w_space = np.append(w_space, w - crop_size) + + index = 0 + for x in h_space: + for y in w_space: + index += 1 + cropped_img = img[x:x + crop_size, y:y + crop_size, ...] + cropped_img = np.ascontiguousarray(cropped_img) + cv2.imwrite( + osp.join(opt['save_folder'], + f'{img_name}_s{index:03d}{extension}'), cropped_img, + [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) + process_info = f'Processing {img_name} ...' + return process_info + + +if __name__ == '__main__': + main() + # ... make sidd to lmdb \ No newline at end of file diff --git a/scripts/data_preparation/sidd.py b/scripts/data_preparation/sidd.py index d8e61f3..4a4da1d 100644 --- a/scripts/data_preparation/sidd.py +++ b/scripts/data_preparation/sidd.py @@ -14,7 +14,8 @@ from tqdm import tqdm from basicsr.utils import scandir_SIDD from basicsr.utils.create_lmdb import create_lmdb_for_SIDD - +import sys +sys.path.append("/mnt/d/Work/IIT-Jodhpur/semester3/CV/project/NAFNet/") def main(): opt = {} diff --git a/setup.py b/setup.py index ab13069..d673a7c 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ def write_version_py(): content = """# GENERATED VERSION FILE # TIME: {} __version__ = '{}' +__gitsha__ = '{}' short_version = '{}' version_info = ({}) """ @@ -81,7 +82,7 @@ version_info = ({}) [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) VERSION = SHORT_VERSION + '+' + sha - version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, + version_file_str = content.format(time.asctime(), VERSION, sha, SHORT_VERSION, VERSION_INFO) with open(version_file, 'w') as f: f.write(version_file_str)