Add sailfish for fully sharded data parallel training

pull/97/head
yuanman.ym 2022-06-16 19:24:37 +08:00
parent 5110de7635
commit 813a191876
9 changed files with 1728 additions and 0 deletions

View File

@ -0,0 +1,36 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytorch extension for distributed training."""
from __future__ import absolute_import, division, print_function
from easycv.core.sailfish.activation import LogSoftmax # noqa: F401
from easycv.core.sailfish.linear import ArcFaceLinear # noqa: F401
from easycv.core.sailfish.linear import Linear # noqa: F401
from easycv.core.sailfish.loss import ArcMarginLoss # noqa: F401
from easycv.core.sailfish.loss import CrossEntropyLoss # noqa: F401
from easycv.core.sailfish.loss import FocalLoss # noqa: F401
from easycv.core.sailfish.loss import NLLLoss # noqa: F401
from easycv.core.sailfish.loss import SoftmaxLoss # noqa: F401
from easycv.core.sailfish.util import BiasUniformInitializer # noqa: F401
from easycv.core.sailfish.util import DistributedParallel # noqa: F401
from easycv.core.sailfish.util import KaimingUniformInitializer # noqa: F401
from easycv.core.sailfish.util import ModelParallel # noqa: F401
from easycv.core.sailfish.util import NormalInitializer # noqa: F401
from easycv.core.sailfish.util import OnesInitializer # noqa: F401
from easycv.core.sailfish.util import ParameterInitializer # noqa: F401
from easycv.core.sailfish.util import RenormUniformInitializer # noqa: F401
from easycv.core.sailfish.util import XavierUniformInitializer # noqa: F401
from easycv.core.sailfish.util import ZerosInitializer # noqa: F401

View File

@ -0,0 +1,52 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Activation modules."""
from __future__ import absolute_import, division, print_function
import torch
from easycv.core.sailfish.util import ModelParallel
class LogSoftmax(torch.nn.Module):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor rescaling them so that the elements of the
n-dimensional output Tensor lie in the range (0, 1].
Shape:
- Input: :math:`(*)` where `*` means, any number of additional
dimensions
- Output: :math:`(*)`, same shape as the input
Returns:
a Tensor of the same dimension and shape as the input with
values in the range (-inf,0].
Examples::
>>> m = LogSoftmax()
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
def __init__(self, epsilon=0, parallel=None):
super(LogSoftmax, self).__init__()
self.epsilon = epsilon
self.parallel = parallel
def forward(self, logits): # pylint: disable=arguments-differ
if isinstance(self.parallel, ModelParallel):
return self.parallel.log_softmax(logits, epsilon=self.epsilon)
return torch.nn.functional.log_softmax(logits, _stacklevel=5)

View File

@ -0,0 +1,269 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions."""
from __future__ import absolute_import, division, print_function
import torch
class _Cat(torch.autograd.Function):
"""Concat inputs."""
@staticmethod
def forward(ctx, inputs, dim, rank, world_size): # noqa: E501 # pylint: disable=arguments-differ
r"""Cat is defined as:
.. math::
\text{all_cat}(x_i) = \bigoplus_j x_j
"""
ctx.dim = dim
ctx.rank = rank
ctx.world_size = world_size
all_inputs = [
torch.zeros(
inputs.size(), dtype=inputs.dtype, device=inputs.device)
for _ in range(world_size)
]
torch.distributed.all_gather(all_inputs, inputs)
output = torch.cat(all_inputs, dim=dim)
output.requires_grad_()
return output
@staticmethod
def backward(ctx, grad_output): # pylint: disable=arguments-differ
r"""Gradient of Cat is defined as:
.. math::
\nabla \text{all_cat}(x_i) = \text{split}(\nabla x_i)
"""
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input)
grad_input_dim_size = grad_input.size()[ctx.dim]
assert grad_input_dim_size % ctx.world_size == 0
split_size = grad_input_dim_size // ctx.world_size
grad_input_splits = torch.split(grad_input, split_size, dim=ctx.dim)
return grad_input_splits[ctx.rank], None, None, None
def all_cat(inputs, dim=0, rank=0, world_size=1):
return _Cat.apply(inputs, dim, rank, world_size)
class _Sum(torch.autograd.Function):
"""Sum inputs."""
@staticmethod
def forward(_, inputs): # pylint: disable=arguments-differ
r"""Sum is defined as:
.. math::
\text{all_sum}(x_i) = \sum_j x_j
"""
inputs_sum = inputs.clone()
torch.distributed.all_reduce(inputs_sum)
inputs_sum.requires_grad_()
return inputs_sum
@staticmethod
def backward(_, grad_output): # pylint: disable=arguments-differ
r"""Gradient of Sum is defined as:
.. math::
\nabla \text{all_sum}(x_i) = \sum_j\nabla x_j
"""
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input)
return grad_input
def all_sum(inputs):
return _Sum.apply(inputs)
class _LogSoftmax(torch.autograd.Function):
"""Compute log softmax of logits."""
@staticmethod
def forward(ctx, logits, epsilon): # pylint: disable=arguments-differ
r"""LogSoftmax is defined as:
.. math::
\log(\text{softmax}(x_i))
= \log\left(\frac{\text{e}^{x_i}}{\sum_j\text{e}^{x_j}}\right)
= x_i - \log\sum_j\text{e}^{x_j}
For numerical stability, it subtracts the maximum value for every logits:
.. math::
\log(\text{softmax}(x_i))
= \hat{x_i} - \log\sum_j\text{e}^{\hat{x_j}},
\hat{x} = x - \max_j{x_j}
"""
ctx.logits_dtype = logits.dtype
logits_max = torch.max(logits, dim=1).values
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX)
logits = logits - logits_max.view(-1, 1)
logits_exp = torch.exp(logits)
logits_exp_sum = torch.sum(logits_exp, dim=1)
torch.distributed.all_reduce(logits_exp_sum)
logits_exp_sum_log = torch.log(logits_exp_sum + epsilon)
prob_log = logits - logits_exp_sum_log.view(-1, 1)
ctx.save_for_backward(prob_log)
prob_log.requires_grad_()
return prob_log
@staticmethod
def backward(ctx, grad_output): # pylint: disable=arguments-differ
r"""Gradient of LogSoftmax is defined as:
.. math::
\nabla\log(\text{softmax}(x_i))
= \nabla x_i - \text{softmax}(x_i) \sum_j \nabla x_j
"""
grad_output_sum = torch.sum(grad_output, dim=1)
torch.distributed.all_reduce(grad_output_sum)
prob_log, = ctx.saved_tensors
grad_input = torch.exp(prob_log) * grad_output_sum.view(-1, 1)
grad_input = grad_output - grad_input
grad_input = grad_input.type(dtype=ctx.logits_dtype)
return grad_input, None
def all_log_softmax(logits, epsilon=1e-8):
return _LogSoftmax.apply(logits, epsilon)
class _NLLLoss(torch.autograd.Function):
"""calculate NLLLoss from mask."""
@staticmethod
def forward(ctx, inputs, correct_mask): # pylint: disable=arguments-differ
ctx.inputs_size = inputs.size()
ctx.save_for_backward(correct_mask)
loss = torch.sum(inputs * correct_mask) / -ctx.inputs_size[0]
torch.distributed.all_reduce(loss)
loss.requires_grad_()
return loss
@staticmethod
def backward(ctx, grad_output): # pylint: disable=arguments-differ
correct_mask, = ctx.saved_tensors
grad_input = grad_output.repeat(*ctx.inputs_size)
grad_input = grad_input * correct_mask / -ctx.inputs_size[0]
return grad_input, None
def all_nll_loss(inputs, correct_mask):
return _NLLLoss.apply(inputs, correct_mask)
def shard_target_and_mask(target, output_features, rank=0):
target_shard_begin = output_features * rank
target_shard_end = output_features * (rank + 1)
target_shard_lmask = torch.ge(target, target_shard_begin)
target_shard_rmask = torch.lt(target, target_shard_end)
target_mask = (target_shard_lmask * target_shard_rmask).to(target.device)
target_shard = (target - target_shard_begin) * target_mask.long()
return target_shard, target_mask
def shard_correct_mask(target, inputs, rank=0):
"""Get correct mask of inputs."""
inputs_size = inputs.size()
target_shard_begin = inputs_size[1] * rank
target_shard_end = inputs_size[1] * (rank + 1)
target_shard_lmask = torch.ge(target, target_shard_begin)
target_shard_rmask = torch.lt(target, target_shard_end)
target_mask = (target_shard_lmask * target_shard_rmask).to(target.device)
target_shard = (target - target_shard_begin) * target_mask.long()
mask = torch.zeros(inputs_size, device=inputs.device, dtype=inputs.dtype)
mask.scatter_(1, target_shard.view(-1, 1).long(), 1)
mask.masked_fill_((~target_mask).view(-1, 1).expand(*inputs.size()), 0)
return mask
def shard_correct_predictions(target, logits, world_size=1):
r"""Calculate correct predictions for logits."""
shard_max_logits, shard_expected_class = torch.max(logits, dim=1)
all_max_logits = [
torch.zeros(
shard_max_logits.size(),
dtype=shard_max_logits.dtype,
device=shard_max_logits.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_max_logits, shard_max_logits)
all_max_logits = torch.cat([t.view(-1, 1) for t in all_max_logits], dim=1)
rank_pred = torch.max(all_max_logits, dim=1)[1].view(-1, 1)
all_shard_pred = [
torch.zeros(
shard_expected_class.size(),
dtype=shard_expected_class.dtype,
device=shard_expected_class.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_shard_pred, shard_expected_class)
all_shard_pred = torch.cat([t.view(-1, 1) for t in all_shard_pred], dim=1)
all_shard_pred_mask = torch.zeros(
all_shard_pred.size(),
device=all_shard_pred.device,
dtype=all_shard_pred.dtype)
all_shard_pred_mask.scatter_(1, rank_pred.long(), 1)
shard_pred = torch.sum(
all_shard_pred * all_shard_pred_mask, dim=1).view(-1, 1)
pred = shard_pred + rank_pred * logits.size()[1]
return (pred == target.data.view_as(pred)).sum().item()
def shard_topk_correct_predictions(target, logits, k, world_size=1):
r"""Calculate correct predictions for logits."""
# Step 1: Compute top-k of shard logits.
logits_topk, logits_topk_idx = torch.topk(logits, k, dim=1)
all_logits_topk = [
torch.zeros(
logits_topk.size(),
dtype=logits_topk.dtype,
device=logits_topk.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_logits_topk, logits_topk)
all_logits_topk = torch.cat([t.view(-1, k) for t in all_logits_topk],
dim=1)
all_logits_topk_idx = [
torch.zeros(
logits_topk_idx.size(),
dtype=logits_topk_idx.dtype,
device=logits_topk_idx.device) for _ in range(world_size)
]
torch.distributed.all_gather(all_logits_topk_idx, logits_topk_idx)
all_logits_topk_idx = torch.cat(
[t.view(-1, k) for t in all_logits_topk_idx], dim=1)
# Step 2: Compute global top-k indices.
_, all_logits_topk_topk_idx = torch.topk(all_logits_topk, k, dim=1)
all_logits_topk_topk_idx = all_logits_topk_topk_idx.view(-1, k)
all_logits_topk_mask = torch.zeros(
all_logits_topk_idx.size(),
device=all_logits_topk_idx.device,
dtype=all_logits_topk_idx.dtype)
all_logits_topk_mask.scatter_(1, all_logits_topk_topk_idx.long(), 1)
batch_size, shard_num_classes = logits.size()
all_logits_topk_base = torch.cat([
torch.ones([batch_size, k],
device=all_logits_topk_idx.device,
dtype=all_logits_topk_idx.dtype) * p * shard_num_classes
for p in range(world_size)
],
dim=1)
all_logits_topk_gidx = all_logits_topk_base + all_logits_topk_idx
# Step 3: Compute predictions and check.
pred = torch.masked_select(all_logits_topk_gidx,
all_logits_topk_mask.type(torch.bool)).view(
batch_size, k)
return (pred == target.view(-1, 1)).sum().item()

View File

@ -0,0 +1,151 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear modules."""
from __future__ import absolute_import, division, print_function
import math
import torch
from easycv.core.sailfish.util import (BiasUniformInitializer,
KaimingUniformInitializer,
ModelParallel, RenormUniformInitializer)
class Linear(torch.nn.Module):
r"""Applies a linear transformation to the incoming data.
"""
def __init__(self,
in_features,
out_features,
bias=True,
weight_initializer=None,
bias_initializer=None,
parallel=None):
super(Linear, self).__init__()
if isinstance(parallel, ModelParallel):
if out_features % parallel.world_size != 0:
raise ValueError(
'out_features must be divided by parallel.world_size')
self.out_features = out_features // parallel.world_size
else:
self.out_features = out_features
self.in_features = in_features
self.weight = torch.nn.Parameter(
torch.Tensor(self.out_features, self.in_features))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(self.out_features))
else:
self.register_parameter('bias', None)
self.weight_initializer = weight_initializer
if weight_initializer is None:
self.weight_initializer = KaimingUniformInitializer(math.sqrt(5))
self.bias_initializer = bias_initializer
if bias_initializer is None:
self.bias_initializer = BiasUniformInitializer(self.in_features)
self.reset_parameters()
self.parallel = parallel
def reset_parameters(self):
r"""Reset parameter."""
self.weight_initializer(self.weight)
if self.bias is not None:
self.bias_initializer(self.bias)
def forward(self, features): # pylint: disable=arguments-differ
features = features.type(dtype=self.weight.dtype)
return torch.nn.functional.linear(features, self.weight, self.bias)
class ArcFaceLinear(torch.nn.Module):
r"""Applies a ArcFace transformation to the incoming data.
See https://arxiv.org/abs/1801.05599 .
"""
def __init__(
self,
in_features,
out_features,
margin=0.5,
scale=64.0, # See normface https://arxiv.org/abs/1704.06369
fast_phi=False,
epsilon=0,
weight_initializer=None,
l2_norm=False,
parallel=None):
super(ArcFaceLinear, self).__init__()
if isinstance(parallel, ModelParallel):
if out_features % parallel.world_size != 0:
raise ValueError(
'out_features must be divided by parallel.world_size')
self.out_features = out_features // parallel.world_size
else:
self.out_features = out_features
self.in_features = in_features
self.margin = margin # Angular margin penalty.
self.scale = scale # Radius of hybershpere.
self.fast_phi = fast_phi
self.epsilon = epsilon
self.weight_initializer = weight_initializer
if weight_initializer is None:
self.weight_initializer = RenormUniformInitializer()
self.l2_norm = l2_norm
self.parallel = parallel
self.weight = torch.nn.Parameter(
torch.Tensor(self.out_features, self.in_features))
self.reset_parameters()
self._cos_margin = math.cos(margin)
self._sin_margin = math.sin(margin)
self._threshold = math.cos(math.pi - margin)
self._min = math.sin(math.pi - margin) * self.margin
def reset_parameters(self):
r"""Reset parameters."""
self.weight_initializer(self.weight)
def forward(self, features, target): # pylint: disable=arguments-differ
r"""Compute ::math`\phi = \cos(\theta + margin)` and logits."""
# (N, E) x (E, C) -> (N, C)
features = features.type(dtype=self.weight.dtype)
if self.l2_norm:
features_norm = torch.norm(features, 2, 1, True)
features = torch.div(features, features_norm)
weight_norm = torch.norm(self.weight, 2, 0, True)
weight = torch.div(self.weight, weight_norm)
else:
features = torch.nn.functional.normalize(features)
weight = torch.nn.functional.normalize(self.weight)
cosine = torch.nn.functional.linear(features, weight)
cosine = cosine.clamp(-1, 1) # for numerical stability
sine = torch.sqrt(1. + self.epsilon - cosine * cosine)
phi = cosine * self._cos_margin - sine * self._sin_margin
phi = phi.type(dtype=cosine.dtype)
if self.fast_phi:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self._threshold, phi,
cosine - self._min)
if isinstance(self.parallel, ModelParallel):
mask = self.parallel.correct_mask(target, cosine)
else:
mask = torch.zeros(
cosine.size(), device=cosine.device, dtype=cosine.dtype)
mask.scatter_(1, target.view(-1, 1).long(), 1)
logits = (mask * phi) + ((1.0 - mask) * cosine)
logits *= self.scale
return logits

View File

@ -0,0 +1,206 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loss modules."""
from __future__ import absolute_import, division, print_function
import torch
from easycv.core.sailfish.activation import LogSoftmax
from easycv.core.sailfish.linear import ArcFaceLinear, Linear
from easycv.core.sailfish.util import ModelParallel
class NLLLoss(torch.nn.Module):
r"""The negative log likelihood loss for log probabilities. It is
useful to train a classification problem with `C` classes.
The `input` given through a forward call is expected to contain
log-probabilities of each class. `input` has to be a Tensor of size either
:math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)`
with :math:`K \geq 1` for the `K`-dimensional case (described later).
Obtaining log-probabilities in a neural network is easily achieved by
adding a `LogSoftmax` layer in the last layer of your network.
You may use `CrossEntropyLoss` instead, if you prefer not to add an
extra layer.
The `target` that this loss expects should be a class index in the range
:math:`[0, C-1]` where `C = number\_classes`.
NLLLoss is defined as:
.. math::
\ell(x, y) = -\frac{1}{N}\sum_{n=1}^N L_{i}
Args:
num_classes: total number of classes.
focal: whether to use FocalLoss implementation.
focal_gamm: The focusing parameter of FocalLoss.
rank: rank of current replica.
world_size: size of replicas.
Shape:
- Input: :math:`(\frac{N}{P}, C)` where `C = number of classes`, or
:math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
in the case of `K`-dimensional loss.
- Target: :math:`(\frac{N}{P}, 1)` where each value is
:math:`0 \leq \text{targets}[i] \leq C-1`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
K-dimensional loss.
- Output: scalar.
If :attr:`reduction` is ``'none'``, then the same size as the target:
:math:`(N)`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in
the case of K-dimensional loss.
Examples::
>>> m = LogSoftmax(...)
>>> loss = NLLLoss(...)
>>> # input is of size N x C = 3 x 5
>>> input = torch.randn(3, 5, requires_grad=True)
>>> # each element in target has to have 0 <= value < C
>>> target = torch.tensor([1, 0, 4])
>>> output = loss(m(input), target)
>>> output.backward()
"""
def __init__(self, focal=False, focal_gamma=0, parallel=None):
super(NLLLoss, self).__init__()
self.focal = focal
self.focal_gamma = focal_gamma
self.parallel = parallel
def forward(self, logprob, target): # pylint: disable=arguments-differ
"""Compute negative log likelihood loss from log-probs and target."""
if isinstance(self.parallel, ModelParallel):
with torch.no_grad():
mask = self.parallel.correct_mask(target, logprob)
loss = self.parallel.nll_loss(logprob, mask)
if self.focal:
loss_exp = torch.exp(-loss)
loss = (1 - loss_exp)**self.focal_gamma * loss
return loss
loss = torch.nn.functional.nll_loss(logprob, target)
if self.focal:
loss_exp = torch.exp(-loss)
loss = (1 - loss_exp)**self.focal_gamma * loss
return loss
class CrossEntropyLoss(torch.nn.Module):
r"""This criterion combines :func:`LogSoftmax` and
:func:`NLLLoss` in one single class.
"""
def __init__(self, epsilon=0, parallel=None):
super(CrossEntropyLoss, self).__init__()
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
self._nll_loss = NLLLoss(parallel=parallel)
def forward(self, logits, target): # pylint: disable=arguments-differ
# 1. Compute log-probabilities for current shard:
logprob = self._log_softmax(logits)
# 2. Compute NLL loss for all shards.
return self._nll_loss(logprob, target)
class FocalLoss(torch.nn.Module):
r"""This criterion combines :func:`LogSoftmax` and
:func:`NLLLoss` in one single class.
"""
def __init__(self, gamma=0, epsilon=1e-8, parallel=None):
super(FocalLoss, self).__init__()
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
self._nll_loss = NLLLoss(
focal=True, focal_gamma=gamma, parallel=parallel)
def forward(self, logits, target): # pylint: disable=arguments-differ
logprob = self._log_softmax(logits)
return self._nll_loss(logprob, target)
class SoftmaxLoss(torch.nn.Module):
r"""This criterion combines :func:`Linear` and
:func:`CrossEntropyLoss` in one single class.
"""
def __init__(self,
in_features,
out_features,
bias=True,
epsilon=0,
weight_initializer=None,
bias_initializer=None,
parallel=None):
super(SoftmaxLoss, self).__init__()
self._linear = Linear(
in_features,
out_features,
bias=bias,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
parallel=parallel)
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
self._nll_loss = NLLLoss(parallel=parallel)
self._parallel = parallel
def forward(self, features, target): # pylint: disable=arguments-differ
if isinstance(self._parallel, ModelParallel):
features = self._parallel.gather(features)
logits = self._linear(features.squeeze())
logprob = self._log_softmax(logits)
if isinstance(self._parallel, ModelParallel):
target = self._parallel.gather_target(target)
return self._nll_loss(logprob, target)
class ArcMarginLoss(torch.nn.Module):
r"""This criterion combines :func:`ArcFaceLinear` and
:func:`CrossEntropyLoss` in one single class.
"""
def __init__(
self,
in_features,
out_features,
margin=0.5,
scale=64.0, # See normface https://arxiv.org/abs/1704.06369
fast_phi=False,
epsilon=0,
weight_initializer=None,
l2_norm=False,
parallel=None):
super(ArcMarginLoss, self).__init__()
self._linear = ArcFaceLinear(
in_features,
out_features,
margin=margin,
scale=scale,
l2_norm=l2_norm,
fast_phi=fast_phi,
epsilon=epsilon,
weight_initializer=weight_initializer,
parallel=parallel)
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
self._nll_loss = NLLLoss(parallel=parallel)
self._parallel = parallel
def forward(self, features, target): # pylint: disable=arguments-differ
if isinstance(self._parallel, ModelParallel):
features = self._parallel.gather(features)
target = self._parallel.gather_target(target)
logits = self._linear(features.squeeze(), target)
logprob = self._log_softmax(logits)
return self._nll_loss(logprob, target)

View File

@ -0,0 +1,172 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility modules."""
from __future__ import absolute_import, division, print_function
import math
import torch
from easycv.core.sailfish.function import (all_cat, all_log_softmax,
all_nll_loss, all_sum,
shard_correct_mask,
shard_correct_predictions,
shard_target_and_mask,
shard_topk_correct_predictions)
class DistributedParallel:
"""Base class of parallelism."""
def __init__(self, rank, world_size):
self._rank = rank
self._world_size = world_size
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def correct_mask(self, target, inputs):
mask = torch.zeros(
inputs.size(), device=inputs.device, dtype=inputs.dtype)
mask.scatter_(1, target.view(-1, 1).long(), 1)
return mask
def correct_predictions(self, target, logits, k=1):
if k == 1:
pred = torch.max(logits, dim=1)[1]
return (pred == target.view(-1, 1)).sum().item()
pred = torch.topk(logits, k, dim=1)[1]
return (pred == target.view(-1, 1)).sum().item()
def xavier_uniform_(self, weight, gain=1.):
return torch.nn.init.xavier_uniform_(weight, gain=gain)
class ModelParallel(DistributedParallel):
"""All-to-All Model Parallelism."""
def gather(self, inputs, dim=0, requires_grad=True):
if requires_grad:
return all_cat(
inputs, dim=dim, rank=self.rank, world_size=self.world_size)
all_inputs = [
torch.zeros(
inputs.size(), dtype=inputs.dtype, device=inputs.device)
for _ in range(self.world_size)
]
torch.distributed.all_gather(all_inputs, inputs)
return torch.cat(all_inputs, dim=dim)
def gather_target(self, target):
return self.gather(target, requires_grad=False)
def reduce_sum(self, inputs):
return all_sum(inputs)
def log_softmax(self, logits, epsilon=1e-8):
return all_log_softmax(logits, epsilon=epsilon)
def nll_loss(self, inputs, correct_mask):
return all_nll_loss(inputs, correct_mask)
def correct_mask(self, target, inputs):
return shard_correct_mask(target, inputs, rank=self.rank)
def target_and_mask(self, target, output_features):
return shard_target_and_mask(target, output_features, rank=self.rank)
def correct_predictions(self, target, logits, k=1):
if k == 1:
return shard_correct_predictions(
target, logits, world_size=self.world_size)
return shard_topk_correct_predictions(
target, logits, k, world_size=self.world_size)
class ParameterInitializer:
r"""Base class for parameter initializer."""
def __call__(self, param, shard_rank=0, num_shards=1):
raise NotImplementedError
class ZerosInitializer(ParameterInitializer):
def __call__(self, param, parallel=None):
torch.nn.init.zeros_(param)
class OnesInitializer(ParameterInitializer):
def __call__(self, param, parallel=None):
torch.nn.init.ones_(param)
class XavierUniformInitializer(ParameterInitializer):
def __init__(self, gain=1.):
self.gain = gain
def __call__(self, param, parallel=None):
if isinstance(parallel, ModelParallel):
if param.dim() != 2:
raise ValueError(
'param with dimensions other than 2 not supported')
r = param.size(1) + param.size(0) * parallel.world_size
a = self.gain * math.sqrt(3.0) * math.sqrt(2.0 / float(r))
torch.nn.init.uniform_(param, -a, a)
else:
torch.nn.init.xavier_uniform_(param, gain=self.gain)
class KaimingUniformInitializer(ParameterInitializer):
def __init__(self, bound):
self.bound = bound
def __call__(self, param, parallel=None):
torch.nn.init.kaiming_uniform_(param, a=self.bound)
class BiasUniformInitializer(ParameterInitializer):
def __init__(self, weight_in_features):
self.weight_in_features = weight_in_features
def __call__(self, param, parallel=None):
a = 1 / math.sqrt(float(self.weight_in_features))
torch.nn.init.uniform_(param, -a, a)
class RenormUniformInitializer(ParameterInitializer):
def __init__(self, maxnorm=1e-5, scale=1e5):
self.maxnorm = maxnorm
self.scale = scale
def __call__(self, param, parallel=None):
param.data.uniform_(-1, 1).renorm_(
2, 0, maxnorm=self.maxnorm).mul_(self.scale)
class NormalInitializer(ParameterInitializer):
def __call__(self, param, parallel=None):
torch.nn.init.normal_(param)

View File

View File

@ -0,0 +1,528 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ArcFaceLinear module tests."""
import os
import random
import unittest
import numpy as np
import torch
from easycv.core import sailfish
def mp_vs_ddp_main(gpu, gpus_per_worker):
r"""Model parallel vs. DDP"""
torch.cuda.set_device(gpu)
torch.distributed.init_process_group(
'nccl', rank=gpu, world_size=gpus_per_worker)
try:
num_steps = 5
freeze_num_steps = 5
learning_rate = 0.1
batch_size = 2
image_size = 5
emb_size = 3
num_classes = 8
margin_m = 0.5
margin_s = 64
momentum = 0.9
model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker)
zeros_init = sailfish.ZerosInitializer()
# baseline
torch.manual_seed(42)
random.seed(42)
baseline_fe = torch.nn.Linear(image_size, emb_size).cuda()
baseline_fe = torch.nn.parallel.DistributedDataParallel(
baseline_fe, device_ids=[gpu])
baseline_fe_params = list(baseline_fe.parameters())
baseline_fc = sailfish.ArcFaceLinear(
emb_size,
num_classes,
margin=margin_m,
scale=margin_s,
weight_initializer=zeros_init).cuda()
baseline_fc = torch.nn.parallel.DistributedDataParallel(
baseline_fc, device_ids=[gpu])
baseline_fc_params = list(baseline_fc.parameters())
baseline_criterion = torch.nn.CrossEntropyLoss().cuda()
baseline_optimizer = torch.optim.SGD(
[{
'params': baseline_fe.parameters()
}, {
'params': baseline_fc.parameters()
}],
lr=learning_rate,
momentum=momentum)
baseline_fe.train()
baseline_fc.train()
baseline_criterion.train()
# hybrid parallelism
torch.manual_seed(42)
random.seed(42)
fe = torch.nn.Linear(image_size, emb_size).cuda()
fe = torch.nn.parallel.DistributedDataParallel(fe, device_ids=[gpu])
fe_params = list(fe.parameters())
fc = sailfish.ArcFaceLinear(
emb_size,
num_classes,
margin=margin_m,
scale=margin_s,
weight_initializer=zeros_init,
parallel=model_parallel).cuda()
fc_params = list(fc.parameters())
criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda()
optimizer = torch.optim.SGD([{
'params': fe.parameters()
}, {
'params': fc.parameters()
}],
lr=learning_rate,
momentum=momentum)
fe.train()
fc.train()
criterion.train()
for step in range(num_steps):
# baseline
torch.manual_seed(42 * step + gpu)
random.seed(42 * step + gpu)
baseline_data = torch.randn([batch_size, image_size]).cuda()
baseline_label = torch.as_tensor([
random.randint(0, num_classes - 1) for _ in range(batch_size)
]).cuda()
baseline_features = baseline_fe(baseline_data)
baseline_logits = baseline_fc(baseline_features, baseline_label)
baseline_loss = baseline_criterion(baseline_logits, baseline_label)
baseline_loss = model_parallel.reduce_sum(baseline_loss)
baseline_loss = baseline_loss / gpus_per_worker
baseline_optimizer.zero_grad()
baseline_loss.backward()
baseline_optimizer.step()
# hybrid parallelism
torch.manual_seed(42 * step + gpu)
random.seed(42 * step + gpu)
data = torch.randn([batch_size, image_size]).cuda()
label = torch.as_tensor([
random.randint(0, num_classes - 1) for _ in range(batch_size)
]).cuda()
features = fe(data)
all_features = model_parallel.gather(features)
all_label = model_parallel.gather_target(label)
shard_logits = fc(all_features, all_label)
loss = criterion(shard_logits, all_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# eval
torch.manual_seed(42 * step + gpu)
random.seed(42 * step + gpu)
with torch.no_grad():
gathered_logits = model_parallel.gather(shard_logits, dim=1)
gathered_baseline_logits = model_parallel.gather(
baseline_logits, dim=0)
logits_norm_val = torch.norm(gathered_logits).item()
baseline_logits_norm_val = torch.norm(
gathered_baseline_logits).item()
np.testing.assert_allclose(
logits_norm_val,
baseline_logits_norm_val,
rtol=1e-5,
atol=1e-4,
err_msg='logits at gpu {} step {}'.format(gpu, step))
loss_val = loss.cpu().detach().numpy()
baseline_loss_val = baseline_loss.cpu().detach().numpy()
np.testing.assert_allclose(
loss_val,
baseline_loss_val,
rtol=1e-5,
atol=1e-4,
err_msg='loss at gpu {} step {}'.format(gpu, step))
fc_grad = model_parallel.gather(fc_params[0].grad)
baseline_fc_grad = baseline_fc_params[0].grad
np.testing.assert_allclose(
fc_grad.cpu().detach().numpy(),
baseline_fc_grad.cpu().detach().numpy(),
rtol=1e-5,
atol=1e-4,
err_msg='fc grad at gpu {} step {}'.format(gpu, step))
fe_weight = fe_params[0]
baseline_fe_weight = baseline_fe_params[0]
np.testing.assert_allclose(
fe_weight.cpu().detach().numpy(),
baseline_fe_weight.cpu().detach().numpy(),
rtol=1e-5,
atol=1e-4,
err_msg='fe weight at gpu {} step {}'.format(gpu, step))
for p in baseline_fe.parameters():
p.requires_grad = False
for p in fe.parameters():
p.requires_grad = False
for step in range(freeze_num_steps):
# baseline
torch.manual_seed(100 * step + gpu)
random.seed(100 * step + gpu)
baseline_data = torch.randn([batch_size, image_size]).cuda()
baseline_label = torch.as_tensor([
random.randint(0, num_classes - 1) for _ in range(batch_size)
]).cuda()
baseline_features = baseline_fe(baseline_data)
baseline_logits = baseline_fc(baseline_features, baseline_label)
baseline_loss = baseline_criterion(baseline_logits, baseline_label)
baseline_loss = model_parallel.reduce_sum(baseline_loss)
baseline_loss = baseline_loss / gpus_per_worker
baseline_optimizer.zero_grad()
baseline_loss.backward()
baseline_optimizer.step()
# hybrid parallelism
torch.manual_seed(100 * step + gpu)
random.seed(100 * step + gpu)
data = torch.randn([batch_size, image_size]).cuda()
label = torch.as_tensor([
random.randint(0, num_classes - 1) for _ in range(batch_size)
]).cuda()
features = fe(data)
all_features = model_parallel.gather(features)
all_label = model_parallel.gather_target(label)
shard_logits = fc(all_features, all_label)
loss = criterion(shard_logits, all_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# eval
torch.manual_seed(100 * step + gpu)
random.seed(100 * step + gpu)
with torch.no_grad():
gathered_logits = model_parallel.gather(shard_logits, dim=1)
gathered_baseline_logits = model_parallel.gather(
baseline_logits, dim=0)
logits_norm_val = torch.norm(gathered_logits).item()
baseline_logits_norm_val = torch.norm(
gathered_baseline_logits).item()
np.testing.assert_allclose(
logits_norm_val,
baseline_logits_norm_val,
rtol=1e-5,
atol=1e-4,
err_msg='freeze logits at gpu {} step {}'.format(
gpu, step))
loss_val = loss.cpu().detach().numpy()
baseline_loss_val = baseline_loss.cpu().detach().numpy()
np.testing.assert_allclose(
loss_val,
baseline_loss_val,
rtol=1e-5,
atol=1e-4,
err_msg='freeze loss at gpu {} step {}'.format(gpu, step))
fc_grad = model_parallel.gather(fc_params[0].grad)
baseline_fc_grad = baseline_fc_params[0].grad
np.testing.assert_allclose(
fc_grad.cpu().detach().numpy(),
baseline_fc_grad.cpu().detach().numpy(),
rtol=1e-5,
atol=1e-4,
err_msg='freeze fc grad at gpu {} step {}'.format(
gpu, step))
fe_weight = fe_params[0]
baseline_fe_weight = baseline_fe_params[0]
np.testing.assert_allclose(
fe_weight.cpu().detach().numpy(),
baseline_fe_weight.cpu().detach().numpy(),
rtol=1e-5,
atol=1e-4,
err_msg='freeze fe weight at gpu {} step {}'.format(
gpu, step))
finally:
torch.distributed.destroy_process_group()
def mp_main(gpu,
gpus_per_worker,
results,
num_steps=1,
batch_size=1,
num_classes=8):
r"""Model parallel"""
torch.cuda.set_device(gpu)
torch.distributed.init_process_group(
'nccl', rank=gpu, world_size=gpus_per_worker)
zeros_init = sailfish.ZerosInitializer()
try:
emb_size = 3
learning_rate = 0.1
margin_m = 0.5
margin_s = 64
momentum = 0.9
image_size = 6
model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker)
# hybrid parallelism
torch.manual_seed(42)
random.seed(42)
fe = torch.nn.Linear(image_size, emb_size).cuda()
fc = sailfish.ArcFaceLinear(
emb_size,
num_classes,
margin=margin_m,
scale=margin_s,
weight_initializer=zeros_init,
parallel=model_parallel).cuda()
fc_params = list(fc.parameters())
criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda()
optimizer = torch.optim.SGD(
fc.parameters(), lr=learning_rate, momentum=momentum)
fc.train()
criterion.train()
for step in range(num_steps):
baseline = results[step]
torch.manual_seed(42 * step + gpu)
random.seed(42 * step + gpu)
data = torch.randn([batch_size, image_size]).cuda()
features = fe(data)
label = torch.as_tensor([
random.randint(0, num_classes - 1) for _ in range(batch_size)
]).cuda()
all_features = model_parallel.gather(features)
all_label = model_parallel.gather_target(label)
torch.manual_seed(42 * step)
random.seed(42 * step)
np.testing.assert_equal(
list(all_features.size()),
baseline['features/size'],
err_msg='Wrong features size at gpu {} step {}'.format(
gpu, step))
np.testing.assert_allclose(
torch.norm(all_features).item(),
baseline['features/norm'],
rtol=1e-5,
err_msg='Wrong features norm at gpu {} step {}'.format(
gpu, step))
shard_logits = fc(all_features, all_label)
loss = criterion(shard_logits, all_label)
np.testing.assert_allclose(
loss.item(),
baseline['loss'],
rtol=1e-5,
err_msg='Wrong loss at gpu {} step {}'.format(gpu, step))
optimizer.zero_grad()
loss.backward()
optimizer.step()
fc_grad = model_parallel.gather(fc_params[0].grad)
np.testing.assert_allclose(
torch.norm(fc_grad).item(),
baseline['logits/grad/norm'],
rtol=1e-5,
err_msg='Wrong logits grad at gpu {} step {}'.format(
gpu, step))
finally:
torch.distributed.destroy_process_group()
def baseline_main(gpus_per_worker, num_steps=1, batch_size=1, num_classes=8):
r"""run on 1 GPU"""
emb_size = 3
learning_rate = 0.1
momentum = 0.9
image_size = 6
zeros_init = sailfish.ZerosInitializer()
# hybrid parallelism
torch.manual_seed(42)
random.seed(42)
fe = torch.nn.Linear(image_size, emb_size).cuda()
fc = sailfish.ArcFaceLinear(
emb_size, num_classes, weight_initializer=zeros_init).cuda()
fc_params = list(fc.parameters())
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(
fc.parameters(), lr=learning_rate, momentum=momentum)
fc.train()
criterion.train()
results = []
for step in range(num_steps):
result_item = {}
features_list = []
label_list = []
for gpu in range(gpus_per_worker):
torch.manual_seed(42 * step + gpu)
random.seed(42 * step + gpu)
features_list.append(
fe(torch.randn([batch_size, image_size]).cuda()))
label_list.append(
torch.as_tensor([
random.randint(0, num_classes - 1)
for _ in range(batch_size)
]).cuda())
all_features = torch.cat(features_list)
all_label = torch.cat(label_list)
torch.manual_seed(42 * step)
random.seed(42 * step)
result_item['features/size'] = list(all_features.size())
result_item['features/norm'] = torch.norm(all_features).item()
logits = fc(all_features, all_label)
loss = criterion(logits, all_label)
result_item['loss'] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
result_item['logits/grad/norm'] = torch.norm(fc_params[0].grad).item()
results.append(result_item)
return results
class TestArcFaceLinear(unittest.TestCase):
r"""Test sailfish.ArcFaceLinear."""
def test_no_parallel(self):
r"""Test sailfish.ArcFaceLinear without parallel."""
in_features = 1
out_features = 2
margin_m = random.random()
margin_s = random.random()
features = torch.randn([1, in_features])
label = torch.as_tensor([random.randint(0, out_features - 1)])
torch.manual_seed(42)
random.seed(42)
baseline = sailfish.ArcFaceLinear(
in_features, out_features, margin=margin_m, scale=margin_s)
baseline_optimizer = torch.optim.SGD(baseline.parameters(), lr=1.)
baseline.train()
baseline_logits = baseline(features, label)
baseline_loss = torch.sum(baseline_logits)
baseline_optimizer.zero_grad()
baseline_loss.backward()
baseline_optimizer.step()
torch.manual_seed(42)
random.seed(42)
fc = sailfish.ArcFaceLinear(
in_features, out_features, margin=margin_m, scale=margin_s)
optimizer = torch.optim.SGD(fc.parameters(), lr=1.)
fc.train()
logits = fc(features, label)
loss = torch.sum(logits)
optimizer.zero_grad()
loss.backward()
optimizer.step()
np.testing.assert_allclose(
logits.detach().numpy(),
baseline_logits.detach().numpy(),
err_msg='logits not equal to baseline')
np.testing.assert_allclose(
[p.detach().numpy() for p in baseline.parameters()],
[p.detach().numpy() for p in fc.parameters()],
err_msg='parameters not equal to baseline')
def test_mp(self):
r"""Test sailfish.ArcFaceLinear on 1 GPU."""
in_features = 1
out_features = 2
features = torch.randn([1, in_features])
label = torch.as_tensor([random.randint(0, out_features - 1)])
torch.manual_seed(42)
random.seed(42)
baseline = sailfish.ArcFaceLinear(in_features, out_features)
baseline_optimizer = torch.optim.SGD(baseline.parameters(), lr=1.)
baseline.train()
baseline_logits = baseline(features, label)
baseline_loss = torch.sum(baseline_logits)
baseline_optimizer.zero_grad()
baseline_loss.backward()
baseline_optimizer.step()
torch.manual_seed(42)
random.seed(42)
model_parallel = sailfish.ModelParallel(0, 1)
fc = sailfish.ArcFaceLinear(
in_features, out_features, parallel=model_parallel)
optimizer = torch.optim.SGD(fc.parameters(), lr=1.)
fc.train()
logits = fc(features, label)
loss = torch.sum(logits)
optimizer.zero_grad()
loss.backward()
optimizer.step()
np.testing.assert_allclose(
logits.detach().numpy(),
baseline_logits.detach().numpy(),
err_msg='logits not equal to baseline')
np.testing.assert_allclose(
[p.detach().numpy() for p in baseline.parameters()],
[p.detach().numpy() for p in fc.parameters()],
err_msg='parameters not equal to baseline')
def cant_test_mp_vs_ddp(self):
r"""Test sailfish.ArcFaceLinear with model parallel."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '24601'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
gpus_per_worker = torch.cuda.device_count()
torch.multiprocessing.spawn(
mp_vs_ddp_main,
args=(gpus_per_worker, ),
nprocs=gpus_per_worker,
join=True)
def test_mp_vs_1gpu(self):
r"""Test sailfish.ArcFaceLinear with model parallel."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '24601'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
gpus_per_worker = torch.cuda.device_count()
num_steps = 5
batch_size = 1
num_classes = gpus_per_worker
results = baseline_main(gpus_per_worker, num_steps, batch_size,
num_classes)
torch.multiprocessing.spawn(
mp_main,
args=(gpus_per_worker, results, num_steps, batch_size,
num_classes),
nprocs=gpus_per_worker,
join=True)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,314 @@
# Copyright 2019 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear module tests."""
import math
import os
import random
import unittest
import numpy as np
import torch
from easycv.core import sailfish
class MockLinear(torch.nn.Module):
r"""Applies a linear transformation to the incoming data.
"""
def __init__(self,
in_features,
out_features,
bias=True,
weight_initializer=None,
bias_initializer=None,
parallel=None):
super(MockLinear, self).__init__()
self.out_features = out_features
self.in_features = in_features
self.weight = torch.nn.Parameter(
torch.Tensor(self.out_features, self.in_features))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(self.out_features))
else:
self.register_parameter('bias', None)
self.weight_initializer = weight_initializer
if weight_initializer is None:
self.weight_initializer = sailfish.KaimingUniformInitializer(
math.sqrt(5))
self.bias_initializer = bias_initializer
if bias_initializer is None:
self.bias_initializer = sailfish.BiasUniformInitializer(
self.in_features)
self.reset_parameters()
self.parallel = parallel
def reset_parameters(self):
r"""Reset parameters."""
self.weight_initializer(self.weight)
if self.bias is not None:
self.bias_initializer(self.bias)
def forward(self, features): # pylint: disable=arguments-differ
features = features.type(dtype=self.weight.dtype)
return torch.nn.functional.linear(features, self.weight, self.bias)
def _run_baseline_train_main(gpus_per_worker, num_steps, batch_size,
in_features, out_features, bias, lr):
r"""Run baseline on 1 GPU."""
torch.manual_seed(42)
random.seed(42)
fc = MockLinear(
in_features,
out_features,
bias=bias,
weight_initializer=sailfish.ZerosInitializer(),
bias_initializer=sailfish.OnesInitializer()).cuda()
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
fc.train()
criterion.train()
results = []
for step in range(num_steps):
result = {}
features_list = []
label_list = []
for gpu in range(gpus_per_worker):
torch.manual_seed(42 * step + gpu)
random.seed(42 * step + gpu)
features_list.append(torch.randn([batch_size, in_features]).cuda())
label_list.append(
torch.as_tensor([
random.randint(0, out_features - 1)
for _ in range(batch_size)
]).cuda())
features = torch.cat(features_list)
label = torch.cat(label_list)
torch.manual_seed(42 * step)
random.seed(42 * step)
logits = fc(features)
loss = criterion(logits, label)
result['loss'] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
result['grads/norm'] = [
torch.norm(p.grad).item() for p in fc.parameters()
]
results.append(result)
return results
def _run_mp_train_main(gpu, gpus_per_worker, baseline_steps, num_steps,
batch_size, in_features, out_features, bias, lr):
r"""Run MP and validate results."""
torch.cuda.set_device(gpu)
torch.distributed.init_process_group(
'nccl', rank=gpu, world_size=gpus_per_worker)
torch.manual_seed(42)
random.seed(42)
model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker)
fc = sailfish.Linear(
in_features,
out_features,
bias=bias,
weight_initializer=sailfish.ZerosInitializer(),
bias_initializer=sailfish.OnesInitializer(),
parallel=model_parallel).cuda()
criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda()
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
fc.train()
criterion.train()
for step in range(num_steps):
torch.manual_seed(42 * step + gpu)
random.seed(42 * step + gpu)
features = torch.randn([batch_size, in_features]).cuda()
features = model_parallel.gather(features)
label = torch.as_tensor([
random.randint(0, out_features - 1) for _ in range(batch_size)
]).cuda()
label = model_parallel.gather_target(label)
torch.manual_seed(42 * step)
random.seed(42 * step)
logits = fc(features)
loss = criterion(logits, label)
np.testing.assert_allclose(
loss.item(),
baseline_steps[step]['loss'],
rtol=1e-5,
err_msg='Wrong loss at gpu {} step {}'.format(gpu, step))
optimizer.zero_grad()
loss.backward()
optimizer.step()
grad_norms = [
torch.norm(model_parallel.gather(p.grad)).item()
for p in fc.parameters()
]
np.testing.assert_allclose(
grad_norms,
baseline_steps[step]['grads/norm'],
rtol=1e-5,
err_msg='Wrong grads norm at gpu {} step {}'.format(gpu, step))
class TestLinear(unittest.TestCase):
r"""Test sailfish.Linear."""
def _run_baseline_train(self, batch_size, in_features, out_features, bias,
lr):
r"""Run baseline without parallel."""
result = {}
features = torch.randn([batch_size, in_features])
fc = torch.nn.Linear(in_features, out_features, bias=bias)
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
fc.train()
logits = fc(features)
loss = torch.sum(logits)
result['loss'] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
result['grads/norm'] = [
torch.norm(p.grad).item() for p in fc.parameters()
]
return result
def _run_mp_no_parallel_train(self, batch_size, in_features, out_features,
bias, lr):
r"""Run MP without parallel."""
result = {}
features = torch.randn([batch_size, in_features])
fc = sailfish.Linear(in_features, out_features, bias=bias)
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
fc.train()
logits = fc(features)
loss = torch.sum(logits)
result['loss'] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
result['grads/norm'] = [
torch.norm(p.grad).item() for p in fc.parameters()
]
return result
def _run_mp_1gpu_train(self, batch_size, in_features, out_features, bias,
lr):
r"""Run MP on 1 GPU."""
result = {}
features = torch.randn([batch_size, in_features])
model_parallel = sailfish.ModelParallel(0, 1)
fc = sailfish.Linear(
in_features, out_features, bias=bias, parallel=model_parallel)
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
fc.train()
logits = fc(features)
loss = torch.sum(logits)
result['loss'] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
result['grads/norm'] = [
torch.norm(p.grad).item() for p in fc.parameters()
]
return result
def test_no_parallel(self):
r"""Test sailfish.Linear without parallel."""
batch_size = 3
in_features = 4
out_features = 5
bias = False
lr = 0.1
for step in range(5):
torch.manual_seed(42 + step)
random.seed(42 + step)
baseline = self._run_baseline_train(batch_size, in_features,
out_features, bias, lr)
torch.manual_seed(42 + step)
random.seed(42 + step)
rc = self._run_mp_no_parallel_train(batch_size, in_features,
out_features, bias, lr)
np.testing.assert_allclose(
rc['loss'], baseline['loss'], err_msg='loss not equal')
np.testing.assert_allclose(
rc['grads/norm'],
baseline['grads/norm'],
err_msg='norm of grads not equal')
def test_mp(self):
r"""Test sailfish.Linear with model parallel on 1 GPU."""
batch_size = 2
in_features = 7
out_features = 4
bias = True
lr = 0.6
for step in range(5):
torch.manual_seed(100 + step)
random.seed(100 + step)
baseline = self._run_baseline_train(batch_size, in_features,
out_features, bias, lr)
torch.manual_seed(100 + step)
random.seed(100 + step)
rc = self._run_mp_1gpu_train(batch_size, in_features, out_features,
bias, lr)
np.testing.assert_allclose(
rc['loss'], baseline['loss'], err_msg='loss not equal')
np.testing.assert_allclose(
rc['grads/norm'],
baseline['grads/norm'],
err_msg='norm of grads not equal')
def test_mp_vs_1gpu(self):
r"""Test sailfish.ArcFaceLinear with model parallel."""
gpus_per_worker = torch.cuda.device_count()
num_steps = 5
batch_size = 2
in_features = 3
out_features = gpus_per_worker
bias = True
lr = 0.6
baseline_steps = _run_baseline_train_main(gpus_per_worker, num_steps,
batch_size, in_features,
out_features, bias, lr)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '24601'
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
torch.multiprocessing.spawn(
_run_mp_train_main,
args=(gpus_per_worker, baseline_steps, num_steps, batch_size,
in_features, out_features, bias, lr),
nprocs=gpus_per_worker,
join=True)
if __name__ == '__main__':
unittest.main()