mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update batchnorm freezing to handle NormAct variants, Add GroupNorm1Act, update BatchNormAct2d tracing change from PyTorch
This commit is contained in:
parent
a2c14c2064
commit
e520553e3e
@ -29,7 +29,8 @@ from .mixed_conv2d import MixedConv2d
|
|||||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
|
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
|
||||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||||
|
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
from .patch_embed import PatchEmbed, resample_patch_embed
|
from .patch_embed import PatchEmbed, resample_patch_embed
|
||||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
|
@ -17,6 +17,7 @@ from typing import Union, List, Optional, Any
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from .create_act import get_act_layer
|
from .create_act import get_act_layer
|
||||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
||||||
@ -77,7 +78,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
|||||||
if self.training and self.track_running_stats:
|
if self.training and self.track_running_stats:
|
||||||
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
||||||
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
||||||
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
|
self.num_batches_tracked.add_(1) # type: ignore[has-type]
|
||||||
if self.momentum is None: # use cumulative moving average
|
if self.momentum is None: # use cumulative moving average
|
||||||
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||||
else: # use exponential moving average
|
else: # use exponential moving average
|
||||||
@ -169,6 +170,159 @@ def convert_sync_batchnorm(module, process_group=None):
|
|||||||
return module_output
|
return module_output
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenBatchNormAct2d(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
BatchNormAct2d where the batch statistics and the affine parameters are fixed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
|
||||||
|
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
apply_act=True,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
drop_layer=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.register_buffer("weight", torch.ones(num_features))
|
||||||
|
self.register_buffer("bias", torch.zeros(num_features))
|
||||||
|
self.register_buffer("running_mean", torch.zeros(num_features))
|
||||||
|
self.register_buffer("running_var", torch.ones(num_features))
|
||||||
|
|
||||||
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||||
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||||
|
if act_layer is not None and apply_act:
|
||||||
|
act_args = dict(inplace=True) if inplace else {}
|
||||||
|
self.act = act_layer(**act_args)
|
||||||
|
else:
|
||||||
|
self.act = nn.Identity()
|
||||||
|
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict: dict,
|
||||||
|
prefix: str,
|
||||||
|
local_metadata: dict,
|
||||||
|
strict: bool,
|
||||||
|
missing_keys: List[str],
|
||||||
|
unexpected_keys: List[str],
|
||||||
|
error_msgs: List[str],
|
||||||
|
):
|
||||||
|
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||||
|
if num_batches_tracked_key in state_dict:
|
||||||
|
del state_dict[num_batches_tracked_key]
|
||||||
|
|
||||||
|
super()._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# move reshapes to the beginning
|
||||||
|
# to make it fuser-friendly
|
||||||
|
w = self.weight.reshape(1, -1, 1, 1)
|
||||||
|
b = self.bias.reshape(1, -1, 1, 1)
|
||||||
|
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||||
|
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||||
|
scale = w * (rv + self.eps).rsqrt()
|
||||||
|
bias = b - rm * scale
|
||||||
|
x = x * scale + bias
|
||||||
|
x = self.act(self.drop(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_batch_norm_2d(module):
|
||||||
|
"""
|
||||||
|
Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
|
||||||
|
of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): Any PyTorch module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: Resulting module
|
||||||
|
|
||||||
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
||||||
|
"""
|
||||||
|
res = module
|
||||||
|
if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
|
||||||
|
res = FrozenBatchNormAct2d(module.num_features)
|
||||||
|
res.num_features = module.num_features
|
||||||
|
res.affine = module.affine
|
||||||
|
if module.affine:
|
||||||
|
res.weight.data = module.weight.data.clone().detach()
|
||||||
|
res.bias.data = module.bias.data.clone().detach()
|
||||||
|
res.running_mean.data = module.running_mean.data
|
||||||
|
res.running_var.data = module.running_var.data
|
||||||
|
res.eps = module.eps
|
||||||
|
res.drop = module.drop
|
||||||
|
res.act = module.act
|
||||||
|
elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
||||||
|
res = FrozenBatchNorm2d(module.num_features)
|
||||||
|
res.num_features = module.num_features
|
||||||
|
res.affine = module.affine
|
||||||
|
if module.affine:
|
||||||
|
res.weight.data = module.weight.data.clone().detach()
|
||||||
|
res.bias.data = module.bias.data.clone().detach()
|
||||||
|
res.running_mean.data = module.running_mean.data
|
||||||
|
res.running_var.data = module.running_var.data
|
||||||
|
res.eps = module.eps
|
||||||
|
else:
|
||||||
|
for name, child in module.named_children():
|
||||||
|
new_child = freeze_batch_norm_2d(child)
|
||||||
|
if new_child is not child:
|
||||||
|
res.add_module(name, new_child)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def unfreeze_batch_norm_2d(module):
|
||||||
|
"""
|
||||||
|
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
|
||||||
|
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
|
||||||
|
recursively and submodules are converted in place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): Any PyTorch module.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.nn.Module: Resulting module
|
||||||
|
|
||||||
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
||||||
|
"""
|
||||||
|
res = module
|
||||||
|
if isinstance(module, FrozenBatchNormAct2d):
|
||||||
|
res = BatchNormAct2d(module.num_features)
|
||||||
|
if module.affine:
|
||||||
|
res.weight.data = module.weight.data.clone().detach()
|
||||||
|
res.bias.data = module.bias.data.clone().detach()
|
||||||
|
res.running_mean.data = module.running_mean.data
|
||||||
|
res.running_var.data = module.running_var.data
|
||||||
|
res.eps = module.eps
|
||||||
|
res.drop = module.drop
|
||||||
|
res.act = module.act
|
||||||
|
elif isinstance(module, FrozenBatchNorm2d):
|
||||||
|
res = torch.nn.BatchNorm2d(module.num_features)
|
||||||
|
if module.affine:
|
||||||
|
res.weight.data = module.weight.data.clone().detach()
|
||||||
|
res.bias.data = module.bias.data.clone().detach()
|
||||||
|
res.running_mean.data = module.running_mean.data
|
||||||
|
res.running_var.data = module.running_var.data
|
||||||
|
res.eps = module.eps
|
||||||
|
else:
|
||||||
|
for name, child in module.named_children():
|
||||||
|
new_child = unfreeze_batch_norm_2d(child)
|
||||||
|
if new_child is not child:
|
||||||
|
res.add_module(name, new_child)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _num_groups(num_channels, num_groups, group_size):
|
def _num_groups(num_channels, num_groups, group_size):
|
||||||
if group_size:
|
if group_size:
|
||||||
assert num_channels % group_size == 0
|
assert num_channels % group_size == 0
|
||||||
@ -179,10 +333,54 @@ def _num_groups(num_channels, num_groups, group_size):
|
|||||||
class GroupNormAct(nn.GroupNorm):
|
class GroupNormAct(nn.GroupNorm):
|
||||||
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
|
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
|
self,
|
||||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
num_channels,
|
||||||
|
num_groups=32,
|
||||||
|
eps=1e-5,
|
||||||
|
affine=True,
|
||||||
|
group_size=None,
|
||||||
|
apply_act=True,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
drop_layer=None,
|
||||||
|
):
|
||||||
super(GroupNormAct, self).__init__(
|
super(GroupNormAct, self).__init__(
|
||||||
_num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
|
_num_groups(num_channels, num_groups, group_size),
|
||||||
|
num_channels,
|
||||||
|
eps=eps,
|
||||||
|
affine=affine,
|
||||||
|
)
|
||||||
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||||
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||||
|
if act_layer is not None and apply_act:
|
||||||
|
act_args = dict(inplace=True) if inplace else {}
|
||||||
|
self.act = act_layer(**act_args)
|
||||||
|
else:
|
||||||
|
self.act = nn.Identity()
|
||||||
|
self._fast_norm = is_fast_norm()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self._fast_norm:
|
||||||
|
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||||
|
else:
|
||||||
|
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm1Act(nn.GroupNorm):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels,
|
||||||
|
eps=1e-5,
|
||||||
|
affine=True,
|
||||||
|
apply_act=True,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
drop_layer=None,
|
||||||
|
):
|
||||||
|
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
|
||||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||||
if act_layer is not None and apply_act:
|
if act_layer is not None and apply_act:
|
||||||
@ -204,8 +402,15 @@ class GroupNormAct(nn.GroupNorm):
|
|||||||
|
|
||||||
class LayerNormAct(nn.LayerNorm):
|
class LayerNormAct(nn.LayerNorm):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
|
self,
|
||||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
normalization_shape: Union[int, List[int], torch.Size],
|
||||||
|
eps=1e-5,
|
||||||
|
affine=True,
|
||||||
|
apply_act=True,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
drop_layer=None,
|
||||||
|
):
|
||||||
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
|
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
|
||||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||||
@ -228,8 +433,15 @@ class LayerNormAct(nn.LayerNorm):
|
|||||||
|
|
||||||
class LayerNormAct2d(nn.LayerNorm):
|
class LayerNormAct2d(nn.LayerNorm):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_channels, eps=1e-5, affine=True,
|
self,
|
||||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
num_channels,
|
||||||
|
eps=1e-5,
|
||||||
|
affine=True,
|
||||||
|
apply_act=True,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
inplace=True,
|
||||||
|
drop_layer=None,
|
||||||
|
):
|
||||||
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
|
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||||
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
||||||
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
||||||
|
@ -7,6 +7,8 @@ import fnmatch
|
|||||||
import torch
|
import torch
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
|
from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
|
||||||
|
freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||||
from .model_ema import ModelEma
|
from .model_ema import ModelEma
|
||||||
|
|
||||||
|
|
||||||
@ -100,70 +102,6 @@ def extract_spp_stats(
|
|||||||
return hook.stats
|
return hook.stats
|
||||||
|
|
||||||
|
|
||||||
def freeze_batch_norm_2d(module):
|
|
||||||
"""
|
|
||||||
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
|
||||||
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
|
||||||
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (torch.nn.Module): Any PyTorch module.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.nn.Module: Resulting module
|
|
||||||
|
|
||||||
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
|
||||||
"""
|
|
||||||
res = module
|
|
||||||
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
|
||||||
res = FrozenBatchNorm2d(module.num_features)
|
|
||||||
res.num_features = module.num_features
|
|
||||||
res.affine = module.affine
|
|
||||||
if module.affine:
|
|
||||||
res.weight.data = module.weight.data.clone().detach()
|
|
||||||
res.bias.data = module.bias.data.clone().detach()
|
|
||||||
res.running_mean.data = module.running_mean.data
|
|
||||||
res.running_var.data = module.running_var.data
|
|
||||||
res.eps = module.eps
|
|
||||||
else:
|
|
||||||
for name, child in module.named_children():
|
|
||||||
new_child = freeze_batch_norm_2d(child)
|
|
||||||
if new_child is not child:
|
|
||||||
res.add_module(name, new_child)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def unfreeze_batch_norm_2d(module):
|
|
||||||
"""
|
|
||||||
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
|
|
||||||
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
|
|
||||||
recursively and submodules are converted in place.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (torch.nn.Module): Any PyTorch module.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.nn.Module: Resulting module
|
|
||||||
|
|
||||||
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
|
||||||
"""
|
|
||||||
res = module
|
|
||||||
if isinstance(module, FrozenBatchNorm2d):
|
|
||||||
res = torch.nn.BatchNorm2d(module.num_features)
|
|
||||||
if module.affine:
|
|
||||||
res.weight.data = module.weight.data.clone().detach()
|
|
||||||
res.bias.data = module.bias.data.clone().detach()
|
|
||||||
res.running_mean.data = module.running_mean.data
|
|
||||||
res.running_var.data = module.running_var.data
|
|
||||||
res.eps = module.eps
|
|
||||||
else:
|
|
||||||
for name, child in module.named_children():
|
|
||||||
new_child = unfreeze_batch_norm_2d(child)
|
|
||||||
if new_child is not child:
|
|
||||||
res.add_module(name, new_child)
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
|
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
|
||||||
"""
|
"""
|
||||||
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
|
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
|
||||||
@ -179,7 +117,12 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
|
|||||||
"""
|
"""
|
||||||
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
|
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
|
||||||
|
|
||||||
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
if isinstance(root_module, (
|
||||||
|
torch.nn.modules.batchnorm.BatchNorm2d,
|
||||||
|
torch.nn.modules.batchnorm.SyncBatchNorm,
|
||||||
|
BatchNormAct2d,
|
||||||
|
SyncBatchNormAct,
|
||||||
|
)):
|
||||||
# Raise assertion here because we can't convert it in place
|
# Raise assertion here because we can't convert it in place
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"You have provided a batch norm layer as the `root module`. Please use "
|
"You have provided a batch norm layer as the `root module`. Please use "
|
||||||
@ -213,13 +156,18 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
|
|||||||
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
|
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
|
||||||
# convert it in place, but will return the converted result. In this case `res` holds the converted
|
# convert it in place, but will return the converted result. In this case `res` holds the converted
|
||||||
# result and we may try to re-assign the named module
|
# result and we may try to re-assign the named module
|
||||||
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
if isinstance(m, (
|
||||||
|
torch.nn.modules.batchnorm.BatchNorm2d,
|
||||||
|
torch.nn.modules.batchnorm.SyncBatchNorm,
|
||||||
|
BatchNormAct2d,
|
||||||
|
SyncBatchNormAct,
|
||||||
|
)):
|
||||||
_add_submodule(root_module, n, res)
|
_add_submodule(root_module, n, res)
|
||||||
# Unfreeze batch norm
|
# Unfreeze batch norm
|
||||||
else:
|
else:
|
||||||
res = unfreeze_batch_norm_2d(m)
|
res = unfreeze_batch_norm_2d(m)
|
||||||
# Ditto. See note above in mode == 'freeze' branch
|
# Ditto. See note above in mode == 'freeze' branch
|
||||||
if isinstance(m, FrozenBatchNorm2d):
|
if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)):
|
||||||
_add_submodule(root_module, n, res)
|
_add_submodule(root_module, n, res)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user