fix RMSProp one_dim_param_no_weight_decay

pull/2784/head
Yang Nie 2023-04-17 14:25:59 +08:00 committed by Tingquan Gao
parent c351dac67e
commit b66ee6384b
1 changed files with 17 additions and 17 deletions

View File

@ -232,16 +232,17 @@ class RMSProp(object):
def __call__(self, model_list): def __call__(self, model_list):
# model_list is None in static graph # model_list is None in static graph
parameters = None parameters = None
if len(self.no_weight_decay_name_list) > 0: if model_list:
params_with_decay = [] params_with_decay = []
params_without_decay = [] params_without_decay = []
for m in model_list: for m in model_list:
params = [p for n, p in m.named_parameters() \ for n, p in m.named_parameters():
if not any(nd in n for nd in self.no_weight_decay_name_list)] if any(nd in n for nd in self.no_weight_decay_name_list) \
params_with_decay.extend(params) or (self.one_dim_param_no_weight_decay and len(p.shape) == 1):
params = [p for n, p in m.named_parameters() \ params_without_decay.append(p)
if any(nd in n for nd in self.no_weight_decay_name_list) or (self.one_dim_param_no_weight_decay and len(p.shape) == 1)] else:
params_without_decay.extend(params) params_with_decay.append(p)
if params_without_decay:
parameters = [{ parameters = [{
"params": params_with_decay, "params": params_with_decay,
"weight_decay": self.weight_decay "weight_decay": self.weight_decay
@ -250,8 +251,7 @@ class RMSProp(object):
"weight_decay": 0.0 "weight_decay": 0.0
}] }]
else: else:
parameters = sum([m.parameters() for m in model_list], parameters = params_with_decay
[]) if model_list else None
opt = optim.RMSProp( opt = optim.RMSProp(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,