Merge pull request #1317 from rwightman/fixes-syncbn_pretrain_cfg_resolve
Fix SyncBatchNorm for BatchNormAc2d, improve resolve_pretrained_cfg behaviour, other mix fixes.pull/1322/head
commit
beef62e7ab
|
@ -61,7 +61,7 @@ from .xcit import *
|
|||
from .factory import create_model, parse_model_name, safe_model_name
|
||||
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model
|
||||
from .layers import convert_splitbn_model, convert_sync_batchnorm
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
|
||||
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value
|
||||
|
|
|
@ -455,18 +455,26 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
|||
filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
|
||||
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None):
|
||||
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
|
||||
if pretrained_cfg and isinstance(pretrained_cfg, dict):
|
||||
# highest priority, pretrained_cfg available and passed explicitly
|
||||
# highest priority, pretrained_cfg available and passed as arg
|
||||
return deepcopy(pretrained_cfg)
|
||||
if kwargs and 'pretrained_cfg' in kwargs:
|
||||
# next highest, pretrained_cfg in a kwargs dict, pop and return
|
||||
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
|
||||
if pretrained_cfg:
|
||||
return deepcopy(pretrained_cfg)
|
||||
# lookup pretrained cfg in model registry by variant
|
||||
# fallback to looking up pretrained cfg in model registry by variant identifier
|
||||
pretrained_cfg = get_pretrained_cfg(variant)
|
||||
assert pretrained_cfg
|
||||
if not pretrained_cfg:
|
||||
_logger.warning(
|
||||
f"No pretrained configuration specified for {variant} model. Using a default."
|
||||
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
|
||||
pretrained_cfg = dict(
|
||||
url='',
|
||||
num_classes=1000,
|
||||
input_size=(3, 224, 224),
|
||||
pool_size=None,
|
||||
crop_pct=.9,
|
||||
interpolation='bicubic',
|
||||
first_conv='',
|
||||
classifier='',
|
||||
)
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
|
|
|
@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3):
|
|||
|
||||
|
||||
def _create_inception_v3(variant, pretrained=False, **kwargs):
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
|
||||
aux_logits = kwargs.pop('aux_logits', False)
|
||||
if aux_logits:
|
||||
assert not kwargs.pop('features_only', False)
|
||||
|
|
|
@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d
|
|||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
|
|
|
@ -164,3 +164,6 @@ class DropPath(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
||||
|
|
|
@ -256,8 +256,9 @@ class EvoNorm2dS0a(EvoNorm2dS0):
|
|||
class EvoNorm2dS1(nn.Module):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
|
||||
apply_act=True, act_layer=None, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
act_layer = act_layer or nn.SiLU
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = create_act_layer(act_layer)
|
||||
|
@ -290,7 +291,7 @@ class EvoNorm2dS1(nn.Module):
|
|||
class EvoNorm2dS1a(EvoNorm2dS1):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
|
||||
apply_act=True, act_layer=None, eps=1e-3, **_):
|
||||
super().__init__(
|
||||
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
|
||||
|
||||
|
@ -305,8 +306,9 @@ class EvoNorm2dS1a(EvoNorm2dS1):
|
|||
class EvoNorm2dS2(nn.Module):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
|
||||
apply_act=True, act_layer=None, eps=1e-5, **_):
|
||||
super().__init__()
|
||||
act_layer = act_layer or nn.SiLU
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = create_act_layer(act_layer)
|
||||
|
@ -338,7 +340,7 @@ class EvoNorm2dS2(nn.Module):
|
|||
class EvoNorm2dS2a(EvoNorm2dS2):
|
||||
def __init__(
|
||||
self, num_features, groups=32, group_size=None,
|
||||
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
|
||||
apply_act=True, act_layer=None, eps=1e-3, **_):
|
||||
super().__init__(
|
||||
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
""" Normalization + Activation Layers
|
||||
"""
|
||||
from typing import Union, List
|
||||
from typing import Union, List, Optional, Any
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
@ -18,10 +18,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
|||
instead of composing it as a .bn member.
|
||||
"""
|
||||
def __init__(
|
||||
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
||||
self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
apply_act=True,
|
||||
act_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
drop_layer=None,
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
try:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
|
||||
**factory_kwargs
|
||||
)
|
||||
except TypeError:
|
||||
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||
if act_layer is not None and apply_act:
|
||||
|
@ -81,6 +100,62 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
|||
return x
|
||||
|
||||
|
||||
class SyncBatchNormAct(nn.SyncBatchNorm):
|
||||
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
|
||||
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
|
||||
# but ONLY when used in conjunction with the timm conversion function below.
|
||||
# Do not create this module directly or use the PyTorch conversion function.
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
|
||||
if hasattr(self, "drop"):
|
||||
x = self.drop(x)
|
||||
if hasattr(self, "act"):
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
def convert_sync_batchnorm(module, process_group=None):
|
||||
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||
if isinstance(module, BatchNormAct2d):
|
||||
# convert timm norm + act layer
|
||||
module_output = SyncBatchNormAct(
|
||||
module.num_features,
|
||||
module.eps,
|
||||
module.momentum,
|
||||
module.affine,
|
||||
module.track_running_stats,
|
||||
process_group=process_group,
|
||||
)
|
||||
# set act and drop attr from the original module
|
||||
module_output.act = module.act
|
||||
module_output.drop = module.drop
|
||||
else:
|
||||
# convert standard BatchNorm layers
|
||||
module_output = torch.nn.SyncBatchNorm(
|
||||
module.num_features,
|
||||
module.eps,
|
||||
module.momentum,
|
||||
module.affine,
|
||||
module.track_running_stats,
|
||||
process_group,
|
||||
)
|
||||
if module.affine:
|
||||
with torch.no_grad():
|
||||
module_output.weight = module.weight
|
||||
module_output.bias = module.bias
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
if hasattr(module, "qconfig"):
|
||||
module_output.qconfig = module.qconfig
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
def _num_groups(num_channels, num_groups, group_size):
|
||||
if group_size:
|
||||
assert num_channels % group_size == 0
|
||||
|
|
|
@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
|
||||
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
|
||||
model = build_model_with_cfg(
|
||||
VisionTransformer, variant, pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
|
|
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
|
||||
from .registry import register_model
|
||||
|
||||
|
|
20
train.py
20
train.py
|
@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
|||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import yaml
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
|
@ -26,14 +25,15 @@ from datetime import datetime
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.utils
|
||||
import yaml
|
||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
|
||||
convert_splitbn_model, model_parameters
|
||||
from timm import utils
|
||||
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
|
||||
LabelSmoothingCrossEntropy
|
||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
|
||||
convert_splitbn_model, convert_sync_batchnorm, model_parameters
|
||||
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
||||
from timm.scheduler import create_scheduler
|
||||
from timm.utils import ApexScaler, NativeScaler
|
||||
|
@ -438,12 +438,14 @@ def main():
|
|||
|
||||
# setup synchronized BatchNorm for distributed training
|
||||
if args.distributed and args.sync_bn:
|
||||
args.dist_bn = '' # disable dist_bn when sync BN active
|
||||
assert not args.split_bn
|
||||
if has_apex and use_amp == 'apex':
|
||||
# Apex SyncBN preferred unless native amp is activated
|
||||
# Apex SyncBN used with Apex AMP
|
||||
# WARNING this won't currently work with models using BatchNormAct2d
|
||||
model = convert_syncbn_model(model)
|
||||
else:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = convert_sync_batchnorm(model)
|
||||
if args.local_rank == 0:
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
|
|
Loading…
Reference in New Issue