fix starnet export

pull/1850/head
WenmuZhou 2021-01-28 11:03:58 +08:00
parent 25bf92295f
commit 6a38af5890
1 changed files with 7 additions and 3 deletions

View File

@ -47,14 +47,18 @@ def main():
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
# init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, -1, -1]
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1]
infer_shape = [3, 32, -1]
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
infer_shape[-1] = 100
model = to_static(
model,