performance test with SyncBN
parent
87466446e3
commit
a5884eaa48
data
layers/sync_batchnorm
solver
tools
|
@ -0,0 +1,66 @@
|
|||
MODEL:
|
||||
PRETRAIN_PATH: '/home/zbc/.torch/models/resnet50-19c8e357.pth'
|
||||
METRIC_LOSS_TYPE: 'triplet_center'
|
||||
IF_LABELSMOOTH: 'on'
|
||||
IF_WITH_CENTER: 'yes'
|
||||
|
||||
|
||||
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
PROB: 0.5 # random horizontal flip
|
||||
RE_PROB: 0.5 # random erasing
|
||||
PADDING: 10
|
||||
|
||||
DATASETS:
|
||||
NAMES: ('market1501')
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER: 'softmax_triplet'
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
OPTIMIZER_NAME: 'Adam'
|
||||
MAX_EPOCHS: 120
|
||||
BASE_LR: 0.000175
|
||||
|
||||
CLUSTER_MARGIN: 0.3
|
||||
|
||||
CENTER_LR: 0.5
|
||||
CENTER_LOSS_WEIGHT: 0.0005
|
||||
|
||||
RANGE_K: 2
|
||||
RANGE_MARGIN: 0.3
|
||||
RANGE_ALPHA: 0
|
||||
RANGE_BETA: 1
|
||||
RANGE_LOSS_WEIGHT: 1
|
||||
|
||||
BIAS_LR_FACTOR: 1
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 32
|
||||
|
||||
STEPS: [40, 70]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
WARMUP_METHOD: 'linear'
|
||||
|
||||
CHECKPOINT_PERIOD: 40
|
||||
LOG_PERIOD: 20
|
||||
EVAL_PERIOD: 40
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 128
|
||||
RE_RANKING: 'no'
|
||||
WEIGHT: "path"
|
||||
NECK_FEAT: 'after'
|
||||
FEAT_NORM: 'yes'
|
||||
|
||||
OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005"
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid
|
|||
from .transforms import build_transforms
|
||||
|
||||
|
||||
def make_data_loader(cfg):
|
||||
def make_data_loader(cfg, num_gpus=1):
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
val_transforms = build_transforms(cfg, is_train=False)
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
|
@ -26,13 +26,15 @@ def make_data_loader(cfg):
|
|||
train_set = ImageDataset(dataset.train, train_transforms)
|
||||
if cfg.DATALOADER.SAMPLER == 'softmax':
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH * num_gpus, shuffle=True, num_workers=num_workers,
|
||||
collate_fn=train_collate_fn
|
||||
)
|
||||
else:
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
|
||||
sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH * num_gpus,
|
||||
sampler=RandomIdentitySampler(dataset.train,
|
||||
cfg.SOLVER.IMS_PER_BATCH * num_gpus,
|
||||
cfg.DATALOADER.NUM_INSTANCE),
|
||||
# sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu
|
||||
num_workers=num_workers, collate_fn=train_collate_fn
|
||||
)
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# File : __init__.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
||||
from .batchnorm import convert_model
|
||||
from .replicate import DataParallelWithCallback, patch_replication_callback
|
|
@ -0,0 +1,361 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||
|
||||
from .comm import SyncMaster
|
||||
from .replicate import DataParallelWithCallback
|
||||
|
||||
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d',
|
||||
'SynchronizedBatchNorm3d', 'convert_model']
|
||||
|
||||
|
||||
def _sum_ft(tensor):
|
||||
"""sum over the first and last dimention"""
|
||||
return tensor.sum(dim=0).sum(dim=-1)
|
||||
|
||||
|
||||
def _unsqueeze_ft(tensor):
|
||||
"""add new dementions at the front and the tail"""
|
||||
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
|
||||
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||
|
||||
|
||||
class _SynchronizedBatchNorm(_BatchNorm):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
||||
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
||||
|
||||
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||
|
||||
self._is_parallel = False
|
||||
self._parallel_id = None
|
||||
self._slave_pipe = None
|
||||
|
||||
def forward(self, input):
|
||||
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||
if not (self._is_parallel and self.training):
|
||||
return F.batch_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training, self.momentum, self.eps)
|
||||
|
||||
# Resize the input to (B, C, -1).
|
||||
input_shape = input.size()
|
||||
input = input.view(input.size(0), self.num_features, -1)
|
||||
|
||||
# Compute the sum and square-sum.
|
||||
sum_size = input.size(0) * input.size(2)
|
||||
input_sum = _sum_ft(input)
|
||||
input_ssum = _sum_ft(input ** 2)
|
||||
|
||||
# Reduce-and-broadcast the statistics.
|
||||
if self._parallel_id == 0:
|
||||
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
else:
|
||||
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
|
||||
# Compute the output.
|
||||
if self.affine:
|
||||
# MJY:: Fuse the multiplication for speed.
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||
else:
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||
|
||||
# Reshape it.
|
||||
return output.view(input_shape)
|
||||
|
||||
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||
self._is_parallel = True
|
||||
self._parallel_id = copy_id
|
||||
|
||||
# parallel_id == 0 means master device.
|
||||
if self._parallel_id == 0:
|
||||
ctx.sync_master = self._sync_master
|
||||
else:
|
||||
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||
|
||||
def _data_parallel_master(self, intermediates):
|
||||
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||
|
||||
# Always using same "device order" makes the ReduceAdd operation faster.
|
||||
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
||||
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||
|
||||
to_reduce = [i[1][:2] for i in intermediates]
|
||||
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||
|
||||
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||
|
||||
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||
|
||||
outputs = []
|
||||
for i, rec in enumerate(intermediates):
|
||||
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
||||
|
||||
return outputs
|
||||
|
||||
def _compute_mean_std(self, sum_, ssum, size):
|
||||
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||
also maintains the moving average on the master device."""
|
||||
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||
mean = sum_ / size
|
||||
sumvar = ssum - sum_ * mean
|
||||
unbias_var = sumvar / (size - 1)
|
||||
bias_var = sumvar / size
|
||||
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
|
||||
return mean, bias_var.clamp(self.eps) ** -0.5
|
||||
|
||||
|
||||
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||
mini-batch.
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of size
|
||||
`batch_size x num_features [x width]`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||
of 3d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, H, W)`
|
||||
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||
of 4d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||
or Spatio-temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x depth x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
def convert_model(module):
|
||||
"""Traverse the input module and its child recursively
|
||||
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
||||
to SynchronizedBatchNorm*N*d
|
||||
|
||||
Args:
|
||||
module: the input module needs to be convert to SyncBN model
|
||||
|
||||
Examples:
|
||||
>>> import torch.nn as nn
|
||||
>>> import torchvision
|
||||
>>> # m is a standard pytorch model
|
||||
>>> m = torchvision.models.resnet18(True)
|
||||
>>> m = nn.DataParallel(m)
|
||||
>>> # after convert, m is using SyncBN
|
||||
>>> m = convert_model(m)
|
||||
"""
|
||||
if isinstance(module, torch.nn.DataParallel):
|
||||
mod = module.module
|
||||
mod = convert_model(mod)
|
||||
mod = DataParallelWithCallback(mod)
|
||||
return mod
|
||||
|
||||
mod = module
|
||||
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
||||
torch.nn.modules.batchnorm.BatchNorm2d,
|
||||
torch.nn.modules.batchnorm.BatchNorm3d],
|
||||
[SynchronizedBatchNorm1d,
|
||||
SynchronizedBatchNorm2d,
|
||||
SynchronizedBatchNorm3d]):
|
||||
if isinstance(module, pth_module):
|
||||
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
||||
mod.running_mean = module.running_mean
|
||||
mod.running_var = module.running_var
|
||||
if module.affine:
|
||||
mod.weight.data = module.weight.data.clone().detach()
|
||||
mod.bias.data = module.bias.data.clone().detach()
|
||||
|
||||
for name, child in module.named_children():
|
||||
mod.add_module(name, convert_model(child))
|
||||
|
||||
return mod
|
|
@ -0,0 +1,74 @@
|
|||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm_reimpl.py
|
||||
# Author : acgtyrant
|
||||
# Date : 11/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
__all__ = ['BatchNormReimpl']
|
||||
|
||||
|
||||
class BatchNorm2dReimpl(nn.Module):
|
||||
"""
|
||||
A re-implementation of batch normalization, used for testing the numerical
|
||||
stability.
|
||||
|
||||
Author: acgtyrant
|
||||
See also:
|
||||
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
||||
"""
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = nn.Parameter(torch.empty(num_features))
|
||||
self.bias = nn.Parameter(torch.empty(num_features))
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_running_stats(self):
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.reset_running_stats()
|
||||
init.uniform_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input_):
|
||||
batchsize, channels, height, width = input_.size()
|
||||
numel = batchsize * height * width
|
||||
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
||||
sum_ = input_.sum(1)
|
||||
sum_of_square = input_.pow(2).sum(1)
|
||||
mean = sum_ / numel
|
||||
sumvar = sum_of_square - sum_ * mean
|
||||
|
||||
self.running_mean = (
|
||||
(1 - self.momentum) * self.running_mean
|
||||
+ self.momentum * mean.detach()
|
||||
)
|
||||
unbias_var = sumvar / (numel - 1)
|
||||
self.running_var = (
|
||||
(1 - self.momentum) * self.running_var
|
||||
+ self.momentum * unbias_var.detach()
|
||||
)
|
||||
|
||||
bias_var = sumvar / numel
|
||||
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
||||
output = (
|
||||
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
||||
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
||||
|
||||
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# File : comm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import queue
|
||||
import collections
|
||||
import threading
|
||||
|
||||
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||
|
||||
|
||||
class FutureResult(object):
|
||||
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._lock = threading.Lock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
|
||||
def put(self, result):
|
||||
with self._lock:
|
||||
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||
self._result = result
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
if self._result is None:
|
||||
self._cond.wait()
|
||||
|
||||
res = self._result
|
||||
self._result = None
|
||||
return res
|
||||
|
||||
|
||||
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||
|
||||
|
||||
class SlavePipe(_SlavePipeBase):
|
||||
"""Pipe for master-slave communication."""
|
||||
|
||||
def run_slave(self, msg):
|
||||
self.queue.put((self.identifier, msg))
|
||||
ret = self.result.get()
|
||||
self.queue.put(True)
|
||||
return ret
|
||||
|
||||
|
||||
class SyncMaster(object):
|
||||
"""An abstract `SyncMaster` object.
|
||||
|
||||
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||
and passed to a registered callback.
|
||||
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||
back to each slave devices.
|
||||
"""
|
||||
|
||||
def __init__(self, master_callback):
|
||||
"""
|
||||
|
||||
Args:
|
||||
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||
"""
|
||||
self._master_callback = master_callback
|
||||
self._queue = queue.Queue()
|
||||
self._registry = collections.OrderedDict()
|
||||
self._activated = False
|
||||
|
||||
def __getstate__(self):
|
||||
return {'master_callback': self._master_callback}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(state['master_callback'])
|
||||
|
||||
def register_slave(self, identifier):
|
||||
"""
|
||||
Register an slave device.
|
||||
|
||||
Args:
|
||||
identifier: an identifier, usually is the device id.
|
||||
|
||||
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||
|
||||
"""
|
||||
if self._activated:
|
||||
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||
self._activated = False
|
||||
self._registry.clear()
|
||||
future = FutureResult()
|
||||
self._registry[identifier] = _MasterRegistry(future)
|
||||
return SlavePipe(identifier, self._queue, future)
|
||||
|
||||
def run_master(self, master_msg):
|
||||
"""
|
||||
Main entry for the master device in each forward pass.
|
||||
The messages were first collected from each devices (including the master device), and then
|
||||
an callback will be invoked to compute the message to be sent back to each devices
|
||||
(including the master device).
|
||||
|
||||
Args:
|
||||
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||
|
||||
Returns: the message to be sent back to the master device.
|
||||
|
||||
"""
|
||||
self._activated = True
|
||||
|
||||
intermediates = [(0, master_msg)]
|
||||
for i in range(self.nr_slaves):
|
||||
intermediates.append(self._queue.get())
|
||||
|
||||
results = self._master_callback(intermediates)
|
||||
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||
|
||||
for i, res in results:
|
||||
if i == 0:
|
||||
continue
|
||||
self._registry[i].result.put(res)
|
||||
|
||||
for i in range(self.nr_slaves):
|
||||
assert self._queue.get() is True
|
||||
|
||||
return results[0][1]
|
||||
|
||||
@property
|
||||
def nr_slaves(self):
|
||||
return len(self._registry)
|
|
@ -0,0 +1,94 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# File : replicate.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import functools
|
||||
|
||||
from torch.nn.parallel.data_parallel import DataParallel
|
||||
|
||||
__all__ = [
|
||||
'CallbackContext',
|
||||
'execute_replication_callbacks',
|
||||
'DataParallelWithCallback',
|
||||
'patch_replication_callback'
|
||||
]
|
||||
|
||||
|
||||
class CallbackContext(object):
|
||||
pass
|
||||
|
||||
|
||||
def execute_replication_callbacks(modules):
|
||||
"""
|
||||
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
||||
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
||||
(shared among multiple copies of this module on different devices).
|
||||
Through this context, different copies can share some information.
|
||||
|
||||
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
||||
of any slave copies.
|
||||
"""
|
||||
master_copy = modules[0]
|
||||
nr_modules = len(list(master_copy.modules()))
|
||||
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
||||
|
||||
for i, module in enumerate(modules):
|
||||
for j, m in enumerate(module.modules()):
|
||||
if hasattr(m, '__data_parallel_replicate__'):
|
||||
m.__data_parallel_replicate__(ctxs[j], i)
|
||||
|
||||
|
||||
class DataParallelWithCallback(DataParallel):
|
||||
"""
|
||||
Data Parallel with a replication callback.
|
||||
|
||||
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
||||
original `replicate` function.
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
# sync_bn.__data_parallel_replicate__ will be invoked.
|
||||
"""
|
||||
|
||||
def replicate(self, module, device_ids):
|
||||
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
|
||||
def patch_replication_callback(data_parallel):
|
||||
"""
|
||||
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
||||
Useful when you have customized `DataParallel` implementation.
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
||||
> patch_replication_callback(sync_bn)
|
||||
# this is equivalent to
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
"""
|
||||
|
||||
assert isinstance(data_parallel, DataParallel)
|
||||
|
||||
old_replicate = data_parallel.replicate
|
||||
|
||||
@functools.wraps(old_replicate)
|
||||
def new_replicate(module, device_ids):
|
||||
modules = old_replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
data_parallel.replicate = new_replicate
|
|
@ -0,0 +1,29 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# File : unittest.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
|
||||
class TorchTestCase(unittest.TestCase):
|
||||
def assertTensorClose(self, x, y):
|
||||
adiff = float((x - y).abs().max())
|
||||
if (y == 0).all():
|
||||
rdiff = 'NaN'
|
||||
else:
|
||||
rdiff = float((adiff / y).abs().max())
|
||||
|
||||
message = (
|
||||
'Tensor close check failed\n'
|
||||
'adiff={}\n'
|
||||
'rdiff={}\n'
|
||||
).format(adiff, rdiff)
|
||||
self.assertTrue(torch.allclose(x, y), message)
|
||||
|
|
@ -13,6 +13,9 @@ def make_optimizer(cfg, model, num_gpus=1):
|
|||
if not value.requires_grad:
|
||||
continue
|
||||
# linear scaling rule, https://arxiv.org/abs/1706.02677
|
||||
# lr and batch_size will both be multiplied by the number gpu
|
||||
# if you want to test same batch size on multiple gpus and one gpu,
|
||||
# you will need to manually adjust the learning rate and batch size in config files
|
||||
lr = cfg.SOLVER.BASE_LR * num_gpus
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||
if "bias" in key:
|
||||
|
|
|
@ -21,16 +21,15 @@ from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR
|
|||
|
||||
from utils.logger import setup_logger
|
||||
|
||||
sys.path.append('..')
|
||||
from sync_batchnorm import convert_model
|
||||
from layers.sync_batchnorm import convert_model
|
||||
|
||||
|
||||
def train(cfg, num_gpus):
|
||||
# prepare dataset
|
||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
|
||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg, num_gpus)
|
||||
|
||||
# prepare model
|
||||
model = build_model(cfg, num_classes, num_gpus)
|
||||
model = build_model(cfg, num_classes)
|
||||
if num_gpus > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
# if using mulitple gpus, convert the model to use SyncBN
|
||||
|
@ -91,8 +90,6 @@ def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
||||
|
||||
if args.config_file != "":
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
|
@ -102,6 +99,10 @@ def main():
|
|||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
if cfg.MODEL.DEVICE == "cuda":
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu
|
||||
num_gpus = torch.cuda.device_count()
|
||||
|
||||
logger = setup_logger("reid_baseline", output_dir, 0)
|
||||
logger.info("Using {} GPUS".format(num_gpus))
|
||||
logger.info(args)
|
||||
|
@ -113,9 +114,6 @@ def main():
|
|||
logger.info(config_str)
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
|
||||
if cfg.MODEL.DEVICE == "cuda":
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu
|
||||
num_gpus = torch.cuda.device_count()
|
||||
cudnn.benchmark = True
|
||||
train(cfg, num_gpus)
|
||||
|
||||
|
|
Loading…
Reference in New Issue