Merge pull request #1853 from HydrogenSulfate/multi_optim

support for multi optimizer case
This commit is contained in:
Walter 2022-04-20 14:33:07 +08:00 committed by GitHub
commit 43a03a0c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 149 additions and 60 deletions

View File

@ -214,16 +214,19 @@ class Engine(object):
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"): if self.config["Global"]["pretrained_model"].startswith("http"):
load_dygraph_pretrain_from_url( load_dygraph_pretrain_from_url(
self.model, self.config["Global"]["pretrained_model"]) [self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
else: else:
load_dygraph_pretrain( load_dygraph_pretrain(
self.model, self.config["Global"]["pretrained_model"]) [self.model, getattr(self, 'train_loss_func', None)],
self.config["Global"]["pretrained_model"])
# build optimizer # build optimizer
if self.mode == 'train': if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"], self.config, self.config["Global"]["epochs"],
len(self.train_dataloader), [self.model]) len(self.train_dataloader),
[self.model, self.train_loss_func])
# for amp training # for amp training
if self.amp: if self.amp:
@ -241,6 +244,11 @@ class Engine(object):
optimizers=self.optimizer, optimizers=self.optimizer,
level=amp_level, level=amp_level,
save_dtype='float32') save_dtype='float32')
if len(self.train_loss_func.parameters()) > 0:
self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func,
level=amp_level,
save_dtype='float32')
# for distributed # for distributed
world_size = dist.get_world_size() world_size = dist.get_world_size()
@ -251,7 +259,10 @@ class Engine(object):
if self.config["Global"]["distributed"]: if self.config["Global"]["distributed"]:
dist.init_parallel_env() dist.init_parallel_env()
self.model = paddle.DataParallel(self.model) self.model = paddle.DataParallel(self.model)
if self.mode == 'train' and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
# build postprocess for infer # build postprocess for infer
if self.mode == 'infer': if self.mode == 'infer':
self.preprocess_func = create_operators(self.config["Infer"][ self.preprocess_func = create_operators(self.config["Infer"][
@ -279,9 +290,9 @@ class Engine(object):
# global iter counter # global iter counter
self.global_step = 0 self.global_step = 0
if self.config["Global"]["checkpoints"] is not None: if self.config.Global.checkpoints is not None:
metric_info = init_model(self.config["Global"], self.model, metric_info = init_model(self.config.Global, self.model,
self.optimizer) self.optimizer, self.train_loss_func)
if metric_info is not None: if metric_info is not None:
best_metric.update(metric_info) best_metric.update(metric_info)
@ -317,7 +328,8 @@ class Engine(object):
best_metric, best_metric,
self.output_dir, self.output_dir,
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="best_model") prefix="best_model",
loss=self.train_loss_func)
logger.info("[Eval][Epoch {}][best metric: {}]".format( logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"])) epoch_id, best_metric["metric"]))
logger.scaler( logger.scaler(
@ -336,7 +348,8 @@ class Engine(object):
"epoch": epoch_id}, "epoch": epoch_id},
self.output_dir, self.output_dir,
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="epoch_{}".format(epoch_id)) prefix="epoch_{}".format(epoch_id),
loss=self.train_loss_func)
# save the latest model # save the latest model
save_load.save_model( save_load.save_model(
self.model, self.model,
@ -344,7 +357,8 @@ class Engine(object):
"epoch": epoch_id}, "epoch": epoch_id},
self.output_dir, self.output_dir,
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="latest") prefix="latest",
loss=self.train_loss_func)
if self.vdl_writer is not None: if self.vdl_writer is not None:
self.vdl_writer.close() self.vdl_writer.close()

View File

@ -53,16 +53,22 @@ def train_epoch(engine, epoch_id, print_batch_step):
out = forward(engine, batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr # step opt
if engine.amp: if engine.amp:
scaled = engine.scaler.scale(loss_dict["loss"]) scaled = engine.scaler.scale(loss_dict["loss"])
scaled.backward() scaled.backward()
engine.scaler.minimize(engine.optimizer, scaled) for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else: else:
loss_dict["loss"].backward() loss_dict["loss"].backward()
engine.optimizer.step() for i in range(len(engine.optimizer)):
engine.optimizer.clear_grad() engine.optimizer[i].step()
engine.lr_sch.step() # clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr
for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step()
# below code just for logging # below code just for logging
# update metric_for_logger # update metric_for_logger

View File

@ -38,7 +38,10 @@ def update_loss(trainer, loss_dict, batch_size):
def log_info(trainer, batch_size, epoch_id, iter_id): def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = "lr: {:.5f}".format(trainer.lr_sch.get_lr()) lr_msg = ", ".join([
"lr_{}: {:.8f}".format(i + 1, lr.get_lr())
for i, lr in enumerate(trainer.lr_sch)
])
metric_msg = ", ".join([ metric_msg = ", ".join([
"{}: {:.5f}".format(key, trainer.output_info[key].avg) "{}: {:.5f}".format(key, trainer.output_info[key].avg)
for key in trainer.output_info for key in trainer.output_info
@ -59,11 +62,12 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg, len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg,
eta_msg)) eta_msg))
logger.scaler( for i, lr in enumerate(trainer.lr_sch):
name="lr", logger.scaler(
value=trainer.lr_sch.get_lr(), name="lr_{}".format(i + 1),
step=trainer.global_step, value=lr.get_lr(),
writer=trainer.vdl_writer) step=trainer.global_step,
writer=trainer.vdl_writer)
for key in trainer.output_info: for key in trainer.output_info:
logger.scaler( logger.scaler(
name="train_{}".format(key), name="train_{}".format(key),

View File

@ -47,6 +47,7 @@ class CombinedLoss(nn.Layer):
param.keys()) param.keys())
self.loss_weight.append(param.pop("weight")) self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param)) self.loss_func.append(eval(name)(**param))
self.loss_func = nn.LayerList(self.loss_func)
def __call__(self, input, batch): def __call__(self, input, batch):
loss_dict = {} loss_dict = {}

View File

@ -18,6 +18,7 @@ from __future__ import print_function
import copy import copy
import paddle import paddle
from typing import Dict, List
from ppcls.utils import logger from ppcls.utils import logger
@ -44,29 +45,78 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph # model_list is None in static graph
def build_optimizer(config, epochs, step_each_epoch, model_list=None): def build_optimizer(config, epochs, step_each_epoch, model_list=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
# step1 build lr optim_config = config["Optimizer"]
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch) if isinstance(optim_config, dict):
logger.debug("build lr ({}) success..".format(lr)) # convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
# step2 build regularization optim_name = optim_config.pop("name")
if 'regularizer' in config and config['regularizer'] is not None: optim_config: List[Dict[str, Dict]] = [{
if 'weight_decay' in config: optim_name: {
logger.warning( 'scope': "all",
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored." **
) optim_config
reg_config = config.pop('regularizer') }
reg_name = reg_config.pop('name') + 'Decay' }]
reg = getattr(paddle.regularizer, reg_name)(**reg_config) optim_list = []
config["weight_decay"] = reg lr_list = []
logger.debug("build regularizer ({}) success..".format(reg)) """NOTE:
# step3 build optimizer Currently only support optim objets below.
optim_name = config.pop('name') 1. single optimizer config.
if 'clip_norm' in config: 2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
clip_norm = config.pop('clip_norm') 3. loss which has parameters, such as CenterLoss.
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm) """
else: for optim_item in optim_config:
grad_clip = None # optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
optim = getattr(optimizer, optim_name)(learning_rate=lr, # step1 build lr
grad_clip=grad_clip, optim_name = list(optim_item.keys())[0] # get optim_name
**config)(model_list=model_list) optim_scope = optim_item[optim_name].pop('scope') # get optim_scope
logger.debug("build optimizer ({}) success..".format(optim)) optim_cfg = optim_item[optim_name] # get optim_cfg
return optim, lr
lr = build_lr_scheduler(optim_cfg.pop('lr'), epochs, step_each_epoch)
logger.debug("build lr ({}) for scope ({}) success..".format(
lr, optim_scope))
# step2 build regularization
if 'regularizer' in optim_cfg and optim_cfg['regularizer'] is not None:
if 'weight_decay' in optim_cfg:
logger.warning(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
)
reg_config = optim_cfg.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
optim_cfg["weight_decay"] = reg
logger.debug("build regularizer ({}) for scope ({}) success..".
format(reg, optim_scope))
# step3 build optimizer
if 'clip_norm' in optim_cfg:
clip_norm = optim_cfg.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
else:
grad_clip = None
optim_model = []
for i in range(len(model_list)):
if len(model_list[i].parameters()) == 0:
continue
if optim_scope == "all":
# optimizer for all
optim_model.append(model_list[i])
else:
if optim_scope.endswith("Loss"):
# optimizer for loss
for m in model_list[i].sublayers(True):
if m.__class_name == optim_scope:
optim_model.append(m)
else:
# opmizer for module in model, such as backbone, neck, head...
if hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
assert len(optim_model) == 1, \
"Invalid optim model for optim scope({}), number of optim_model={}".format(optim_scope, len(optim_model))
optim = getattr(optimizer, optim_name)(
learning_rate=lr, grad_clip=grad_clip,
**optim_cfg)(model_list=optim_model)
logger.debug("build optimizer ({}) for scope ({}) success..".format(
optim, optim_scope))
optim_list.append(optim)
lr_list.append(lr)
return optim_list, lr_list

View File

@ -18,9 +18,6 @@ from __future__ import print_function
import errno import errno
import os import os
import re
import shutil
import tempfile
import paddle import paddle
from ppcls.utils import logger from ppcls.utils import logger
@ -47,10 +44,15 @@ def _mkdir_if_not_exist(path):
def load_dygraph_pretrain(model, path=None): def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {}.pdparams does not "
"exists.".format(path)) "exists.".format(path))
param_state_dict = paddle.load(path + ".pdparams") param_state_dict = paddle.load(path + ".pdparams")
model.set_dict(param_state_dict) if isinstance(model, list):
for m in model:
if hasattr(m, 'set_dict'):
m.set_dict(param_state_dict)
else:
model.set_dict(param_state_dict)
return return
@ -85,7 +87,7 @@ def load_distillation_model(model, pretrained_model):
pretrained_model)) pretrained_model))
def init_model(config, net, optimizer=None): def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
@ -95,11 +97,15 @@ def init_model(config, net, optimizer=None):
"Given dir {}.pdparams not exist.".format(checkpoints) "Given dir {}.pdparams not exist.".format(checkpoints)
assert os.path.exists(checkpoints + ".pdopt"), \ assert os.path.exists(checkpoints + ".pdopt"), \
"Given dir {}.pdopt not exist.".format(checkpoints) "Given dir {}.pdopt not exist.".format(checkpoints)
para_dict = paddle.load(checkpoints + ".pdparams") # load state dict
opti_dict = paddle.load(checkpoints + ".pdopt") opti_dict = paddle.load(checkpoints + ".pdopt")
para_dict = paddle.load(checkpoints + ".pdparams")
metric_dict = paddle.load(checkpoints + ".pdstates") metric_dict = paddle.load(checkpoints + ".pdstates")
net.set_dict(para_dict) # set state dict
optimizer.set_state_dict(opti_dict) net.set_state_dict(para_dict)
loss.set_state_dict(para_dict)
for i in range(len(optimizer)):
optimizer[i].set_state_dict(opti_dict)
logger.info("Finish load checkpoints from {}".format(checkpoints)) logger.info("Finish load checkpoints from {}".format(checkpoints))
return metric_dict return metric_dict
@ -120,7 +126,8 @@ def save_model(net,
metric_info, metric_info,
model_path, model_path,
model_name="", model_name="",
prefix='ppcls'): prefix='ppcls',
loss: paddle.nn.Layer=None):
""" """
save model to the target path save model to the target path
""" """
@ -130,7 +137,14 @@ def save_model(net,
_mkdir_if_not_exist(model_path) _mkdir_if_not_exist(model_path)
model_path = os.path.join(model_path, prefix) model_path = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_path + ".pdparams") params_state_dict = net.state_dict()
paddle.save(optimizer.state_dict(), model_path + ".pdopt") loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
paddle.save(params_state_dict, model_path + ".pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
paddle.save(metric_info, model_path + ".pdstates") paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path)) logger.info("Already save model in {}".format(model_path))