100 lines
2.6 KiB
Python
Raw Normal View History

2018-07-04 10:32:43 +01:00
from __future__ import absolute_import
2019-02-03 14:04:10 +00:00
from __future__ import print_function
2018-07-04 10:32:43 +01:00
2018-04-27 09:51:04 +01:00
import torch
2019-02-03 14:04:10 +00:00
import torch.nn as nn
2018-04-27 09:51:04 +01:00
2019-03-19 17:26:08 +00:00
AVAI_OPTIMS = ['adam', 'amsgrad', 'sgd', 'rmsprop']
def build_optimizer(
model,
optim='adam',
lr=0.0003,
weight_decay=5e-04,
momentum=0.9,
sgd_dampening=0,
sgd_nesterov=False,
rmsprop_alpha=0.99,
adam_beta1=0.9,
adam_beta2=0.99,
staged_lr=False,
new_layers=None,
base_lr_mult=0.1
2019-03-15 23:17:38 +00:00
):
2019-03-19 17:26:08 +00:00
if optim not in AVAI_OPTIMS:
raise ValueError('Unsupported optim: {}. Must be one of {}'.format(optim, AVAI_OPTIMS))
if not isinstance(model, nn.Module):
raise TypeError('model given to build_optimizer must be an instance of nn.Module')
2019-03-15 23:17:38 +00:00
2019-02-03 14:04:10 +00:00
if staged_lr:
2019-03-19 17:26:08 +00:00
if isinstance(new_layers, str):
new_layers = [new_layers]
if isinstance(model, nn.DataParallel):
model = model.module
2019-02-03 14:04:10 +00:00
base_params = []
base_layers = []
new_params = []
2019-03-19 17:26:08 +00:00
2019-02-03 14:04:10 +00:00
for name, module in model.named_children():
if name in new_layers:
new_params += [p for p in module.parameters()]
else:
base_params += [p for p in module.parameters()]
base_layers.append(name)
2019-03-19 17:26:08 +00:00
2019-02-03 14:04:10 +00:00
param_groups = [
{'params': base_params, 'lr': lr * base_lr_mult},
{'params': new_params},
]
2019-03-19 17:26:08 +00:00
2019-02-03 14:04:10 +00:00
print('Use staged learning rate')
2019-03-19 17:26:08 +00:00
print('Base layers (lr*{}): {}'.format(base_lr_mult, base_layers))
print('New layers (lr): {}'.format(new_layers))
2019-03-15 23:17:38 +00:00
2019-02-03 14:04:10 +00:00
else:
param_groups = model.parameters()
2019-03-15 23:17:38 +00:00
print('Initializing optimizer: {}'.format(optim))
2018-04-27 09:51:04 +01:00
if optim == 'adam':
2019-03-15 23:17:38 +00:00
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
2018-07-02 13:56:46 +01:00
elif optim == 'amsgrad':
2019-03-15 23:17:38 +00:00
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
amsgrad=True,
)
2018-04-27 09:51:04 +01:00
elif optim == 'sgd':
2019-03-15 23:17:38 +00:00
optimizer = torch.optim.SGD(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
dampening=sgd_dampening,
nesterov=sgd_nesterov,
)
2018-04-30 17:50:17 +01:00
elif optim == 'rmsprop':
2019-03-15 23:17:38 +00:00
optimizer = torch.optim.RMSprop(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
alpha=rmsprop_alpha,
)
2019-03-19 17:26:08 +00:00
return optimizer