2022-07-08 15:59:56 +00:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-07-13 11:52:02 +00:00
|
|
|
import copy
|
2022-07-08 15:59:56 +00:00
|
|
|
import unittest
|
2022-07-13 11:52:02 +00:00
|
|
|
from os.path import dirname, exists, join
|
2022-07-08 15:59:56 +00:00
|
|
|
from unittest import mock
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2022-07-13 11:52:02 +00:00
|
|
|
from mmengine import Config, ConfigDict
|
2022-07-08 15:59:56 +00:00
|
|
|
|
|
|
|
from mmocr.registry import MODELS
|
2022-07-13 11:52:02 +00:00
|
|
|
from mmocr.testing.data import create_dummy_textdet_inputs
|
2022-07-08 15:59:56 +00:00
|
|
|
from mmocr.utils import register_all_modules
|
|
|
|
|
|
|
|
|
2022-07-13 11:52:02 +00:00
|
|
|
class TestDRRG(unittest.TestCase):
|
2022-07-08 15:59:56 +00:00
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
cfg_path = 'textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py'
|
|
|
|
self.model_cfg = self._get_detector_cfg(cfg_path)
|
|
|
|
register_all_modules()
|
|
|
|
self.model = MODELS.build(self.model_cfg)
|
2022-07-13 11:52:02 +00:00
|
|
|
self.inputs = create_dummy_textdet_inputs(input_shape=(1, 3, 224, 224))
|
2022-07-08 15:59:56 +00:00
|
|
|
|
|
|
|
def _get_comp_attribs(self):
|
|
|
|
num_rois = 32
|
|
|
|
x = np.random.randint(4, 224, (num_rois, 1))
|
|
|
|
y = np.random.randint(4, 224, (num_rois, 1))
|
|
|
|
h = 4 * np.ones((num_rois, 1))
|
|
|
|
w = 4 * np.ones((num_rois, 1))
|
|
|
|
angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2
|
|
|
|
cos, sin = np.cos(angle), np.sin(angle)
|
|
|
|
comp_labels = np.random.randint(1, 3, (num_rois, 1))
|
|
|
|
num_rois = num_rois * np.ones((num_rois, 1))
|
|
|
|
comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels])
|
|
|
|
gt_comp_attribs = np.expand_dims(
|
|
|
|
comp_attribs.astype(np.float32), axis=0)
|
|
|
|
return gt_comp_attribs
|
|
|
|
|
|
|
|
def _get_drrg_inputs(self):
|
|
|
|
imgs = self.inputs['imgs']
|
|
|
|
data_samples = self.inputs['data_samples']
|
|
|
|
gt_text_mask = self.inputs['gt_text_mask']
|
|
|
|
gt_center_region_mask = self.inputs['gt_center_region_mask']
|
|
|
|
gt_mask = self.inputs['gt_mask']
|
|
|
|
gt_top_height_map = self.inputs['gt_radius_map']
|
|
|
|
gt_bot_height_map = gt_top_height_map.copy()
|
|
|
|
gt_sin_map = self.inputs['gt_sin_map']
|
|
|
|
gt_cos_map = self.inputs['gt_cos_map']
|
|
|
|
gt_comp_attribs = self._get_comp_attribs()
|
|
|
|
return imgs, data_samples, (gt_text_mask, gt_center_region_mask,
|
|
|
|
gt_mask, gt_top_height_map,
|
|
|
|
gt_bot_height_map, gt_sin_map, gt_cos_map,
|
|
|
|
gt_comp_attribs)
|
|
|
|
|
2022-07-14 06:14:52 +00:00
|
|
|
@mock.patch(
|
|
|
|
'mmocr.models.textdet.module_losses.drrg_module_loss.DRRGModuleLoss.'
|
|
|
|
'get_targets')
|
2022-07-08 15:59:56 +00:00
|
|
|
def test_loss(self, mock_get_targets):
|
|
|
|
imgs, data_samples, targets = self._get_drrg_inputs()
|
|
|
|
mock_get_targets.return_value = targets
|
|
|
|
losses = self.model(imgs, data_samples, mode='loss')
|
|
|
|
self.assertIsInstance(losses, dict)
|
|
|
|
|
|
|
|
@mock.patch('mmocr.models.textdet.detectors.drrg.DRRG.extract_feat')
|
|
|
|
def test_predict(self, mock_extract_feat):
|
|
|
|
model_cfg = self.model_cfg.copy()
|
|
|
|
model_cfg['det_head']['in_channels'] = 6
|
|
|
|
model_cfg['det_head']['text_region_thr'] = 0.8
|
|
|
|
model_cfg['det_head']['center_region_thr'] = 0.8
|
|
|
|
model = MODELS.build(model_cfg)
|
|
|
|
imgs, data_samples, _ = self._get_drrg_inputs()
|
|
|
|
|
|
|
|
maps = torch.zeros((1, 6, 224, 224), dtype=torch.float)
|
|
|
|
maps[:, 0:2, :, :] = -10.
|
|
|
|
maps[:, 0, 60:100, 50:170] = 10.
|
|
|
|
maps[:, 1, 75:85, 60:160] = 10.
|
|
|
|
maps[:, 2, 75:85, 60:160] = 0.
|
|
|
|
maps[:, 3, 75:85, 60:160] = 1.
|
|
|
|
maps[:, 4, 75:85, 60:160] = 10.
|
|
|
|
maps[:, 5, 75:85, 60:160] = 10.
|
|
|
|
mock_extract_feat.return_value = maps
|
|
|
|
with torch.no_grad():
|
|
|
|
full_pass_weight = torch.zeros((6, 6, 1, 1))
|
|
|
|
for i in range(6):
|
|
|
|
full_pass_weight[i, i, 0, 0] = 1
|
|
|
|
model.det_head.out_conv.weight.data = full_pass_weight
|
|
|
|
model.det_head.out_conv.bias.data.fill_(0.)
|
|
|
|
results = model(imgs, data_samples, mode='predict')
|
|
|
|
self.assertIn('polygons', results[0].pred_instances)
|
|
|
|
self.assertIn('scores', results[0].pred_instances)
|
|
|
|
self.assertTrue(
|
|
|
|
isinstance(results[0].pred_instances['scores'], torch.FloatTensor))
|
2022-07-13 11:52:02 +00:00
|
|
|
|
|
|
|
def _get_config_directory(self):
|
|
|
|
"""Find the predefined detector config directory."""
|
|
|
|
try:
|
|
|
|
# Assume we are running in the source mmocr repo
|
|
|
|
repo_dpath = dirname(dirname(dirname(dirname(dirname(__file__)))))
|
|
|
|
except NameError:
|
|
|
|
# For IPython development when this __file__ is not defined
|
|
|
|
import mmocr
|
|
|
|
repo_dpath = dirname(
|
|
|
|
dirname(dirname(dirname(dirname(mmocr.__file__)))))
|
|
|
|
config_dpath = join(repo_dpath, 'configs')
|
|
|
|
if not exists(config_dpath):
|
|
|
|
raise Exception('Cannot find config path')
|
|
|
|
return config_dpath
|
|
|
|
|
|
|
|
def _get_config_module(self, fname: str) -> 'ConfigDict':
|
|
|
|
"""Load a configuration as a python module."""
|
|
|
|
config_dpath = self._get_config_directory()
|
|
|
|
config_fpath = join(config_dpath, fname)
|
|
|
|
config_mod = Config.fromfile(config_fpath)
|
|
|
|
return config_mod
|
|
|
|
|
|
|
|
def _get_detector_cfg(self, fname: str) -> 'ConfigDict':
|
|
|
|
"""Grab necessary configs necessary to create a detector.
|
|
|
|
|
|
|
|
These are deep copied to allow for safe modification of parameters
|
|
|
|
without influencing other tests.
|
|
|
|
"""
|
|
|
|
config = self._get_config_module(fname)
|
|
|
|
model = copy.deepcopy(config.model)
|
|
|
|
return model
|