Freeze unfreeze functionality finalized. Tests added
parent
0cb8ea432c
commit
65c3d78b96
|
@ -0,0 +1,60 @@
|
|||
from torch.nn.modules.batchnorm import BatchNorm2d
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
import timm
|
||||
from timm.utils.model import freeze, unfreeze
|
||||
|
||||
|
||||
def test_freeze_unfreeze():
|
||||
model = timm.create_model('resnet18')
|
||||
|
||||
# Freeze all
|
||||
freeze(model)
|
||||
# Check top level module
|
||||
assert model.fc.weight.requires_grad == False
|
||||
# Check submodule
|
||||
assert model.layer1[0].conv1.weight.requires_grad == False
|
||||
# Check BN
|
||||
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||||
|
||||
# Unfreeze all
|
||||
unfreeze(model)
|
||||
# Check top level module
|
||||
assert model.fc.weight.requires_grad == True
|
||||
# Check submodule
|
||||
assert model.layer1[0].conv1.weight.requires_grad == True
|
||||
# Check BN
|
||||
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||||
|
||||
# Freeze some
|
||||
freeze(model, ['layer1', 'layer2.0'])
|
||||
# Check frozen
|
||||
assert model.layer1[0].conv1.weight.requires_grad == False
|
||||
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||||
assert model.layer2[0].conv1.weight.requires_grad == False
|
||||
# Check not frozen
|
||||
assert model.layer3[0].conv1.weight.requires_grad == True
|
||||
assert isinstance(model.layer3[0].bn1, BatchNorm2d)
|
||||
assert model.layer2[1].conv1.weight.requires_grad == True
|
||||
|
||||
# Unfreeze some
|
||||
unfreeze(model, ['layer1', 'layer2.0'])
|
||||
# Check not frozen
|
||||
assert model.layer1[0].conv1.weight.requires_grad == True
|
||||
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||||
assert model.layer2[0].conv1.weight.requires_grad == True
|
||||
|
||||
# Freeze BN
|
||||
# From root
|
||||
freeze(model, ['layer1.0.bn1'])
|
||||
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||||
# From direct parent
|
||||
freeze(model.layer1[0], ['bn1'])
|
||||
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||||
|
||||
# Unfreeze BN
|
||||
unfreeze(model, ['layer1.0.bn1'])
|
||||
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||||
# From direct parent
|
||||
unfreeze(model.layer1[0], ['bn1'])
|
||||
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
|
@ -3,7 +3,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
class GroupNorm(nn.GroupNorm):
|
||||
|
@ -23,42 +22,3 @@ class LayerNorm2d(nn.LayerNorm):
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Inherits from torchvision while adding the `convert_frozen_batchnorm` from
|
||||
https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def convert_frozen_batchnorm(cls, module):
|
||||
"""
|
||||
Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it
|
||||
converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are
|
||||
converted in place.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: Resulting module
|
||||
"""
|
||||
res = module
|
||||
if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
||||
res = cls(module.num_features)
|
||||
if module.affine:
|
||||
res.weight.data = module.weight.data.clone().detach()
|
||||
res.bias.data = module.bias.data.clone().detach()
|
||||
res.running_mean.data = module.running_mean.data
|
||||
res.running_var.data = module.running_var.data
|
||||
res.eps = module.eps
|
||||
else:
|
||||
for name, child in module.named_children():
|
||||
new_child = cls.convert_frozen_batchnorm(child)
|
||||
if new_child is not child:
|
||||
res.add_module(name, new_child)
|
||||
return res
|
||||
|
||||
|
|
|
@ -4,16 +4,12 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
"""
|
||||
from logging import root
|
||||
from typing import Sequence
|
||||
import re
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import fnmatch
|
||||
|
||||
from torch.nn.modules import module
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from .model_ema import ModelEma
|
||||
from timm.models.layers.norm import FrozenBatchNorm2d
|
||||
|
||||
|
||||
def unwrap_model(model):
|
||||
|
@ -99,55 +95,172 @@ def extract_spp_stats(model,
|
|||
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
|
||||
_ = model(x)
|
||||
return hook.stats
|
||||
|
||||
|
||||
def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True):
|
||||
|
||||
def freeze_batch_norm_2d(module):
|
||||
"""
|
||||
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
||||
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
||||
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Any PyTorch module.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: Resulting module
|
||||
"""
|
||||
res = module
|
||||
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
||||
res = FrozenBatchNorm2d(module.num_features)
|
||||
res.num_features = module.num_features
|
||||
res.affine = module.affine
|
||||
if module.affine:
|
||||
res.weight.data = module.weight.data.clone().detach()
|
||||
res.bias.data = module.bias.data.clone().detach()
|
||||
res.running_mean.data = module.running_mean.data
|
||||
res.running_var.data = module.running_var.data
|
||||
res.eps = module.eps
|
||||
else:
|
||||
for name, child in module.named_children():
|
||||
new_child = freeze_batch_norm_2d(child)
|
||||
if new_child is not child:
|
||||
res.add_module(name, new_child)
|
||||
return res
|
||||
|
||||
|
||||
def unfreeze_batch_norm_2d(module):
|
||||
"""
|
||||
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
|
||||
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
|
||||
recursively and submodules are converted in place.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Any PyTorch module.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: Resulting module
|
||||
"""
|
||||
res = module
|
||||
if isinstance(module, FrozenBatchNorm2d):
|
||||
res = torch.nn.BatchNorm2d(module.num_features)
|
||||
if module.affine:
|
||||
res.weight.data = module.weight.data.clone().detach()
|
||||
res.bias.data = module.bias.data.clone().detach()
|
||||
res.running_mean.data = module.running_mean.data
|
||||
res.running_var.data = module.running_var.data
|
||||
res.eps = module.eps
|
||||
else:
|
||||
for name, child in module.named_children():
|
||||
new_child = unfreeze_batch_norm_2d(child)
|
||||
if new_child is not child:
|
||||
res.add_module(name, new_child)
|
||||
return res
|
||||
|
||||
|
||||
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
|
||||
"""
|
||||
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
|
||||
done in place.
|
||||
Args:
|
||||
modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be
|
||||
(un)frozen. If a string or strings are provided these will be interpreted according to the named modules
|
||||
of the provided ``root_module``.
|
||||
root_module (nn.Module, optional): Root module relative to which named modules (accessible via
|
||||
``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified
|
||||
with a string or strings. Defaults to `None`.
|
||||
include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers.
|
||||
root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
|
||||
submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
|
||||
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
|
||||
means that the whole root module will be (un)frozen. Defaults to []
|
||||
include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers.
|
||||
Defaults to `True`.
|
||||
mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`.
|
||||
|
||||
TODO before finalizing PR: Implement unfreezing of batch norm
|
||||
mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`.
|
||||
"""
|
||||
|
||||
if not isinstance(modules, Sequence):
|
||||
modules = [modules]
|
||||
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
|
||||
|
||||
if isinstance(modules[0], str):
|
||||
assert root_module is not None, \
|
||||
"When providing strings for the `modules` argument, a `root_module` must be provided"
|
||||
module_names = modules
|
||||
modules = [root_module.get_submodule(m) for m in module_names]
|
||||
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
||||
# Raise assertion here because we can't convert it in place
|
||||
raise AssertionError(
|
||||
"You have provided a batch norm layer as the `root module`. Please use "
|
||||
"`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.")
|
||||
|
||||
for n, m in zip(module_names, modules):
|
||||
if isinstance(submodules, str):
|
||||
submodules = [submodules]
|
||||
|
||||
named_modules = submodules
|
||||
submodules = [root_module.get_submodule(m) for m in submodules]
|
||||
|
||||
if not(len(submodules)):
|
||||
named_modules, submodules = list(zip(*root_module.named_children()))
|
||||
|
||||
for n, m in zip(named_modules, submodules):
|
||||
# (Un)freeze parameters
|
||||
for p in m.parameters():
|
||||
p.requires_grad = (not mode)
|
||||
p.requires_grad = (False if mode == 'freeze' else True)
|
||||
if include_bn_running_stats:
|
||||
res = FrozenBatchNorm2d.convert_frozen_batchnorm(m)
|
||||
# It's possible that `m` is a type of BatchNorm in itself, in which case
|
||||
# `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted
|
||||
# result. In this case `res` holds the converted result and we may try to re-assign the named module
|
||||
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
||||
if module_names is not None and root_module is not None:
|
||||
root_module.add_module(n, res)
|
||||
# Helper to add submodule specified as a named_module
|
||||
def _add_submodule(module, name, submodule):
|
||||
split = name.rsplit('.', 1)
|
||||
if len(split) > 1:
|
||||
module.get_submodule(split[0]).add_module(split[1], submodule)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling "
|
||||
"`freeze` with a list of module names while providing a `root_module` argument.")
|
||||
module.add_module(name, submodule)
|
||||
# Freeze batch norm
|
||||
if mode == 'freeze':
|
||||
res = freeze_batch_norm_2d(m)
|
||||
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
|
||||
# convert it in place, but will return the converted result. In this case `res` holds the converted
|
||||
# result and we may try to re-assign the named module
|
||||
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
||||
_add_submodule(root_module, n, res)
|
||||
# Unfreeze batch norm
|
||||
else:
|
||||
res = unfreeze_batch_norm_2d(m)
|
||||
# Ditto. See note above in mode == 'freeze' branch
|
||||
if isinstance(m, FrozenBatchNorm2d):
|
||||
_add_submodule(root_module, n, res)
|
||||
|
||||
|
||||
def unfreeze(modules, root_module=None, include_bn_running_stats=True):
|
||||
def freeze(root_module, submodules=[], include_bn_running_stats=True):
|
||||
"""
|
||||
Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further
|
||||
information.
|
||||
Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
|
||||
Args:
|
||||
root_module (nn.Module): Root module relative to which `submodules` are referenced.
|
||||
submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
|
||||
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
|
||||
means that the whole root module will be frozen. Defaults to `[]`.
|
||||
include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and
|
||||
`SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning,
|
||||
it's good practice to freeze batch norm stats. And note that these are different to the affine parameters
|
||||
which are just normal PyTorch parameters. Defaults to `True`.
|
||||
|
||||
Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> model = timm.create_model('resnet18')
|
||||
>>> # Freeze up to and including layer2
|
||||
>>> submodules = [n for n, _ in model.named_children()]
|
||||
>>> print(submodules)
|
||||
['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc']
|
||||
>>> freeze(model, submodules[:submodules.index('layer2') + 1])
|
||||
>>> # Check for yourself that it works as expected
|
||||
>>> print(model.layer2[0].conv1.weight.requires_grad)
|
||||
False
|
||||
>>> print(model.layer3[0].conv1.weight.requires_grad)
|
||||
True
|
||||
>>> # Unfreeze
|
||||
>>> unfreeze(model)
|
||||
"""
|
||||
freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False)
|
||||
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze")
|
||||
|
||||
|
||||
def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
|
||||
"""
|
||||
Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
|
||||
Args:
|
||||
root_module (nn.Module): Root module relative to which `submodules` are referenced.
|
||||
submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided
|
||||
as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty
|
||||
list means that the whole root module will be unfrozen. Defaults to `[]`.
|
||||
include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers.
|
||||
These will be converted to `BatchNorm2d` in place. Defaults to `True`.
|
||||
|
||||
See example in docstring for `freeze`.
|
||||
"""
|
||||
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
|
||||
|
Loading…
Reference in New Issue