fix RMSProp one_dim_param_no_weight_decay
parent
c351dac67e
commit
b66ee6384b
|
@ -232,16 +232,17 @@ class RMSProp(object):
|
||||||
def __call__(self, model_list):
|
def __call__(self, model_list):
|
||||||
# model_list is None in static graph
|
# model_list is None in static graph
|
||||||
parameters = None
|
parameters = None
|
||||||
if len(self.no_weight_decay_name_list) > 0:
|
if model_list:
|
||||||
params_with_decay = []
|
params_with_decay = []
|
||||||
params_without_decay = []
|
params_without_decay = []
|
||||||
for m in model_list:
|
for m in model_list:
|
||||||
params = [p for n, p in m.named_parameters() \
|
for n, p in m.named_parameters():
|
||||||
if not any(nd in n for nd in self.no_weight_decay_name_list)]
|
if any(nd in n for nd in self.no_weight_decay_name_list) \
|
||||||
params_with_decay.extend(params)
|
or (self.one_dim_param_no_weight_decay and len(p.shape) == 1):
|
||||||
params = [p for n, p in m.named_parameters() \
|
params_without_decay.append(p)
|
||||||
if any(nd in n for nd in self.no_weight_decay_name_list) or (self.one_dim_param_no_weight_decay and len(p.shape) == 1)]
|
else:
|
||||||
params_without_decay.extend(params)
|
params_with_decay.append(p)
|
||||||
|
if params_without_decay:
|
||||||
parameters = [{
|
parameters = [{
|
||||||
"params": params_with_decay,
|
"params": params_with_decay,
|
||||||
"weight_decay": self.weight_decay
|
"weight_decay": self.weight_decay
|
||||||
|
@ -250,8 +251,7 @@ class RMSProp(object):
|
||||||
"weight_decay": 0.0
|
"weight_decay": 0.0
|
||||||
}]
|
}]
|
||||||
else:
|
else:
|
||||||
parameters = sum([m.parameters() for m in model_list],
|
parameters = params_with_decay
|
||||||
[]) if model_list else None
|
|
||||||
opt = optim.RMSProp(
|
opt = optim.RMSProp(
|
||||||
learning_rate=self.learning_rate,
|
learning_rate=self.learning_rate,
|
||||||
momentum=self.momentum,
|
momentum=self.momentum,
|
||||||
|
|
Loading…
Reference in New Issue