mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support BatchNormAct2d for sync-bn use. Fix #1254
This commit is contained in:
parent
7cedc8d474
commit
879df47c0a
@ -61,7 +61,7 @@ from .xcit import *
|
||||
from .factory import create_model, parse_model_name, safe_model_name
|
||||
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model
|
||||
from .layers import convert_splitbn_model, convert_sync_batchnorm
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
|
||||
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value
|
||||
|
@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
|
@ -1,10 +1,15 @@
|
||||
""" Normalization + Activation Layers
|
||||
"""
|
||||
from typing import Union, List
|
||||
from typing import Union, List, Optional, Any
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
try:
|
||||
from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm
|
||||
FULL_SYNC_BN = True
|
||||
except ImportError:
|
||||
FULL_SYNC_BN = False
|
||||
|
||||
from .trace_utils import _assert
|
||||
from .create_act import get_act_layer
|
||||
@ -18,10 +23,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
||||
instead of composing it as a .bn member.
|
||||
"""
|
||||
def __init__(
|
||||
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
||||
self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
try:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
|
||||
**factory_kwargs
|
||||
)
|
||||
except TypeError:
|
||||
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
@ -81,6 +105,62 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
||||
return x
|
||||
|
||||
|
||||
class SyncBatchNormAct(nn.SyncBatchNorm):
|
||||
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
|
||||
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
|
||||
# but ONLY when used in conjunction with the timm conversion function below.
|
||||
# Do not create this module directly or use the PyTorch conversion function.
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
|
||||
if hasattr(self, "drop"):
|
||||
x = self.drop(x)
|
||||
if hasattr(self, "act"):
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
def convert_sync_batchnorm(module, process_group=None):
|
||||
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||
if isinstance(module, BatchNormAct2d):
|
||||
# convert timm norm + act layer
|
||||
module_output = SyncBatchNormAct(
|
||||
module.num_features,
|
||||
module.eps,
|
||||
module.momentum,
|
||||
module.affine,
|
||||
module.track_running_stats,
|
||||
process_group=process_group,
|
||||
)
|
||||
# set act and drop attr from the original module
|
||||
module_output.act = module.act
|
||||
module_output.drop = module.drop
|
||||
else:
|
||||
# convert standard BatchNorm layers
|
||||
module_output = torch.nn.SyncBatchNorm(
|
||||
module.num_features,
|
||||
module.eps,
|
||||
module.momentum,
|
||||
module.affine,
|
||||
module.track_running_stats,
|
||||
process_group,
|
||||
)
|
||||
if module.affine:
|
||||
with torch.no_grad():
|
||||
module_output.weight = module.weight
|
||||
module_output.bias = module.bias
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
if hasattr(module, "qconfig"):
|
||||
module_output.qconfig = module.qconfig
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
def _num_groups(num_channels, num_groups, group_size):
|
||||
if group_size:
|
||||
assert num_channels % group_size == 0
|
||||
|
19
train.py
19
train.py
@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
||||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import yaml
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
@ -26,14 +25,15 @@ from datetime import datetime
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.utils
|
||||
import yaml
|
||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
|
||||
convert_splitbn_model, model_parameters
|
||||
from timm import utils
|
||||
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
|
||||
LabelSmoothingCrossEntropy
|
||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
|
||||
convert_splitbn_model, convert_sync_batchnorm, model_parameters
|
||||
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
||||
from timm.scheduler import create_scheduler
|
||||
from timm.utils import ApexScaler, NativeScaler
|
||||
@ -440,10 +440,11 @@ def main():
|
||||
if args.distributed and args.sync_bn:
|
||||
assert not args.split_bn
|
||||
if has_apex and use_amp == 'apex':
|
||||
# Apex SyncBN preferred unless native amp is activated
|
||||
# Apex SyncBN used with Apex AMP
|
||||
# WARNING this won't currently work with models using BatchNormAct2d
|
||||
model = convert_syncbn_model(model)
|
||||
else:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = convert_sync_batchnorm(model)
|
||||
if args.local_rank == 0:
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
|
Loading…
x
Reference in New Issue
Block a user