mirror of https://github.com/open-mmlab/mmcv.git
Merge 5698ac7914
into f19d3e771c
commit
26d804b8e8
|
@ -10,63 +10,101 @@
|
|||
|
||||
template <typename T>
|
||||
__global__ void softmax_focal_loss_forward_cuda_kernel(
|
||||
const int nthreads, const T* softmax, const int64_t* target,
|
||||
const T* weight, T* output, const T gamma, const T alpha,
|
||||
const int num_classes) {
|
||||
const int nthreads, const T* __restrict__ log_softmax_prob,
|
||||
const int64_t* __restrict__ target, const T* __restrict__ weight,
|
||||
T* __restrict__ output,
|
||||
const T gamma, const T alpha, const int num_classes) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int64_t label = target[index];
|
||||
T pred = softmax[index * num_classes + label];
|
||||
const int n = index / num_classes;
|
||||
const int c = index % num_classes;
|
||||
|
||||
if (label >= 0) {
|
||||
output[index] =
|
||||
-alpha * pow((T)1. - pred, gamma) * log(max(pred, (T)FLT_MIN));
|
||||
// focal loss
|
||||
// FL(p) = - alpha * (1-p)^gamma * log(p) if curr_class == label
|
||||
//
|
||||
// note that log_softmax_prob is calculated in Python part
|
||||
// by using PyTorch API F.log_softmax()
|
||||
const int64_t label = target[n];
|
||||
if (c == label) {
|
||||
const T w = (weight != NULL) ? weight[label] : T(1);
|
||||
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;
|
||||
|
||||
const T log_pred = log_softmax_prob[index];
|
||||
const T pred = exp(log_pred);
|
||||
|
||||
output[index] = -alpha_fac * pow(1 - pred, gamma) * log_pred;
|
||||
} else {
|
||||
output[index] = 0;
|
||||
}
|
||||
if (weight != NULL) {
|
||||
output[index] *= weight[label];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void softmax_focal_loss_backward_cuda1_kernel(
|
||||
const int nthreads, const T* softmax, const int64_t* target,
|
||||
const T* weight, T* buff, const T gamma, const T alpha,
|
||||
const int num_classes) {
|
||||
__global__ void softmax_focal_loss_backward_cuda_kernel(
|
||||
const int nthreads, const T* __restrict__ log_softmax_prob,
|
||||
const int64_t* __restrict__ target, const T* __restrict__ weight,
|
||||
T* __restrict__ sum_buff_along_class, T* __restrict__ grad_input,
|
||||
const T gamma, const T alpha, const int num_classes) {
|
||||
// forward node: x ----> p ----> FL
|
||||
// func: SM FL
|
||||
//
|
||||
// backward node: x <---- p <---- FL
|
||||
// index: j i FL
|
||||
//
|
||||
// For simplicity, the alpha of FL is ignored here
|
||||
// dFL/dp = - [((1-p)^gamma) / p
|
||||
// - gamma * (1-p)^(gamma-1) * log(p)]
|
||||
// dp_i/dx_j = dSM/dx_j
|
||||
// = p_i * (1-p_j) i==j;
|
||||
// p_i * (0-p_j) i!=j;
|
||||
// = p_i * (delta - p_j) where delta is Kronecker delta
|
||||
//
|
||||
// Replacing the p of dFL/dp with p_i, then
|
||||
// dFL/dx_j = dFL/dp_i * dp_i/dx_j
|
||||
// = - (delta - p_j) * [ (1-p_i)^gamma
|
||||
// - gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
|
||||
// = (delta - p_j) * [- (1-p_i)^gamma +
|
||||
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
|
||||
//
|
||||
// Let B_i denote [- (1-p_i)^gamma +
|
||||
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i],
|
||||
// and indices {i} is summed for all classes at index j
|
||||
// since x_j received all the gradients from {p_i}.
|
||||
// Then, dFL/dx_j = sum_i{ (delta - p_j) * B_i }
|
||||
// = sum_i{ delta*B_i - p_j*B_i }
|
||||
// = B_j - (p_j * sum_i{B_i})
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int64_t label = target[index];
|
||||
T pred = softmax[index * num_classes + label];
|
||||
// B_i
|
||||
const int n = index / num_classes;
|
||||
const int c = index % num_classes;
|
||||
|
||||
if (label >= 0) {
|
||||
buff[index] = alpha * (-pow((T)1. - pred, gamma) +
|
||||
gamma * pow((T)1. - pred, gamma - 1) * pred *
|
||||
log(max(pred, (T)FLT_MIN)));
|
||||
} else {
|
||||
buff[index] = 0;
|
||||
}
|
||||
if (weight != NULL) {
|
||||
buff[index] *= weight[label];
|
||||
}
|
||||
}
|
||||
}
|
||||
const int64_t label = target[n];
|
||||
if (c == label) {
|
||||
const T w = (weight != NULL) ? weight[label] : T(1);
|
||||
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;
|
||||
|
||||
template <typename T>
|
||||
__global__ void softmax_focal_loss_backward_cuda2_kernel(
|
||||
const int nthreads, const T* softmax, const int64_t* target, const T* buff,
|
||||
T* grad_input, const int num_classes) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int n = index / num_classes;
|
||||
int c = index % num_classes;
|
||||
int64_t label = target[n];
|
||||
const T log_pred = log_softmax_prob[index];
|
||||
const T pred = exp(log_pred);
|
||||
const T one_minus_pred = 1 - pred;
|
||||
|
||||
if (label >= 0) {
|
||||
T flag = (label == c ? (T)1. : (T)0.);
|
||||
grad_input[index] = buff[n] * (flag - softmax[index]);
|
||||
const T buff = alpha_fac * (
|
||||
-pow(one_minus_pred, gamma) +
|
||||
gamma * pow(one_minus_pred, gamma - 1) * log_pred * pred
|
||||
);
|
||||
grad_input[index] = buff;
|
||||
sum_buff_along_class[n] += buff;
|
||||
} else {
|
||||
grad_input[index] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// dFL/dx_j
|
||||
const int n = index / num_classes;
|
||||
|
||||
const T pred = exp(log_softmax_prob[index]);
|
||||
grad_input[index] -= pred * sum_buff_along_class[n];
|
||||
}
|
||||
}
|
||||
|
||||
#endif // SOFTMAX_FOCAL_LOSS_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -409,13 +409,17 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
|
|||
const float gamma,
|
||||
const float alpha);
|
||||
|
||||
void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
|
||||
Tensor weight, Tensor output,
|
||||
void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor output,
|
||||
const float gamma,
|
||||
const float alpha);
|
||||
|
||||
void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
|
||||
Tensor weight, Tensor buff,
|
||||
void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor sum_buff_along_class,
|
||||
Tensor grad_input,
|
||||
const float gamma,
|
||||
const float alpha);
|
||||
|
@ -433,18 +437,26 @@ void sigmoid_focal_loss_backward_cuda(Tensor input, Tensor target,
|
|||
gamma, alpha);
|
||||
}
|
||||
|
||||
void softmax_focal_loss_forward_cuda(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, float gamma, float alpha) {
|
||||
SoftmaxFocalLossForwardCUDAKernelLauncher(input, target, weight, output,
|
||||
gamma, alpha);
|
||||
void softmax_focal_loss_forward_cuda(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor output,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
SoftmaxFocalLossForwardCUDAKernelLauncher(log_softmax_prob, target, weight,
|
||||
output, gamma, alpha);
|
||||
}
|
||||
|
||||
void softmax_focal_loss_backward_cuda(Tensor input, Tensor target,
|
||||
Tensor weight, Tensor buff,
|
||||
Tensor grad_input, float gamma,
|
||||
float alpha) {
|
||||
SoftmaxFocalLossBackwardCUDAKernelLauncher(input, target, weight, buff,
|
||||
grad_input, gamma, alpha);
|
||||
void softmax_focal_loss_backward_cuda(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor sum_buff_along_class,
|
||||
Tensor grad_input,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
SoftmaxFocalLossBackwardCUDAKernelLauncher(log_softmax_prob, target, weight,
|
||||
sum_buff_along_class, grad_input,
|
||||
gamma, alpha);
|
||||
}
|
||||
|
||||
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
|
||||
|
@ -454,13 +466,20 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
|
|||
Tensor weight, Tensor grad_input,
|
||||
float gamma, float alpha);
|
||||
|
||||
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, float gamma, float alpha);
|
||||
void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor output,
|
||||
const float gamma,
|
||||
const float alpha);
|
||||
|
||||
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
|
||||
Tensor weight, Tensor buff,
|
||||
Tensor grad_input, float gamma,
|
||||
float alpha);
|
||||
void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor sum_buff_along_class,
|
||||
Tensor grad_input,
|
||||
const float gamma,
|
||||
const float alpha);
|
||||
|
||||
REGISTER_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, CUDA,
|
||||
sigmoid_focal_loss_forward_cuda);
|
||||
|
|
|
@ -47,64 +47,53 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
|
|||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
|
||||
Tensor weight, Tensor output,
|
||||
void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor output,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
int output_size = output.numel();
|
||||
int num_classes = softmax.size(1);
|
||||
int num_classes = log_softmax_prob.size(1);
|
||||
|
||||
AT_ASSERTM(target.max().item<int64_t>() <= (int64_t)num_classes,
|
||||
"target label should smaller or equal than num classes");
|
||||
at::cuda::CUDAGuard device_guard(softmax.device());
|
||||
at::cuda::CUDAGuard device_guard(log_softmax_prob.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
|
||||
log_softmax_prob.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
|
||||
softmax_focal_loss_forward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, softmax.data_ptr<scalar_t>(),
|
||||
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
|
||||
output.data_ptr<scalar_t>(), gamma, alpha, num_classes);
|
||||
output_size,
|
||||
log_softmax_prob.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
|
||||
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
|
||||
gamma, alpha, num_classes);
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
|
||||
Tensor weight, Tensor buff,
|
||||
void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor sum_buff_along_class,
|
||||
Tensor grad_input,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
int num_classes = softmax.size(1);
|
||||
int output_size = grad_input.numel();
|
||||
int num_classes = log_softmax_prob.size(1);
|
||||
|
||||
int output_size = buff.numel();
|
||||
at::cuda::CUDAGuard device_guard(grad_input.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_input.scalar_type(),
|
||||
"softmax_focal_loss_backward_cuda1_"
|
||||
"kernel",
|
||||
[&] {
|
||||
softmax_focal_loss_backward_cuda1_kernel<scalar_t>
|
||||
log_softmax_prob.scalar_type(), "softmax_focal_loss_backward_cuda_kernel", [&] {
|
||||
softmax_focal_loss_backward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, softmax.data_ptr<scalar_t>(),
|
||||
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
|
||||
buff.data_ptr<scalar_t>(), gamma, alpha, num_classes);
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
output_size = grad_input.numel();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_input.scalar_type(),
|
||||
"softmax_focal_loss_backward_cuda2_"
|
||||
"kernel",
|
||||
[&] {
|
||||
softmax_focal_loss_backward_cuda2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
output_size, softmax.data_ptr<scalar_t>(),
|
||||
target.data_ptr<int64_t>(), buff.data_ptr<scalar_t>(),
|
||||
grad_input.data_ptr<scalar_t>(), num_classes);
|
||||
output_size,
|
||||
log_softmax_prob.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
|
||||
weight.data_ptr<scalar_t>(), sum_buff_along_class.data_ptr<scalar_t>(),
|
||||
grad_input.data_ptr<scalar_t>(),
|
||||
gamma, alpha, num_classes);
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
|
|
@ -29,18 +29,26 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
|
|||
grad_input, gamma, alpha);
|
||||
}
|
||||
|
||||
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, float gamma, float alpha) {
|
||||
DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, input, target, weight,
|
||||
output, gamma, alpha);
|
||||
void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor output,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, log_softmax_prob,
|
||||
target, weight, output, gamma, alpha);
|
||||
}
|
||||
|
||||
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
|
||||
Tensor weight, Tensor buff,
|
||||
Tensor grad_input, float gamma,
|
||||
float alpha) {
|
||||
DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, input, target, weight,
|
||||
buff, grad_input, gamma, alpha);
|
||||
void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor sum_buff_along_class,
|
||||
Tensor grad_input,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, log_softmax_prob,
|
||||
target, weight, sum_buff_along_class, grad_input,
|
||||
gamma, alpha);
|
||||
}
|
||||
|
||||
#ifdef MMCV_WITH_DIOPI
|
||||
|
@ -145,14 +153,24 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
|
|||
#endif
|
||||
}
|
||||
|
||||
void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, float gamma, float alpha) {
|
||||
softmax_focal_loss_forward_impl(input, target, weight, output, gamma, alpha);
|
||||
void softmax_focal_loss_forward(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor output,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
softmax_focal_loss_forward_impl(log_softmax_prob, target, weight,
|
||||
output, gamma, alpha);
|
||||
}
|
||||
|
||||
void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor buff, Tensor grad_input, float gamma,
|
||||
float alpha) {
|
||||
softmax_focal_loss_backward_impl(input, target, weight, buff, grad_input,
|
||||
void softmax_focal_loss_backward(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor sum_buff_along_class,
|
||||
Tensor grad_input,
|
||||
const float gamma,
|
||||
const float alpha) {
|
||||
softmax_focal_loss_backward_impl(log_softmax_prob, target, weight,
|
||||
sum_buff_along_class, grad_input,
|
||||
gamma, alpha);
|
||||
}
|
||||
|
|
|
@ -103,12 +103,20 @@ void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
|
|||
void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor grad_input, float gamma, float alpha);
|
||||
|
||||
void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, float gamma, float alpha);
|
||||
void softmax_focal_loss_forward(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor output,
|
||||
const float gamma,
|
||||
const float alpha);
|
||||
|
||||
void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor buff, Tensor grad_input, float gamma,
|
||||
float alpha);
|
||||
void softmax_focal_loss_backward(const Tensor log_softmax_prob,
|
||||
const Tensor target,
|
||||
const Tensor weight,
|
||||
Tensor sum_buff_along_class,
|
||||
Tensor grad_input,
|
||||
const float gamma,
|
||||
const float alpha);
|
||||
|
||||
void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor,
|
||||
Tensor weight_tensor, Tensor out_tensor, int b,
|
||||
|
@ -566,13 +574,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
py::arg("weight"), py::arg("grad_input"), py::arg("gamma"),
|
||||
py::arg("alpha"));
|
||||
m.def("softmax_focal_loss_forward", &softmax_focal_loss_forward,
|
||||
"softmax_focal_loss_forward", py::arg("input"), py::arg("target"),
|
||||
py::arg("weight"), py::arg("output"), py::arg("gamma"),
|
||||
py::arg("alpha"));
|
||||
m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward,
|
||||
"softmax_focal_loss_backward", py::arg("input"), py::arg("target"),
|
||||
py::arg("weight"), py::arg("buff"), py::arg("grad_input"),
|
||||
"softmax_focal_loss_forward", py::arg("log_softmax_prob"),
|
||||
py::arg("target"), py::arg("weight"), py::arg("output"),
|
||||
py::arg("gamma"), py::arg("alpha"));
|
||||
m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward,
|
||||
"softmax_focal_loss_backward", py::arg("log_softmax_prob"),
|
||||
py::arg("target"), py::arg("weight"), py::arg("sum_buff_along_class"),
|
||||
py::arg("grad_input"), py::arg("gamma"), py::arg("alpha"));
|
||||
m.def("three_interpolate_forward", &three_interpolate_forward,
|
||||
"three_interpolate_forward", py::arg("points_tensor"),
|
||||
py::arg("idx_tensor"), py::arg("weight_tensor"), py::arg("out_tensor"),
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Optional, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
|
@ -132,16 +133,13 @@ class SoftmaxFocalLossFunction(Function):
|
|||
ctx.alpha = float(alpha)
|
||||
ctx.reduction = ctx.reduction_dict[reduction]
|
||||
|
||||
channel_stats, _ = torch.max(input, dim=1)
|
||||
input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
|
||||
input_softmax.exp_()
|
||||
# log_softmax for numerical stability
|
||||
log_softmax_prob = F.log_softmax(input, dim=1)
|
||||
|
||||
channel_stats = input_softmax.sum(dim=1)
|
||||
input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
|
||||
output = input.new_zeros(input.size())
|
||||
|
||||
output = input.new_zeros(input.size(0))
|
||||
ext_module.softmax_focal_loss_forward(
|
||||
input_softmax,
|
||||
log_softmax_prob,
|
||||
target,
|
||||
weight,
|
||||
output,
|
||||
|
@ -152,27 +150,30 @@ class SoftmaxFocalLossFunction(Function):
|
|||
output = output.sum() / input.size(0)
|
||||
elif ctx.reduction == ctx.reduction_dict['sum']:
|
||||
output = output.sum()
|
||||
ctx.save_for_backward(input_softmax, target, weight)
|
||||
ctx.save_for_backward(log_softmax_prob, target, weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output: torch.Tensor) -> tuple:
|
||||
input_softmax, target, weight = ctx.saved_tensors
|
||||
buff = input_softmax.new_zeros(input_softmax.size(0))
|
||||
grad_input = input_softmax.new_zeros(input_softmax.size())
|
||||
log_softmax_prob, target, weight = ctx.saved_tensors
|
||||
|
||||
sum_buff_along_class = log_softmax_prob.new_zeros(
|
||||
log_softmax_prob.size(0))
|
||||
grad_input = log_softmax_prob.new_zeros(log_softmax_prob.size())
|
||||
|
||||
ext_module.softmax_focal_loss_backward(
|
||||
input_softmax,
|
||||
log_softmax_prob,
|
||||
target,
|
||||
weight,
|
||||
buff,
|
||||
sum_buff_along_class,
|
||||
grad_input,
|
||||
gamma=ctx.gamma,
|
||||
alpha=ctx.alpha)
|
||||
|
||||
grad_input *= grad_output
|
||||
if ctx.reduction == ctx.reduction_dict['mean']:
|
||||
grad_input /= input_softmax.size(0)
|
||||
grad_input /= log_softmax_prob.size(0)
|
||||
return grad_input, None, None, None, None, None
|
||||
|
||||
|
||||
|
|
|
@ -21,13 +21,13 @@ inputs = [
|
|||
([[1e-6, 2e-6, 3e-6], [4e-6, 5e-5, 6e-4], [7e-3, 8e-2, 9e-1]], [1, 2, 0]),
|
||||
]
|
||||
|
||||
softmax_outputs = [(0.00566451, [[-0.00657264, 0.00657264],
|
||||
[0.00657264, -0.00657264]]),
|
||||
(0.34956908, [[0.10165970, 0.03739851, -0.13905823],
|
||||
[0.01227554, -0.10298023, 0.09070466]]),
|
||||
(0.15754992, [[0.02590877, -0.05181759, 0.02590882],
|
||||
[0.02589641, 0.02589760, -0.05179400],
|
||||
[-0.07307514, 0.02234372, 0.05073142]])]
|
||||
softmax_outputs = [(0.01132904, [[-0.01971794, 0.01971793],
|
||||
[0.00657264, -0.00657265]]),
|
||||
(0.34956908, [[0.10165971, 0.03739851, -0.13905823],
|
||||
[0.01227554, -0.10298022, 0.09070467]]),
|
||||
(0.30995172, [[0.02590877, -0.05181758, 0.02590882],
|
||||
[0.02589641, 0.02589760, -0.05179401],
|
||||
[-0.21922545, 0.06703118, 0.15219429]])]
|
||||
|
||||
sigmoid_outputs = [(0.13562961, [[-0.00657264, 0.11185755],
|
||||
[0.11185755, -0.00657264]]),
|
||||
|
|
Loading…
Reference in New Issue