mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
convert fp16 params to fp32 when params is fp16 format
This commit is contained in:
parent
d0efcc74c9
commit
c6738f4c53
@ -53,6 +53,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
||||
checkpoints = global_config.get('checkpoints')
|
||||
pretrained_model = global_config.get('pretrained_model')
|
||||
best_model_dict = {}
|
||||
is_float16 = False
|
||||
|
||||
if model_type == 'vqa':
|
||||
checkpoints = config['Architecture']['Backbone']['checkpoints']
|
||||
@ -90,7 +91,6 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
||||
params = paddle.load(checkpoints + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
is_float16 = False
|
||||
for key, value in state_dict.items():
|
||||
if key not in params:
|
||||
logger.warning("{} not in loaded params {} !".format(
|
||||
@ -107,7 +107,6 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
||||
"The shape of model params {} {} not matched with loaded params shape {} !".
|
||||
format(key, value.shape, pre_value.shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
|
||||
if is_float16:
|
||||
logger.info(
|
||||
"The parameter type is float16, which is converted to float32 when loading"
|
||||
@ -130,9 +129,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
||||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
elif pretrained_model:
|
||||
load_pretrained_params(model, pretrained_model)
|
||||
is_float16 = load_pretrained_params(model, pretrained_model)
|
||||
else:
|
||||
logger.info('train from scratch')
|
||||
best_model_dict['is_float16'] = is_float16
|
||||
return best_model_dict
|
||||
|
||||
|
||||
@ -167,7 +167,7 @@ def load_pretrained_params(model, path):
|
||||
"The parameter type is float16, which is converted to float32 when loading"
|
||||
)
|
||||
logger.info("load pretrain successful from {}".format(path))
|
||||
return model
|
||||
return is_float16
|
||||
|
||||
|
||||
def save_model(model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user