mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove layer-decay print
This commit is contained in:
parent
e069249a2d
commit
33e30f8c8b
@ -1,7 +1,7 @@
|
|||||||
""" Optimizer Factory w/ Custom Weight Decay
|
""" Optimizer Factory w/ Custom Weight Decay
|
||||||
Hacked together by / Copyright 2021 Ross Wightman
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import json
|
import logging
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Optional, Callable, Tuple
|
from typing import Optional, Callable, Tuple
|
||||||
|
|
||||||
@ -31,6 +31,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def param_groups_weight_decay(
|
def param_groups_weight_decay(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -92,6 +94,7 @@ def param_groups_layer_decay(
|
|||||||
no_weight_decay_list: Tuple[str] = (),
|
no_weight_decay_list: Tuple[str] = (),
|
||||||
layer_decay: float = .75,
|
layer_decay: float = .75,
|
||||||
end_layer_decay: Optional[float] = None,
|
end_layer_decay: Optional[float] = None,
|
||||||
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Parameter groups for layer-wise lr decay & weight decay
|
Parameter groups for layer-wise lr decay & weight decay
|
||||||
@ -142,8 +145,9 @@ def param_groups_layer_decay(
|
|||||||
param_group_names[group_name]["param_names"].append(name)
|
param_group_names[group_name]["param_names"].append(name)
|
||||||
param_groups[group_name]["params"].append(param)
|
param_groups[group_name]["params"].append(param)
|
||||||
|
|
||||||
# FIXME temporary output to debug new feature
|
if verbose:
|
||||||
print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
import json
|
||||||
|
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
||||||
|
|
||||||
return list(param_groups.values())
|
return list(param_groups.values())
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user