Qingpeng Li 2025-04-14 11:49:49 +00:00 committed by GitHub
commit 26d804b8e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 217 additions and 144 deletions

View File

@ -10,63 +10,101 @@
template <typename T>
__global__ void softmax_focal_loss_forward_cuda_kernel(
const int nthreads, const T* softmax, const int64_t* target,
const T* weight, T* output, const T gamma, const T alpha,
const int num_classes) {
const int nthreads, const T* __restrict__ log_softmax_prob,
const int64_t* __restrict__ target, const T* __restrict__ weight,
T* __restrict__ output,
const T gamma, const T alpha, const int num_classes) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int64_t label = target[index];
T pred = softmax[index * num_classes + label];
const int n = index / num_classes;
const int c = index % num_classes;
if (label >= 0) {
output[index] =
-alpha * pow((T)1. - pred, gamma) * log(max(pred, (T)FLT_MIN));
// focal loss
// FL(p) = - alpha * (1-p)^gamma * log(p) if curr_class == label
//
// note that log_softmax_prob is calculated in Python part
// by using PyTorch API F.log_softmax()
const int64_t label = target[n];
if (c == label) {
const T w = (weight != NULL) ? weight[label] : T(1);
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;
const T log_pred = log_softmax_prob[index];
const T pred = exp(log_pred);
output[index] = -alpha_fac * pow(1 - pred, gamma) * log_pred;
} else {
output[index] = 0;
}
if (weight != NULL) {
output[index] *= weight[label];
}
}
}
template <typename T>
__global__ void softmax_focal_loss_backward_cuda1_kernel(
const int nthreads, const T* softmax, const int64_t* target,
const T* weight, T* buff, const T gamma, const T alpha,
const int num_classes) {
__global__ void softmax_focal_loss_backward_cuda_kernel(
const int nthreads, const T* __restrict__ log_softmax_prob,
const int64_t* __restrict__ target, const T* __restrict__ weight,
T* __restrict__ sum_buff_along_class, T* __restrict__ grad_input,
const T gamma, const T alpha, const int num_classes) {
// forward node: x ----> p ----> FL
// func: SM FL
//
// backward node: x <---- p <---- FL
// index: j i FL
//
// For simplicity, the alpha of FL is ignored here
// dFL/dp = - [((1-p)^gamma) / p
// - gamma * (1-p)^(gamma-1) * log(p)]
// dp_i/dx_j = dSM/dx_j
// = p_i * (1-p_j) i==j;
// p_i * (0-p_j) i!=j;
// = p_i * (delta - p_j) where delta is Kronecker delta
//
// Replacing the p of dFL/dp with p_i, then
// dFL/dx_j = dFL/dp_i * dp_i/dx_j
// = - (delta - p_j) * [ (1-p_i)^gamma
// - gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
// = (delta - p_j) * [- (1-p_i)^gamma +
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
//
// Let B_i denote [- (1-p_i)^gamma +
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i],
// and indices {i} is summed for all classes at index j
// since x_j received all the gradients from {p_i}.
// Then, dFL/dx_j = sum_i{ (delta - p_j) * B_i }
// = sum_i{ delta*B_i - p_j*B_i }
// = B_j - (p_j * sum_i{B_i})
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int64_t label = target[index];
T pred = softmax[index * num_classes + label];
// B_i
const int n = index / num_classes;
const int c = index % num_classes;
if (label >= 0) {
buff[index] = alpha * (-pow((T)1. - pred, gamma) +
gamma * pow((T)1. - pred, gamma - 1) * pred *
log(max(pred, (T)FLT_MIN)));
} else {
buff[index] = 0;
}
if (weight != NULL) {
buff[index] *= weight[label];
}
}
}
const int64_t label = target[n];
if (c == label) {
const T w = (weight != NULL) ? weight[label] : T(1);
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;
template <typename T>
__global__ void softmax_focal_loss_backward_cuda2_kernel(
const int nthreads, const T* softmax, const int64_t* target, const T* buff,
T* grad_input, const int num_classes) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index / num_classes;
int c = index % num_classes;
int64_t label = target[n];
const T log_pred = log_softmax_prob[index];
const T pred = exp(log_pred);
const T one_minus_pred = 1 - pred;
if (label >= 0) {
T flag = (label == c ? (T)1. : (T)0.);
grad_input[index] = buff[n] * (flag - softmax[index]);
const T buff = alpha_fac * (
-pow(one_minus_pred, gamma) +
gamma * pow(one_minus_pred, gamma - 1) * log_pred * pred
);
grad_input[index] = buff;
sum_buff_along_class[n] += buff;
} else {
grad_input[index] = 0;
}
}
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// dFL/dx_j
const int n = index / num_classes;
const T pred = exp(log_softmax_prob[index]);
grad_input[index] -= pred * sum_buff_along_class[n];
}
}
#endif // SOFTMAX_FOCAL_LOSS_CUDA_KERNEL_CUH

View File

@ -409,13 +409,17 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
const float gamma,
const float alpha);
void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor output,
void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha);
void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor buff,
void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha);
@ -433,18 +437,26 @@ void sigmoid_focal_loss_backward_cuda(Tensor input, Tensor target,
gamma, alpha);
}
void softmax_focal_loss_forward_cuda(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
SoftmaxFocalLossForwardCUDAKernelLauncher(input, target, weight, output,
gamma, alpha);
void softmax_focal_loss_forward_cuda(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
SoftmaxFocalLossForwardCUDAKernelLauncher(log_softmax_prob, target, weight,
output, gamma, alpha);
}
void softmax_focal_loss_backward_cuda(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha) {
SoftmaxFocalLossBackwardCUDAKernelLauncher(input, target, weight, buff,
grad_input, gamma, alpha);
void softmax_focal_loss_backward_cuda(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
SoftmaxFocalLossBackwardCUDAKernelLauncher(log_softmax_prob, target, weight,
sum_buff_along_class, grad_input,
gamma, alpha);
}
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
@ -454,13 +466,20 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor grad_input,
float gamma, float alpha);
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);
void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha);
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha);
void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha);
REGISTER_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, CUDA,
sigmoid_focal_loss_forward_cuda);

View File

@ -47,64 +47,53 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
AT_CUDA_CHECK(cudaGetLastError());
}
void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor output,
void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
int output_size = output.numel();
int num_classes = softmax.size(1);
int num_classes = log_softmax_prob.size(1);
AT_ASSERTM(target.max().item<int64_t>() <= (int64_t)num_classes,
"target label should smaller or equal than num classes");
at::cuda::CUDAGuard device_guard(softmax.device());
at::cuda::CUDAGuard device_guard(log_softmax_prob.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
log_softmax_prob.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
softmax_focal_loss_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), gamma, alpha, num_classes);
output_size,
log_softmax_prob.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
}
void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor buff,
void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
int num_classes = softmax.size(1);
int output_size = grad_input.numel();
int num_classes = log_softmax_prob.size(1);
int output_size = buff.numel();
at::cuda::CUDAGuard device_guard(grad_input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda1_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda1_kernel<scalar_t>
log_softmax_prob.scalar_type(), "softmax_focal_loss_backward_cuda_kernel", [&] {
softmax_focal_loss_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
buff.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());
output_size = grad_input.numel();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda2_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda2_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), buff.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), num_classes);
output_size,
log_softmax_prob.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), sum_buff_along_class.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(),
gamma, alpha, num_classes);
});
AT_CUDA_CHECK(cudaGetLastError());

View File

@ -29,18 +29,26 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
grad_input, gamma, alpha);
}
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, input, target, weight,
output, gamma, alpha);
void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, log_softmax_prob,
target, weight, output, gamma, alpha);
}
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input, float gamma,
float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, input, target, weight,
buff, grad_input, gamma, alpha);
void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, log_softmax_prob,
target, weight, sum_buff_along_class, grad_input,
gamma, alpha);
}
#ifdef MMCV_WITH_DIOPI
@ -145,14 +153,24 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
#endif
}
void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
softmax_focal_loss_forward_impl(input, target, weight, output, gamma, alpha);
void softmax_focal_loss_forward(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha) {
softmax_focal_loss_forward_impl(log_softmax_prob, target, weight,
output, gamma, alpha);
}
void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input, float gamma,
float alpha) {
softmax_focal_loss_backward_impl(input, target, weight, buff, grad_input,
void softmax_focal_loss_backward(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha) {
softmax_focal_loss_backward_impl(log_softmax_prob, target, weight,
sum_buff_along_class, grad_input,
gamma, alpha);
}

View File

@ -103,12 +103,20 @@ void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha);
void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);
void softmax_focal_loss_forward(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor output,
const float gamma,
const float alpha);
void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input, float gamma,
float alpha);
void softmax_focal_loss_backward(const Tensor log_softmax_prob,
const Tensor target,
const Tensor weight,
Tensor sum_buff_along_class,
Tensor grad_input,
const float gamma,
const float alpha);
void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor,
Tensor weight_tensor, Tensor out_tensor, int b,
@ -566,13 +574,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("weight"), py::arg("grad_input"), py::arg("gamma"),
py::arg("alpha"));
m.def("softmax_focal_loss_forward", &softmax_focal_loss_forward,
"softmax_focal_loss_forward", py::arg("input"), py::arg("target"),
py::arg("weight"), py::arg("output"), py::arg("gamma"),
py::arg("alpha"));
m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward,
"softmax_focal_loss_backward", py::arg("input"), py::arg("target"),
py::arg("weight"), py::arg("buff"), py::arg("grad_input"),
"softmax_focal_loss_forward", py::arg("log_softmax_prob"),
py::arg("target"), py::arg("weight"), py::arg("output"),
py::arg("gamma"), py::arg("alpha"));
m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward,
"softmax_focal_loss_backward", py::arg("log_softmax_prob"),
py::arg("target"), py::arg("weight"), py::arg("sum_buff_along_class"),
py::arg("grad_input"), py::arg("gamma"), py::arg("alpha"));
m.def("three_interpolate_forward", &three_interpolate_forward,
"three_interpolate_forward", py::arg("points_tensor"),
py::arg("idx_tensor"), py::arg("weight_tensor"), py::arg("out_tensor"),

View File

@ -3,6 +3,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
@ -132,16 +133,13 @@ class SoftmaxFocalLossFunction(Function):
ctx.alpha = float(alpha)
ctx.reduction = ctx.reduction_dict[reduction]
channel_stats, _ = torch.max(input, dim=1)
input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
input_softmax.exp_()
# log_softmax for numerical stability
log_softmax_prob = F.log_softmax(input, dim=1)
channel_stats = input_softmax.sum(dim=1)
input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
output = input.new_zeros(input.size())
output = input.new_zeros(input.size(0))
ext_module.softmax_focal_loss_forward(
input_softmax,
log_softmax_prob,
target,
weight,
output,
@ -152,27 +150,30 @@ class SoftmaxFocalLossFunction(Function):
output = output.sum() / input.size(0)
elif ctx.reduction == ctx.reduction_dict['sum']:
output = output.sum()
ctx.save_for_backward(input_softmax, target, weight)
ctx.save_for_backward(log_softmax_prob, target, weight)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input_softmax, target, weight = ctx.saved_tensors
buff = input_softmax.new_zeros(input_softmax.size(0))
grad_input = input_softmax.new_zeros(input_softmax.size())
log_softmax_prob, target, weight = ctx.saved_tensors
sum_buff_along_class = log_softmax_prob.new_zeros(
log_softmax_prob.size(0))
grad_input = log_softmax_prob.new_zeros(log_softmax_prob.size())
ext_module.softmax_focal_loss_backward(
input_softmax,
log_softmax_prob,
target,
weight,
buff,
sum_buff_along_class,
grad_input,
gamma=ctx.gamma,
alpha=ctx.alpha)
grad_input *= grad_output
if ctx.reduction == ctx.reduction_dict['mean']:
grad_input /= input_softmax.size(0)
grad_input /= log_softmax_prob.size(0)
return grad_input, None, None, None, None, None

View File

@ -21,13 +21,13 @@ inputs = [
([[1e-6, 2e-6, 3e-6], [4e-6, 5e-5, 6e-4], [7e-3, 8e-2, 9e-1]], [1, 2, 0]),
]
softmax_outputs = [(0.00566451, [[-0.00657264, 0.00657264],
[0.00657264, -0.00657264]]),
(0.34956908, [[0.10165970, 0.03739851, -0.13905823],
[0.01227554, -0.10298023, 0.09070466]]),
(0.15754992, [[0.02590877, -0.05181759, 0.02590882],
[0.02589641, 0.02589760, -0.05179400],
[-0.07307514, 0.02234372, 0.05073142]])]
softmax_outputs = [(0.01132904, [[-0.01971794, 0.01971793],
[0.00657264, -0.00657265]]),
(0.34956908, [[0.10165971, 0.03739851, -0.13905823],
[0.01227554, -0.10298022, 0.09070467]]),
(0.30995172, [[0.02590877, -0.05181758, 0.02590882],
[0.02589641, 0.02589760, -0.05179401],
[-0.21922545, 0.06703118, 0.15219429]])]
sigmoid_outputs = [(0.13562961, [[-0.00657264, 0.11185755],
[0.11185755, -0.00657264]]),