documentation and dbnet related code

pull/2/head
quincylin1 2021-04-03 00:41:23 +08:00
parent 3ec7bd4934
commit b031934129
23 changed files with 2258 additions and 0 deletions

20
docs/Makefile 100644
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

15
docs/api.rst 100644
View File

@ -0,0 +1,15 @@
API Reference
=============
mmocr.apis
-------------
.. automodule:: mmocr.apis
:members:
mmocr.core
-------------
evaluation
^^^^^^^^^^
.. automodule:: mmocr.core.evaluation
:members:

View File

@ -0,0 +1 @@
## Changelog

View File

@ -0,0 +1,76 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at chenkaidev@gmail.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

83
docs/conf.py 100644
View File

@ -0,0 +1,83 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
import os
import subprocess
import sys
sys.path.insert(0, os.path.abspath('..'))
# -- Project information -----------------------------------------------------
project = 'MMOCR'
copyright = '2020-2030, OpenMMLab'
author = 'OpenMMLab'
# The full version, including alpha/beta/rc tags
release = '0.1.0'
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'recommonmark',
'sphinx_markdown_tables',
]
autodoc_mock_imports = ['torch', 'torchvision', 'mmcv', 'mmocr.version']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
# The master toctree document.
master_doc = 'index'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
master_doc = 'index'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
def builder_inited_handler(app):
subprocess.run(['./merge_docs.sh'])
subprocess.run(['./stats.py'])
def setup(app):
app.connect('builder-inited', builder_inited_handler)

View File

@ -0,0 +1,134 @@
# Contributing to mmocr
All kinds of contributions are welcome, including but not limited to the following.
- Fixes (typo, bugs)
- New features and components
## Workflow
This document describes the fork & merge request workflow that should be used when contributing to **MMOCR**.
The official public [repository](https://github.com/open-mmlab/mmocr) holds two branches with an infinite lifetime only:
+ master
+ develop
The *master* branch is the main branch where the source code of **HEAD** always reflects a *production-ready state*.
The *develop* branch is the branch where the source code of **HEAD** always reflects a state with the latest development changes for the next release.
Feature branches are used to develop new features for the upcoming or a distant future release.
![](res/git-workflow-master-develop.png)
All new developers to **MMOCR** need to follow the following steps:
### Step 1: creating a Fork
1. Fork the repo on GitHub or GitLab to your personal account. Click the `Fork` button on the [project page](https://github.com/open-mmlab/mmocr).
2. Clone your new forked repo to your computer.
```
git clone https://github.com/<your name>/mmocr.git
```
3. Add the official repo as an upstream:
```
git remote add upstream https://github.com/open-mmlab/mmocr.git
```
### Step 2: develop a new feature
#### Step 2.1: keeping your fork up to date
Whenever you want to update your fork with the latest upstream changes, you need to fetch the upstream repo's branches and latest commits to bring them into your repository:
```
# Fetch from upstream remote
git fetch upstream
# Update your master branch
git checkout master
git rebase upstream/master
git push origin master
# Update your develop branch
git checkout develop
git rebase upsteam/develop
git push origin develop
```
#### Step 2.2: creating a feature branch
```
git checkout -b <your_new_feature> develop
```
Till now, your fork has three branches as follows:
![](res/git-workflow-feature.png)
#### Step 2.3: develop and test <your_new_feature>
Develop your new feature and test it to make sure it works well.
Pls run
```
pre-commit run --all-files
pytest tests
```
and fix all failures before every git commit.
#### Step 2.4: prepare to PR
##### Merge official repo updates to your fork
```
# fetch from upstream remote. i.e., the official repo
git fetch upstream
# update the develop branch of your fork
git checkout develop
git rebase upsteam/develop
git push origin develop
# update the <your_new_feature> branch
git checkout <your_new_feature>
git rebase develop
# solve conflicts if any and Test
```
##### Push <your_new_feature> branch to your remote forked repo,
```
git checkout <your_new_feature>
git push origin <your_new_feature>
```
#### Step 2.5: send PR
Go to the page for your fork on GitHub, select your new feature branch, and click the pull request button to integrate your feature branch into the upstream remotes develop branch.
#### Step 2.6: review code
#### Step 2.7: revise <your_new_feature> (optional)
If PR is not accepted, pls follow Step 2.1, 2.3, 2.4 and 2.5 till your PR is accepted.
#### Step 2.8: del <your_new_feature> branch if your PR is accepted.
```
git branch -d <your_new_feature>
git push origin :<your_new_feature>
```
## Code style
### Python
We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
We use the following tools for linting and formatting:
- [flake8](http://flake8.pycqa.org/en/latest/): linter
- [yapf](https://github.com/google/yapf): formatter
- [isort](https://github.com/timothycrosley/isort): sort imports
>Before you create a PR, make sure that your code lints and is formatted by yapf.
### C++ and CUDA
We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).

182
docs/datasets.md 100644
View File

@ -0,0 +1,182 @@
# Datasets Preparation
This page lists the datasets which are commonly used in text detection, text recognition and key information extraction, and their download links.
## Text Detection
**The structure of the text detection dataset directory is organized as follows.**
```
├── ctw1500
│   ├── imgs
│   ├── instances_test.json
│   └── instances_training.json
├── icdar2015
│   ├── imgs
│   ├── instances_test.json
│   └── instances_training.json
├── icdar2017
│   ├── imgs
│   ├── instances_training.json
│   └── instances_val.json
├── synthtext
│   ├── imgs
│   ├── instances_training.json
│   ├── instances_training.txt
│   └── instances_training.lmdb
```
| Dataset | | Images | | | Annotation Files | | | Note | |
|:---------:|:-:|:--------------------------:|:-:|:--------------------------------------------:|:---------------------------------------:|:----------------------------------------:|:-:|:----:|---|
| | | | | training | validation | testing | | | |
| CTW1500 | | [link](https://github.com/Yuliang-Liu/Curve-Text-Detector) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/ctw1500/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/ctw1500/instances_test.json) | | | |
| ICDAR2015 | | [link](https://rrc.cvc.uab.es/?ch=4&com=downloads) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) | | | |
| ICDAR2017 | | [link](https://rrc.cvc.uab.es/?ch=8&com=downloads) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_training.json) | [instances_val.json](https://openmmlab) | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_test.json) | | | |
| Synthtext | | [link](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.json) [instances_training.txt](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.txt)|-| | | |
- For `icdar2015`:
- Step1: Download `ch4_training_images.zip` and `ch4_test_images.zip` from this [link](https://rrc.cvc.uab.es/?ch=4&com=downloads)
- Step2: Download [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) and [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json)
- Step3:
```bash
mkdir icdar2015 && cd icdar2015
mv /path/to/instances_training.json .
mv /path/to/instances_test.json .
mkdir imgs && cd imgs
ln -s /path/to/ch4_training_images training
ln -s /path/to/ch4_test_images test
```
## Text Recognition
**The structure of the text recognition dataset directory is organized as follows.**
```
├── mixture
│   ├── coco_text
│ │ ├── train_label.txt
│ │ ├── train_words
│   ├── icdar_2011
│ │ ├── training_label.txt
│ │ ├── Challenge1_Training_Task3_Images_GT
│   ├── icdar_2013
│ │ ├── train_label.txt
│ │ ├── test_label_1015.txt
│ │ ├── test_label_1095.txt
│ │ ├── Challenge2_Training_Task3_Images_GT
│ │ ├── Challenge2_Test_Task3_Images
│   ├── icdar_2015
│ │ ├── train_label.txt
│ │ ├── test_label.txt
│ │ ├── ch4_training_word_images_gt
│ │ ├── ch4_test_word_images_gt
│   ├── III5K
│ │ ├── train_label.txt
│ │ ├── test_label.txt
│ │ ├── train
│ │ ├── test
│   ├── ct80
│ │ ├── test_label.txt
│ │ ├── image
│   ├── svt
│ │ ├── test_label.txt
│ │ ├── image
│   ├── svtp
│ │ ├── test_label.txt
│ │ ├── image
│   ├── Synth90k
│ │ ├── shuffle_labels.txt
│ │ ├── mnt
│   ├── SynthText
│ │ ├── shuffle_labels.txt
│ │ ├── instances_train.txt
│ │ ├── synthtext
│   ├── SynthAdd
│ │ ├── label.txt
│ │ ├── SynthText_Add
```
| Dataset | | images | annotation file | annotation file | Note |
|:----------:|:-:|:---------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------:|:----:|
|| | |training | test | |
| coco_text ||[link](https://rrc.cvc.uab.es/?ch=5&com=downloads) |[train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt) |- | |
| icdar_2011 ||[link](http://www.cvc.uab.es/icdar2011competition/?com=downloads) |[train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) |- | |
| icdar_2013 | | [link](https://rrc.cvc.uab.es/?ch=2&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt) | [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) | |
| icdar_2015 | | [link](https://rrc.cvc.uab.es/?ch=4&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt) | |
| IIIT5K | | [link](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt) | |
| ct80 | | - |-|[test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt)||
| svt | | [link](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | |
| svtp | | - | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | |
| Synth90k | | [link](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) | - | |
| SynthText | | [link](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) &#124; [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) | - | |
| SynthAdd | | [link](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/SynthText_Add.zip) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)|- | |
- For `icdar_2013`:
- Step1: Download `Challenge2_Test_Task3_Images.zip` and `Challenge2_Training_Task3_Images_GT.zip` from this [link](https://rrc.cvc.uab.es/?ch=2&com=downloads)
- Step2: Download [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) and [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt)
- For `icdar_2015`:
- Step1: Download `ch4_training_word_images_gt.zip` and `ch4_test_word_images_gt.zip` from this [link](https://rrc.cvc.uab.es/?ch=4&com=downloads)
- Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) and [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt)
- For `IIIT5K`:
- Step1: Download `IIIT5K-Word_V3.0.tar.gz` from this [link](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)
- Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) and [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt)
- For `svt`:
- Step1: Download `svt.zip` form this [link](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)
- Step2: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt)
- For `ct80`:
- Step1: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt)
- For `svtp`:
- Step1: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt)
- For `coco_text`:
- Step1: Download from this [link](https://rrc.cvc.uab.es/?ch=5&com=downloads)
- Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt)
- For `Syn90k`:
- Step1: Download `mjsynth.tar.gz` from this [link](https://www.robots.ox.ac.uk/~vgg/data/text/)
- Step2: Download [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt)
- Step3:
```bash
mkdir Syn90k && cd Syn90k
mv /path/to/mjsynth.tar.gz .
tar -xzf mjsynth.tar.gz
mv /path/to/shuffle_labels.txt .
# create soft link
cd /path/to/mmocr/data/mixture
ln -s /path/to/Syn90k Syn90k
```
- For `SynthText`:
- Step1: Download `SynthText.zip` from this [link](https://www.robots.ox.ac.uk/~vgg/data/scenetext/)
- Step2: Download [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt)
- Step3: Download [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt)
- Step4:
```bash
unzip SynthText.zip
cd SynthText
mv /path/to/shuffle_labels.txt .
# create soft link
cd /path/to/mmocr/data/mixture
ln -s /path/to/SynthText SynthText
```
- For `SynthAdd`:
- Step1: Download `SynthText_Add.zip` from this [link](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/SynthText_Add.zip)
- Step2: Download [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)
- Step3:
```bash
mkdir SynthAdd && cd SynthAdd
mv /path/to/SynthText_Add.zip .
unzip SynthText_Add.zip
mv /path/to/label.txt .
# create soft link
cd /path/to/mmocr/data/mixture
ln -s /path/to/SynthAdd SynthAdd
```

View File

@ -0,0 +1,319 @@
# Getting Started
This page provides basic tutorials on the usage of MMOCR.
For the installation instructions, please see [INSTALL.md](INSTALL.md).
## Inference with Pretrained Models
We provide testing scripts to evaluate a full dataset, as well as some task-specific image demos.
### Test a Single Image
You can use the following command to test a single image with one GPU.
```shell
python demo/image_demo.py ${TEST_IMG} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${SAVE_PATH} [--imshow] [--device ${GPU_ID}]
```
If `--imshow` is specified, the demo will also show the image with OpenCV. For example:
```shell
python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/demo_text_det_pred.jpg
```
The predicted result will be saved as `demo/demo_text_det_pred.jpg`.
### Test Multiple Images
```shell
# for text detection
sh tools/test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
# for text recognition
sh tools/ocr_test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
```
It will save both the prediction results and visualized images to `${RESULTS_DIR}`
### Test a Dataset
MMOCR implements **distributed** testing with `MMDistributedDataParallel`. (Please refer to [datasets.md](datasets.md) to prepare your datasets)
#### Test with Single/Multiple GPUs
You can use the following command to test a dataset with single/multiple GPUs.
```shell
./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--eval ${EVAL_METRIC}]
```
For example,
```shell
./tools/dist_test.sh configs/example_config.py work_dirs/example_exp/example_model_20200202.pth 1 --eval hmean-iou
```
##### Optional Arguments
- `--eval`: Specify the evaluation metric. For text detection, the metric should be either 'hmean-ic13' or 'hmean-iou'. For text recognition, the metric should be 'acc'.
#### Test with Slurm
If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `slurm_test.sh`.
```shell
[GPUS=${GPUS}] ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--eval ${EVAL_METRIC}]
```
Here is an example of using 8 GPUs to test an example model on the 'dev' partition with job name 'test_job'.
```shell
GPUS=8 ./tools/slurm_test.sh dev test_job configs/example_config.py work_dirs/example_exp/example_model_20200202.pth --eval hmean-iou
```
You can check [slurm_test.sh](https://github.com/open-mmlab/mmocr/blob/master/tools/slurm_test.sh) for full arguments and environment variables.
##### Optional Arguments
- `--eval`: Specify the evaluation metric. For text detection, the metric should be either 'hmean-ic13' or 'hmean-iou'. For text recognition, the metric should be 'acc'.
## Train a Model
MMOCR implements **distributed** training with `MMDistributedDataParallel`. (Please refer to [datasets.md](datasets.md) to prepare your datasets)
All outputs (log files and checkpoints) will be saved to a working directory specified by `work_dir` in the config file.
By default, we evaluate the model on the validation set after several iterations. You can change the evaluation interval by adding the interval argument in the training config as follows:
```python
evaluation = dict(interval=1, by_epoch=True) # This evaluates the model per epoch.
```
### Train with Single/Multiple GPUs
```shell
./tools/dist_train.sh ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} [optional arguments]
```
Optional Arguments:
- `--no-validate` (**not suggested**): By default, the codebase will perform evaluation at every k-th iteration during training. To disable this behavior, use `--no-validate`.
#### Train with Toy Dataset.
We provide a toy dataset under `tests/data`, and you can train a toy model directly, before the academic dataset is prepared.
For example, train a text recognition task with `seg` method and toy dataset,
```
./tools/dist_train.sh configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py work_dirs/seg 1
```
And train a text recognition task with `sar` method and toy dataset,
```
./tools/dist_train.sh configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py work_dirs/sar 1
```
### Train with Slurm
If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `slurm_train.sh`.
```shell
[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}
```
Here is an example of using 8 GPUs to train a text detection model on the dev partition.
```shell
GPUS=8 ./tools/slurm_train.sh dev psenet-ic15 configs/textdet/psenet/psenet_r50_fpnf_sbn_1x_icdar2015.py /nfs/xxxx/psenet-ic15
```
You can check [slurm_train.sh](https://github.com/open-mmlab/mmocr/blob/master/tools/slurm_train.sh) for full arguments and environment variables.
### Launch Multiple Jobs on a Single Machine
If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
you need to specify different ports (29500 by default) for each job to avoid communication conflicts.
If you use `dist_train.sh` to launch training jobs, you can set the ports in the command shell.
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
```
If you launch training jobs with Slurm, you need to modify the config files to set different communication ports.
In `config1.py`,
```python
dist_params = dict(backend='nccl', port=29500)
```
In `config2.py`,
```python
dist_params = dict(backend='nccl', port=29501)
```
Then you can launch two jobs with `config1.py` ang `config2.py`.
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR}
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
```
## Useful Tools
We provide numerous useful tools under `mmocr/tools` directory.
### Publish a Model
Before you upload a model to AWS, you may want to
(1) convert the model weights to CPU tensors, (2) delete the optimizer states and
(3) compute the hash of the checkpoint file and append the hash id to the filename.
```shell
python tools/publish_model.py ${INPUT_FILENAME} ${OUTPUT_FILENAME}
```
E.g.,
```shell
python tools/publish_model.py work_dirs/psenet/latest.pth psenet_r50_fpnf_sbn_1x_20190801.pth
```
The final output filename will be `psenet_r50_fpnf_sbn_1x_20190801-{hash id}.pth`.
## Customized Settings
### Flexible Dataset
To support the tasks of `text detection`, `text recognition` and `key information extraction`, we have designed a new type of dataset which consists of `loader` and `parser` to load and parse different types of annotation files.
- **loader**: Load the annotation file. There are two types of loader, `HardDiskLoader` and `LmdbLoader`
- `HardDiskLoader`: Load `txt` format annotation file from hard disk to memory.
- `LmdbLoader`: Load `lmdb` format annotation file with lmdb backend, which is very useful for **extremely large** annotation files to avoid out-of-memory problem when ten or more GPUs are used, since each GPU will start multiple processes to load annotation file to memory.
- **parser**: Parse the annotation file line-by-line and return with `dict` format. There are two types of parser, `LineStrParser` and `LineJsonParser`.
- `LineStrParser`: Parse one line in ann file while treating it as a string and separating it to several parts by a `separator`. It can be used on tasks with simple annotation files such as text recognition where each line of the annotation files contains the `filename` and `label` attribute only.
- `LineJsonParser`: Parse one line in ann file while treating it as a json-string and using `json.loads` to convert it to `dict`. It can be used on tasks with complex annotation files such as text detection where each line of the annotation files contains multiple attributes (e.g. `filename`, `height`, `width`, `box`, `segmentation`, `iscrowd`, `category_id`, etc.).
Here we show some examples of using different combination of `loader` and `parser`.
#### Encoder-Decoder-Based Text Recognition Task
```python
dataset_type = 'OCRDataset'
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
train_anno_file = 'tests/data/ocr_toy_dataset/label.txt'
train = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=train_anno_file,
loader=dict(
type='HardDiskLoader',
repeat=10,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
```
You can check the content of the annotation file in `tests/data/ocr_toy_dataset/label.txt`.
The combination of `HardDiskLoader` and `LineStrParser` will return a dict for each file by calling `__getitem__`: `{'filename': '1223731.jpg', 'text': 'GRAND'}`.
##### Optional Arguments:
- `repeat`: The number of repeated lines in the annotation files. For example, if there are `10` lines in the annotation file, setting `repeat=10` will generate a corresponding annotation file with size `100`.
If the annotation file is extreme large, you can convert it from txt format to lmdb format with the following command:
```python
python tools/data_converter/txt2lmdb.py -i ann_file.txt -o ann_file.lmdb
```
After that, you can use `LmdbLoader` in dataset like below.
```python
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
train_anno_file = 'tests/data/ocr_toy_dataset/label.lmdb'
train = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=train_anno_file,
loader=dict(
type='LmdbLoader',
repeat=10,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
```
#### Segmentation-Based Text Recognition Task
```python
prefix = 'tests/data/ocr_char_ann_toy_dataset/'
train = dict(
type='OCRSegDataset',
img_prefix=prefix + 'imgs',
ann_file=prefix + 'instances_train.txt',
loader=dict(
type='HardDiskLoader',
repeat=10,
parser=dict(
type='LineJsonParser',
keys=['file_name', 'annotations', 'text'])),
pipeline=train_pipeline,
test_mode=True)
```
You can check the content of the annotation file in `tests/data/ocr_char_ann_toy_dataset/instances_train.txt`.
The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for each file by calling `__getitem__` each time:
```python
{"file_name": "resort_88_101_1.png", "annotations": [{"char_text": "F", "char_box": [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]}, {"char_text": "r", "char_box": [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]}, {"char_text": "o", "char_box": [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]}, {"char_text": "m", "char_box": [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]}, {"char_text": ":", "char_box": [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]}], "text": "From:"}
```
#### Text Detection Task
```python
dataset_type = 'TextDetDataset'
img_prefix = 'tests/data/toy_dataset/imgs'
test_anno_file = 'tests/data/toy_dataset/instances_test.txt'
test = dict(
type=dataset_type,
img_prefix=img_prefix,
ann_file=test_anno_file,
loader=dict(
type='HardDiskLoader',
repeat=4,
parser=dict(
type='LineJsonParser',
keys=['file_name', 'height', 'width', 'annotations'])),
pipeline=test_pipeline,
test_mode=True)
```
The results are generated in the same way as the segmentation-based text recognition task above.
You can check the content of the annotation file in `tests/data/toy_dataset/instances_test.txt`.
The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for each file by calling `__getitem__`:
```python
{"file_name": "test/img_10.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [260.0, 138.0, 24.0, 20.0], "segmentation": [[261, 138, 284, 140, 279, 158, 260, 158]]}, {"iscrowd": 0, "category_id": 1, "bbox": [288.0, 138.0, 129.0, 23.0], "segmentation": [[288, 138, 417, 140, 416, 161, 290, 157]]}, {"iscrowd": 0, "category_id": 1, "bbox": [743.0, 145.0, 37.0, 18.0], "segmentation": [[743, 145, 779, 146, 780, 163, 746, 163]]}, {"iscrowd": 0, "category_id": 1, "bbox": [783.0, 129.0, 50.0, 26.0], "segmentation": [[783, 129, 831, 132, 833, 155, 785, 153]]}, {"iscrowd": 1, "category_id": 1, "bbox": [831.0, 133.0, 43.0, 23.0], "segmentation": [[831, 133, 870, 135, 874, 156, 835, 155]]}, {"iscrowd": 1, "category_id": 1, "bbox": [159.0, 204.0, 72.0, 15.0], "segmentation": [[159, 205, 230, 204, 231, 218, 159, 219]]}, {"iscrowd": 1, "category_id": 1, "bbox": [785.0, 158.0, 75.0, 21.0], "segmentation": [[785, 158, 856, 158, 860, 178, 787, 179]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1011.0, 157.0, 68.0, 16.0], "segmentation": [[1011, 157, 1079, 160, 1076, 173, 1011, 170]]}]}
```
### COCO-like Dataset
For text detection, you can also use an annotation file in a COCO format that is defined in [mmdet](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py):
```python
dataset_type = 'IcdarDataset'
prefix = 'tests/data/toy_dataset/'
test=dict(
type=dataset_type,
ann_file=prefix + 'instances_test.json',
img_prefix=prefix + 'imgs',
pipeline=test_pipeline)
```
You can check the content of the annotation file in `tests/data/toy_dataset/instances_test.json`
- The icdar2015/2017 annotations have to be converted into the COCO format using `tools/data_converter/icdar_converter.py`:
```shell
python tools/data_converter/icdar_converter.py ${src_root_path} -o ${out_path} -d ${data_type} --split-list training validation test
```
- The ctw1500 annotations have to be converted into the COCO format using `tools/data_converter/ctw1500_converter.py`:
```shell
python tools/data_converter/ctw1500_converter.py ${src_root_path} -o ${out_path} --split-list training test
```
```

38
docs/index.rst 100644
View File

@ -0,0 +1,38 @@
Welcome to MMOCR's documentation!
=======================================
.. toctree::
:maxdepth: 2
:caption: Get Started
install.md
getting_started.md
technical_details.md
contributing.md
.. toctree::
:maxdepth: 2
:caption: Model Zoo
modelzoo.md
textdet_models.md
textrecog_models.md
kie_models.md
.. toctree::
:maxdepth: 2
:caption: Notes
changelog.md
faq.md
.. toctree::
:caption: API Reference
api.rst
Indices and tables
==================
* :ref:`genindex`
* :ref:`search`

231
docs/install.md 100644
View File

@ -0,0 +1,231 @@
# Installation
## Prerequisites
- Linux (Windows is not officially supported)
- Python 3.7
- PyTorch 1.5
- torchvision 0.6.0
- CUDA 10.1
- NCCL 2
- GCC 5.4.0 or higher
- [mmcv](https://github.com/open-mmlab/mmcv) 1.2.6
We have tested the following versions of OS and softwares:
- OS: Ubuntu 16.04
- CUDA: 10.1
- GCC(G++): 5.4.0
- mmcv 1.2.6
- PyTorch 1.5
- torchvision 0.6.0
MMOCR depends on Pytorch and mmdetection v2.9.0.
## Step-by-Step Installation Instructions
a. Create a conda virtual environment and activate it.
```shell
conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
```
b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/), e.g.,
```shell
conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch
```
Note: Make sure that your compilation CUDA version and runtime CUDA version match.
You can check the supported CUDA version for precompiled packages on the [PyTorch website](https://pytorch.org/).
`E.g. 1` If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install
PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1.
```python
conda install pytorch cudatoolkit=10.1 torchvision -c pytorch
```
`E.g. 2` If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install
PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2.
```python
conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch
```
If you build PyTorch from source instead of installing the prebuilt package,
you can use more CUDA versions such as 9.0.
c. Create a folder called `code` and clone the mmcv repository into it.
```shell
mkdir code
cd code
git clone https://github.com/open-mmlab/mmcv.git
cd mmcv
git checkout -b v1.2.6 v1.2.6
pip install -r requirements.txt
MMCV_WITH_OPS=1 pip install -v -e .
```
d. Clone the mmdetection repository into it. The mmdetection repo is separate from the mmcv repo in `code`.
```shell
cd ..
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
git checkout -b v2.9.0 v2.9.0
pip install -r requirements.txt
pip install -v -e .
export PYTHONPATH=$(pwd):$PYTHONPATH
```
Note that we have tested mmdetection v2.9.0 only. Other versions might be incompatible.
e. Clone the mmocr repository into it. The mmdetection repo is separate from the mmcv and mmdetection repo in `code`.
```shell
cd ..
git clone git@gitlab.sz.sensetime.com:kuangzhh/mmocr.git
cd mmocr
```
f. Install build requirements and then install MMOCR.
```shell
pip install -r requirements.txt
pip install -v -e . # or "python setup.py build_ext --inplace"
export PYTHONPATH=$(pwd):$PYTHONPATH
```
## Full Set-up Script
Here is the full script for setting up mmocr with conda.
```shell
conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
# install latest pytorch prebuilt with the default prebuilt CUDA version (usually the latest)
conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch
# install mmcv
mkdir code
cd code
git clone https://github.com/open-mmlab/mmcv.git
cd mmcv # code/mmcv
git checkout -b v1.2.6 v1.2.6
pip install -r requirements.txt
MMCV_WITH_OPS=1 pip install -v -e .
# install mmdetection
cd .. # exit to code
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection # code/mmdetection
git checkout -b v2.9.0 v2.9.0
pip install -r requirements.txt
pip install -v -e .
export PYTHONPATH=$(pwd):$PYTHONPATH
# install mmocr
cd ..
git clone git@gitlab.sz.sensetime.com:kuangzhh/mmocr.git
cd mmocr # code/mmocr
pip install -r requirements.txt
pip install -v -e . # or "python setup.py build_ext --inplace"
export PYTHONPATH=$(pwd):$PYTHONPATH
```
## Another option: Docker Image
We provide a [Dockerfile](https://github.com/open-mmlab/mmocr/blob/master/docker/Dockerfile) to build an image.
```shell
# build an image with PyTorch 1.5, CUDA 10.1
docker build -t mmocr docker/
```
Run it with
```shell
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmocr/data mmocr
```
## Prepare Datasets
It is recommended to symlink the dataset root to `mmocr/data`. Please refer to [datasets.md](datasets.md) to prepare your datasets.
If your folder structure is different, you may need to change the corresponding paths in config files.
The `mmocr` folder is organized as follows:
```
mmocr
├── configs
│ ├── _base_
│ ├── textdet
│ └── textrecog
├── data
│ ├── icdar2015
│ ├── icdar2017
│ └── synthtext
├── demo
│ ├── demo_text_det.jpg
│ ├── demo_text_recog.jpg
│ ├── image_demo.py
│ └── webcam_demo.py
├── docs
│ ├── CHANGELOG.md
│ ├── CODE_OF_CONDUCT.md
│ ├── conf.py
│ ├── CONTRIBUTING.md
│ ├── GETTING_STARTED.md
│ ├── index.rst
│ ├── INSTALL.md
│ ├── make.bat
│ ├── Makefile
│ ├── MODEL_ZOO.md
│ ├── requirements.txt
│ ├── res
│ └── TECHNICAL_DETAILS.md
├── mmocr
│ ├── core
│ ├── datasets
│ ├── __init__.py
│ ├── models
│ ├── utils
│ └── version.py
├── README.md
├── requirements
│ ├── build.txt
│ ├── optional.txt
│ ├── runtime.txt
│ └── tests.txt
├── requirements.txt
├── resources
│ ├── illustration.jpg
│ └── mmocr-logo.jpg
├── setup.cfg
├── setup.py
├── tests
│ ├── data
│ ├── test_dataset
│ ├── test_metrics
│ ├── test_models
│ ├── test_tools
│ └── test_utils
└── tools
├── data_converter
├── dist_train.sh
├── dist_test.sh
├── ocr_test_imgs.py
├── ocr_test_imgs.sh
├── publish_model.py
├── slurm_test.sh
├── slurm_train.sh
├── test_imgs.py
├── test_imgs.sh
├── test.py
└── train.py
```
The icdar2017 official annotations can be converted into the coco format that mmocr supports using `code/mmocr/tools/data_converter/icdar_converter.py`.

36
docs/make.bat 100644
View File

@ -0,0 +1,36 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

10
docs/merge_docs.sh 100755
View File

@ -0,0 +1,10 @@
#!/usr/bin/env bash
sed -i '$a\\n' ../configs/kie/*/*.md
sed -i '$a\\n' ../configs/textdet/*/*.md
sed -i '$a\\n' ../configs/textrecog/*/*.md
# gather models
cat ../configs/kie/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Kie Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmediting/tree/master/=g' >kie_models.md
cat ../configs/textdet/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Detection Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmediting/tree/master/=g' >textdet_models.md
cat ../configs/textrecog/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmediting/tree/master/=g' >textrecog_models.md

View File

@ -0,0 +1,4 @@
recommonmark
sphinx
sphinx_markdown_tables
sphinx_rtd_theme

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

94
docs/stats.py 100755
View File

@ -0,0 +1,94 @@
#!/usr/bin/env python
import functools as func
import glob
import re
from os.path import basename, splitext
import numpy as np
import titlecase
def anchor(name):
return re.sub(r'-+', '-', re.sub(r'[^a-zA-Z0-9]', '-',
name.strip().lower())).strip('-')
# Count algorithms
files = sorted(glob.glob('*_models.md'))
# files = sorted(glob.glob('docs/*_models.md'))
stats = []
for f in files:
with open(f, 'r') as content_file:
content = content_file.read()
# title
title = content.split('\n')[0].replace('#', '')
# count papers
papers = set((papertype, titlecase.titlecase(paper.lower().strip()))
for (papertype, paper) in re.findall(
r'\n\s*\[([A-Z]+?)\]\s*\n.*?\btitle\s*=\s*{(.*?)}',
content, re.DOTALL))
# paper links
revcontent = '\n'.join(list(reversed(content.splitlines())))
paperlinks = {}
for _, p in papers:
print(p)
q = p.replace('\\', '\\\\').replace('?', '\\?')
paperlinks[p] = ' '.join(
(f'[⇨]({splitext(basename(f))[0]}.html#{anchor(paperlink)})'
for paperlink in re.findall(
rf'\btitle\s*=\s*{{\s*{q}\s*}}.*?\n## (.*?)\s*[,;]?\s*\n',
revcontent, re.DOTALL | re.IGNORECASE)))
print(' ', paperlinks[p])
paperlist = '\n'.join(
sorted(f' - [{t}] {x} ({paperlinks[x]})' for t, x in papers))
# count configs
configs = set(x.lower().strip()
for x in re.findall(r'https.*configs/.*\.py', content))
# count ckpts
ckpts = set(x.lower().strip()
for x in re.findall(r'https://download.*\.pth', content)
if 'mmaction' in x)
statsmsg = f"""
## [{title}]({f})
* Number of checkpoints: {len(ckpts)}
* Number of configs: {len(configs)}
* Number of papers: {len(papers)}
{paperlist}
"""
stats.append((papers, configs, ckpts, statsmsg))
allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _, _ in stats])
allconfigs = func.reduce(lambda a, b: a.union(b), [c for _, c, _, _ in stats])
allckpts = func.reduce(lambda a, b: a.union(b), [c for _, _, c, _ in stats])
msglist = '\n'.join(x for _, _, _, x in stats)
papertypes, papercounts = np.unique([t for t, _ in allpapers],
return_counts=True)
countstr = '\n'.join(
[f' - {t}: {c}' for t, c in zip(papertypes, papercounts)])
modelzoo = f"""
# Overview
* Number of checkpoints: {len(allckpts)}
* Number of configs: {len(allconfigs)}
* Number of papers: {len(allpapers)}
{countstr}
For supported datasets, see [datasets overview](datasets.md).
{msglist}
"""
with open('modelzoo.md', 'w') as f:
f.write(modelzoo)

View File

@ -0,0 +1,226 @@
# Technical Details
In this section, we will introduce the main units of training a detector:
data pipeline, model and iteration pipeline.
## Data pipeline
Following typical conventions, we use `Dataset` and `DataLoader` for data loading
with multiple workers. `Dataset` returns a dict of data items corresponding
the arguments of models' forward method.
Since the data in object detection may not be the same size (image size, gt bbox size, etc.),
we introduce a new `DataContainer` type in MMCV to help collect and distribute
data of different size.
See [here](https://github.com/open-mmlab/mmcv/blob/master/mmcv/parallel/data_container.py) for more details.
The data preparation pipeline and the dataset is decomposed. Usually a dataset
defines how to process the annotations and a data pipeline defines all the steps to prepare a data dict.
A pipeline consists of a sequence of operations. Each operation takes a dict as input and also output a dict for the next transform.
We present a classical pipeline in the following figure. The blue blocks are pipeline operations. With the pipeline going on, each operator can add new keys (marked as green) to the result dict or update the existing keys (marked as orange).
![pipeline figure](../demo/data_pipeline.png)
The operations are categorized into data loading, pre-processing, formatting and test-time augmentation.
Here is an pipeline example for Faster R-CNN.
```python
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
```
For each operation, we list the related dict fields that are added/updated/removed.
### Data loading
`LoadImageFromFile`
- add: img, img_shape, ori_shape
`LoadAnnotations`
- add: gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg, bbox_fields, mask_fields
`LoadProposals`
- add: proposals
### Pre-processing
`Resize`
- add: scale, scale_idx, pad_shape, scale_factor, keep_ratio
- update: img, img_shape, *bbox_fields, *mask_fields, *seg_fields
`RandomFlip`
- add: flip
- update: img, *bbox_fields, *mask_fields, *seg_fields
`Pad`
- add: pad_fixed_size, pad_size_divisor
- update: img, pad_shape, *mask_fields, *seg_fields
`RandomCrop`
- update: img, pad_shape, gt_bboxes, gt_labels, gt_masks, *bbox_fields
`Normalize`
- add: img_norm_cfg
- update: img
`SegRescale`
- update: gt_semantic_seg
`PhotoMetricDistortion`
- update: img
`Expand`
- update: img, gt_bboxes
`MinIoURandomCrop`
- update: img, gt_bboxes, gt_labels
`Corrupt`
- update: img
### Formatting
`ToTensor`
- update: specified by `keys`.
`ImageToTensor`
- update: specified by `keys`.
`Transpose`
- update: specified by `keys`.
`ToDataContainer`
- update: specified by `fields`.
`DefaultFormatBundle`
- update: img, proposals, gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks, gt_semantic_seg
`Collect`
- add: img_meta (the keys of img_meta is specified by `meta_keys`)
- remove: all other keys except for those specified by `keys`
### Test time augmentation
`MultiScaleFlipAug`
## Model
In MMDetection, model components are basically categorized as 4 types.
- backbone: usually a FCN network to extract feature maps, e.g., ResNet.
- neck: the part between backbones and heads, e.g., FPN, ASPP.
- head: the part for specific tasks, e.g., bbox prediction and mask prediction.
- roi extractor: the part for extracting features from feature maps, e.g., RoI Align.
We also write implement some general detection pipelines with the above components,
such as `SingleStageDetector` and `TwoStageDetector`.
### Build a model with basic components
Following some basic pipelines (e.g., two-stage detectors), the model structure
can be customized through config files with no pains.
If we want to implement some new components, e.g, the path aggregation
FPN structure in [Path Aggregation Network for Instance Segmentation](https://arxiv.org/abs/1803.01534), there are two things to do.
1. create a new file in `mmdet/models/necks/pafpn.py`.
```python
from ..registry import NECKS
@NECKS.register
class PAFPN(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False):
pass
def forward(self, inputs):
# implementation is ignored
pass
```
2. Import the module in `mmdet/models/necks/__init__.py`.
```python
from .pafpn import PAFPN
```
2. modify the config file from
```python
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5)
```
to
```python
neck=dict(
type='PAFPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5)
```
We will release more components (backbones, necks, heads) for research purpose.
### Write a new model
To write a new detection pipeline, you need to inherit from `BaseDetector`,
which defines the following abstract methods.
- `extract_feat()`: given an image batch of shape (n, c, h, w), extract the feature map(s).
- `forward_train()`: forward method of the training mode
- `simple_test()`: single scale testing without augmentation
- `aug_test()`: testing with augmentation (multi-scale, flip, etc.)
[TwoStageDetector](https://github.com/hellock/mmdetection/blob/master/mmdet/models/detectors/two_stage.py)
is a good example which shows how to do that.
## Iteration pipeline
We adopt distributed training for both single machine and multiple machines.
Supposing that the server has 8 GPUs, 8 processes will be started and each process runs on a single GPU.
Each process keeps an isolated model, data loader, and optimizer.
Model parameters are only synchronized once at the beginning.
After a forward and backward pass, gradients will be allreduced among all GPUs,
and the optimizer will update model parameters.
Since the gradients are allreduced, the model parameter stays the same for all processes after the iteration.
## Other information
For more information, please refer to our [technical report](https://arxiv.org/abs/1906.07155).

View File

@ -0,0 +1,272 @@
import cv2
import imgaug
import imgaug.augmenters as iaa
import numpy as np
from mmdet.core.mask import PolygonMasks
from mmdet.datasets.builder import PIPELINES
class AugmenterBuilder:
"""Build imgaug object according ImgAug argmentations."""
def __init__(self):
pass
def build(self, args, root=True):
if args is None:
return None
elif isinstance(args, (int, float, str)):
return args
elif isinstance(args, list):
if root:
sequence = [self.build(value, root=False) for value in args]
return iaa.Sequential(sequence)
arg_list = [self.to_tuple_if_list(a) for a in args[1:]]
return getattr(iaa, args[0])(*arg_list)
elif isinstance(args, dict):
if 'cls' in args:
cls = getattr(iaa, args['cls'])
return cls(
**{
k: self.to_tuple_if_list(v)
for k, v in args.items() if not k == 'cls'
})
else:
return {
key: self.build(value, root=False)
for key, value in args.items()
}
else:
raise RuntimeError('unknown augmenter arg: ' + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
return tuple(obj)
return obj
@PIPELINES.register_module()
class ImgAug:
"""A wrapper to use imgaug https://github.com/aleju/imgaug.
Args:
args ([list[list|dict]]): The argumentation list. For details, please
refer to imgaug document. Take args=[['Fliplr', 0.5],
dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an
example. The args horizontally flip images with probability 0.5,
followed by random rotation with angles in range [-10, 10], and
resize with an independent scale in range [0.5, 3.0] for each
side of images.
"""
def __init__(self, args=None):
self.augmenter_args = args
self.augmenter = AugmenterBuilder().build(self.augmenter_args)
def __call__(self, results):
# img is bgr
image = results['img']
aug = None
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
results['img'] = aug.augment_image(image)
results['img_shape'] = results['img'].shape
results['flip'] = 'unknown' # it's unknown
results['flip_direction'] = 'unknown' # it's unknown
target_shape = results['img_shape']
self.may_augment_annotation(aug, shape, target_shape, results)
return results
def may_augment_annotation(self, aug, shape, target_shape, results):
if aug is None:
return results
for key in results['mask_fields']:
# augment polygon mask
masks = []
for mask in results[key]:
masks.append(
[self.may_augment_poly(aug, shape, target_shape, mask[0])])
if len(masks) > 0:
results[key] = PolygonMasks(masks, *target_shape[:2])
for key in results['bbox_fields']:
# augment bbox
bboxes = []
for bbox in results[key]:
bbox = self.may_augment_poly(aug, shape, target_shape, bbox)
bboxes.append(bbox)
results[key] = np.zeros(0)
if len(bboxes) > 0:
results[key] = np.stack(bboxes)
return results
def may_augment_poly(self, aug, img_shape, target_shape, poly):
# poly n x 2
poly = poly.reshape(-1, 2)
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
poly = [[p.x, p.y] for p in keypoints]
poly = np.array(poly).flatten()
return poly
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str
@PIPELINES.register_module()
class EastRandomCrop:
def __init__(self,
target_size=(640, 640),
max_tries=10,
min_crop_side_ratio=0.1):
self.target_size = target_size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
def __call__(self, results):
# sampling crop
# crop image, boxes, masks
img = results['img']
crop_x, crop_y, crop_w, crop_h = self.crop_area(
img, results['gt_masks'])
scale_w = self.target_size[0] / crop_w
scale_h = self.target_size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
padimg = np.zeros(
(self.target_size[1], self.target_size[0], img.shape[2]),
img.dtype)
padimg[:h, :w] = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
# for bboxes
for key in results['bbox_fields']:
lines = []
for box in results[key]:
box = box.reshape(2, 2)
poly = ((box - (crop_x, crop_y)) * scale)
if not self.is_poly_outside_rect(poly, 0, 0, w, h):
lines.append(poly.flatten())
results[key] = np.array(lines)
# for masks
for key in results['mask_fields']:
polys = []
polys_label = []
for poly in results[key]:
poly = np.array(poly).reshape(-1, 2)
poly = ((poly - (crop_x, crop_y)) * scale)
if not self.is_poly_outside_rect(poly, 0, 0, w, h):
polys.append([poly])
polys_label.append(0)
results[key] = PolygonMasks(polys, *self.target_size)
if key == 'gt_masks':
results['gt_labels'] = polys_label
results['img'] = padimg
results['img_shape'] = padimg.shape
return results
def is_poly_in_rect(self, poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(self, poly, x, y, w, h):
poly = np.array(poly).reshape(-1, 2)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(self, axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(self, axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(self, regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(self, img, polys):
h, w, _ = img.shape
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in polys:
points = np.round(
points, decimals=0).astype(np.int32).reshape(-1, 2)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = self.split_regions(h_axis)
w_regions = self.split_regions(w_axis)
for i in range(self.max_tries):
if len(w_regions) > 1:
xmin, xmax = self.region_wise_random_select(w_regions, w)
else:
xmin, xmax = self.random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = self.region_wise_random_select(h_regions, h)
else:
ymin, ymax = self.random_select(h_axis, h)
if xmax - xmin < self.min_crop_side_ratio * w or \
ymax - ymin < self.min_crop_side_ratio * h:
# area too small
continue
num_poly_in_rect = 0
for poly in polys:
if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
ymax - ymin):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h

View File

@ -0,0 +1,238 @@
import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon
from mmdet.core import BitmapMasks
from mmdet.datasets.builder import PIPELINES
from . import BaseTextDetTargets
@PIPELINES.register_module()
class DBNetTargets(BaseTextDetTargets):
"""Generate gt shrinked text, gt threshold map, and their effective region
masks to learn DBNet: Real-time Scene Text Detection with Differentiable
Binarization [https://arxiv.org/abs/1911.08947]. This was partially adapted
from https://github.com/MhLiao/DB.
Args:
shrink_ratio (float): The area shrinked ratio between text
kernels and their text masks.
thr_min (float): The minimum value of the threshold map.
thr_max (float): The maximum value of the threshold map.
min_short_size (int): The minimum size of polygon below which
the polygon is invalid.
"""
def __init__(self,
shrink_ratio=0.4,
thr_min=0.3,
thr_max=0.7,
min_short_size=8):
super().__init__()
self.shrink_ratio = shrink_ratio
self.thr_min = thr_min
self.thr_max = thr_max
self.min_short_size = min_short_size
def find_invalid(self, results):
"""Find invalid polygons.
Args:
results (dict): The dict containing gt_mask.
Returns:
ignore_tags (list[bool]): The indicators for ignoring polygons.
"""
texts = results['gt_masks'].masks
ignore_tags = [False] * len(texts)
for inx, text in enumerate(texts):
if self.invalid_polygon(text[0]):
ignore_tags[inx] = True
return ignore_tags
def invalid_polygon(self, poly):
"""Judge the input polygon is invalid or not. It is invalid if its area
smaller than 1 or the shorter side of its minimum bounding box smaller
than min_short_size.
Args:
poly (ndarray): The polygon boundary point sequence.
Returns:
True/False (bool): Whether the polygon is invalid.
"""
area = self.polygon_area(poly)
if abs(area) < 1:
return True
short_size = min(self.polygon_size(poly))
if short_size < self.min_short_size:
return True
return False
def ignore_texts(self, results, ignore_tags):
"""Ignore gt masks and gt_labels while padding gt_masks_ignore in
results given ignore_tags.
Args:
results (dict): Result for one image.
ignore_tags (list[int]): Indicate whether to ignore its
corresponding ground truth text.
Returns:
results (dict): Results after filtering.
"""
flag_len = len(ignore_tags)
assert flag_len == len(results['gt_masks'].masks)
assert flag_len == len(results['gt_labels'])
results['gt_masks_ignore'].masks += [
mask for i, mask in enumerate(results['gt_masks'].masks)
if ignore_tags[i]
]
results['gt_masks'].masks = [
mask for i, mask in enumerate(results['gt_masks'].masks)
if not ignore_tags[i]
]
results['gt_labels'] = np.array([
mask for i, mask in enumerate(results['gt_labels'])
if not ignore_tags[i]
])
return results
def generate_thr_map(self, img_size, polygons):
"""Generate threshold map.
Args:
img_size (tuple(int)): The image size (h,w)
polygons (list(ndarray)): The polygon list.
Returns:
thr_map (ndarray): The generated threshold map.
thr_mask (ndarray): The effective mask of threshold map.
"""
thr_map = np.zeros(img_size, dtype=np.float32)
thr_mask = np.zeros(img_size, dtype=np.uint8)
for polygon in polygons:
self.draw_border_map(polygon[0], thr_map, mask=thr_mask)
thr_map = thr_map * (self.thr_max - self.thr_min) + self.thr_min
return thr_map, thr_mask
def draw_border_map(self, polygon, canvas, mask):
"""Generate threshold map for one polygon.
Args:
polygon(ndarray): The polygon boundary ndarray.
canvas(ndarray): The generated threshold map.
mask(ndarray): The generated threshold mask.
"""
polygon = polygon.reshape(-1, 2)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
distance = polygon_shape.area * \
(1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(p) for p in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = padding.Execute(distance)
if len(padded_polygon) > 0:
padded_polygon = np.array(padded_polygon[0])
else:
print(f'padding {polygon} with {distance} gets {padded_polygon}')
padded_polygon = polygon.copy().astype(np.int32)
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
x_min = padded_polygon[:, 0].min()
x_max = padded_polygon[:, 0].max()
y_min = padded_polygon[:, 1].min()
y_max = padded_polygon[:, 1].max()
width = x_max - x_min + 1
height = y_max - y_min + 1
polygon[:, 0] = polygon[:, 0] - x_min
polygon[:, 1] = polygon[:, 1] - y_min
xs = np.broadcast_to(
np.linspace(0, width - 1, num=width).reshape(1, width),
(height, width))
ys = np.broadcast_to(
np.linspace(0, height - 1, num=height).reshape(height, 1),
(height, width))
distance_map = np.zeros((polygon.shape[0], height, width),
dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
x_min_valid = min(max(0, x_min), canvas.shape[1] - 1)
x_max_valid = min(max(0, x_max), canvas.shape[1] - 1)
y_min_valid = min(max(0, y_min), canvas.shape[0] - 1)
y_max_valid = min(max(0, y_max), canvas.shape[0] - 1)
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1] = np.fmax(
1 - distance_map[y_min_valid - y_min:y_max_valid - y_max +
height, x_min_valid - x_min:x_max_valid -
x_max + width],
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1])
def generate_targets(self, results):
"""Generate the gt targets for DBNet.
Args:
results (dict): The input result dictionary.
Returns:
results (dict): The output result dictionary.
"""
assert isinstance(results, dict)
polygons = results['gt_masks'].masks
if 'bbox_fields' in results:
results['bbox_fields'].clear()
ignore_tags = self.find_invalid(results)
h, w, _ = results['img_shape']
gt_shrink, ignore_tags = self.generate_kernels((h, w),
polygons,
self.shrink_ratio,
ignore_tags=ignore_tags)
results = self.ignore_texts(results, ignore_tags)
# polygons and polygons_ignore reassignment.
polygons = results['gt_masks'].masks
polygons_ignore = results['gt_masks_ignore'].masks
gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore)
gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons)
results['mask_fields'].clear() # rm gt_masks encoded by polygons
results.pop('gt_labels', None)
results.pop('gt_masks', None)
results.pop('gt_bboxes', None)
results.pop('gt_bboxes_ignore', None)
mapping = {
'gt_shrink': gt_shrink,
'gt_shrink_mask': gt_shrink_mask,
'gt_thr': gt_thr,
'gt_thr_mask': gt_thr_mask
}
for key, value in mapping.items():
value = value if isinstance(value, list) else [value]
results[key] = BitmapMasks(value, h, w)
results['mask_fields'].append(key)
return results

View File

@ -0,0 +1,86 @@
import torch
import torch.nn as nn
from mmdet.models.builder import HEADS, build_loss
from .head_mixin import HeadMixin
@HEADS.register_module()
class DBHead(HeadMixin, nn.Module):
"""The class for DBNet head.
This was partially adapted from https://github.com/MhLiao/DB
"""
def __init__(self,
in_channels,
with_bias=False,
decoding_type='db',
text_repr_type='poly',
downsample_ratio=1.0,
loss=dict(type='DBLoss'),
train_cfg=None,
test_cfg=None):
"""Initialization.
Args:
in_channels (int): The number of input channels of the db head.
decoding_type (str): The type of decoder for dbnet.
text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
downsample_ratio (float): The downsample ratio of ground truths.
loss (dict): The type of loss for dbnet.
"""
super().__init__()
assert isinstance(in_channels, int)
self.in_channels = in_channels
self.text_repr_type = text_repr_type
self.loss_module = build_loss(loss)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.downsample_ratio = downsample_ratio
self.decoding_type = decoding_type
self.binarize = nn.Sequential(
nn.Conv2d(
in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
self.threshold = self._init_thr(in_channels)
def init_weights(self):
self.binarize.apply(self.init_class_parameters)
self.threshold.apply(self.init_class_parameters)
def init_class_parameters(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight.data)
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1.)
m.bias.data.fill_(1e-4)
def diff_binarize(self, prob_map, thr_map, k):
return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
def forward(self, inputs):
prob_map = self.binarize(inputs)
thr_map = self.threshold(inputs)
binary_map = self.diff_binarize(prob_map, thr_map, k=50)
outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
return (outputs, )
def _init_thr(self, inner_channels, bias=False):
in_channels = inner_channels
seq = nn.Sequential(
nn.Conv2d(
in_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
return seq

View File

@ -0,0 +1,23 @@
from mmdet.models.builder import DETECTORS
from . import SingleStageTextDetector, TextDetectorMixin
@DETECTORS.register_module()
class DBNet(TextDetectorMixin, SingleStageTextDetector):
"""The class for implementing DBNet text detector: Real-time Scene Text
Detection with Differentiable Binarization.
[https://arxiv.org/abs/1911.08947].
"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
show_score=False):
SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
train_cfg, test_cfg, pretrained)
TextDetectorMixin.__init__(self, show_score)

View File

@ -0,0 +1,169 @@
import torch
import torch.nn.functional as F
from torch import nn
from mmdet.models.builder import LOSSES
from mmocr.core.visualize import show_feature # noqa F401
from mmocr.models.common.losses.dice_loss import DiceLoss
@LOSSES.register_module()
class DBLoss(nn.Module):
"""The class for implementing DBNet loss.
This is partially adapted from https://github.com/MhLiao/DB.
"""
def __init__(self,
alpha=1,
beta=1,
reduction='mean',
negative_ratio=3.0,
eps=1e-6,
bbce_loss=False):
"""Initialization.
Args:
alpha (float): The binary loss coef.
beta (float): The threshold loss coef.
reduction (str): The way to reduce the loss.
negative_ratio (float): The ratio of positives to negatives.
eps (float): Epsilon in the threshold loss function.
bbce_loss (bool): Whether to use balanced bce for probability loss.
If False, dice loss will be used instead.
"""
super().__init__()
assert reduction in ['mean',
'sum'], " reduction must in ['mean','sum']"
self.alpha = alpha
self.beta = beta
self.reduction = reduction
self.negative_ratio = negative_ratio
self.eps = eps
self.bbce_loss = bbce_loss
self.dice_loss = DiceLoss(eps=eps)
def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitMasks]): The BitMasks list. Each item is for
one img.
target_sz (tuple(int, int)): The target tensor size of KxHxW
with K being the number of kernels.
Returns
result_tensors (list[tensor]): The list of kernel tensors. Each
element is for one kernel level.
"""
assert isinstance(bitmasks, list)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_levels = len(bitmasks[0])
result_tensors = []
for level_inx in range(num_levels):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
mask_sz = mask.shape
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
result_tensors.append(kernel)
return result_tensors
def balance_bce_loss(self, pred, gt, mask):
positive = (gt * mask)
negative = ((1 - gt) * mask)
positive_count = int(positive.float().sum())
negative_count = min(
int(negative.float().sum()),
int(positive_count * self.negative_ratio))
assert gt.max() <= 1 and gt.min() >= 0
assert pred.max() <= 1 and pred.min() >= 0
loss = F.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = loss * positive.float()
negative_loss = loss * negative.float()
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
positive_count + negative_count + self.eps)
return balance_loss
def l1_thr_loss(self, pred, gt, mask):
thr_loss = torch.abs((pred - gt) * mask).sum() / (
mask.sum() + self.eps)
return thr_loss
def forward(self, preds, downsample_ratio, gt_shrink, gt_shrink_mask,
gt_thr, gt_thr_mask):
"""Compute DBNet loss.
Args:
preds (tensor): The output tensor with size of Nx3xHxW.
downsample_ratio (float): The downsample ratio for the
ground truths.
gt_shrink (list[BitmapMasks]): The mask list with each element
being the shrinked text mask for one img.
gt_shrink_mask (list[BitmapMasks]): The effective mask list with
each element being the shrinked effective mask for one img.
gt_thr (list[BitmapMasks]): The mask list with each element
being the threshold text mask for one img.
gt_thr_mask (list[BitmapMasks]): The effective mask list with
each element being the threshold effective mask for one img.
Returns:
results(dict): The dict for dbnet losses with loss_prob,
loss_db and loss_thresh.
"""
assert isinstance(downsample_ratio, float)
assert isinstance(gt_shrink, list)
assert isinstance(gt_shrink_mask, list)
assert isinstance(gt_thr, list)
assert isinstance(gt_thr_mask, list)
preds = preds[0]
pred_prob = preds[:, 0, :, :]
pred_thr = preds[:, 1, :, :]
pred_db = preds[:, 2, :, :]
feature_sz = preds.size()
keys = ['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask']
gt = {}
for k in keys:
gt[k] = eval(k)
gt[k] = [item.rescale(downsample_ratio) for item in gt[k]]
gt[k] = self.bitmasks2tensor(gt[k], feature_sz[2:])
gt[k] = [item.to(preds.device) for item in gt[k]]
gt['gt_shrink'][0] = (gt['gt_shrink'][0] > 0).float()
if self.bbce_loss:
loss_prob = self.balance_bce_loss(pred_prob, gt['gt_shrink'][0],
gt['gt_shrink_mask'][0])
else:
loss_prob = self.dice_loss(pred_prob, gt['gt_shrink'][0],
gt['gt_shrink_mask'][0])
loss_db = self.dice_loss(pred_db, gt['gt_shrink'][0],
gt['gt_shrink_mask'][0])
loss_thr = self.l1_thr_loss(pred_thr, gt['gt_thr'][0],
gt['gt_thr_mask'][0])
results = dict(
loss_prob=self.alpha * loss_prob,
loss_db=loss_db,
loss_thr=self.beta * loss_thr)
return results

@ -0,0 +1 @@
Subproject commit 4c96c2754d6785fa19663f3c62e54470ec185862