Merge pull request #1 from littletomatodonkey/me/add_pdemo

fix convert weight
pull/1925/head
cuicheng01 2022-05-17 21:32:04 +08:00 committed by GitHub
commit 1989b66044
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 43 deletions

View File

@ -344,15 +344,15 @@ class Engine(object):
if self.use_dali: if self.use_dali:
self.train_dataloader.reset() self.train_dataloader.reset()
metric_msg = ", ".join([ metric_msg = ", ".join(
self.output_info[key].avg_info for key in self.output_info [self.output_info[key].avg_info for key in self.output_info])
])
logger.info("[Train][Epoch {}/{}][Avg]{}".format( logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.config["Global"]["epochs"], metric_msg)) epoch_id, self.config["Global"]["epochs"], metric_msg))
self.output_info.clear() self.output_info.clear()
# eval model and save model if possible # eval model and save model if possible
start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1 start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
if self.config["Global"][ if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][ "eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > start_eval_epoch: "eval_interval"] == 0 and epoch_id > start_eval_epoch:
@ -367,7 +367,8 @@ class Engine(object):
self.output_dir, self.output_dir,
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="best_model", prefix="best_model",
loss=self.train_loss_func) loss=self.train_loss_func,
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format( logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"])) epoch_id, best_metric["metric"]))
logger.scaler( logger.scaler(

View File

@ -1,31 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle
def convert_distill_weights(distill_weights_path, student_weights_path):
assert os.path.exists(distill_weights_path), \
"Given distill_weights_path {} not exist.".format(distill_weights_path)
# Load teacher and student weights
all_params = paddle.load(distill_weights_path)
# Extract student weights
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# Save student weights
paddle.save(s_params, student_weights_path)

View File

@ -42,6 +42,14 @@ def _mkdir_if_not_exist(path):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def _extract_student_weights(all_params, student_prefix="Student."):
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
return s_params
def load_dygraph_pretrain(model, path=None): def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {}.pdparams does not " raise ValueError("Model pretrain path {}.pdparams does not "
@ -117,7 +125,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
else: # common load else: # common load
load_dygraph_pretrain(net, path=pretrained_model) load_dygraph_pretrain(net, path=pretrained_model)
logger.info("Finish load pretrained model from {}".format( logger.info("Finish load pretrained model from {}".format(
pretrained_model)) pretrained_model))
def save_model(net, def save_model(net,
@ -126,7 +134,8 @@ def save_model(net,
model_path, model_path,
model_name="", model_name="",
prefix='ppcls', prefix='ppcls',
loss: paddle.nn.Layer=None): loss: paddle.nn.Layer=None,
save_student_model=False):
""" """
save model to the target path save model to the target path
""" """
@ -137,11 +146,18 @@ def save_model(net,
model_path = os.path.join(model_path, prefix) model_path = os.path.join(model_path, prefix)
params_state_dict = net.state_dict() params_state_dict = net.state_dict()
loss_state_dict = loss.state_dict() if loss is not None:
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys()) loss_state_dict = loss.state_dict()
assert len(keys_inter) == 0, \ keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys(
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}" ))
params_state_dict.update(loss_state_dict) assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if save_student_model:
s_params = _extract_student_weights(params_state_dict)
if len(s_params) > 0:
paddle.save(s_params, model_path + "_student.pdparams")
paddle.save(params_state_dict, model_path + ".pdparams") paddle.save(params_state_dict, model_path + ".pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt") paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")