From 1a167ff31769dbe6eccd717cd4d3d09700f6c4e3 Mon Sep 17 00:00:00 2001 From: gaotongxiao Date: Thu, 14 Jul 2022 11:57:35 +0000 Subject: [PATCH] Migrate tests --- .dev_scripts/covignore.cfg | 11 +- configs/textrecog/crnn/crnn.py | 2 +- configs/textrecog/crnn/crnn_toy_dataset.py | 2 +- configs/textrecog/tps/crnn_tps.py | 2 +- mmocr/evaluation/functional/__init__.py | 4 + .../functional/hmean.py} | 3 +- mmocr/evaluation/functional/hmean_ic13.py | 219 -------- mmocr/evaluation/metrics/hmean_iou_metric.py | 4 +- mmocr/models/textrecog/__init__.py | 1 - .../textrecog/preprocessors/__init__.py | 5 - .../preprocessors/base_preprocessor.py | 12 - .../preprocessors/tps_preprocessor.py | 275 ---------- .../models/textrecog/recognizers/__init__.py | 4 +- mmocr/models/textrecog/recognizers/crnn.py | 2 +- mmocr/utils/__init__.py | 13 +- old_tests/test_core/test_end2end_vis.py | 25 - old_tests/test_metrics/test_eval_utils.py | 225 -------- old_tests/test_models/test_detector.py | 517 ------------------ old_tests/test_models/test_kie_config.py | 131 ----- old_tests/test_models/test_loss.py | 147 ----- old_tests/test_models/test_modules.py | 133 ----- old_tests/test_models/test_ocr_backbone.py | 147 ----- old_tests/test_models/test_ocr_layer.py | 63 --- .../test_models/test_ocr_preprocessor.py | 39 -- old_tests/test_models/test_recog_config.py | 157 ------ .../test_evaluation/functional/test_hmean.py | 23 + ..._hmean_iou.py => test_hmean_iou_metric.py} | 0 .../layers/test_transformer_layers.py | 37 ++ .../modules/test_transformer_module.py | 15 + .../test_textrecog/layers/test_conv_layer.py | 42 ++ .../test_backbones/test_resnet31_ocr.py | 29 + .../test_backbones/test_resnet_abi.py | 31 ++ .../test_backbones/test_very_deep_vgg.py | 19 + 33 files changed, 216 insertions(+), 2123 deletions(-) create mode 100644 mmocr/evaluation/functional/__init__.py rename mmocr/{utils/evaluation_utils.py => evaluation/functional/hmean.py} (92%) delete mode 100644 mmocr/evaluation/functional/hmean_ic13.py delete mode 100644 mmocr/models/textrecog/preprocessors/__init__.py delete mode 100644 mmocr/models/textrecog/preprocessors/base_preprocessor.py delete mode 100644 mmocr/models/textrecog/preprocessors/tps_preprocessor.py delete mode 100644 old_tests/test_core/test_end2end_vis.py delete mode 100644 old_tests/test_metrics/test_eval_utils.py delete mode 100644 old_tests/test_models/test_detector.py delete mode 100644 old_tests/test_models/test_kie_config.py delete mode 100644 old_tests/test_models/test_loss.py delete mode 100644 old_tests/test_models/test_modules.py delete mode 100644 old_tests/test_models/test_ocr_backbone.py delete mode 100644 old_tests/test_models/test_ocr_layer.py delete mode 100644 old_tests/test_models/test_ocr_preprocessor.py delete mode 100644 old_tests/test_models/test_recog_config.py create mode 100644 tests/test_evaluation/functional/test_hmean.py rename tests/test_evaluation/test_metrics/{test_hmean_iou.py => test_hmean_iou_metric.py} (100%) create mode 100644 tests/test_models/test_common/layers/test_transformer_layers.py create mode 100644 tests/test_models/test_common/modules/test_transformer_module.py create mode 100644 tests/test_models/test_textrecog/layers/test_conv_layer.py create mode 100644 tests/test_models/test_textrecog/test_backbones/test_resnet31_ocr.py create mode 100644 tests/test_models/test_textrecog/test_backbones/test_resnet_abi.py create mode 100644 tests/test_models/test_textrecog/test_backbones/test_very_deep_vgg.py diff --git a/.dev_scripts/covignore.cfg b/.dev_scripts/covignore.cfg index a5e7ce2a..213afd4b 100644 --- a/.dev_scripts/covignore.cfg +++ b/.dev_scripts/covignore.cfg @@ -8,17 +8,8 @@ # It will be removed after all models have been refactored mmocr/utils/bbox_utils.py -# It will be removed after all models have been refactored -mmocr/utils/ocr.py -mmocr/utils/evaluation_utils.py - -# Major part is coverd, however, it's hard to cover model's output. +# Major part is covered, however, it's hard to cover model's output. mmocr/models/textdet/detectors/mmdet_wrapper.py -# Cover it by tests seems like an impossible mission -mmocr/models/textdet/postprocessors/drrg_postprocessor.py - -# It will be removed after HmeanIc13Metric -mmocr/evaluation/functional/hmean_ic13.py # It will be removed after KieVisualizer and TextSpotterVisualizer mmocr/visualization/visualize.py diff --git a/configs/textrecog/crnn/crnn.py b/configs/textrecog/crnn/crnn.py index e7cdad9a..ba505968 100644 --- a/configs/textrecog/crnn/crnn.py +++ b/configs/textrecog/crnn/crnn.py @@ -4,7 +4,7 @@ dictionary = dict( with_padding=True) model = dict( - type='CRNNNet', + type='CRNN', preprocessor=None, backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), encoder=None, diff --git a/configs/textrecog/crnn/crnn_toy_dataset.py b/configs/textrecog/crnn/crnn_toy_dataset.py index f61c68af..b7f5f771 100644 --- a/configs/textrecog/crnn/crnn_toy_dataset.py +++ b/configs/textrecog/crnn/crnn_toy_dataset.py @@ -9,7 +9,7 @@ label_convertor = dict( type='CTCConvertor', dict_type='DICT36', with_unknown=True, lower=True) model = dict( - type='CRNNNet', + type='CRNN', preprocessor=None, backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), encoder=None, diff --git a/configs/textrecog/tps/crnn_tps.py b/configs/textrecog/tps/crnn_tps.py index 9fc09478..02b6c6ba 100644 --- a/configs/textrecog/tps/crnn_tps.py +++ b/configs/textrecog/tps/crnn_tps.py @@ -3,7 +3,7 @@ label_convertor = dict( type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True) model = dict( - type='CRNNNet', + type='CRNN', preprocessor=dict( type='TPSPreprocessor', num_fiducial=20, diff --git a/mmocr/evaluation/functional/__init__.py b/mmocr/evaluation/functional/__init__.py new file mode 100644 index 00000000..6aaf7576 --- /dev/null +++ b/mmocr/evaluation/functional/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hmean import compute_hmean + +__all__ = ['compute_hmean'] diff --git a/mmocr/utils/evaluation_utils.py b/mmocr/evaluation/functional/hmean.py similarity index 92% rename from mmocr/utils/evaluation_utils.py rename to mmocr/evaluation/functional/hmean.py index 821b54c9..d3aabf4c 100644 --- a/mmocr/utils/evaluation_utils.py +++ b/mmocr/evaluation/functional/hmean.py @@ -1,9 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -# TODO check whether to keep these utils after refactoring ic13 metrics def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num): - # TODO Add typehints & Test + # TODO Add typehints """Compute hmean given hit number, ground truth number and prediction number. diff --git a/mmocr/evaluation/functional/hmean_ic13.py b/mmocr/evaluation/functional/hmean_ic13.py deleted file mode 100644 index 6313118c..00000000 --- a/mmocr/evaluation/functional/hmean_ic13.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np - -import mmocr.utils as utils - - -# TODO replace with HmeanIc13Metric -def compute_recall_precision(gt_polys, pred_polys): - """Compute the recall and the precision matrices between gt and predicted - polygons. - - Args: - gt_polys (list[Polygon]): List of gt polygons. - pred_polys (list[Polygon]): List of predicted polygons. - - Returns: - recall (ndarray): Recall matrix of size gt_num x det_num. - precision (ndarray): Precision matrix of size gt_num x det_num. - """ - assert isinstance(gt_polys, list) - assert isinstance(pred_polys, list) - - gt_num = len(gt_polys) - det_num = len(pred_polys) - sz = [gt_num, det_num] - - recall = np.zeros(sz) - precision = np.zeros(sz) - # compute area recall and precision for each (gt, det) pair - # in one img - for gt_id in range(gt_num): - for pred_id in range(det_num): - gt = gt_polys[gt_id] - det = pred_polys[pred_id] - - inter_area = utils.poly_intersection(det, gt) - gt_area = gt.area - det_area = det.area - if gt_area != 0: - recall[gt_id, pred_id] = inter_area / gt_area - if det_area != 0: - precision[gt_id, pred_id] = inter_area / det_area - - return recall, precision - - -def eval_hmean_ic13(det_boxes, - gt_boxes, - gt_ignored_boxes, - precision_thr=0.4, - recall_thr=0.8, - center_dist_thr=1.0, - one2one_score=1., - one2many_score=0.8, - many2one_score=1.): - """Evaluate hmean of text detection using the icdar2013 standard. - - Args: - det_boxes (list[list[list[float]]]): List of arrays of shape (n, 2k). - Each element is the det_boxes for one img. k>=4. - gt_boxes (list[list[list[float]]]): List of arrays of shape (m, 2k). - Each element is the gt_boxes for one img. k>=4. - gt_ignored_boxes (list[list[list[float]]]): List of arrays of - (l, 2k). Each element is the ignored gt_boxes for one img. k>=4. - precision_thr (float): Precision threshold of the iou of one - (gt_box, det_box) pair. - recall_thr (float): Recall threshold of the iou of one - (gt_box, det_box) pair. - center_dist_thr (float): Distance threshold of one (gt_box, det_box) - center point pair. - one2one_score (float): Reward when one gt matches one det_box. - one2many_score (float): Reward when one gt matches many det_boxes. - many2one_score (float): Reward when many gts match one det_box. - - Returns: - hmean (tuple[dict]): Tuple of dicts which encodes the hmean for - the dataset and all images. - """ - assert utils.is_3dlist(det_boxes) - assert utils.is_3dlist(gt_boxes) - assert utils.is_3dlist(gt_ignored_boxes) - - assert 0 <= precision_thr <= 1 - assert 0 <= recall_thr <= 1 - assert center_dist_thr > 0 - assert 0 <= one2one_score <= 1 - assert 0 <= one2many_score <= 1 - assert 0 <= many2one_score <= 1 - - img_num = len(det_boxes) - assert img_num == len(gt_boxes) - assert img_num == len(gt_ignored_boxes) - - dataset_gt_num = 0 - dataset_pred_num = 0 - dataset_hit_recall = 0.0 - dataset_hit_prec = 0.0 - - img_results = [] - - for i in range(img_num): - gt = gt_boxes[i] - gt_ignored = gt_ignored_boxes[i] - pred = det_boxes[i] - - gt_num = len(gt) - ignored_num = len(gt_ignored) - pred_num = len(pred) - - accum_recall = 0. - accum_precision = 0. - - gt_points = gt + gt_ignored - gt_polys = [utils.poly2shapely(p) for p in gt_points] - gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))] - gt_num = len(gt_polys) - - pred_polys, pred_points, pred_ignored_index = utils.ignore_pred( - pred, gt_ignored_index, gt_polys, precision_thr) - - if pred_num > 0 and gt_num > 0: - - gt_hit = np.zeros(gt_num, np.int8).tolist() - pred_hit = np.zeros(pred_num, np.int8).tolist() - - # compute area recall and precision for each (gt, pred) pair - # in one img. - recall_mat, precision_mat = compute_recall_precision( - gt_polys, pred_polys) - - # match one gt to one pred box. - for gt_id in range(gt_num): - for pred_id in range(pred_num): - if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0 - or gt_id in gt_ignored_index - or pred_id in pred_ignored_index): - continue - match = utils.one2one_match_ic13(gt_id, pred_id, - recall_mat, precision_mat, - recall_thr, precision_thr) - - if match: - gt_point = np.array(gt_points[gt_id]) - det_point = np.array(pred_points[pred_id]) - - norm_dist = utils.bbox_center_distance( - det_point, gt_point) - norm_dist /= utils.bbox_diag( - det_point) + utils.bbox_diag(gt_point) - norm_dist *= 2.0 - - if norm_dist < center_dist_thr: - gt_hit[gt_id] = 1 - pred_hit[pred_id] = 1 - accum_recall += one2one_score - accum_precision += one2one_score - - # match one gt to many det boxes. - for gt_id in range(gt_num): - if gt_id in gt_ignored_index: - continue - match, match_det_set = utils.one2many_match_ic13( - gt_id, recall_mat, precision_mat, recall_thr, - precision_thr, gt_hit, pred_hit, pred_ignored_index) - - if match: - gt_hit[gt_id] = 1 - accum_recall += one2many_score - accum_precision += one2many_score * len(match_det_set) - for pred_id in match_det_set: - pred_hit[pred_id] = 1 - - # match many gt to one det box. One pair of (det,gt) are matched - # successfully if their recall, precision, normalized distance - # meet some thresholds. - for pred_id in range(pred_num): - if pred_id in pred_ignored_index: - continue - - match, match_gt_set = utils.many2one_match_ic13( - pred_id, recall_mat, precision_mat, recall_thr, - precision_thr, gt_hit, pred_hit, gt_ignored_index) - - if match: - pred_hit[pred_id] = 1 - accum_recall += many2one_score * len(match_gt_set) - accum_precision += many2one_score - for gt_id in match_gt_set: - gt_hit[gt_id] = 1 - - gt_care_number = gt_num - ignored_num - pred_care_number = pred_num - len(pred_ignored_index) - - r, p, h = utils.compute_hmean(accum_recall, accum_precision, - gt_care_number, pred_care_number) - - img_results.append({'recall': r, 'precision': p, 'hmean': h}) - - dataset_gt_num += gt_care_number - dataset_pred_num += pred_care_number - dataset_hit_recall += accum_recall - dataset_hit_prec += accum_precision - - total_r, total_p, total_h = utils.compute_hmean(dataset_hit_recall, - dataset_hit_prec, - dataset_gt_num, - dataset_pred_num) - - dataset_results = { - 'num_gts': dataset_gt_num, - 'num_dets': dataset_pred_num, - 'num_recall': dataset_hit_recall, - 'num_precision': dataset_hit_prec, - 'recall': total_r, - 'precision': total_p, - 'hmean': total_h - } - - return dataset_results, img_results diff --git a/mmocr/evaluation/metrics/hmean_iou_metric.py b/mmocr/evaluation/metrics/hmean_iou_metric.py index 69018537..42b36893 100644 --- a/mmocr/evaluation/metrics/hmean_iou_metric.py +++ b/mmocr/evaluation/metrics/hmean_iou_metric.py @@ -9,9 +9,9 @@ from scipy.sparse import csr_matrix from scipy.sparse.csgraph import maximum_bipartite_matching from shapely.geometry import Polygon +from mmocr.evaluation.functional import compute_hmean from mmocr.registry import METRICS -from mmocr.utils import (compute_hmean, poly_intersection, poly_iou, - polys2shapely) +from mmocr.utils import poly_intersection, poly_iou, polys2shapely @METRICS.register_module() diff --git a/mmocr/models/textrecog/__init__.py b/mmocr/models/textrecog/__init__.py index c688eded..f3cb906c 100644 --- a/mmocr/models/textrecog/__init__.py +++ b/mmocr/models/textrecog/__init__.py @@ -6,5 +6,4 @@ from .dictionary import * # NOQA from .encoders import * # NOQA from .plugins import * # NOQA from .postprocessors import * # NOQA -from .preprocessors import * # NOQA from .recognizers import * # NOQA diff --git a/mmocr/models/textrecog/preprocessors/__init__.py b/mmocr/models/textrecog/preprocessors/__init__.py deleted file mode 100644 index 57ea828a..00000000 --- a/mmocr/models/textrecog/preprocessors/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base_preprocessor import BasePreprocessor -from .tps_preprocessor import TPSPreprocessor - -__all__ = ['BasePreprocessor', 'TPSPreprocessor'] diff --git a/mmocr/models/textrecog/preprocessors/base_preprocessor.py b/mmocr/models/textrecog/preprocessors/base_preprocessor.py deleted file mode 100644 index bf6a6520..00000000 --- a/mmocr/models/textrecog/preprocessors/base_preprocessor.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmcv.runner import BaseModule - -from mmocr.registry import MODELS - - -@MODELS.register_module() -class BasePreprocessor(BaseModule): - """Base Preprocessor class for text recognition.""" - - def forward(self, x, **kwargs): - return x diff --git a/mmocr/models/textrecog/preprocessors/tps_preprocessor.py b/mmocr/models/textrecog/preprocessors/tps_preprocessor.py deleted file mode 100644 index e34c28cc..00000000 --- a/mmocr/models/textrecog/preprocessors/tps_preprocessor.py +++ /dev/null @@ -1,275 +0,0 @@ -# Modified from https://github.com/clovaai/deep-text-recognition-benchmark -# -# Licensed under the Apache License, Version 2.0 (the "License");s -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from mmocr.registry import MODELS -from .base_preprocessor import BasePreprocessor - - -@MODELS.register_module() -class TPSPreprocessor(BasePreprocessor): - """Rectification Network of RARE, namely TPS based STN in - https://arxiv.org/pdf/1603.03915.pdf. - - Args: - num_fiducial (int): Number of fiducial points of TPS-STN. - img_size (tuple(int, int)): Size :math:`(H, W)` of the input image. - rectified_img_size (tuple(int, int)): Size :math:`(H_r, W_r)` of - the rectified image. - num_img_channel (int): Number of channels of the input image. - init_cfg (dict or list[dict], optional): Initialization configs. - """ - - def __init__(self, - num_fiducial=20, - img_size=(32, 100), - rectified_img_size=(32, 100), - num_img_channel=1, - init_cfg=None): - super().__init__(init_cfg=init_cfg) - assert isinstance(num_fiducial, int) - assert num_fiducial > 0 - assert isinstance(img_size, tuple) - assert isinstance(rectified_img_size, tuple) - assert isinstance(num_img_channel, int) - - self.num_fiducial = num_fiducial - self.img_size = img_size - self.rectified_img_size = rectified_img_size - self.num_img_channel = num_img_channel - self.LocalizationNetwork = LocalizationNetwork(self.num_fiducial, - self.num_img_channel) - self.GridGenerator = GridGenerator(self.num_fiducial, - self.rectified_img_size) - - def forward(self, batch_img): - """ - Args: - batch_img (Tensor): Images to be rectified with size - :math:`(N, C, H, W)`. - - Returns: - Tensor: Rectified image with size :math:`(N, C, H_r, W_r)`. - """ - batch_C_prime = self.LocalizationNetwork( - batch_img) # batch_size x K x 2 - build_P_prime = self.GridGenerator.build_P_prime( - batch_C_prime, batch_img.device - ) # batch_size x n (= rectified_img_width x rectified_img_height) x 2 - build_P_prime_reshape = build_P_prime.reshape([ - build_P_prime.size(0), self.rectified_img_size[0], - self.rectified_img_size[1], 2 - ]) - - batch_rectified_img = F.grid_sample( - batch_img, - build_P_prime_reshape, - padding_mode='border', - align_corners=True) - - return batch_rectified_img - - -class LocalizationNetwork(nn.Module): - """Localization Network of RARE, which predicts C' (K x 2) from input - (img_width x img_height) - - Args: - num_fiducial (int): Number of fiducial points of TPS-STN. - num_img_channel (int): Number of channels of the input image. - """ - - def __init__(self, num_fiducial, num_img_channel): - super().__init__() - self.num_fiducial = num_fiducial - self.num_img_channel = num_img_channel - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=self.num_img_channel, - out_channels=64, - kernel_size=3, - stride=1, - padding=1, - bias=False), - nn.BatchNorm2d(64), - nn.ReLU(True), - nn.MaxPool2d(2, 2), # batch_size x 64 x img_height/2 x img_width/2 - nn.Conv2d(64, 128, 3, 1, 1, bias=False), - nn.BatchNorm2d(128), - nn.ReLU(True), - nn.MaxPool2d(2, 2), # batch_size x 128 x img_h/4 x img_w/4 - nn.Conv2d(128, 256, 3, 1, 1, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(True), - nn.MaxPool2d(2, 2), # batch_size x 256 x img_h/8 x img_w/8 - nn.Conv2d(256, 512, 3, 1, 1, bias=False), - nn.BatchNorm2d(512), - nn.ReLU(True), - nn.AdaptiveAvgPool2d(1) # batch_size x 512 - ) - - self.localization_fc1 = nn.Sequential( - nn.Linear(512, 256), nn.ReLU(True)) - self.localization_fc2 = nn.Linear(256, self.num_fiducial * 2) - - # Init fc2 in LocalizationNetwork - self.localization_fc2.weight.data.fill_(0) - ctrl_pts_x = np.linspace(-1.0, 1.0, int(num_fiducial / 2)) - ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(num_fiducial / 2)) - ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(num_fiducial / 2)) - ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) - ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) - initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) - self.localization_fc2.bias.data = torch.from_numpy( - initial_bias).float().view(-1) - - def forward(self, batch_img): - """ - Args: - batch_img (Tensor): Batch input image of shape - :math:`(N, C, H, W)`. - - Returns: - Tensor: Predicted coordinates of fiducial points for input batch. - The shape is :math:`(N, F, 2)` where :math:`F` is ``num_fiducial``. - """ - batch_size = batch_img.size(0) - features = self.conv(batch_img).view(batch_size, -1) - batch_C_prime = self.localization_fc2( - self.localization_fc1(features)).view(batch_size, - self.num_fiducial, 2) - return batch_C_prime - - -class GridGenerator(nn.Module): - """Grid Generator of RARE, which produces P_prime by multiplying T with P. - - Args: - num_fiducial (int): Number of fiducial points of TPS-STN. - rectified_img_size (tuple(int, int)): - Size :math:`(H_r, W_r)` of the rectified image. - """ - - def __init__(self, num_fiducial, rectified_img_size): - """Generate P_hat and inv_delta_C for later.""" - super().__init__() - self.eps = 1e-6 - self.rectified_img_height = rectified_img_size[0] - self.rectified_img_width = rectified_img_size[1] - self.num_fiducial = num_fiducial - self.C = self._build_C(self.num_fiducial) # num_fiducial x 2 - self.P = self._build_P(self.rectified_img_width, - self.rectified_img_height) - # for multi-gpu, you need register buffer - self.register_buffer( - 'inv_delta_C', - torch.tensor(self._build_inv_delta_C( - self.num_fiducial, - self.C)).float()) # num_fiducial+3 x num_fiducial+3 - self.register_buffer('P_hat', - torch.tensor( - self._build_P_hat( - self.num_fiducial, self.C, - self.P)).float()) # n x num_fiducial+3 - # for fine-tuning with different image width, - # you may use below instead of self.register_buffer - # self.inv_delta_C = torch.tensor( - # self._build_inv_delta_C( - # self.num_fiducial, - # self.C)).float().cuda() # num_fiducial+3 x num_fiducial+3 - # self.P_hat = torch.tensor( - # self._build_P_hat(self.num_fiducial, self.C, - # self.P)).float().cuda() # n x num_fiducial+3 - - def _build_C(self, num_fiducial): - """Return coordinates of fiducial points in rectified_img; C.""" - ctrl_pts_x = np.linspace(-1.0, 1.0, int(num_fiducial / 2)) - ctrl_pts_y_top = -1 * np.ones(int(num_fiducial / 2)) - ctrl_pts_y_bottom = np.ones(int(num_fiducial / 2)) - ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) - ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) - C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) - return C # num_fiducial x 2 - - def _build_inv_delta_C(self, num_fiducial, C): - """Return inv_delta_C which is needed to calculate T.""" - hat_C = np.zeros((num_fiducial, num_fiducial), dtype=float) - for i in range(0, num_fiducial): - for j in range(i, num_fiducial): - r = np.linalg.norm(C[i] - C[j]) - hat_C[i, j] = r - hat_C[j, i] = r - np.fill_diagonal(hat_C, 1) - hat_C = (hat_C**2) * np.log(hat_C) - # print(C.shape, hat_C.shape) - delta_C = np.concatenate( # num_fiducial+3 x num_fiducial+3 - [ - np.concatenate([np.ones((num_fiducial, 1)), C, hat_C], - axis=1), # num_fiducial x num_fiducial+3 - np.concatenate([np.zeros( - (2, 3)), np.transpose(C)], axis=1), # 2 x num_fiducial+3 - np.concatenate([np.zeros( - (1, 3)), np.ones((1, num_fiducial))], - axis=1) # 1 x num_fiducial+3 - ], - axis=0) - inv_delta_C = np.linalg.inv(delta_C) - return inv_delta_C # num_fiducial+3 x num_fiducial+3 - - def _build_P(self, rectified_img_width, rectified_img_height): - rectified_img_grid_x = ( - np.arange(-rectified_img_width, rectified_img_width, 2) + - 1.0) / rectified_img_width # self.rectified_img_width - rectified_img_grid_y = ( - np.arange(-rectified_img_height, rectified_img_height, 2) + - 1.0) / rectified_img_height # self.rectified_img_height - P = np.stack( # self.rectified_img_w x self.rectified_img_h x 2 - np.meshgrid(rectified_img_grid_x, rectified_img_grid_y), - axis=2) - return P.reshape([ - -1, 2 - ]) # n (= self.rectified_img_width x self.rectified_img_height) x 2 - - def _build_P_hat(self, num_fiducial, C, P): - n = P.shape[ - 0] # n (= self.rectified_img_width x self.rectified_img_height) - P_tile = np.tile(np.expand_dims(P, axis=1), - (1, num_fiducial, - 1)) # n x 2 -> n x 1 x 2 -> n x num_fiducial x 2 - C_tile = np.expand_dims(C, axis=0) # 1 x num_fiducial x 2 - P_diff = P_tile - C_tile # n x num_fiducial x 2 - rbf_norm = np.linalg.norm( - P_diff, ord=2, axis=2, keepdims=False) # n x num_fiducial - rbf = np.multiply(np.square(rbf_norm), - np.log(rbf_norm + self.eps)) # n x num_fiducial - P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) - return P_hat # n x num_fiducial+3 - - def build_P_prime(self, batch_C_prime, device='cuda'): - """Generate Grid from batch_C_prime [batch_size x num_fiducial x 2]""" - batch_size = batch_C_prime.size(0) - batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) - batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) - batch_C_prime_with_zeros = torch.cat( - (batch_C_prime, torch.zeros(batch_size, 3, 2).float().to(device)), - dim=1) # batch_size x num_fiducial+3 x 2 - batch_T = torch.bmm( - batch_inv_delta_C, - batch_C_prime_with_zeros) # batch_size x num_fiducial+3 x 2 - batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 - return batch_P_prime # batch_size x n x 2 diff --git a/mmocr/models/textrecog/recognizers/__init__.py b/mmocr/models/textrecog/recognizers/__init__.py index b9b7835f..dec944e2 100644 --- a/mmocr/models/textrecog/recognizers/__init__.py +++ b/mmocr/models/textrecog/recognizers/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .abinet import ABINet from .base import BaseRecognizer -from .crnn import CRNNNet +from .crnn import CRNN from .encoder_decoder_recognizer import EncoderDecoderRecognizer from .master import MASTER from .nrtr import NRTR @@ -10,6 +10,6 @@ from .sar import SARNet from .satrn import SATRN __all__ = [ - 'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNNNet', 'SARNet', 'NRTR', + 'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR', 'RobustScanner', 'SATRN', 'ABINet', 'MASTER' ] diff --git a/mmocr/models/textrecog/recognizers/crnn.py b/mmocr/models/textrecog/recognizers/crnn.py index f9a53f52..61d6853d 100644 --- a/mmocr/models/textrecog/recognizers/crnn.py +++ b/mmocr/models/textrecog/recognizers/crnn.py @@ -4,5 +4,5 @@ from .encoder_decoder_recognizer import EncoderDecoderRecognizer @MODELS.register_module() -class CRNNNet(EncoderDecoderRecognizer): +class CRNN(EncoderDecoderRecognizer): """CTC-loss based recognizer.""" diff --git a/mmocr/utils/__init__.py b/mmocr/utils/__init__.py index d1196116..5c3ee2a5 100644 --- a/mmocr/utils/__init__.py +++ b/mmocr/utils/__init__.py @@ -8,7 +8,6 @@ from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type, is_type_list, valid_boundary) from .collect_env import collect_env from .data_converter_utils import dump_ocr_data, recog_anno_to_imginfo -from .evaluation_utils import compute_hmean from .fileio import list_from_file, list_to_file from .img_utils import crop_img, warp_img from .mask_utils import fill_hole @@ -39,10 +38,10 @@ __all__ = [ 'poly_iou', 'poly_make_valid', 'poly_union', 'poly2shapely', 'polys2shapely', 'register_all_modules', 'offset_polygon', 'sort_vertex8', 'sort_vertex', 'bbox_center_distance', 'bbox_diag_distance', - 'compute_hmean', 'boundary_iou', 'point_distance', 'points_center', - 'fill_hole', 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', - 'warp_img', 'ConfigType', 'DetSampleList', 'RecForwardResults', - 'InitConfigType', 'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', - 'OptMultiConfig', 'OptRecSampleList', 'RecSampleList', 'MultiConfig', - 'OptTensor', 'ColorType', 'OptKIESampleList', 'KIESampleList' + 'boundary_iou', 'point_distance', 'points_center', 'fill_hole', + 'LineJsonParser', 'LineStrParser', 'shapely2poly', 'crop_img', 'warp_img', + 'ConfigType', 'DetSampleList', 'RecForwardResults', 'InitConfigType', + 'OptConfigType', 'OptDetSampleList', 'OptInitConfigType', 'OptMultiConfig', + 'OptRecSampleList', 'RecSampleList', 'MultiConfig', 'OptTensor', + 'ColorType', 'OptKIESampleList', 'KIESampleList' ] diff --git a/old_tests/test_core/test_end2end_vis.py b/old_tests/test_core/test_end2end_vis.py deleted file mode 100644 index 2e7a6812..00000000 --- a/old_tests/test_core/test_end2end_vis.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np - -from mmocr.core import det_recog_show_result - - -def test_det_recog_show_result(): - img = np.ones((100, 100, 3), dtype=np.uint8) * 255 - det_recog_res = { - 'result': [{ - 'box': [51, 88, 51, 62, 85, 62, 85, 88], - 'box_score': 0.9417, - 'text': 'hell', - 'text_score': 0.8834 - }] - } - - vis_img = det_recog_show_result(img, det_recog_res) - - assert vis_img.shape[0] == 100 - assert vis_img.shape[1] == 200 - assert vis_img.shape[2] == 3 - - det_recog_res['result'][0]['text'] = '中文' - det_recog_show_result(img, det_recog_res) diff --git a/old_tests/test_metrics/test_eval_utils.py b/old_tests/test_metrics/test_eval_utils.py deleted file mode 100644 index bd2144dc..00000000 --- a/old_tests/test_metrics/test_eval_utils.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""Tests the utils of evaluation.""" -import numpy as np -import pytest - -import mmocr.core.evaluation.utils as utils - - -def test_compute_hmean(): - - # test invalid arguments - with pytest.raises(AssertionError): - utils.compute_hmean(0, 0, 0.0, 0) - with pytest.raises(AssertionError): - utils.compute_hmean(0, 0, 0, 0.0) - with pytest.raises(AssertionError): - utils.compute_hmean([1], 0, 0, 0) - with pytest.raises(AssertionError): - utils.compute_hmean(0, [1], 0, 0) - - _, _, hmean = utils.compute_hmean(2, 2, 2, 2) - assert hmean == 1 - - _, _, hmean = utils.compute_hmean(0, 0, 2, 2) - assert hmean == 0 - - -def test_box_center_distance(): - p1 = np.array([1, 1, 3, 3]) - p2 = np.array([2, 2, 4, 2]) - - assert utils.box_center_distance(p1, p2) == 1 - - -def test_box_diag(): - # test unsupported type - with pytest.raises(AssertionError): - utils.box_diag([1, 2]) - with pytest.raises(AssertionError): - utils.box_diag(np.array([1, 2, 3, 4])) - - box = np.array([0, 0, 1, 1, 0, 10, -10, 0]) - - assert utils.box_diag(box) == 10 - - -def test_one2one_match_ic13(): - gt_id = 0 - det_id = 0 - recall_mat = np.array([[1, 0], [0, 0]]) - precision_mat = np.array([[1, 0], [0, 0]]) - recall_thr = 0.5 - precision_thr = 0.5 - # test invalid arguments. - with pytest.raises(AssertionError): - utils.one2one_match_ic13(0.0, det_id, recall_mat, precision_mat, - recall_thr, precision_thr) - with pytest.raises(AssertionError): - utils.one2one_match_ic13(gt_id, 0.0, recall_mat, precision_mat, - recall_thr, precision_thr) - with pytest.raises(AssertionError): - utils.one2one_match_ic13(gt_id, det_id, [0, 0], precision_mat, - recall_thr, precision_thr) - with pytest.raises(AssertionError): - utils.one2one_match_ic13(gt_id, det_id, recall_mat, [0, 0], recall_thr, - precision_thr) - with pytest.raises(AssertionError): - utils.one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, 1.1, - precision_thr) - with pytest.raises(AssertionError): - utils.one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, - recall_thr, 1.1) - - assert utils.one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, - recall_thr, precision_thr) - recall_mat = np.array([[1, 0], [0.6, 0]]) - precision_mat = np.array([[1, 0], [0.6, 0]]) - assert not utils.one2one_match_ic13( - gt_id, det_id, recall_mat, precision_mat, recall_thr, precision_thr) - recall_mat = np.array([[1, 0.6], [0, 0]]) - precision_mat = np.array([[1, 0.6], [0, 0]]) - assert not utils.one2one_match_ic13( - gt_id, det_id, recall_mat, precision_mat, recall_thr, precision_thr) - - -def test_one2many_match_ic13(): - gt_id = 0 - recall_mat = np.array([[1, 0], [0, 0]]) - precision_mat = np.array([[1, 0], [0, 0]]) - recall_thr = 0.5 - precision_thr = 0.5 - gt_match_flag = [0, 0] - det_match_flag = [0, 0] - det_dont_care_index = [] - # test invalid arguments. - with pytest.raises(AssertionError): - gt_id_tmp = 0.0 - utils.one2many_match_ic13(gt_id_tmp, recall_mat, precision_mat, - recall_thr, precision_thr, gt_match_flag, - det_match_flag, det_dont_care_index) - with pytest.raises(AssertionError): - recall_mat_tmp = [1, 0] - utils.one2many_match_ic13(gt_id, recall_mat_tmp, precision_mat, - recall_thr, precision_thr, gt_match_flag, - det_match_flag, det_dont_care_index) - with pytest.raises(AssertionError): - precision_mat_tmp = [1, 0] - utils.one2many_match_ic13(gt_id, recall_mat, precision_mat_tmp, - recall_thr, precision_thr, gt_match_flag, - det_match_flag, det_dont_care_index) - with pytest.raises(AssertionError): - - utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, 1.1, - precision_thr, gt_match_flag, det_match_flag, - det_dont_care_index) - with pytest.raises(AssertionError): - - utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, - 1.1, gt_match_flag, det_match_flag, - det_dont_care_index) - with pytest.raises(AssertionError): - gt_match_flag_tmp = np.array([0, 1]) - utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, - precision_thr, gt_match_flag_tmp, - det_match_flag, det_dont_care_index) - with pytest.raises(AssertionError): - det_match_flag_tmp = np.array([0, 1]) - utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, - precision_thr, gt_match_flag, - det_match_flag_tmp, det_dont_care_index) - with pytest.raises(AssertionError): - det_dont_care_index_tmp = np.array([0, 1]) - utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, - precision_thr, gt_match_flag, det_match_flag, - det_dont_care_index_tmp) - - # test matched case - - result = utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, - recall_thr, precision_thr, - gt_match_flag, det_match_flag, - det_dont_care_index) - assert result[0] - assert result[1] == [0] - - # test unmatched case - gt_match_flag_tmp = [1, 0] - result = utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, - recall_thr, precision_thr, - gt_match_flag_tmp, det_match_flag, - det_dont_care_index) - assert not result[0] - assert result[1] == [] - - -def test_many2one_match_ic13(): - det_id = 0 - recall_mat = np.array([[1, 0], [0, 0]]) - precision_mat = np.array([[1, 0], [0, 0]]) - recall_thr = 0.5 - precision_thr = 0.5 - gt_match_flag = [0, 0] - det_match_flag = [0, 0] - gt_dont_care_index = [] - # test invalid arguments. - with pytest.raises(AssertionError): - det_id_tmp = 1.0 - utils.many2one_match_ic13(det_id_tmp, recall_mat, precision_mat, - recall_thr, precision_thr, gt_match_flag, - det_match_flag, gt_dont_care_index) - with pytest.raises(AssertionError): - recall_mat_tmp = [[1, 0], [0, 0]] - utils.many2one_match_ic13(det_id, recall_mat_tmp, precision_mat, - recall_thr, precision_thr, gt_match_flag, - det_match_flag, gt_dont_care_index) - with pytest.raises(AssertionError): - precision_mat_tmp = [[1, 0], [0, 0]] - utils.many2one_match_ic13(det_id, recall_mat, precision_mat_tmp, - recall_thr, precision_thr, gt_match_flag, - det_match_flag, gt_dont_care_index) - with pytest.raises(AssertionError): - recall_thr_tmp = 1.1 - utils.many2one_match_ic13(det_id, recall_mat, precision_mat, - recall_thr_tmp, precision_thr, gt_match_flag, - det_match_flag, gt_dont_care_index) - with pytest.raises(AssertionError): - precision_thr_tmp = 1.1 - utils.many2one_match_ic13(det_id, recall_mat, precision_mat, - recall_thr, precision_thr_tmp, gt_match_flag, - det_match_flag, gt_dont_care_index) - with pytest.raises(AssertionError): - gt_match_flag_tmp = np.array([0, 1]) - utils.many2one_match_ic13(det_id, recall_mat, precision_mat, - recall_thr, precision_thr, gt_match_flag_tmp, - det_match_flag, gt_dont_care_index) - with pytest.raises(AssertionError): - det_match_flag_tmp = np.array([0, 1]) - utils.many2one_match_ic13(det_id, recall_mat, precision_mat, - recall_thr, precision_thr, gt_match_flag, - det_match_flag_tmp, gt_dont_care_index) - with pytest.raises(AssertionError): - gt_dont_care_index_tmp = np.array([0, 1]) - utils.many2one_match_ic13(det_id, recall_mat, precision_mat, - recall_thr, precision_thr, gt_match_flag, - det_match_flag, gt_dont_care_index_tmp) - - # test matched cases - - result = utils.many2one_match_ic13(det_id, recall_mat, precision_mat, - recall_thr, precision_thr, - gt_match_flag, det_match_flag, - gt_dont_care_index) - assert result[0] - assert result[1] == [0] - - # test unmatched cases - - gt_dont_care_index = [0] - - result = utils.many2one_match_ic13(det_id, recall_mat, precision_mat, - recall_thr, precision_thr, - gt_match_flag, det_match_flag, - gt_dont_care_index) - assert not result[0] - assert result[1] == [] diff --git a/old_tests/test_models/test_detector.py b/old_tests/test_models/test_detector.py deleted file mode 100644 index 474cd8af..00000000 --- a/old_tests/test_models/test_detector.py +++ /dev/null @@ -1,517 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""pytest tests/test_detector.py.""" -import copy -import tempfile -from functools import partial -from os.path import dirname, exists, join - -import numpy as np -import pytest -import torch - -from mmocr.utils import revert_sync_batchnorm - - -def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300), - num_items=None, num_classes=1): # yapf: disable - """Create a superset of inputs needed to run test or train batches. - - Args: - input_shape (tuple): Input batch dimensions. - - num_items (None | list[int]): Specifies the number of boxes - for each batch item. - - num_classes (int): Number of distinct labels a box might have. - """ - from mmdet.core import BitmapMasks - - (N, C, H, W) = input_shape - - rng = np.random.RandomState(0) - - imgs = rng.rand(*input_shape) - - img_metas = [{ - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'pad_shape': (H, W, C), - 'filename': '.png', - 'scale_factor': np.array([1, 1, 1, 1]), - 'flip': False, - } for _ in range(N)] - - gt_bboxes = [] - gt_labels = [] - gt_masks = [] - gt_kernels = [] - gt_effective_mask = [] - - for batch_idx in range(N): - if num_items is None: - num_boxes = rng.randint(1, 10) - else: - num_boxes = num_items[batch_idx] - - cx, cy, bw, bh = rng.rand(num_boxes, 4).T - - tl_x = ((cx * W) - (W * bw / 2)).clip(0, W) - tl_y = ((cy * H) - (H * bh / 2)).clip(0, H) - br_x = ((cx * W) + (W * bw / 2)).clip(0, W) - br_y = ((cy * H) + (H * bh / 2)).clip(0, H) - - boxes = np.vstack([tl_x, tl_y, br_x, br_y]).T - class_idxs = [0] * num_boxes - - gt_bboxes.append(torch.FloatTensor(boxes)) - gt_labels.append(torch.LongTensor(class_idxs)) - kernels = [] - for kernel_inx in range(num_kernels): - kernel = np.random.rand(H, W) - kernels.append(kernel) - gt_kernels.append(BitmapMasks(kernels, H, W)) - gt_effective_mask.append(BitmapMasks([np.ones((H, W))], H, W)) - - mask = np.random.randint(0, 2, (len(boxes), H, W), dtype=np.uint8) - gt_masks.append(BitmapMasks(mask, H, W)) - - mm_inputs = { - 'imgs': torch.FloatTensor(imgs).requires_grad_(True), - 'img_metas': img_metas, - 'gt_bboxes': gt_bboxes, - 'gt_labels': gt_labels, - 'gt_bboxes_ignore': None, - 'gt_masks': gt_masks, - 'gt_kernels': gt_kernels, - 'gt_mask': gt_effective_mask, - 'gt_thr_mask': gt_effective_mask, - 'gt_text_mask': gt_effective_mask, - 'gt_center_region_mask': gt_effective_mask, - 'gt_radius_map': gt_kernels, - 'gt_sin_map': gt_kernels, - 'gt_cos_map': gt_kernels, - } - return mm_inputs - - -def _get_config_directory(): - """Find the predefined detector config directory.""" - try: - # Assume we are running in the source mmocr repo - repo_dpath = dirname(dirname(dirname(__file__))) - except NameError: - # For IPython development when this __file__ is not defined - import mmocr - repo_dpath = dirname(dirname(mmocr.__file__)) - config_dpath = join(repo_dpath, 'configs') - if not exists(config_dpath): - raise Exception('Cannot find config path') - return config_dpath - - -def _get_config_module(fname): - """Load a configuration as a python module.""" - from mmcv import Config - config_dpath = _get_config_directory() - config_fpath = join(config_dpath, fname) - config_mod = Config.fromfile(config_fpath) - return config_mod - - -def _get_detector_cfg(fname): - """Grab configs necessary to create a detector. - - These are deep copied to allow for safe modification of parameters without - influencing other tests. - """ - config = _get_config_module(fname) - model = copy.deepcopy(config.model) - return model - - -@pytest.mark.parametrize('cfg_file', [ - 'textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py', - 'textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py', - 'textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py' -]) -def test_ocr_mask_rcnn(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - - input_shape = (1, 3, 224, 224) - mm_inputs = _demo_mm_inputs(0, input_shape) - - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - gt_labels = mm_inputs.pop('gt_labels') - gt_masks = mm_inputs.pop('gt_masks') - - # Test forward train - gt_bboxes = mm_inputs['gt_bboxes'] - losses = detector.forward( - imgs, - img_metas, - gt_bboxes=gt_bboxes, - gt_labels=gt_labels, - gt_masks=gt_masks) - assert isinstance(losses, dict) - - # Test forward test - with torch.no_grad(): - img_list = [g[None, :] for g in imgs] - batch_results = [] - for one_img, one_meta in zip(img_list, img_metas): - result = detector.forward([one_img], [[one_meta]], - return_loss=False) - batch_results.append(result) - - # Test show_result - - results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} - img = np.random.rand(5, 5) - detector.show_result(img, results) - - -@pytest.mark.parametrize('cfg_file', [ - 'textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py', - 'textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py', - 'textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py' -]) -def test_panet(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - detector = revert_sync_batchnorm(detector) - - input_shape = (1, 3, 224, 224) - num_kernels = 2 - mm_inputs = _demo_mm_inputs(num_kernels, input_shape) - - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - gt_kernels = mm_inputs.pop('gt_kernels') - gt_mask = mm_inputs.pop('gt_mask') - - # Test forward train - losses = detector.forward( - imgs, img_metas, gt_kernels=gt_kernels, gt_mask=gt_mask) - assert isinstance(losses, dict) - - # Test forward test - with torch.no_grad(): - img_list = [g[None, :] for g in imgs] - batch_results = [] - for one_img, one_meta in zip(img_list, img_metas): - result = detector.forward([one_img], [[one_meta]], - return_loss=False) - batch_results.append(result) - - # Test onnx export - detector.forward = partial( - detector.simple_test, img_metas=img_metas, rescale=True) - with tempfile.TemporaryDirectory() as tmpdirname: - onnx_path = f'{tmpdirname}/tmp.onnx' - torch.onnx.export( - detector, (img_list[0], ), - onnx_path, - input_names=['input'], - output_names=['output'], - export_params=True, - keep_initializers_as_inputs=False) - - # Test show result - results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} - img = np.random.rand(5, 5) - detector.show_result(img, results) - - -@pytest.mark.parametrize('cfg_file', [ - 'textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py', - 'textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py', - 'textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py' -]) -def test_psenet(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - detector = revert_sync_batchnorm(detector) - - input_shape = (1, 3, 224, 224) - num_kernels = 7 - mm_inputs = _demo_mm_inputs(num_kernels, input_shape) - - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - gt_kernels = mm_inputs.pop('gt_kernels') - gt_mask = mm_inputs.pop('gt_mask') - - # Test forward train - losses = detector.forward( - imgs, img_metas, gt_kernels=gt_kernels, gt_mask=gt_mask) - assert isinstance(losses, dict) - - # Test forward test - with torch.no_grad(): - img_list = [g[None, :] for g in imgs] - batch_results = [] - for one_img, one_meta in zip(img_list, img_metas): - result = detector.forward([one_img], [[one_meta]], - return_loss=False) - batch_results.append(result) - - # Test show result - results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} - img = np.random.rand(5, 5) - detector.show_result(img, results) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') -@pytest.mark.parametrize('cfg_file', [ - 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', - 'textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py' -]) -def test_dbnet(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - detector = revert_sync_batchnorm(detector) - detector = detector.cuda() - input_shape = (1, 3, 224, 224) - num_kernels = 7 - mm_inputs = _demo_mm_inputs(num_kernels, input_shape) - - imgs = mm_inputs.pop('imgs') - imgs = imgs.cuda() - img_metas = mm_inputs.pop('img_metas') - gt_shrink = mm_inputs.pop('gt_kernels') - gt_shrink_mask = mm_inputs.pop('gt_mask') - gt_thr = mm_inputs.pop('gt_masks') - gt_thr_mask = mm_inputs.pop('gt_thr_mask') - - # Test forward train - losses = detector.forward( - imgs, - img_metas, - gt_shrink=gt_shrink, - gt_shrink_mask=gt_shrink_mask, - gt_thr=gt_thr, - gt_thr_mask=gt_thr_mask) - assert isinstance(losses, dict) - - # Test forward test - with torch.no_grad(): - img_list = [g[None, :] for g in imgs] - batch_results = [] - for one_img, one_meta in zip(img_list, img_metas): - result = detector.forward([one_img], [[one_meta]], - return_loss=False) - batch_results.append(result) - - # Test show result - results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} - img = np.random.rand(5, 5) - detector.show_result(img, results) - - -@pytest.mark.parametrize( - 'cfg_file', - ['textdet/textsnake/' - 'textsnake_r50_fpn_unet_1200e_ctw1500.py']) -def test_textsnake(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - detector = revert_sync_batchnorm(detector) - input_shape = (1, 3, 224, 224) - num_kernels = 1 - mm_inputs = _demo_mm_inputs(num_kernels, input_shape) - - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - gt_text_mask = mm_inputs.pop('gt_text_mask') - gt_center_region_mask = mm_inputs.pop('gt_center_region_mask') - gt_mask = mm_inputs.pop('gt_mask') - gt_radius_map = mm_inputs.pop('gt_radius_map') - gt_sin_map = mm_inputs.pop('gt_sin_map') - gt_cos_map = mm_inputs.pop('gt_cos_map') - - # Test forward train - losses = detector.forward( - imgs, - img_metas, - gt_text_mask=gt_text_mask, - gt_center_region_mask=gt_center_region_mask, - gt_mask=gt_mask, - gt_radius_map=gt_radius_map, - gt_sin_map=gt_sin_map, - gt_cos_map=gt_cos_map) - assert isinstance(losses, dict) - - # Test forward test get_boundary - maps = torch.zeros((1, 5, 224, 224), dtype=torch.float) - maps[:, 0:2, :, :] = -10. - maps[:, 0, 60:100, 12:212] = 10. - maps[:, 1, 70:90, 22:202] = 10. - maps[:, 2, 70:90, 22:202] = 0. - maps[:, 3, 70:90, 22:202] = 1. - maps[:, 4, 70:90, 22:202] = 10. - - one_meta = img_metas[0] - result = detector.bbox_head.get_boundary(maps, [one_meta], False) - assert 'boundary_result' in result - assert 'filename' in result - - # Test show result - results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} - img = np.random.rand(5, 5) - detector.show_result(img, results) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') -@pytest.mark.parametrize('cfg_file', [ - 'textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py', - 'textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py' -]) -def test_fcenet(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - detector = revert_sync_batchnorm(detector) - detector = detector.cuda() - - fourier_degree = 5 - input_shape = (1, 3, 256, 256) - (n, c, h, w) = input_shape - - imgs = torch.randn(n, c, h, w).float().cuda() - img_metas = [{ - 'img_shape': (h, w, c), - 'ori_shape': (h, w, c), - 'pad_shape': (h, w, c), - 'filename': '.png', - 'scale_factor': np.array([1, 1, 1, 1]), - 'flip': False, - } for _ in range(n)] - - p3_maps = [] - p4_maps = [] - p5_maps = [] - for _ in range(n): - p3_maps.append( - np.random.random((5 + 4 * fourier_degree, h // 8, w // 8))) - p4_maps.append( - np.random.random((5 + 4 * fourier_degree, h // 16, w // 16))) - p5_maps.append( - np.random.random((5 + 4 * fourier_degree, h // 32, w // 32))) - - # Test forward train - losses = detector.forward( - imgs, img_metas, p3_maps=p3_maps, p4_maps=p4_maps, p5_maps=p5_maps) - assert isinstance(losses, dict) - - # Test forward test - with torch.no_grad(): - img_list = [g[None, :] for g in imgs] - batch_results = [] - for one_img, one_meta in zip(img_list, img_metas): - result = detector.forward([one_img], [[one_meta]], - return_loss=False) - batch_results.append(result) - - # Test show result - results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} - img = np.random.rand(5, 5) - detector.show_result(img, results) - - -@pytest.mark.parametrize( - 'cfg_file', ['textdet/drrg/' - 'drrg_r50_fpn_unet_1200e_ctw1500.py']) -def test_drrg(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - detector = revert_sync_batchnorm(detector) - - input_shape = (1, 3, 224, 224) - num_kernels = 1 - mm_inputs = _demo_mm_inputs(num_kernels, input_shape) - - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - gt_text_mask = mm_inputs.pop('gt_text_mask') - gt_center_region_mask = mm_inputs.pop('gt_center_region_mask') - gt_mask = mm_inputs.pop('gt_mask') - gt_top_height_map = mm_inputs.pop('gt_radius_map') - gt_bot_height_map = gt_top_height_map.copy() - gt_sin_map = mm_inputs.pop('gt_sin_map') - gt_cos_map = mm_inputs.pop('gt_cos_map') - num_rois = 32 - x = np.random.randint(4, 224, (num_rois, 1)) - y = np.random.randint(4, 224, (num_rois, 1)) - h = 4 * np.ones((num_rois, 1)) - w = 4 * np.ones((num_rois, 1)) - angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2 - cos, sin = np.cos(angle), np.sin(angle) - comp_labels = np.random.randint(1, 3, (num_rois, 1)) - num_rois = num_rois * np.ones((num_rois, 1)) - comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels]) - gt_comp_attribs = np.expand_dims(comp_attribs.astype(np.float32), axis=0) - - # Test forward train - losses = detector.forward( - imgs, - img_metas, - gt_text_mask=gt_text_mask, - gt_center_region_mask=gt_center_region_mask, - gt_mask=gt_mask, - gt_top_height_map=gt_top_height_map, - gt_bot_height_map=gt_bot_height_map, - gt_sin_map=gt_sin_map, - gt_cos_map=gt_cos_map, - gt_comp_attribs=gt_comp_attribs) - assert isinstance(losses, dict) - - # Test forward test - model['bbox_head']['in_channels'] = 6 - model['bbox_head']['text_region_thr'] = 0.8 - model['bbox_head']['center_region_thr'] = 0.8 - detector = build_detector(model) - maps = torch.zeros((1, 6, 224, 224), dtype=torch.float) - maps[:, 0:2, :, :] = -10. - maps[:, 0, 60:100, 50:170] = 10. - maps[:, 1, 75:85, 60:160] = 10. - maps[:, 2, 75:85, 60:160] = 0. - maps[:, 3, 75:85, 60:160] = 1. - maps[:, 4, 75:85, 60:160] = 10. - maps[:, 5, 75:85, 60:160] = 10. - - with torch.no_grad(): - full_pass_weight = torch.zeros((6, 6, 1, 1)) - for i in range(6): - full_pass_weight[i, i, 0, 0] = 1 - detector.bbox_head.out_conv.weight.data = full_pass_weight - detector.bbox_head.out_conv.bias.data.fill_(0.) - outs = detector.bbox_head.single_test(maps) - boundaries = detector.bbox_head.get_boundary(*outs, img_metas, True) - assert len(boundaries) == 1 - - # Test show result - results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} - img = np.random.rand(5, 5) - detector.show_result(img, results) diff --git a/old_tests/test_models/test_kie_config.py b/old_tests/test_models/test_kie_config.py deleted file mode 100644 index b2b1f351..00000000 --- a/old_tests/test_models/test_kie_config.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from os.path import dirname, exists, join - -import numpy as np -import pytest -import torch - - -def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300), - num_items=None): # yapf: disable - """Create a superset of inputs needed to run test or train batches. - - Args: - input_shape (tuple): Input batch dimensions. - - num_items (None | list[int]): Specifies the number of boxes - for each batch item. - """ - - (N, C, H, W) = input_shape - rng = np.random.RandomState(0) - imgs = rng.rand(*input_shape) - - img_metas = [{ - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'pad_shape': (H, W, C), - 'filename': '.png', - } for _ in range(N)] - relations = [torch.randn(10, 10, 5) for _ in range(N)] - texts = [torch.ones(10, 16) for _ in range(N)] - gt_bboxes = [torch.Tensor([[2, 2, 4, 4]]).expand(10, 4) for _ in range(N)] - gt_labels = [torch.ones(10, 11).long() for _ in range(N)] - - mm_inputs = { - 'imgs': torch.FloatTensor(imgs).requires_grad_(True), - 'img_metas': img_metas, - 'relations': relations, - 'texts': texts, - 'gt_bboxes': gt_bboxes, - 'gt_labels': gt_labels - } - return mm_inputs - - -def _get_config_directory(): - """Find the predefined detector config directory.""" - try: - # Assume we are running in the source mmocr repo - repo_dpath = dirname(dirname(dirname(__file__))) - except NameError: - # For IPython development when this __file__ is not defined - import mmocr - repo_dpath = dirname(dirname(mmocr.__file__)) - config_dpath = join(repo_dpath, 'configs') - if not exists(config_dpath): - raise Exception('Cannot find config path') - return config_dpath - - -def _get_config_module(fname): - """Load a configuration as a python module.""" - from mmcv import Config - config_dpath = _get_config_directory() - config_fpath = join(config_dpath, fname) - config_mod = Config.fromfile(config_fpath) - return config_mod - - -def _get_detector_cfg(fname): - """Grab configs necessary to create a detector. - - These are deep copied to allow for safe modification of parameters without - influencing other tests. - """ - config = _get_config_module(fname) - config.model.class_list = None - model = copy.deepcopy(config.model) - return model - - -@pytest.mark.parametrize('cfg_file', [ - 'kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py', - 'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py' -]) -def test_sdmgr_pipeline(cfg_file): - model = _get_detector_cfg(cfg_file) - - from mmocr.models import build_detector - detector = build_detector(model) - - input_shape = (1, 3, 128, 128) - - mm_inputs = _demo_mm_inputs(0, input_shape) - - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - relations = mm_inputs.pop('relations') - texts = mm_inputs.pop('texts') - gt_bboxes = mm_inputs.pop('gt_bboxes') - gt_labels = mm_inputs.pop('gt_labels') - - # Test forward train - losses = detector.forward( - imgs, - img_metas, - relations=relations, - texts=texts, - gt_bboxes=gt_bboxes, - gt_labels=gt_labels) - assert isinstance(losses, dict) - - # Test forward test - with torch.no_grad(): - batch_results = [] - for idx in range(len(img_metas)): - result = detector.forward( - imgs[idx:idx + 1], - None, - return_loss=False, - relations=[relations[idx]], - texts=[texts[idx]], - gt_bboxes=[gt_bboxes[idx]]) - batch_results.append(result) - - # Test show_result - results = {'nodes': torch.randn(1, 3)} - boxes = [[1, 1, 2, 1, 2, 2, 1, 2]] - img = np.random.rand(5, 5, 3) - detector.show_result(img, results, boxes) diff --git a/old_tests/test_models/test_loss.py b/old_tests/test_models/test_loss.py deleted file mode 100644 index 8488c222..00000000 --- a/old_tests/test_models/test_loss.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import torch -from mmdet.core import BitmapMasks - -import mmocr.models.textdet.module_losses as module_losses - - -def test_panloss(): - panloss = module_losses.PANModuleLoss() - - # test bitmasks2tensor - mask = [[1, 0, 1], [1, 1, 1], [0, 0, 1]] - target = [[1, 0, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] - masks = [np.array(mask)] - bitmasks = BitmapMasks(masks, 3, 3) - target_sz = (6, 5) - results = panloss.bitmasks2tensor([bitmasks], target_sz) - assert len(results) == 1 - assert torch.sum(torch.abs(results[0].float() - - torch.Tensor(target))).item() == 0 - - -def test_fcenetloss(): - k = 5 - fcenetloss = module_losses.FCEModuleLoss(fourier_degree=k, num_sample=10) - - input_shape = (1, 3, 64, 64) - (n, c, h, w) = input_shape - - # test ohem - pred = torch.ones((200, 2), dtype=torch.float) - target = torch.ones(200, dtype=torch.long) - target[20:] = 0 - mask = torch.ones(200, dtype=torch.long) - - ohem_loss1 = fcenetloss.ohem(pred, target, mask) - ohem_loss2 = fcenetloss.ohem(pred, target, 1 - mask) - assert isinstance(ohem_loss1, torch.Tensor) - assert isinstance(ohem_loss2, torch.Tensor) - - # test forward - preds = [] - for i in range(n): - scale = 8 * 2**i - pred = [] - pred.append(torch.rand(n, 4, h // scale, w // scale)) - pred.append(torch.rand(n, 4 * k + 2, h // scale, w // scale)) - preds.append(pred) - - p3_maps = [] - p4_maps = [] - p5_maps = [] - for _ in range(n): - p3_maps.append(np.random.random((5 + 4 * k, h // 8, w // 8))) - p4_maps.append(np.random.random((5 + 4 * k, h // 16, w // 16))) - p5_maps.append(np.random.random((5 + 4 * k, h // 32, w // 32))) - - loss = fcenetloss(preds, 0, p3_maps, p4_maps, p5_maps) - assert isinstance(loss, dict) - - -def test_drrgloss(): - drrgloss = module_losses.DRRGModuleLoss() - assert np.allclose(drrgloss.ohem_ratio, 3.0) - - # test balance_bce_loss - pred = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=torch.float) - target = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long) - mask = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long) - bce_loss = drrgloss.balance_bce_loss(pred, target, mask).item() - assert np.allclose(bce_loss, 0) - - # test balance_bce_loss with positive_count equal to zero - pred = torch.ones((16, 16), dtype=torch.float) - target = torch.ones((16, 16), dtype=torch.long) - mask = torch.zeros((16, 16), dtype=torch.long) - bce_loss = drrgloss.balance_bce_loss(pred, target, mask).item() - assert np.allclose(bce_loss, 0) - - # test gcn_loss - gcn_preds = torch.tensor([[0., 1.], [1., 0.]]) - labels = torch.tensor([1, 0], dtype=torch.long) - gcn_loss = drrgloss.gcn_loss((gcn_preds, labels)) - assert gcn_loss.item() - - # test bitmasks2tensor - mask = [[1, 0, 1], [1, 1, 1], [0, 0, 1]] - target = [[1, 0, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] - masks = [np.array(mask)] - bitmasks = BitmapMasks(masks, 3, 3) - target_sz = (6, 5) - results = drrgloss.bitmasks2tensor([bitmasks], target_sz) - assert len(results) == 1 - assert torch.sum(torch.abs(results[0].float() - - torch.Tensor(target))).item() == 0 - - # test forward - target_maps = [BitmapMasks([np.random.randn(20, 20)], 20, 20)] - target_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)] - gt_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)] - preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels)) - loss_dict = drrgloss(preds, 1., target_masks, target_masks, gt_masks, - target_maps, target_maps, target_maps, target_maps) - - assert isinstance(loss_dict, dict) - assert 'loss_text' in loss_dict.keys() - assert 'loss_center' in loss_dict.keys() - assert 'loss_height' in loss_dict.keys() - assert 'loss_sin' in loss_dict.keys() - assert 'loss_cos' in loss_dict.keys() - assert 'loss_gcn' in loss_dict.keys() - - # test forward with downsample_ratio less than 1. - target_maps = [BitmapMasks([np.random.randn(40, 40)], 40, 40)] - target_masks = [BitmapMasks([np.ones((40, 40))], 40, 40)] - gt_masks = [BitmapMasks([np.ones((40, 40))], 40, 40)] - preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels)) - loss_dict = drrgloss(preds, 0.5, target_masks, target_masks, gt_masks, - target_maps, target_maps, target_maps, target_maps) - - assert isinstance(loss_dict, dict) - - # test forward with blank gt_mask. - target_maps = [BitmapMasks([np.random.randn(20, 20)], 20, 20)] - target_masks = [BitmapMasks([np.ones((20, 20))], 20, 20)] - gt_masks = [BitmapMasks([np.zeros((20, 20))], 20, 20)] - preds = (torch.randn((1, 6, 20, 20)), (gcn_preds, labels)) - loss_dict = drrgloss(preds, 1., target_masks, target_masks, gt_masks, - target_maps, target_maps, target_maps, target_maps) - - assert isinstance(loss_dict, dict) - - -def test_dice_loss(): - pred = torch.Tensor([[[-1000, -1000, -1000], [-1000, -1000, -1000], - [-1000, -1000, -1000]]]) - target = torch.Tensor([[[0, 0, 0], [0, 0, 0], [0, 0, 0]]]) - mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) - - pan_loss = module_losses.PANModuleLoss() - - dice_loss = pan_loss.dice_loss_with_logits(pred, target, mask) - - assert np.allclose(dice_loss.item(), 0) diff --git a/old_tests/test_models/test_modules.py b/old_tests/test_models/test_modules.py deleted file mode 100644 index 9e19ea3b..00000000 --- a/old_tests/test_models/test_modules.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import torch - -from mmocr.models.textdet.modules import GCN, LocalGraphs, ProposalLocalGraphs -from mmocr.models.textdet.modules.utils import (feature_embedding, - normalize_adjacent_matrix) - - -def test_local_graph_forward_train(): - geo_feat_len = 24 - pooling_h, pooling_w = pooling_out_size = (2, 2) - num_rois = 32 - - local_graph_generator = LocalGraphs((4, 4), 3, geo_feat_len, 1.0, - pooling_out_size, 0.5) - - feature_maps = torch.randn((2, 3, 128, 128), dtype=torch.float) - x = np.random.randint(4, 124, (num_rois, 1)) - y = np.random.randint(4, 124, (num_rois, 1)) - h = 4 * np.ones((num_rois, 1)) - w = 4 * np.ones((num_rois, 1)) - angle = (np.random.random_sample((num_rois, 1)) * 2 - 1) * np.pi / 2 - cos, sin = np.cos(angle), np.sin(angle) - comp_labels = np.random.randint(1, 3, (num_rois, 1)) - num_rois = num_rois * np.ones((num_rois, 1)) - comp_attribs = np.hstack([num_rois, x, y, h, w, cos, sin, comp_labels]) - comp_attribs = comp_attribs.astype(np.float32) - comp_attribs_ = comp_attribs.copy() - comp_attribs = np.stack([comp_attribs, comp_attribs_]) - - (node_feats, adjacent_matrix, knn_inds, - linkage_labels) = local_graph_generator(feature_maps, comp_attribs) - feat_len = geo_feat_len + feature_maps.size()[1] * pooling_h * pooling_w - - assert node_feats.dim() == adjacent_matrix.dim() == 3 - assert node_feats.size()[-1] == feat_len - assert knn_inds.size()[-1] == 4 - assert linkage_labels.size()[-1] == 4 - assert (node_feats.size()[0] == adjacent_matrix.size()[0] == - knn_inds.size()[0] == linkage_labels.size()[0]) - assert (node_feats.size()[1] == adjacent_matrix.size()[1] == - adjacent_matrix.size()[2]) - - -def test_local_graph_forward_test(): - geo_feat_len = 24 - pooling_h, pooling_w = pooling_out_size = (2, 2) - - local_graph_generator = ProposalLocalGraphs( - (4, 4), 2, geo_feat_len, 1., pooling_out_size, 0.1, 3., 6., 1., 0.5, - 0.3, 0.5, 0.5, 2) - - maps = torch.zeros((1, 6, 224, 224), dtype=torch.float) - maps[:, 0:2, :, :] = -10. - maps[:, 0, 60:100, 50:170] = 10. - maps[:, 1, 75:85, 60:160] = 10. - maps[:, 2, 75:85, 60:160] = 0. - maps[:, 3, 75:85, 60:160] = 1. - maps[:, 4, 75:85, 60:160] = 10. - maps[:, 5, 75:85, 60:160] = 10. - feature_maps = torch.randn((2, 6, 224, 224), dtype=torch.float) - feat_len = geo_feat_len + feature_maps.size()[1] * pooling_h * pooling_w - - none_flag, graph_data = local_graph_generator(maps, feature_maps) - (node_feats, adjacent_matrices, knn_inds, local_graphs, - text_comps) = graph_data - - assert none_flag is False - assert text_comps.ndim == 2 - assert text_comps.shape[0] > 0 - assert text_comps.shape[1] == 9 - assert (node_feats.size()[0] == adjacent_matrices.size()[0] == - knn_inds.size()[0] == local_graphs.size()[0] == - text_comps.shape[0]) - assert (node_feats.size()[1] == adjacent_matrices.size()[1] == - adjacent_matrices.size()[2] == local_graphs.size()[1]) - assert node_feats.size()[-1] == feat_len - - # test proposal local graphs with area of center region less than threshold - maps[:, 1, 75:85, 60:160] = -10. - maps[:, 1, 80, 80] = 10. - none_flag, _ = local_graph_generator(maps, feature_maps) - assert none_flag - - # test proposal local graphs with one text component - local_graph_generator = ProposalLocalGraphs( - (4, 4), 2, geo_feat_len, 1., pooling_out_size, 0.1, 8., 20., 1., 0.5, - 0.3, 0.5, 0.5, 2) - maps[:, 1, 78:82, 78:82] = 10. - none_flag, _ = local_graph_generator(maps, feature_maps) - assert none_flag - - # test proposal local graphs with text components out of text region - maps[:, 0, 60:100, 50:170] = -10. - maps[:, 0, 78:82, 78:82] = 10. - none_flag, _ = local_graph_generator(maps, feature_maps) - assert none_flag - - -def test_gcn(): - num_local_graphs = 32 - num_max_graph_nodes = 16 - input_feat_len = 512 - k = 8 - gcn = GCN(input_feat_len) - node_feat = torch.randn( - (num_local_graphs, num_max_graph_nodes, input_feat_len)) - adjacent_matrix = torch.rand( - (num_local_graphs, num_max_graph_nodes, num_max_graph_nodes)) - knn_inds = torch.randint(1, num_max_graph_nodes, (num_local_graphs, k)) - output = gcn(node_feat, adjacent_matrix, knn_inds) - assert output.size() == (num_local_graphs * k, 2) - - -def test_normalize_adjacent_matrix(): - adjacent_matrix = np.random.randint(0, 2, (16, 16)) - normalized_matrix = normalize_adjacent_matrix(adjacent_matrix) - assert normalized_matrix.shape == adjacent_matrix.shape - - -def test_feature_embedding(): - out_feat_len = 48 - - # test without residue dimensions - feats = np.random.randn(10, 8) - embed_feats = feature_embedding(feats, out_feat_len) - assert embed_feats.shape == (10, out_feat_len) - - # test with residue dimensions - feats = np.random.randn(10, 9) - embed_feats = feature_embedding(feats, out_feat_len) - assert embed_feats.shape == (10, out_feat_len) diff --git a/old_tests/test_models/test_ocr_backbone.py b/old_tests/test_models/test_ocr_backbone.py deleted file mode 100644 index 5139458c..00000000 --- a/old_tests/test_models/test_ocr_backbone.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch - -from mmocr.models.textrecog.backbones import (ResNet, ResNet31OCR, ResNetABI, - VeryDeepVgg) - - -def test_resnet31_ocr_backbone(): - """Test resnet backbone.""" - with pytest.raises(AssertionError): - ResNet31OCR(2.5) - - with pytest.raises(AssertionError): - ResNet31OCR(3, layers=5) - - with pytest.raises(AssertionError): - ResNet31OCR(3, channels=5) - - # Test ResNet18 forward - model = ResNet31OCR() - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 32, 160) - feat = model(imgs) - assert feat.shape == torch.Size([1, 512, 4, 40]) - - -def test_vgg_deep_vgg_ocr_backbone(): - - model = VeryDeepVgg() - model.init_weights() - model.train() - - imgs = torch.randn(1, 3, 32, 160) - feats = model(imgs) - assert feats.shape == torch.Size([1, 512, 1, 41]) - - -def test_resnet_abi(): - """Test resnet backbone.""" - with pytest.raises(AssertionError): - ResNetABI(2.5) - - with pytest.raises(AssertionError): - ResNetABI(3, arch_settings=5) - - with pytest.raises(AssertionError): - ResNetABI(3, stem_channels=None) - - with pytest.raises(AssertionError): - ResNetABI(arch_settings=[3, 4, 6, 6], strides=[1, 2, 1, 2, 1]) - - # Test forwarding - model = ResNetABI() - model.train() - - imgs = torch.randn(1, 3, 32, 160) - feat = model(imgs) - assert feat.shape == torch.Size([1, 512, 8, 40]) - - -def test_resnet(): - """Test all ResNet backbones.""" - - resnet45_aster = ResNet( - in_channels=3, - stem_channels=[64, 128], - block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), - arch_layers=[3, 4, 6, 6, 3], - arch_channels=[32, 64, 128, 256, 512], - strides=[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]) - - resnet45_abi = ResNet( - in_channels=3, - stem_channels=32, - block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), - arch_layers=[3, 4, 6, 6, 3], - arch_channels=[32, 64, 128, 256, 512], - strides=[2, 1, 2, 1, 1]) - - resnet_31 = ResNet( - in_channels=3, - stem_channels=[64, 128], - block_cfgs=dict(type='BasicBlock'), - arch_layers=[1, 2, 5, 3], - arch_channels=[256, 256, 512, 512], - strides=[1, 1, 1, 1], - plugins=[ - dict( - cfg=dict(type='Maxpool2d', kernel_size=2, stride=(2, 2)), - stages=(True, True, False, False), - position='before_stage'), - dict( - cfg=dict(type='Maxpool2d', kernel_size=(2, 1), stride=(2, 1)), - stages=(False, False, True, False), - position='before_stage'), - dict( - cfg=dict( - type='ConvModule', - kernel_size=3, - stride=1, - padding=1, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU')), - stages=(True, True, True, True), - position='after_stage') - ]) - - resnet31_master = ResNet( - in_channels=3, - stem_channels=[64, 128], - block_cfgs=dict(type='BasicBlock'), - arch_layers=[1, 2, 5, 3], - arch_channels=[256, 256, 512, 512], - strides=[1, 1, 1, 1], - plugins=[ - dict( - cfg=dict(type='Maxpool2d', kernel_size=2, stride=(2, 2)), - stages=(True, True, False, False), - position='before_stage'), - dict( - cfg=dict(type='Maxpool2d', kernel_size=(2, 1), stride=(2, 1)), - stages=(False, False, True, False), - position='before_stage'), - dict( - cfg=dict(type='GCAModule', ratio=0.0625, n_head=1), - stages=[True, True, True, True], - position='after_stage'), - dict( - cfg=dict( - type='ConvModule', - kernel_size=3, - stride=1, - padding=1, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU')), - stages=(True, True, True, True), - position='after_stage') - ]) - img = torch.rand(1, 3, 32, 100) - - assert resnet45_aster(img).shape == torch.Size([1, 512, 1, 25]) - assert resnet45_abi(img).shape == torch.Size([1, 512, 8, 25]) - assert resnet_31(img).shape == torch.Size([1, 512, 4, 25]) - assert resnet31_master(img).shape == torch.Size([1, 512, 4, 25]) diff --git a/old_tests/test_models/test_ocr_layer.py b/old_tests/test_models/test_ocr_layer.py deleted file mode 100644 index e4b4a39b..00000000 --- a/old_tests/test_models/test_ocr_layer.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmocr.models.common import (PositionalEncoding, TFDecoderLayer, - TFEncoderLayer) -from mmocr.models.textrecog.layers import BasicBlock, Bottleneck -from mmocr.models.textrecog.layers.conv_layer import conv3x3 - - -def test_conv_layer(): - conv3by3 = conv3x3(3, 6) - assert conv3by3.in_channels == 3 - assert conv3by3.out_channels == 6 - assert conv3by3.kernel_size == (3, 3) - - x = torch.rand(1, 64, 224, 224) - # test basic block - basic_block = BasicBlock(64, 64) - assert basic_block.expansion == 1 - - out = basic_block(x) - - assert out.shape == torch.Size([1, 64, 224, 224]) - - # test bottle neck - bottle_neck = Bottleneck(64, 64, downsample=True) - assert bottle_neck.expansion == 4 - - out = bottle_neck(x) - - assert out.shape == torch.Size([1, 256, 224, 224]) - - -def test_transformer_layer(): - # test decoder_layer - decoder_layer = TFDecoderLayer() - in_dec = torch.rand(1, 30, 512) - out_enc = torch.rand(1, 128, 512) - out_dec = decoder_layer(in_dec, out_enc) - assert out_dec.shape == torch.Size([1, 30, 512]) - - decoder_layer = TFDecoderLayer( - operation_order=('self_attn', 'norm', 'enc_dec_attn', 'norm', 'ffn', - 'norm')) - out_dec = decoder_layer(in_dec, out_enc) - assert out_dec.shape == torch.Size([1, 30, 512]) - - # test positional_encoding - pos_encoder = PositionalEncoding() - x = torch.rand(1, 30, 512) - out = pos_encoder(x) - assert out.size() == x.size() - - # test encoder_layer - encoder_layer = TFEncoderLayer() - in_enc = torch.rand(1, 20, 512) - out_enc = encoder_layer(in_enc) - assert out_dec.shape == torch.Size([1, 30, 512]) - - encoder_layer = TFEncoderLayer( - operation_order=('self_attn', 'norm', 'ffn', 'norm')) - out_enc = encoder_layer(in_enc) - assert out_dec.shape == torch.Size([1, 30, 512]) diff --git a/old_tests/test_models/test_ocr_preprocessor.py b/old_tests/test_models/test_ocr_preprocessor.py deleted file mode 100644 index c0a50463..00000000 --- a/old_tests/test_models/test_ocr_preprocessor.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch - -from mmocr.models.textrecog.preprocessors import (BasePreprocessor, - TPSPreprocessor) - - -def test_tps_preprocessor(): - with pytest.raises(AssertionError): - TPSPreprocessor(num_fiducial=-1) - with pytest.raises(AssertionError): - TPSPreprocessor(img_size=32) - with pytest.raises(AssertionError): - TPSPreprocessor(rectified_img_size=100) - with pytest.raises(AssertionError): - TPSPreprocessor(num_img_channel='bgr') - - tps_preprocessor = TPSPreprocessor( - num_fiducial=20, - img_size=(32, 100), - rectified_img_size=(32, 100), - num_img_channel=1) - tps_preprocessor.init_weights() - tps_preprocessor.train() - - batch_img = torch.randn(1, 1, 32, 100) - processed = tps_preprocessor(batch_img) - assert processed.shape == torch.Size([1, 1, 32, 100]) - - -def test_base_preprocessor(): - preprocessor = BasePreprocessor() - preprocessor.init_weights() - preprocessor.train() - - batch_img = torch.randn(1, 1, 32, 100) - processed = preprocessor(batch_img) - assert processed.shape == torch.Size([1, 1, 32, 100]) diff --git a/old_tests/test_models/test_recog_config.py b/old_tests/test_models/test_recog_config.py deleted file mode 100644 index 5084f4ad..00000000 --- a/old_tests/test_models/test_recog_config.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from os.path import dirname, exists, join - -import numpy as np -import pytest -import torch - - -def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300), - num_items=None): # yapf: disable - """Create a superset of inputs needed to run test or train batches. - - Args: - input_shape (tuple): Input batch dimensions. - - num_items (None | list[int]): Specifies the number of boxes - for each batch item. - """ - - (N, C, H, W) = input_shape - - rng = np.random.RandomState(0) - - imgs = rng.rand(*input_shape) - - img_metas = [{ - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'resize_shape': (H, W, C), - 'filename': '.png', - 'text': 'hello', - 'valid_ratio': 1.0, - } for _ in range(N)] - - mm_inputs = { - 'imgs': torch.FloatTensor(imgs).requires_grad_(True), - 'img_metas': img_metas - } - return mm_inputs - - -def _demo_gt_kernel_inputs(num_kernels=3, input_shape=(1, 3, 300, 300), - num_items=None): # yapf: disable - """Create a superset of inputs needed to run test or train batches. - - Args: - input_shape (tuple): Input batch dimensions. - - num_items (None | list[int]): Specifies the number of boxes - for each batch item. - """ - from mmdet.core import BitmapMasks - - (N, C, H, W) = input_shape - gt_kernels = [] - - for batch_idx in range(N): - kernels = [] - for kernel_inx in range(num_kernels): - kernel = np.random.rand(H, W) - kernels.append(kernel) - gt_kernels.append(BitmapMasks(kernels, H, W)) - - return gt_kernels - - -def _get_config_directory(): - """Find the predefined detector config directory.""" - try: - # Assume we are running in the source mmocr repo - repo_dpath = dirname(dirname(dirname(__file__))) - except NameError: - # For IPython development when this __file__ is not defined - import mmocr - repo_dpath = dirname(dirname(mmocr.__file__)) - config_dpath = join(repo_dpath, 'configs') - if not exists(config_dpath): - raise Exception('Cannot find config path') - return config_dpath - - -def _get_config_module(fname): - """Load a configuration as a python module.""" - from mmcv import Config - config_dpath = _get_config_directory() - config_fpath = join(config_dpath, fname) - config_mod = Config.fromfile(config_fpath) - return config_mod - - -def _get_detector_cfg(fname): - """Grab configs necessary to create a detector. - - These are deep copied to allow for safe modification of parameters without - influencing other tests. - """ - config = _get_config_module(fname) - model = copy.deepcopy(config.model) - return model - - -@pytest.mark.parametrize('cfg_file', [ - 'textrecog/sar/sar_r31_parallel_decoder_academic.py', - 'textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py', - 'textrecog/sar/sar_r31_sequential_decoder_academic.py', - 'textrecog/crnn/crnn_toy_dataset.py', - 'textrecog/crnn/crnn_academic_dataset.py', - 'textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py', - 'textrecog/nrtr/nrtr_modality_transform_academic.py', - 'textrecog/nrtr/nrtr_modality_transform_toy_dataset.py', - 'textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py', - 'textrecog/robust_scanner/robustscanner_r31_academic.py', - 'textrecog/seg/seg_r31_1by16_fpnocr_academic.py', - 'textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py', - 'textrecog/satrn/satrn_academic.py', 'textrecog/satrn/satrn_small.py', - 'textrecog/tps/crnn_tps_academic_dataset.py' -]) -def test_recognizer_pipeline(cfg_file): - model = _get_detector_cfg(cfg_file) - model['pretrained'] = None - - from mmocr.models import build_detector - detector = build_detector(model) - - input_shape = (1, 3, 32, 160) - if 'crnn' in cfg_file: - input_shape = (1, 1, 32, 160) - mm_inputs = _demo_mm_inputs(0, input_shape) - gt_kernels = None - if 'seg' in cfg_file: - gt_kernels = _demo_gt_kernel_inputs(3, input_shape) - - imgs = mm_inputs.pop('imgs') - img_metas = mm_inputs.pop('img_metas') - - # Test forward train - if 'seg' in cfg_file: - losses = detector.forward(imgs, img_metas, gt_kernels=gt_kernels) - else: - losses = detector.forward(imgs, img_metas) - assert isinstance(losses, dict) - - # Test forward test - with torch.no_grad(): - img_list = [g[None, :] for g in imgs] - batch_results = [] - for one_img, one_meta in zip(img_list, img_metas): - result = detector.forward([one_img], [[one_meta]], - return_loss=False) - batch_results.append(result) - - # Test show_result - - results = {'text': 'hello', 'score': 1.0} - img = np.random.rand(5, 5, 3) - detector.show_result(img, results) diff --git a/tests/test_evaluation/functional/test_hmean.py b/tests/test_evaluation/functional/test_hmean.py new file mode 100644 index 00000000..9fdde228 --- /dev/null +++ b/tests/test_evaluation/functional/test_hmean.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmocr.evaluation.functional import compute_hmean + + +class TestHmean(TestCase): + + def test_compute_hmean(self): + with self.assertRaises(AssertionError): + compute_hmean(0, 0, 0.0, 0) + with self.assertRaises(AssertionError): + compute_hmean(0, 0, 0, 0.0) + with self.assertRaises(AssertionError): + compute_hmean([1], 0, 0, 0) + with self.assertRaises(AssertionError): + compute_hmean(0, [1], 0, 0) + + _, _, hmean = compute_hmean(2, 2, 2, 2) + self.assertEqual(hmean, 1) + + _, _, hmean = compute_hmean(0, 0, 2, 2) + self.assertEqual(hmean, 0) diff --git a/tests/test_evaluation/test_metrics/test_hmean_iou.py b/tests/test_evaluation/test_metrics/test_hmean_iou_metric.py similarity index 100% rename from tests/test_evaluation/test_metrics/test_hmean_iou.py rename to tests/test_evaluation/test_metrics/test_hmean_iou_metric.py diff --git a/tests/test_models/test_common/layers/test_transformer_layers.py b/tests/test_models/test_common/layers/test_transformer_layers.py new file mode 100644 index 00000000..a495c72b --- /dev/null +++ b/tests/test_models/test_common/layers/test_transformer_layers.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.common.layers.transformer_layers import (TFDecoderLayer, + TFEncoderLayer) + + +class TestTFEncoderLayer(TestCase): + + def test_forward(self): + encoder_layer = TFEncoderLayer() + in_enc = torch.rand(1, 20, 512) + out_enc = encoder_layer(in_enc) + self.assertEqual(out_enc.shape, torch.Size([1, 20, 512])) + + encoder_layer = TFEncoderLayer( + operation_order=('self_attn', 'norm', 'ffn', 'norm')) + out_enc = encoder_layer(in_enc) + self.assertEqual(out_enc.shape, torch.Size([1, 20, 512])) + + +class TestTFDecoderLayer(TestCase): + + def test_forward(self): + decoder_layer = TFDecoderLayer() + in_dec = torch.rand(1, 30, 512) + out_enc = torch.rand(1, 128, 512) + out_dec = decoder_layer(in_dec, out_enc) + self.assertEqual(out_dec.shape, torch.Size([1, 30, 512])) + + decoder_layer = TFDecoderLayer( + operation_order=('self_attn', 'norm', 'enc_dec_attn', 'norm', + 'ffn', 'norm')) + out_dec = decoder_layer(in_dec, out_enc) + self.assertEqual(out_dec.shape, torch.Size([1, 30, 512])) diff --git a/tests/test_models/test_common/modules/test_transformer_module.py b/tests/test_models/test_common/modules/test_transformer_module.py new file mode 100644 index 00000000..84f9140e --- /dev/null +++ b/tests/test_models/test_common/modules/test_transformer_module.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.common.modules import PositionalEncoding + + +class TestPositionalEncoding(TestCase): + + def test_forward(self): + pos_encoder = PositionalEncoding() + x = torch.rand(1, 30, 512) + out = pos_encoder(x) + assert out.size() == x.size() diff --git a/tests/test_models/test_textrecog/layers/test_conv_layer.py b/tests/test_models/test_textrecog/layers/test_conv_layer.py new file mode 100644 index 00000000..bf65d86c --- /dev/null +++ b/tests/test_models/test_textrecog/layers/test_conv_layer.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textrecog.layers.conv_layer import (BasicBlock, Bottleneck, + conv1x1, conv3x3) + + +class TestUtils(TestCase): + + def test_conv3x3(self): + conv = conv3x3(3, 6) + self.assertEqual(conv.in_channels, 3) + self.assertEqual(conv.out_channels, 6) + self.assertEqual(conv.kernel_size, (3, 3)) + + def test_conv1x1(self): + conv = conv1x1(3, 6) + self.assertEqual(conv.in_channels, 3) + self.assertEqual(conv.out_channels, 6) + self.assertEqual(conv.kernel_size, (1, 1)) + + +class TestBasicBlock(TestCase): + + def test_forward(self): + x = torch.rand(1, 64, 224, 224) + basic_block = BasicBlock(64, 64) + self.assertEqual(basic_block.expansion, 1) + out = basic_block(x) + self.assertEqual(out.shape, torch.Size([1, 64, 224, 224])) + + +class TestBottleneck(TestCase): + + def test_forward(self): + x = torch.rand(1, 64, 224, 224) + bottle_neck = Bottleneck(64, 64, downsample=True) + self.assertEqual(bottle_neck.expansion, 4) + out = bottle_neck(x) + self.assertEqual(out.shape, torch.Size([1, 256, 224, 224])) diff --git a/tests/test_models/test_textrecog/test_backbones/test_resnet31_ocr.py b/tests/test_models/test_textrecog/test_backbones/test_resnet31_ocr.py new file mode 100644 index 00000000..a6bc566c --- /dev/null +++ b/tests/test_models/test_textrecog/test_backbones/test_resnet31_ocr.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textrecog.backbones import ResNet31OCR + + +class TestResNet31OCR(TestCase): + + def test_forward(self): + """Test resnet backbone.""" + with self.assertRaises(AssertionError): + ResNet31OCR(2.5) + + with self.assertRaises(AssertionError): + ResNet31OCR(3, layers=5) + + with self.assertRaises(AssertionError): + ResNet31OCR(3, channels=5) + + # Test ResNet18 forward + model = ResNet31OCR() + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 160) + feat = model(imgs) + self.assertEqual(feat.shape, torch.Size([1, 512, 4, 40])) diff --git a/tests/test_models/test_textrecog/test_backbones/test_resnet_abi.py b/tests/test_models/test_textrecog/test_backbones/test_resnet_abi.py new file mode 100644 index 00000000..950e9a8a --- /dev/null +++ b/tests/test_models/test_textrecog/test_backbones/test_resnet_abi.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textrecog.backbones import ResNetABI + + +class TestResNetABI(TestCase): + + def test_forward(self): + """Test resnet backbone.""" + with self.assertRaises(AssertionError): + ResNetABI(2.5) + + with self.assertRaises(AssertionError): + ResNetABI(3, arch_settings=5) + + with self.assertRaises(AssertionError): + ResNetABI(3, stem_channels=None) + + with self.assertRaises(AssertionError): + ResNetABI(arch_settings=[3, 4, 6, 6], strides=[1, 2, 1, 2, 1]) + + # Test forwarding + model = ResNetABI() + model.train() + + imgs = torch.randn(1, 3, 32, 160) + feat = model(imgs) + self.assertEqual(feat.shape, torch.Size([1, 512, 8, 40])) diff --git a/tests/test_models/test_textrecog/test_backbones/test_very_deep_vgg.py b/tests/test_models/test_textrecog/test_backbones/test_very_deep_vgg.py new file mode 100644 index 00000000..7b1d7c73 --- /dev/null +++ b/tests/test_models/test_textrecog/test_backbones/test_very_deep_vgg.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textrecog.backbones import VeryDeepVgg + + +class TestVeryDeepVgg(TestCase): + + def test_forward(self): + + model = VeryDeepVgg() + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 160) + feats = model(imgs) + self.assertEqual(feats.shape, torch.Size([1, 512, 1, 41]))