From 6a38af589028fe38cd8c165598519ce1f8d8ed45 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Thu, 28 Jan 2021 11:03:58 +0800 Subject: [PATCH] fix starnet export --- tools/export_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tools/export_model.py b/tools/export_model.py index b7d61a59f..a2428bf72 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -47,14 +47,18 @@ def main(): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - init_model(config, model, logger) + # init_model(config, model, logger) model.eval() save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - infer_shape = [3, -1, -1] + infer_shape = [3, -1, -1] if config['Architecture']['model_type'] == "rec": - infer_shape = [3, 32, -1] + infer_shape = [3, 32, -1] + if 'Transform' in config['Architecture'] and config['Architecture'][ + 'Transform'] is not None and config['Architecture'][ + 'Transform']['name'] == 'TPS': + infer_shape[-1] = 100 model = to_static( model,