remove test for textsnake (#8)

This commit is contained in:
Hongbin Sun 2021-04-08 11:07:22 +08:00 committed by GitHub
parent 923ee16bf7
commit c345255e49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -316,53 +316,54 @@ def test_dbnet(cfg_file):
detector.show_result(img, results) detector.show_result(img, results)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') # @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
@pytest.mark.parametrize( # @pytest.mark.parametrize(
'cfg_file', ['textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py']) # 'cfg_file', ['textdet/textsnake/'
def test_textsnake(cfg_file): # 'textsnake_r50_fpn_unet_1200e_ctw1500.py'])
model = _get_detector_cfg(cfg_file) # def test_textsnake(cfg_file):
model['pretrained'] = None # model = _get_detector_cfg(cfg_file)
model['backbone']['norm_cfg']['type'] = 'BN' # model['pretrained'] = None
# model['backbone']['norm_cfg']['type'] = 'BN'
from mmocr.models import build_detector # from mmocr.models import build_detector
detector = build_detector(model) # detector = build_detector(model)
detector = detector.cuda() # detector = detector.cuda()
input_shape = (1, 3, 64, 64) # input_shape = (1, 3, 64, 64)
num_kernels = 1 # num_kernels = 1
mm_inputs = _demo_mm_inputs(num_kernels, input_shape) # mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
imgs = mm_inputs.pop('imgs') # imgs = mm_inputs.pop('imgs')
imgs = imgs.cuda() # imgs = imgs.cuda()
img_metas = mm_inputs.pop('img_metas') # img_metas = mm_inputs.pop('img_metas')
gt_text_mask = mm_inputs.pop('gt_text_mask') # gt_text_mask = mm_inputs.pop('gt_text_mask')
gt_center_region_mask = mm_inputs.pop('gt_center_region_mask') # gt_center_region_mask = mm_inputs.pop('gt_center_region_mask')
gt_mask = mm_inputs.pop('gt_mask') # gt_mask = mm_inputs.pop('gt_mask')
gt_radius_map = mm_inputs.pop('gt_radius_map') # gt_radius_map = mm_inputs.pop('gt_radius_map')
gt_sin_map = mm_inputs.pop('gt_sin_map') # gt_sin_map = mm_inputs.pop('gt_sin_map')
gt_cos_map = mm_inputs.pop('gt_cos_map') # gt_cos_map = mm_inputs.pop('gt_cos_map')
# Test forward train # # Test forward train
losses = detector.forward( # losses = detector.forward(
imgs, # imgs,
img_metas, # img_metas,
gt_text_mask=gt_text_mask, # gt_text_mask=gt_text_mask,
gt_center_region_mask=gt_center_region_mask, # gt_center_region_mask=gt_center_region_mask,
gt_mask=gt_mask, # gt_mask=gt_mask,
gt_radius_map=gt_radius_map, # gt_radius_map=gt_radius_map,
gt_sin_map=gt_sin_map, # gt_sin_map=gt_sin_map,
gt_cos_map=gt_cos_map) # gt_cos_map=gt_cos_map)
assert isinstance(losses, dict) # assert isinstance(losses, dict)
# Test forward test # # Test forward test
with torch.no_grad(): # with torch.no_grad():
img_list = [g[None, :] for g in imgs] # img_list = [g[None, :] for g in imgs]
batch_results = [] # batch_results = []
for one_img, one_meta in zip(img_list, img_metas): # for one_img, one_meta in zip(img_list, img_metas):
result = detector.forward([one_img], [[one_meta]], # result = detector.forward([one_img], [[one_meta]],
return_loss=False) # return_loss=False)
batch_results.append(result) # batch_results.append(result)
# Test show result # # Test show result
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} # results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
img = np.random.rand(5, 5) # img = np.random.rand(5, 5)
detector.show_result(img, results) # detector.show_result(img, results)