Added necessary code changes and merged with repo https://github.com/XPixelGroup/BasicSR for inference of below paths
python basicsr/demo.py -opt options/test/SIDD/NAFNet-width64.yml --input_path ./demo/noisy.png --output_path ./demo/denoise_img.png python basicsr/demo.py -opt options/test/REDS/NAFNet-width64.yml --input_path ./demo/blurry.jpg --output_path ./demo/deblur_img.png Signed-off-by: Ranjan Debnath <ranjandebnath.rd@gmail.com>pull/140/head
parent
2b4af71ebe
commit
cf36476ea3
|
@ -0,0 +1,6 @@
|
|||
This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins.
|
||||
|
||||
This directory caches those eggs to prevent repeated downloads.
|
||||
|
||||
However, it is safe to delete this directory.
|
||||
|
|
@ -6,4 +6,6 @@ logs/
|
|||
*__pycache__*
|
||||
*.sh
|
||||
datasets
|
||||
basicsr.egg-info
|
||||
basicsr.egg-info
|
||||
.eggs/*
|
||||
build/*
|
|
@ -1,10 +1,20 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
from .niqe import calculate_niqe
|
||||
from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_left, calculate_psnr_left, calculate_skimage_ssim, calculate_skimage_ssim_left
|
||||
from copy import deepcopy
|
||||
|
||||
__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe', 'calculate_ssim_left', 'calculate_psnr_left', 'calculate_skimage_ssim', 'calculate_skimage_ssim_left']
|
||||
from basicsr.utils.registry import METRIC_REGISTRY
|
||||
from .niqe import calculate_niqe
|
||||
from .psnr_ssim import calculate_psnr, calculate_ssim
|
||||
|
||||
__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
|
||||
|
||||
|
||||
def calculate_metric(data, opt):
|
||||
"""Calculate metric from data and options.
|
||||
|
||||
Args:
|
||||
opt (dict): Configuration. It must contain:
|
||||
type (str): Model type.
|
||||
"""
|
||||
opt = deepcopy(opt)
|
||||
metric_type = opt.pop('type')
|
||||
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
|
||||
return metric
|
||||
|
|
|
@ -1,35 +1,22 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scipy import linalg
|
||||
from tqdm import tqdm
|
||||
|
||||
from basicsr.models.archs.inception import InceptionV3
|
||||
from basicsr.archs.inception import InceptionV3
|
||||
|
||||
|
||||
def load_patched_inception_v3(device='cuda',
|
||||
resize_input=True,
|
||||
normalize_input=False):
|
||||
def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False):
|
||||
# we may not resize the input, but in [rosinality/stylegan2-pytorch] it
|
||||
# does resize the input.
|
||||
inception = InceptionV3([3],
|
||||
resize_input=resize_input,
|
||||
normalize_input=normalize_input)
|
||||
inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input)
|
||||
inception = nn.DataParallel(inception).eval().to(device)
|
||||
return inception
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_inception_features(data_generator,
|
||||
inception,
|
||||
len_generator=None,
|
||||
device='cuda'):
|
||||
def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'):
|
||||
"""Extract inception features.
|
||||
|
||||
Args:
|
||||
|
@ -63,33 +50,27 @@ def extract_inception_features(data_generator,
|
|||
def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
||||
"""Numpy implementation of the Frechet Distance.
|
||||
|
||||
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
||||
and X_2 ~ N(mu_2, C_2) is
|
||||
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
||||
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is:
|
||||
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
||||
Stable version by Dougal J. Sutherland.
|
||||
|
||||
Args:
|
||||
mu1 (np.array): The sample mean over activations.
|
||||
sigma1 (np.array): The covariance matrix over activations for
|
||||
generated samples.
|
||||
mu2 (np.array): The sample mean over activations, precalculated on an
|
||||
representative data set.
|
||||
sigma2 (np.array): The covariance matrix over activations,
|
||||
precalculated on an representative data set.
|
||||
sigma1 (np.array): The covariance matrix over activations for generated samples.
|
||||
mu2 (np.array): The sample mean over activations, precalculated on an representative data set.
|
||||
sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set.
|
||||
|
||||
Returns:
|
||||
float: The Frechet Distance.
|
||||
"""
|
||||
assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
|
||||
assert sigma1.shape == sigma2.shape, (
|
||||
'Two covariances have different dimensions')
|
||||
assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions')
|
||||
|
||||
cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
|
||||
|
||||
# Product might be almost singular
|
||||
if not np.isfinite(cov_sqrt).all():
|
||||
print('Product of cov matrices is singular. Adding {eps} to diagonal '
|
||||
'of cov estimates')
|
||||
print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates')
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
|
||||
|
||||
|
|
|
@ -1,12 +1,6 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import numpy as np
|
||||
|
||||
from basicsr.utils.matlab_functions import bgr2ycbcr
|
||||
from basicsr.utils import bgr2ycbcr
|
||||
|
||||
|
||||
def reorder_image(img, input_order='HWC'):
|
||||
|
@ -27,9 +21,7 @@ def reorder_image(img, input_order='HWC'):
|
|||
"""
|
||||
|
||||
if input_order not in ['HWC', 'CHW']:
|
||||
raise ValueError(
|
||||
f'Wrong input_order {input_order}. Supported input_orders are '
|
||||
"'HWC' and 'CHW'")
|
||||
raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
|
||||
if len(img.shape) == 2:
|
||||
img = img[..., None]
|
||||
if input_order == 'CHW':
|
||||
|
|
|
@ -1,20 +1,17 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
from scipy.ndimage.filters import convolve
|
||||
import os
|
||||
from scipy.ndimage import convolve
|
||||
from scipy.special import gamma
|
||||
|
||||
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
||||
from basicsr.utils.matlab_functions import imresize
|
||||
from basicsr.utils.registry import METRIC_REGISTRY
|
||||
|
||||
|
||||
def estimate_aggd_param(block):
|
||||
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
|
||||
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
|
||||
|
||||
Args:
|
||||
block (ndarray): 2D Image block.
|
||||
|
@ -26,15 +23,13 @@ def estimate_aggd_param(block):
|
|||
block = block.flatten()
|
||||
gam = np.arange(0.2, 10.001, 0.001) # len = 9801
|
||||
gam_reciprocal = np.reciprocal(gam)
|
||||
r_gam = np.square(gamma(gam_reciprocal * 2)) / (
|
||||
gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
|
||||
r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
|
||||
|
||||
left_std = np.sqrt(np.mean(block[block < 0]**2))
|
||||
right_std = np.sqrt(np.mean(block[block > 0]**2))
|
||||
gammahat = left_std / right_std
|
||||
rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
|
||||
rhatnorm = (rhat * (gammahat**3 + 1) *
|
||||
(gammahat + 1)) / ((gammahat**2 + 1)**2)
|
||||
rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
|
||||
array_position = np.argmin((r_gam - rhatnorm)**2)
|
||||
|
||||
alpha = gam[array_position]
|
||||
|
@ -70,22 +65,18 @@ def compute_feature(block):
|
|||
return feat
|
||||
|
||||
|
||||
def niqe(img,
|
||||
mu_pris_param,
|
||||
cov_pris_param,
|
||||
gaussian_window,
|
||||
block_size_h=96,
|
||||
block_size_w=96):
|
||||
def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
|
||||
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
||||
|
||||
Ref: Making a "Completely Blind" Image Quality Analyzer.
|
||||
``Paper: Making a "Completely Blind" Image Quality Analyzer``
|
||||
|
||||
This implementation could produce almost the same results as the official
|
||||
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
||||
|
||||
Note that we do not include block overlap height and width, since they are
|
||||
always 0 in the official implementation.
|
||||
|
||||
For good performance, it is advisable by the official implemtation to
|
||||
For good performance, it is advisable by the official implementation to
|
||||
divide the distorted image in to the same size patched as used for the
|
||||
construction of multivariate Gaussian model.
|
||||
|
||||
|
@ -104,8 +95,7 @@ def niqe(img,
|
|||
block_size_w (int): Width of the blocks in to which image is divided.
|
||||
Default: 96 (the official recommended value).
|
||||
"""
|
||||
assert img.ndim == 2, (
|
||||
'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
|
||||
assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
|
||||
# crop image
|
||||
h, w = img.shape
|
||||
num_block_h = math.floor(h / block_size_h)
|
||||
|
@ -115,10 +105,7 @@ def niqe(img,
|
|||
distparam = [] # dist param is actually the multiscale features
|
||||
for scale in (1, 2): # perform on two scales (1, 2)
|
||||
mu = convolve(img, gaussian_window, mode='nearest')
|
||||
sigma = np.sqrt(
|
||||
np.abs(
|
||||
convolve(np.square(img), gaussian_window, mode='nearest') -
|
||||
np.square(mu)))
|
||||
sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
|
||||
# normalize, as in Eq. 1 in the paper
|
||||
img_nomalized = (img - mu) / (sigma + 1)
|
||||
|
||||
|
@ -126,21 +113,14 @@ def niqe(img,
|
|||
for idx_w in range(num_block_w):
|
||||
for idx_h in range(num_block_h):
|
||||
# process ecah block
|
||||
block = img_nomalized[idx_h * block_size_h //
|
||||
scale:(idx_h + 1) * block_size_h //
|
||||
scale, idx_w * block_size_w //
|
||||
scale:(idx_w + 1) * block_size_w //
|
||||
scale]
|
||||
block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
|
||||
idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
|
||||
feat.append(compute_feature(block))
|
||||
|
||||
distparam.append(np.array(feat))
|
||||
# TODO: matlab bicubic downsample with anti-aliasing
|
||||
# for simplicity, now we use opencv instead, which will result in
|
||||
# a slight difference.
|
||||
|
||||
if scale == 1:
|
||||
h, w = img.shape
|
||||
img = cv2.resize(
|
||||
img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
|
||||
img = imresize(img / 255., scale=0.5, antialiasing=True)
|
||||
img = img * 255.
|
||||
|
||||
distparam = np.concatenate(distparam, axis=1)
|
||||
|
@ -154,20 +134,25 @@ def niqe(img,
|
|||
# compute niqe quality, Eq. 10 in the paper
|
||||
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
|
||||
quality = np.matmul(
|
||||
np.matmul((mu_pris_param - mu_distparam), invcov_param),
|
||||
np.transpose((mu_pris_param - mu_distparam)))
|
||||
quality = np.sqrt(quality)
|
||||
np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
|
||||
|
||||
quality = np.sqrt(quality)
|
||||
quality = float(np.squeeze(quality))
|
||||
return quality
|
||||
|
||||
|
||||
def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
|
||||
@METRIC_REGISTRY.register()
|
||||
def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs):
|
||||
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
||||
|
||||
Ref: Making a "Completely Blind" Image Quality Analyzer.
|
||||
``Paper: Making a "Completely Blind" Image Quality Analyzer``
|
||||
|
||||
This implementation could produce almost the same results as the official
|
||||
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
||||
|
||||
> MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
|
||||
> Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
|
||||
|
||||
We use the official params estimated from the pristine dataset.
|
||||
We use the recommended block size (96, 96) without overlaps.
|
||||
|
||||
|
@ -181,15 +166,15 @@ def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
|
|||
pixels are not involved in the metric calculation.
|
||||
input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
|
||||
Default: 'HWC'.
|
||||
convert_to (str): Whether coverted to 'y' (of MATLAB YCbCr) or 'gray'.
|
||||
convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
|
||||
Default: 'y'.
|
||||
|
||||
Returns:
|
||||
float: NIQE result.
|
||||
"""
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
# we use the official params estimated from the pristine dataset.
|
||||
niqe_pris_params = np.load('basicsr/metrics/niqe_pris_params.npz')
|
||||
niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz'))
|
||||
mu_pris_param = niqe_pris_params['mu_pris_param']
|
||||
cov_pris_param = niqe_pris_params['cov_pris_param']
|
||||
gaussian_window = niqe_pris_params['gaussian_window']
|
||||
|
@ -206,6 +191,9 @@ def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
|
|||
if crop_border != 0:
|
||||
img = img[crop_border:-crop_border, crop_border:-crop_border]
|
||||
|
||||
# round is necessary for being consistent with MATLAB's result
|
||||
img = img.round()
|
||||
|
||||
niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
|
||||
|
||||
return niqe_result
|
||||
|
|
|
@ -1,263 +1,91 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# modified from https://github.com/mayorx/matlab_ssim_pytorch_implementation/blob/main/calc_ssim.py
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
||||
from skimage.metrics import structural_similarity
|
||||
import torch
|
||||
from basicsr.utils.color_util import rgb2ycbcr_pt
|
||||
from basicsr.utils.registry import METRIC_REGISTRY
|
||||
|
||||
def calculate_psnr(img1,
|
||||
img2,
|
||||
crop_border,
|
||||
input_order='HWC',
|
||||
test_y_channel=False):
|
||||
|
||||
@METRIC_REGISTRY.register()
|
||||
def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
|
||||
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
||||
|
||||
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
||||
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
||||
|
||||
Args:
|
||||
img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
|
||||
img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
|
||||
crop_border (int): Cropped pixels in each edge of an image. These
|
||||
pixels are not involved in the PSNR calculation.
|
||||
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
||||
Default: 'HWC'.
|
||||
img (ndarray): Images with range [0, 255].
|
||||
img2 (ndarray): Images with range [0, 255].
|
||||
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
|
||||
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
Returns:
|
||||
float: psnr result.
|
||||
float: PSNR result.
|
||||
"""
|
||||
|
||||
assert img1.shape == img2.shape, (
|
||||
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
||||
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
||||
if input_order not in ['HWC', 'CHW']:
|
||||
raise ValueError(
|
||||
f'Wrong input_order {input_order}. Supported input_orders are '
|
||||
'"HWC" and "CHW"')
|
||||
if type(img1) == torch.Tensor:
|
||||
if len(img1.shape) == 4:
|
||||
img1 = img1.squeeze(0)
|
||||
img1 = img1.detach().cpu().numpy().transpose(1,2,0)
|
||||
if type(img2) == torch.Tensor:
|
||||
if len(img2.shape) == 4:
|
||||
img2 = img2.squeeze(0)
|
||||
img2 = img2.detach().cpu().numpy().transpose(1,2,0)
|
||||
|
||||
img1 = reorder_image(img1, input_order=input_order)
|
||||
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
|
||||
img = reorder_image(img, input_order=input_order)
|
||||
img2 = reorder_image(img2, input_order=input_order)
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
if crop_border != 0:
|
||||
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
|
||||
def _psnr(img1, img2):
|
||||
if test_y_channel:
|
||||
img1 = to_y_channel(img1)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
max_value = 1. if img1.max() <= 1 else 255.
|
||||
return 20. * np.log10(max_value / np.sqrt(mse))
|
||||
|
||||
if img1.ndim == 3 and img1.shape[2] == 6:
|
||||
l1, r1 = img1[:,:,:3], img1[:,:,3:]
|
||||
l2, r2 = img2[:,:,:3], img2[:,:,3:]
|
||||
return (_psnr(l1, l2) + _psnr(r1, r2))/2
|
||||
else:
|
||||
return _psnr(img1, img2)
|
||||
if test_y_channel:
|
||||
img = to_y_channel(img)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
def calculate_psnr_left(img1,
|
||||
img2,
|
||||
crop_border,
|
||||
input_order='HWC',
|
||||
test_y_channel=False):
|
||||
assert input_order == 'HWC'
|
||||
assert crop_border == 0
|
||||
|
||||
img1 = img1[:,64:,:3]
|
||||
img2 = img2[:,64:,:3]
|
||||
return calculate_psnr(img1=img1, img2=img2, crop_border=0, input_order=input_order, test_y_channel=test_y_channel)
|
||||
|
||||
def _ssim(img1, img2, max_value):
|
||||
"""Calculate SSIM (structural similarity) for one channel images.
|
||||
|
||||
It is called by func:`calculate_ssim`.
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
"""
|
||||
|
||||
C1 = (0.01 * max_value)**2
|
||||
C2 = (0.03 * max_value)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) *
|
||||
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
def prepare_for_ssim(img, k):
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
|
||||
conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
|
||||
conv.weight.requires_grad = False
|
||||
conv.weight[:, :, :, :] = 1. / (k * k)
|
||||
|
||||
img = conv(img)
|
||||
|
||||
img = img.squeeze(0).squeeze(0)
|
||||
img = img[0::k, 0::k]
|
||||
return img.detach().cpu().numpy()
|
||||
|
||||
def prepare_for_ssim_rgb(img, k):
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
img = torch.from_numpy(img).float() #HxWx3
|
||||
|
||||
conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
|
||||
conv.weight.requires_grad = False
|
||||
conv.weight[:, :, :, :] = 1. / (k * k)
|
||||
|
||||
new_img = []
|
||||
|
||||
for i in range(3):
|
||||
new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
|
||||
|
||||
return torch.stack(new_img, dim=2).detach().cpu().numpy()
|
||||
|
||||
def _3d_gaussian_calculator(img, conv3d):
|
||||
out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
|
||||
return out
|
||||
|
||||
def _generate_3d_gaussian_kernel():
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
kernel_3 = cv2.getGaussianKernel(11, 1.5)
|
||||
kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
|
||||
conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
|
||||
conv3d.weight.requires_grad = False
|
||||
conv3d.weight[0, 0, :, :, :] = kernel
|
||||
return conv3d
|
||||
|
||||
def _ssim_3d(img1, img2, max_value):
|
||||
assert len(img1.shape) == 3 and len(img2.shape) == 3
|
||||
"""Calculate SSIM (structural similarity) for one channel images.
|
||||
|
||||
It is called by func:`calculate_ssim`.
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
|
||||
img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
"""
|
||||
C1 = (0.01 * max_value) ** 2
|
||||
C2 = (0.03 * max_value) ** 2
|
||||
img1 = img1.astype(np.float64)
|
||||
img = img.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
kernel = _generate_3d_gaussian_kernel().cuda()
|
||||
|
||||
img1 = torch.tensor(img1).float().cuda()
|
||||
img2 = torch.tensor(img2).float().cuda()
|
||||
mse = np.mean((img - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 10. * np.log10(255. * 255. / mse)
|
||||
|
||||
|
||||
mu1 = _3d_gaussian_calculator(img1, kernel)
|
||||
mu2 = _3d_gaussian_calculator(img2, kernel)
|
||||
@METRIC_REGISTRY.register()
|
||||
def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
|
||||
"""Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
|
||||
|
||||
mu1_sq = mu1 ** 2
|
||||
mu2_sq = mu2 ** 2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
|
||||
sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
|
||||
sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) *
|
||||
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return float(ssim_map.mean())
|
||||
|
||||
def _ssim_cly(img1, img2):
|
||||
assert len(img1.shape) == 2 and len(img2.shape) == 2
|
||||
"""Calculate SSIM (structural similarity) for one channel images.
|
||||
|
||||
It is called by func:`calculate_ssim`.
|
||||
Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
|
||||
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
|
||||
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
float: PSNR result.
|
||||
"""
|
||||
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
||||
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
# print(kernel)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
if crop_border != 0:
|
||||
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
|
||||
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
|
||||
|
||||
bt = cv2.BORDER_REPLICATE
|
||||
if test_y_channel:
|
||||
img = rgb2ycbcr_pt(img, y_only=True)
|
||||
img2 = rgb2ycbcr_pt(img2, y_only=True)
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
|
||||
mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
|
||||
img = img.to(torch.float64)
|
||||
img2 = img2.to(torch.float64)
|
||||
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) *
|
||||
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
|
||||
return 10. * torch.log10(1. / (mse + 1e-8))
|
||||
|
||||
|
||||
def calculate_ssim(img1,
|
||||
img2,
|
||||
crop_border,
|
||||
input_order='HWC',
|
||||
test_y_channel=False,
|
||||
ssim3d=True):
|
||||
@METRIC_REGISTRY.register()
|
||||
def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
|
||||
"""Calculate SSIM (structural similarity).
|
||||
|
||||
Ref:
|
||||
Image quality assessment: From error visibility to structural similarity
|
||||
``Paper: Image quality assessment: From error visibility to structural similarity``
|
||||
|
||||
The results are the same as that of the official released MATLAB code in
|
||||
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
||||
|
@ -266,93 +94,138 @@ def calculate_ssim(img1,
|
|||
averaged.
|
||||
|
||||
Args:
|
||||
img1 (ndarray): Images with range [0, 255].
|
||||
img (ndarray): Images with range [0, 255].
|
||||
img2 (ndarray): Images with range [0, 255].
|
||||
crop_border (int): Cropped pixels in each edge of an image. These
|
||||
pixels are not involved in the SSIM calculation.
|
||||
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
|
||||
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
||||
Default: 'HWC'.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
Returns:
|
||||
float: ssim result.
|
||||
float: SSIM result.
|
||||
"""
|
||||
|
||||
assert img1.shape == img2.shape, (
|
||||
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
||||
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
||||
if input_order not in ['HWC', 'CHW']:
|
||||
raise ValueError(
|
||||
f'Wrong input_order {input_order}. Supported input_orders are '
|
||||
'"HWC" and "CHW"')
|
||||
|
||||
if type(img1) == torch.Tensor:
|
||||
if len(img1.shape) == 4:
|
||||
img1 = img1.squeeze(0)
|
||||
img1 = img1.detach().cpu().numpy().transpose(1,2,0)
|
||||
if type(img2) == torch.Tensor:
|
||||
if len(img2.shape) == 4:
|
||||
img2 = img2.squeeze(0)
|
||||
img2 = img2.detach().cpu().numpy().transpose(1,2,0)
|
||||
|
||||
img1 = reorder_image(img1, input_order=input_order)
|
||||
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
|
||||
img = reorder_image(img, input_order=input_order)
|
||||
img2 = reorder_image(img2, input_order=input_order)
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
if crop_border != 0:
|
||||
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
|
||||
def _cal_ssim(img1, img2):
|
||||
if test_y_channel:
|
||||
img1 = to_y_channel(img1)
|
||||
img2 = to_y_channel(img2)
|
||||
return _ssim_cly(img1[..., 0], img2[..., 0])
|
||||
if test_y_channel:
|
||||
img = to_y_channel(img)
|
||||
img2 = to_y_channel(img2)
|
||||
|
||||
ssims = []
|
||||
# ssims_before = []
|
||||
img = img.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
|
||||
# skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
|
||||
# print('.._skimage',
|
||||
# skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
|
||||
max_value = 1 if img1.max() <= 1 else 255
|
||||
with torch.no_grad():
|
||||
final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim(img1, img2, max_value)
|
||||
ssims.append(final_ssim)
|
||||
ssims = []
|
||||
for i in range(img.shape[2]):
|
||||
ssims.append(_ssim(img[..., i], img2[..., i]))
|
||||
return np.array(ssims).mean()
|
||||
|
||||
# for i in range(img1.shape[2]):
|
||||
# ssims_before.append(_ssim(img1, img2))
|
||||
|
||||
# print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
|
||||
# ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
|
||||
@METRIC_REGISTRY.register()
|
||||
def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
|
||||
"""Calculate SSIM (structural similarity) (PyTorch version).
|
||||
|
||||
return np.array(ssims).mean()
|
||||
``Paper: Image quality assessment: From error visibility to structural similarity``
|
||||
|
||||
if img1.ndim == 3 and img1.shape[2] == 6:
|
||||
l1, r1 = img1[:,:,:3], img1[:,:,3:]
|
||||
l2, r2 = img2[:,:,:3], img2[:,:,3:]
|
||||
return (_cal_ssim(l1, l2) + _cal_ssim(r1, r2))/2
|
||||
else:
|
||||
return _cal_ssim(img1, img2)
|
||||
The results are the same as that of the official released MATLAB code in
|
||||
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
||||
|
||||
def calculate_ssim_left(img1,
|
||||
img2,
|
||||
crop_border,
|
||||
input_order='HWC',
|
||||
test_y_channel=False,
|
||||
ssim3d=True):
|
||||
assert input_order == 'HWC'
|
||||
assert crop_border == 0
|
||||
For three-channel images, SSIM is calculated for each channel and then
|
||||
averaged.
|
||||
|
||||
img1 = img1[:,64:,:3]
|
||||
img2 = img2[:,64:,:3]
|
||||
return calculate_ssim(img1=img1, img2=img2, crop_border=0, input_order=input_order, test_y_channel=test_y_channel, ssim3d=ssim3d)
|
||||
Args:
|
||||
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
|
||||
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
|
||||
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
|
||||
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
||||
|
||||
def calculate_skimage_ssim(img1, img2):
|
||||
return structural_similarity(img1, img2, multichannel=True)
|
||||
Returns:
|
||||
float: SSIM result.
|
||||
"""
|
||||
|
||||
def calculate_skimage_ssim_left(img1, img2):
|
||||
img1 = img1[:,64:,:3]
|
||||
img2 = img2[:,64:,:3]
|
||||
return calculate_skimage_ssim(img1=img1, img2=img2)
|
||||
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
|
||||
|
||||
if crop_border != 0:
|
||||
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
|
||||
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
|
||||
|
||||
if test_y_channel:
|
||||
img = rgb2ycbcr_pt(img, y_only=True)
|
||||
img2 = rgb2ycbcr_pt(img2, y_only=True)
|
||||
|
||||
img = img.to(torch.float64)
|
||||
img2 = img2.to(torch.float64)
|
||||
|
||||
ssim = _ssim_pth(img * 255., img2 * 255.)
|
||||
return ssim
|
||||
|
||||
|
||||
def _ssim(img, img2):
|
||||
"""Calculate SSIM (structural similarity) for one channel images.
|
||||
|
||||
It is called by func:`calculate_ssim`.
|
||||
|
||||
Args:
|
||||
img (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
||||
|
||||
Returns:
|
||||
float: SSIM result.
|
||||
"""
|
||||
|
||||
c1 = (0.01 * 255)**2
|
||||
c2 = (0.03 * 255)**2
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def _ssim_pth(img, img2):
|
||||
"""Calculate SSIM (structural similarity) (PyTorch version).
|
||||
|
||||
It is called by func:`calculate_ssim_pt`.
|
||||
|
||||
Args:
|
||||
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
|
||||
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
|
||||
|
||||
Returns:
|
||||
float: SSIM result.
|
||||
"""
|
||||
c1 = (0.01 * 255)**2
|
||||
c2 = (0.03 * 255)**2
|
||||
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device)
|
||||
|
||||
mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode
|
||||
mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq
|
||||
sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2
|
||||
|
||||
cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2)
|
||||
ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map
|
||||
return ssim_map.mean([1, 2, 3])
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
|
||||
modulated_deform_conv)
|
||||
|
||||
__all__ = [
|
||||
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
|
||||
'modulated_deform_conv'
|
||||
]
|
|
@ -0,0 +1,379 @@
|
|||
import math
|
||||
import os
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair, _single
|
||||
|
||||
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
||||
if BASICSR_JIT == 'True':
|
||||
from torch.utils.cpp_extension import load
|
||||
module_path = os.path.dirname(__file__)
|
||||
deform_conv_ext = load(
|
||||
'deform_conv',
|
||||
sources=[
|
||||
os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
|
||||
os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
|
||||
os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
|
||||
],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from . import deform_conv_ext
|
||||
except ImportError:
|
||||
pass
|
||||
# avoid annoying print output
|
||||
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
|
||||
# '1. compile with BASICSR_EXT=True. or\n '
|
||||
# '2. set BASICSR_JIT=True during running')
|
||||
|
||||
|
||||
class DeformConvFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
input,
|
||||
offset,
|
||||
weight,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
deformable_groups=1,
|
||||
im2col_step=64):
|
||||
if input is not None and input.dim() != 4:
|
||||
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
|
||||
ctx.stride = _pair(stride)
|
||||
ctx.padding = _pair(padding)
|
||||
ctx.dilation = _pair(dilation)
|
||||
ctx.groups = groups
|
||||
ctx.deformable_groups = deformable_groups
|
||||
ctx.im2col_step = im2col_step
|
||||
|
||||
ctx.save_for_backward(input, offset, weight)
|
||||
|
||||
output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
|
||||
|
||||
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
|
||||
|
||||
if not input.is_cuda:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
||||
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
||||
deform_conv_ext.deform_conv_forward(input, weight,
|
||||
offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
||||
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
||||
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||
ctx.deformable_groups, cur_im2col_step)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input, offset, weight = ctx.saved_tensors
|
||||
|
||||
grad_input = grad_offset = grad_weight = None
|
||||
|
||||
if not grad_output.is_cuda:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
||||
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
grad_input = torch.zeros_like(input)
|
||||
grad_offset = torch.zeros_like(offset)
|
||||
deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
|
||||
grad_offset, weight, ctx.bufs_[0], weight.size(3),
|
||||
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
||||
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
||||
ctx.deformable_groups, cur_im2col_step)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
|
||||
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
||||
weight.size(2), ctx.stride[1], ctx.stride[0],
|
||||
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
|
||||
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
|
||||
cur_im2col_step)
|
||||
|
||||
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
|
||||
|
||||
@staticmethod
|
||||
def _output_size(input, weight, padding, dilation, stride):
|
||||
channels = weight.size(0)
|
||||
output_size = (input.size(0), channels)
|
||||
for d in range(input.dim() - 2):
|
||||
in_size = input.size(d + 2)
|
||||
pad = padding[d]
|
||||
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
||||
stride_ = stride[d]
|
||||
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
||||
if not all(map(lambda s: s > 0, output_size)):
|
||||
raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
|
||||
return output_size
|
||||
|
||||
|
||||
class ModulatedDeformConvFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
input,
|
||||
offset,
|
||||
mask,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
deformable_groups=1):
|
||||
ctx.stride = stride
|
||||
ctx.padding = padding
|
||||
ctx.dilation = dilation
|
||||
ctx.groups = groups
|
||||
ctx.deformable_groups = deformable_groups
|
||||
ctx.with_bias = bias is not None
|
||||
if not ctx.with_bias:
|
||||
bias = input.new_empty(1) # fake tensor
|
||||
if not input.is_cuda:
|
||||
raise NotImplementedError
|
||||
if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
|
||||
ctx.save_for_backward(input, offset, mask, weight, bias)
|
||||
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
|
||||
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
||||
deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
|
||||
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
|
||||
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
||||
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
if not grad_output.is_cuda:
|
||||
raise NotImplementedError
|
||||
input, offset, mask, weight, bias = ctx.saved_tensors
|
||||
grad_input = torch.zeros_like(input)
|
||||
grad_offset = torch.zeros_like(offset)
|
||||
grad_mask = torch.zeros_like(mask)
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
grad_bias = torch.zeros_like(bias)
|
||||
deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
|
||||
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
|
||||
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
|
||||
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
||||
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
||||
if not ctx.with_bias:
|
||||
grad_bias = None
|
||||
|
||||
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
|
||||
|
||||
@staticmethod
|
||||
def _infer_shape(ctx, input, weight):
|
||||
n = input.size(0)
|
||||
channels_out = weight.size(0)
|
||||
height, width = input.shape[2:4]
|
||||
kernel_h, kernel_w = weight.shape[2:4]
|
||||
height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
|
||||
width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
|
||||
return n, channels_out, height_out, width_out
|
||||
|
||||
|
||||
deform_conv = DeformConvFunction.apply
|
||||
modulated_deform_conv = ModulatedDeformConvFunction.apply
|
||||
|
||||
|
||||
class DeformConv(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
deformable_groups=1,
|
||||
bias=False):
|
||||
super(DeformConv, self).__init__()
|
||||
|
||||
assert not bias
|
||||
assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}'
|
||||
assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
self.padding = _pair(padding)
|
||||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.deformable_groups = deformable_groups
|
||||
# enable compatibility with nn.Conv2d
|
||||
self.transposed = False
|
||||
self.output_padding = _single(0)
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, x, offset):
|
||||
# To fix an assert error in deform_conv_cuda.cpp:128
|
||||
# input image is smaller than kernel
|
||||
input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
|
||||
if input_pad:
|
||||
pad_h = max(self.kernel_size[0] - x.size(2), 0)
|
||||
pad_w = max(self.kernel_size[1] - x.size(3), 0)
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
||||
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
||||
out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
||||
self.deformable_groups)
|
||||
if input_pad:
|
||||
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
|
||||
return out
|
||||
|
||||
|
||||
class DeformConvPack(DeformConv):
|
||||
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
||||
stride (int or tuple[int]): Same as nn.Conv2d.
|
||||
padding (int or tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int or tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
bias (bool or str): If specified as `auto`, it will be decided by the
|
||||
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
||||
False.
|
||||
"""
|
||||
|
||||
_version = 2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeformConvPack, self).__init__(*args, **kwargs)
|
||||
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
||||
kernel_size=self.kernel_size,
|
||||
stride=_pair(self.stride),
|
||||
padding=_pair(self.padding),
|
||||
dilation=_pair(self.dilation),
|
||||
bias=True)
|
||||
self.init_offset()
|
||||
|
||||
def init_offset(self):
|
||||
self.conv_offset.weight.data.zero_()
|
||||
self.conv_offset.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
offset = self.conv_offset(x)
|
||||
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
||||
self.deformable_groups)
|
||||
|
||||
|
||||
class ModulatedDeformConv(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
deformable_groups=1,
|
||||
bias=True):
|
||||
super(ModulatedDeformConv, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.deformable_groups = deformable_groups
|
||||
self.with_bias = bias
|
||||
# enable compatibility with nn.Conv2d
|
||||
self.transposed = False
|
||||
self.output_padding = _single(0)
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.zero_()
|
||||
|
||||
def forward(self, x, offset, mask):
|
||||
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
||||
self.groups, self.deformable_groups)
|
||||
|
||||
|
||||
class ModulatedDeformConvPack(ModulatedDeformConv):
|
||||
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
||||
stride (int or tuple[int]): Same as nn.Conv2d.
|
||||
padding (int or tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int or tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
bias (bool or str): If specified as `auto`, it will be decided by the
|
||||
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
||||
False.
|
||||
"""
|
||||
|
||||
_version = 2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
|
||||
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
||||
kernel_size=self.kernel_size,
|
||||
stride=_pair(self.stride),
|
||||
padding=_pair(self.padding),
|
||||
dilation=_pair(self.dilation),
|
||||
bias=True)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
super(ModulatedDeformConvPack, self).init_weights()
|
||||
if hasattr(self, 'conv_offset'):
|
||||
self.conv_offset.weight.data.zero_()
|
||||
self.conv_offset.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_offset(x)
|
||||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
offset = torch.cat((o1, o2), dim=1)
|
||||
mask = torch.sigmoid(mask)
|
||||
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
||||
self.groups, self.deformable_groups)
|
|
@ -0,0 +1,685 @@
|
|||
// modify from
|
||||
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int ksize_h, const int ksize_w, const int pad_h,
|
||||
const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
at::Tensor data_col);
|
||||
|
||||
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int ksize_h, const int ksize_w, const int pad_h,
|
||||
const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
at::Tensor grad_im);
|
||||
|
||||
void deformable_col2im_coord(
|
||||
const at::Tensor data_col, const at::Tensor data_im,
|
||||
const at::Tensor data_offset, const int channels, const int height,
|
||||
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
||||
const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||
const int deformable_group, at::Tensor grad_offset);
|
||||
|
||||
void modulated_deformable_im2col_cuda(
|
||||
const at::Tensor data_im, const at::Tensor data_offset,
|
||||
const at::Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kenerl_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
at::Tensor data_col);
|
||||
|
||||
void modulated_deformable_col2im_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_offset,
|
||||
const at::Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kenerl_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
at::Tensor grad_im);
|
||||
|
||||
void modulated_deformable_col2im_coord_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_im,
|
||||
const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
||||
at::Tensor grad_mask);
|
||||
|
||||
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
||||
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
||||
int padW, int dilationH, int dilationW, int group,
|
||||
int deformable_group) {
|
||||
TORCH_CHECK(weight.ndimension() == 4,
|
||||
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
||||
"but got: %s",
|
||||
weight.ndimension());
|
||||
|
||||
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
|
||||
TORCH_CHECK(kW > 0 && kH > 0,
|
||||
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
||||
kW);
|
||||
|
||||
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
||||
"kernel size should be consistent with weight, ",
|
||||
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
||||
kW, weight.size(2), weight.size(3));
|
||||
|
||||
TORCH_CHECK(dW > 0 && dH > 0,
|
||||
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
||||
|
||||
TORCH_CHECK(
|
||||
dilationW > 0 && dilationH > 0,
|
||||
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
||||
dilationH, dilationW);
|
||||
|
||||
int ndim = input.ndimension();
|
||||
int dimf = 0;
|
||||
int dimh = 1;
|
||||
int dimw = 2;
|
||||
|
||||
if (ndim == 4) {
|
||||
dimf++;
|
||||
dimh++;
|
||||
dimw++;
|
||||
}
|
||||
|
||||
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
||||
ndim);
|
||||
|
||||
long nInputPlane = weight.size(1) * group;
|
||||
long inputHeight = input.size(dimh);
|
||||
long inputWidth = input.size(dimw);
|
||||
long nOutputPlane = weight.size(0);
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
|
||||
TORCH_CHECK(nInputPlane % deformable_group == 0,
|
||||
"input channels must divide deformable group size");
|
||||
|
||||
if (outputWidth < 1 || outputHeight < 1)
|
||||
AT_ERROR(
|
||||
"Given input size: (%ld x %ld x %ld). "
|
||||
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
||||
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
||||
outputWidth);
|
||||
|
||||
TORCH_CHECK(input.size(1) == nInputPlane,
|
||||
"invalid number of input planes, expected: %d, but got: %d",
|
||||
nInputPlane, input.size(1));
|
||||
|
||||
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
|
||||
"input image is smaller than kernel");
|
||||
|
||||
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
||||
"invalid spatial size of offset, expected height: %d width: %d, but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
||||
|
||||
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
||||
"invalid number of channels of offset");
|
||||
|
||||
if (gradOutput != NULL) {
|
||||
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
||||
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
||||
nOutputPlane, gradOutput->size(dimf));
|
||||
|
||||
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
|
||||
gradOutput->size(dimw) == outputWidth),
|
||||
"invalid size of gradOutput, expected height: %d width: %d , but "
|
||||
"got height: %d width: %d",
|
||||
outputHeight, outputWidth, gradOutput->size(dimh),
|
||||
gradOutput->size(dimw));
|
||||
}
|
||||
}
|
||||
|
||||
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
||||
at::Tensor offset, at::Tensor output,
|
||||
at::Tensor columns, at::Tensor ones, int kW,
|
||||
int kH, int dW, int dH, int padW, int padH,
|
||||
int dilationW, int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
// todo: resize columns to include im2col: done
|
||||
// todo: add im2col_step as input
|
||||
// todo: add new output buffer and transpose it to output (or directly
|
||||
// transpose output) todo: possibly change data indexing because of
|
||||
// parallel_imgs
|
||||
|
||||
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
||||
dilationH, dilationW, group, deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
input = input.contiguous();
|
||||
offset = offset.contiguous();
|
||||
weight = weight.contiguous();
|
||||
|
||||
int batch = 1;
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input.unsqueeze_(0);
|
||||
offset.unsqueeze_(0);
|
||||
}
|
||||
|
||||
// todo: assert batchsize dividable by im2col_step
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
||||
outputHeight, outputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
||||
ones = at::ones({outputHeight, outputWidth}, input.options());
|
||||
}
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
at::Tensor output_buffer =
|
||||
at::zeros({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth},
|
||||
output.options());
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
||||
output_buffer.size(2), output_buffer.size(3)});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output_buffer[elt][g] = output_buffer[elt][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output_buffer[elt][g]);
|
||||
}
|
||||
}
|
||||
|
||||
output_buffer = output_buffer.view(
|
||||
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
||||
output_buffer.size(3), output_buffer.size(4)});
|
||||
|
||||
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step, outputHeight, outputWidth});
|
||||
output_buffer.transpose_(1, 2);
|
||||
output.copy_(output_buffer);
|
||||
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
||||
at::Tensor gradOutput, at::Tensor gradInput,
|
||||
at::Tensor gradOffset, at::Tensor weight,
|
||||
at::Tensor columns, int kW, int kH, int dW,
|
||||
int dH, int padW, int padH, int dilationW,
|
||||
int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
||||
dilationH, dilationW, group, deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
input = input.contiguous();
|
||||
offset = offset.contiguous();
|
||||
gradOutput = gradOutput.contiguous();
|
||||
weight = weight.contiguous();
|
||||
|
||||
int batch = 1;
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
||||
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = weight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
// change order of grad output
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight,
|
||||
outputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
// divide into groups
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
||||
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradOutput = gradOutput.view(
|
||||
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
||||
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
||||
|
||||
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
||||
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
||||
dilationH, dilationW, im2col_step, deformable_group,
|
||||
gradOffset[elt]);
|
||||
|
||||
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
||||
}
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
gradOffset = gradOffset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
gradOffset =
|
||||
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int deform_conv_backward_parameters_cuda(
|
||||
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
||||
at::Tensor gradWeight, // at::Tensor gradBias,
|
||||
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
||||
int padW, int padH, int dilationW, int dilationH, int group,
|
||||
int deformable_group, float scale, int im2col_step) {
|
||||
// todo: transpose and reshape outGrad
|
||||
// todo: reshape columns
|
||||
// todo: add im2col_step as input
|
||||
|
||||
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
||||
padW, dilationH, dilationW, group, deformable_group);
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
input = input.contiguous();
|
||||
offset = offset.contiguous();
|
||||
gradOutput = gradOutput.contiguous();
|
||||
|
||||
int batch = 1;
|
||||
|
||||
if (input.ndimension() == 3) {
|
||||
// Force batch
|
||||
batch = 0;
|
||||
input = input.view(
|
||||
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
||||
gradOutput = gradOutput.view(
|
||||
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
||||
}
|
||||
|
||||
long batchSize = input.size(0);
|
||||
long nInputPlane = input.size(1);
|
||||
long inputHeight = input.size(2);
|
||||
long inputWidth = input.size(3);
|
||||
|
||||
long nOutputPlane = gradWeight.size(0);
|
||||
|
||||
long outputWidth =
|
||||
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
||||
long outputHeight =
|
||||
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
||||
|
||||
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
||||
|
||||
columns = at::zeros(
|
||||
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
||||
input.options());
|
||||
|
||||
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
||||
nOutputPlane, outputHeight, outputWidth});
|
||||
gradOutput.transpose_(1, 2);
|
||||
|
||||
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
||||
outputHeight, outputWidth});
|
||||
gradOutputBuffer.copy_(gradOutput);
|
||||
gradOutputBuffer =
|
||||
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
||||
im2col_step * outputHeight, outputWidth});
|
||||
|
||||
gradOutput.transpose_(1, 2);
|
||||
gradOutput =
|
||||
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
||||
|
||||
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
||||
inputHeight, inputWidth});
|
||||
offset =
|
||||
offset.view({batchSize / im2col_step, im2col_step,
|
||||
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
||||
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
||||
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
||||
dilationW, im2col_step, deformable_group, columns);
|
||||
|
||||
// divide into group
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
||||
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
gradWeight =
|
||||
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
gradWeight[g] = gradWeight[g]
|
||||
.flatten(1)
|
||||
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
||||
columns[g].transpose(1, 0), 1.0, scale)
|
||||
.view_as(gradWeight[g]);
|
||||
}
|
||||
gradOutputBuffer = gradOutputBuffer.view(
|
||||
{gradOutputBuffer.size(0),
|
||||
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
||||
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
||||
gradWeight.size(2), gradWeight.size(3),
|
||||
gradWeight.size(4)});
|
||||
}
|
||||
|
||||
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
||||
offset = offset.view(
|
||||
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
||||
|
||||
if (batch == 0) {
|
||||
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
||||
input = input.view({nInputPlane, inputHeight, inputWidth});
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
void modulated_deform_conv_cuda_forward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
||||
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group, const int deformable_group,
|
||||
const bool with_bias) {
|
||||
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||||
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_out = weight.size(0);
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
// resize output
|
||||
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
||||
// resize temporary columns
|
||||
columns =
|
||||
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
||||
input.options());
|
||||
|
||||
output = output.view({output.size(0), group, output.size(1) / group,
|
||||
output.size(2), output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
// divide into group
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output[b][g] = output[b][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output[b][g]);
|
||||
}
|
||||
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
}
|
||||
|
||||
output = output.view({output.size(0), output.size(1) * output.size(2),
|
||||
output.size(3), output.size(4)});
|
||||
|
||||
if (with_bias) {
|
||||
output += bias.view({1, bias.size(0), 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deform_conv_cuda_backward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
||||
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
||||
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias) {
|
||||
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
||||
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
grad_input = grad_input.view({batch, channels, height, width});
|
||||
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
||||
input.options());
|
||||
|
||||
grad_output =
|
||||
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
||||
grad_output.size(2), grad_output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
// divide int group
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
|
||||
// gradient w.r.t. input coordinate data
|
||||
modulated_deformable_col2im_coord_cuda(
|
||||
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||||
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||||
grad_mask[b]);
|
||||
// gradient w.r.t. input data
|
||||
modulated_deformable_col2im_cuda(
|
||||
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||||
|
||||
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||||
// group
|
||||
modulated_deformable_im2col_cuda(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
||||
grad_weight.size(1), grad_weight.size(2),
|
||||
grad_weight.size(3)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
grad_weight[g] =
|
||||
grad_weight[g]
|
||||
.flatten(1)
|
||||
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
||||
.view_as(grad_weight[g]);
|
||||
if (with_bias) {
|
||||
grad_bias[g] =
|
||||
grad_bias[g]
|
||||
.view({-1, 1})
|
||||
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
||||
.view(-1);
|
||||
}
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
||||
grad_weight.size(2), grad_weight.size(3),
|
||||
grad_weight.size(4)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
||||
}
|
||||
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
||||
grad_output.size(2), grad_output.size(3),
|
||||
grad_output.size(4)});
|
||||
}
|
|
@ -0,0 +1,867 @@
|
|||
/*!
|
||||
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
|
||||
*
|
||||
* COPYRIGHT
|
||||
*
|
||||
* All contributions by the University of California:
|
||||
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
||||
* All rights reserved.
|
||||
*
|
||||
* All other contributions:
|
||||
* Copyright (c) 2014-2017, the respective contributors
|
||||
* All rights reserved.
|
||||
*
|
||||
* Caffe uses a shared copyright model: each contributor holds copyright over
|
||||
* their contributions to Caffe. The project versioning records all such
|
||||
* contribution and copyright details. If a contributor wants to further mark
|
||||
* their specific copyright on a particular contribution, they should indicate
|
||||
* their copyright solely in the commit message of the change when it is
|
||||
* committed.
|
||||
*
|
||||
* LICENSE
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
||||
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
* CONTRIBUTION AGREEMENT
|
||||
*
|
||||
* By contributing to the BVLC/caffe repository through pull-request, comment,
|
||||
* or otherwise, the contributor releases their content to the
|
||||
* license and copyright terms herein.
|
||||
*
|
||||
***************** END Caffe Copyright Notice and Disclaimer ********************
|
||||
*
|
||||
* Copyright (c) 2018 Microsoft
|
||||
* Licensed under The MIT License [see LICENSE for details]
|
||||
* \file modulated_deformable_im2col.cuh
|
||||
* \brief Function definitions of converting an image to
|
||||
* column matrix based on kernel, padding, dilation, and offset.
|
||||
* These functions are mainly used in deformable convolution operators.
|
||||
* \ref: https://arxiv.org/abs/1703.06211
|
||||
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
|
||||
*/
|
||||
|
||||
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
const int kMaxGridNum = 65535;
|
||||
|
||||
inline int GET_BLOCKS(const int N)
|
||||
{
|
||||
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
||||
const int height, const int width, scalar_t h, scalar_t w)
|
||||
{
|
||||
|
||||
int h_low = floor(h);
|
||||
int w_low = floor(w);
|
||||
int h_high = h_low + 1;
|
||||
int w_high = w_low + 1;
|
||||
|
||||
scalar_t lh = h - h_low;
|
||||
scalar_t lw = w - w_low;
|
||||
scalar_t hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
scalar_t v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0)
|
||||
v1 = bottom_data[h_low * data_width + w_low];
|
||||
scalar_t v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1)
|
||||
v2 = bottom_data[h_low * data_width + w_high];
|
||||
scalar_t v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0)
|
||||
v3 = bottom_data[h_high * data_width + w_low];
|
||||
scalar_t v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1)
|
||||
v4 = bottom_data[h_high * data_width + w_high];
|
||||
|
||||
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int h, const int w, const int height, const int width)
|
||||
{
|
||||
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
if (h == argmax_h_low && w == argmax_w_low)
|
||||
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_low && w == argmax_w_high)
|
||||
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
||||
if (h == argmax_h_high && w == argmax_w_low)
|
||||
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_high && w == argmax_w_high)
|
||||
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int height, const int width, const scalar_t *im_data,
|
||||
const int data_width, const int bp_dir)
|
||||
{
|
||||
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
|
||||
if (bp_dir == 0)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
else if (bp_dir == 1)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
|
||||
const int batch_size, const int num_channels, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *data_col)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
// index index of output matrix
|
||||
const int w_col = index % width_col;
|
||||
const int h_col = (index / width_col) % height_col;
|
||||
const int b_col = (index / width_col / height_col) % batch_size;
|
||||
const int c_im = (index / width_col / height_col) / batch_size;
|
||||
const int c_col = c_im * kernel_h * kernel_w;
|
||||
|
||||
// compute deformable group index
|
||||
const int deformable_group_index = c_im / channel_per_deformable_group;
|
||||
|
||||
const int h_in = h_col * stride_h - pad_h;
|
||||
const int w_in = w_col * stride_w - pad_w;
|
||||
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
||||
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
||||
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
for (int i = 0; i < kernel_h; ++i)
|
||||
{
|
||||
for (int j = 0; j < kernel_w; ++j)
|
||||
{
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
scalar_t val = static_cast<scalar_t>(0);
|
||||
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
||||
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
||||
{
|
||||
//const scalar_t map_h = i * dilation_h + offset_h;
|
||||
//const scalar_t map_w = j * dilation_w + offset_w;
|
||||
//const int cur_height = height - h_in;
|
||||
//const int cur_width = width - w_in;
|
||||
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
||||
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
||||
}
|
||||
*data_col_ptr = val;
|
||||
data_col_ptr += batch_size * height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_im2col(
|
||||
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h, const int ksize_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
||||
const int deformable_group, at::Tensor data_col)
|
||||
{
|
||||
// num_axes should be smaller than block size
|
||||
// todo: check parallel_imgs is correctly passed in
|
||||
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
||||
int channel_per_deformable_group = channels / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
|
||||
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
|
||||
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
|
||||
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
||||
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
|
||||
height_col, width_col, data_col_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void deformable_col2im_gpu_kernel(
|
||||
const int n, const scalar_t *data_col, const scalar_t *data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *grad_im)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
||||
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / channel_per_deformable_group;
|
||||
|
||||
int w_out = index % width_col;
|
||||
int h_out = (index / width_col) % height_col;
|
||||
int b = (index / width_col / height_col) % batch_size;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
|
||||
2 * kernel_h * kernel_w * height_col * width_col;
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
||||
|
||||
const scalar_t cur_top_grad = data_col[index];
|
||||
const int cur_h = (int)cur_inv_h_data;
|
||||
const int cur_w = (int)cur_inv_w_data;
|
||||
for (int dy = -2; dy <= 2; dy++)
|
||||
{
|
||||
for (int dx = -2; dx <= 2; dx++)
|
||||
{
|
||||
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
||||
cur_w + dx >= 0 && cur_w + dx < width &&
|
||||
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
||||
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
||||
{
|
||||
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
||||
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
||||
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_col2im(
|
||||
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
|
||||
const int height, const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group,
|
||||
at::Tensor grad_im)
|
||||
{
|
||||
|
||||
// todo: make sure parallel_imgs is passed in correctly
|
||||
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
||||
int channel_per_deformable_group = channels / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
||||
|
||||
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
|
||||
ksize_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
|
||||
const scalar_t *data_im, const scalar_t *data_offset,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int offset_channels, const int deformable_group,
|
||||
const int height_col, const int width_col, scalar_t *grad_offset)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
scalar_t val = 0;
|
||||
int w = index % width_col;
|
||||
int h = (index / width_col) % height_col;
|
||||
int c = (index / width_col / height_col) % offset_channels;
|
||||
int b = (index / width_col / height_col) / offset_channels;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
||||
const int col_step = kernel_h * kernel_w;
|
||||
int cnt = 0;
|
||||
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
|
||||
batch_size * width_col * height_col;
|
||||
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
|
||||
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
||||
kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
||||
|
||||
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
||||
{
|
||||
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
||||
const int bp_dir = offset_c % 2;
|
||||
|
||||
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
||||
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
int w_out = col_pos % width_col;
|
||||
int h_out = (col_pos / width_col) % height_col;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
||||
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
||||
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
||||
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
||||
{
|
||||
inv_h = inv_w = -2;
|
||||
}
|
||||
const scalar_t weight = get_coordinate_weight(
|
||||
inv_h, inv_w,
|
||||
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
||||
val += weight * data_col_ptr[col_pos];
|
||||
cnt += 1;
|
||||
}
|
||||
|
||||
grad_offset[index] = val;
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_col2im_coord(
|
||||
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
|
||||
const int channels, const int height, const int width, const int ksize_h,
|
||||
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
|
||||
const int stride_w, const int dilation_h, const int dilation_w,
|
||||
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
|
||||
{
|
||||
|
||||
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
||||
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
||||
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
|
||||
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
||||
|
||||
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
|
||||
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
|
||||
height_col, width_col, grad_offset_);
|
||||
}));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
||||
const int height, const int width, scalar_t h, scalar_t w)
|
||||
{
|
||||
int h_low = floor(h);
|
||||
int w_low = floor(w);
|
||||
int h_high = h_low + 1;
|
||||
int w_high = w_low + 1;
|
||||
|
||||
scalar_t lh = h - h_low;
|
||||
scalar_t lw = w - w_low;
|
||||
scalar_t hh = 1 - lh, hw = 1 - lw;
|
||||
|
||||
scalar_t v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0)
|
||||
v1 = bottom_data[h_low * data_width + w_low];
|
||||
scalar_t v2 = 0;
|
||||
if (h_low >= 0 && w_high <= width - 1)
|
||||
v2 = bottom_data[h_low * data_width + w_high];
|
||||
scalar_t v3 = 0;
|
||||
if (h_high <= height - 1 && w_low >= 0)
|
||||
v3 = bottom_data[h_high * data_width + w_low];
|
||||
scalar_t v4 = 0;
|
||||
if (h_high <= height - 1 && w_high <= width - 1)
|
||||
v4 = bottom_data[h_high * data_width + w_high];
|
||||
|
||||
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int h, const int w, const int height, const int width)
|
||||
{
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
if (h == argmax_h_low && w == argmax_w_low)
|
||||
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_low && w == argmax_w_high)
|
||||
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
||||
if (h == argmax_h_high && w == argmax_w_low)
|
||||
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
||||
if (h == argmax_h_high && w == argmax_w_high)
|
||||
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
||||
const int height, const int width, const scalar_t *im_data,
|
||||
const int data_width, const int bp_dir)
|
||||
{
|
||||
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
||||
{
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
int argmax_h_low = floor(argmax_h);
|
||||
int argmax_w_low = floor(argmax_w);
|
||||
int argmax_h_high = argmax_h_low + 1;
|
||||
int argmax_w_high = argmax_w_low + 1;
|
||||
|
||||
scalar_t weight = 0;
|
||||
|
||||
if (bp_dir == 0)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
else if (bp_dir == 1)
|
||||
{
|
||||
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
||||
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
||||
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
||||
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
||||
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
||||
}
|
||||
|
||||
return weight;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
|
||||
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int num_channels, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *data_col)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
// index index of output matrix
|
||||
const int w_col = index % width_col;
|
||||
const int h_col = (index / width_col) % height_col;
|
||||
const int b_col = (index / width_col / height_col) % batch_size;
|
||||
const int c_im = (index / width_col / height_col) / batch_size;
|
||||
const int c_col = c_im * kernel_h * kernel_w;
|
||||
|
||||
// compute deformable group index
|
||||
const int deformable_group_index = c_im / channel_per_deformable_group;
|
||||
|
||||
const int h_in = h_col * stride_h - pad_h;
|
||||
const int w_in = w_col * stride_w - pad_w;
|
||||
|
||||
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
||||
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
||||
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
for (int i = 0; i < kernel_h; ++i)
|
||||
{
|
||||
for (int j = 0; j < kernel_w; ++j)
|
||||
{
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
||||
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
scalar_t val = static_cast<scalar_t>(0);
|
||||
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
||||
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
|
||||
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
||||
{
|
||||
//const float map_h = i * dilation_h + offset_h;
|
||||
//const float map_w = j * dilation_w + offset_w;
|
||||
//const int cur_height = height - h_in;
|
||||
//const int cur_width = width - w_in;
|
||||
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
||||
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
||||
}
|
||||
*data_col_ptr = val * mask;
|
||||
data_col_ptr += batch_size * height_col * width_col;
|
||||
//data_col_ptr += height_col * width_col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
|
||||
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *grad_im)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
||||
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / channel_per_deformable_group;
|
||||
|
||||
int w_out = index % width_col;
|
||||
int h_out = (index / width_col) % height_col;
|
||||
int b = (index / width_col / height_col) % batch_size;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
||||
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
||||
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
||||
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
||||
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
||||
|
||||
const scalar_t cur_top_grad = data_col[index] * mask;
|
||||
const int cur_h = (int)cur_inv_h_data;
|
||||
const int cur_w = (int)cur_inv_w_data;
|
||||
for (int dy = -2; dy <= 2; dy++)
|
||||
{
|
||||
for (int dx = -2; dx <= 2; dx++)
|
||||
{
|
||||
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
||||
cur_w + dx >= 0 && cur_w + dx < width &&
|
||||
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
||||
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
||||
{
|
||||
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
||||
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
||||
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
|
||||
const scalar_t *data_col, const scalar_t *data_im,
|
||||
const scalar_t *data_offset, const scalar_t *data_mask,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int channel_per_deformable_group,
|
||||
const int batch_size, const int offset_channels, const int deformable_group,
|
||||
const int height_col, const int width_col,
|
||||
scalar_t *grad_offset, scalar_t *grad_mask)
|
||||
{
|
||||
CUDA_KERNEL_LOOP(index, n)
|
||||
{
|
||||
scalar_t val = 0, mval = 0;
|
||||
int w = index % width_col;
|
||||
int h = (index / width_col) % height_col;
|
||||
int c = (index / width_col / height_col) % offset_channels;
|
||||
int b = (index / width_col / height_col) / offset_channels;
|
||||
// compute the start and end of the output
|
||||
|
||||
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
||||
const int col_step = kernel_h * kernel_w;
|
||||
int cnt = 0;
|
||||
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
|
||||
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
||||
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
||||
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
||||
|
||||
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
||||
|
||||
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
||||
{
|
||||
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
||||
const int bp_dir = offset_c % 2;
|
||||
|
||||
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
||||
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
||||
int w_out = col_pos % width_col;
|
||||
int h_out = (col_pos / width_col) % height_col;
|
||||
int w_in = w_out * stride_w - pad_w;
|
||||
int h_in = h_out * stride_h - pad_h;
|
||||
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
||||
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
||||
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
||||
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
||||
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
||||
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
||||
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
||||
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
||||
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
||||
{
|
||||
inv_h = inv_w = -2;
|
||||
}
|
||||
else
|
||||
{
|
||||
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
|
||||
}
|
||||
const scalar_t weight = dmcn_get_coordinate_weight(
|
||||
inv_h, inv_w,
|
||||
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
||||
val += weight * data_col_ptr[col_pos] * mask;
|
||||
cnt += 1;
|
||||
}
|
||||
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
||||
grad_offset[index] = val;
|
||||
if (offset_c % 2 == 0)
|
||||
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
|
||||
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deformable_im2col_cuda(
|
||||
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im, const int width_im,
|
||||
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int deformable_group, at::Tensor data_col)
|
||||
{
|
||||
// num_axes should be smaller than block size
|
||||
const int channel_per_deformable_group = channels / deformable_group;
|
||||
const int num_kernels = channels * batch_size * height_col * width_col;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
|
||||
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
||||
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
|
||||
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
|
||||
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
|
||||
batch_size, channels, deformable_group, height_col, width_col, data_col_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im, const int width_im,
|
||||
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int deformable_group, at::Tensor grad_im)
|
||||
{
|
||||
|
||||
const int channel_per_deformable_group = channels / deformable_group;
|
||||
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
||||
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
||||
|
||||
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
|
||||
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
batch_size, deformable_group, height_col, width_col, grad_im_);
|
||||
}));
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_coord_cuda(
|
||||
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im, const int width_im,
|
||||
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int deformable_group,
|
||||
at::Tensor grad_offset, at::Tensor grad_mask)
|
||||
{
|
||||
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
|
||||
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
|
||||
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
||||
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
||||
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
||||
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
||||
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
||||
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
|
||||
|
||||
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
|
||||
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, channel_per_deformable_group,
|
||||
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
|
||||
grad_offset_, grad_mask_);
|
||||
}));
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
// modify from
|
||||
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#define WITH_CUDA // always use cuda
|
||||
#ifdef WITH_CUDA
|
||||
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
||||
at::Tensor offset, at::Tensor output,
|
||||
at::Tensor columns, at::Tensor ones, int kW,
|
||||
int kH, int dW, int dH, int padW, int padH,
|
||||
int dilationW, int dilationH, int group,
|
||||
int deformable_group, int im2col_step);
|
||||
|
||||
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
||||
at::Tensor gradOutput, at::Tensor gradInput,
|
||||
at::Tensor gradOffset, at::Tensor weight,
|
||||
at::Tensor columns, int kW, int kH, int dW,
|
||||
int dH, int padW, int padH, int dilationW,
|
||||
int dilationH, int group,
|
||||
int deformable_group, int im2col_step);
|
||||
|
||||
int deform_conv_backward_parameters_cuda(
|
||||
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
||||
at::Tensor gradWeight, // at::Tensor gradBias,
|
||||
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
||||
int padW, int padH, int dilationW, int dilationH, int group,
|
||||
int deformable_group, float scale, int im2col_step);
|
||||
|
||||
void modulated_deform_conv_cuda_forward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
||||
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group, const int deformable_group,
|
||||
const bool with_bias);
|
||||
|
||||
void modulated_deform_conv_cuda_backward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
||||
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
||||
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias);
|
||||
#endif
|
||||
|
||||
int deform_conv_forward(at::Tensor input, at::Tensor weight,
|
||||
at::Tensor offset, at::Tensor output,
|
||||
at::Tensor columns, at::Tensor ones, int kW,
|
||||
int kH, int dW, int dH, int padW, int padH,
|
||||
int dilationW, int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return deform_conv_forward_cuda(input, weight, offset, output, columns,
|
||||
ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
|
||||
deformable_group, im2col_step);
|
||||
#else
|
||||
AT_ERROR("deform conv is not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("deform conv is not implemented on CPU");
|
||||
}
|
||||
|
||||
int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
|
||||
at::Tensor gradOutput, at::Tensor gradInput,
|
||||
at::Tensor gradOffset, at::Tensor weight,
|
||||
at::Tensor columns, int kW, int kH, int dW,
|
||||
int dH, int padW, int padH, int dilationW,
|
||||
int dilationH, int group,
|
||||
int deformable_group, int im2col_step) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return deform_conv_backward_input_cuda(input, offset, gradOutput,
|
||||
gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
|
||||
dilationW, dilationH, group, deformable_group, im2col_step);
|
||||
#else
|
||||
AT_ERROR("deform conv is not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("deform conv is not implemented on CPU");
|
||||
}
|
||||
|
||||
int deform_conv_backward_parameters(
|
||||
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
||||
at::Tensor gradWeight, // at::Tensor gradBias,
|
||||
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
||||
int padW, int padH, int dilationW, int dilationH, int group,
|
||||
int deformable_group, float scale, int im2col_step) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
|
||||
gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
|
||||
dilationH, group, deformable_group, scale, im2col_step);
|
||||
#else
|
||||
AT_ERROR("deform conv is not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("deform conv is not implemented on CPU");
|
||||
}
|
||||
|
||||
void modulated_deform_conv_forward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
||||
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_h, const int pad_w, const int dilation_h,
|
||||
const int dilation_w, const int group, const int deformable_group,
|
||||
const bool with_bias) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
|
||||
offset, mask, output, columns, kernel_h, kernel_w, stride_h,
|
||||
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
||||
deformable_group, with_bias);
|
||||
#else
|
||||
AT_ERROR("modulated deform conv is not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("modulated deform conv is not implemented on CPU");
|
||||
}
|
||||
|
||||
void modulated_deform_conv_backward(
|
||||
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
||||
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
||||
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
||||
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias) {
|
||||
if (input.device().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
|
||||
offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
|
||||
grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
|
||||
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
|
||||
with_bias);
|
||||
#else
|
||||
AT_ERROR("modulated deform conv is not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("modulated deform conv is not implemented on CPU");
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("deform_conv_forward", &deform_conv_forward,
|
||||
"deform forward");
|
||||
m.def("deform_conv_backward_input", &deform_conv_backward_input,
|
||||
"deform_conv_backward_input");
|
||||
m.def("deform_conv_backward_parameters",
|
||||
&deform_conv_backward_parameters,
|
||||
"deform_conv_backward_parameters");
|
||||
m.def("modulated_deform_conv_forward",
|
||||
&modulated_deform_conv_forward,
|
||||
"modulated deform conv forward");
|
||||
m.def("modulated_deform_conv_backward",
|
||||
&modulated_deform_conv_backward,
|
||||
"modulated deform conv backward");
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
||||
|
||||
__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
|
|
@ -0,0 +1,95 @@
|
|||
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
||||
if BASICSR_JIT == 'True':
|
||||
from torch.utils.cpp_extension import load
|
||||
module_path = os.path.dirname(__file__)
|
||||
fused_act_ext = load(
|
||||
'fused',
|
||||
sources=[
|
||||
os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
|
||||
os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
|
||||
],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from . import fused_act_ext
|
||||
except ImportError:
|
||||
pass
|
||||
# avoid annoying print output
|
||||
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
|
||||
# '1. compile with BASICSR_EXT=True. or\n '
|
||||
# '2. set BASICSR_JIT=True during running')
|
||||
|
||||
|
||||
class FusedLeakyReLUFunctionBackward(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, out, negative_slope, scale):
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
empty = grad_output.new_empty(0)
|
||||
|
||||
grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
|
||||
|
||||
dim = [0]
|
||||
|
||||
if grad_input.ndim > 2:
|
||||
dim += list(range(2, grad_input.ndim))
|
||||
|
||||
grad_bias = grad_input.sum(dim).detach()
|
||||
|
||||
return grad_input, grad_bias
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input, gradgrad_bias):
|
||||
out, = ctx.saved_tensors
|
||||
gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
|
||||
ctx.scale)
|
||||
|
||||
return gradgrad_out, None, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLUFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, bias, negative_slope, scale):
|
||||
empty = input.new_empty(0)
|
||||
out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
||||
ctx.save_for_backward(out)
|
||||
ctx.negative_slope = negative_slope
|
||||
ctx.scale = scale
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
out, = ctx.saved_tensors
|
||||
|
||||
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
|
||||
|
||||
return grad_input, grad_bias, None, None
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
|
||||
super().__init__()
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(channel))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
|
||||
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
|
@ -0,0 +1,26 @@
|
|||
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor& input,
|
||||
const torch::Tensor& bias,
|
||||
const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor fused_bias_act(const torch::Tensor& input,
|
||||
const torch::Tensor& bias,
|
||||
const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(bias);
|
||||
|
||||
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
||||
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
||||
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
||||
|
||||
scalar_t zero = 0.0;
|
||||
|
||||
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
||||
scalar_t x = p_x[xi];
|
||||
|
||||
if (use_bias) {
|
||||
x += p_b[(xi / step_b) % size_b];
|
||||
}
|
||||
|
||||
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
||||
|
||||
scalar_t y;
|
||||
|
||||
switch (act * 10 + grad) {
|
||||
default:
|
||||
case 10: y = x; break;
|
||||
case 11: y = x; break;
|
||||
case 12: y = 0.0; break;
|
||||
|
||||
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
||||
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
||||
case 32: y = 0.0; break;
|
||||
}
|
||||
|
||||
out[xi] = y * scale;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
||||
int act, int grad, float alpha, float scale) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto b = bias.contiguous();
|
||||
auto ref = refer.contiguous();
|
||||
|
||||
int use_bias = b.numel() ? 1 : 0;
|
||||
int use_ref = ref.numel() ? 1 : 0;
|
||||
|
||||
int size_x = x.numel();
|
||||
int size_b = b.numel();
|
||||
int step_b = 1;
|
||||
|
||||
for (int i = 1 + 1; i < x.dim(); i++) {
|
||||
step_b *= x.size(i);
|
||||
}
|
||||
|
||||
int loop_x = 4;
|
||||
int block_size = 4 * 32;
|
||||
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
||||
|
||||
auto y = torch::empty_like(x);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
||||
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
y.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
b.data_ptr<scalar_t>(),
|
||||
ref.data_ptr<scalar_t>(),
|
||||
act,
|
||||
grad,
|
||||
alpha,
|
||||
scale,
|
||||
loop_x,
|
||||
size_x,
|
||||
step_b,
|
||||
size_b,
|
||||
use_bias,
|
||||
use_ref
|
||||
);
|
||||
});
|
||||
|
||||
return y;
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
from .upfirdn2d import upfirdn2d
|
||||
|
||||
__all__ = ['upfirdn2d']
|
|
@ -0,0 +1,24 @@
|
|||
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
||||
int up_x, int up_y, int down_x, int down_y,
|
||||
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
||||
CHECK_CUDA(input);
|
||||
CHECK_CUDA(kernel);
|
||||
|
||||
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
||||
}
|
|
@ -0,0 +1,370 @@
|
|||
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
|
||||
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
||||
//
|
||||
// This work is made available under the Nvidia Source Code License-NC.
|
||||
// To view a copy of this license, visit
|
||||
// https://nvlabs.github.io/stylegan2/license.html
|
||||
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
||||
int c = a / b;
|
||||
|
||||
if (c * b > a) {
|
||||
c--;
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
struct UpFirDn2DKernelParams {
|
||||
int up_x;
|
||||
int up_y;
|
||||
int down_x;
|
||||
int down_y;
|
||||
int pad_x0;
|
||||
int pad_x1;
|
||||
int pad_y0;
|
||||
int pad_y1;
|
||||
|
||||
int major_dim;
|
||||
int in_h;
|
||||
int in_w;
|
||||
int minor_dim;
|
||||
int kernel_h;
|
||||
int kernel_w;
|
||||
int out_h;
|
||||
int out_w;
|
||||
int loop_major;
|
||||
int loop_x;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= out_y * p.minor_dim;
|
||||
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
||||
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
||||
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
||||
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major && major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, out_x = out_x_base;
|
||||
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
||||
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
||||
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
||||
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
||||
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
||||
|
||||
const scalar_t *x_p =
|
||||
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
||||
minor_idx];
|
||||
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
||||
int x_px = p.minor_dim;
|
||||
int k_px = -p.up_x;
|
||||
int x_py = p.in_w * p.minor_dim;
|
||||
int k_py = -p.up_y * p.kernel_w;
|
||||
|
||||
scalar_t v = 0.0f;
|
||||
|
||||
for (int y = 0; y < h; y++) {
|
||||
for (int x = 0; x < w; x++) {
|
||||
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
||||
x_p += x_px;
|
||||
k_p += k_px;
|
||||
}
|
||||
|
||||
x_p += x_py - w * x_px;
|
||||
k_p += k_py - w * k_px;
|
||||
}
|
||||
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
||||
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
||||
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
||||
const scalar_t *kernel,
|
||||
const UpFirDn2DKernelParams p) {
|
||||
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
||||
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
||||
|
||||
__shared__ volatile float sk[kernel_h][kernel_w];
|
||||
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
||||
|
||||
int minor_idx = blockIdx.x;
|
||||
int tile_out_y = minor_idx / p.minor_dim;
|
||||
minor_idx -= tile_out_y * p.minor_dim;
|
||||
tile_out_y *= tile_out_h;
|
||||
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
||||
int major_idx_base = blockIdx.z * p.loop_major;
|
||||
|
||||
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
||||
major_idx_base >= p.major_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
||||
tap_idx += blockDim.x) {
|
||||
int ky = tap_idx / kernel_w;
|
||||
int kx = tap_idx - ky * kernel_w;
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (kx < p.kernel_w & ky < p.kernel_h) {
|
||||
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
||||
}
|
||||
|
||||
sk[ky][kx] = v;
|
||||
}
|
||||
|
||||
for (int loop_major = 0, major_idx = major_idx_base;
|
||||
loop_major < p.loop_major & major_idx < p.major_dim;
|
||||
loop_major++, major_idx++) {
|
||||
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
||||
loop_x < p.loop_x & tile_out_x < p.out_w;
|
||||
loop_x++, tile_out_x += tile_out_w) {
|
||||
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
||||
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
||||
int tile_in_x = floor_div(tile_mid_x, up_x);
|
||||
int tile_in_y = floor_div(tile_mid_y, up_y);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
||||
in_idx += blockDim.x) {
|
||||
int rel_in_y = in_idx / tile_in_w;
|
||||
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
||||
int in_x = rel_in_x + tile_in_x;
|
||||
int in_y = rel_in_y + tile_in_y;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
||||
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
||||
p.minor_dim +
|
||||
minor_idx];
|
||||
}
|
||||
|
||||
sx[rel_in_y][rel_in_x] = v;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
||||
out_idx += blockDim.x) {
|
||||
int rel_out_y = out_idx / tile_out_w;
|
||||
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
||||
int out_x = rel_out_x + tile_out_x;
|
||||
int out_y = rel_out_y + tile_out_y;
|
||||
|
||||
int mid_x = tile_mid_x + rel_out_x * down_x;
|
||||
int mid_y = tile_mid_y + rel_out_y * down_y;
|
||||
int in_x = floor_div(mid_x, up_x);
|
||||
int in_y = floor_div(mid_y, up_y);
|
||||
int rel_in_x = in_x - tile_in_x;
|
||||
int rel_in_y = in_y - tile_in_y;
|
||||
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
||||
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
||||
|
||||
scalar_t v = 0.0;
|
||||
|
||||
#pragma unroll
|
||||
for (int y = 0; y < kernel_h / up_y; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < kernel_w / up_x; x++)
|
||||
v += sx[rel_in_y + y][rel_in_x + x] *
|
||||
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
||||
|
||||
if (out_x < p.out_w & out_y < p.out_h) {
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
||||
minor_idx] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
||||
const torch::Tensor &kernel, int up_x, int up_y,
|
||||
int down_x, int down_y, int pad_x0, int pad_x1,
|
||||
int pad_y0, int pad_y1) {
|
||||
int curDevice = -1;
|
||||
cudaGetDevice(&curDevice);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
||||
|
||||
UpFirDn2DKernelParams p;
|
||||
|
||||
auto x = input.contiguous();
|
||||
auto k = kernel.contiguous();
|
||||
|
||||
p.major_dim = x.size(0);
|
||||
p.in_h = x.size(1);
|
||||
p.in_w = x.size(2);
|
||||
p.minor_dim = x.size(3);
|
||||
p.kernel_h = k.size(0);
|
||||
p.kernel_w = k.size(1);
|
||||
p.up_x = up_x;
|
||||
p.up_y = up_y;
|
||||
p.down_x = down_x;
|
||||
p.down_y = down_y;
|
||||
p.pad_x0 = pad_x0;
|
||||
p.pad_x1 = pad_x1;
|
||||
p.pad_y0 = pad_y0;
|
||||
p.pad_y1 = pad_y1;
|
||||
|
||||
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
||||
p.down_y;
|
||||
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
||||
p.down_x;
|
||||
|
||||
auto out =
|
||||
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
||||
|
||||
int mode = -1;
|
||||
|
||||
int tile_out_h = -1;
|
||||
int tile_out_w = -1;
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 1;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
||||
mode = 2;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 3;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 4;
|
||||
tile_out_h = 16;
|
||||
tile_out_w = 64;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
||||
mode = 5;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
||||
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
||||
mode = 6;
|
||||
tile_out_h = 8;
|
||||
tile_out_w = 32;
|
||||
}
|
||||
|
||||
dim3 block_size;
|
||||
dim3 grid_size;
|
||||
|
||||
if (tile_out_h > 0 && tile_out_w > 0) {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 1;
|
||||
block_size = dim3(32 * 8, 1, 1);
|
||||
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
||||
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
} else {
|
||||
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
||||
p.loop_x = 4;
|
||||
block_size = dim3(4, 32, 1);
|
||||
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
||||
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
||||
(p.major_dim - 1) / p.loop_major + 1);
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
||||
switch (mode) {
|
||||
case 1:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 2:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 3:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 4:
|
||||
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 5:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
case 6:
|
||||
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
||||
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
||||
x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
|
||||
break;
|
||||
|
||||
default:
|
||||
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(), p);
|
||||
}
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
|
@ -0,0 +1,192 @@
|
|||
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.nn import functional as F
|
||||
|
||||
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
||||
if BASICSR_JIT == 'True':
|
||||
from torch.utils.cpp_extension import load
|
||||
module_path = os.path.dirname(__file__)
|
||||
upfirdn2d_ext = load(
|
||||
'upfirdn2d',
|
||||
sources=[
|
||||
os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
|
||||
os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
|
||||
],
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from . import upfirdn2d_ext
|
||||
except ImportError:
|
||||
pass
|
||||
# avoid annoying print output
|
||||
# print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
|
||||
# '1. compile with BASICSR_EXT=True. or\n '
|
||||
# '2. set BASICSR_JIT=True during running')
|
||||
|
||||
|
||||
class UpFirDn2dBackward(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
|
||||
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
||||
|
||||
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
||||
|
||||
grad_input = upfirdn2d_ext.upfirdn2d(
|
||||
grad_output,
|
||||
grad_kernel,
|
||||
down_x,
|
||||
down_y,
|
||||
up_x,
|
||||
up_y,
|
||||
g_pad_x0,
|
||||
g_pad_x1,
|
||||
g_pad_y0,
|
||||
g_pad_y1,
|
||||
)
|
||||
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
||||
|
||||
ctx.save_for_backward(kernel)
|
||||
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
ctx.up_x = up_x
|
||||
ctx.up_y = up_y
|
||||
ctx.down_x = down_x
|
||||
ctx.down_y = down_y
|
||||
ctx.pad_x0 = pad_x0
|
||||
ctx.pad_x1 = pad_x1
|
||||
ctx.pad_y0 = pad_y0
|
||||
ctx.pad_y1 = pad_y1
|
||||
ctx.in_size = in_size
|
||||
ctx.out_size = out_size
|
||||
|
||||
return grad_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gradgrad_input):
|
||||
kernel, = ctx.saved_tensors
|
||||
|
||||
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
||||
|
||||
gradgrad_out = upfirdn2d_ext.upfirdn2d(
|
||||
gradgrad_input,
|
||||
kernel,
|
||||
ctx.up_x,
|
||||
ctx.up_y,
|
||||
ctx.down_x,
|
||||
ctx.down_y,
|
||||
ctx.pad_x0,
|
||||
ctx.pad_x1,
|
||||
ctx.pad_y0,
|
||||
ctx.pad_y1,
|
||||
)
|
||||
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
|
||||
# ctx.out_size[1], ctx.in_size[3])
|
||||
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
|
||||
|
||||
return gradgrad_out, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class UpFirDn2d(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel, up, down, pad):
|
||||
up_x, up_y = up
|
||||
down_x, down_y = down
|
||||
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
||||
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
_, channel, in_h, in_w = input.shape
|
||||
ctx.in_size = input.shape
|
||||
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
ctx.out_size = (out_h, out_w)
|
||||
|
||||
ctx.up = (up_x, up_y)
|
||||
ctx.down = (down_x, down_y)
|
||||
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
|
||||
g_pad_x0 = kernel_w - pad_x0 - 1
|
||||
g_pad_y0 = kernel_h - pad_y0 - 1
|
||||
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
||||
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
||||
|
||||
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
||||
|
||||
out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
|
||||
# out = out.view(major, out_h, out_w, minor)
|
||||
out = out.view(-1, channel, out_h, out_w)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
kernel, grad_kernel = ctx.saved_tensors
|
||||
|
||||
grad_input = UpFirDn2dBackward.apply(
|
||||
grad_output,
|
||||
kernel,
|
||||
grad_kernel,
|
||||
ctx.up,
|
||||
ctx.down,
|
||||
ctx.pad,
|
||||
ctx.g_pad,
|
||||
ctx.in_size,
|
||||
ctx.out_size,
|
||||
)
|
||||
|
||||
return grad_input, None, None, None, None
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
if input.device.type == 'cpu':
|
||||
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
else:
|
||||
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, channel, in_h, in_w = input.shape
|
||||
input = input.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
|
@ -1,18 +1,19 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
|
||||
from .diffjpeg import DiffJPEG
|
||||
from .file_client import FileClient
|
||||
from .img_process_util import USMSharp, usm_sharp
|
||||
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding
|
||||
from .logger import (MessageLogger, get_env_info, get_root_logger,
|
||||
init_tb_logger, init_wandb_logger)
|
||||
from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
|
||||
scandir, scandir_SIDD, set_random_seed, sizeof_fmt)
|
||||
from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
|
||||
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
||||
from .options import yaml_load
|
||||
|
||||
__all__ = [
|
||||
# color_util.py
|
||||
'bgr2ycbcr',
|
||||
'rgb2ycbcr',
|
||||
'rgb2ycbcr_pt',
|
||||
'ycbcr2bgr',
|
||||
'ycbcr2rgb',
|
||||
# file_client.py
|
||||
'FileClient',
|
||||
# img_util.py
|
||||
|
@ -23,6 +24,7 @@ __all__ = [
|
|||
'crop_border',
|
||||
# logger.py
|
||||
'MessageLogger',
|
||||
'AvgTimer',
|
||||
'init_tb_logger',
|
||||
'init_wandb_logger',
|
||||
'get_root_logger',
|
||||
|
@ -33,11 +35,14 @@ __all__ = [
|
|||
'mkdir_and_rename',
|
||||
'make_exp_dirs',
|
||||
'scandir',
|
||||
'scandir_SIDD',
|
||||
'check_resume',
|
||||
'sizeof_fmt',
|
||||
'padding',
|
||||
'create_lmdb_for_reds',
|
||||
'create_lmdb_for_gopro',
|
||||
'create_lmdb_for_rain13k',
|
||||
# diffjpeg
|
||||
'DiffJPEG',
|
||||
# img_process_util
|
||||
'USMSharp',
|
||||
'usm_sharp',
|
||||
# options
|
||||
'yaml_load',
|
||||
'padding'
|
||||
]
|
||||
|
|
|
@ -5,8 +5,11 @@
|
|||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import math
|
||||
import os
|
||||
import requests
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from tqdm import tqdm
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .misc import sizeof_fmt
|
||||
|
||||
|
@ -14,8 +17,7 @@ from .misc import sizeof_fmt
|
|||
def download_file_from_google_drive(file_id, save_path):
|
||||
"""Download files from google drive.
|
||||
|
||||
Ref:
|
||||
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
||||
Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
|
||||
|
||||
Args:
|
||||
file_id (str): File id.
|
||||
|
@ -33,11 +35,9 @@ def download_file_from_google_drive(file_id, save_path):
|
|||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
# get file size
|
||||
response_file_size = session.get(
|
||||
URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
||||
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
||||
if 'Content-Range' in response_file_size.headers:
|
||||
file_size = int(
|
||||
response_file_size.headers['Content-Range'].split('/')[1])
|
||||
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
|
@ -51,10 +51,7 @@ def get_confirm_token(response):
|
|||
return None
|
||||
|
||||
|
||||
def save_response_content(response,
|
||||
destination,
|
||||
file_size=None,
|
||||
chunk_size=32768):
|
||||
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
||||
if file_size is not None:
|
||||
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
||||
|
||||
|
@ -68,9 +65,40 @@ def save_response_content(response,
|
|||
downloaded_size += chunk_size
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
|
||||
f'/ {readable_file_size}')
|
||||
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Load file form http url, will download models if necessary.
|
||||
|
||||
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
|
||||
Args:
|
||||
url (str): URL to be downloaded.
|
||||
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
||||
Default: None.
|
||||
progress (bool): Whether to show the download progress. Default: True.
|
||||
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded file.
|
||||
"""
|
||||
if model_dir is None: # use the pytorch hub_dir
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
|
|
|
@ -42,13 +42,11 @@ class MemcachedBackend(BaseStorageBackend):
|
|||
try:
|
||||
import mc
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Please install memcached to enable MemcachedBackend.')
|
||||
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
||||
|
||||
self.server_list_cfg = server_list_cfg
|
||||
self.client_cfg = client_cfg
|
||||
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
|
||||
self.client_cfg)
|
||||
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
||||
# mc.pyvector servers as a point which points to a memory cache
|
||||
self._mc_buffer = mc.pyvector()
|
||||
|
||||
|
@ -99,13 +97,7 @@ class LmdbBackend(BaseStorageBackend):
|
|||
_client (list): A list of several lmdb envs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
db_paths,
|
||||
client_keys='default',
|
||||
readonly=True,
|
||||
lock=False,
|
||||
readahead=False,
|
||||
**kwargs):
|
||||
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
||||
try:
|
||||
import lmdb
|
||||
except ImportError:
|
||||
|
@ -118,32 +110,22 @@ class LmdbBackend(BaseStorageBackend):
|
|||
self.db_paths = [str(v) for v in db_paths]
|
||||
elif isinstance(db_paths, str):
|
||||
self.db_paths = [str(db_paths)]
|
||||
assert len(client_keys) == len(self.db_paths), (
|
||||
'client_keys and db_paths should have the same length, '
|
||||
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
||||
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
||||
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
||||
|
||||
self._client = {}
|
||||
|
||||
for client, path in zip(client_keys, self.db_paths):
|
||||
self._client[client] = lmdb.open(
|
||||
path,
|
||||
readonly=readonly,
|
||||
lock=lock,
|
||||
readahead=readahead,
|
||||
map_size=8*1024*10485760,
|
||||
# max_readers=1,
|
||||
**kwargs)
|
||||
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
||||
|
||||
def get(self, filepath, client_key):
|
||||
"""Get values according to the filepath from one lmdb named client_key.
|
||||
|
||||
Args:
|
||||
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
||||
client_key (str): Used for distinguishing differnet lmdb envs.
|
||||
client_key (str): Used for distinguishing different lmdb envs.
|
||||
"""
|
||||
filepath = str(filepath)
|
||||
assert client_key in self._client, (f'client_key {client_key} is not '
|
||||
'in lmdb clients.')
|
||||
assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
|
||||
client = self._client[client_key]
|
||||
with client.begin(write=False) as txn:
|
||||
value_buf = txn.get(filepath.encode('ascii'))
|
||||
|
@ -174,9 +156,8 @@ class FileClient(object):
|
|||
|
||||
def __init__(self, backend='disk', **kwargs):
|
||||
if backend not in self._backends:
|
||||
raise ValueError(
|
||||
f'Backend {backend} is not supported. Currently supported ones'
|
||||
f' are {list(self._backends.keys())}')
|
||||
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
||||
f' are {list(self._backends.keys())}')
|
||||
self.backend = backend
|
||||
self.client = self._backends[backend](**kwargs)
|
||||
|
||||
|
|
|
@ -27,8 +27,7 @@ def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
|
|||
assert concat_axis in [0, 1]
|
||||
cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
|
||||
if cat_flow.ndim != 2:
|
||||
raise IOError(f'{flow_path} is not a valid quantized flow file, '
|
||||
f'its dimension is {cat_flow.ndim}.')
|
||||
raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
|
||||
assert cat_flow.shape[concat_axis] % 2 == 0
|
||||
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
|
||||
flow = dequantize_flow(dx, dy, *args, **kwargs)
|
||||
|
@ -40,8 +39,7 @@ def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
|
|||
raise IOError(f'Invalid flow file: {flow_path}')
|
||||
else:
|
||||
if header != 'PIEH':
|
||||
raise IOError(f'Invalid flow file: {flow_path}, '
|
||||
'header does not contain PIEH')
|
||||
raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
|
||||
|
||||
w = np.fromfile(f, np.int32, 1).squeeze()
|
||||
h = np.fromfile(f, np.int32, 1).squeeze()
|
||||
|
@ -77,8 +75,8 @@ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
|||
assert concat_axis in [0, 1]
|
||||
dx, dy = quantize_flow(flow, *args, **kwargs)
|
||||
dxdy = np.concatenate((dx, dy), axis=concat_axis)
|
||||
os.makedirs(filename, exist_ok=True)
|
||||
cv2.imwrite(dxdy, filename)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
cv2.imwrite(filename, dxdy)
|
||||
|
||||
|
||||
def quantize_flow(flow, max_val=0.02, norm=True):
|
||||
|
@ -103,9 +101,7 @@ def quantize_flow(flow, max_val=0.02, norm=True):
|
|||
dx = dx / w # avoid inplace operations
|
||||
dy = dy / h
|
||||
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
|
||||
flow_comps = [
|
||||
quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
|
||||
]
|
||||
flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
|
||||
return tuple(flow_comps)
|
||||
|
||||
|
||||
|
@ -147,15 +143,12 @@ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
|||
tuple: Quantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(
|
||||
f'levels must be a positive integer, but got {levels}')
|
||||
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
||||
if min_val >= max_val:
|
||||
raise ValueError(
|
||||
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
|
||||
arr = np.clip(arr, min_val, max_val) - min_val
|
||||
quantized_arr = np.minimum(
|
||||
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
||||
quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
||||
|
||||
return quantized_arr
|
||||
|
||||
|
@ -174,13 +167,10 @@ def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
|||
tuple: Dequantized array.
|
||||
"""
|
||||
if not (isinstance(levels, int) and levels > 1):
|
||||
raise ValueError(
|
||||
f'levels must be a positive integer, but got {levels}')
|
||||
raise ValueError(f'levels must be a positive integer, but got {levels}')
|
||||
if min_val >= max_val:
|
||||
raise ValueError(
|
||||
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
||||
|
||||
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
|
||||
min_val) / levels + min_val
|
||||
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
|
||||
|
||||
return dequantized_arr
|
||||
|
|
|
@ -27,6 +27,8 @@ def img2tensor(imgs, bgr2rgb=True, float32=True):
|
|||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
|
@ -60,11 +62,8 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
|||
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
||||
shape (H x W). The channel order is BGR.
|
||||
"""
|
||||
if not (torch.is_tensor(tensor) or
|
||||
(isinstance(tensor, list)
|
||||
and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(
|
||||
f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if torch.is_tensor(tensor):
|
||||
tensor = [tensor]
|
||||
|
@ -75,9 +74,7 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
|||
|
||||
n_dim = _tensor.dim()
|
||||
if n_dim == 4:
|
||||
img_np = make_grid(
|
||||
_tensor, nrow=int(math.sqrt(_tensor.size(0))),
|
||||
normalize=False).numpy()
|
||||
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
|
@ -86,14 +83,13 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
|||
img_np = img_np.transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image
|
||||
img_np = np.squeeze(img_np, axis=2)
|
||||
elif img_np.shape[2] == 3:
|
||||
else:
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 2:
|
||||
img_np = _tensor.numpy()
|
||||
else:
|
||||
raise TypeError('Only support 4D, 3D or 2D tensor. '
|
||||
f'But received with dimension: {n_dim}')
|
||||
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
||||
if out_type == np.uint8:
|
||||
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||
img_np = (img_np * 255.0).round()
|
||||
|
@ -104,6 +100,23 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
|||
return result
|
||||
|
||||
|
||||
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
||||
"""This implementation is slightly faster than tensor2img.
|
||||
It now only supports torch tensor with shape (1, c, h, w).
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
||||
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
"""
|
||||
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
||||
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
||||
output = output.type(torch.uint8).cpu().numpy()
|
||||
if rgb2bgr:
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
|
||||
def imfrombytes(content, flag='color', float32=False):
|
||||
"""Read an image from bytes.
|
||||
|
||||
|
@ -118,31 +131,12 @@ def imfrombytes(content, flag='color', float32=False):
|
|||
ndarray: Loaded image array.
|
||||
"""
|
||||
img_np = np.frombuffer(content, np.uint8)
|
||||
imread_flags = {
|
||||
'color': cv2.IMREAD_COLOR,
|
||||
'grayscale': cv2.IMREAD_GRAYSCALE,
|
||||
'unchanged': cv2.IMREAD_UNCHANGED
|
||||
}
|
||||
if img_np is None:
|
||||
raise Exception('None .. !!!')
|
||||
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
||||
img = cv2.imdecode(img_np, imread_flags[flag])
|
||||
if float32:
|
||||
img = img.astype(np.float32) / 255.
|
||||
return img
|
||||
|
||||
def padding(img_lq, img_gt, gt_size):
|
||||
h, w, _ = img_lq.shape
|
||||
|
||||
h_pad = max(0, gt_size - h)
|
||||
w_pad = max(0, gt_size - w)
|
||||
|
||||
if h_pad == 0 and w_pad == 0:
|
||||
return img_lq, img_gt
|
||||
|
||||
img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
||||
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
||||
# print('img_lq', img_lq.shape, img_gt.shape)
|
||||
return img_lq, img_gt
|
||||
|
||||
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||
"""Write image to file.
|
||||
|
@ -160,7 +154,9 @@ def imwrite(img, file_path, params=None, auto_mkdir=True):
|
|||
if auto_mkdir:
|
||||
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
return cv2.imwrite(file_path, img, params)
|
||||
ok = cv2.imwrite(file_path, img, params)
|
||||
if not ok:
|
||||
raise IOError('Failed in writing images.')
|
||||
|
||||
|
||||
def crop_border(imgs, crop_border):
|
||||
|
@ -177,10 +173,21 @@ def crop_border(imgs, crop_border):
|
|||
return imgs
|
||||
else:
|
||||
if isinstance(imgs, list):
|
||||
return [
|
||||
v[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
for v in imgs
|
||||
]
|
||||
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
||||
else:
|
||||
return imgs[crop_border:-crop_border, crop_border:-crop_border,
|
||||
...]
|
||||
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
|
||||
|
||||
def padding(img_lq, img_gt, gt_size):
|
||||
h, w, _ = img_lq.shape
|
||||
|
||||
h_pad = max(0, gt_size - h)
|
||||
w_pad = max(0, gt_size - w)
|
||||
|
||||
if h_pad == 0 and w_pad == 0:
|
||||
return img_lq, img_gt
|
||||
|
||||
img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
||||
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
||||
# print('img_lq', img_lq.shape, img_gt.shape)
|
||||
return img_lq, img_gt
|
||||
|
|
|
@ -24,10 +24,13 @@ def make_lmdb_from_imgs(data_path,
|
|||
"""Make lmdb from images.
|
||||
|
||||
Contents of lmdb. The file structure is:
|
||||
example.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
::
|
||||
|
||||
example.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
@ -64,11 +67,10 @@ def make_lmdb_from_imgs(data_path,
|
|||
estimated size from images. Default: None
|
||||
"""
|
||||
|
||||
assert len(img_path_list) == len(keys), (
|
||||
'img_path_list and keys should have the same length, '
|
||||
f'but got {len(img_path_list)} and {len(keys)}')
|
||||
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
||||
f'but got {len(img_path_list)} and {len(keys)}')
|
||||
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
||||
print(f'Total images: {len(img_path_list)}')
|
||||
print(f'Totoal images: {len(img_path_list)}')
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
|
@ -90,10 +92,7 @@ def make_lmdb_from_imgs(data_path,
|
|||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(img_path_list, keys):
|
||||
pool.apply_async(
|
||||
read_img_worker,
|
||||
args=(osp.join(data_path, path), key, compress_level),
|
||||
callback=callback)
|
||||
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
pbar.close()
|
||||
|
@ -102,10 +101,8 @@ def make_lmdb_from_imgs(data_path,
|
|||
# create lmdb environment
|
||||
if map_size is None:
|
||||
# obtain data size for one image
|
||||
img = cv2.imread(
|
||||
osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
||||
_, img_byte = cv2.imencode(
|
||||
'.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
data_size_per_img = img_byte.nbytes
|
||||
print('Data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(img_path_list)
|
||||
|
@ -125,8 +122,7 @@ def make_lmdb_from_imgs(data_path,
|
|||
img_byte = dataset[key]
|
||||
h, w, c = shapes[key]
|
||||
else:
|
||||
_, img_byte, img_shape = read_img_worker(
|
||||
osp.join(data_path, path), key, compress_level)
|
||||
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
||||
h, w, c = img_shape
|
||||
|
||||
txn.put(key_byte, img_byte)
|
||||
|
@ -162,8 +158,7 @@ def read_img_worker(path, key, compress_level):
|
|||
c = 1
|
||||
else:
|
||||
h, w, c = img.shape
|
||||
_, img_byte = cv2.imencode('.png', img,
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
return (key, img_byte, (h, w, c))
|
||||
|
||||
|
||||
|
@ -178,11 +173,7 @@ class LmdbMaker():
|
|||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lmdb_path,
|
||||
map_size=1024**4,
|
||||
batch=5000,
|
||||
compress_level=1):
|
||||
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
|
|
|
@ -10,6 +10,43 @@ import time
|
|||
|
||||
from .dist_util import get_dist_info, master_only
|
||||
|
||||
initialized_logger = {}
|
||||
|
||||
|
||||
class AvgTimer():
|
||||
|
||||
def __init__(self, window=200):
|
||||
self.window = window # average window
|
||||
self.current_time = 0
|
||||
self.total_time = 0
|
||||
self.count = 0
|
||||
self.avg_time = 0
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
self.start_time = self.tic = time.time()
|
||||
|
||||
def record(self):
|
||||
self.count += 1
|
||||
self.toc = time.time()
|
||||
self.current_time = self.toc - self.tic
|
||||
self.total_time += self.current_time
|
||||
# calculate average time
|
||||
self.avg_time = self.total_time / self.count
|
||||
|
||||
# reset
|
||||
if self.count > self.window:
|
||||
self.count = 0
|
||||
self.total_time = 0
|
||||
|
||||
self.tic = time.time()
|
||||
|
||||
def get_current_time(self):
|
||||
return self.current_time
|
||||
|
||||
def get_avg_time(self):
|
||||
return self.avg_time
|
||||
|
||||
|
||||
class MessageLogger():
|
||||
"""Message logger for printing.
|
||||
|
@ -34,6 +71,9 @@ class MessageLogger():
|
|||
self.start_time = time.time()
|
||||
self.logger = get_root_logger()
|
||||
|
||||
def reset_start_time(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
@master_only
|
||||
def __call__(self, log_vars):
|
||||
"""Format logging message.
|
||||
|
@ -50,11 +90,9 @@ class MessageLogger():
|
|||
# epoch, iter, learning rates
|
||||
epoch = log_vars.pop('epoch')
|
||||
current_iter = log_vars.pop('iter')
|
||||
total_iter = log_vars.pop('total_iter')
|
||||
lrs = log_vars.pop('lrs')
|
||||
|
||||
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
|
||||
f'iter:{current_iter:8,d}, lr:(')
|
||||
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
|
||||
for v in lrs:
|
||||
message += f'{v:.3e},'
|
||||
message += ')] '
|
||||
|
@ -76,17 +114,10 @@ class MessageLogger():
|
|||
message += f'{k}: {v:.4e} '
|
||||
# tensorboard logger
|
||||
if self.use_tb_logger and 'debug' not in self.exp_name:
|
||||
normed_step = 10000 * (current_iter / total_iter)
|
||||
normed_step = int(normed_step)
|
||||
|
||||
if k.startswith('l_'):
|
||||
self.tb_logger.add_scalar(f'losses/{k}', v, normed_step)
|
||||
elif k.startswith('m_'):
|
||||
self.tb_logger.add_scalar(f'metrics/{k}', v, normed_step)
|
||||
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
||||
else:
|
||||
assert 1 == 0
|
||||
# else:
|
||||
# self.tb_logger.add_scalar(k, v, current_iter)
|
||||
self.tb_logger.add_scalar(k, v, current_iter)
|
||||
self.logger.info(message)
|
||||
|
||||
|
||||
|
@ -101,7 +132,7 @@ def init_tb_logger(log_dir):
|
|||
def init_wandb_logger(opt):
|
||||
"""We now only use wandb to sync tensorboard log."""
|
||||
import wandb
|
||||
logger = logging.getLogger('basicsr')
|
||||
logger = get_root_logger()
|
||||
|
||||
project = opt['logger']['wandb']['project']
|
||||
resume_id = opt['logger']['wandb'].get('resume_id')
|
||||
|
@ -113,20 +144,12 @@ def init_wandb_logger(opt):
|
|||
wandb_id = wandb.util.generate_id()
|
||||
resume = 'never'
|
||||
|
||||
wandb.init(
|
||||
id=wandb_id,
|
||||
resume=resume,
|
||||
name=opt['name'],
|
||||
config=opt,
|
||||
project=project,
|
||||
sync_tensorboard=True)
|
||||
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
||||
|
||||
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
||||
|
||||
|
||||
def get_root_logger(logger_name='basicsr',
|
||||
log_level=logging.INFO,
|
||||
log_file=None):
|
||||
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
||||
"""Get the root logger.
|
||||
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
|
@ -146,20 +169,25 @@ def get_root_logger(logger_name='basicsr',
|
|||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
# if the logger has been initialized, just return it
|
||||
if logger.hasHandlers():
|
||||
if logger_name in initialized_logger:
|
||||
return logger
|
||||
|
||||
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
||||
logging.basicConfig(format=format_str, level=log_level)
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter(format_str))
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
rank, _ = get_dist_info()
|
||||
if rank != 0:
|
||||
logger.setLevel('ERROR')
|
||||
elif log_file is not None:
|
||||
logger.setLevel(log_level)
|
||||
# add file handler
|
||||
file_handler = logging.FileHandler(log_file, 'w')
|
||||
file_handler.setFormatter(logging.Formatter(format_str))
|
||||
file_handler.setLevel(log_level)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
initialized_logger[logger_name] = True
|
||||
return logger
|
||||
|
||||
|
||||
|
|
|
@ -15,13 +15,11 @@ def cubic(x):
|
|||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
||||
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx +
|
||||
2) * (((absx > 1) *
|
||||
(absx <= 2)).type_as(absx))
|
||||
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
||||
(absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel,
|
||||
kernel_width, antialiasing):
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
"""Calculate weights and indices, used for imresize function.
|
||||
|
||||
Args:
|
||||
|
@ -56,8 +54,8 @@ def calculate_weights_indices(in_length, out_length, scale, kernel,
|
|||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(
|
||||
0, p - 1, p).view(1, p).expand(out_length, p)
|
||||
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
||||
out_length, p)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
|
@ -109,11 +107,18 @@ def imresize(img, scale, antialiasing=True):
|
|||
Returns:
|
||||
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
||||
"""
|
||||
squeeze_flag = False
|
||||
if type(img).__module__ == np.__name__: # numpy type
|
||||
numpy_type = True
|
||||
if img.ndim == 2:
|
||||
img = img[:, :, None]
|
||||
squeeze_flag = True
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
else:
|
||||
numpy_type = False
|
||||
if img.ndim == 2:
|
||||
img = img.unsqueeze(0)
|
||||
squeeze_flag = True
|
||||
|
||||
in_c, in_h, in_w = img.size()
|
||||
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
||||
|
@ -121,10 +126,10 @@ def imresize(img, scale, antialiasing=True):
|
|||
kernel = 'cubic'
|
||||
|
||||
# get weights and indices
|
||||
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(
|
||||
in_h, out_h, scale, kernel, kernel_width, antialiasing)
|
||||
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(
|
||||
in_w, out_w, scale, kernel, kernel_width, antialiasing)
|
||||
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
||||
|
@ -145,8 +150,7 @@ def imresize(img, scale, antialiasing=True):
|
|||
for i in range(out_h):
|
||||
idx = int(indices_h[i][0])
|
||||
for j in range(in_c):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
|
||||
0, 1).mv(weights_h[i])
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
|
@ -168,200 +172,13 @@ def imresize(img, scale, antialiasing=True):
|
|||
for i in range(out_w):
|
||||
idx = int(indices_w[i][0])
|
||||
for j in range(in_c):
|
||||
out_2[j, :, i] = out_1_aug[j, :,
|
||||
idx:idx + kernel_width].mv(weights_w[i])
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
||||
|
||||
if squeeze_flag:
|
||||
out_2 = out_2.squeeze(0)
|
||||
if numpy_type:
|
||||
out_2 = out_2.numpy().transpose(1, 2, 0)
|
||||
out_2 = out_2.numpy()
|
||||
if not squeeze_flag:
|
||||
out_2 = out_2.transpose(1, 2, 0)
|
||||
|
||||
return out_2
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=False):
|
||||
"""Convert a RGB image to YCbCr image.
|
||||
|
||||
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
||||
[24.966, 112.0, -18.214]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def bgr2ycbcr(img, y_only=False):
|
||||
"""Convert a BGR image to YCbCr image.
|
||||
|
||||
The bgr version of rgb2ycbcr.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
"""Convert a YCbCr image to RGB image.
|
||||
|
||||
This function produces the same results as Matlab's ycbcr2rgb function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted RGB image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
|
||||
[0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [
|
||||
-222.921, 135.576, -276.836
|
||||
] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2bgr(img):
|
||||
"""Convert a YCbCr image to BGR image.
|
||||
|
||||
The bgr version of ycbcr2rgb.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted BGR image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
|
||||
[0.00791071, -0.00153632, 0],
|
||||
[0, -0.00318811, 0.00625893]]) * 255.0 + [
|
||||
-276.836, 135.576, -222.921
|
||||
] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def _convert_input_type_range(img):
|
||||
"""Convert the type and range of the input image.
|
||||
|
||||
It converts the input image to np.float32 type and range of [0, 1].
|
||||
It is mainly used for pre-processing the input image in colorspace
|
||||
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with type of np.float32 and range of
|
||||
[0, 1].
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = img.astype(np.float32)
|
||||
if img_type == np.float32:
|
||||
pass
|
||||
elif img_type == np.uint8:
|
||||
img /= 255.
|
||||
else:
|
||||
raise TypeError('The img type should be np.float32 or np.uint8, '
|
||||
f'but got {img_type}')
|
||||
return img
|
||||
|
||||
|
||||
def _convert_output_type_range(img, dst_type):
|
||||
"""Convert the type and range of the image according to dst_type.
|
||||
|
||||
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||
images will be converted to np.uint8 type with range [0, 255]. If
|
||||
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||
range [0, 1].
|
||||
It is mainly used for post-processing images in colorspace convertion
|
||||
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The image to be converted with np.float32 type and
|
||||
range [0, 255].
|
||||
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||
converts the image to np.uint8 type with range [0, 255]. If
|
||||
dst_type is np.float32, it converts the image to np.float32 type
|
||||
with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with desired type and range.
|
||||
"""
|
||||
if dst_type not in (np.uint8, np.float32):
|
||||
raise TypeError('The dst_type should be np.float32 or np.uint8, '
|
||||
f'but got {dst_type}')
|
||||
if dst_type == np.uint8:
|
||||
img = img.round()
|
||||
else:
|
||||
img /= 255.
|
||||
return img.astype(dst_type)
|
||||
|
|
|
@ -12,7 +12,6 @@ import torch
|
|||
from os import path as osp
|
||||
|
||||
from .dist_util import master_only
|
||||
from .logger import get_root_logger
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
|
@ -50,9 +49,9 @@ def make_exp_dirs(opt):
|
|||
else:
|
||||
mkdir_and_rename(path_opt.pop('results_root'))
|
||||
for key, path in path_opt.items():
|
||||
if ('strict_load' not in key) and ('pretrain_network'
|
||||
not in key) and ('resume'
|
||||
not in key):
|
||||
if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
|
||||
continue
|
||||
else:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
|
@ -69,7 +68,7 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
|||
Default: False.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative pathes.
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
|
@ -91,54 +90,12 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
|||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(
|
||||
entry.path, suffix=suffix, recursive=recursive)
|
||||
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
||||
|
||||
def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
keywords (str | tuple(str), optional): File keywords that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative pathes.
|
||||
"""
|
||||
|
||||
if (keywords is not None) and not isinstance(keywords, (str, tuple)):
|
||||
raise TypeError('"keywords" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, keywords, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if keywords is None:
|
||||
yield return_path
|
||||
elif return_path.find(keywords) > 0:
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(
|
||||
entry.path, keywords=keywords, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, keywords=keywords, recursive=recursive)
|
||||
|
||||
def check_resume(opt, resume_iter):
|
||||
"""Check resume states and pretrain_network paths.
|
||||
|
@ -147,7 +104,6 @@ def check_resume(opt, resume_iter):
|
|||
opt (dict): Options.
|
||||
resume_iter (int): Resume iteration.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
if opt['path']['resume_state']:
|
||||
# get all the networks
|
||||
networks = [key for key in opt.keys() if key.startswith('network_')]
|
||||
|
@ -156,17 +112,22 @@ def check_resume(opt, resume_iter):
|
|||
if opt['path'].get(f'pretrain_{network}') is not None:
|
||||
flag_pretrain = True
|
||||
if flag_pretrain:
|
||||
logger.warning(
|
||||
'pretrain_network path will be ignored during resuming.')
|
||||
print('pretrain_network path will be ignored during resuming.')
|
||||
# set pretrained model paths
|
||||
for network in networks:
|
||||
name = f'pretrain_{network}'
|
||||
basename = network.replace('network_', '')
|
||||
if opt['path'].get('ignore_resume_networks') is None or (
|
||||
basename not in opt['path']['ignore_resume_networks']):
|
||||
opt['path'][name] = osp.join(
|
||||
opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
||||
logger.info(f"Set {name} to {opt['path'][name]}")
|
||||
if opt['path'].get('ignore_resume_networks') is None or (network
|
||||
not in opt['path']['ignore_resume_networks']):
|
||||
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
||||
print(f"Set {name} to {opt['path'][name]}")
|
||||
|
||||
# change param_key to params in resume
|
||||
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
|
||||
for param_key in param_keys:
|
||||
if opt['path'][param_key] == 'params_ema':
|
||||
opt['path'][param_key] = 'params'
|
||||
print(f'Set {param_key} to params')
|
||||
|
||||
|
||||
def sizeof_fmt(size, suffix='B'):
|
||||
|
@ -177,7 +138,7 @@ def sizeof_fmt(size, suffix='B'):
|
|||
suffix (str): Suffix. Default: 'B'.
|
||||
|
||||
Return:
|
||||
str: Formated file siz.
|
||||
str: Formatted file size.
|
||||
"""
|
||||
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
||||
if abs(size) < 1024.0:
|
||||
|
|
|
@ -4,16 +4,23 @@
|
|||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import yaml
|
||||
from collections import OrderedDict
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.utils import set_random_seed
|
||||
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
|
||||
|
||||
|
||||
def ordered_yaml():
|
||||
"""Support OrderedDict for yaml.
|
||||
|
||||
Returns:
|
||||
yaml Loader and Dumper.
|
||||
tuple: yaml Loader and Dumper.
|
||||
"""
|
||||
try:
|
||||
from yaml import CDumper as Dumper
|
||||
|
@ -34,6 +41,189 @@ def ordered_yaml():
|
|||
return Loader, Dumper
|
||||
|
||||
|
||||
def yaml_load(f):
|
||||
"""Load yaml file or string.
|
||||
|
||||
Args:
|
||||
f (str): File path or a python string.
|
||||
|
||||
Returns:
|
||||
dict: Loaded dict.
|
||||
"""
|
||||
if os.path.isfile(f):
|
||||
with open(f, 'r') as f:
|
||||
return yaml.load(f, Loader=ordered_yaml()[0])
|
||||
else:
|
||||
return yaml.load(f, Loader=ordered_yaml()[0])
|
||||
|
||||
|
||||
def dict2str(opt, indent_level=1):
|
||||
"""dict to string for printing options.
|
||||
|
||||
Args:
|
||||
opt (dict): Option dict.
|
||||
indent_level (int): Indent level. Default: 1.
|
||||
|
||||
Return:
|
||||
(str): Option string for printing.
|
||||
"""
|
||||
msg = '\n'
|
||||
for k, v in opt.items():
|
||||
if isinstance(v, dict):
|
||||
msg += ' ' * (indent_level * 2) + k + ':['
|
||||
msg += dict2str(v, indent_level + 1)
|
||||
msg += ' ' * (indent_level * 2) + ']\n'
|
||||
else:
|
||||
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
||||
return msg
|
||||
|
||||
|
||||
def _postprocess_yml_value(value):
|
||||
# None
|
||||
if value == '~' or value.lower() == 'none':
|
||||
return None
|
||||
# bool
|
||||
if value.lower() == 'true':
|
||||
return True
|
||||
elif value.lower() == 'false':
|
||||
return False
|
||||
# !!float number
|
||||
if value.startswith('!!float'):
|
||||
return float(value.replace('!!float', ''))
|
||||
# number
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
|
||||
return float(value)
|
||||
# list
|
||||
if value.startswith('['):
|
||||
return eval(value)
|
||||
# str
|
||||
return value
|
||||
|
||||
|
||||
def parse_options(root_path, is_train=True):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
|
||||
parser.add_argument('--auto_resume', action='store_true')
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
|
||||
args = parser.parse_args()
|
||||
|
||||
# parse yml to dict
|
||||
opt = yaml_load(args.opt)
|
||||
|
||||
# distributed settings
|
||||
if args.launcher == 'none':
|
||||
opt['dist'] = False
|
||||
print('Disable distributed.', flush=True)
|
||||
else:
|
||||
opt['dist'] = True
|
||||
if args.launcher == 'slurm' and 'dist_params' in opt:
|
||||
init_dist(args.launcher, **opt['dist_params'])
|
||||
else:
|
||||
init_dist(args.launcher)
|
||||
opt['rank'], opt['world_size'] = get_dist_info()
|
||||
|
||||
# random seed
|
||||
seed = opt.get('manual_seed')
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
opt['manual_seed'] = seed
|
||||
set_random_seed(seed + opt['rank'])
|
||||
|
||||
# force to update yml options
|
||||
if args.force_yml is not None:
|
||||
for entry in args.force_yml:
|
||||
# now do not support creating new keys
|
||||
keys, value = entry.split('=')
|
||||
keys, value = keys.strip(), value.strip()
|
||||
value = _postprocess_yml_value(value)
|
||||
eval_str = 'opt'
|
||||
for key in keys.split(':'):
|
||||
eval_str += f'["{key}"]'
|
||||
eval_str += '=value'
|
||||
# using exec function
|
||||
exec(eval_str)
|
||||
|
||||
opt['auto_resume'] = args.auto_resume
|
||||
opt['is_train'] = is_train
|
||||
|
||||
# debug setting
|
||||
if args.debug and not opt['name'].startswith('debug'):
|
||||
opt['name'] = 'debug_' + opt['name']
|
||||
|
||||
if opt['num_gpu'] == 'auto':
|
||||
opt['num_gpu'] = torch.cuda.device_count()
|
||||
|
||||
# datasets
|
||||
for phase, dataset in opt['datasets'].items():
|
||||
# for multiple datasets, e.g., val_1, val_2; test_1, test_2
|
||||
phase = phase.split('_')[0]
|
||||
dataset['phase'] = phase
|
||||
if 'scale' in opt:
|
||||
dataset['scale'] = opt['scale']
|
||||
if dataset.get('dataroot_gt') is not None:
|
||||
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
||||
if dataset.get('dataroot_lq') is not None:
|
||||
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
||||
|
||||
# paths
|
||||
for key, val in opt['path'].items():
|
||||
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
|
||||
opt['path'][key] = osp.expanduser(val)
|
||||
|
||||
if is_train:
|
||||
experiments_root = opt['path'].get('experiments_root')
|
||||
if experiments_root is None:
|
||||
experiments_root = osp.join(root_path, 'experiments')
|
||||
experiments_root = osp.join(experiments_root, opt['name'])
|
||||
|
||||
opt['path']['experiments_root'] = experiments_root
|
||||
opt['path']['models'] = osp.join(experiments_root, 'models')
|
||||
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
|
||||
opt['path']['log'] = experiments_root
|
||||
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
|
||||
|
||||
# change some options for debug mode
|
||||
if 'debug' in opt['name']:
|
||||
if 'val' in opt:
|
||||
opt['val']['val_freq'] = 8
|
||||
opt['logger']['print_freq'] = 1
|
||||
opt['logger']['save_checkpoint_freq'] = 8
|
||||
else: # test
|
||||
results_root = opt['path'].get('results_root')
|
||||
if results_root is None:
|
||||
results_root = osp.join(root_path, 'results')
|
||||
results_root = osp.join(results_root, opt['name'])
|
||||
|
||||
opt['path']['results_root'] = results_root
|
||||
opt['path']['log'] = results_root
|
||||
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
||||
|
||||
return opt, args
|
||||
|
||||
|
||||
@master_only
|
||||
def copy_opt_file(opt_file, experiments_root):
|
||||
# copy the yml file to the experiment root
|
||||
import sys
|
||||
import time
|
||||
from shutil import copyfile
|
||||
cmd = ' '.join(sys.argv)
|
||||
filename = osp.join(experiments_root, osp.basename(opt_file))
|
||||
copyfile(opt_file, filename)
|
||||
|
||||
with open(filename, 'r+') as f:
|
||||
lines = f.readlines()
|
||||
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
|
||||
f.seek(0)
|
||||
f.writelines(lines)
|
||||
|
||||
|
||||
def parse(opt_path, is_train=True):
|
||||
"""Parse option file.
|
||||
|
||||
|
@ -94,24 +284,3 @@ def parse(opt_path, is_train=True):
|
|||
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
||||
|
||||
return opt
|
||||
|
||||
|
||||
def dict2str(opt, indent_level=1):
|
||||
"""dict to string for printing options.
|
||||
|
||||
Args:
|
||||
opt (dict): Option dict.
|
||||
indent_level (int): Indent level. Default: 1.
|
||||
|
||||
Return:
|
||||
(str): Option string for printing.
|
||||
"""
|
||||
msg = '\n'
|
||||
for k, v in opt.items():
|
||||
if isinstance(v, dict):
|
||||
msg += ' ' * (indent_level * 2) + k + ':['
|
||||
msg += dict2str(v, indent_level + 1)
|
||||
msg += ' ' * (indent_level * 2) + ']\n'
|
||||
else:
|
||||
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
||||
return msg
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# GENERATED VERSION FILE
|
||||
# TIME: Mon Apr 18 21:35:20 2022
|
||||
__version__ = '1.2.0+386ca20'
|
||||
# TIME: Sat Apr 20 12:26:03 2024
|
||||
__version__ = '1.2.0+2b4af71'
|
||||
__gitsha__ = '2b4af71'
|
||||
short_version = '1.2.0'
|
||||
version_info = (1, 2, 0)
|
||||
|
|
3
setup.py
3
setup.py
|
@ -71,6 +71,7 @@ def write_version_py():
|
|||
content = """# GENERATED VERSION FILE
|
||||
# TIME: {}
|
||||
__version__ = '{}'
|
||||
__gitsha__ = '{}'
|
||||
short_version = '{}'
|
||||
version_info = ({})
|
||||
"""
|
||||
|
@ -81,7 +82,7 @@ version_info = ({})
|
|||
[x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
||||
VERSION = SHORT_VERSION + '+' + sha
|
||||
|
||||
version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
|
||||
version_file_str = content.format(time.asctime(), VERSION, sha, SHORT_VERSION,
|
||||
VERSION_INFO)
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(version_file_str)
|
||||
|
|
Loading…
Reference in New Issue