fix starnet export
parent
25bf92295f
commit
6a38af5890
|
@ -47,14 +47,18 @@ def main():
|
|||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
init_model(config, model, logger)
|
||||
# init_model(config, model, logger)
|
||||
model.eval()
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
|
||||
infer_shape = [3, -1, -1]
|
||||
infer_shape = [3, -1, -1]
|
||||
if config['Architecture']['model_type'] == "rec":
|
||||
infer_shape = [3, 32, -1]
|
||||
infer_shape = [3, 32, -1]
|
||||
if 'Transform' in config['Architecture'] and config['Architecture'][
|
||||
'Transform'] is not None and config['Architecture'][
|
||||
'Transform']['name'] == 'TPS':
|
||||
infer_shape[-1] = 100
|
||||
|
||||
model = to_static(
|
||||
model,
|
||||
|
|
Loading…
Reference in New Issue