Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into tttt
commit
3bed2e1f22
|
@ -53,6 +53,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
|||
checkpoints = global_config.get('checkpoints')
|
||||
pretrained_model = global_config.get('pretrained_model')
|
||||
best_model_dict = {}
|
||||
is_float16 = False
|
||||
|
||||
if model_type == 'vqa':
|
||||
# NOTE: for vqa model, resume training is not supported now
|
||||
|
@ -100,6 +101,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
|||
key, params.keys()))
|
||||
continue
|
||||
pre_value = params[key]
|
||||
if pre_value.dtype == paddle.float16:
|
||||
pre_value = pre_value.astype(paddle.float32)
|
||||
is_float16 = True
|
||||
if list(value.shape) == list(pre_value.shape):
|
||||
new_state_dict[key] = pre_value
|
||||
else:
|
||||
|
@ -107,7 +111,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
|||
"The shape of model params {} {} not matched with loaded params shape {} !".
|
||||
format(key, value.shape, pre_value.shape))
|
||||
model.set_state_dict(new_state_dict)
|
||||
|
||||
if is_float16:
|
||||
logger.info(
|
||||
"The parameter type is float16, which is converted to float32 when loading"
|
||||
)
|
||||
if optimizer is not None:
|
||||
if os.path.exists(checkpoints + '.pdopt'):
|
||||
optim_dict = paddle.load(checkpoints + '.pdopt')
|
||||
|
@ -126,9 +133,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
|
|||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
elif pretrained_model:
|
||||
load_pretrained_params(model, pretrained_model)
|
||||
is_float16 = load_pretrained_params(model, pretrained_model)
|
||||
else:
|
||||
logger.info('train from scratch')
|
||||
best_model_dict['is_float16'] = is_float16
|
||||
return best_model_dict
|
||||
|
||||
|
||||
|
@ -142,19 +150,28 @@ def load_pretrained_params(model, path):
|
|||
params = paddle.load(path + '.pdparams')
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
is_float16 = False
|
||||
for k1 in params.keys():
|
||||
if k1 not in state_dict.keys():
|
||||
logger.warning("The pretrained params {} not in model".format(k1))
|
||||
else:
|
||||
if params[k1].dtype == paddle.float16:
|
||||
params[k1] = params[k1].astype(paddle.float32)
|
||||
is_float16 = True
|
||||
if list(state_dict[k1].shape) == list(params[k1].shape):
|
||||
new_state_dict[k1] = params[k1]
|
||||
else:
|
||||
logger.warning(
|
||||
"The shape of model params {} {} not matched with loaded params {} {} !".
|
||||
format(k1, state_dict[k1].shape, k1, params[k1].shape))
|
||||
|
||||
model.set_state_dict(new_state_dict)
|
||||
if is_float16:
|
||||
logger.info(
|
||||
"The parameter type is float16, which is converted to float32 when loading"
|
||||
)
|
||||
logger.info("load pretrain successful from {}".format(path))
|
||||
return model
|
||||
return is_float16
|
||||
|
||||
|
||||
def save_model(model,
|
||||
|
|
|
@ -6,7 +6,7 @@ Global.use_gpu:True|True
|
|||
Global.auto_cast:fp32
|
||||
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=8|whole_train_whole_infer=8
|
||||
Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
|
||||
Architecture.Backbone.checkpoints:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
|
||||
|
|
|
@ -160,18 +160,18 @@ def to_float32(preds):
|
|||
for k in preds:
|
||||
if isinstance(preds[k], dict) or isinstance(preds[k], list):
|
||||
preds[k] = to_float32(preds[k])
|
||||
elif isinstance(preds[k], paddle.Tensor):
|
||||
preds[k] = preds[k].astype(paddle.float32)
|
||||
else:
|
||||
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
|
||||
elif isinstance(preds, list):
|
||||
for k in range(len(preds)):
|
||||
if isinstance(preds[k], dict):
|
||||
preds[k] = to_float32(preds[k])
|
||||
elif isinstance(preds[k], list):
|
||||
preds[k] = to_float32(preds[k])
|
||||
elif isinstance(preds[k], paddle.Tensor):
|
||||
preds[k] = preds[k].astype(paddle.float32)
|
||||
elif isinstance(preds[k], paddle.Tensor):
|
||||
preds = preds.astype(paddle.float32)
|
||||
else:
|
||||
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
|
||||
else:
|
||||
preds = paddle.to_tensor(preds, dtype='float32')
|
||||
return preds
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue