support for multi optimizer case
parent
a5a1c19273
commit
24abea151a
|
@ -214,16 +214,19 @@ class Engine(object):
|
|||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
if self.config["Global"]["pretrained_model"].startswith("http"):
|
||||
load_dygraph_pretrain_from_url(
|
||||
self.model, self.config["Global"]["pretrained_model"])
|
||||
[self.model, self.train_loss_func],
|
||||
self.config["Global"]["pretrained_model"])
|
||||
else:
|
||||
load_dygraph_pretrain(
|
||||
self.model, self.config["Global"]["pretrained_model"])
|
||||
[self.model, self.train_loss_func],
|
||||
self.config["Global"]["pretrained_model"])
|
||||
|
||||
# build optimizer
|
||||
if self.mode == 'train':
|
||||
self.optimizer, self.lr_sch = build_optimizer(
|
||||
self.config["Optimizer"], self.config["Global"]["epochs"],
|
||||
len(self.train_dataloader), [self.model])
|
||||
len(self.train_dataloader),
|
||||
[self.model, self.train_loss_func])
|
||||
|
||||
# for amp training
|
||||
if self.amp:
|
||||
|
@ -241,6 +244,11 @@ class Engine(object):
|
|||
optimizers=self.optimizer,
|
||||
level=amp_level,
|
||||
save_dtype='float32')
|
||||
if len(self.train_loss_func.parameters()) > 0:
|
||||
self.train_loss_func = paddle.amp.decorate(
|
||||
models=self.train_loss_func,
|
||||
level=amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
# for distributed
|
||||
world_size = dist.get_world_size()
|
||||
|
@ -251,7 +259,9 @@ class Engine(object):
|
|||
if self.config["Global"]["distributed"]:
|
||||
dist.init_parallel_env()
|
||||
self.model = paddle.DataParallel(self.model)
|
||||
|
||||
if len(self.train_loss_func.parameters()) > 0:
|
||||
self.train_loss_func = paddle.DataParallel(
|
||||
self.train_loss_func)
|
||||
# build postprocess for infer
|
||||
if self.mode == 'infer':
|
||||
self.preprocess_func = create_operators(self.config["Infer"][
|
||||
|
@ -279,9 +289,9 @@ class Engine(object):
|
|||
# global iter counter
|
||||
self.global_step = 0
|
||||
|
||||
if self.config["Global"]["checkpoints"] is not None:
|
||||
metric_info = init_model(self.config["Global"], self.model,
|
||||
self.optimizer)
|
||||
if self.config.Global.checkpoints is not None:
|
||||
metric_info = init_model(self.config.Global, self.model,
|
||||
self.optimizer, self.train_loss_func)
|
||||
if metric_info is not None:
|
||||
best_metric.update(metric_info)
|
||||
|
||||
|
@ -317,7 +327,8 @@ class Engine(object):
|
|||
best_metric,
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="best_model")
|
||||
prefix="best_model",
|
||||
loss=self.train_loss_func)
|
||||
logger.info("[Eval][Epoch {}][best metric: {}]".format(
|
||||
epoch_id, best_metric["metric"]))
|
||||
logger.scaler(
|
||||
|
@ -336,7 +347,8 @@ class Engine(object):
|
|||
"epoch": epoch_id},
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="epoch_{}".format(epoch_id))
|
||||
prefix="epoch_{}".format(epoch_id),
|
||||
loss=self.train_loss_func)
|
||||
# save the latest model
|
||||
save_load.save_model(
|
||||
self.model,
|
||||
|
@ -344,7 +356,8 @@ class Engine(object):
|
|||
"epoch": epoch_id},
|
||||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="latest")
|
||||
prefix="latest",
|
||||
loss=self.train_loss_func)
|
||||
|
||||
if self.vdl_writer is not None:
|
||||
self.vdl_writer.close()
|
||||
|
|
|
@ -53,16 +53,22 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
|
||||
# step opt and lr
|
||||
# step opt
|
||||
if engine.amp:
|
||||
scaled = engine.scaler.scale(loss_dict["loss"])
|
||||
scaled.backward()
|
||||
engine.scaler.minimize(engine.optimizer, scaled)
|
||||
for i in range(len(engine.optimizer)):
|
||||
engine.scaler.minimize(engine.optimizer[i], scaled)
|
||||
else:
|
||||
loss_dict["loss"].backward()
|
||||
engine.optimizer.step()
|
||||
engine.optimizer.clear_grad()
|
||||
engine.lr_sch.step()
|
||||
for i in range(len(engine.optimizer)):
|
||||
engine.optimizer[i].step()
|
||||
# clear grad
|
||||
for i in range(len(engine.optimizer)):
|
||||
engine.optimizer[i].clear_grad()
|
||||
# step lr
|
||||
for i in range(len(engine.lr_sch)):
|
||||
engine.lr_sch[i].step()
|
||||
|
||||
# below code just for logging
|
||||
# update metric_for_logger
|
||||
|
|
|
@ -38,7 +38,10 @@ def update_loss(trainer, loss_dict, batch_size):
|
|||
|
||||
|
||||
def log_info(trainer, batch_size, epoch_id, iter_id):
|
||||
lr_msg = "lr: {:.5f}".format(trainer.lr_sch.get_lr())
|
||||
lr_msg = ", ".join([
|
||||
"lr_{}: {:.8f}".format(i + 1, lr.get_lr())
|
||||
for i, lr in enumerate(trainer.lr_sch)
|
||||
])
|
||||
metric_msg = ", ".join([
|
||||
"{}: {:.5f}".format(key, trainer.output_info[key].avg)
|
||||
for key in trainer.output_info
|
||||
|
@ -59,11 +62,12 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
|
|||
len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg,
|
||||
eta_msg))
|
||||
|
||||
logger.scaler(
|
||||
name="lr",
|
||||
value=trainer.lr_sch.get_lr(),
|
||||
step=trainer.global_step,
|
||||
writer=trainer.vdl_writer)
|
||||
for i, lr in enumerate(trainer.lr_sch):
|
||||
logger.scaler(
|
||||
name="lr_{}".format(i + 1),
|
||||
value=lr.get_lr(),
|
||||
step=trainer.global_step,
|
||||
writer=trainer.vdl_writer)
|
||||
for key in trainer.output_info:
|
||||
logger.scaler(
|
||||
name="train_{}".format(key),
|
||||
|
|
|
@ -47,6 +47,7 @@ class CombinedLoss(nn.Layer):
|
|||
param.keys())
|
||||
self.loss_weight.append(param.pop("weight"))
|
||||
self.loss_func.append(eval(name)(**param))
|
||||
self.loss_func = nn.LayerList(self.loss_func)
|
||||
|
||||
def __call__(self, input, batch):
|
||||
loss_dict = {}
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
|||
|
||||
import copy
|
||||
import paddle
|
||||
from typing import Dict, List
|
||||
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
@ -44,29 +45,57 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
|||
# model_list is None in static graph
|
||||
def build_optimizer(config, epochs, step_each_epoch, model_list=None):
|
||||
config = copy.deepcopy(config)
|
||||
# step1 build lr
|
||||
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
|
||||
logger.debug("build lr ({}) success..".format(lr))
|
||||
# step2 build regularization
|
||||
if 'regularizer' in config and config['regularizer'] is not None:
|
||||
if 'weight_decay' in config:
|
||||
logger.warning(
|
||||
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
|
||||
)
|
||||
reg_config = config.pop('regularizer')
|
||||
reg_name = reg_config.pop('name') + 'Decay'
|
||||
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
||||
config["weight_decay"] = reg
|
||||
logger.debug("build regularizer ({}) success..".format(reg))
|
||||
# step3 build optimizer
|
||||
optim_name = config.pop('name')
|
||||
if 'clip_norm' in config:
|
||||
clip_norm = config.pop('clip_norm')
|
||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
|
||||
else:
|
||||
grad_clip = None
|
||||
optim = getattr(optimizer, optim_name)(learning_rate=lr,
|
||||
grad_clip=grad_clip,
|
||||
**config)(model_list=model_list)
|
||||
logger.debug("build optimizer ({}) success..".format(optim))
|
||||
return optim, lr
|
||||
if isinstance(config, dict):
|
||||
# convert to [{optim_name1: {scope: xxx, **optim_cfg}}, {optim_name2: {scope: xxx, **optim_cfg}}, ...]
|
||||
optim_name = config.Optimizer.pop('name')
|
||||
config: List[Dict[str, Dict]] = [{
|
||||
optim_name: {
|
||||
'scope': config.Arch.name,
|
||||
**
|
||||
config.Optimizer
|
||||
}
|
||||
}]
|
||||
optim_list = []
|
||||
lr_list = []
|
||||
for optim_item in config:
|
||||
# optim_cfg = {optim_name1: {scope: xxx, **optim_cfg}}
|
||||
# step1 build lr
|
||||
optim_name = optim_item.keys()[0] # get optim_name1
|
||||
optim_scope = optim_item[optim_name].pop('scope') # get scope
|
||||
optim_cfg = optim_item[optim_name] # get optim_cfg
|
||||
|
||||
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
|
||||
logger.debug("build lr ({}) for scope ({}) success..".format(
|
||||
lr, optim_scope))
|
||||
# step2 build regularization
|
||||
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
|
||||
if 'weight_decay' in optim_cfg:
|
||||
logger.warning(
|
||||
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
|
||||
)
|
||||
reg_config = optim_cfg.pop('regularizer')
|
||||
reg_name = reg_config.pop('name') + 'Decay'
|
||||
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
|
||||
optim_cfg["weight_decay"] = reg
|
||||
logger.debug("build regularizer ({}) success..".format(reg))
|
||||
# step3 build optimizer
|
||||
if 'clip_norm' in optim_cfg:
|
||||
clip_norm = optim_cfg.pop('clip_norm')
|
||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
|
||||
else:
|
||||
grad_clip = None
|
||||
optim_model = []
|
||||
for i in range(len(model_list)):
|
||||
class_name = model_list[i].__class__.__name__
|
||||
if class_name == optim_scope:
|
||||
optim_model.append(model_list[i])
|
||||
assert len(optim_model) == 1 and len(optim_model[0].parameters()) > 0, \
|
||||
f"Invalid optim model for optim scope({optim_scope}), number of optim_model={len(optim_model)}, and number of optim_model's params={len(optim_model[0].parameters())}"
|
||||
optim = getattr(optimizer, optim_name)(
|
||||
learning_rate=lr, grad_clip=grad_clip,
|
||||
**optim_cfg)(model_list=optim_model)
|
||||
logger.debug("build optimizer ({}) for scope ({}) success..".format(
|
||||
optim, optim_scope))
|
||||
optim_list.append(optim)
|
||||
lr_list.append(lr)
|
||||
return optim_list, lr_list
|
||||
|
|
|
@ -18,9 +18,6 @@ from __future__ import print_function
|
|||
|
||||
import errno
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import paddle
|
||||
from ppcls.utils import logger
|
||||
|
@ -47,10 +44,14 @@ def _mkdir_if_not_exist(path):
|
|||
|
||||
def load_dygraph_pretrain(model, path=None):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
raise ValueError("Model pretrain path {}.pdparams does not "
|
||||
"exists.".format(path))
|
||||
param_state_dict = paddle.load(path + ".pdparams")
|
||||
model.set_dict(param_state_dict)
|
||||
if isinstance(model, list):
|
||||
for m in model:
|
||||
m.set_dict(param_state_dict)
|
||||
else:
|
||||
model.set_dict(param_state_dict)
|
||||
return
|
||||
|
||||
|
||||
|
@ -85,7 +86,7 @@ def load_distillation_model(model, pretrained_model):
|
|||
pretrained_model))
|
||||
|
||||
|
||||
def init_model(config, net, optimizer=None):
|
||||
def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
|
@ -95,11 +96,15 @@ def init_model(config, net, optimizer=None):
|
|||
"Given dir {}.pdparams not exist.".format(checkpoints)
|
||||
assert os.path.exists(checkpoints + ".pdopt"), \
|
||||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
||||
para_dict = paddle.load(checkpoints + ".pdparams")
|
||||
# load state dict
|
||||
opti_dict = paddle.load(checkpoints + ".pdopt")
|
||||
para_dict = paddle.load(checkpoints + ".pdparams")
|
||||
metric_dict = paddle.load(checkpoints + ".pdstates")
|
||||
net.set_dict(para_dict)
|
||||
optimizer.set_state_dict(opti_dict)
|
||||
# set state dict
|
||||
net.set_state_dict(para_dict)
|
||||
loss.loss_func[i].set_state_dict(para_dict)
|
||||
for i in range(len(optimizer)):
|
||||
optimizer[i].set_state_dict(opti_dict)
|
||||
logger.info("Finish load checkpoints from {}".format(checkpoints))
|
||||
return metric_dict
|
||||
|
||||
|
@ -120,7 +125,8 @@ def save_model(net,
|
|||
metric_info,
|
||||
model_path,
|
||||
model_name="",
|
||||
prefix='ppcls'):
|
||||
prefix='ppcls',
|
||||
loss: paddle.nn.Layer=None):
|
||||
"""
|
||||
save model to the target path
|
||||
"""
|
||||
|
@ -130,7 +136,14 @@ def save_model(net,
|
|||
_mkdir_if_not_exist(model_path)
|
||||
model_path = os.path.join(model_path, prefix)
|
||||
|
||||
paddle.save(net.state_dict(), model_path + ".pdparams")
|
||||
paddle.save(optimizer.state_dict(), model_path + ".pdopt")
|
||||
params_state_dict = net.state_dict()
|
||||
loss_state_dict = loss.state_dict()
|
||||
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
|
||||
assert len(keys_inter) == 0, \
|
||||
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
|
||||
params_state_dict.update(loss_state_dict)
|
||||
|
||||
paddle.save(params_state_dict, model_path + ".pdparams")
|
||||
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
|
||||
paddle.save(metric_info, model_path + ".pdstates")
|
||||
logger.info("Already save model in {}".format(model_path))
|
||||
|
|
Loading…
Reference in New Issue