mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Fix image export in test.py for KIE models (#486)
* Fix image export in test.py for sdmgr model * fix pretrainedpull/509/head
parent
38bdc10f22
commit
8c72d80164
|
@ -22,7 +22,10 @@ test_pipeline = [
|
|||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size_divisor=32),
|
||||
dict(type='KIEFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'relations', 'texts', 'gt_bboxes'],
|
||||
meta_keys=['img_norm_cfg', 'img_shape', 'ori_filename'])
|
||||
]
|
||||
|
||||
dataset_type = 'KIEDataset'
|
||||
|
|
|
@ -35,7 +35,8 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
|
|||
f'but got {type(config)}')
|
||||
if cfg_options is not None:
|
||||
config.merge_from_dict(cfg_options)
|
||||
config.model.pretrained = None
|
||||
if config.model.get('pretrained'):
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
||||
if checkpoint is not None:
|
||||
|
|
|
@ -2,16 +2,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.cnn import fuse_conv_bn
|
||||
from mmcv.image import tensor2imgs
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
||||
wrap_fp16_model)
|
||||
from mmdet.apis import multi_gpu_test, single_gpu_test
|
||||
from mmdet.apis import multi_gpu_test
|
||||
from mmdet.core import encode_mask_results
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
|
||||
from mmocr.apis.inference import disable_text_recog_aug_test
|
||||
|
@ -103,6 +106,85 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
def single_gpu_test(model,
|
||||
data_loader,
|
||||
show=False,
|
||||
out_dir=None,
|
||||
is_kie=False,
|
||||
show_score_thr=0.3):
|
||||
model.eval()
|
||||
results = []
|
||||
dataset = data_loader.dataset
|
||||
prog_bar = mmcv.ProgressBar(len(dataset))
|
||||
for i, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
result = model(return_loss=False, rescale=True, **data)
|
||||
|
||||
batch_size = len(result)
|
||||
if show or out_dir:
|
||||
if is_kie:
|
||||
img_tensor = data['img'].data[0]
|
||||
if img_tensor.shape[0] != 1:
|
||||
raise KeyError('Visualizing KIE outputs in batches is'
|
||||
'currently not supported.')
|
||||
gt_bboxes = data['gt_bboxes'].data[0]
|
||||
img_metas = data['img_metas'].data[0]
|
||||
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
||||
for i, img in enumerate(imgs):
|
||||
h, w, _ = img_metas[i]['img_shape']
|
||||
img_show = img[:h, :w, :]
|
||||
if out_dir:
|
||||
out_file = osp.join(out_dir,
|
||||
img_metas[i]['ori_filename'])
|
||||
else:
|
||||
out_file = None
|
||||
|
||||
model.module.show_result(
|
||||
img_show,
|
||||
result[i],
|
||||
gt_bboxes[i],
|
||||
show=show,
|
||||
out_file=out_file)
|
||||
else:
|
||||
if batch_size == 1 and isinstance(data['img'][0],
|
||||
torch.Tensor):
|
||||
img_tensor = data['img'][0]
|
||||
else:
|
||||
img_tensor = data['img'][0].data[0]
|
||||
img_metas = data['img_metas'][0].data[0]
|
||||
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
||||
assert len(imgs) == len(img_metas)
|
||||
|
||||
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
|
||||
h, w, _ = img_meta['img_shape']
|
||||
img_show = img[:h, :w, :]
|
||||
|
||||
ori_h, ori_w = img_meta['ori_shape'][:-1]
|
||||
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
|
||||
|
||||
if out_dir:
|
||||
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
||||
else:
|
||||
out_file = None
|
||||
|
||||
model.module.show_result(
|
||||
img_show,
|
||||
result[i],
|
||||
show=show,
|
||||
out_file=out_file,
|
||||
score_thr=show_score_thr)
|
||||
|
||||
# encode mask results
|
||||
if isinstance(result[0], tuple):
|
||||
result = [(bbox_results, encode_mask_results(mask_results))
|
||||
for bbox_results, mask_results in result]
|
||||
results.extend(result)
|
||||
|
||||
for _ in range(batch_size):
|
||||
prog_bar.update()
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
|
@ -209,8 +291,9 @@ def main():
|
|||
|
||||
if not distributed:
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
is_kie = cfg.model.type in ['SDMGR']
|
||||
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
|
||||
args.show_score_thr)
|
||||
is_kie, args.show_score_thr)
|
||||
else:
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
|
|
Loading…
Reference in New Issue