mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Small post-merge tweak for freeze/unfreeze, add to __init__ for utils
This commit is contained in:
parent
5ca72dcc75
commit
e5da481073
@ -7,7 +7,7 @@ from .jit import set_jit_legacy
|
|||||||
from .log import setup_default_logging, FormatterNoInfo
|
from .log import setup_default_logging, FormatterNoInfo
|
||||||
from .metrics import AverageMeter, accuracy
|
from .metrics import AverageMeter, accuracy
|
||||||
from .misc import natural_key, add_bool_arg
|
from .misc import natural_key, add_bool_arg
|
||||||
from .model import unwrap_model, get_state_dict
|
from .model import unwrap_model, get_state_dict, freeze, unfreeze
|
||||||
from .model_ema import ModelEma, ModelEmaV2
|
from .model_ema import ModelEma, ModelEmaV2
|
||||||
from .random import random_seed
|
from .random import random_seed
|
||||||
from .summary import update_summary, get_outdir
|
from .summary import update_summary, get_outdir
|
||||||
|
@ -194,7 +194,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
|
|||||||
for n, m in zip(named_modules, submodules):
|
for n, m in zip(named_modules, submodules):
|
||||||
# (Un)freeze parameters
|
# (Un)freeze parameters
|
||||||
for p in m.parameters():
|
for p in m.parameters():
|
||||||
p.requires_grad = (False if mode == 'freeze' else True)
|
p.requires_grad = False if mode == 'freeze' else True
|
||||||
if include_bn_running_stats:
|
if include_bn_running_stats:
|
||||||
# Helper to add submodule specified as a named_module
|
# Helper to add submodule specified as a named_module
|
||||||
def _add_submodule(module, name, submodule):
|
def _add_submodule(module, name, submodule):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user