[Fix] Fix image export in test.py for KIE models (#486)

* Fix image export in test.py for sdmgr model

* fix pretrained
pull/509/head
Tong Gao 2021-09-18 18:24:55 +08:00 committed by GitHub
parent 38bdc10f22
commit 8c72d80164
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 4 deletions

View File

@ -22,7 +22,10 @@ test_pipeline = [
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='KIEFormatBundle'),
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
dict(
type='Collect',
keys=['img', 'relations', 'texts', 'gt_bboxes'],
meta_keys=['img_norm_cfg', 'img_shape', 'ori_filename'])
]
dataset_type = 'KIEDataset'

View File

@ -35,7 +35,8 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
config.model.pretrained = None
if config.model.get('pretrained'):
config.model.pretrained = None
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:

View File

@ -2,16 +2,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import warnings
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.image import tensor2imgs
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.apis import multi_gpu_test
from mmdet.core import encode_mask_results
from mmdet.datasets import replace_ImageToTensor
from mmocr.apis.inference import disable_text_recog_aug_test
@ -103,6 +106,85 @@ def parse_args():
return args
def single_gpu_test(model,
data_loader,
show=False,
out_dir=None,
is_kie=False,
show_score_thr=0.3):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
if is_kie:
img_tensor = data['img'].data[0]
if img_tensor.shape[0] != 1:
raise KeyError('Visualizing KIE outputs in batches is'
'currently not supported.')
gt_bboxes = data['gt_bboxes'].data[0]
img_metas = data['img_metas'].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
for i, img in enumerate(imgs):
h, w, _ = img_metas[i]['img_shape']
img_show = img[:h, :w, :]
if out_dir:
out_file = osp.join(out_dir,
img_metas[i]['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
gt_bboxes[i],
show=show,
out_file=out_file)
else:
if batch_size == 1 and isinstance(data['img'][0],
torch.Tensor):
img_tensor = data['img'][0]
else:
img_tensor = data['img'][0].data[0]
img_metas = data['img_metas'][0].data[0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
ori_h, ori_w = img_meta['ori_shape'][:-1]
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
if out_dir:
out_file = osp.join(out_dir, img_meta['ori_filename'])
else:
out_file = None
model.module.show_result(
img_show,
result[i],
show=show,
out_file=out_file,
score_thr=show_score_thr)
# encode mask results
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
results.extend(result)
for _ in range(batch_size):
prog_bar.update()
return results
def main():
args = parse_args()
@ -209,8 +291,9 @@ def main():
if not distributed:
model = MMDataParallel(model, device_ids=[0])
is_kie = cfg.model.type in ['SDMGR']
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
args.show_score_thr)
is_kie, args.show_score_thr)
else:
model = MMDistributedDataParallel(
model.cuda(),