mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] Refactor TPS (#1240)
* [Docs] Limit extension versions (#1209) * loss * fix * [update] limit extension versions * refactor tps * Update mmocr/models/textrecog/preprocessors/tps_preprocessor.py Co-authored-by: Tong Gao <gaotongxiao@gmail.com> * Update mmocr/models/textrecog/preprocessors/tps_preprocessor.py Co-authored-by: Tong Gao <gaotongxiao@gmail.com> * refine Co-authored-by: Tong Gao <gaotongxiao@gmail.com>pull/1596/head
parent
b8c445b04f
commit
a12c215e85
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .tps_preprocessor import STN, TPStransform
|
||||
|
||||
__all__ = ['TPStransform', 'STN']
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BasePreprocessor(BaseModule):
|
||||
"""Base Preprocessor class for text recognition."""
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return x
|
|
@ -0,0 +1,272 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from .base import BasePreprocessor
|
||||
|
||||
|
||||
class TPStransform(nn.Module):
|
||||
"""Implement TPS transform.
|
||||
|
||||
This was partially adapted from https://github.com/ayumiymk/aster.pytorch
|
||||
|
||||
Args:
|
||||
output_image_size (tuple[int, int]): The size of the output image.
|
||||
Defaults to (32, 128).
|
||||
num_control_points (int): The number of control points. Defaults to 20.
|
||||
margins (tuple[float, float]): The margins for control points to the
|
||||
top and down side of the image. Defaults to [0.05, 0.05].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_image_size: Tuple[int, int] = (32, 100),
|
||||
num_control_points: int = 20,
|
||||
margins: Tuple[float, float] = [0.05, 0.05]) -> None:
|
||||
super().__init__()
|
||||
self.output_image_size = output_image_size
|
||||
self.num_control_points = num_control_points
|
||||
self.margins = margins
|
||||
self.target_height, self.target_width = output_image_size
|
||||
|
||||
# build output control points
|
||||
target_control_points = self._build_output_control_points(
|
||||
num_control_points, margins)
|
||||
N = num_control_points
|
||||
|
||||
# create padded kernel matrix
|
||||
forward_kernel = torch.zeros(N + 3, N + 3)
|
||||
target_control_partial_repr = self._compute_partial_repr(
|
||||
target_control_points, target_control_points)
|
||||
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
||||
forward_kernel[:N, -3].fill_(1)
|
||||
forward_kernel[-3, :N].fill_(1)
|
||||
forward_kernel[:N, -2:].copy_(target_control_points)
|
||||
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
||||
|
||||
# compute inverse matrix
|
||||
inverse_kernel = torch.inverse(forward_kernel)
|
||||
|
||||
# create target coordinate matrix
|
||||
HW = self.target_height * self.target_width
|
||||
tgt_coord = list(
|
||||
itertools.product(
|
||||
range(self.target_height), range(self.target_width)))
|
||||
tgt_coord = torch.Tensor(tgt_coord)
|
||||
Y, X = tgt_coord.split(1, dim=1)
|
||||
Y = Y / (self.target_height - 1)
|
||||
X = X / (self.target_width - 1)
|
||||
tgt_coord = torch.cat([X, Y], dim=1)
|
||||
tgt_coord_partial_repr = self._compute_partial_repr(
|
||||
tgt_coord, target_control_points)
|
||||
tgt_coord_repr = torch.cat(
|
||||
[tgt_coord_partial_repr,
|
||||
torch.ones(HW, 1), tgt_coord], dim=1)
|
||||
|
||||
# register precomputed matrices
|
||||
self.register_buffer('inverse_kernel', inverse_kernel)
|
||||
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
||||
self.register_buffer('target_coordinate_repr', tgt_coord_repr)
|
||||
self.register_buffer('target_control_points', target_control_points)
|
||||
|
||||
def forward(self, input: torch.Tensor,
|
||||
source_control_points: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of the TPS block.
|
||||
|
||||
Args:
|
||||
input (Tensor): The input image.
|
||||
source_control_points (Tensor): The control points of the source
|
||||
image of shape (N, self.num_control_points, 2).
|
||||
Returns:
|
||||
Tensor: The output image after TPS transform.
|
||||
"""
|
||||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.size(1) == self.num_control_points
|
||||
assert source_control_points.size(2) == 2
|
||||
batch_size = source_control_points.size(0)
|
||||
|
||||
Y = torch.cat([
|
||||
source_control_points,
|
||||
self.padding_matrix.expand(batch_size, 3, 2)
|
||||
], 1)
|
||||
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
|
||||
source_coordinate = torch.matmul(self.target_coordinate_repr,
|
||||
mapping_matrix)
|
||||
|
||||
grid = source_coordinate.view(-1, self.target_height,
|
||||
self.target_width, 2)
|
||||
grid = torch.clamp(grid, 0, 1)
|
||||
grid = 2.0 * grid - 1.0
|
||||
output_maps = self._grid_sample(input, grid, canvas=None)
|
||||
return output_maps
|
||||
|
||||
def _grid_sample(self,
|
||||
input: torch.Tensor,
|
||||
grid: torch.Tensor,
|
||||
canvas: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Sample the input image at the given grid.
|
||||
|
||||
Args:
|
||||
input (Tensor): The input image.
|
||||
grid (Tensor): The grid to sample the input image.
|
||||
canvas (Optional[Tensor]): The canvas to store the output image.
|
||||
Returns:
|
||||
Tensor: The sampled image.
|
||||
"""
|
||||
output = F.grid_sample(input, grid, align_corners=True)
|
||||
if canvas is None:
|
||||
return output
|
||||
else:
|
||||
input_mask = input.data.new(input.size()).fill_(1)
|
||||
output_mask = F.grid_sample(input_mask, grid, align_corners=True)
|
||||
padded_output = output * output_mask + canvas * (1 - output_mask)
|
||||
return padded_output
|
||||
|
||||
def _compute_partial_repr(self, input_points: torch.Tensor,
|
||||
control_points: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the partial representation matrix.
|
||||
|
||||
Args:
|
||||
input_points (Tensor): The input points.
|
||||
control_points (Tensor): The control points.
|
||||
Returns:
|
||||
Tensor: The partial representation matrix.
|
||||
"""
|
||||
N = input_points.size(0)
|
||||
M = control_points.size(0)
|
||||
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(
|
||||
1, M, 2)
|
||||
pairwise_diff_square = pairwise_diff * pairwise_diff
|
||||
pairwise_dist = pairwise_diff_square[:, :,
|
||||
0] + pairwise_diff_square[:, :, 1]
|
||||
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
||||
mask = repr_matrix != repr_matrix
|
||||
repr_matrix.masked_fill_(mask, 0)
|
||||
return repr_matrix
|
||||
|
||||
# output_ctrl_pts are specified, according to our task.
|
||||
def _build_output_control_points(self, num_control_points: torch.Tensor,
|
||||
margins: Tuple[float,
|
||||
float]) -> torch.Tensor:
|
||||
"""Build the output control points.
|
||||
|
||||
The output points will be fix at
|
||||
top and down side of the image.
|
||||
Args:
|
||||
num_control_points (Tensor): The number of control points.
|
||||
margins (Tuple[float, float]): The margins for control points to
|
||||
the top and down side of the image.
|
||||
Returns:
|
||||
Tensor: The output control points.
|
||||
"""
|
||||
margin_x, margin_y = margins
|
||||
num_ctrl_pts_per_side = num_control_points // 2
|
||||
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x,
|
||||
num_ctrl_pts_per_side)
|
||||
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
||||
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
|
||||
axis=0)
|
||||
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
|
||||
return output_ctrl_pts
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STN(BasePreprocessor):
|
||||
"""Implement STN module in ASTER: An Attentional Scene Text Recognizer with
|
||||
Flexible Rectification
|
||||
(https://ieeexplore.ieee.org/abstract/document/8395027/)
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
resized_image_size (Tuple[int, int]): The resized image size. The input
|
||||
image will be downsampled to have a better recitified result.
|
||||
output_image_size: The size of the output image for TPS. Defaults to
|
||||
(32, 100).
|
||||
num_control_points: The number of control points. Defaults to 20.
|
||||
margins: The margins for control points to the top and down side of the
|
||||
image for TPS. Defaults to [0.05, 0.05].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
resized_image_size: Tuple[int, int] = (32, 64),
|
||||
output_image_size: Tuple[int, int] = (32, 100),
|
||||
num_control_points: int = 20,
|
||||
margins: Tuple[float, float] = [0.05, 0.05],
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = [
|
||||
dict(type='Xavier', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer='BatchNorm2d'),
|
||||
]):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.resized_image_size = resized_image_size
|
||||
self.num_control_points = num_control_points
|
||||
self.tps = TPStransform(output_image_size, num_control_points, margins)
|
||||
self.stn_convnet = nn.Sequential(
|
||||
ConvModule(in_channels, 32, 3, 1, 1, norm_cfg=dict(type='BN')),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
ConvModule(32, 64, 3, 1, 1, norm_cfg=dict(type='BN')),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
ConvModule(64, 128, 3, 1, 1, norm_cfg=dict(type='BN')),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
ConvModule(128, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
ConvModule(256, 256, 3, 1, 1, norm_cfg=dict(type='BN')),
|
||||
)
|
||||
|
||||
self.stn_fc1 = nn.Sequential(
|
||||
nn.Linear(2 * 256, 512), nn.BatchNorm1d(512),
|
||||
nn.ReLU(inplace=True))
|
||||
self.stn_fc2 = nn.Linear(512, num_control_points * 2)
|
||||
self.init_stn(self.stn_fc2)
|
||||
|
||||
def init_stn(self, stn_fc2: nn.Linear) -> None:
|
||||
"""Initialize the output linear layer of stn, so that the initial
|
||||
source point will be at the top and down side of the image, which will
|
||||
help to optimize.
|
||||
|
||||
Args:
|
||||
stn_fc2 (nn.Linear): The output linear layer of stn.
|
||||
"""
|
||||
margin = 0.01
|
||||
sampling_num_per_side = int(self.num_control_points / 2)
|
||||
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
|
||||
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
|
||||
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom],
|
||||
axis=0).astype(np.float32)
|
||||
stn_fc2.weight.data.zero_()
|
||||
stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1)
|
||||
|
||||
def forward(self, img: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of STN.
|
||||
|
||||
Args:
|
||||
img (Tensor): The input image tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: The rectified image tensor.
|
||||
"""
|
||||
resize_img = F.interpolate(
|
||||
img, self.resized_image_size, mode='bilinear', align_corners=True)
|
||||
points = self.stn_convnet(resize_img)
|
||||
batch_size, _, _, _ = points.size()
|
||||
points = points.view(batch_size, -1)
|
||||
img_feat = self.stn_fc1(points)
|
||||
points = self.stn_fc2(0.1 * img_feat)
|
||||
points = points.view(-1, self.num_control_points, 2)
|
||||
|
||||
transformd_image = self.tps(img, points)
|
||||
return transformd_image
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.preprocessors import STN, TPStransform
|
||||
|
||||
|
||||
class TestTPS(TestCase):
|
||||
|
||||
def test_tps_transform(self):
|
||||
tps = TPStransform(output_image_size=(32, 100), num_control_points=20)
|
||||
image = torch.rand(2, 3, 32, 64)
|
||||
control_points = torch.rand(2, 20, 2)
|
||||
transformed_image = tps(image, control_points)
|
||||
self.assertEqual(transformed_image.shape, (2, 3, 32, 100))
|
||||
|
||||
def test_stn(self):
|
||||
stn = STN(
|
||||
in_channels=3,
|
||||
resized_image_size=(32, 64),
|
||||
output_image_size=(32, 100),
|
||||
num_control_points=20)
|
||||
image = torch.rand(2, 3, 64, 256)
|
||||
transformed_image = stn(image)
|
||||
self.assertEqual(transformed_image.shape, (2, 3, 32, 100))
|
Loading…
Reference in New Issue