fix RMSProp one_dim_param_no_weight_decay
parent
c351dac67e
commit
b66ee6384b
|
@ -232,26 +232,26 @@ class RMSProp(object):
|
|||
def __call__(self, model_list):
|
||||
# model_list is None in static graph
|
||||
parameters = None
|
||||
if len(self.no_weight_decay_name_list) > 0:
|
||||
if model_list:
|
||||
params_with_decay = []
|
||||
params_without_decay = []
|
||||
for m in model_list:
|
||||
params = [p for n, p in m.named_parameters() \
|
||||
if not any(nd in n for nd in self.no_weight_decay_name_list)]
|
||||
params_with_decay.extend(params)
|
||||
params = [p for n, p in m.named_parameters() \
|
||||
if any(nd in n for nd in self.no_weight_decay_name_list) or (self.one_dim_param_no_weight_decay and len(p.shape) == 1)]
|
||||
params_without_decay.extend(params)
|
||||
parameters = [{
|
||||
"params": params_with_decay,
|
||||
"weight_decay": self.weight_decay
|
||||
}, {
|
||||
"params": params_without_decay,
|
||||
"weight_decay": 0.0
|
||||
}]
|
||||
else:
|
||||
parameters = sum([m.parameters() for m in model_list],
|
||||
[]) if model_list else None
|
||||
for n, p in m.named_parameters():
|
||||
if any(nd in n for nd in self.no_weight_decay_name_list) \
|
||||
or (self.one_dim_param_no_weight_decay and len(p.shape) == 1):
|
||||
params_without_decay.append(p)
|
||||
else:
|
||||
params_with_decay.append(p)
|
||||
if params_without_decay:
|
||||
parameters = [{
|
||||
"params": params_with_decay,
|
||||
"weight_decay": self.weight_decay
|
||||
}, {
|
||||
"params": params_without_decay,
|
||||
"weight_decay": 0.0
|
||||
}]
|
||||
else:
|
||||
parameters = params_with_decay
|
||||
opt = optim.RMSProp(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
|
|
Loading…
Reference in New Issue