mirror of https://github.com/open-mmlab/mmocr.git
parent
6992923768
commit
7f4a1eecdc
|
@ -171,6 +171,7 @@ Supported algorithms:
|
|||
<summary>Text Spotting</summary>
|
||||
|
||||
- [x] [ABCNet](projects/ABCNet/README.md) (CVPR'2020)
|
||||
- [x] [ABCNetV2](projects/ABCNet/README_V2.md) (TPAMI'2021)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -171,6 +171,7 @@ pip3 install -e .
|
|||
<summary>端对端 OCR</summary>
|
||||
|
||||
- [x] [ABCNet](projects/ABCNet/README.md) (CVPR'2020)
|
||||
- [x] [ABCNetV2](projects/ABCNet/README_V2.md) (TPAMI'2021)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -44,14 +44,6 @@ ln -s ${DataPath} $PYTHONPATH
|
|||
New-Item -ItemType SymbolicLink -Path $env:PYTHONPATH -Name data -Target ${DataPath}
|
||||
```
|
||||
|
||||
As of now, `BezierAlign` is not yet supported by MMCV, and we will use third-party MMCV with the implementation of `BezierAlign`. You will need to install it from the source code as follows:
|
||||
|
||||
```bash
|
||||
git clone -b lkk/bezier_align https://github.com/Harold-lkk/mmcv.git
|
||||
cd mmcv
|
||||
MMCV_WITH_OPS=1 MAX_JOBS=8 python setup.py develop
|
||||
```
|
||||
|
||||
### Training commands
|
||||
|
||||
In the current directory, run the following command to train the model:
|
||||
|
@ -95,7 +87,6 @@ If you find ABCNet useful in your research or applications, please cite ABCNet w
|
|||
booktitle = {Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR)},
|
||||
year = {2020}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Checklist
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# ABCNet: Real-time Scene Text Spotting with Adaptive Bezier-Curve Network
|
||||
|
||||
<div>
|
||||
<a href="https://arxiv.org/abs/2105.03620">[arXiv paper]</a>
|
||||
<a href="https://ieeexplore.ieee.org/document/9525302">[TPAMI paper]</a>
|
||||
</div>
|
||||
|
||||
## Description
|
||||
|
||||
This is an implementation of [ABCNetV2](https://github.com/aim-uofa/AdelaiDet) based on [MMOCR](https://github.com/open-mmlab/mmocr/tree/dev-1.x), [MMCV](https://github.com/open-mmlab/mmcv), and [MMEngine](https://github.com/open-mmlab/mmengine).
|
||||
|
||||
**ABCNetV2** contributions are four-fold: 1) For the first time, we adaptively fit arbitrarily-shaped text by a parameterized Bezier curve, which, compared with segmentation-based methods, can not only provide structured output but also controllable representation. 2) We design a novel BezierAlign layer for extracting accurate convolution features of a text instance of arbitrary shapes, significantly improving the precision of recognition over previous methods. 3) Different from previous methods, which often suffer from complex post-processing and sensitive hyper-parameters, our ABCNet v2 maintains a simple pipeline with the only post-processing non-maximum suppression (NMS). 4) As the performance of text recognition closely depends on feature alignment, ABCNet v2 further adopts a simple yet effective coordinate convolution to encode the position of the convolutional filters, which leads to a considerable improvement with negligible computation overhead. Comprehensive experiments conducted on various bilingual (English and Chinese) benchmark datasets demonstrate that ABCNet v2 can achieve state-of-the-art performance while maintaining very high efficiency.
|
||||
|
||||
<center>
|
||||
<img src="https://user-images.githubusercontent.com/24622904/213096846-2557e0ac-ca18-4c4f-88c1-569107f48f9b.png">
|
||||
</center>
|
||||
|
||||
## Usage
|
||||
|
||||
<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.7
|
||||
- PyTorch 1.6 or higher
|
||||
- [MIM](https://github.com/open-mmlab/mim)
|
||||
- [MMOCR](https://github.com/open-mmlab/mmocr)
|
||||
|
||||
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. In `ABCNet/` root directory, run the following line to add the current directory to `PYTHONPATH`:
|
||||
|
||||
```shell
|
||||
# Linux
|
||||
export PYTHONPATH=`pwd`:$PYTHONPATH
|
||||
# Windows PowerShell
|
||||
$env:PYTHONPATH=Get-Location
|
||||
```
|
||||
|
||||
if the data is not in `ABCNet/`, you can link the data into `ABCNet/`:
|
||||
|
||||
```shell
|
||||
# Linux
|
||||
ln -s ${DataPath} $PYTHONPATH
|
||||
# Windows PowerShell
|
||||
New-Item -ItemType SymbolicLink -Path $env:PYTHONPATH -Name data -Target ${DataPath}
|
||||
```
|
||||
|
||||
### Testing commands
|
||||
|
||||
In the current directory, run the following command to test the model:
|
||||
|
||||
```bash
|
||||
mim test mmocr config/abcnet_v2/abcnet-v2_resnet50_bifpn_500e_icdar2015.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH}
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
Here we provide the baseline version of ABCNet with ResNet50 backbone.
|
||||
|
||||
To find more variants, please visit the [official model zoo](https://github.com/aim-uofa/AdelaiDet/blob/master/configs/BAText/README.md).
|
||||
|
||||
| Name | Pretrained Model | E2E-None-Hmean | det-Hmean | Download |
|
||||
| :-------------------: | :--------------: | :------------: | :-------: | :------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| v2-icdar2015-finetune | SynthText | 0.6628 | 0.8886 | [model](https://download.openmmlab.com/mmocr/textspotting/abcnet-v2/abcnet-v2_resnet50_bifpn/abcnet-v2_resnet50_bifpn_500e_icdar2015-5e4cc7ed.pth) |
|
||||
|
||||
## Citation
|
||||
|
||||
If you find ABCNetV2 useful in your research or applications, please cite ABCNetV2 with the following BibTeX entry.
|
||||
|
||||
```BibTeX
|
||||
@ARTICLE{9525302,
|
||||
author={Liu, Yuliang and Shen, Chunhua and Jin, Lianwen and He, Tong and Chen, Peng and Liu, Chongyu and Chen, Hao},
|
||||
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
||||
title={ABCNet v2: Adaptive Bezier-Curve Network for Real-time End-to-end Text Spotting},
|
||||
year={2021},
|
||||
volume={},
|
||||
number={},
|
||||
pages={1-1},
|
||||
doi={10.1109/TPAMI.2021.3107437}}
|
||||
```
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Here is a checklist illustrating a usual development workflow of a successful project, and also serves as an overview of this project's progress. The PIC (person in charge) or contributors of this project should check all the items that they believe have been finished, which will further be verified by codebase maintainers via a PR.
|
||||
|
||||
OpenMMLab's maintainer will review the code to ensure the project's quality. Reaching the first milestone means that this project suffices the minimum requirement of being merged into 'projects/'. But this project is only eligible to become a part of the core package upon attaining the last milestone.
|
||||
|
||||
Note that keeping this section up-to-date is crucial not only for this project's developers but the entire community, since there might be some other contributors joining this project and deciding their starting point from this list. It also helps maintainers accurately estimate time and effort on further code polishing, if needed.
|
||||
|
||||
A project does not necessarily have to be finished in a single PR, but it's essential for the project to at least reach the first milestone in its very first PR. -->
|
||||
|
||||
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
|
||||
|
||||
- [x] Finish the code
|
||||
|
||||
<!-- The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmocr.registry.MODELS` and configurable via a config file. -->
|
||||
|
||||
- [x] Basic docstrings & proper citation
|
||||
|
||||
<!-- Each major object should contain a docstring, describing its functionality and arguments. If you have adapted the code from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd) -->
|
||||
|
||||
- [x] Test-time correctness
|
||||
|
||||
<!-- If you are reproducing the result from a paper, make sure your model's inference-time performance matches that in the original paper. The weights usually could be obtained by simply renaming the keys in the official pre-trained weights. This test could be skipped though, if you are able to prove the training-time correctness and check the second milestone. -->
|
||||
|
||||
- [x] A full README
|
||||
|
||||
<!-- As this template does. -->
|
||||
|
||||
- [ ] Milestone 2: Indicates a successful model implementation.
|
||||
|
||||
- [ ] Training-time correctness
|
||||
|
||||
<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->
|
||||
|
||||
- [ ] Milestone 3: Good to be a part of our core package!
|
||||
|
||||
- [ ] Type hints and docstrings
|
||||
|
||||
<!-- Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/mmocr/utils/polygon_utils.py#L80-L96) -->
|
||||
|
||||
- [ ] Unit tests
|
||||
|
||||
<!-- Unit tests for each module are required. [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/tests/test_utils/test_polygon_utils.py#L97-L106) -->
|
||||
|
||||
- [ ] Code polishing
|
||||
|
||||
<!-- Refactor your code according to reviewer's comment. -->
|
||||
|
||||
- [ ] Metafile.yml
|
||||
|
||||
<!-- It will be parsed by MIM and Inferencer. [Example](https://github.com/open-mmlab/mmocr/blob/1.x/configs/textdet/dbnet/metafile.yml) -->
|
||||
|
||||
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
|
||||
|
||||
<!-- In particular, you may have to refactor this README into a standard one. [Example](/configs/textdet/dbnet/README.md) -->
|
||||
|
||||
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
|
|
@ -9,11 +9,13 @@ from .abcnet_rec_backbone import ABCNetRecBackbone
|
|||
from .abcnet_rec_decoder import ABCNetRecDecoder
|
||||
from .abcnet_rec_encoder import ABCNetRecEncoder
|
||||
from .bezier_roi_extractor import BezierRoIExtractor
|
||||
from .bifpn import BiFPN
|
||||
from .coordinate_head import CoordinateHead
|
||||
from .rec_roi_head import RecRoIHead
|
||||
|
||||
__all__ = [
|
||||
'ABCNetDetHead', 'ABCNetDetPostprocessor', 'ABCNetRecBackbone',
|
||||
'ABCNetRecDecoder', 'ABCNetRecEncoder', 'ABCNet', 'ABCNetRec',
|
||||
'BezierRoIExtractor', 'RecRoIHead', 'ABCNetPostprocessor',
|
||||
'ABCNetDetModuleLoss'
|
||||
'ABCNetDetModuleLoss', 'BiFPN', 'CoordinateHead'
|
||||
]
|
||||
|
|
|
@ -46,8 +46,7 @@ class ABCNetRecBackbone(BaseModule):
|
|||
stride=(2, 1),
|
||||
bias='auto',
|
||||
norm_cfg=dict(type='GN', num_groups=32),
|
||||
act_cfg=dict(type='ReLU')),
|
||||
nn.AvgPool2d(kernel_size=(2, 1), stride=1))
|
||||
act_cfg=dict(type='ReLU')), nn.AdaptiveAvgPool2d((1, None)))
|
||||
|
||||
def forward(self, x):
|
||||
return self.convs(x)
|
||||
|
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import ConfigType, MultiConfig, OptConfigType
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiFPN(BaseModule):
|
||||
"""illustration of a minimal bifpn unit P7_0 ------------------------->
|
||||
P7_2 -------->
|
||||
|
||||
|-------------| ↑ ↓ |
|
||||
P6_0 ---------> P6_1 ---------> P6_2 -------->
|
||||
|-------------|--------------↑ ↑ ↓ | P5_0
|
||||
---------> P5_1 ---------> P5_2 --------> |-------------|--------------↑
|
||||
↑ ↓ | P4_0 ---------> P4_1 ---------> P4_2
|
||||
--------> |-------------|--------------↑ ↑
|
||||
|--------------↓ | P3_0 -------------------------> P3_2 -------->
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: List[int],
|
||||
out_channels: int,
|
||||
num_outs: int,
|
||||
repeat_times: int = 2,
|
||||
start_level: int = 0,
|
||||
end_level: int = -1,
|
||||
add_extra_convs: bool = False,
|
||||
relu_before_extra_convs: bool = False,
|
||||
no_norm_on_lateral: bool = False,
|
||||
conv_cfg: OptConfigType = None,
|
||||
norm_cfg: OptConfigType = None,
|
||||
act_cfg: OptConfigType = None,
|
||||
laterial_conv1x1: bool = False,
|
||||
upsample_cfg: ConfigType = dict(mode='nearest'),
|
||||
pool_cfg: ConfigType = dict(),
|
||||
init_cfg: MultiConfig = dict(
|
||||
type='Xavier', layer='Conv2d', distribution='uniform')):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_ins = len(in_channels)
|
||||
self.num_outs = num_outs
|
||||
self.relu_before_extra_convs = relu_before_extra_convs
|
||||
self.no_norm_on_lateral = no_norm_on_lateral
|
||||
self.upsample_cfg = upsample_cfg.copy()
|
||||
self.repeat_times = repeat_times
|
||||
if end_level == -1 or end_level == self.num_ins - 1:
|
||||
self.backbone_end_level = self.num_ins
|
||||
assert num_outs >= self.num_ins - start_level
|
||||
else:
|
||||
# if end_level is not the last level, no extra level is allowed
|
||||
self.backbone_end_level = end_level + 1
|
||||
assert end_level < self.num_ins
|
||||
assert num_outs == end_level - start_level + 1
|
||||
self.start_level = start_level
|
||||
self.end_level = end_level
|
||||
self.add_extra_convs = add_extra_convs
|
||||
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.extra_convs = nn.ModuleList()
|
||||
self.bifpn_convs = nn.ModuleList()
|
||||
for i in range(self.start_level, self.backbone_end_level):
|
||||
if in_channels[i] == out_channels:
|
||||
l_conv = nn.Identity()
|
||||
else:
|
||||
l_conv = ConvModule(
|
||||
in_channels[i],
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
bias=True,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
self.lateral_convs.append(l_conv)
|
||||
|
||||
for _ in range(repeat_times):
|
||||
self.bifpn_convs.append(
|
||||
BiFPNLayer(
|
||||
channels=out_channels,
|
||||
levels=num_outs,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
pool_cfg=pool_cfg))
|
||||
|
||||
# add extra conv layers (e.g., RetinaNet)
|
||||
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
||||
if add_extra_convs and extra_levels >= 1:
|
||||
for i in range(extra_levels):
|
||||
if i == 0:
|
||||
in_channels = self.in_channels[self.backbone_end_level - 1]
|
||||
else:
|
||||
in_channels = out_channels
|
||||
if in_channels == out_channels:
|
||||
extra_fpn_conv = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
extra_fpn_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
self.extra_convs.append(extra_fpn_conv)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
def extra_convs(inputs, extra_convs):
|
||||
outputs = list()
|
||||
for extra_conv in extra_convs:
|
||||
inputs = extra_conv(inputs)
|
||||
outputs.append(inputs)
|
||||
return outputs
|
||||
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i + self.start_level])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
if self.num_outs > len(laterals) and self.add_extra_convs:
|
||||
extra_source = inputs[self.backbone_end_level - 1]
|
||||
for extra_conv in self.extra_convs:
|
||||
extra_source = extra_conv(extra_source)
|
||||
laterals.append(extra_source)
|
||||
|
||||
for bifpn_module in self.bifpn_convs:
|
||||
laterals = bifpn_module(laterals)
|
||||
outs = laterals
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * x.sigmoid()
|
||||
|
||||
|
||||
class BiFPNLayer(BaseModule):
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
levels,
|
||||
init=0.5,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
upsample_cfg=None,
|
||||
pool_cfg=None,
|
||||
eps=0.0001,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.act_cfg = act_cfg
|
||||
self.upsample_cfg = upsample_cfg
|
||||
self.pool_cfg = pool_cfg
|
||||
self.eps = eps
|
||||
self.levels = levels
|
||||
self.bifpn_convs = nn.ModuleList()
|
||||
# weighted
|
||||
self.weight_two_nodes = nn.Parameter(
|
||||
torch.Tensor(2, levels).fill_(init))
|
||||
self.weight_three_nodes = nn.Parameter(
|
||||
torch.Tensor(3, levels - 2).fill_(init))
|
||||
self.relu = nn.ReLU()
|
||||
for _ in range(2):
|
||||
for _ in range(self.levels - 1): # 1,2,3
|
||||
fpn_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False))
|
||||
self.bifpn_convs.append(fpn_conv)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == self.levels
|
||||
# build top-down and down-top path with stack
|
||||
levels = self.levels
|
||||
# w relu
|
||||
w1 = self.relu(self.weight_two_nodes)
|
||||
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
|
||||
w2 = self.relu(self.weight_three_nodes)
|
||||
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
|
||||
# build top-down
|
||||
idx_bifpn = 0
|
||||
pathtd = inputs
|
||||
inputs_clone = []
|
||||
for in_tensor in inputs:
|
||||
inputs_clone.append(in_tensor.clone())
|
||||
|
||||
for i in range(levels - 1, 0, -1):
|
||||
_, _, h, w = pathtd[i - 1].shape
|
||||
# pathtd[i - 1] = (
|
||||
# w1[0, i - 1] * pathtd[i - 1] + w1[1, i - 1] *
|
||||
# F.interpolate(pathtd[i], size=(h, w), mode='nearest')) / (
|
||||
# w1[0, i - 1] + w1[1, i - 1] + self.eps)
|
||||
pathtd[i -
|
||||
1] = w1[0, i -
|
||||
1] * pathtd[i - 1] + w1[1, i - 1] * F.interpolate(
|
||||
pathtd[i], size=(h, w), mode='nearest')
|
||||
pathtd[i - 1] = swish(pathtd[i - 1])
|
||||
pathtd[i - 1] = self.bifpn_convs[idx_bifpn](pathtd[i - 1])
|
||||
idx_bifpn = idx_bifpn + 1
|
||||
# build down-top
|
||||
for i in range(0, levels - 2, 1):
|
||||
tmp_path = torch.stack([
|
||||
inputs_clone[i + 1], pathtd[i + 1],
|
||||
F.max_pool2d(pathtd[i], kernel_size=3, stride=2, padding=1)
|
||||
],
|
||||
dim=-1)
|
||||
norm_weight = w2[:, i] / (w2[:, i].sum() + self.eps)
|
||||
pathtd[i + 1] = (norm_weight * tmp_path).sum(dim=-1)
|
||||
# pathtd[i + 1] = w2[0, i] * inputs_clone[i + 1]
|
||||
# + w2[1, i] * pathtd[
|
||||
# i + 1] + w2[2, i] * F.max_pool2d(
|
||||
# pathtd[i], kernel_size=3, stride=2, padding=1)
|
||||
pathtd[i + 1] = swish(pathtd[i + 1])
|
||||
pathtd[i + 1] = self.bifpn_convs[idx_bifpn](pathtd[i + 1])
|
||||
idx_bifpn = idx_bifpn + 1
|
||||
|
||||
pathtd[levels - 1] = w1[0, levels - 1] * pathtd[levels - 1] + w1[
|
||||
1, levels - 1] * F.max_pool2d(
|
||||
pathtd[levels - 2], kernel_size=3, stride=2, padding=1)
|
||||
pathtd[levels - 1] = swish(pathtd[levels - 1])
|
||||
pathtd[levels - 1] = self.bifpn_convs[idx_bifpn](pathtd[levels - 1])
|
||||
return pathtd
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CoordinateHead(BaseModule):
|
||||
|
||||
def __init__(self,
|
||||
in_channel=256,
|
||||
conv_num=4,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
mask_convs = list()
|
||||
for i in range(conv_num):
|
||||
if i == 0:
|
||||
mask_conv = ConvModule(
|
||||
in_channels=in_channel + 2, # 2 for coord
|
||||
out_channels=in_channel,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
else:
|
||||
mask_conv = ConvModule(
|
||||
in_channels=in_channel,
|
||||
out_channels=in_channel,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
mask_convs.append(mask_conv)
|
||||
self.mask_convs = nn.Sequential(*mask_convs)
|
||||
|
||||
def forward(self, features):
|
||||
coord_features = list()
|
||||
for feature in features:
|
||||
x_range = torch.linspace(
|
||||
-1, 1, feature.shape[-1], device=feature.device)
|
||||
y_range = torch.linspace(
|
||||
-1, 1, feature.shape[-2], device=feature.device)
|
||||
y, x = torch.meshgrid(y_range, x_range)
|
||||
y = y.expand([feature.shape[0], 1, -1, -1])
|
||||
x = x.expand([feature.shape[0], 1, -1, -1])
|
||||
coord = torch.cat([x, y], 1)
|
||||
feature_with_coord = torch.cat([feature, coord], dim=1)
|
||||
feature_with_coord = self.mask_convs(feature_with_coord)
|
||||
feature_with_coord = feature_with_coord + feature
|
||||
coord_features.append(feature_with_coord)
|
||||
return coord_features
|
|
@ -15,6 +15,7 @@ class RecRoIHead(BaseRoIHead):
|
|||
"""Simplest base roi head including one bbox head and one mask head."""
|
||||
|
||||
def __init__(self,
|
||||
neck=None,
|
||||
sampler: OptMultiConfig = None,
|
||||
roi_extractor: OptMultiConfig = None,
|
||||
rec_head: OptMultiConfig = None,
|
||||
|
@ -22,6 +23,8 @@ class RecRoIHead(BaseRoIHead):
|
|||
super().__init__(init_cfg)
|
||||
if sampler is not None:
|
||||
self.sampler = TASK_UTILS.build(sampler)
|
||||
if neck is not None:
|
||||
self.neck = MODELS.build(neck)
|
||||
self.roi_extractor = MODELS.build(roi_extractor)
|
||||
self.rec_head = MODELS.build(rec_head)
|
||||
|
||||
|
@ -54,7 +57,8 @@ class RecRoIHead(BaseRoIHead):
|
|||
|
||||
def predict(self, inputs: Tuple[Tensor],
|
||||
data_samples: DetSampleList) -> RecSampleList:
|
||||
|
||||
if hasattr(self, 'neck') and self.neck is not None:
|
||||
inputs = self.neck(inputs)
|
||||
pred_instances = [ds.pred_instances for ds in data_samples]
|
||||
bbox_feats = self.roi_extractor(inputs, pred_instances)
|
||||
if bbox_feats.size(0) == 0:
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
file_client_args = dict(backend='disk')
|
||||
num_classes = 1
|
||||
strides = [8, 16, 32, 64, 128]
|
||||
bbox_coder = dict(type='mmdet.DistancePointBBoxCoder')
|
||||
with_bezier = True
|
||||
norm_on_bbox = True
|
||||
use_sigmoid_cls = True
|
||||
|
||||
dictionary = dict(
|
||||
type='Dictionary',
|
||||
dict_file='{{ fileDirname }}/../../dicts/abcnet.txt',
|
||||
with_start=False,
|
||||
with_end=False,
|
||||
same_start_end=False,
|
||||
with_padding=True,
|
||||
with_unknown=True)
|
||||
|
||||
model = dict(
|
||||
type='ABCNet',
|
||||
data_preprocessor=dict(
|
||||
type='TextDetDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53][::-1],
|
||||
std=[1, 1, 1],
|
||||
bgr_to_rgb=False,
|
||||
pad_size_divisor=32),
|
||||
backbone=dict(
|
||||
type='mmdet.ResNet',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=1,
|
||||
norm_cfg=dict(type='BN', requires_grad=False),
|
||||
norm_eval=True,
|
||||
style='caffe',
|
||||
init_cfg=dict(
|
||||
type='Pretrained',
|
||||
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
|
||||
neck=dict(
|
||||
type='BiFPN',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
out_channels=256,
|
||||
start_level=0,
|
||||
add_extra_convs=True, # use P5
|
||||
norm_cfg=dict(type='BN'),
|
||||
num_outs=6,
|
||||
relu_before_extra_convs=True),
|
||||
det_head=dict(
|
||||
type='ABCNetDetHead',
|
||||
num_classes=num_classes,
|
||||
in_channels=256,
|
||||
stacked_convs=4,
|
||||
feat_channels=256,
|
||||
strides=strides,
|
||||
norm_on_bbox=norm_on_bbox,
|
||||
use_sigmoid_cls=use_sigmoid_cls,
|
||||
centerness_on_reg=True,
|
||||
dcn_on_last_conv=False,
|
||||
conv_bias=True,
|
||||
use_scale=False,
|
||||
with_bezier=with_bezier,
|
||||
init_cfg=dict(
|
||||
type='Normal',
|
||||
layer='Conv2d',
|
||||
std=0.01,
|
||||
override=dict(
|
||||
type='Normal',
|
||||
name='conv_cls',
|
||||
std=0.01,
|
||||
bias=-4.59511985013459), # -log((1-p)/p) where p=0.01
|
||||
),
|
||||
module_loss=None,
|
||||
postprocessor=dict(
|
||||
type='ABCNetDetPostprocessor',
|
||||
# rescale_fields=['polygons', 'bboxes'],
|
||||
use_sigmoid_cls=use_sigmoid_cls,
|
||||
strides=[8, 16, 32, 64, 128],
|
||||
bbox_coder=dict(type='mmdet.DistancePointBBoxCoder'),
|
||||
with_bezier=True,
|
||||
test_cfg=dict(
|
||||
# rescale_fields=['polygon', 'bboxes', 'bezier'],
|
||||
nms_pre=1000,
|
||||
nms=dict(type='nms', iou_threshold=0.4),
|
||||
score_thr=0.3))),
|
||||
roi_head=dict(
|
||||
type='RecRoIHead',
|
||||
neck=dict(type='CoordinateHead'),
|
||||
roi_extractor=dict(
|
||||
type='BezierRoIExtractor',
|
||||
roi_layer=dict(
|
||||
type='BezierAlign', output_size=(16, 64), sampling_ratio=1.0),
|
||||
out_channels=256,
|
||||
featmap_strides=[4, 8, 16]),
|
||||
rec_head=dict(
|
||||
type='ABCNetRec',
|
||||
backbone=dict(type='ABCNetRecBackbone'),
|
||||
encoder=dict(type='ABCNetRecEncoder'),
|
||||
decoder=dict(
|
||||
type='ABCNetRecDecoder',
|
||||
dictionary=dictionary,
|
||||
postprocessor=dict(type='AttentionPostprocessor'),
|
||||
max_seq_len=25))),
|
||||
postprocessor=dict(
|
||||
type='ABCNetPostprocessor',
|
||||
rescale_fields=['polygons', 'bboxes', 'beziers'],
|
||||
))
|
||||
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=file_client_args,
|
||||
color_type='color_ignore_orientation'),
|
||||
dict(type='Resize', scale=(2000, 4000), keep_ratio=True, backend='pillow'),
|
||||
dict(
|
||||
type='LoadOCRAnnotations',
|
||||
with_polygon=True,
|
||||
with_bbox=True,
|
||||
with_label=True,
|
||||
with_text=True),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
]
|
|
@ -0,0 +1,23 @@
|
|||
_base_ = [
|
||||
'_base_abcnet-v2_resnet50_bifpn.py',
|
||||
'../_base_/datasets/icdar2015.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
|
||||
icdar2015_textspotting_test.pipeline = _base_.test_pipeline
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=icdar2015_textspotting_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
custom_imports = dict(imports=['abcnet'], allow_failed_imports=False)
|
Loading…
Reference in New Issue