fix save load
parent
6f631e4340
commit
afafb8f41d
|
@ -344,15 +344,15 @@ class Engine(object):
|
|||
|
||||
if self.use_dali:
|
||||
self.train_dataloader.reset()
|
||||
metric_msg = ", ".join([
|
||||
self.output_info[key].avg_info for key in self.output_info
|
||||
])
|
||||
metric_msg = ", ".join(
|
||||
[self.output_info[key].avg_info for key in self.output_info])
|
||||
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
|
||||
epoch_id, self.config["Global"]["epochs"], metric_msg))
|
||||
self.output_info.clear()
|
||||
|
||||
# eval model and save model if possible
|
||||
start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1
|
||||
start_eval_epoch = self.config["Global"].get("start_eval_epoch",
|
||||
0) - 1
|
||||
if self.config["Global"][
|
||||
"eval_during_train"] and epoch_id % self.config["Global"][
|
||||
"eval_interval"] == 0 and epoch_id > start_eval_epoch:
|
||||
|
@ -367,7 +367,8 @@ class Engine(object):
|
|||
self.output_dir,
|
||||
model_name=self.config["Arch"]["name"],
|
||||
prefix="best_model",
|
||||
loss=self.train_loss_func)
|
||||
loss=self.train_loss_func,
|
||||
save_student_model=True)
|
||||
logger.info("[Eval][Epoch {}][best metric: {}]".format(
|
||||
epoch_id, best_metric["metric"]))
|
||||
logger.scaler(
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ == ["extract_subnet_weights"]
|
||||
|
||||
import os
|
||||
import paddle
|
||||
|
||||
|
||||
def extract_subnet_weights(distill_weights_path,
|
||||
student_weights_path,
|
||||
student_name="Student"):
|
||||
assert os.path.exists(distill_weights_path), \
|
||||
"Given distill_weights_path {} not exist.".format(distill_weights_path)
|
||||
# Load teacher and student weights
|
||||
all_params = paddle.load(distill_weights_path)
|
||||
# Extract student weights
|
||||
student_prefix = student_name + "."
|
||||
s_params = {
|
||||
key[len(student_prefix):]: all_params[key]
|
||||
for key in all_params if student_prefix in key
|
||||
}
|
||||
assert len(
|
||||
s_params
|
||||
) > 0, f"extracted params length must be > 0 but got {len(s_params)}"
|
||||
# Save subnet weights
|
||||
paddle.save(s_params, student_weights_path)
|
|
@ -42,6 +42,14 @@ def _mkdir_if_not_exist(path):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def _extract_student_weights(all_params, student_prefix="Student."):
|
||||
s_params = {
|
||||
key[len(student_prefix):]: all_params[key]
|
||||
for key in all_params if student_prefix in key
|
||||
}
|
||||
return s_params
|
||||
|
||||
|
||||
def load_dygraph_pretrain(model, path=None):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {}.pdparams does not "
|
||||
|
@ -117,7 +125,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
|
|||
else: # common load
|
||||
load_dygraph_pretrain(net, path=pretrained_model)
|
||||
logger.info("Finish load pretrained model from {}".format(
|
||||
pretrained_model))
|
||||
pretrained_model))
|
||||
|
||||
|
||||
def save_model(net,
|
||||
|
@ -126,7 +134,8 @@ def save_model(net,
|
|||
model_path,
|
||||
model_name="",
|
||||
prefix='ppcls',
|
||||
loss: paddle.nn.Layer=None):
|
||||
loss: paddle.nn.Layer=None,
|
||||
save_student_model=False):
|
||||
"""
|
||||
save model to the target path
|
||||
"""
|
||||
|
@ -137,11 +146,18 @@ def save_model(net,
|
|||
model_path = os.path.join(model_path, prefix)
|
||||
|
||||
params_state_dict = net.state_dict()
|
||||
loss_state_dict = loss.state_dict()
|
||||
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
|
||||
assert len(keys_inter) == 0, \
|
||||
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
|
||||
params_state_dict.update(loss_state_dict)
|
||||
if loss is not None:
|
||||
loss_state_dict = loss.state_dict()
|
||||
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys(
|
||||
))
|
||||
assert len(keys_inter) == 0, \
|
||||
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
|
||||
params_state_dict.update(loss_state_dict)
|
||||
|
||||
if save_student_model:
|
||||
s_params = _extract_student_weights(params_state_dict)
|
||||
if len(s_params) > 0:
|
||||
paddle.save(s_params, model_path + "_student.pdparams")
|
||||
|
||||
paddle.save(params_state_dict, model_path + ".pdparams")
|
||||
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
|
||||
|
|
Loading…
Reference in New Issue