EasyCV/easycv/core/sailfish/function.py

270 lines
10 KiB
Python

# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions."""
from __future__ import absolute_import, division, print_function
import torch
class _Cat(torch.autograd.Function):
"""Concat inputs."""
@staticmethod
def forward(ctx, inputs, dim, rank, world_size): # noqa: E501 # pylint: disable=arguments-differ
r"""Cat is defined as:
.. math::
\text{all_cat}(x_i) = \bigoplus_j x_j
"""
ctx.dim = dim
ctx.rank = rank
ctx.world_size = world_size
all_inputs = [
torch.zeros(
inputs.size(), dtype=inputs.dtype, device=inputs.device)
for _ in range(world_size)
]
torch.distributed.all_gather(all_inputs, inputs)
output = torch.cat(all_inputs, dim=dim)
output.requires_grad_()
return output
@staticmethod
def backward(ctx, grad_output): # pylint: disable=arguments-differ
r"""Gradient of Cat is defined as:
.. math::
\nabla \text{all_cat}(x_i) = \text{split}(\nabla x_i)
"""
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input)
grad_input_dim_size = grad_input.size()[ctx.dim]
assert grad_input_dim_size % ctx.world_size == 0
split_size = grad_input_dim_size // ctx.world_size
grad_input_splits = torch.split(grad_input, split_size, dim=ctx.dim)
return grad_input_splits[ctx.rank], None, None, None
def all_cat(inputs, dim=0, rank=0, world_size=1):
return _Cat.apply(inputs, dim, rank, world_size)
class _Sum(torch.autograd.Function):
"""Sum inputs."""
@staticmethod
def forward(_, inputs): # pylint: disable=arguments-differ
r"""Sum is defined as:
.. math::
\text{all_sum}(x_i) = \sum_j x_j
"""
inputs_sum = inputs.clone()
torch.distributed.all_reduce(inputs_sum)
inputs_sum.requires_grad_()
return inputs_sum
@staticmethod
def backward(_, grad_output): # pylint: disable=arguments-differ
r"""Gradient of Sum is defined as:
.. math::
\nabla \text{all_sum}(x_i) = \sum_j\nabla x_j
"""
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input)
return grad_input
def all_sum(inputs):
return _Sum.apply(inputs)
class _LogSoftmax(torch.autograd.Function):
"""Compute log softmax of logits."""
@staticmethod
def forward(ctx, logits, epsilon): # pylint: disable=arguments-differ
r"""LogSoftmax is defined as:
.. math::
\log(\text{softmax}(x_i))
= \log\left(\frac{\text{e}^{x_i}}{\sum_j\text{e}^{x_j}}\right)
= x_i - \log\sum_j\text{e}^{x_j}
For numerical stability, it subtracts the maximum value for every logits:
.. math::
\log(\text{softmax}(x_i))
= \hat{x_i} - \log\sum_j\text{e}^{\hat{x_j}},
\hat{x} = x - \max_j{x_j}
"""
ctx.logits_dtype = logits.dtype
logits_max = torch.max(logits, dim=1).values
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX)
logits = logits - logits_max.view(-1, 1)
logits_exp = torch.exp(logits)
logits_exp_sum = torch.sum(logits_exp, dim=1)
torch.distributed.all_reduce(logits_exp_sum)
logits_exp_sum_log = torch.log(logits_exp_sum + epsilon)
prob_log = logits - logits_exp_sum_log.view(-1, 1)
ctx.save_for_backward(prob_log)
prob_log.requires_grad_()
return prob_log
@staticmethod
def backward(ctx, grad_output): # pylint: disable=arguments-differ
r"""Gradient of LogSoftmax is defined as:
.. math::
\nabla\log(\text{softmax}(x_i))
= \nabla x_i - \text{softmax}(x_i) \sum_j \nabla x_j
"""
grad_output_sum = torch.sum(grad_output, dim=1)
torch.distributed.all_reduce(grad_output_sum)
prob_log, = ctx.saved_tensors
grad_input = torch.exp(prob_log) * grad_output_sum.view(-1, 1)
grad_input = grad_output - grad_input
grad_input = grad_input.type(dtype=ctx.logits_dtype)
return grad_input, None
def all_log_softmax(logits, epsilon=1e-8):
return _LogSoftmax.apply(logits, epsilon)
class _NLLLoss(torch.autograd.Function):
"""calculate NLLLoss from mask."""
@staticmethod
def forward(ctx, inputs, correct_mask): # pylint: disable=arguments-differ
ctx.inputs_size = inputs.size()
ctx.save_for_backward(correct_mask)
loss = torch.sum(inputs * correct_mask) / -ctx.inputs_size[0]
torch.distributed.all_reduce(loss)
loss.requires_grad_()
return loss
@staticmethod
def backward(ctx, grad_output): # pylint: disable=arguments-differ
correct_mask, = ctx.saved_tensors
grad_input = grad_output.repeat(*ctx.inputs_size)
grad_input = grad_input * correct_mask / -ctx.inputs_size[0]
return grad_input, None
def all_nll_loss(inputs, correct_mask):
return _NLLLoss.apply(inputs, correct_mask)
def shard_target_and_mask(target, output_features, rank=0):
target_shard_begin = output_features * rank
target_shard_end = output_features * (rank + 1)
target_shard_lmask = torch.ge(target, target_shard_begin)
target_shard_rmask = torch.lt(target, target_shard_end)
target_mask = (target_shard_lmask * target_shard_rmask).to(target.device)
target_shard = (target - target_shard_begin) * target_mask.long()
return target_shard, target_mask
def shard_correct_mask(target, inputs, rank=0):
"""Get correct mask of inputs."""
inputs_size = inputs.size()
target_shard_begin = inputs_size[1] * rank
target_shard_end = inputs_size[1] * (rank + 1)
target_shard_lmask = torch.ge(target, target_shard_begin)
target_shard_rmask = torch.lt(target, target_shard_end)
target_mask = (target_shard_lmask * target_shard_rmask).to(target.device)
target_shard = (target - target_shard_begin) * target_mask.long()
mask = torch.zeros(inputs_size, device=inputs.device, dtype=inputs.dtype)
mask.scatter_(1, target_shard.view(-1, 1).long(), 1)
mask.masked_fill_((~target_mask).view(-1, 1).expand(*inputs.size()), 0)
return mask
def shard_correct_predictions(target, logits, world_size=1):
r"""Calculate correct predictions for logits."""
shard_max_logits, shard_expected_class = torch.max(logits, dim=1)
all_max_logits = [
torch.zeros(
shard_max_logits.size(),
dtype=shard_max_logits.dtype,
device=shard_max_logits.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_max_logits, shard_max_logits)
all_max_logits = torch.cat([t.view(-1, 1) for t in all_max_logits], dim=1)
rank_pred = torch.max(all_max_logits, dim=1)[1].view(-1, 1)
all_shard_pred = [
torch.zeros(
shard_expected_class.size(),
dtype=shard_expected_class.dtype,
device=shard_expected_class.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_shard_pred, shard_expected_class)
all_shard_pred = torch.cat([t.view(-1, 1) for t in all_shard_pred], dim=1)
all_shard_pred_mask = torch.zeros(
all_shard_pred.size(),
device=all_shard_pred.device,
dtype=all_shard_pred.dtype)
all_shard_pred_mask.scatter_(1, rank_pred.long(), 1)
shard_pred = torch.sum(
all_shard_pred * all_shard_pred_mask, dim=1).view(-1, 1)
pred = shard_pred + rank_pred * logits.size()[1]
return (pred == target.data.view_as(pred)).sum().item()
def shard_topk_correct_predictions(target, logits, k, world_size=1):
r"""Calculate correct predictions for logits."""
# Step 1: Compute top-k of shard logits.
logits_topk, logits_topk_idx = torch.topk(logits, k, dim=1)
all_logits_topk = [
torch.zeros(
logits_topk.size(),
dtype=logits_topk.dtype,
device=logits_topk.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_logits_topk, logits_topk)
all_logits_topk = torch.cat([t.view(-1, k) for t in all_logits_topk],
dim=1)
all_logits_topk_idx = [
torch.zeros(
logits_topk_idx.size(),
dtype=logits_topk_idx.dtype,
device=logits_topk_idx.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_logits_topk_idx, logits_topk_idx)
all_logits_topk_idx = torch.cat(
[t.view(-1, k) for t in all_logits_topk_idx], dim=1)
# Step 2: Compute global top-k indices.
_, all_logits_topk_topk_idx = torch.topk(all_logits_topk, k, dim=1)
all_logits_topk_topk_idx = all_logits_topk_topk_idx.view(-1, k)
all_logits_topk_mask = torch.zeros(
all_logits_topk_idx.size(),
device=all_logits_topk_idx.device,
dtype=all_logits_topk_idx.dtype)
all_logits_topk_mask.scatter_(1, all_logits_topk_topk_idx.long(), 1)
batch_size, shard_num_classes = logits.size()
all_logits_topk_base = torch.cat([
torch.ones([batch_size, k],
device=all_logits_topk_idx.device,
dtype=all_logits_topk_idx.dtype) * p * shard_num_classes
for p in range(world_size)
],
dim=1)
all_logits_topk_gidx = all_logits_topk_base + all_logits_topk_idx
# Step 3: Compute predictions and check.
pred = torch.masked_select(all_logits_topk_gidx,
all_logits_topk_mask.type(torch.bool)).view(
batch_size, k)
return (pred == target.view(-1, 1)).sum().item()