mirror of https://github.com/alibaba/EasyCV.git
Add sailfish for fully sharded data parallel training
parent
5110de7635
commit
813a191876
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Pytorch extension for distributed training."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from easycv.core.sailfish.activation import LogSoftmax # noqa: F401
|
||||
from easycv.core.sailfish.linear import ArcFaceLinear # noqa: F401
|
||||
from easycv.core.sailfish.linear import Linear # noqa: F401
|
||||
from easycv.core.sailfish.loss import ArcMarginLoss # noqa: F401
|
||||
from easycv.core.sailfish.loss import CrossEntropyLoss # noqa: F401
|
||||
from easycv.core.sailfish.loss import FocalLoss # noqa: F401
|
||||
from easycv.core.sailfish.loss import NLLLoss # noqa: F401
|
||||
from easycv.core.sailfish.loss import SoftmaxLoss # noqa: F401
|
||||
from easycv.core.sailfish.util import BiasUniformInitializer # noqa: F401
|
||||
from easycv.core.sailfish.util import DistributedParallel # noqa: F401
|
||||
from easycv.core.sailfish.util import KaimingUniformInitializer # noqa: F401
|
||||
from easycv.core.sailfish.util import ModelParallel # noqa: F401
|
||||
from easycv.core.sailfish.util import NormalInitializer # noqa: F401
|
||||
from easycv.core.sailfish.util import OnesInitializer # noqa: F401
|
||||
from easycv.core.sailfish.util import ParameterInitializer # noqa: F401
|
||||
from easycv.core.sailfish.util import RenormUniformInitializer # noqa: F401
|
||||
from easycv.core.sailfish.util import XavierUniformInitializer # noqa: F401
|
||||
from easycv.core.sailfish.util import ZerosInitializer # noqa: F401
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Activation modules."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.core.sailfish.util import ModelParallel
|
||||
|
||||
|
||||
class LogSoftmax(torch.nn.Module):
|
||||
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
|
||||
input Tensor rescaling them so that the elements of the
|
||||
n-dimensional output Tensor lie in the range (0, 1].
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*)` where `*` means, any number of additional
|
||||
dimensions
|
||||
- Output: :math:`(*)`, same shape as the input
|
||||
|
||||
Returns:
|
||||
a Tensor of the same dimension and shape as the input with
|
||||
values in the range (-inf,0].
|
||||
|
||||
Examples::
|
||||
>>> m = LogSoftmax()
|
||||
>>> input = torch.randn(2, 3)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon=0, parallel=None):
|
||||
super(LogSoftmax, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
self.parallel = parallel
|
||||
|
||||
def forward(self, logits): # pylint: disable=arguments-differ
|
||||
if isinstance(self.parallel, ModelParallel):
|
||||
return self.parallel.log_softmax(logits, epsilon=self.epsilon)
|
||||
return torch.nn.functional.log_softmax(logits, _stacklevel=5)
|
|
@ -0,0 +1,269 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functions."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class _Cat(torch.autograd.Function):
|
||||
"""Concat inputs."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, dim, rank, world_size): # noqa: E501 # pylint: disable=arguments-differ
|
||||
r"""Cat is defined as:
|
||||
.. math::
|
||||
\text{all_cat}(x_i) = \bigoplus_j x_j
|
||||
"""
|
||||
ctx.dim = dim
|
||||
ctx.rank = rank
|
||||
ctx.world_size = world_size
|
||||
all_inputs = [
|
||||
torch.zeros(
|
||||
inputs.size(), dtype=inputs.dtype, device=inputs.device)
|
||||
for _ in range(world_size)
|
||||
]
|
||||
torch.distributed.all_gather(all_inputs, inputs)
|
||||
output = torch.cat(all_inputs, dim=dim)
|
||||
output.requires_grad_()
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pylint: disable=arguments-differ
|
||||
r"""Gradient of Cat is defined as:
|
||||
.. math::
|
||||
\nabla \text{all_cat}(x_i) = \text{split}(\nabla x_i)
|
||||
"""
|
||||
grad_input = grad_output.clone()
|
||||
torch.distributed.all_reduce(grad_input)
|
||||
grad_input_dim_size = grad_input.size()[ctx.dim]
|
||||
assert grad_input_dim_size % ctx.world_size == 0
|
||||
split_size = grad_input_dim_size // ctx.world_size
|
||||
grad_input_splits = torch.split(grad_input, split_size, dim=ctx.dim)
|
||||
return grad_input_splits[ctx.rank], None, None, None
|
||||
|
||||
|
||||
def all_cat(inputs, dim=0, rank=0, world_size=1):
|
||||
return _Cat.apply(inputs, dim, rank, world_size)
|
||||
|
||||
|
||||
class _Sum(torch.autograd.Function):
|
||||
"""Sum inputs."""
|
||||
|
||||
@staticmethod
|
||||
def forward(_, inputs): # pylint: disable=arguments-differ
|
||||
r"""Sum is defined as:
|
||||
.. math::
|
||||
\text{all_sum}(x_i) = \sum_j x_j
|
||||
"""
|
||||
inputs_sum = inputs.clone()
|
||||
torch.distributed.all_reduce(inputs_sum)
|
||||
inputs_sum.requires_grad_()
|
||||
return inputs_sum
|
||||
|
||||
@staticmethod
|
||||
def backward(_, grad_output): # pylint: disable=arguments-differ
|
||||
r"""Gradient of Sum is defined as:
|
||||
.. math::
|
||||
\nabla \text{all_sum}(x_i) = \sum_j\nabla x_j
|
||||
"""
|
||||
grad_input = grad_output.clone()
|
||||
torch.distributed.all_reduce(grad_input)
|
||||
return grad_input
|
||||
|
||||
|
||||
def all_sum(inputs):
|
||||
return _Sum.apply(inputs)
|
||||
|
||||
|
||||
class _LogSoftmax(torch.autograd.Function):
|
||||
"""Compute log softmax of logits."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits, epsilon): # pylint: disable=arguments-differ
|
||||
r"""LogSoftmax is defined as:
|
||||
.. math::
|
||||
\log(\text{softmax}(x_i))
|
||||
= \log\left(\frac{\text{e}^{x_i}}{\sum_j\text{e}^{x_j}}\right)
|
||||
= x_i - \log\sum_j\text{e}^{x_j}
|
||||
|
||||
For numerical stability, it subtracts the maximum value for every logits:
|
||||
.. math::
|
||||
\log(\text{softmax}(x_i))
|
||||
= \hat{x_i} - \log\sum_j\text{e}^{\hat{x_j}},
|
||||
\hat{x} = x - \max_j{x_j}
|
||||
"""
|
||||
ctx.logits_dtype = logits.dtype
|
||||
logits_max = torch.max(logits, dim=1).values
|
||||
torch.distributed.all_reduce(
|
||||
logits_max, op=torch.distributed.ReduceOp.MAX)
|
||||
logits = logits - logits_max.view(-1, 1)
|
||||
logits_exp = torch.exp(logits)
|
||||
logits_exp_sum = torch.sum(logits_exp, dim=1)
|
||||
torch.distributed.all_reduce(logits_exp_sum)
|
||||
logits_exp_sum_log = torch.log(logits_exp_sum + epsilon)
|
||||
prob_log = logits - logits_exp_sum_log.view(-1, 1)
|
||||
ctx.save_for_backward(prob_log)
|
||||
prob_log.requires_grad_()
|
||||
return prob_log
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pylint: disable=arguments-differ
|
||||
r"""Gradient of LogSoftmax is defined as:
|
||||
.. math::
|
||||
\nabla\log(\text{softmax}(x_i))
|
||||
= \nabla x_i - \text{softmax}(x_i) \sum_j \nabla x_j
|
||||
"""
|
||||
grad_output_sum = torch.sum(grad_output, dim=1)
|
||||
torch.distributed.all_reduce(grad_output_sum)
|
||||
prob_log, = ctx.saved_tensors
|
||||
grad_input = torch.exp(prob_log) * grad_output_sum.view(-1, 1)
|
||||
grad_input = grad_output - grad_input
|
||||
grad_input = grad_input.type(dtype=ctx.logits_dtype)
|
||||
return grad_input, None
|
||||
|
||||
|
||||
def all_log_softmax(logits, epsilon=1e-8):
|
||||
return _LogSoftmax.apply(logits, epsilon)
|
||||
|
||||
|
||||
class _NLLLoss(torch.autograd.Function):
|
||||
"""calculate NLLLoss from mask."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, correct_mask): # pylint: disable=arguments-differ
|
||||
ctx.inputs_size = inputs.size()
|
||||
ctx.save_for_backward(correct_mask)
|
||||
loss = torch.sum(inputs * correct_mask) / -ctx.inputs_size[0]
|
||||
torch.distributed.all_reduce(loss)
|
||||
loss.requires_grad_()
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pylint: disable=arguments-differ
|
||||
correct_mask, = ctx.saved_tensors
|
||||
grad_input = grad_output.repeat(*ctx.inputs_size)
|
||||
grad_input = grad_input * correct_mask / -ctx.inputs_size[0]
|
||||
return grad_input, None
|
||||
|
||||
|
||||
def all_nll_loss(inputs, correct_mask):
|
||||
return _NLLLoss.apply(inputs, correct_mask)
|
||||
|
||||
|
||||
def shard_target_and_mask(target, output_features, rank=0):
|
||||
target_shard_begin = output_features * rank
|
||||
target_shard_end = output_features * (rank + 1)
|
||||
target_shard_lmask = torch.ge(target, target_shard_begin)
|
||||
target_shard_rmask = torch.lt(target, target_shard_end)
|
||||
target_mask = (target_shard_lmask * target_shard_rmask).to(target.device)
|
||||
target_shard = (target - target_shard_begin) * target_mask.long()
|
||||
return target_shard, target_mask
|
||||
|
||||
|
||||
def shard_correct_mask(target, inputs, rank=0):
|
||||
"""Get correct mask of inputs."""
|
||||
inputs_size = inputs.size()
|
||||
target_shard_begin = inputs_size[1] * rank
|
||||
target_shard_end = inputs_size[1] * (rank + 1)
|
||||
target_shard_lmask = torch.ge(target, target_shard_begin)
|
||||
target_shard_rmask = torch.lt(target, target_shard_end)
|
||||
target_mask = (target_shard_lmask * target_shard_rmask).to(target.device)
|
||||
target_shard = (target - target_shard_begin) * target_mask.long()
|
||||
mask = torch.zeros(inputs_size, device=inputs.device, dtype=inputs.dtype)
|
||||
mask.scatter_(1, target_shard.view(-1, 1).long(), 1)
|
||||
mask.masked_fill_((~target_mask).view(-1, 1).expand(*inputs.size()), 0)
|
||||
return mask
|
||||
|
||||
|
||||
def shard_correct_predictions(target, logits, world_size=1):
|
||||
r"""Calculate correct predictions for logits."""
|
||||
shard_max_logits, shard_expected_class = torch.max(logits, dim=1)
|
||||
all_max_logits = [
|
||||
torch.zeros(
|
||||
shard_max_logits.size(),
|
||||
dtype=shard_max_logits.dtype,
|
||||
device=shard_max_logits.device) for _ in range(world_size)
|
||||
]
|
||||
torch.distributed.all_gather(all_max_logits, shard_max_logits)
|
||||
all_max_logits = torch.cat([t.view(-1, 1) for t in all_max_logits], dim=1)
|
||||
rank_pred = torch.max(all_max_logits, dim=1)[1].view(-1, 1)
|
||||
all_shard_pred = [
|
||||
torch.zeros(
|
||||
shard_expected_class.size(),
|
||||
dtype=shard_expected_class.dtype,
|
||||
device=shard_expected_class.device) for _ in range(world_size)
|
||||
]
|
||||
torch.distributed.all_gather(all_shard_pred, shard_expected_class)
|
||||
all_shard_pred = torch.cat([t.view(-1, 1) for t in all_shard_pred], dim=1)
|
||||
all_shard_pred_mask = torch.zeros(
|
||||
all_shard_pred.size(),
|
||||
device=all_shard_pred.device,
|
||||
dtype=all_shard_pred.dtype)
|
||||
all_shard_pred_mask.scatter_(1, rank_pred.long(), 1)
|
||||
shard_pred = torch.sum(
|
||||
all_shard_pred * all_shard_pred_mask, dim=1).view(-1, 1)
|
||||
pred = shard_pred + rank_pred * logits.size()[1]
|
||||
return (pred == target.data.view_as(pred)).sum().item()
|
||||
|
||||
|
||||
def shard_topk_correct_predictions(target, logits, k, world_size=1):
|
||||
r"""Calculate correct predictions for logits."""
|
||||
# Step 1: Compute top-k of shard logits.
|
||||
logits_topk, logits_topk_idx = torch.topk(logits, k, dim=1)
|
||||
all_logits_topk = [
|
||||
torch.zeros(
|
||||
logits_topk.size(),
|
||||
dtype=logits_topk.dtype,
|
||||
device=logits_topk.device) for _ in range(world_size)
|
||||
]
|
||||
torch.distributed.all_gather(all_logits_topk, logits_topk)
|
||||
all_logits_topk = torch.cat([t.view(-1, k) for t in all_logits_topk],
|
||||
dim=1)
|
||||
|
||||
all_logits_topk_idx = [
|
||||
torch.zeros(
|
||||
logits_topk_idx.size(),
|
||||
dtype=logits_topk_idx.dtype,
|
||||
device=logits_topk_idx.device) for _ in range(world_size)
|
||||
]
|
||||
torch.distributed.all_gather(all_logits_topk_idx, logits_topk_idx)
|
||||
all_logits_topk_idx = torch.cat(
|
||||
[t.view(-1, k) for t in all_logits_topk_idx], dim=1)
|
||||
|
||||
# Step 2: Compute global top-k indices.
|
||||
_, all_logits_topk_topk_idx = torch.topk(all_logits_topk, k, dim=1)
|
||||
all_logits_topk_topk_idx = all_logits_topk_topk_idx.view(-1, k)
|
||||
all_logits_topk_mask = torch.zeros(
|
||||
all_logits_topk_idx.size(),
|
||||
device=all_logits_topk_idx.device,
|
||||
dtype=all_logits_topk_idx.dtype)
|
||||
all_logits_topk_mask.scatter_(1, all_logits_topk_topk_idx.long(), 1)
|
||||
batch_size, shard_num_classes = logits.size()
|
||||
all_logits_topk_base = torch.cat([
|
||||
torch.ones([batch_size, k],
|
||||
device=all_logits_topk_idx.device,
|
||||
dtype=all_logits_topk_idx.dtype) * p * shard_num_classes
|
||||
for p in range(world_size)
|
||||
],
|
||||
dim=1)
|
||||
all_logits_topk_gidx = all_logits_topk_base + all_logits_topk_idx
|
||||
|
||||
# Step 3: Compute predictions and check.
|
||||
pred = torch.masked_select(all_logits_topk_gidx,
|
||||
all_logits_topk_mask.type(torch.bool)).view(
|
||||
batch_size, k)
|
||||
return (pred == target.view(-1, 1)).sum().item()
|
|
@ -0,0 +1,151 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Linear modules."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.core.sailfish.util import (BiasUniformInitializer,
|
||||
KaimingUniformInitializer,
|
||||
ModelParallel, RenormUniformInitializer)
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
r"""Applies a linear transformation to the incoming data.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
weight_initializer=None,
|
||||
bias_initializer=None,
|
||||
parallel=None):
|
||||
super(Linear, self).__init__()
|
||||
if isinstance(parallel, ModelParallel):
|
||||
if out_features % parallel.world_size != 0:
|
||||
raise ValueError(
|
||||
'out_features must be divided by parallel.world_size')
|
||||
self.out_features = out_features // parallel.world_size
|
||||
else:
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.Tensor(self.out_features, self.in_features))
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(self.out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.weight_initializer = weight_initializer
|
||||
if weight_initializer is None:
|
||||
self.weight_initializer = KaimingUniformInitializer(math.sqrt(5))
|
||||
self.bias_initializer = bias_initializer
|
||||
if bias_initializer is None:
|
||||
self.bias_initializer = BiasUniformInitializer(self.in_features)
|
||||
self.reset_parameters()
|
||||
self.parallel = parallel
|
||||
|
||||
def reset_parameters(self):
|
||||
r"""Reset parameter."""
|
||||
self.weight_initializer(self.weight)
|
||||
if self.bias is not None:
|
||||
self.bias_initializer(self.bias)
|
||||
|
||||
def forward(self, features): # pylint: disable=arguments-differ
|
||||
features = features.type(dtype=self.weight.dtype)
|
||||
return torch.nn.functional.linear(features, self.weight, self.bias)
|
||||
|
||||
|
||||
class ArcFaceLinear(torch.nn.Module):
|
||||
r"""Applies a ArcFace transformation to the incoming data.
|
||||
See https://arxiv.org/abs/1801.05599 .
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
margin=0.5,
|
||||
scale=64.0, # See normface https://arxiv.org/abs/1704.06369
|
||||
fast_phi=False,
|
||||
epsilon=0,
|
||||
weight_initializer=None,
|
||||
l2_norm=False,
|
||||
parallel=None):
|
||||
super(ArcFaceLinear, self).__init__()
|
||||
if isinstance(parallel, ModelParallel):
|
||||
if out_features % parallel.world_size != 0:
|
||||
raise ValueError(
|
||||
'out_features must be divided by parallel.world_size')
|
||||
self.out_features = out_features // parallel.world_size
|
||||
else:
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
self.margin = margin # Angular margin penalty.
|
||||
self.scale = scale # Radius of hybershpere.
|
||||
self.fast_phi = fast_phi
|
||||
self.epsilon = epsilon
|
||||
self.weight_initializer = weight_initializer
|
||||
if weight_initializer is None:
|
||||
self.weight_initializer = RenormUniformInitializer()
|
||||
self.l2_norm = l2_norm
|
||||
self.parallel = parallel
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.Tensor(self.out_features, self.in_features))
|
||||
self.reset_parameters()
|
||||
|
||||
self._cos_margin = math.cos(margin)
|
||||
self._sin_margin = math.sin(margin)
|
||||
self._threshold = math.cos(math.pi - margin)
|
||||
self._min = math.sin(math.pi - margin) * self.margin
|
||||
|
||||
def reset_parameters(self):
|
||||
r"""Reset parameters."""
|
||||
self.weight_initializer(self.weight)
|
||||
|
||||
def forward(self, features, target): # pylint: disable=arguments-differ
|
||||
r"""Compute ::math`\phi = \cos(\theta + margin)` and logits."""
|
||||
# (N, E) x (E, C) -> (N, C)
|
||||
features = features.type(dtype=self.weight.dtype)
|
||||
if self.l2_norm:
|
||||
features_norm = torch.norm(features, 2, 1, True)
|
||||
features = torch.div(features, features_norm)
|
||||
weight_norm = torch.norm(self.weight, 2, 0, True)
|
||||
weight = torch.div(self.weight, weight_norm)
|
||||
else:
|
||||
features = torch.nn.functional.normalize(features)
|
||||
weight = torch.nn.functional.normalize(self.weight)
|
||||
cosine = torch.nn.functional.linear(features, weight)
|
||||
cosine = cosine.clamp(-1, 1) # for numerical stability
|
||||
sine = torch.sqrt(1. + self.epsilon - cosine * cosine)
|
||||
phi = cosine * self._cos_margin - sine * self._sin_margin
|
||||
phi = phi.type(dtype=cosine.dtype)
|
||||
if self.fast_phi:
|
||||
phi = torch.where(cosine > 0, phi, cosine)
|
||||
else:
|
||||
phi = torch.where(cosine > self._threshold, phi,
|
||||
cosine - self._min)
|
||||
if isinstance(self.parallel, ModelParallel):
|
||||
mask = self.parallel.correct_mask(target, cosine)
|
||||
else:
|
||||
mask = torch.zeros(
|
||||
cosine.size(), device=cosine.device, dtype=cosine.dtype)
|
||||
mask.scatter_(1, target.view(-1, 1).long(), 1)
|
||||
logits = (mask * phi) + ((1.0 - mask) * cosine)
|
||||
logits *= self.scale
|
||||
return logits
|
|
@ -0,0 +1,206 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Loss modules."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.core.sailfish.activation import LogSoftmax
|
||||
from easycv.core.sailfish.linear import ArcFaceLinear, Linear
|
||||
from easycv.core.sailfish.util import ModelParallel
|
||||
|
||||
|
||||
class NLLLoss(torch.nn.Module):
|
||||
r"""The negative log likelihood loss for log probabilities. It is
|
||||
useful to train a classification problem with `C` classes.
|
||||
|
||||
The `input` given through a forward call is expected to contain
|
||||
log-probabilities of each class. `input` has to be a Tensor of size either
|
||||
:math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)`
|
||||
with :math:`K \geq 1` for the `K`-dimensional case (described later).
|
||||
|
||||
Obtaining log-probabilities in a neural network is easily achieved by
|
||||
adding a `LogSoftmax` layer in the last layer of your network.
|
||||
You may use `CrossEntropyLoss` instead, if you prefer not to add an
|
||||
extra layer.
|
||||
|
||||
The `target` that this loss expects should be a class index in the range
|
||||
:math:`[0, C-1]` where `C = number\_classes`.
|
||||
|
||||
NLLLoss is defined as:
|
||||
.. math::
|
||||
\ell(x, y) = -\frac{1}{N}\sum_{n=1}^N L_{i}
|
||||
|
||||
Args:
|
||||
num_classes: total number of classes.
|
||||
focal: whether to use FocalLoss implementation.
|
||||
focal_gamm: The focusing parameter of FocalLoss.
|
||||
rank: rank of current replica.
|
||||
world_size: size of replicas.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(\frac{N}{P}, C)` where `C = number of classes`, or
|
||||
:math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
|
||||
in the case of `K`-dimensional loss.
|
||||
- Target: :math:`(\frac{N}{P}, 1)` where each value is
|
||||
:math:`0 \leq \text{targets}[i] \leq C-1`, or
|
||||
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
|
||||
K-dimensional loss.
|
||||
- Output: scalar.
|
||||
If :attr:`reduction` is ``'none'``, then the same size as the target:
|
||||
:math:`(N)`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in
|
||||
the case of K-dimensional loss.
|
||||
Examples::
|
||||
>>> m = LogSoftmax(...)
|
||||
>>> loss = NLLLoss(...)
|
||||
>>> # input is of size N x C = 3 x 5
|
||||
>>> input = torch.randn(3, 5, requires_grad=True)
|
||||
>>> # each element in target has to have 0 <= value < C
|
||||
>>> target = torch.tensor([1, 0, 4])
|
||||
>>> output = loss(m(input), target)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
def __init__(self, focal=False, focal_gamma=0, parallel=None):
|
||||
super(NLLLoss, self).__init__()
|
||||
self.focal = focal
|
||||
self.focal_gamma = focal_gamma
|
||||
self.parallel = parallel
|
||||
|
||||
def forward(self, logprob, target): # pylint: disable=arguments-differ
|
||||
"""Compute negative log likelihood loss from log-probs and target."""
|
||||
if isinstance(self.parallel, ModelParallel):
|
||||
with torch.no_grad():
|
||||
mask = self.parallel.correct_mask(target, logprob)
|
||||
loss = self.parallel.nll_loss(logprob, mask)
|
||||
if self.focal:
|
||||
loss_exp = torch.exp(-loss)
|
||||
loss = (1 - loss_exp)**self.focal_gamma * loss
|
||||
return loss
|
||||
loss = torch.nn.functional.nll_loss(logprob, target)
|
||||
if self.focal:
|
||||
loss_exp = torch.exp(-loss)
|
||||
loss = (1 - loss_exp)**self.focal_gamma * loss
|
||||
return loss
|
||||
|
||||
|
||||
class CrossEntropyLoss(torch.nn.Module):
|
||||
r"""This criterion combines :func:`LogSoftmax` and
|
||||
:func:`NLLLoss` in one single class.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon=0, parallel=None):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
|
||||
self._nll_loss = NLLLoss(parallel=parallel)
|
||||
|
||||
def forward(self, logits, target): # pylint: disable=arguments-differ
|
||||
# 1. Compute log-probabilities for current shard:
|
||||
logprob = self._log_softmax(logits)
|
||||
|
||||
# 2. Compute NLL loss for all shards.
|
||||
return self._nll_loss(logprob, target)
|
||||
|
||||
|
||||
class FocalLoss(torch.nn.Module):
|
||||
r"""This criterion combines :func:`LogSoftmax` and
|
||||
:func:`NLLLoss` in one single class.
|
||||
"""
|
||||
|
||||
def __init__(self, gamma=0, epsilon=1e-8, parallel=None):
|
||||
super(FocalLoss, self).__init__()
|
||||
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
|
||||
self._nll_loss = NLLLoss(
|
||||
focal=True, focal_gamma=gamma, parallel=parallel)
|
||||
|
||||
def forward(self, logits, target): # pylint: disable=arguments-differ
|
||||
logprob = self._log_softmax(logits)
|
||||
return self._nll_loss(logprob, target)
|
||||
|
||||
|
||||
class SoftmaxLoss(torch.nn.Module):
|
||||
r"""This criterion combines :func:`Linear` and
|
||||
:func:`CrossEntropyLoss` in one single class.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
epsilon=0,
|
||||
weight_initializer=None,
|
||||
bias_initializer=None,
|
||||
parallel=None):
|
||||
super(SoftmaxLoss, self).__init__()
|
||||
self._linear = Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
parallel=parallel)
|
||||
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
|
||||
self._nll_loss = NLLLoss(parallel=parallel)
|
||||
self._parallel = parallel
|
||||
|
||||
def forward(self, features, target): # pylint: disable=arguments-differ
|
||||
if isinstance(self._parallel, ModelParallel):
|
||||
features = self._parallel.gather(features)
|
||||
logits = self._linear(features.squeeze())
|
||||
logprob = self._log_softmax(logits)
|
||||
if isinstance(self._parallel, ModelParallel):
|
||||
target = self._parallel.gather_target(target)
|
||||
return self._nll_loss(logprob, target)
|
||||
|
||||
|
||||
class ArcMarginLoss(torch.nn.Module):
|
||||
r"""This criterion combines :func:`ArcFaceLinear` and
|
||||
:func:`CrossEntropyLoss` in one single class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
margin=0.5,
|
||||
scale=64.0, # See normface https://arxiv.org/abs/1704.06369
|
||||
fast_phi=False,
|
||||
epsilon=0,
|
||||
weight_initializer=None,
|
||||
l2_norm=False,
|
||||
parallel=None):
|
||||
super(ArcMarginLoss, self).__init__()
|
||||
self._linear = ArcFaceLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
margin=margin,
|
||||
scale=scale,
|
||||
l2_norm=l2_norm,
|
||||
fast_phi=fast_phi,
|
||||
epsilon=epsilon,
|
||||
weight_initializer=weight_initializer,
|
||||
parallel=parallel)
|
||||
self._log_softmax = LogSoftmax(epsilon=epsilon, parallel=parallel)
|
||||
self._nll_loss = NLLLoss(parallel=parallel)
|
||||
self._parallel = parallel
|
||||
|
||||
def forward(self, features, target): # pylint: disable=arguments-differ
|
||||
if isinstance(self._parallel, ModelParallel):
|
||||
features = self._parallel.gather(features)
|
||||
target = self._parallel.gather_target(target)
|
||||
logits = self._linear(features.squeeze(), target)
|
||||
logprob = self._log_softmax(logits)
|
||||
return self._nll_loss(logprob, target)
|
|
@ -0,0 +1,172 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility modules."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.core.sailfish.function import (all_cat, all_log_softmax,
|
||||
all_nll_loss, all_sum,
|
||||
shard_correct_mask,
|
||||
shard_correct_predictions,
|
||||
shard_target_and_mask,
|
||||
shard_topk_correct_predictions)
|
||||
|
||||
|
||||
class DistributedParallel:
|
||||
"""Base class of parallelism."""
|
||||
|
||||
def __init__(self, rank, world_size):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def correct_mask(self, target, inputs):
|
||||
mask = torch.zeros(
|
||||
inputs.size(), device=inputs.device, dtype=inputs.dtype)
|
||||
mask.scatter_(1, target.view(-1, 1).long(), 1)
|
||||
return mask
|
||||
|
||||
def correct_predictions(self, target, logits, k=1):
|
||||
if k == 1:
|
||||
pred = torch.max(logits, dim=1)[1]
|
||||
return (pred == target.view(-1, 1)).sum().item()
|
||||
pred = torch.topk(logits, k, dim=1)[1]
|
||||
return (pred == target.view(-1, 1)).sum().item()
|
||||
|
||||
def xavier_uniform_(self, weight, gain=1.):
|
||||
return torch.nn.init.xavier_uniform_(weight, gain=gain)
|
||||
|
||||
|
||||
class ModelParallel(DistributedParallel):
|
||||
"""All-to-All Model Parallelism."""
|
||||
|
||||
def gather(self, inputs, dim=0, requires_grad=True):
|
||||
if requires_grad:
|
||||
return all_cat(
|
||||
inputs, dim=dim, rank=self.rank, world_size=self.world_size)
|
||||
all_inputs = [
|
||||
torch.zeros(
|
||||
inputs.size(), dtype=inputs.dtype, device=inputs.device)
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
torch.distributed.all_gather(all_inputs, inputs)
|
||||
return torch.cat(all_inputs, dim=dim)
|
||||
|
||||
def gather_target(self, target):
|
||||
return self.gather(target, requires_grad=False)
|
||||
|
||||
def reduce_sum(self, inputs):
|
||||
return all_sum(inputs)
|
||||
|
||||
def log_softmax(self, logits, epsilon=1e-8):
|
||||
return all_log_softmax(logits, epsilon=epsilon)
|
||||
|
||||
def nll_loss(self, inputs, correct_mask):
|
||||
return all_nll_loss(inputs, correct_mask)
|
||||
|
||||
def correct_mask(self, target, inputs):
|
||||
return shard_correct_mask(target, inputs, rank=self.rank)
|
||||
|
||||
def target_and_mask(self, target, output_features):
|
||||
return shard_target_and_mask(target, output_features, rank=self.rank)
|
||||
|
||||
def correct_predictions(self, target, logits, k=1):
|
||||
if k == 1:
|
||||
return shard_correct_predictions(
|
||||
target, logits, world_size=self.world_size)
|
||||
return shard_topk_correct_predictions(
|
||||
target, logits, k, world_size=self.world_size)
|
||||
|
||||
|
||||
class ParameterInitializer:
|
||||
r"""Base class for parameter initializer."""
|
||||
|
||||
def __call__(self, param, shard_rank=0, num_shards=1):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ZerosInitializer(ParameterInitializer):
|
||||
|
||||
def __call__(self, param, parallel=None):
|
||||
torch.nn.init.zeros_(param)
|
||||
|
||||
|
||||
class OnesInitializer(ParameterInitializer):
|
||||
|
||||
def __call__(self, param, parallel=None):
|
||||
torch.nn.init.ones_(param)
|
||||
|
||||
|
||||
class XavierUniformInitializer(ParameterInitializer):
|
||||
|
||||
def __init__(self, gain=1.):
|
||||
self.gain = gain
|
||||
|
||||
def __call__(self, param, parallel=None):
|
||||
if isinstance(parallel, ModelParallel):
|
||||
if param.dim() != 2:
|
||||
raise ValueError(
|
||||
'param with dimensions other than 2 not supported')
|
||||
r = param.size(1) + param.size(0) * parallel.world_size
|
||||
a = self.gain * math.sqrt(3.0) * math.sqrt(2.0 / float(r))
|
||||
torch.nn.init.uniform_(param, -a, a)
|
||||
else:
|
||||
torch.nn.init.xavier_uniform_(param, gain=self.gain)
|
||||
|
||||
|
||||
class KaimingUniformInitializer(ParameterInitializer):
|
||||
|
||||
def __init__(self, bound):
|
||||
self.bound = bound
|
||||
|
||||
def __call__(self, param, parallel=None):
|
||||
torch.nn.init.kaiming_uniform_(param, a=self.bound)
|
||||
|
||||
|
||||
class BiasUniformInitializer(ParameterInitializer):
|
||||
|
||||
def __init__(self, weight_in_features):
|
||||
self.weight_in_features = weight_in_features
|
||||
|
||||
def __call__(self, param, parallel=None):
|
||||
a = 1 / math.sqrt(float(self.weight_in_features))
|
||||
torch.nn.init.uniform_(param, -a, a)
|
||||
|
||||
|
||||
class RenormUniformInitializer(ParameterInitializer):
|
||||
|
||||
def __init__(self, maxnorm=1e-5, scale=1e5):
|
||||
self.maxnorm = maxnorm
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, param, parallel=None):
|
||||
param.data.uniform_(-1, 1).renorm_(
|
||||
2, 0, maxnorm=self.maxnorm).mul_(self.scale)
|
||||
|
||||
|
||||
class NormalInitializer(ParameterInitializer):
|
||||
|
||||
def __call__(self, param, parallel=None):
|
||||
torch.nn.init.normal_(param)
|
|
@ -0,0 +1,528 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ArcFaceLinear module tests."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from easycv.core import sailfish
|
||||
|
||||
|
||||
def mp_vs_ddp_main(gpu, gpus_per_worker):
|
||||
r"""Model parallel vs. DDP"""
|
||||
torch.cuda.set_device(gpu)
|
||||
torch.distributed.init_process_group(
|
||||
'nccl', rank=gpu, world_size=gpus_per_worker)
|
||||
try:
|
||||
num_steps = 5
|
||||
freeze_num_steps = 5
|
||||
learning_rate = 0.1
|
||||
batch_size = 2
|
||||
image_size = 5
|
||||
emb_size = 3
|
||||
num_classes = 8
|
||||
margin_m = 0.5
|
||||
margin_s = 64
|
||||
momentum = 0.9
|
||||
model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker)
|
||||
|
||||
zeros_init = sailfish.ZerosInitializer()
|
||||
|
||||
# baseline
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
baseline_fe = torch.nn.Linear(image_size, emb_size).cuda()
|
||||
baseline_fe = torch.nn.parallel.DistributedDataParallel(
|
||||
baseline_fe, device_ids=[gpu])
|
||||
baseline_fe_params = list(baseline_fe.parameters())
|
||||
baseline_fc = sailfish.ArcFaceLinear(
|
||||
emb_size,
|
||||
num_classes,
|
||||
margin=margin_m,
|
||||
scale=margin_s,
|
||||
weight_initializer=zeros_init).cuda()
|
||||
baseline_fc = torch.nn.parallel.DistributedDataParallel(
|
||||
baseline_fc, device_ids=[gpu])
|
||||
baseline_fc_params = list(baseline_fc.parameters())
|
||||
baseline_criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
baseline_optimizer = torch.optim.SGD(
|
||||
[{
|
||||
'params': baseline_fe.parameters()
|
||||
}, {
|
||||
'params': baseline_fc.parameters()
|
||||
}],
|
||||
lr=learning_rate,
|
||||
momentum=momentum)
|
||||
baseline_fe.train()
|
||||
baseline_fc.train()
|
||||
baseline_criterion.train()
|
||||
|
||||
# hybrid parallelism
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
fe = torch.nn.Linear(image_size, emb_size).cuda()
|
||||
fe = torch.nn.parallel.DistributedDataParallel(fe, device_ids=[gpu])
|
||||
fe_params = list(fe.parameters())
|
||||
fc = sailfish.ArcFaceLinear(
|
||||
emb_size,
|
||||
num_classes,
|
||||
margin=margin_m,
|
||||
scale=margin_s,
|
||||
weight_initializer=zeros_init,
|
||||
parallel=model_parallel).cuda()
|
||||
fc_params = list(fc.parameters())
|
||||
criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda()
|
||||
optimizer = torch.optim.SGD([{
|
||||
'params': fe.parameters()
|
||||
}, {
|
||||
'params': fc.parameters()
|
||||
}],
|
||||
lr=learning_rate,
|
||||
momentum=momentum)
|
||||
fe.train()
|
||||
fc.train()
|
||||
criterion.train()
|
||||
|
||||
for step in range(num_steps):
|
||||
# baseline
|
||||
torch.manual_seed(42 * step + gpu)
|
||||
random.seed(42 * step + gpu)
|
||||
baseline_data = torch.randn([batch_size, image_size]).cuda()
|
||||
baseline_label = torch.as_tensor([
|
||||
random.randint(0, num_classes - 1) for _ in range(batch_size)
|
||||
]).cuda()
|
||||
baseline_features = baseline_fe(baseline_data)
|
||||
baseline_logits = baseline_fc(baseline_features, baseline_label)
|
||||
baseline_loss = baseline_criterion(baseline_logits, baseline_label)
|
||||
baseline_loss = model_parallel.reduce_sum(baseline_loss)
|
||||
baseline_loss = baseline_loss / gpus_per_worker
|
||||
baseline_optimizer.zero_grad()
|
||||
baseline_loss.backward()
|
||||
baseline_optimizer.step()
|
||||
|
||||
# hybrid parallelism
|
||||
torch.manual_seed(42 * step + gpu)
|
||||
random.seed(42 * step + gpu)
|
||||
data = torch.randn([batch_size, image_size]).cuda()
|
||||
label = torch.as_tensor([
|
||||
random.randint(0, num_classes - 1) for _ in range(batch_size)
|
||||
]).cuda()
|
||||
features = fe(data)
|
||||
all_features = model_parallel.gather(features)
|
||||
all_label = model_parallel.gather_target(label)
|
||||
shard_logits = fc(all_features, all_label)
|
||||
loss = criterion(shard_logits, all_label)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# eval
|
||||
torch.manual_seed(42 * step + gpu)
|
||||
random.seed(42 * step + gpu)
|
||||
with torch.no_grad():
|
||||
gathered_logits = model_parallel.gather(shard_logits, dim=1)
|
||||
gathered_baseline_logits = model_parallel.gather(
|
||||
baseline_logits, dim=0)
|
||||
logits_norm_val = torch.norm(gathered_logits).item()
|
||||
baseline_logits_norm_val = torch.norm(
|
||||
gathered_baseline_logits).item()
|
||||
np.testing.assert_allclose(
|
||||
logits_norm_val,
|
||||
baseline_logits_norm_val,
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='logits at gpu {} step {}'.format(gpu, step))
|
||||
|
||||
loss_val = loss.cpu().detach().numpy()
|
||||
baseline_loss_val = baseline_loss.cpu().detach().numpy()
|
||||
np.testing.assert_allclose(
|
||||
loss_val,
|
||||
baseline_loss_val,
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='loss at gpu {} step {}'.format(gpu, step))
|
||||
|
||||
fc_grad = model_parallel.gather(fc_params[0].grad)
|
||||
baseline_fc_grad = baseline_fc_params[0].grad
|
||||
np.testing.assert_allclose(
|
||||
fc_grad.cpu().detach().numpy(),
|
||||
baseline_fc_grad.cpu().detach().numpy(),
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='fc grad at gpu {} step {}'.format(gpu, step))
|
||||
|
||||
fe_weight = fe_params[0]
|
||||
baseline_fe_weight = baseline_fe_params[0]
|
||||
np.testing.assert_allclose(
|
||||
fe_weight.cpu().detach().numpy(),
|
||||
baseline_fe_weight.cpu().detach().numpy(),
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='fe weight at gpu {} step {}'.format(gpu, step))
|
||||
|
||||
for p in baseline_fe.parameters():
|
||||
p.requires_grad = False
|
||||
for p in fe.parameters():
|
||||
p.requires_grad = False
|
||||
for step in range(freeze_num_steps):
|
||||
# baseline
|
||||
torch.manual_seed(100 * step + gpu)
|
||||
random.seed(100 * step + gpu)
|
||||
baseline_data = torch.randn([batch_size, image_size]).cuda()
|
||||
baseline_label = torch.as_tensor([
|
||||
random.randint(0, num_classes - 1) for _ in range(batch_size)
|
||||
]).cuda()
|
||||
baseline_features = baseline_fe(baseline_data)
|
||||
baseline_logits = baseline_fc(baseline_features, baseline_label)
|
||||
baseline_loss = baseline_criterion(baseline_logits, baseline_label)
|
||||
baseline_loss = model_parallel.reduce_sum(baseline_loss)
|
||||
baseline_loss = baseline_loss / gpus_per_worker
|
||||
baseline_optimizer.zero_grad()
|
||||
baseline_loss.backward()
|
||||
baseline_optimizer.step()
|
||||
|
||||
# hybrid parallelism
|
||||
torch.manual_seed(100 * step + gpu)
|
||||
random.seed(100 * step + gpu)
|
||||
data = torch.randn([batch_size, image_size]).cuda()
|
||||
label = torch.as_tensor([
|
||||
random.randint(0, num_classes - 1) for _ in range(batch_size)
|
||||
]).cuda()
|
||||
features = fe(data)
|
||||
all_features = model_parallel.gather(features)
|
||||
all_label = model_parallel.gather_target(label)
|
||||
shard_logits = fc(all_features, all_label)
|
||||
loss = criterion(shard_logits, all_label)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# eval
|
||||
torch.manual_seed(100 * step + gpu)
|
||||
random.seed(100 * step + gpu)
|
||||
with torch.no_grad():
|
||||
gathered_logits = model_parallel.gather(shard_logits, dim=1)
|
||||
gathered_baseline_logits = model_parallel.gather(
|
||||
baseline_logits, dim=0)
|
||||
logits_norm_val = torch.norm(gathered_logits).item()
|
||||
baseline_logits_norm_val = torch.norm(
|
||||
gathered_baseline_logits).item()
|
||||
np.testing.assert_allclose(
|
||||
logits_norm_val,
|
||||
baseline_logits_norm_val,
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='freeze logits at gpu {} step {}'.format(
|
||||
gpu, step))
|
||||
|
||||
loss_val = loss.cpu().detach().numpy()
|
||||
baseline_loss_val = baseline_loss.cpu().detach().numpy()
|
||||
np.testing.assert_allclose(
|
||||
loss_val,
|
||||
baseline_loss_val,
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='freeze loss at gpu {} step {}'.format(gpu, step))
|
||||
|
||||
fc_grad = model_parallel.gather(fc_params[0].grad)
|
||||
baseline_fc_grad = baseline_fc_params[0].grad
|
||||
np.testing.assert_allclose(
|
||||
fc_grad.cpu().detach().numpy(),
|
||||
baseline_fc_grad.cpu().detach().numpy(),
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='freeze fc grad at gpu {} step {}'.format(
|
||||
gpu, step))
|
||||
|
||||
fe_weight = fe_params[0]
|
||||
baseline_fe_weight = baseline_fe_params[0]
|
||||
np.testing.assert_allclose(
|
||||
fe_weight.cpu().detach().numpy(),
|
||||
baseline_fe_weight.cpu().detach().numpy(),
|
||||
rtol=1e-5,
|
||||
atol=1e-4,
|
||||
err_msg='freeze fe weight at gpu {} step {}'.format(
|
||||
gpu, step))
|
||||
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def mp_main(gpu,
|
||||
gpus_per_worker,
|
||||
results,
|
||||
num_steps=1,
|
||||
batch_size=1,
|
||||
num_classes=8):
|
||||
r"""Model parallel"""
|
||||
torch.cuda.set_device(gpu)
|
||||
torch.distributed.init_process_group(
|
||||
'nccl', rank=gpu, world_size=gpus_per_worker)
|
||||
zeros_init = sailfish.ZerosInitializer()
|
||||
try:
|
||||
emb_size = 3
|
||||
learning_rate = 0.1
|
||||
margin_m = 0.5
|
||||
margin_s = 64
|
||||
momentum = 0.9
|
||||
image_size = 6
|
||||
model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker)
|
||||
|
||||
# hybrid parallelism
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
fe = torch.nn.Linear(image_size, emb_size).cuda()
|
||||
fc = sailfish.ArcFaceLinear(
|
||||
emb_size,
|
||||
num_classes,
|
||||
margin=margin_m,
|
||||
scale=margin_s,
|
||||
weight_initializer=zeros_init,
|
||||
parallel=model_parallel).cuda()
|
||||
fc_params = list(fc.parameters())
|
||||
criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda()
|
||||
optimizer = torch.optim.SGD(
|
||||
fc.parameters(), lr=learning_rate, momentum=momentum)
|
||||
fc.train()
|
||||
criterion.train()
|
||||
|
||||
for step in range(num_steps):
|
||||
baseline = results[step]
|
||||
torch.manual_seed(42 * step + gpu)
|
||||
random.seed(42 * step + gpu)
|
||||
data = torch.randn([batch_size, image_size]).cuda()
|
||||
features = fe(data)
|
||||
label = torch.as_tensor([
|
||||
random.randint(0, num_classes - 1) for _ in range(batch_size)
|
||||
]).cuda()
|
||||
all_features = model_parallel.gather(features)
|
||||
all_label = model_parallel.gather_target(label)
|
||||
torch.manual_seed(42 * step)
|
||||
random.seed(42 * step)
|
||||
np.testing.assert_equal(
|
||||
list(all_features.size()),
|
||||
baseline['features/size'],
|
||||
err_msg='Wrong features size at gpu {} step {}'.format(
|
||||
gpu, step))
|
||||
np.testing.assert_allclose(
|
||||
torch.norm(all_features).item(),
|
||||
baseline['features/norm'],
|
||||
rtol=1e-5,
|
||||
err_msg='Wrong features norm at gpu {} step {}'.format(
|
||||
gpu, step))
|
||||
shard_logits = fc(all_features, all_label)
|
||||
loss = criterion(shard_logits, all_label)
|
||||
np.testing.assert_allclose(
|
||||
loss.item(),
|
||||
baseline['loss'],
|
||||
rtol=1e-5,
|
||||
err_msg='Wrong loss at gpu {} step {}'.format(gpu, step))
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
fc_grad = model_parallel.gather(fc_params[0].grad)
|
||||
np.testing.assert_allclose(
|
||||
torch.norm(fc_grad).item(),
|
||||
baseline['logits/grad/norm'],
|
||||
rtol=1e-5,
|
||||
err_msg='Wrong logits grad at gpu {} step {}'.format(
|
||||
gpu, step))
|
||||
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def baseline_main(gpus_per_worker, num_steps=1, batch_size=1, num_classes=8):
|
||||
r"""run on 1 GPU"""
|
||||
emb_size = 3
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
image_size = 6
|
||||
|
||||
zeros_init = sailfish.ZerosInitializer()
|
||||
|
||||
# hybrid parallelism
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
fe = torch.nn.Linear(image_size, emb_size).cuda()
|
||||
fc = sailfish.ArcFaceLinear(
|
||||
emb_size, num_classes, weight_initializer=zeros_init).cuda()
|
||||
fc_params = list(fc.parameters())
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
optimizer = torch.optim.SGD(
|
||||
fc.parameters(), lr=learning_rate, momentum=momentum)
|
||||
fc.train()
|
||||
criterion.train()
|
||||
|
||||
results = []
|
||||
for step in range(num_steps):
|
||||
result_item = {}
|
||||
features_list = []
|
||||
label_list = []
|
||||
for gpu in range(gpus_per_worker):
|
||||
torch.manual_seed(42 * step + gpu)
|
||||
random.seed(42 * step + gpu)
|
||||
features_list.append(
|
||||
fe(torch.randn([batch_size, image_size]).cuda()))
|
||||
label_list.append(
|
||||
torch.as_tensor([
|
||||
random.randint(0, num_classes - 1)
|
||||
for _ in range(batch_size)
|
||||
]).cuda())
|
||||
all_features = torch.cat(features_list)
|
||||
all_label = torch.cat(label_list)
|
||||
torch.manual_seed(42 * step)
|
||||
random.seed(42 * step)
|
||||
result_item['features/size'] = list(all_features.size())
|
||||
result_item['features/norm'] = torch.norm(all_features).item()
|
||||
logits = fc(all_features, all_label)
|
||||
loss = criterion(logits, all_label)
|
||||
result_item['loss'] = loss.item()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
result_item['logits/grad/norm'] = torch.norm(fc_params[0].grad).item()
|
||||
results.append(result_item)
|
||||
return results
|
||||
|
||||
|
||||
class TestArcFaceLinear(unittest.TestCase):
|
||||
r"""Test sailfish.ArcFaceLinear."""
|
||||
|
||||
def test_no_parallel(self):
|
||||
r"""Test sailfish.ArcFaceLinear without parallel."""
|
||||
in_features = 1
|
||||
out_features = 2
|
||||
margin_m = random.random()
|
||||
margin_s = random.random()
|
||||
|
||||
features = torch.randn([1, in_features])
|
||||
label = torch.as_tensor([random.randint(0, out_features - 1)])
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
baseline = sailfish.ArcFaceLinear(
|
||||
in_features, out_features, margin=margin_m, scale=margin_s)
|
||||
baseline_optimizer = torch.optim.SGD(baseline.parameters(), lr=1.)
|
||||
baseline.train()
|
||||
baseline_logits = baseline(features, label)
|
||||
baseline_loss = torch.sum(baseline_logits)
|
||||
baseline_optimizer.zero_grad()
|
||||
baseline_loss.backward()
|
||||
baseline_optimizer.step()
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
fc = sailfish.ArcFaceLinear(
|
||||
in_features, out_features, margin=margin_m, scale=margin_s)
|
||||
optimizer = torch.optim.SGD(fc.parameters(), lr=1.)
|
||||
fc.train()
|
||||
logits = fc(features, label)
|
||||
loss = torch.sum(logits)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
np.testing.assert_allclose(
|
||||
logits.detach().numpy(),
|
||||
baseline_logits.detach().numpy(),
|
||||
err_msg='logits not equal to baseline')
|
||||
np.testing.assert_allclose(
|
||||
[p.detach().numpy() for p in baseline.parameters()],
|
||||
[p.detach().numpy() for p in fc.parameters()],
|
||||
err_msg='parameters not equal to baseline')
|
||||
|
||||
def test_mp(self):
|
||||
r"""Test sailfish.ArcFaceLinear on 1 GPU."""
|
||||
in_features = 1
|
||||
out_features = 2
|
||||
|
||||
features = torch.randn([1, in_features])
|
||||
label = torch.as_tensor([random.randint(0, out_features - 1)])
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
baseline = sailfish.ArcFaceLinear(in_features, out_features)
|
||||
baseline_optimizer = torch.optim.SGD(baseline.parameters(), lr=1.)
|
||||
baseline.train()
|
||||
baseline_logits = baseline(features, label)
|
||||
baseline_loss = torch.sum(baseline_logits)
|
||||
baseline_optimizer.zero_grad()
|
||||
baseline_loss.backward()
|
||||
baseline_optimizer.step()
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
model_parallel = sailfish.ModelParallel(0, 1)
|
||||
fc = sailfish.ArcFaceLinear(
|
||||
in_features, out_features, parallel=model_parallel)
|
||||
optimizer = torch.optim.SGD(fc.parameters(), lr=1.)
|
||||
fc.train()
|
||||
logits = fc(features, label)
|
||||
loss = torch.sum(logits)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
np.testing.assert_allclose(
|
||||
logits.detach().numpy(),
|
||||
baseline_logits.detach().numpy(),
|
||||
err_msg='logits not equal to baseline')
|
||||
np.testing.assert_allclose(
|
||||
[p.detach().numpy() for p in baseline.parameters()],
|
||||
[p.detach().numpy() for p in fc.parameters()],
|
||||
err_msg='parameters not equal to baseline')
|
||||
|
||||
def cant_test_mp_vs_ddp(self):
|
||||
r"""Test sailfish.ArcFaceLinear with model parallel."""
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '24601'
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
os.environ['RANK'] = '0'
|
||||
|
||||
gpus_per_worker = torch.cuda.device_count()
|
||||
torch.multiprocessing.spawn(
|
||||
mp_vs_ddp_main,
|
||||
args=(gpus_per_worker, ),
|
||||
nprocs=gpus_per_worker,
|
||||
join=True)
|
||||
|
||||
def test_mp_vs_1gpu(self):
|
||||
r"""Test sailfish.ArcFaceLinear with model parallel."""
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '24601'
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
os.environ['RANK'] = '0'
|
||||
|
||||
gpus_per_worker = torch.cuda.device_count()
|
||||
num_steps = 5
|
||||
batch_size = 1
|
||||
num_classes = gpus_per_worker
|
||||
results = baseline_main(gpus_per_worker, num_steps, batch_size,
|
||||
num_classes)
|
||||
torch.multiprocessing.spawn(
|
||||
mp_main,
|
||||
args=(gpus_per_worker, results, num_steps, batch_size,
|
||||
num_classes),
|
||||
nprocs=gpus_per_worker,
|
||||
join=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,314 @@
|
|||
# Copyright 2019 Alibaba Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Linear module tests."""
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from easycv.core import sailfish
|
||||
|
||||
|
||||
class MockLinear(torch.nn.Module):
|
||||
r"""Applies a linear transformation to the incoming data.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
weight_initializer=None,
|
||||
bias_initializer=None,
|
||||
parallel=None):
|
||||
super(MockLinear, self).__init__()
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.Tensor(self.out_features, self.in_features))
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(self.out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.weight_initializer = weight_initializer
|
||||
if weight_initializer is None:
|
||||
self.weight_initializer = sailfish.KaimingUniformInitializer(
|
||||
math.sqrt(5))
|
||||
self.bias_initializer = bias_initializer
|
||||
if bias_initializer is None:
|
||||
self.bias_initializer = sailfish.BiasUniformInitializer(
|
||||
self.in_features)
|
||||
self.reset_parameters()
|
||||
self.parallel = parallel
|
||||
|
||||
def reset_parameters(self):
|
||||
r"""Reset parameters."""
|
||||
self.weight_initializer(self.weight)
|
||||
if self.bias is not None:
|
||||
self.bias_initializer(self.bias)
|
||||
|
||||
def forward(self, features): # pylint: disable=arguments-differ
|
||||
features = features.type(dtype=self.weight.dtype)
|
||||
return torch.nn.functional.linear(features, self.weight, self.bias)
|
||||
|
||||
|
||||
def _run_baseline_train_main(gpus_per_worker, num_steps, batch_size,
|
||||
in_features, out_features, bias, lr):
|
||||
r"""Run baseline on 1 GPU."""
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
fc = MockLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
weight_initializer=sailfish.ZerosInitializer(),
|
||||
bias_initializer=sailfish.OnesInitializer()).cuda()
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
|
||||
fc.train()
|
||||
criterion.train()
|
||||
|
||||
results = []
|
||||
for step in range(num_steps):
|
||||
result = {}
|
||||
features_list = []
|
||||
label_list = []
|
||||
for gpu in range(gpus_per_worker):
|
||||
torch.manual_seed(42 * step + gpu)
|
||||
random.seed(42 * step + gpu)
|
||||
features_list.append(torch.randn([batch_size, in_features]).cuda())
|
||||
label_list.append(
|
||||
torch.as_tensor([
|
||||
random.randint(0, out_features - 1)
|
||||
for _ in range(batch_size)
|
||||
]).cuda())
|
||||
features = torch.cat(features_list)
|
||||
label = torch.cat(label_list)
|
||||
|
||||
torch.manual_seed(42 * step)
|
||||
random.seed(42 * step)
|
||||
logits = fc(features)
|
||||
loss = criterion(logits, label)
|
||||
result['loss'] = loss.item()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
result['grads/norm'] = [
|
||||
torch.norm(p.grad).item() for p in fc.parameters()
|
||||
]
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _run_mp_train_main(gpu, gpus_per_worker, baseline_steps, num_steps,
|
||||
batch_size, in_features, out_features, bias, lr):
|
||||
r"""Run MP and validate results."""
|
||||
torch.cuda.set_device(gpu)
|
||||
torch.distributed.init_process_group(
|
||||
'nccl', rank=gpu, world_size=gpus_per_worker)
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
model_parallel = sailfish.ModelParallel(gpu, gpus_per_worker)
|
||||
fc = sailfish.Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
weight_initializer=sailfish.ZerosInitializer(),
|
||||
bias_initializer=sailfish.OnesInitializer(),
|
||||
parallel=model_parallel).cuda()
|
||||
criterion = sailfish.CrossEntropyLoss(parallel=model_parallel).cuda()
|
||||
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
|
||||
fc.train()
|
||||
criterion.train()
|
||||
|
||||
for step in range(num_steps):
|
||||
torch.manual_seed(42 * step + gpu)
|
||||
random.seed(42 * step + gpu)
|
||||
features = torch.randn([batch_size, in_features]).cuda()
|
||||
features = model_parallel.gather(features)
|
||||
label = torch.as_tensor([
|
||||
random.randint(0, out_features - 1) for _ in range(batch_size)
|
||||
]).cuda()
|
||||
label = model_parallel.gather_target(label)
|
||||
|
||||
torch.manual_seed(42 * step)
|
||||
random.seed(42 * step)
|
||||
logits = fc(features)
|
||||
loss = criterion(logits, label)
|
||||
np.testing.assert_allclose(
|
||||
loss.item(),
|
||||
baseline_steps[step]['loss'],
|
||||
rtol=1e-5,
|
||||
err_msg='Wrong loss at gpu {} step {}'.format(gpu, step))
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
grad_norms = [
|
||||
torch.norm(model_parallel.gather(p.grad)).item()
|
||||
for p in fc.parameters()
|
||||
]
|
||||
np.testing.assert_allclose(
|
||||
grad_norms,
|
||||
baseline_steps[step]['grads/norm'],
|
||||
rtol=1e-5,
|
||||
err_msg='Wrong grads norm at gpu {} step {}'.format(gpu, step))
|
||||
|
||||
|
||||
class TestLinear(unittest.TestCase):
|
||||
r"""Test sailfish.Linear."""
|
||||
|
||||
def _run_baseline_train(self, batch_size, in_features, out_features, bias,
|
||||
lr):
|
||||
r"""Run baseline without parallel."""
|
||||
result = {}
|
||||
features = torch.randn([batch_size, in_features])
|
||||
fc = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
|
||||
fc.train()
|
||||
logits = fc(features)
|
||||
loss = torch.sum(logits)
|
||||
result['loss'] = loss.item()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
result['grads/norm'] = [
|
||||
torch.norm(p.grad).item() for p in fc.parameters()
|
||||
]
|
||||
return result
|
||||
|
||||
def _run_mp_no_parallel_train(self, batch_size, in_features, out_features,
|
||||
bias, lr):
|
||||
r"""Run MP without parallel."""
|
||||
result = {}
|
||||
features = torch.randn([batch_size, in_features])
|
||||
fc = sailfish.Linear(in_features, out_features, bias=bias)
|
||||
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
|
||||
fc.train()
|
||||
logits = fc(features)
|
||||
loss = torch.sum(logits)
|
||||
result['loss'] = loss.item()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
result['grads/norm'] = [
|
||||
torch.norm(p.grad).item() for p in fc.parameters()
|
||||
]
|
||||
return result
|
||||
|
||||
def _run_mp_1gpu_train(self, batch_size, in_features, out_features, bias,
|
||||
lr):
|
||||
r"""Run MP on 1 GPU."""
|
||||
result = {}
|
||||
features = torch.randn([batch_size, in_features])
|
||||
model_parallel = sailfish.ModelParallel(0, 1)
|
||||
fc = sailfish.Linear(
|
||||
in_features, out_features, bias=bias, parallel=model_parallel)
|
||||
optimizer = torch.optim.SGD(fc.parameters(), lr=lr)
|
||||
fc.train()
|
||||
logits = fc(features)
|
||||
loss = torch.sum(logits)
|
||||
result['loss'] = loss.item()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
result['grads/norm'] = [
|
||||
torch.norm(p.grad).item() for p in fc.parameters()
|
||||
]
|
||||
return result
|
||||
|
||||
def test_no_parallel(self):
|
||||
r"""Test sailfish.Linear without parallel."""
|
||||
batch_size = 3
|
||||
in_features = 4
|
||||
out_features = 5
|
||||
bias = False
|
||||
lr = 0.1
|
||||
|
||||
for step in range(5):
|
||||
torch.manual_seed(42 + step)
|
||||
random.seed(42 + step)
|
||||
baseline = self._run_baseline_train(batch_size, in_features,
|
||||
out_features, bias, lr)
|
||||
|
||||
torch.manual_seed(42 + step)
|
||||
random.seed(42 + step)
|
||||
rc = self._run_mp_no_parallel_train(batch_size, in_features,
|
||||
out_features, bias, lr)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
rc['loss'], baseline['loss'], err_msg='loss not equal')
|
||||
np.testing.assert_allclose(
|
||||
rc['grads/norm'],
|
||||
baseline['grads/norm'],
|
||||
err_msg='norm of grads not equal')
|
||||
|
||||
def test_mp(self):
|
||||
r"""Test sailfish.Linear with model parallel on 1 GPU."""
|
||||
batch_size = 2
|
||||
in_features = 7
|
||||
out_features = 4
|
||||
bias = True
|
||||
lr = 0.6
|
||||
|
||||
for step in range(5):
|
||||
torch.manual_seed(100 + step)
|
||||
random.seed(100 + step)
|
||||
baseline = self._run_baseline_train(batch_size, in_features,
|
||||
out_features, bias, lr)
|
||||
|
||||
torch.manual_seed(100 + step)
|
||||
random.seed(100 + step)
|
||||
rc = self._run_mp_1gpu_train(batch_size, in_features, out_features,
|
||||
bias, lr)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
rc['loss'], baseline['loss'], err_msg='loss not equal')
|
||||
np.testing.assert_allclose(
|
||||
rc['grads/norm'],
|
||||
baseline['grads/norm'],
|
||||
err_msg='norm of grads not equal')
|
||||
|
||||
def test_mp_vs_1gpu(self):
|
||||
r"""Test sailfish.ArcFaceLinear with model parallel."""
|
||||
gpus_per_worker = torch.cuda.device_count()
|
||||
num_steps = 5
|
||||
batch_size = 2
|
||||
in_features = 3
|
||||
out_features = gpus_per_worker
|
||||
bias = True
|
||||
lr = 0.6
|
||||
|
||||
baseline_steps = _run_baseline_train_main(gpus_per_worker, num_steps,
|
||||
batch_size, in_features,
|
||||
out_features, bias, lr)
|
||||
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '24601'
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
os.environ['RANK'] = '0'
|
||||
torch.multiprocessing.spawn(
|
||||
_run_mp_train_main,
|
||||
args=(gpus_per_worker, baseline_steps, num_steps, batch_size,
|
||||
in_features, out_features, bias, lr),
|
||||
nprocs=gpus_per_worker,
|
||||
join=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue