[Feature]: parrots add parrots/fused_bias & upfirdn2d (#989)

pull/997/head
Wang Xiaolin 2021-04-27 13:49:50 +08:00 committed by GitHub
parent 1738d50c17
commit 0fc19b4692
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 720 additions and 49 deletions

View File

@ -0,0 +1,26 @@
// Modified from
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act,
int grad, float alpha, float scale);
#endif
torch::Tensor fused_bias_leakyrelu(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act,
int grad, float alpha, float scale) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_leakyrelu_op(input, bias, refer, act, grad, alpha, scale);
#else
AT_ERROR("Fused bias leakyrelu is not compiled with GPU support");
#endif
}

View File

@ -0,0 +1,109 @@
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
template <typename scalar_t>
static __global__ void fused_bias_act_kernel(
scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
const scalar_t *p_ref, int act, int grad, scalar_t alpha, scalar_t scale,
int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
// act = 1: linear layer
// act = 3: leaky relu layer
// grad = 0: direct forward path
// grad = 1: first order deviation
// grad = 2: second order deviation
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;
case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act,
int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});
return y;
}

View File

@ -0,0 +1,40 @@
#include <torch/extension.h>
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
using namespace at;
using namespace parrots;
torch::Tensor fused_bias_leakyrelu(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act,
int grad, float alpha, float scale);
void fused_bias_leakyrelu_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
int act, grad;
float alpha, scale;
SSAttrs(attr)
.get<int>("act", act)
.get<int>("grad", grad)
.get<float>("alpha", alpha)
.get<float>("scale", scale)
.done();
const auto &input = buildATensor(ctx, ins[0]);
const auto &bias = buildATensor(ctx, ins[1]);
const auto &refer = buildATensor(ctx, ins[2]);
auto out = fused_bias_leakyrelu(input, bias, refer, act, grad, alpha, scale);
updateDArray(ctx, out, outs[0]);
}
PARROTS_EXTENSION_REGISTER(fused_bias_leakyrelu)
.attr("act")
.attr("grad")
.attr("alpha")
.attr("scale")
.input(3)
.output(1)
.apply(fused_bias_leakyrelu_parrots)
.done();

View File

@ -0,0 +1,25 @@
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);
#endif
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1);
#else
AT_ERROR("UpFirDn2d is not compiled with GPU support");
#endif
}

View File

@ -0,0 +1,370 @@
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}

View File

@ -0,0 +1,46 @@
#include <torch/extension.h>
#include <parrots/compute/aten.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/ssattrs.hpp>
using namespace at;
using namespace parrots;
torch::Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x,
int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1);
void upfirdn2d_parrots(CudaContext &ctx, const SSElement &attr,
const OperatorBase::in_list_t &ins,
OperatorBase::out_list_t &outs) {
int up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1;
const auto &input = buildATensor(ctx, ins[0]);
const auto &kernel = buildATensor(ctx, ins[1]);
SSAttrs(attr)
.get("up_x", up_x)
.get("up_y", up_y)
.get("down_x", down_x)
.get("down_y", down_y)
.get("pad_x0", pad_x0)
.get("pad_x1", pad_x1)
.get("pad_y0", pad_y0)
.get("pad_y1", pad_y1)
.done();
auto out = upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0,
pad_x1, pad_y0, pad_y1);
updateDArray(ctx, out, outs[0]);
}
PARROTS_EXTENSION_REGISTER(upfirdn2d)
.attr("up_x")
.attr("up_y")
.attr("down_x")
.attr("down_y")
.attr("pad_x0")
.attr("pad_x1")
.attr("pad_y0")
.attr("pad_y1")
.input(2)
.output(1)
.apply(upfirdn2d_parrots)
.done();

View File

@ -25,9 +25,14 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
empty = grad_output.new_empty(0)
grad_input = ext_module.fused_bias_leakyrelu(grad_output, empty, out,
3, 1, negative_slope,
scale)
grad_input = ext_module.fused_bias_leakyrelu(
grad_output,
empty,
out,
act=3,
grad=1,
alpha=negative_slope,
scale=scale)
dim = [0]
@ -46,8 +51,13 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
# the first part is zero. Thus, we direct consider the second part
# which is similar with the first order deviation in implementation.
gradgrad_out = ext_module.fused_bias_leakyrelu(
gradgrad_input, gradgrad_bias.to(out.dtype), out, 3, 1,
ctx.negative_slope, ctx.scale)
gradgrad_input,
gradgrad_bias,
out,
act=3,
grad=1,
alpha=ctx.negative_slope,
scale=ctx.scale)
return gradgrad_out, None, None, None
@ -57,8 +67,15 @@ class FusedBiasLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
out = ext_module.fused_bias_leakyrelu(input, bias, empty, 3, 0,
negative_slope, scale)
out = ext_module.fused_bias_leakyrelu(
input,
bias,
empty,
act=3,
grad=0,
alpha=negative_slope,
scale=scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale

View File

@ -24,15 +24,14 @@ class UpFirDn2dBackward(Function):
grad_input = upfirdn2d_ext.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
up_x=down_x,
up_y=down_y,
down_x=up_x,
down_y=up_y,
pad_x0=g_pad_x0,
pad_x1=g_pad_x1,
pad_y0=g_pad_y0,
pad_y1=g_pad_y1)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
in_size[3])
@ -63,15 +62,14 @@ class UpFirDn2dBackward(Function):
gradgrad_out = upfirdn2d_ext.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
up_x=ctx.up_x,
up_y=ctx.up_y,
down_x=ctx.down_x,
down_y=ctx.down_y,
pad_x0=ctx.pad_x0,
pad_x1=ctx.pad_x1,
pad_y0=ctx.pad_y0,
pad_y1=ctx.pad_y1)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
# ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
@ -111,8 +109,17 @@ class UpFirDn2d(Function):
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x,
down_y, pad_x0, pad_x1, pad_y0, pad_y1)
out = upfirdn2d_ext.upfirdn2d(
input,
kernel,
up_x=up_x,
up_y=up_y,
down_x=down_x,
down_y=down_y,
pad_x0=pad_x0,
pad_x1=pad_x1,
pad_y0=pad_y0,
pad_y1=pad_y1)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)

View File

@ -19,7 +19,7 @@ else:
'nms', 'softnms', 'nms_match', 'nms_rotated', 'top_pool_forward',
'top_pool_backward', 'bottom_pool_forward', 'bottom_pool_backward',
'left_pool_forward', 'left_pool_backward', 'right_pool_forward',
'right_pool_backward'
'right_pool_backward', 'fused_bias_leakyrelu', 'upfirdn2d'
]
def load_ext(name, funcs):

View File

@ -1,6 +1,12 @@
import pytest
import torch
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = False
class TestFusedBiasLeakyReLU(object):
@ -16,13 +22,22 @@ class TestFusedBiasLeakyReLU(object):
def test_gradient(self):
from mmcv.ops import FusedBiasLeakyReLU
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
eps=1e-4,
atol=1e-3)
if _USING_PARROTS:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
eps=1e-4,
atol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.skipif(
not torch.cuda.is_available() or _USING_PARROTS,
reason='requires cuda')
def test_gradgradient(self):
from mmcv.ops import FusedBiasLeakyReLU

View File

@ -1,6 +1,12 @@
import pytest
import torch
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
except ImportError:
from torch.autograd import gradcheck, gradgradcheck
_USING_PARROTS = False
class TestUpFirDn2d(object):
@ -25,17 +31,27 @@ class TestUpFirDn2d(object):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_upfirdn2d(self):
from mmcv.ops import upfirdn2d
if _USING_PARROTS:
gradcheck(
upfirdn2d,
(self.input_tensor.cuda(),
self.kernel.type_as(
self.input_tensor).cuda(), self.factor, 1, self.pad),
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
upfirdn2d,
(self.input_tensor.cuda(),
self.kernel.type_as(
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)
gradcheck(
upfirdn2d,
(self.input_tensor.cuda(), self.kernel.type_as(
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)
gradgradcheck(
upfirdn2d,
(self.input_tensor.cuda(), self.kernel.type_as(
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)
gradgradcheck(
upfirdn2d,
(self.input_tensor.cuda(),
self.kernel.type_as(
self.input_tensor).cuda(), self.factor, 1, self.pad),
eps=1e-4,
atol=1e-3)