fix optimizer and regularizer
parent
94a8f50ae7
commit
b8a7d186d7
|
@ -19,36 +19,15 @@ from __future__ import print_function
|
|||
import sys
|
||||
import math
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers.ops as ops
|
||||
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
|
||||
from paddle.optimizer.lr_scheduler import LinearLrWarmup
|
||||
from paddle.optimizer.lr_scheduler import PiecewiseLR
|
||||
from paddle.optimizer.lr_scheduler import CosineAnnealingLR
|
||||
from paddle.optimizer.lr_scheduler import ExponentialLR
|
||||
|
||||
__all__ = ['LearningRateBuilder']
|
||||
|
||||
|
||||
class Linear(object):
|
||||
"""
|
||||
Linear learning rate decay
|
||||
|
||||
Args:
|
||||
lr(float): initial learning rate
|
||||
steps(int): total decay steps
|
||||
end_lr(float): end learning rate, default: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, lr, steps, end_lr=0.0, **kwargs):
|
||||
super(Linear, self).__init__()
|
||||
self.lr = lr
|
||||
self.steps = steps
|
||||
self.end_lr = end_lr
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = fluid.layers.polynomial_decay(
|
||||
self.lr, self.steps, self.end_lr, power=1)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class Cosine(object):
|
||||
class Cosine(CosineAnnealingLR):
|
||||
"""
|
||||
Cosine learning rate decay
|
||||
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
|
||||
|
@ -60,20 +39,14 @@ class Cosine(object):
|
|||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, epochs, **kwargs):
|
||||
super(Cosine, self).__init__()
|
||||
self.lr = lr
|
||||
self.step_each_epoch = step_each_epoch
|
||||
self.epochs = epochs
|
||||
super(Cosine, self).__init__(
|
||||
learning_rate=lr,
|
||||
T_max=step_each_epoch * epochs, )
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = fluid.layers.cosine_decay(
|
||||
learning_rate=self.lr,
|
||||
step_each_epoch=self.step_each_epoch,
|
||||
epochs=self.epochs)
|
||||
return learning_rate
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class Piecewise(object):
|
||||
class Piecewise(PiecewiseLR):
|
||||
"""
|
||||
Piecewise learning rate decay
|
||||
|
||||
|
@ -85,16 +58,15 @@ class Piecewise(object):
|
|||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, decay_epochs, gamma=0.1, **kwargs):
|
||||
super(Piecewise, self).__init__()
|
||||
self.bd = [step_each_epoch * e for e in decay_epochs]
|
||||
self.lr = [lr * (gamma**i) for i in range(len(self.bd) + 1)]
|
||||
boundaries = [step_each_epoch * e for e in decay_epochs]
|
||||
lr_values = [lr * (gamma**i) for i in range(len(boundaries) + 1)]
|
||||
super(Piecewise, self).__init__(
|
||||
boundaries=boundaries, values=lr_values)
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = fluid.layers.piecewise_decay(self.bd, self.lr)
|
||||
return learning_rate
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class CosineWarmup(object):
|
||||
class CosineWarmup(LinearLrWarmup):
|
||||
"""
|
||||
Cosine learning rate decay with warmup
|
||||
[0, warmup_epoch): linear warmup
|
||||
|
@ -108,28 +80,23 @@ class CosineWarmup(object):
|
|||
"""
|
||||
|
||||
def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
|
||||
super(CosineWarmup, self).__init__()
|
||||
self.lr = lr
|
||||
self.step_each_epoch = step_each_epoch
|
||||
self.epochs = epochs
|
||||
self.warmup_epoch = warmup_epoch
|
||||
assert epochs > warmup_epoch, "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
|
||||
epochs, warmup_epoch)
|
||||
warmup_step = warmup_epoch * step_each_epoch
|
||||
start_lr = 0.0
|
||||
end_lr = lr
|
||||
lr_sch = Cosine(lr, step_each_epoch, epochs - warmup_epoch)
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = fluid.layers.cosine_decay(
|
||||
learning_rate=self.lr,
|
||||
step_each_epoch=self.step_each_epoch,
|
||||
epochs=self.epochs)
|
||||
super(CosineWarmup, self).__init__(
|
||||
learning_rate=lr_sch,
|
||||
warmup_steps=warmup_step,
|
||||
start_lr=start_lr,
|
||||
end_lr=end_lr)
|
||||
|
||||
learning_rate = fluid.layers.linear_lr_warmup(
|
||||
learning_rate,
|
||||
warmup_steps=self.warmup_epoch * self.step_each_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.lr)
|
||||
|
||||
return learning_rate
|
||||
self.update_specified = False
|
||||
|
||||
|
||||
class ExponentialWarmup(object):
|
||||
class ExponentialWarmup(LinearLrWarmup):
|
||||
"""
|
||||
Exponential learning rate decay with warmup
|
||||
[0, warmup_epoch): linear warmup
|
||||
|
@ -150,27 +117,22 @@ class ExponentialWarmup(object):
|
|||
decay_rate=0.97,
|
||||
warmup_epoch=5,
|
||||
**kwargs):
|
||||
super(ExponentialWarmup, self).__init__()
|
||||
self.lr = lr
|
||||
warmup_step = warmup_epoch * step_each_epoch
|
||||
start_lr = 0.0
|
||||
end_lr = lr
|
||||
lr_sch = ExponentialLR(lr, decay_rate)
|
||||
|
||||
super(ExponentialWarmup, self).__init__(
|
||||
learning_rate=lr_sch,
|
||||
warmup_steps=warmup_step,
|
||||
start_lr=start_lr,
|
||||
end_lr=end_lr)
|
||||
|
||||
# NOTE: hac method to update exponential lr scheduler
|
||||
self.update_specified = True
|
||||
self.update_start_step = warmup_step
|
||||
self.update_step_interval = int(decay_epochs * step_each_epoch)
|
||||
self.step_each_epoch = step_each_epoch
|
||||
self.decay_epochs = decay_epochs
|
||||
self.decay_rate = decay_rate
|
||||
self.warmup_epoch = warmup_epoch
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = fluid.layers.exponential_decay(
|
||||
learning_rate=self.lr,
|
||||
decay_steps=self.decay_epochs * self.step_each_epoch,
|
||||
decay_rate=self.decay_rate,
|
||||
staircase=False)
|
||||
|
||||
learning_rate = fluid.layers.linear_lr_warmup(
|
||||
learning_rate,
|
||||
warmup_steps=self.warmup_epoch * self.step_each_epoch,
|
||||
start_lr=0.0,
|
||||
end_lr=self.lr)
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
class LearningRateBuilder():
|
||||
|
@ -193,5 +155,5 @@ class LearningRateBuilder():
|
|||
|
||||
def __call__(self):
|
||||
mod = sys.modules[__name__]
|
||||
lr = getattr(mod, self.function)(**self.params)()
|
||||
lr = getattr(mod, self.function)(**self.params)
|
||||
return lr
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import print_function
|
|||
|
||||
import sys
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
|
||||
__all__ = ['OptimizerBuilder']
|
||||
|
||||
|
@ -33,11 +33,10 @@ class L1Decay(object):
|
|||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L1Decay, self).__init__()
|
||||
self.regularization_coeff = factor
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = fluid.regularizer.L1Decay(
|
||||
regularization_coeff=self.regularization_coeff)
|
||||
reg = paddle.regularizer.L1Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
|
@ -51,11 +50,10 @@ class L2Decay(object):
|
|||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L2Decay, self).__init__()
|
||||
self.regularization_coeff = factor
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = fluid.regularizer.L2Decay(
|
||||
regularization_coeff=self.regularization_coeff)
|
||||
reg = paddle.regularizer.L2Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
|
@ -83,11 +81,11 @@ class Momentum(object):
|
|||
self.regularization = regularization
|
||||
|
||||
def __call__(self):
|
||||
opt = fluid.optimizer.Momentum(
|
||||
opt = paddle.optimizer.Momentum(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
parameter_list=self.parameter_list,
|
||||
regularization=self.regularization)
|
||||
parameters=self.parameter_list,
|
||||
weight_decay=self.regularization)
|
||||
return opt
|
||||
|
||||
|
||||
|
@ -121,13 +119,13 @@ class RMSProp(object):
|
|||
self.regularization = regularization
|
||||
|
||||
def __call__(self):
|
||||
opt = fluid.optimizer.RMSProp(
|
||||
opt = paddle.optimizer.RMSProp(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
rho=self.rho,
|
||||
epsilon=self.epsilon,
|
||||
parameter_list=self.parameter_list,
|
||||
regularization=self.regularization)
|
||||
parameters=self.parameter_list,
|
||||
weight_decay=self.regularization)
|
||||
return opt
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,9 @@ from __future__ import print_function
|
|||
import os
|
||||
import sys
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
# TODO: need to be fixed in the future.
|
||||
from paddle.fluid import is_compiled_with_cuda
|
||||
|
||||
from ppcls.modeling import get_architectures
|
||||
from ppcls.modeling import similar_architectures
|
||||
|
@ -33,10 +35,9 @@ def check_version():
|
|||
"""
|
||||
err = "PaddlePaddle version 1.8.0 or higher is required, " \
|
||||
"or a suitable develop version is satisfied as well. \n" \
|
||||
"Please make sure the version is good with your code." \
|
||||
|
||||
"Please make sure the version is good with your code."
|
||||
try:
|
||||
fluid.require_version('1.8.0')
|
||||
paddle.utils.require_version('0.0.0')
|
||||
except Exception:
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
|
@ -50,7 +51,7 @@ def check_gpu():
|
|||
"install paddlepaddle-gpu to run model on GPU."
|
||||
|
||||
try:
|
||||
assert fluid.is_compiled_with_cuda()
|
||||
assert is_compiled_with_cuda()
|
||||
except AssertionError:
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
|
|
|
@ -22,7 +22,8 @@ import re
|
|||
import shutil
|
||||
import tempfile
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
from paddle.io import load_program_state
|
||||
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
@ -50,7 +51,7 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
|
|||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
if load_static_weights:
|
||||
pre_state_dict = fluid.load_program_state(path)
|
||||
pre_state_dict = load_program_state(path)
|
||||
param_state_dict = {}
|
||||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
|
@ -64,7 +65,7 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
|
|||
model.set_dict(param_state_dict)
|
||||
return
|
||||
|
||||
param_state_dict, optim_state_dict = fluid.load_dygraph(path)
|
||||
param_state_dict, optim_state_dict = paddle.load(path)
|
||||
model.set_dict(param_state_dict)
|
||||
return
|
||||
|
||||
|
@ -105,7 +106,7 @@ def init_model(config, net, optimizer=None):
|
|||
"Given dir {}.pdparams not exist.".format(checkpoints)
|
||||
assert os.path.exists(checkpoints + ".pdopt"), \
|
||||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
||||
para_dict, opti_dict = fluid.dygraph.load_dygraph(checkpoints)
|
||||
para_dict, opti_dict = paddle(checkpoints)
|
||||
net.set_dict(para_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
logger.info(
|
||||
|
@ -141,8 +142,8 @@ def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
|
|||
_mkdir_if_not_exist(model_path)
|
||||
model_prefix = os.path.join(model_path, prefix)
|
||||
|
||||
fluid.dygraph.save_dygraph(net.state_dict(), model_prefix)
|
||||
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_prefix)
|
||||
paddle.save(net.state_dict(), model_prefix)
|
||||
paddle.save(optimizer.state_dict(), model_prefix)
|
||||
logger.info(
|
||||
logger.coloring("Already save model in {}".format(model_path),
|
||||
"HEADER"))
|
||||
|
|
|
@ -69,8 +69,6 @@ def create_model(architecture, classes_num):
|
|||
"""
|
||||
name = architecture["name"]
|
||||
params = architecture.get("params", {})
|
||||
print(name)
|
||||
print(params)
|
||||
return architectures.__dict__[name](class_dim=classes_num, **params)
|
||||
|
||||
|
||||
|
@ -237,7 +235,7 @@ def create_optimizer(config, parameter_list=None):
|
|||
# create optimizer instance
|
||||
opt_config = config['OPTIMIZER']
|
||||
opt = OptimizerBuilder(**opt_config)
|
||||
return opt(lr, parameter_list)
|
||||
return opt(lr, parameter_list), lr
|
||||
|
||||
|
||||
def create_feeds(batch, use_mix):
|
||||
|
@ -253,7 +251,13 @@ def create_feeds(batch, use_mix):
|
|||
return feeds
|
||||
|
||||
|
||||
def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
|
||||
def run(dataloader,
|
||||
config,
|
||||
net,
|
||||
optimizer=None,
|
||||
lr_scheduler=None,
|
||||
epoch=0,
|
||||
mode='train'):
|
||||
"""
|
||||
Feed data to the model and fetch the measures and loss
|
||||
|
||||
|
@ -302,6 +306,17 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
|
|||
metric_list['lr'].update(
|
||||
optimizer._global_learning_rate().numpy()[0], batch_size)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
if lr_scheduler.update_specified:
|
||||
curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
|
||||
update = max(
|
||||
0, curr_global_counter - lr_scheduler.update_start_step
|
||||
) % lr_scheduler.update_step_interval == 0
|
||||
if update:
|
||||
lr_scheduler.step()
|
||||
else:
|
||||
lr_scheduler.step()
|
||||
|
||||
for name, fetch in fetchs.items():
|
||||
metric_list[name].update(fetch.numpy()[0], batch_size)
|
||||
metric_list['batch_time'].update(time.time() - tic)
|
||||
|
|
|
@ -23,15 +23,16 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
|||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
import program
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils.save_load import init_model, save_model
|
||||
from ppcls.utils.config import get_config
|
||||
from ppcls.data import Reader
|
||||
|
||||
import paddle
|
||||
from paddle.distributed import ParallelEnv
|
||||
|
||||
from ppcls.data import Reader
|
||||
from ppcls.utils.config import get_config
|
||||
from ppcls.utils.save_load import init_model, save_model
|
||||
from ppcls.utils import logger
|
||||
import program
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("PaddleClas train script")
|
||||
parser.add_argument(
|
||||
|
@ -67,7 +68,7 @@ def main(args):
|
|||
|
||||
net = program.create_model(config.ARCHITECTURE, config.classes_num)
|
||||
|
||||
optimizer = program.create_optimizer(
|
||||
optimizer, lr_scheduler = program.create_optimizer(
|
||||
config, parameter_list=net.parameters())
|
||||
|
||||
if config["use_data_parallel"]:
|
||||
|
@ -90,8 +91,8 @@ def main(args):
|
|||
for epoch_id in range(config.epochs):
|
||||
net.train()
|
||||
# 1. train with train dataset
|
||||
program.run(train_dataloader, config, net, optimizer, epoch_id,
|
||||
'train')
|
||||
program.run(train_dataloader, config, net, optimizer, lr_scheduler,
|
||||
epoch_id, 'train')
|
||||
|
||||
if not config["use_data_parallel"] or ParallelEnv().local_rank == 0:
|
||||
# 2. validate with validate dataset
|
||||
|
|
Loading…
Reference in New Issue