From 813a191876ced801330b10c208e331e77585fc0b Mon Sep 17 00:00:00 2001 From: "yuanman.ym" Date: Thu, 16 Jun 2022 19:24:37 +0800 Subject: [PATCH] Add sailfish for fully sharded data parallel training --- easycv/core/sailfish/__init__.py | 36 ++ easycv/core/sailfish/activation.py | 52 +++ easycv/core/sailfish/function.py | 269 ++++++++++++++ easycv/core/sailfish/linear.py | 151 ++++++++ easycv/core/sailfish/loss.py | 206 +++++++++++ easycv/core/sailfish/util.py | 172 +++++++++ tests/core/sailfish/__init__.py | 0 tests/core/sailfish/test_arcface.py | 528 ++++++++++++++++++++++++++++ tests/core/sailfish/test_linear.py | 314 +++++++++++++++++ 9 files changed, 1728 insertions(+) create mode 100644 easycv/core/sailfish/__init__.py create mode 100644 easycv/core/sailfish/activation.py create mode 100644 easycv/core/sailfish/function.py create mode 100644 easycv/core/sailfish/linear.py create mode 100644 easycv/core/sailfish/loss.py create mode 100644 easycv/core/sailfish/util.py create mode 100644 tests/core/sailfish/__init__.py create mode 100644 tests/core/sailfish/test_arcface.py create mode 100644 tests/core/sailfish/test_linear.py diff --git a/easycv/core/sailfish/__init__.py b/easycv/core/sailfish/__init__.py new file mode 100644 index 00000000..4493087e --- /dev/null +++ b/easycv/core/sailfish/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pytorch extension for distributed training.""" + +from __future__ import absolute_import, division, print_function + +from easycv.core.sailfish.activation import LogSoftmax # noqa: F401 +from easycv.core.sailfish.linear import ArcFaceLinear # noqa: F401 +from easycv.core.sailfish.linear import Linear # noqa: F401 +from easycv.core.sailfish.loss import ArcMarginLoss # noqa: F401 +from easycv.core.sailfish.loss import CrossEntropyLoss # noqa: F401 +from easycv.core.sailfish.loss import FocalLoss # noqa: F401 +from easycv.core.sailfish.loss import NLLLoss # noqa: F401 +from easycv.core.sailfish.loss import SoftmaxLoss # noqa: F401 +from easycv.core.sailfish.util import BiasUniformInitializer # noqa: F401 +from easycv.core.sailfish.util import DistributedParallel # noqa: F401 +from easycv.core.sailfish.util import KaimingUniformInitializer # noqa: F401 +from easycv.core.sailfish.util import ModelParallel # noqa: F401 +from easycv.core.sailfish.util import NormalInitializer # noqa: F401 +from easycv.core.sailfish.util import OnesInitializer # noqa: F401 +from easycv.core.sailfish.util import ParameterInitializer # noqa: F401 +from easycv.core.sailfish.util import RenormUniformInitializer # noqa: F401 +from easycv.core.sailfish.util import XavierUniformInitializer # noqa: F401 +from easycv.core.sailfish.util import ZerosInitializer # noqa: F401 diff --git a/easycv/core/sailfish/activation.py b/easycv/core/sailfish/activation.py new file mode 100644 index 00000000..7275f407 --- /dev/null +++ b/easycv/core/sailfish/activation.py @@ -0,0 +1,52 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Activation modules.""" + +from __future__ import absolute_import, division, print_function + +import torch + +from easycv.core.sailfish.util import ModelParallel + + +class LogSoftmax(torch.nn.Module): + r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional + input Tensor rescaling them so that the elements of the + n-dimensional output Tensor lie in the range (0, 1]. + + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range (-inf,0]. + + Examples:: + >>> m = LogSoftmax() + >>> input = torch.randn(2, 3) + >>> output = m(input) + """ + + def __init__(self, epsilon=0, parallel=None): + super(LogSoftmax, self).__init__() + self.epsilon = epsilon + self.parallel = parallel + + def forward(self, logits): # pylint: disable=arguments-differ + if isinstance(self.parallel, ModelParallel): + return self.parallel.log_softmax(logits, epsilon=self.epsilon) + return torch.nn.functional.log_softmax(logits, _stacklevel=5) diff --git a/easycv/core/sailfish/function.py b/easycv/core/sailfish/function.py new file mode 100644 index 00000000..1c992ecd --- /dev/null +++ b/easycv/core/sailfish/function.py @@ -0,0 +1,269 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions.""" + +from __future__ import absolute_import, division, print_function + +import torch + + +class _Cat(torch.autograd.Function): + """Concat inputs.""" + + @staticmethod + def forward(ctx, inputs, dim, rank, world_size): # noqa: E501 # pylint: disable=arguments-differ + r"""Cat is defined as: + .. math:: + \text{all_cat}(x_i) = \bigoplus_j x_j + """ + ctx.dim = dim + ctx.rank = rank + ctx.world_size = world_size + all_inputs = [ + torch.zeros( + inputs.size(), dtype=inputs.dtype, device=inputs.device) + for _ in range(world_size) + ] + torch.distributed.all_gather(all_inputs, inputs) + output = torch.cat(all_inputs, dim=dim) + output.requires_grad_() + return output + + @staticmethod + def backward(ctx, grad_output): # pylint: disable=arguments-differ + r"""Gradient of Cat is defined as: + .. math:: + \nabla \text{all_cat}(x_i) = \text{split}(\nabla x_i) + """ + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input) + grad_input_dim_size = grad_input.size()[ctx.dim] + assert grad_input_dim_size % ctx.world_size == 0 + split_size = grad_input_dim_size // ctx.world_size + grad_input_splits = torch.split(grad_input, split_size, dim=ctx.dim) + return grad_input_splits[ctx.rank], None, None, None + + +def all_cat(inputs, dim=0, rank=0, world_size=1): + return _Cat.apply(inputs, dim, rank, world_size) + + +class _Sum(torch.autograd.Function): + """Sum inputs.""" + + @staticmethod + def forward(_, inputs): # pylint: disable=arguments-differ + r"""Sum is defined as: + .. math:: + \text{all_sum}(x_i) = \sum_j x_j + """ + inputs_sum = inputs.clone() + torch.distributed.all_reduce(inputs_sum) + inputs_sum.requires_grad_() + return inputs_sum + + @staticmethod + def backward(_, grad_output): # pylint: disable=arguments-differ + r"""Gradient of Sum is defined as: + .. math:: + \nabla \text{all_sum}(x_i) = \sum_j\nabla x_j + """ + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input) + return grad_input + + +def all_sum(inputs): + return _Sum.apply(inputs) + + +class _LogSoftmax(torch.autograd.Function): + """Compute log softmax of logits.""" + + @staticmethod + def forward(ctx, logits, epsilon): # pylint: disable=arguments-differ + r"""LogSoftmax is defined as: + .. math:: + \log(\text{softmax}(x_i)) + = \log\left(\frac{\text{e}^{x_i}}{\sum_j\text{e}^{x_j}}\right) + = x_i - \log\sum_j\text{e}^{x_j} + + For numerical stability, it subtracts the maximum value for every logits: + .. math:: + \log(\text{softmax}(x_i)) + = \hat{x_i} - \log\sum_j\text{e}^{\hat{x_j}}, + \hat{x} = x - \max_j{x_j} + """ + ctx.logits_dtype = logits.dtype + logits_max = torch.max(logits, dim=1).values + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX) + logits = logits - logits_max.view(-1, 1) + logits_exp = torch.exp(logits) + logits_exp_sum = torch.sum(logits_exp, dim=1) + torch.distributed.all_reduce(logits_exp_sum) + logits_exp_sum_log = torch.log(logits_exp_sum + epsilon) + prob_log = logits - logits_exp_sum_log.view(-1, 1) + ctx.save_for_backward(prob_log) + prob_log.requires_grad_() + return prob_log + + @staticmethod + def backward(ctx, grad_output): # pylint: disable=arguments-differ + r"""Gradient of LogSoftmax is defined as: + .. math:: + \nabla\log(\text{softmax}(x_i)) + = \nabla x_i - \text{softmax}(x_i) \sum_j \nabla x_j + """ + grad_output_sum = torch.sum(grad_output, dim=1) + torch.distributed.all_reduce(grad_output_sum) + prob_log, = ctx.saved_tensors + grad_input = torch.exp(prob_log) * grad_output_sum.view(-1, 1) + grad_input = grad_output - grad_input + grad_input = grad_input.type(dtype=ctx.logits_dtype) + return grad_input, None + + +def all_log_softmax(logits, epsilon=1e-8): + return _LogSoftmax.apply(logits, epsilon) + + +class _NLLLoss(torch.autograd.Function): + """calculate NLLLoss from mask.""" + + @staticmethod + def forward(ctx, inputs, correct_mask): # pylint: disable=arguments-differ + ctx.inputs_size = inputs.size() + ctx.save_for_backward(correct_mask) + loss = torch.sum(inputs * correct_mask) / -ctx.inputs_size[0] + torch.distributed.all_reduce(loss) + loss.requires_grad_() + return loss + + @staticmethod + def backward(ctx, grad_output): # pylint: disable=arguments-differ + correct_mask, = ctx.saved_tensors + grad_input = grad_output.repeat(*ctx.inputs_size) + grad_input = grad_input * correct_mask / -ctx.inputs_size[0] + return grad_input, None + + +def all_nll_loss(inputs, correct_mask): + return _NLLLoss.apply(inputs, correct_mask) + + +def shard_target_and_mask(target, output_features, rank=0): + target_shard_begin = output_features * rank + target_shard_end = output_features * (rank + 1) + target_shard_lmask = torch.ge(target, target_shard_begin) + target_shard_rmask = torch.lt(target, target_shard_end) + target_mask = (target_shard_lmask * target_shard_rmask).to(target.device) + target_shard = (target - target_shard_begin) * target_mask.long() + return target_shard, target_mask + + +def shard_correct_mask(target, inputs, rank=0): + """Get correct mask of inputs.""" + inputs_size = inputs.size() + target_shard_begin = inputs_size[1] * rank + target_shard_end = inputs_size[1] * (rank + 1) + target_shard_lmask = torch.ge(target, target_shard_begin) + target_shard_rmask = torch.lt(target, target_shard_end) + target_mask = (target_shard_lmask * target_shard_rmask).to(target.device) + target_shard = (target - target_shard_begin) * target_mask.long() + mask = torch.zeros(inputs_size, device=inputs.device, dtype=inputs.dtype) + mask.scatter_(1, target_shard.view(-1, 1).long(), 1) + mask.masked_fill_((~target_mask).view(-1, 1).expand(*inputs.size()), 0) + return mask + + +def shard_correct_predictions(target, logits, world_size=1): + r"""Calculate correct predictions for logits.""" + shard_max_logits, shard_expected_class = torch.max(logits, dim=1) + all_max_logits = [ + torch.zeros( + shard_max_logits.size(), + dtype=shard_max_logits.dtype, + device=shard_max_logits.device) for _ in range(world_size) + ] + torch.distributed.all_gather(all_max_logits, shard_max_logits) + all_max_logits = torch.cat([t.view(-1, 1) for t in all_max_logits], dim=1) + rank_pred = torch.max(all_max_logits, dim=1)[1].view(-1, 1) + all_shard_pred = [ + torch.zeros( + shard_expected_class.size(), + dtype=shard_expected_class.dtype, + device=shard_expected_class.device) for _ in range(world_size) + ] + torch.distributed.all_gather(all_shard_pred, shard_expected_class) + all_shard_pred = torch.cat([t.view(-1, 1) for t in all_shard_pred], dim=1) + all_shard_pred_mask = torch.zeros( + all_shard_pred.size(), + device=all_shard_pred.device, + dtype=all_shard_pred.dtype) + all_shard_pred_mask.scatter_(1, rank_pred.long(), 1) + shard_pred = torch.sum( + all_shard_pred * all_shard_pred_mask, dim=1).view(-1, 1) + pred = shard_pred + rank_pred * logits.size()[1] + return (pred == target.data.view_as(pred)).sum().item() + + +def shard_topk_correct_predictions(target, logits, k, world_size=1): + r"""Calculate correct predictions for logits.""" + # Step 1: Compute top-k of shard logits. + logits_topk, logits_topk_idx = torch.topk(logits, k, dim=1) + all_logits_topk = [ + torch.zeros( + logits_topk.size(), + dtype=logits_topk.dtype, + device=logits_topk.device) for _ in range(world_size) + ] + torch.distributed.all_gather(all_logits_topk, logits_topk) + all_logits_topk = torch.cat([t.view(-1, k) for t in all_logits_topk], + dim=1) + + all_logits_topk_idx = [ + torch.zeros( + logits_topk_idx.size(), + dtype=logits_topk_idx.dtype, + device=logits_topk_idx.device) for _ in range(world_size) + ] + torch.distributed.all_gather(all_logits_topk_idx, logits_topk_idx) + all_logits_topk_idx = torch.cat( + [t.view(-1, k) for t in all_logits_topk_idx], dim=1) + + # Step 2: Compute global top-k indices. + _, all_logits_topk_topk_idx = torch.topk(all_logits_topk, k, dim=1) + all_logits_topk_topk_idx = all_logits_topk_topk_idx.view(-1, k) + all_logits_topk_mask = torch.zeros( + all_logits_topk_idx.size(), + device=all_logits_topk_idx.device, + dtype=all_logits_topk_idx.dtype) + all_logits_topk_mask.scatter_(1, all_logits_topk_topk_idx.long(), 1) + batch_size, shard_num_classes = logits.size() + all_logits_topk_base = torch.cat([ + torch.ones([batch_size, k], + device=all_logits_topk_idx.device, + dtype=all_logits_topk_idx.dtype) * p * shard_num_classes + for p in range(world_size) + ], + dim=1) + all_logits_topk_gidx = all_logits_topk_base + all_logits_topk_idx + + # Step 3: Compute predictions and check. + pred = torch.masked_select(all_logits_topk_gidx, + all_logits_topk_mask.type(torch.bool)).view( + batch_size, k) + return (pred == target.view(-1, 1)).sum().item() diff --git a/easycv/core/sailfish/linear.py b/easycv/core/sailfish/linear.py new file mode 100644 index 00000000..6386dab6 --- /dev/null +++ b/easycv/core/sailfish/linear.py @@ -0,0 +1,151 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear modules.""" + +from __future__ import absolute_import, division, print_function +import math + +import torch + +from easycv.core.sailfish.util import (BiasUniformInitializer, + KaimingUniformInitializer, + ModelParallel, RenormUniformInitializer) + + +class Linear(torch.nn.Module): + r"""Applies a linear transformation to the incoming data. + """ + + def __init__(self, + in_features, + out_features, + bias=True, + weight_initializer=None, + bias_initializer=None, + parallel=None): + super(Linear, self).__init__() + if isinstance(parallel, ModelParallel): + if out_features % parallel.world_size != 0: + raise ValueError( + 'out_features must be divided by parallel.world_size') + self.out_features = out_features // parallel.world_size + else: + self.out_features = out_features + self.in_features = in_features + self.weight = torch.nn.Parameter( + torch.Tensor(self.out_features, self.in_features)) + if bias: + self.bias = torch.nn.Parameter(torch.Tensor(self.out_features)) + else: + self.register_parameter('bias', None) + self.weight_initializer = weight_initializer + if weight_initializer is None: + self.weight_initializer = KaimingUniformInitializer(math.sqrt(5)) + self.bias_initializer = bias_initializer + if bias_initializer is None: + self.bias_initializer = BiasUniformInitializer(self.in_features) + self.reset_parameters() + self.parallel = parallel + + def reset_parameters(self): + r"""Reset parameter.""" + self.weight_initializer(self.weight) + if self.bias is not None: + self.bias_initializer(self.bias) + + def forward(self, features): # pylint: disable=arguments-differ + features = features.type(dtype=self.weight.dtype) + return torch.nn.functional.linear(features, self.weight, self.bias) + + +class ArcFaceLinear(torch.nn.Module): + r"""Applies a ArcFace transformation to the incoming data. + See https://arxiv.org/abs/1801.05599 . + """ + + def __init__( + self, + in_features, + out_features, + margin=0.5, + scale=64.0, # See normface https://arxiv.org/abs/1704.06369 + fast_phi=False, + epsilon=0, + weight_initializer=None, + l2_norm=False, + parallel=None): + super(ArcFaceLinear, self).__init__() + if isinstance(parallel, ModelParallel): + if out_features % parallel.world_size != 0: + raise ValueError( + 'out_features must be divided by parallel.world_size') + self.out_features = out_features // parallel.world_size + else: + self.out_features = out_features + self.in_features = in_features + self.margin = margin # Angular margin penalty. + self.scale = scale # Radius of hybershpere. + self.fast_phi = fast_phi + self.epsilon = epsilon + self.weight_initializer = weight_initializer + if weight_initializer is None: + self.weight_initializer = RenormUniformInitializer() + self.l2_norm = l2_norm + self.parallel = parallel + + self.weight = torch.nn.Parameter( + torch.Tensor(self.out_features, self.in_features)) + self.reset_parameters() + + self._cos_margin = math.cos(margin) + self._sin_margin = math.sin(margin) + self._threshold = math.cos(math.pi - margin) + self._min = math.sin(math.pi - margin) * self.margin + + def reset_parameters(self): + r"""Reset parameters.""" + self.weight_initializer(self.weight) + + def forward(self, features, target): # pylint: disable=arguments-differ + r"""Compute ::math`\phi = \cos(\theta + margin)` and logits.""" + # (N, E) x (E, C) -> (N, C) + features = features.type(dtype=self.weight.dtype) + if self.l2_norm: + features_norm = torch.norm(features, 2, 1, True) + features = torch.div(features, features_norm) + weight_norm = torch.norm(self.weight, 2, 0, True) + weight = torch.div(self.weight, weight_norm) + else: + features = torch.nn.functional.normalize(features) + weight = torch.nn.functional.normalize(self.weight) + cosine = torch.nn.functional.linear(features, weight) + cosine = cosine.clamp(-1, 1) # for numerical stability + sine = torch.sqrt(1. + self.epsilon - cosine * cosine) + phi = cosine * self._cos_margin - sine * self._sin_margin + phi = phi.type(dtype=cosine.dtype) + if self.fast_phi: + phi = torch.where(cosine > 0, phi, cosine) + else: + phi = torch.where(cosine > self._threshold, phi, + cosine - self._min) + if isinstance(self.parallel, ModelParallel): + mask = self.parallel.correct_mask(target, cosine) + else: + mask = torch.zeros( + cosine.size(), device=cosine.device, dtype=cosine.dtype) + mask.scatter_(1, target.view(-1, 1).long(), 1) + logits = (mask * phi) + ((1.0 - mask) * cosine) + logits *= self.scale + return logits diff --git a/easycv/core/sailfish/loss.py b/easycv/core/sailfish/loss.py new file mode 100644 index 00000000..f1dd2cb4 --- /dev/null +++ b/easycv/core/sailfish/loss.py @@ -0,0 +1,206 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss modules.""" + +from __future__ import absolute_import, division, print_function + +import torch + +from easycv.core.sailfish.activation import LogSoftmax +from easycv.core.sailfish.linear import ArcFaceLinear, Linear +from easycv.core.sailfish.util import ModelParallel + + +class NLLLoss(torch.nn.Module): + r"""The negative log likelihood loss for log probabilities. It is + useful to train a classification problem with `C` classes. + + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case (described later). + + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an + extra layer. + + The `target` that this loss expects should be a class index in the range + :math:`[0, C-1]` where `C = number\_classes`. + + NLLLoss is defined as: + .. math:: + \ell(x, y) = -\frac{1}{N}\sum_{n=1}^N L_{i} + + Args: + num_classes: total number of classes. + focal: whether to use FocalLoss implementation. + focal_gamm: The focusing parameter of FocalLoss. + rank: rank of current replica. + world_size: size of replicas. + + Shape: + - Input: :math:`(\frac{N}{P}, C)` where `C = number of classes`, or + :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: :math:`(\frac{N}{P}, 1)` where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of + K-dimensional loss. + - Output: scalar. + If :attr:`reduction` is ``'none'``, then the same size as the target: + :math:`(N)`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in + the case of K-dimensional loss. + Examples:: + >>> m = LogSoftmax(...) + >>> loss = NLLLoss(...) + >>> # input is of size N x C = 3 x 5 + >>> input = torch.randn(3, 5, requires_grad=True) + >>> # each element in target has to have 0 <= value < C + >>> target = torch.tensor([1, 0, 4]) + >>> output = loss(m(input), target) + >>> output.backward() + """ + + def __init__(self, focal=False, focal_gamma=0, parallel=None): + super(NLLLoss, self).__init__() + self.focal = focal + self.focal_gamma = focal_gamma + self.parallel = parallel + + def forward(self, logprob, target): # pylint: disable=arguments-differ + """Compute negative log likelihood loss from log-probs and target.""" + if isinstance(self.parallel, ModelParallel): + with torch.no_grad(): + mask = self.parallel.correct_mask(target, logprob) + loss = self.parallel.nll_loss(logprob, mask) + if self.focal: + loss_exp = torch.exp(-loss) + loss = (1 - loss_exp)**self.focal_gamma * loss + return loss + loss = torch.nn.functional.nll_loss(logprob, target) + if self.focal: + loss_exp = torch.exp(-loss) + loss = (1 - loss_exp)**self.focal_gamma * loss + return loss + + +class CrossEntropyLoss(torch.nn.Module): + r"""This criterion combines :func:`LogSoftmax` and + :func:`NLLLoss` in one single class. + """ + + def __init__(self, epsilon=0, parallel=None): + super(CrossEntropyLoss, self).__init__() + self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel) + self._nll_loss = NLLLoss(parallel=parallel) + + def forward(self, logits, target): # pylint: disable=arguments-differ + # 1. Compute log-probabilities for current shard: + logprob = self._log_softmax(logits) + + # 2. Compute NLL loss for all shards. + return self._nll_loss(logprob, target) + + +class FocalLoss(torch.nn.Module): + r"""This criterion combines :func:`LogSoftmax` and + :func:`NLLLoss` in one single class. + """ + + def __init__(self, gamma=0, epsilon=1e-8, parallel=None): + super(FocalLoss, self).__init__() + self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel) + self._nll_loss = NLLLoss( + focal=True, focal_gamma=gamma, parallel=parallel) + + def forward(self, logits, target): # pylint: disable=arguments-differ + logprob = self._log_softmax(logits) + return self._nll_loss(logprob, target) + + +class SoftmaxLoss(torch.nn.Module): + r"""This criterion combines :func:`Linear` and + :func:`CrossEntropyLoss` in one single class. + """ + + def __init__(self, + in_features, + out_features, + bias=True, + epsilon=0, + weight_initializer=None, + bias_initializer=None, + parallel=None): + super(SoftmaxLoss, self).__init__() + self._linear = Linear( + in_features, + out_features, + bias=bias, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + parallel=parallel) + self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel) + self._nll_loss = NLLLoss(parallel=parallel) + self._parallel = parallel + + def forward(self, features, target): # pylint: disable=arguments-differ + if isinstance(self._parallel, ModelParallel): + features = self._parallel.gather(features) + logits = self._linear(features.squeeze()) + logprob = self._log_softmax(logits) + if isinstance(self._parallel, ModelParallel): + target = self._parallel.gather_target(target) + return self._nll_loss(logprob, target) + + +class ArcMarginLoss(torch.nn.Module): + r"""This criterion combines :func:`ArcFaceLinear` and + :func:`CrossEntropyLoss` in one single class. + """ + + def __init__( + self, + in_features, + out_features, + margin=0.5, + scale=64.0, # See normface https://arxiv.org/abs/1704.06369 + fast_phi=False, + epsilon=0, + weight_initializer=None, + l2_norm=False, + parallel=None): + super(ArcMarginLoss, self).__init__() + self._linear = ArcFaceLinear( + in_features, + out_features, + margin=margin, + scale=scale, + l2_norm=l2_norm, + fast_phi=fast_phi, + epsilon=epsilon, + weight_initializer=weight_initializer, + parallel=parallel) + self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel) + self._nll_loss = NLLLoss(parallel=parallel) + self._parallel = parallel + + def forward(self, features, target): # pylint: disable=arguments-differ + if isinstance(self._parallel, ModelParallel): + features = self._parallel.gather(features) + target = self._parallel.gather_target(target) + logits = self._linear(features.squeeze(), target) + logprob = self._log_softmax(logits) + return self._nll_loss(logprob, target) diff --git a/easycv/core/sailfish/util.py b/easycv/core/sailfish/util.py new file mode 100644 index 00000000..57155cbc --- /dev/null +++ b/easycv/core/sailfish/util.py @@ -0,0 +1,172 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility modules.""" + +from __future__ import absolute_import, division, print_function +import math + +import torch + +from easycv.core.sailfish.function import (all_cat, all_log_softmax, + all_nll_loss, all_sum, + shard_correct_mask, + shard_correct_predictions, + shard_target_and_mask, + shard_topk_correct_predictions) + + +class DistributedParallel: + """Base class of parallelism.""" + + def __init__(self, rank, world_size): + self._rank = rank + self._world_size = world_size + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def correct_mask(self, target, inputs): + mask = torch.zeros( + inputs.size(), device=inputs.device, dtype=inputs.dtype) + mask.scatter_(1, target.view(-1, 1).long(), 1) + return mask + + def correct_predictions(self, target, logits, k=1): + if k == 1: + pred = torch.max(logits, dim=1)[1] + return (pred == target.view(-1, 1)).sum().item() + pred = torch.topk(logits, k, dim=1)[1] + return (pred == target.view(-1, 1)).sum().item() + + def xavier_uniform_(self, weight, gain=1.): + return torch.nn.init.xavier_uniform_(weight, gain=gain) + + +class ModelParallel(DistributedParallel): + """All-to-All Model Parallelism.""" + + def gather(self, inputs, dim=0, requires_grad=True): + if requires_grad: + return all_cat( + inputs, dim=dim, rank=self.rank, world_size=self.world_size) + all_inputs = [ + torch.zeros( + inputs.size(), dtype=inputs.dtype, device=inputs.device) + for _ in range(self.world_size) + ] + torch.distributed.all_gather(all_inputs, inputs) + return torch.cat(all_inputs, dim=dim) + + def gather_target(self, target): + return self.gather(target, requires_grad=False) + + def reduce_sum(self, inputs): + return all_sum(inputs) + + def log_softmax(self, logits, epsilon=1e-8): + return all_log_softmax(logits, epsilon=epsilon) + + def nll_loss(self, inputs, correct_mask): + return all_nll_loss(inputs, correct_mask) + + def correct_mask(self, target, inputs): + return shard_correct_mask(target, inputs, rank=self.rank) + + def target_and_mask(self, target, output_features): + return shard_target_and_mask(target, output_features, rank=self.rank) + + def correct_predictions(self, target, logits, k=1): + if k == 1: + return shard_correct_predictions( + target, logits, world_size=self.world_size) + return shard_topk_correct_predictions( + target, logits, k, world_size=self.world_size) + + +class ParameterInitializer: + r"""Base class for parameter initializer.""" + + def __call__(self, param, shard_rank=0, num_shards=1): + raise NotImplementedError + + +class ZerosInitializer(ParameterInitializer): + + def __call__(self, param, parallel=None): + torch.nn.init.zeros_(param) + + +class OnesInitializer(ParameterInitializer): + + def __call__(self, param, parallel=None): + torch.nn.init.ones_(param) + + +class XavierUniformInitializer(ParameterInitializer): + + def __init__(self, gain=1.): + self.gain = gain + + def __call__(self, param, parallel=None): + if isinstance(parallel, ModelParallel): + if param.dim() != 2: + raise ValueError( + 'param with dimensions other than 2 not supported') + r = param.size(1) + param.size(0) * parallel.world_size + a = self.gain * math.sqrt(3.0) * math.sqrt(2.0 / float(r)) + torch.nn.init.uniform_(param, -a, a) + else: + torch.nn.init.xavier_uniform_(param, gain=self.gain) + + +class KaimingUniformInitializer(ParameterInitializer): + + def __init__(self, bound): + self.bound = bound + + def __call__(self, param, parallel=None): + torch.nn.init.kaiming_uniform_(param, a=self.bound) + + +class BiasUniformInitializer(ParameterInitializer): + + def __init__(self, weight_in_features): + self.weight_in_features = weight_in_features + + def __call__(self, param, parallel=None): + a = 1 / math.sqrt(float(self.weight_in_features)) + torch.nn.init.uniform_(param, -a, a) + + +class RenormUniformInitializer(ParameterInitializer): + + def __init__(self, maxnorm=1e-5, scale=1e5): + self.maxnorm = maxnorm + self.scale = scale + + def __call__(self, param, parallel=None): + param.data.uniform_(-1, 1).renorm_( + 2, 0, maxnorm=self.maxnorm).mul_(self.scale) + + +class NormalInitializer(ParameterInitializer): + + def __call__(self, param, parallel=None): + torch.nn.init.normal_(param) diff --git a/tests/core/sailfish/__init__.py b/tests/core/sailfish/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/sailfish/test_arcface.py b/tests/core/sailfish/test_arcface.py new file mode 100644 index 00000000..fdeb5f43 --- /dev/null +++ b/tests/core/sailfish/test_arcface.py @@ -0,0 +1,528 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ArcFaceLinear module tests.""" + +import os +import random +import unittest + +import numpy as np +import torch + +from easycv.core import sailfish + + +def mp_vs_ddp_main(gpu, gpus_per_worker): + r"""Model parallel vs. DDP""" + torch.cuda.set_device(gpu) + torch.distributed.init_process_group( + 'nccl', rank=gpu, world_size=gpus_per_worker) + try: + num_steps = 5 + freeze_num_steps = 5 + learning_rate = 0.1 + batch_size = 2 + image_size = 5 + emb_size = 3 + num_classes = 8 + margin_m = 0.5 + margin_s = 64 + momentum = 0.9 + model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker) + + zeros_init = sailfish.ZerosInitializer() + + # baseline + torch.manual_seed(42) + random.seed(42) + baseline_fe = torch.nn.Linear(image_size, emb_size).cuda() + baseline_fe = torch.nn.parallel.DistributedDataParallel( + baseline_fe, device_ids=[gpu]) + baseline_fe_params = list(baseline_fe.parameters()) + baseline_fc = sailfish.ArcFaceLinear( + emb_size, + num_classes, + margin=margin_m, + scale=margin_s, + weight_initializer=zeros_init).cuda() + baseline_fc = torch.nn.parallel.DistributedDataParallel( + baseline_fc, device_ids=[gpu]) + baseline_fc_params = list(baseline_fc.parameters()) + baseline_criterion = torch.nn.CrossEntropyLoss().cuda() + baseline_optimizer = torch.optim.SGD( + [{ + 'params': baseline_fe.parameters() + }, { + 'params': baseline_fc.parameters() + }], + lr=learning_rate, + momentum=momentum) + baseline_fe.train() + baseline_fc.train() + baseline_criterion.train() + + # hybrid parallelism + torch.manual_seed(42) + random.seed(42) + fe = torch.nn.Linear(image_size, emb_size).cuda() + fe = torch.nn.parallel.DistributedDataParallel(fe, device_ids=[gpu]) + fe_params = list(fe.parameters()) + fc = sailfish.ArcFaceLinear( + emb_size, + num_classes, + margin=margin_m, + scale=margin_s, + weight_initializer=zeros_init, + parallel=model_parallel).cuda() + fc_params = list(fc.parameters()) + criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda() + optimizer = torch.optim.SGD([{ + 'params': fe.parameters() + }, { + 'params': fc.parameters() + }], + lr=learning_rate, + momentum=momentum) + fe.train() + fc.train() + criterion.train() + + for step in range(num_steps): + # baseline + torch.manual_seed(42 * step + gpu) + random.seed(42 * step + gpu) + baseline_data = torch.randn([batch_size, image_size]).cuda() + baseline_label = torch.as_tensor([ + random.randint(0, num_classes - 1) for _ in range(batch_size) + ]).cuda() + baseline_features = baseline_fe(baseline_data) + baseline_logits = baseline_fc(baseline_features, baseline_label) + baseline_loss = baseline_criterion(baseline_logits, baseline_label) + baseline_loss = model_parallel.reduce_sum(baseline_loss) + baseline_loss = baseline_loss / gpus_per_worker + baseline_optimizer.zero_grad() + baseline_loss.backward() + baseline_optimizer.step() + + # hybrid parallelism + torch.manual_seed(42 * step + gpu) + random.seed(42 * step + gpu) + data = torch.randn([batch_size, image_size]).cuda() + label = torch.as_tensor([ + random.randint(0, num_classes - 1) for _ in range(batch_size) + ]).cuda() + features = fe(data) + all_features = model_parallel.gather(features) + all_label = model_parallel.gather_target(label) + shard_logits = fc(all_features, all_label) + loss = criterion(shard_logits, all_label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # eval + torch.manual_seed(42 * step + gpu) + random.seed(42 * step + gpu) + with torch.no_grad(): + gathered_logits = model_parallel.gather(shard_logits, dim=1) + gathered_baseline_logits = model_parallel.gather( + baseline_logits, dim=0) + logits_norm_val = torch.norm(gathered_logits).item() + baseline_logits_norm_val = torch.norm( + gathered_baseline_logits).item() + np.testing.assert_allclose( + logits_norm_val, + baseline_logits_norm_val, + rtol=1e-5, + atol=1e-4, + err_msg='logits at gpu {} step {}'.format(gpu, step)) + + loss_val = loss.cpu().detach().numpy() + baseline_loss_val = baseline_loss.cpu().detach().numpy() + np.testing.assert_allclose( + loss_val, + baseline_loss_val, + rtol=1e-5, + atol=1e-4, + err_msg='loss at gpu {} step {}'.format(gpu, step)) + + fc_grad = model_parallel.gather(fc_params[0].grad) + baseline_fc_grad = baseline_fc_params[0].grad + np.testing.assert_allclose( + fc_grad.cpu().detach().numpy(), + baseline_fc_grad.cpu().detach().numpy(), + rtol=1e-5, + atol=1e-4, + err_msg='fc grad at gpu {} step {}'.format(gpu, step)) + + fe_weight = fe_params[0] + baseline_fe_weight = baseline_fe_params[0] + np.testing.assert_allclose( + fe_weight.cpu().detach().numpy(), + baseline_fe_weight.cpu().detach().numpy(), + rtol=1e-5, + atol=1e-4, + err_msg='fe weight at gpu {} step {}'.format(gpu, step)) + + for p in baseline_fe.parameters(): + p.requires_grad = False + for p in fe.parameters(): + p.requires_grad = False + for step in range(freeze_num_steps): + # baseline + torch.manual_seed(100 * step + gpu) + random.seed(100 * step + gpu) + baseline_data = torch.randn([batch_size, image_size]).cuda() + baseline_label = torch.as_tensor([ + random.randint(0, num_classes - 1) for _ in range(batch_size) + ]).cuda() + baseline_features = baseline_fe(baseline_data) + baseline_logits = baseline_fc(baseline_features, baseline_label) + baseline_loss = baseline_criterion(baseline_logits, baseline_label) + baseline_loss = model_parallel.reduce_sum(baseline_loss) + baseline_loss = baseline_loss / gpus_per_worker + baseline_optimizer.zero_grad() + baseline_loss.backward() + baseline_optimizer.step() + + # hybrid parallelism + torch.manual_seed(100 * step + gpu) + random.seed(100 * step + gpu) + data = torch.randn([batch_size, image_size]).cuda() + label = torch.as_tensor([ + random.randint(0, num_classes - 1) for _ in range(batch_size) + ]).cuda() + features = fe(data) + all_features = model_parallel.gather(features) + all_label = model_parallel.gather_target(label) + shard_logits = fc(all_features, all_label) + loss = criterion(shard_logits, all_label) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # eval + torch.manual_seed(100 * step + gpu) + random.seed(100 * step + gpu) + with torch.no_grad(): + gathered_logits = model_parallel.gather(shard_logits, dim=1) + gathered_baseline_logits = model_parallel.gather( + baseline_logits, dim=0) + logits_norm_val = torch.norm(gathered_logits).item() + baseline_logits_norm_val = torch.norm( + gathered_baseline_logits).item() + np.testing.assert_allclose( + logits_norm_val, + baseline_logits_norm_val, + rtol=1e-5, + atol=1e-4, + err_msg='freeze logits at gpu {} step {}'.format( + gpu, step)) + + loss_val = loss.cpu().detach().numpy() + baseline_loss_val = baseline_loss.cpu().detach().numpy() + np.testing.assert_allclose( + loss_val, + baseline_loss_val, + rtol=1e-5, + atol=1e-4, + err_msg='freeze loss at gpu {} step {}'.format(gpu, step)) + + fc_grad = model_parallel.gather(fc_params[0].grad) + baseline_fc_grad = baseline_fc_params[0].grad + np.testing.assert_allclose( + fc_grad.cpu().detach().numpy(), + baseline_fc_grad.cpu().detach().numpy(), + rtol=1e-5, + atol=1e-4, + err_msg='freeze fc grad at gpu {} step {}'.format( + gpu, step)) + + fe_weight = fe_params[0] + baseline_fe_weight = baseline_fe_params[0] + np.testing.assert_allclose( + fe_weight.cpu().detach().numpy(), + baseline_fe_weight.cpu().detach().numpy(), + rtol=1e-5, + atol=1e-4, + err_msg='freeze fe weight at gpu {} step {}'.format( + gpu, step)) + + finally: + torch.distributed.destroy_process_group() + + +def mp_main(gpu, + gpus_per_worker, + results, + num_steps=1, + batch_size=1, + num_classes=8): + r"""Model parallel""" + torch.cuda.set_device(gpu) + torch.distributed.init_process_group( + 'nccl', rank=gpu, world_size=gpus_per_worker) + zeros_init = sailfish.ZerosInitializer() + try: + emb_size = 3 + learning_rate = 0.1 + margin_m = 0.5 + margin_s = 64 + momentum = 0.9 + image_size = 6 + model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker) + + # hybrid parallelism + torch.manual_seed(42) + random.seed(42) + fe = torch.nn.Linear(image_size, emb_size).cuda() + fc = sailfish.ArcFaceLinear( + emb_size, + num_classes, + margin=margin_m, + scale=margin_s, + weight_initializer=zeros_init, + parallel=model_parallel).cuda() + fc_params = list(fc.parameters()) + criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda() + optimizer = torch.optim.SGD( + fc.parameters(), lr=learning_rate, momentum=momentum) + fc.train() + criterion.train() + + for step in range(num_steps): + baseline = results[step] + torch.manual_seed(42 * step + gpu) + random.seed(42 * step + gpu) + data = torch.randn([batch_size, image_size]).cuda() + features = fe(data) + label = torch.as_tensor([ + random.randint(0, num_classes - 1) for _ in range(batch_size) + ]).cuda() + all_features = model_parallel.gather(features) + all_label = model_parallel.gather_target(label) + torch.manual_seed(42 * step) + random.seed(42 * step) + np.testing.assert_equal( + list(all_features.size()), + baseline['features/size'], + err_msg='Wrong features size at gpu {} step {}'.format( + gpu, step)) + np.testing.assert_allclose( + torch.norm(all_features).item(), + baseline['features/norm'], + rtol=1e-5, + err_msg='Wrong features norm at gpu {} step {}'.format( + gpu, step)) + shard_logits = fc(all_features, all_label) + loss = criterion(shard_logits, all_label) + np.testing.assert_allclose( + loss.item(), + baseline['loss'], + rtol=1e-5, + err_msg='Wrong loss at gpu {} step {}'.format(gpu, step)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + fc_grad = model_parallel.gather(fc_params[0].grad) + np.testing.assert_allclose( + torch.norm(fc_grad).item(), + baseline['logits/grad/norm'], + rtol=1e-5, + err_msg='Wrong logits grad at gpu {} step {}'.format( + gpu, step)) + + finally: + torch.distributed.destroy_process_group() + + +def baseline_main(gpus_per_worker, num_steps=1, batch_size=1, num_classes=8): + r"""run on 1 GPU""" + emb_size = 3 + learning_rate = 0.1 + momentum = 0.9 + image_size = 6 + + zeros_init = sailfish.ZerosInitializer() + + # hybrid parallelism + torch.manual_seed(42) + random.seed(42) + fe = torch.nn.Linear(image_size, emb_size).cuda() + fc = sailfish.ArcFaceLinear( + emb_size, num_classes, weight_initializer=zeros_init).cuda() + fc_params = list(fc.parameters()) + criterion = torch.nn.CrossEntropyLoss().cuda() + optimizer = torch.optim.SGD( + fc.parameters(), lr=learning_rate, momentum=momentum) + fc.train() + criterion.train() + + results = [] + for step in range(num_steps): + result_item = {} + features_list = [] + label_list = [] + for gpu in range(gpus_per_worker): + torch.manual_seed(42 * step + gpu) + random.seed(42 * step + gpu) + features_list.append( + fe(torch.randn([batch_size, image_size]).cuda())) + label_list.append( + torch.as_tensor([ + random.randint(0, num_classes - 1) + for _ in range(batch_size) + ]).cuda()) + all_features = torch.cat(features_list) + all_label = torch.cat(label_list) + torch.manual_seed(42 * step) + random.seed(42 * step) + result_item['features/size'] = list(all_features.size()) + result_item['features/norm'] = torch.norm(all_features).item() + logits = fc(all_features, all_label) + loss = criterion(logits, all_label) + result_item['loss'] = loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + result_item['logits/grad/norm'] = torch.norm(fc_params[0].grad).item() + results.append(result_item) + return results + + +class TestArcFaceLinear(unittest.TestCase): + r"""Test sailfish.ArcFaceLinear.""" + + def test_no_parallel(self): + r"""Test sailfish.ArcFaceLinear without parallel.""" + in_features = 1 + out_features = 2 + margin_m = random.random() + margin_s = random.random() + + features = torch.randn([1, in_features]) + label = torch.as_tensor([random.randint(0, out_features - 1)]) + + torch.manual_seed(42) + random.seed(42) + baseline = sailfish.ArcFaceLinear( + in_features, out_features, margin=margin_m, scale=margin_s) + baseline_optimizer = torch.optim.SGD(baseline.parameters(), lr=1.) + baseline.train() + baseline_logits = baseline(features, label) + baseline_loss = torch.sum(baseline_logits) + baseline_optimizer.zero_grad() + baseline_loss.backward() + baseline_optimizer.step() + + torch.manual_seed(42) + random.seed(42) + fc = sailfish.ArcFaceLinear( + in_features, out_features, margin=margin_m, scale=margin_s) + optimizer = torch.optim.SGD(fc.parameters(), lr=1.) + fc.train() + logits = fc(features, label) + loss = torch.sum(logits) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + np.testing.assert_allclose( + logits.detach().numpy(), + baseline_logits.detach().numpy(), + err_msg='logits not equal to baseline') + np.testing.assert_allclose( + [p.detach().numpy() for p in baseline.parameters()], + [p.detach().numpy() for p in fc.parameters()], + err_msg='parameters not equal to baseline') + + def test_mp(self): + r"""Test sailfish.ArcFaceLinear on 1 GPU.""" + in_features = 1 + out_features = 2 + + features = torch.randn([1, in_features]) + label = torch.as_tensor([random.randint(0, out_features - 1)]) + + torch.manual_seed(42) + random.seed(42) + baseline = sailfish.ArcFaceLinear(in_features, out_features) + baseline_optimizer = torch.optim.SGD(baseline.parameters(), lr=1.) + baseline.train() + baseline_logits = baseline(features, label) + baseline_loss = torch.sum(baseline_logits) + baseline_optimizer.zero_grad() + baseline_loss.backward() + baseline_optimizer.step() + + torch.manual_seed(42) + random.seed(42) + model_parallel = sailfish.ModelParallel(0, 1) + fc = sailfish.ArcFaceLinear( + in_features, out_features, parallel=model_parallel) + optimizer = torch.optim.SGD(fc.parameters(), lr=1.) + fc.train() + logits = fc(features, label) + loss = torch.sum(logits) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + np.testing.assert_allclose( + logits.detach().numpy(), + baseline_logits.detach().numpy(), + err_msg='logits not equal to baseline') + np.testing.assert_allclose( + [p.detach().numpy() for p in baseline.parameters()], + [p.detach().numpy() for p in fc.parameters()], + err_msg='parameters not equal to baseline') + + def cant_test_mp_vs_ddp(self): + r"""Test sailfish.ArcFaceLinear with model parallel.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '24601' + os.environ['WORLD_SIZE'] = '1' + os.environ['RANK'] = '0' + + gpus_per_worker = torch.cuda.device_count() + torch.multiprocessing.spawn( + mp_vs_ddp_main, + args=(gpus_per_worker, ), + nprocs=gpus_per_worker, + join=True) + + def test_mp_vs_1gpu(self): + r"""Test sailfish.ArcFaceLinear with model parallel.""" + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '24601' + os.environ['WORLD_SIZE'] = '1' + os.environ['RANK'] = '0' + + gpus_per_worker = torch.cuda.device_count() + num_steps = 5 + batch_size = 1 + num_classes = gpus_per_worker + results = baseline_main(gpus_per_worker, num_steps, batch_size, + num_classes) + torch.multiprocessing.spawn( + mp_main, + args=(gpus_per_worker, results, num_steps, batch_size, + num_classes), + nprocs=gpus_per_worker, + join=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/core/sailfish/test_linear.py b/tests/core/sailfish/test_linear.py new file mode 100644 index 00000000..8bf1023c --- /dev/null +++ b/tests/core/sailfish/test_linear.py @@ -0,0 +1,314 @@ +# Copyright 2019 Alibaba Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear module tests.""" + +import math +import os +import random +import unittest + +import numpy as np +import torch + +from easycv.core import sailfish + + +class MockLinear(torch.nn.Module): + r"""Applies a linear transformation to the incoming data. + """ + + def __init__(self, + in_features, + out_features, + bias=True, + weight_initializer=None, + bias_initializer=None, + parallel=None): + super(MockLinear, self).__init__() + self.out_features = out_features + self.in_features = in_features + self.weight = torch.nn.Parameter( + torch.Tensor(self.out_features, self.in_features)) + if bias: + self.bias = torch.nn.Parameter(torch.Tensor(self.out_features)) + else: + self.register_parameter('bias', None) + self.weight_initializer = weight_initializer + if weight_initializer is None: + self.weight_initializer = sailfish.KaimingUniformInitializer( + math.sqrt(5)) + self.bias_initializer = bias_initializer + if bias_initializer is None: + self.bias_initializer = sailfish.BiasUniformInitializer( + self.in_features) + self.reset_parameters() + self.parallel = parallel + + def reset_parameters(self): + r"""Reset parameters.""" + self.weight_initializer(self.weight) + if self.bias is not None: + self.bias_initializer(self.bias) + + def forward(self, features): # pylint: disable=arguments-differ + features = features.type(dtype=self.weight.dtype) + return torch.nn.functional.linear(features, self.weight, self.bias) + + +def _run_baseline_train_main(gpus_per_worker, num_steps, batch_size, + in_features, out_features, bias, lr): + r"""Run baseline on 1 GPU.""" + torch.manual_seed(42) + random.seed(42) + fc = MockLinear( + in_features, + out_features, + bias=bias, + weight_initializer=sailfish.ZerosInitializer(), + bias_initializer=sailfish.OnesInitializer()).cuda() + criterion = torch.nn.CrossEntropyLoss().cuda() + optimizer = torch.optim.SGD(fc.parameters(), lr=lr) + fc.train() + criterion.train() + + results = [] + for step in range(num_steps): + result = {} + features_list = [] + label_list = [] + for gpu in range(gpus_per_worker): + torch.manual_seed(42 * step + gpu) + random.seed(42 * step + gpu) + features_list.append(torch.randn([batch_size, in_features]).cuda()) + label_list.append( + torch.as_tensor([ + random.randint(0, out_features - 1) + for _ in range(batch_size) + ]).cuda()) + features = torch.cat(features_list) + label = torch.cat(label_list) + + torch.manual_seed(42 * step) + random.seed(42 * step) + logits = fc(features) + loss = criterion(logits, label) + result['loss'] = loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + result['grads/norm'] = [ + torch.norm(p.grad).item() for p in fc.parameters() + ] + results.append(result) + return results + + +def _run_mp_train_main(gpu, gpus_per_worker, baseline_steps, num_steps, + batch_size, in_features, out_features, bias, lr): + r"""Run MP and validate results.""" + torch.cuda.set_device(gpu) + torch.distributed.init_process_group( + 'nccl', rank=gpu, world_size=gpus_per_worker) + + torch.manual_seed(42) + random.seed(42) + model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker) + fc = sailfish.Linear( + in_features, + out_features, + bias=bias, + weight_initializer=sailfish.ZerosInitializer(), + bias_initializer=sailfish.OnesInitializer(), + parallel=model_parallel).cuda() + criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda() + optimizer = torch.optim.SGD(fc.parameters(), lr=lr) + fc.train() + criterion.train() + + for step in range(num_steps): + torch.manual_seed(42 * step + gpu) + random.seed(42 * step + gpu) + features = torch.randn([batch_size, in_features]).cuda() + features = model_parallel.gather(features) + label = torch.as_tensor([ + random.randint(0, out_features - 1) for _ in range(batch_size) + ]).cuda() + label = model_parallel.gather_target(label) + + torch.manual_seed(42 * step) + random.seed(42 * step) + logits = fc(features) + loss = criterion(logits, label) + np.testing.assert_allclose( + loss.item(), + baseline_steps[step]['loss'], + rtol=1e-5, + err_msg='Wrong loss at gpu {} step {}'.format(gpu, step)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + grad_norms = [ + torch.norm(model_parallel.gather(p.grad)).item() + for p in fc.parameters() + ] + np.testing.assert_allclose( + grad_norms, + baseline_steps[step]['grads/norm'], + rtol=1e-5, + err_msg='Wrong grads norm at gpu {} step {}'.format(gpu, step)) + + +class TestLinear(unittest.TestCase): + r"""Test sailfish.Linear.""" + + def _run_baseline_train(self, batch_size, in_features, out_features, bias, + lr): + r"""Run baseline without parallel.""" + result = {} + features = torch.randn([batch_size, in_features]) + fc = torch.nn.Linear(in_features, out_features, bias=bias) + optimizer = torch.optim.SGD(fc.parameters(), lr=lr) + fc.train() + logits = fc(features) + loss = torch.sum(logits) + result['loss'] = loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + result['grads/norm'] = [ + torch.norm(p.grad).item() for p in fc.parameters() + ] + return result + + def _run_mp_no_parallel_train(self, batch_size, in_features, out_features, + bias, lr): + r"""Run MP without parallel.""" + result = {} + features = torch.randn([batch_size, in_features]) + fc = sailfish.Linear(in_features, out_features, bias=bias) + optimizer = torch.optim.SGD(fc.parameters(), lr=lr) + fc.train() + logits = fc(features) + loss = torch.sum(logits) + result['loss'] = loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + result['grads/norm'] = [ + torch.norm(p.grad).item() for p in fc.parameters() + ] + return result + + def _run_mp_1gpu_train(self, batch_size, in_features, out_features, bias, + lr): + r"""Run MP on 1 GPU.""" + result = {} + features = torch.randn([batch_size, in_features]) + model_parallel = sailfish.ModelParallel(0, 1) + fc = sailfish.Linear( + in_features, out_features, bias=bias, parallel=model_parallel) + optimizer = torch.optim.SGD(fc.parameters(), lr=lr) + fc.train() + logits = fc(features) + loss = torch.sum(logits) + result['loss'] = loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + result['grads/norm'] = [ + torch.norm(p.grad).item() for p in fc.parameters() + ] + return result + + def test_no_parallel(self): + r"""Test sailfish.Linear without parallel.""" + batch_size = 3 + in_features = 4 + out_features = 5 + bias = False + lr = 0.1 + + for step in range(5): + torch.manual_seed(42 + step) + random.seed(42 + step) + baseline = self._run_baseline_train(batch_size, in_features, + out_features, bias, lr) + + torch.manual_seed(42 + step) + random.seed(42 + step) + rc = self._run_mp_no_parallel_train(batch_size, in_features, + out_features, bias, lr) + + np.testing.assert_allclose( + rc['loss'], baseline['loss'], err_msg='loss not equal') + np.testing.assert_allclose( + rc['grads/norm'], + baseline['grads/norm'], + err_msg='norm of grads not equal') + + def test_mp(self): + r"""Test sailfish.Linear with model parallel on 1 GPU.""" + batch_size = 2 + in_features = 7 + out_features = 4 + bias = True + lr = 0.6 + + for step in range(5): + torch.manual_seed(100 + step) + random.seed(100 + step) + baseline = self._run_baseline_train(batch_size, in_features, + out_features, bias, lr) + + torch.manual_seed(100 + step) + random.seed(100 + step) + rc = self._run_mp_1gpu_train(batch_size, in_features, out_features, + bias, lr) + + np.testing.assert_allclose( + rc['loss'], baseline['loss'], err_msg='loss not equal') + np.testing.assert_allclose( + rc['grads/norm'], + baseline['grads/norm'], + err_msg='norm of grads not equal') + + def test_mp_vs_1gpu(self): + r"""Test sailfish.ArcFaceLinear with model parallel.""" + gpus_per_worker = torch.cuda.device_count() + num_steps = 5 + batch_size = 2 + in_features = 3 + out_features = gpus_per_worker + bias = True + lr = 0.6 + + baseline_steps = _run_baseline_train_main(gpus_per_worker, num_steps, + batch_size, in_features, + out_features, bias, lr) + + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '24601' + os.environ['WORLD_SIZE'] = '1' + os.environ['RANK'] = '0' + torch.multiprocessing.spawn( + _run_mp_train_main, + args=(gpus_per_worker, baseline_steps, num_steps, batch_size, + in_features, out_features, bias, lr), + nprocs=gpus_per_worker, + join=True) + + +if __name__ == '__main__': + unittest.main()