mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
remove test for textsnake (#8)
This commit is contained in:
parent
923ee16bf7
commit
c345255e49
@ -316,53 +316,54 @@ def test_dbnet(cfg_file):
|
|||||||
detector.show_result(img, results)
|
detector.show_result(img, results)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
# @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
@pytest.mark.parametrize(
|
# @pytest.mark.parametrize(
|
||||||
'cfg_file', ['textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py'])
|
# 'cfg_file', ['textdet/textsnake/'
|
||||||
def test_textsnake(cfg_file):
|
# 'textsnake_r50_fpn_unet_1200e_ctw1500.py'])
|
||||||
model = _get_detector_cfg(cfg_file)
|
# def test_textsnake(cfg_file):
|
||||||
model['pretrained'] = None
|
# model = _get_detector_cfg(cfg_file)
|
||||||
model['backbone']['norm_cfg']['type'] = 'BN'
|
# model['pretrained'] = None
|
||||||
|
# model['backbone']['norm_cfg']['type'] = 'BN'
|
||||||
|
|
||||||
from mmocr.models import build_detector
|
# from mmocr.models import build_detector
|
||||||
detector = build_detector(model)
|
# detector = build_detector(model)
|
||||||
detector = detector.cuda()
|
# detector = detector.cuda()
|
||||||
input_shape = (1, 3, 64, 64)
|
# input_shape = (1, 3, 64, 64)
|
||||||
num_kernels = 1
|
# num_kernels = 1
|
||||||
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
|
# mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
|
||||||
|
|
||||||
imgs = mm_inputs.pop('imgs')
|
# imgs = mm_inputs.pop('imgs')
|
||||||
imgs = imgs.cuda()
|
# imgs = imgs.cuda()
|
||||||
img_metas = mm_inputs.pop('img_metas')
|
# img_metas = mm_inputs.pop('img_metas')
|
||||||
gt_text_mask = mm_inputs.pop('gt_text_mask')
|
# gt_text_mask = mm_inputs.pop('gt_text_mask')
|
||||||
gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
|
# gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
|
||||||
gt_mask = mm_inputs.pop('gt_mask')
|
# gt_mask = mm_inputs.pop('gt_mask')
|
||||||
gt_radius_map = mm_inputs.pop('gt_radius_map')
|
# gt_radius_map = mm_inputs.pop('gt_radius_map')
|
||||||
gt_sin_map = mm_inputs.pop('gt_sin_map')
|
# gt_sin_map = mm_inputs.pop('gt_sin_map')
|
||||||
gt_cos_map = mm_inputs.pop('gt_cos_map')
|
# gt_cos_map = mm_inputs.pop('gt_cos_map')
|
||||||
|
|
||||||
# Test forward train
|
# # Test forward train
|
||||||
losses = detector.forward(
|
# losses = detector.forward(
|
||||||
imgs,
|
# imgs,
|
||||||
img_metas,
|
# img_metas,
|
||||||
gt_text_mask=gt_text_mask,
|
# gt_text_mask=gt_text_mask,
|
||||||
gt_center_region_mask=gt_center_region_mask,
|
# gt_center_region_mask=gt_center_region_mask,
|
||||||
gt_mask=gt_mask,
|
# gt_mask=gt_mask,
|
||||||
gt_radius_map=gt_radius_map,
|
# gt_radius_map=gt_radius_map,
|
||||||
gt_sin_map=gt_sin_map,
|
# gt_sin_map=gt_sin_map,
|
||||||
gt_cos_map=gt_cos_map)
|
# gt_cos_map=gt_cos_map)
|
||||||
assert isinstance(losses, dict)
|
# assert isinstance(losses, dict)
|
||||||
|
|
||||||
# Test forward test
|
# # Test forward test
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
img_list = [g[None, :] for g in imgs]
|
# img_list = [g[None, :] for g in imgs]
|
||||||
batch_results = []
|
# batch_results = []
|
||||||
for one_img, one_meta in zip(img_list, img_metas):
|
# for one_img, one_meta in zip(img_list, img_metas):
|
||||||
result = detector.forward([one_img], [[one_meta]],
|
# result = detector.forward([one_img], [[one_meta]],
|
||||||
return_loss=False)
|
# return_loss=False)
|
||||||
batch_results.append(result)
|
# batch_results.append(result)
|
||||||
|
|
||||||
# Test show result
|
# # Test show result
|
||||||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
# results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||||
img = np.random.rand(5, 5)
|
# img = np.random.rand(5, 5)
|
||||||
detector.show_result(img, results)
|
# detector.show_result(img, results)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user