mirror of https://github.com/open-mmlab/mmcv.git
update ca_forward_kernel (#1144)
parent
7b150fab34
commit
94818ad136
|
@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
|
|||
dim3 threads(32, 32);
|
||||
int d1 = (w + threads.x - 1) / threads.x;
|
||||
int d2 = (h + threads.y - 1) / threads.y;
|
||||
int d3 = h + w;
|
||||
dim3 blocks(d1, d2, d3);
|
||||
int d3 = h + w - 1;
|
||||
dim3 blocks(d1, d2, d3 * n);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
|
||||
ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
|
@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
|
|||
dim3 threads(32, 32);
|
||||
int d1 = (w + threads.x - 1) / threads.x;
|
||||
int d2 = (h + threads.y - 1) / threads.y;
|
||||
int d3 = c;
|
||||
int d3 = c * n;
|
||||
dim3 blocks(d1, d2, d3);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
|
||||
|
@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
|
|||
dim3 threads(32, 32);
|
||||
int d1 = (w + threads.x - 1) / threads.x;
|
||||
int d2 = (h + threads.y - 1) / threads.y;
|
||||
int d3 = c;
|
||||
int d3 = c * n;
|
||||
dim3 blocks(d1, d2, d3);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
|
||||
|
@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
|
|||
dim3 threads(32, 32);
|
||||
int d1 = (w + threads.x - 1) / threads.x;
|
||||
int d2 = (h + threads.y - 1) / threads.y;
|
||||
int d3 = h + w;
|
||||
dim3 blocks(d1, d2, d3);
|
||||
int d3 = h + w - 1;
|
||||
dim3 blocks(d1, d2, d3 * n);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
|
||||
|
@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
|
|||
g.contiguous().data_ptr<scalar_t>(),
|
||||
dw.contiguous().data_ptr<scalar_t>(), n, c, h, w);
|
||||
});
|
||||
|
||||
d3 = c * n;
|
||||
blocks = dim3(d1, d2, d3);
|
||||
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
|
||||
ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
dout.contiguous().data_ptr<scalar_t>(),
|
||||
|
|
Loading…
Reference in New Issue