diff --git a/tests/test_models/test_detector.py b/tests/test_models/test_detector.py index 4a65bbc5..a5a255c1 100644 --- a/tests/test_models/test_detector.py +++ b/tests/test_models/test_detector.py @@ -316,54 +316,55 @@ def test_dbnet(cfg_file): detector.show_result(img, results) -# @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') -# @pytest.mark.parametrize( -# 'cfg_file', ['textdet/textsnake/' -# 'textsnake_r50_fpn_unet_1200e_ctw1500.py']) -# def test_textsnake(cfg_file): -# model = _get_detector_cfg(cfg_file) -# model['pretrained'] = None -# model['backbone']['norm_cfg']['type'] = 'BN' +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.parametrize( + 'cfg_file', + ['textdet/textsnake/' + 'textsnake_r50_fpn_unet_1200e_ctw1500.py']) +def test_textsnake(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + model['backbone']['norm_cfg']['type'] = 'BN' -# from mmocr.models import build_detector -# detector = build_detector(model) -# detector = detector.cuda() -# input_shape = (1, 3, 64, 64) -# num_kernels = 1 -# mm_inputs = _demo_mm_inputs(num_kernels, input_shape) + from mmocr.models import build_detector + detector = build_detector(model) + detector = detector.cuda() + input_shape = (1, 3, 64, 64) + num_kernels = 1 + mm_inputs = _demo_mm_inputs(num_kernels, input_shape) -# imgs = mm_inputs.pop('imgs') -# imgs = imgs.cuda() -# img_metas = mm_inputs.pop('img_metas') -# gt_text_mask = mm_inputs.pop('gt_text_mask') -# gt_center_region_mask = mm_inputs.pop('gt_center_region_mask') -# gt_mask = mm_inputs.pop('gt_mask') -# gt_radius_map = mm_inputs.pop('gt_radius_map') -# gt_sin_map = mm_inputs.pop('gt_sin_map') -# gt_cos_map = mm_inputs.pop('gt_cos_map') + imgs = mm_inputs.pop('imgs') + imgs = imgs.cuda() + img_metas = mm_inputs.pop('img_metas') + gt_text_mask = mm_inputs.pop('gt_text_mask') + gt_center_region_mask = mm_inputs.pop('gt_center_region_mask') + gt_mask = mm_inputs.pop('gt_mask') + gt_radius_map = mm_inputs.pop('gt_radius_map') + gt_sin_map = mm_inputs.pop('gt_sin_map') + gt_cos_map = mm_inputs.pop('gt_cos_map') -# # Test forward train -# losses = detector.forward( -# imgs, -# img_metas, -# gt_text_mask=gt_text_mask, -# gt_center_region_mask=gt_center_region_mask, -# gt_mask=gt_mask, -# gt_radius_map=gt_radius_map, -# gt_sin_map=gt_sin_map, -# gt_cos_map=gt_cos_map) -# assert isinstance(losses, dict) + # Test forward train + losses = detector.forward( + imgs, + img_metas, + gt_text_mask=gt_text_mask, + gt_center_region_mask=gt_center_region_mask, + gt_mask=gt_mask, + gt_radius_map=gt_radius_map, + gt_sin_map=gt_sin_map, + gt_cos_map=gt_cos_map) + assert isinstance(losses, dict) -# # Test forward test -# with torch.no_grad(): -# img_list = [g[None, :] for g in imgs] -# batch_results = [] -# for one_img, one_meta in zip(img_list, img_metas): -# result = detector.forward([one_img], [[one_meta]], -# return_loss=False) -# batch_results.append(result) + # Test forward test + # with torch.no_grad(): + # img_list = [g[None, :] for g in imgs] + # batch_results = [] + # for one_img, one_meta in zip(img_list, img_metas): + # result = detector.forward([one_img], [[one_meta]], + # return_loss=False) + # batch_results.append(result) -# # Test show result -# results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} -# img = np.random.rand(5, 5) -# detector.show_result(img, results) + # Test show result + results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} + img = np.random.rand(5, 5) + detector.show_result(img, results)