diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 00000000..a7ee6382
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,3 @@
+[run]
+omit =
+ */__init__.py
diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md
new file mode 100644
index 00000000..efd43057
--- /dev/null
+++ b/.github/CODE_OF_CONDUCT.md
@@ -0,0 +1,76 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to making participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies both within project spaces and in public spaces
+when an individual is representing the project or its community. Examples of
+representing a project or community include using an official project e-mail
+address, posting via an official social media account, or acting as an appointed
+representative at an online or offline event. Representation of a project may be
+further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at chenkaidev@gmail.com. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md
new file mode 100644
index 00000000..4a6fe184
--- /dev/null
+++ b/.github/CONTRIBUTING.md
@@ -0,0 +1 @@
+We appreciate all contributions to improve MMOCR. Please refer to [CONTRIBUTING.md](/docs/contributing.md) in MMCV for more details about the contributing guideline.
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
new file mode 100644
index 00000000..f9b07021
--- /dev/null
+++ b/.github/workflows/build.yml
@@ -0,0 +1,68 @@
+name: build
+
+on:
+ push:
+ branches:
+ - master
+
+ pull_request:
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.7
+ - name: Install pre-commit hook
+ run: |
+ pip install pre-commit
+ pre-commit install
+ - name: Linting
+ run: pre-commit run --all-files
+ - name: Check docstring coverage
+ run: |
+ pip install interrogate
+ interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 50 mmocr
+
+ build_cpu:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [3.7]
+ torch: [1.5.0, 1.6.0, 1.7.0]
+ include:
+ - torch: 1.5.0
+ torchvision: 0.6.0
+ - torch: 1.6.0
+ torchvision: 0.7.0
+ - torch: 1.7.0
+ torchvision: 0.8.1
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Upgrade pip
+ run: pip install pip --upgrade
+ - name: Install Pillow
+ run: pip install Pillow==6.2.2
+ if: ${{matrix.torchvision == '0.4.1'}}
+ - name: Install PyTorch
+ run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Install MMCV
+ run: pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
+ - name: Install MMDet
+ run: pip install git+https://github.com/open-mmlab/mmdetection/
+ - name: Install other dependencies
+ run: pip install -r requirements.txt
+ - name: Build and install
+ run: rm -rf .eggs && pip install -e .
+ - name: Run unittests and generate coverage report
+ run: |
+ coverage run --branch --source mmocr -m pytest tests/
+ coverage xml
+ coverage report -m
diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml
new file mode 100644
index 00000000..af71b058
--- /dev/null
+++ b/.github/workflows/publish-to-pypi.yml
@@ -0,0 +1,20 @@
+name: deploy
+
+on: push
+
+jobs:
+ build-n-publish:
+ runs-on: ubuntu-latest
+ if: startsWith(github.event.ref, 'refs/tags')
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.7
+ - name: Build MMOCR
+ run: python setup.py sdist
+ - name: Publish distribution to PyPI
+ run: |
+ pip install twine
+ twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..52720295
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,138 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+*.ipynb
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# cython generated cpp
+!data/dict
+data/*
+.vscode
+.idea
+
+# custom
+*.pkl
+*.pkl.json
+*.log.json
+work_dirs/
+exps/
+*~
+show_dir/
+
+# Pytorch
+*.pth
+
+# demo
+!tests/data
+tests/results
+
+#temp files
+.DS_Store
+
+checkpoints
+
+htmlcov
+*.swp
+log.txt
+workspace.code-workspace
+results
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..38324d51
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,36 @@
+exclude: ^tests/data/
+repos:
+ - repo: https://gitlab.com/pycqa/flake8
+ rev: 3.8.1
+ hooks:
+ - id: flake8
+ - repo: https://github.com/asottile/seed-isort-config
+ rev: v2.2.0
+ hooks:
+ - id: seed-isort-config
+ - repo: https://github.com/timothycrosley/isort
+ rev: 4.3.21
+ hooks:
+ - id: isort
+ - repo: https://github.com/pre-commit/mirrors-yapf
+ rev: v0.30.0
+ hooks:
+ - id: yapf
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v3.1.0
+ hooks:
+ - id: trailing-whitespace
+ - id: check-yaml
+ - id: end-of-file-fixer
+ - id: requirements-txt-fixer
+ - id: double-quote-string-fixer
+ - id: check-merge-conflict
+ - id: fix-encoding-pragma
+ args: ["--remove"]
+ - id: mixed-line-ending
+ args: ["--fix=lf"]
+ - repo: https://github.com/myint/docformatter
+ rev: v1.3.1
+ hooks:
+ - id: docformatter
+ args: ["--in-place", "--wrap-descriptions", "79"]
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 00000000..d7a39be8
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,621 @@
+[MASTER]
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-whitelist=
+
+# Specify a score threshold to be exceeded before program exits with error.
+fail-under=10.0
+
+# Add files or directories to the blacklist. They should be base names, not
+# paths.
+ignore=CVS,configs
+
+# Add files or directories matching the regex patterns to the blacklist. The
+# regex matches against base names, not paths.
+ignore-patterns=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
+confidence=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=print-statement,
+ parameter-unpacking,
+ unpacking-in-except,
+ old-raise-syntax,
+ backtick,
+ long-suffix,
+ old-ne-operator,
+ old-octal-literal,
+ import-star-module-level,
+ non-ascii-bytes-literal,
+ raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ use-symbolic-message-instead,
+ apply-builtin,
+ basestring-builtin,
+ buffer-builtin,
+ cmp-builtin,
+ coerce-builtin,
+ execfile-builtin,
+ file-builtin,
+ long-builtin,
+ raw_input-builtin,
+ reduce-builtin,
+ standarderror-builtin,
+ unicode-builtin,
+ xrange-builtin,
+ coerce-method,
+ delslice-method,
+ getslice-method,
+ setslice-method,
+ no-absolute-import,
+ old-division,
+ dict-iter-method,
+ dict-view-method,
+ next-method-called,
+ metaclass-assignment,
+ indexing-exception,
+ raising-string,
+ reload-builtin,
+ oct-method,
+ hex-method,
+ nonzero-method,
+ cmp-method,
+ input-builtin,
+ round-builtin,
+ intern-builtin,
+ unichr-builtin,
+ map-builtin-not-iterating,
+ zip-builtin-not-iterating,
+ range-builtin-not-iterating,
+ filter-builtin-not-iterating,
+ using-cmp-argument,
+ eq-without-hash,
+ div-method,
+ idiv-method,
+ rdiv-method,
+ exception-message-attribute,
+ invalid-str-codec,
+ sys-max-int,
+ bad-python3-import,
+ deprecated-string-function,
+ deprecated-str-translate-call,
+ deprecated-itertools-function,
+ deprecated-types-field,
+ next-method-defined,
+ dict-items-not-iterating,
+ dict-keys-not-iterating,
+ dict-values-not-iterating,
+ deprecated-operator-function,
+ deprecated-urllib-function,
+ xreadlines-attribute,
+ deprecated-sys-function,
+ exception-escape,
+ comprehension-escape,
+ no-member,
+ invalid-name,
+ too-many-branches,
+ wrong-import-order,
+ too-many-arguments,
+ missing-function-docstring,
+ missing-module-docstring,
+ too-many-locals,
+ too-few-public-methods,
+ abstract-method,
+ broad-except,
+ too-many-nested-blocks,
+ too-many-instance-attributes,
+ missing-class-docstring,
+ duplicate-code,
+ not-callable,
+ protected-access,
+ dangerous-default-value,
+ no-name-in-module,
+ logging-fstring-interpolation,
+ super-init-not-called,
+ redefined-builtin,
+ attribute-defined-outside-init,
+ arguments-differ,
+ cyclic-import,
+ bad-super-call,
+ too-many-statements
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'error', 'warning', 'refactor', and 'convention'
+# which contain the number of messages in each category, as well as 'statement'
+# which is the total number of statements analyzed. This score is used by the
+# global evaluation report (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+#msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages.
+reports=no
+
+# Activate the evaluation score.
+score=yes
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it work,
+# install the python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[LOGGING]
+
+# The type of string formatting that logging methods do. `old` means using %
+# formatting, `new` is for `{}` formatting.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+ _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )??$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Maximum number of lines in a module.
+max-module-lines=1000
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=no
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=no
+
+
+[SIMILARITIES]
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+ XXX,
+ TODO
+
+# Regular expression of note tags to take in consideration.
+#notes-rgx=
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
+
+# Bad variable names regexes, separated by a comma. If names match any regex,
+# they will always be refused
+bad-names-rgxs=
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style.
+#class-attribute-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+ j,
+ k,
+ ex,
+ Run,
+ _,
+ x,
+ y,
+ w,
+ h,
+ a,
+ b
+
+# Good variable names regexes, separated by a comma. If names match any regex,
+# they will always be accepted
+good-names-rgxs=
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style.
+#variable-rgx=
+
+
+[DESIGN]
+
+# Maximum number of arguments for function / method.
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=optparse,tkinter.tix
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled).
+ext-import-graph=
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled).
+import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp,
+ __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=cls
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "BaseException, Exception".
+overgeneral-exceptions=BaseException,
+ Exception
diff --git a/.readthedocs.yml b/.readthedocs.yml
new file mode 100644
index 00000000..73ea4cb7
--- /dev/null
+++ b/.readthedocs.yml
@@ -0,0 +1,7 @@
+version: 2
+
+python:
+ version: 3.7
+ install:
+ - requirements: requirements/docs.txt
+ - requirements: requirements/readthedocs.txt
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..3076a437
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,203 @@
+Copyright (c) MMOCR Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021 MMOCR Authors. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index c553ca6a..96f13097 100644
--- a/README.md
+++ b/README.md
@@ -1,17 +1,17 @@
-
+
## Introduction
-[](https://github.com/open-mmlab/mmediting/actions)
-[](https://mmediting.readthedocs.io/en/latest/?badge=latest)
-[](https://codecov.io/gh/open-mmlab/mmediting)
-[](https://github.com/open-mmlab/mmediting/blob/master/LICENSE)
+[](https://github.com/open-mmlab/mmocr/actions)
+[](https://mmocr.readthedocs.io/en/latest/?badge=latest)
+[](https://codecov.io/gh/open-mmlab/mmocr)
+[](https://github.com/open-mmlab/mmocr/blob/master/LICENSE)
MMOCR is an open-source toolbox based on PyTorch and mmdetection for text detection, text recognition, and the corresponding downstream tasks including key information extraction. It is part of the open-mmlab project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/).
-The master branch works with **PyTorch 1.5**.
+The master branch works with **PyTorch 1.5+**.
Documentation: https://mmocr.readthedocs.io/en/latest/.
@@ -31,7 +31,7 @@ Documentation: https://mmocr.readthedocs.io/en/latest/.
- **Modular Design**
- The modular design of MMOCR enables users to define their own optimizers, data preprocessors, and model components such as backbones, necks and heads as well as losses. Please refer to [GETTING_STARTED.md](docs/GETTING_STARTED.md) for how to construct a customized model.
+ The modular design of MMOCR enables users to define their own optimizers, data preprocessors, and model components such as backbones, necks and heads as well as losses. Please refer to [getting_started.md](docs/getting_started.md) for how to construct a customized model.
- **Numerous Utilities**
@@ -43,24 +43,24 @@ This project is released under the [Apache 2.0 license](LICENSE).
## Changelog
-v1.0 was released on 31/03/2021.
+v1.0 was released on 07/04/2021.
## Benchmark and Model Zoo
-Please refer to [MODEL_ZOO.md](MODEL_ZOO.md) for more details.
+Please refer to [modelzoo.md](modelzoo.md) for more details.
## Installation
-Please refer to [INSTALL.md](docs/INSTALL.md) for installation.
+Please refer to [install.md](docs/install.md) for installation.
## Get Started
-Please see [GETTING_STARTED.md](docs/GETTING_STARTED.md) for the basic usage of MMOCR.
+Please see [getting_started.md](docs/getting_started.md) for the basic usage of MMOCR.
## Contributing
-We appreciate all contributions to improve MMOCR. Please refer to [CONTRIBUTING.md](docs/CONTRIBUTING.md) for the contributing guidelines.
+We appreciate all contributions to improve MMOCR. Please refer to [contributing.md](docs/contributing.md) for the contributing guidelines.
## Acknowledgement
diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py
new file mode 100644
index 00000000..949e800b
--- /dev/null
+++ b/configs/_base_/default_runtime.py
@@ -0,0 +1,14 @@
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type='TextLoggerHook')
+
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/_base_/det_dataset/toy_dataset.py b/configs/_base_/det_dataset/toy_dataset.py
new file mode 100644
index 00000000..0b87d55f
--- /dev/null
+++ b/configs/_base_/det_dataset/toy_dataset.py
@@ -0,0 +1,97 @@
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+train_cfg = None
+test_cfg = None
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 640)],
+ ratio_range=(0.7, 1.3),
+ aspect_ratio_range=(0.9, 1.1),
+ multiscale_mode='value',
+ keep_ratio=False),
+ # shrink_ratio is from big to small. The 1st must be 1.0
+ dict(type='PANetTargets', shrink_ratio=(1.0, 0.7)),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='RandomRotateTextDet'),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ instance_key='gt_kernels'),
+ dict(type='Pad', size_divisor=32),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels', 'gt_mask'],
+ visualize=dict(flag=False, boundary_key='gt_kernels')),
+ dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(3000, 640),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(3000, 640), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+
+dataset_type = 'TextDetDataset'
+img_prefix = 'tests/data/toy_dataset/imgs'
+train_anno_file = 'tests/data/toy_dataset/instances_test.txt'
+train1 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=4,
+ parser=dict(
+ type='LineJsonParser',
+ keys=['file_name', 'height', 'width', 'annotations'])),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+data_root = 'tests/data/toy_dataset'
+train2 = dict(
+ type='IcdarDataset',
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline)
+
+test_anno_file = 'tests/data/toy_dataset/instances_test.txt'
+test = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=test_anno_file,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineJsonParser',
+ keys=['file_name', 'height', 'width', 'annotations'])),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train1, train2]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='hmean-iou')
diff --git a/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py
new file mode 100644
index 00000000..348778d8
--- /dev/null
+++ b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py
@@ -0,0 +1,126 @@
+# model settings
+model = dict(
+ type='OCRMaskRCNN',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=5),
+ rpn_head=dict(
+ type='RPNHead',
+ in_channels=256,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[4],
+ ratios=[0.17, 0.44, 1.13, 2.90, 7.46],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ roi_head=dict(
+ type='StandardRoIHead',
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=dict(
+ type='Shared2FCBBoxHead',
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=1,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ mask_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ mask_head=dict(
+ type='FCNMaskHead',
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=1,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
+
+ # model training and testing settings
+ train_cfg=dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.3,
+ min_pos_iou=0.3,
+ match_low_quality=True,
+ ignore_iof_thr=-1,
+ gpu_assign_thr=50),
+ sampler=dict(
+ type='RandomSampler',
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ rpn_proposal=dict(
+ nms_across_levels=False,
+ nms_pre=2000,
+ nms_post=1000,
+ max_num=1000,
+ nms_thr=0.7,
+ min_bbox_size=0),
+ rcnn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0.5,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='OHEMSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ mask_size=28,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ rpn=dict(
+ nms_across_levels=False,
+ nms_pre=1000,
+ nms_post=1000,
+ max_num=1000,
+ nms_thr=0.7,
+ min_bbox_size=0),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.5),
+ max_per_img=100,
+ mask_thr_binary=0.5)))
diff --git a/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py
new file mode 100644
index 00000000..3ef9a55e
--- /dev/null
+++ b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py
@@ -0,0 +1,126 @@
+# model settings
+model = dict(
+ type='OCRMaskRCNN',
+ text_repr_type='poly',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=5),
+ rpn_head=dict(
+ type='RPNHead',
+ in_channels=256,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[4],
+ ratios=[0.17, 0.44, 1.13, 2.90, 7.46],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ roi_head=dict(
+ type='StandardRoIHead',
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7, sample_num=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=dict(
+ type='Shared2FCBBoxHead',
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ mask_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=14, sample_num=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ mask_head=dict(
+ type='FCNMaskHead',
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=80,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
+ # model training and testing settings
+ train_cfg=dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.3,
+ min_pos_iou=0.3,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ rpn_proposal=dict(
+ nms_across_levels=False,
+ nms_pre=2000,
+ nms_post=1000,
+ max_num=1000,
+ nms_thr=0.7,
+ min_bbox_size=0),
+ rcnn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0.5,
+ match_low_quality=True,
+ ignore_iof_thr=-1,
+ gpu_assign_thr=50),
+ sampler=dict(
+ type='OHEMSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ mask_size=28,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ rpn=dict(
+ nms_across_levels=False,
+ nms_pre=1000,
+ nms_post=1000,
+ max_num=1000,
+ nms_thr=0.7,
+ min_bbox_size=0),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.5),
+ max_per_img=100,
+ mask_thr_binary=0.5)))
diff --git a/configs/_base_/recog_datasets/seg_toy_dataset.py b/configs/_base_/recog_datasets/seg_toy_dataset.py
new file mode 100644
index 00000000..d6c49dbf
--- /dev/null
+++ b/configs/_base_/recog_datasets/seg_toy_dataset.py
@@ -0,0 +1,96 @@
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+gt_label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomPaddingOCR',
+ max_ratio=[0.15, 0.2, 0.15, 0.2],
+ box_type='char_quads'),
+ dict(type='OpencvToPil'),
+ dict(
+ type='RandomRotateImageBox',
+ min_angle=-17,
+ max_angle=17,
+ box_type='char_quads'),
+ dict(type='PilToOpencv'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=512,
+ keep_aspect_ratio=True),
+ dict(
+ type='OCRSegTargets',
+ label_convertor=gt_label_convertor,
+ box_type='char_quads'),
+ dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+ dict(type='ToTensorOCR'),
+ dict(type='FancyPCA'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels'],
+ visualize=dict(flag=False, boundary_key=None),
+ call_super=False),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_kernels'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=None,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(type='CustomFormatBundle', call_super=False),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+prefix = 'tests/data/ocr_char_ann_toy_dataset/'
+train = dict(
+ type='OCRSegDataset',
+ img_prefix=prefix + 'imgs',
+ ann_file=prefix + 'instances_train.txt',
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=100,
+ parser=dict(
+ type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
+ pipeline=train_pipeline,
+ test_mode=True)
+
+test = dict(
+ type='OCRDataset',
+ img_prefix=prefix + 'imgs',
+ ann_file=prefix + 'instances_test.txt',
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=1,
+ train=dict(type='ConcatDataset', datasets=[train]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/_base_/recog_datasets/toy_dataset.py b/configs/_base_/recog_datasets/toy_dataset.py
new file mode 100755
index 00000000..83848863
--- /dev/null
+++ b/configs/_base_/recog_datasets/toy_dataset.py
@@ -0,0 +1,99 @@
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+img_prefix = 'tests/data/ocr_toy_dataset/imgs'
+train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
+train1 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
+train2 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file2,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
+test = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=test_anno_file1,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train1, train2]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/_base_/recog_models/crnn.py b/configs/_base_/recog_models/crnn.py
new file mode 100644
index 00000000..6b98c3d9
--- /dev/null
+++ b/configs/_base_/recog_models/crnn.py
@@ -0,0 +1,11 @@
+label_convertor = dict(
+ type='CTCConvertor', dict_type='DICT90', with_unknown=False)
+
+model = dict(
+ type='CRNNNet',
+ preprocessor=None,
+ backbone=dict(type='VeryDeepVgg', leakyRelu=False),
+ encoder=None,
+ decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
+ loss=dict(type='CTCLoss', flatten=False),
+ label_convertor=label_convertor)
diff --git a/configs/_base_/recog_models/nrtr.py b/configs/_base_/recog_models/nrtr.py
new file mode 100644
index 00000000..40657578
--- /dev/null
+++ b/configs/_base_/recog_models/nrtr.py
@@ -0,0 +1,11 @@
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+model = dict(
+ type='NRTR',
+ backbone=dict(type='NRTRModalityTransform'),
+ encoder=dict(type='TFEncoder'),
+ decoder=dict(type='TFDecoder'),
+ loss=dict(type='TFLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=40)
diff --git a/configs/_base_/recog_models/robust_scanner.py b/configs/_base_/recog_models/robust_scanner.py
new file mode 100644
index 00000000..4cc2fa10
--- /dev/null
+++ b/configs/_base_/recog_models/robust_scanner.py
@@ -0,0 +1,24 @@
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+hybrid_decoder = dict(type='SequenceAttentionDecoder')
+
+position_decoder = dict(type='PositionAttentionDecoder')
+
+model = dict(
+ type='RobustScanner',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='ChannelReductionEncoder',
+ in_channels=512,
+ out_channels=128,
+ ),
+ decoder=dict(
+ type='RobustScannerDecoder',
+ dim_input=512,
+ dim_model=128,
+ hybrid_decoder=hybrid_decoder,
+ position_decoder=position_decoder),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
diff --git a/configs/_base_/recog_models/sar.py b/configs/_base_/recog_models/sar.py
new file mode 100755
index 00000000..8438d9b9
--- /dev/null
+++ b/configs/_base_/recog_models/sar.py
@@ -0,0 +1,24 @@
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+model = dict(
+ type='SARNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='SAREncoder',
+ enc_bi_rnn=False,
+ enc_do_rnn=0.1,
+ enc_gru=False,
+ ),
+ decoder=dict(
+ type='ParallelSARDecoder',
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0,
+ dec_gru=False,
+ pred_dropout=0.1,
+ d_k=512,
+ pred_concat=True),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
diff --git a/configs/_base_/recog_models/transformer.py b/configs/_base_/recog_models/transformer.py
new file mode 100644
index 00000000..476643fa
--- /dev/null
+++ b/configs/_base_/recog_models/transformer.py
@@ -0,0 +1,11 @@
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=False)
+
+model = dict(
+ type='TransformerNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(type='TFEncoder'),
+ decoder=dict(type='TFDecoder'),
+ loss=dict(type='TFLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=40)
diff --git a/configs/_base_/runtime_10e.py b/configs/_base_/runtime_10e.py
new file mode 100644
index 00000000..393c5f14
--- /dev/null
+++ b/configs/_base_/runtime_10e.py
@@ -0,0 +1,14 @@
+checkpoint_config = dict(interval=10)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type='TextLoggerHook')
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/_base_/schedules/schedule_1200e.py b/configs/_base_/schedules/schedule_1200e.py
new file mode 100644
index 00000000..31e00920
--- /dev/null
+++ b/configs/_base_/schedules/schedule_1200e.py
@@ -0,0 +1,6 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.007, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True)
+total_epochs = 1200
diff --git a/configs/_base_/schedules/schedule_160e.py b/configs/_base_/schedules/schedule_160e.py
new file mode 100644
index 00000000..0958701a
--- /dev/null
+++ b/configs/_base_/schedules/schedule_160e.py
@@ -0,0 +1,11 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.08, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[80, 128])
+total_epochs = 160
diff --git a/configs/_base_/schedules/schedule_1x.py b/configs/_base_/schedules/schedule_1x.py
new file mode 100644
index 00000000..12694c87
--- /dev/null
+++ b/configs/_base_/schedules/schedule_1x.py
@@ -0,0 +1,11 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[8, 11])
+total_epochs = 12
diff --git a/configs/_base_/schedules/schedule_20e.py b/configs/_base_/schedules/schedule_20e.py
new file mode 100644
index 00000000..e6ca2b24
--- /dev/null
+++ b/configs/_base_/schedules/schedule_20e.py
@@ -0,0 +1,4 @@
+_base_ = './schedule_1x.py'
+# learning policy
+lr_config = dict(step=[16, 19])
+total_epochs = 20
diff --git a/configs/_base_/schedules/schedule_2e.py b/configs/_base_/schedules/schedule_2e.py
new file mode 100644
index 00000000..110ef8c5
--- /dev/null
+++ b/configs/_base_/schedules/schedule_2e.py
@@ -0,0 +1,6 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.007, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=True)
+total_epochs = 2
diff --git a/configs/_base_/schedules/schedule_2x.py b/configs/_base_/schedules/schedule_2x.py
new file mode 100644
index 00000000..72b4135c
--- /dev/null
+++ b/configs/_base_/schedules/schedule_2x.py
@@ -0,0 +1,4 @@
+_base_ = './schedule_1x.py'
+# learning policy
+lr_config = dict(step=[16, 22])
+total_epochs = 24
diff --git a/configs/_base_/schedules/schedule_adadelta_16e.py b/configs/_base_/schedules/schedule_adadelta_16e.py
new file mode 100644
index 00000000..f8cc8b9b
--- /dev/null
+++ b/configs/_base_/schedules/schedule_adadelta_16e.py
@@ -0,0 +1,6 @@
+# optimizer
+optimizer = dict(type='Adadelta', lr=1.0)
+optimizer_config = dict(grad_clip=dict(max_norm=0.5))
+# learning policy
+lr_config = dict(policy='step', step=[8, 10, 12])
+total_epochs = 16
diff --git a/configs/_base_/schedules/schedule_adadelta_8e.py b/configs/_base_/schedules/schedule_adadelta_8e.py
new file mode 100644
index 00000000..2b4a8444
--- /dev/null
+++ b/configs/_base_/schedules/schedule_adadelta_8e.py
@@ -0,0 +1,6 @@
+# optimizer
+optimizer = dict(type='Adadelta', lr=1.0)
+optimizer_config = dict(grad_clip=dict(max_norm=0.5))
+# learning policy
+lr_config = dict(policy='step', step=[4, 6, 7])
+total_epochs = 8
diff --git a/configs/_base_/schedules/schedule_adam_1e.py b/configs/_base_/schedules/schedule_adam_1e.py
new file mode 100644
index 00000000..b4b13379
--- /dev/null
+++ b/configs/_base_/schedules/schedule_adam_1e.py
@@ -0,0 +1,6 @@
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='poly', power=0.9)
+total_epochs = 1
diff --git a/configs/_base_/schedules/schedule_adam_600e.py b/configs/_base_/schedules/schedule_adam_600e.py
new file mode 100644
index 00000000..e946603e
--- /dev/null
+++ b/configs/_base_/schedules/schedule_adam_600e.py
@@ -0,0 +1,6 @@
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='poly', power=0.9)
+total_epochs = 600
diff --git a/configs/_base_/schedules/schedule_sgd_600e.py b/configs/_base_/schedules/schedule_sgd_600e.py
new file mode 100644
index 00000000..9a605291
--- /dev/null
+++ b/configs/_base_/schedules/schedule_sgd_600e.py
@@ -0,0 +1,6 @@
+# optimizer
+optimizer = dict(type='SGD', lr=1e-3, momentum=0.99, weight_decay=5e-4)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[200, 400])
+total_epochs = 600
diff --git a/configs/kie/sdmgr/README.md b/configs/kie/sdmgr/README.md
new file mode 100644
index 00000000..d0de5641
--- /dev/null
+++ b/configs/kie/sdmgr/README.md
@@ -0,0 +1,25 @@
+# Spatial Dual-Modality Graph Reasoning for Key Information Extraction
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@misc{sun2021spatial,
+ title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
+ author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
+ year={2021},
+ eprint={2103.14470},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+## Results and models
+
+### WildReceipt
+
+| Method | Modality | Macro F1-Score | Download |
+| :--------------------------------------------------------------------: | :--------------: | :------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.876 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_unet16_60e_wildreceipt_20210405-16a47642.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210405_104508.log.json) |
+| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.864 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_20210405-07bc26ad.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210405_141138.log.json) |
diff --git a/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py b/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py
new file mode 100644
index 00000000..6568909e
--- /dev/null
+++ b/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py
@@ -0,0 +1,93 @@
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+max_scale, min_scale = 1024, 512
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='KIEFormatBundle'),
+ dict(
+ type='Collect',
+ keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='KIEFormatBundle'),
+ dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
+]
+
+dataset_type = 'KIEDataset'
+data_root = 'data/wildreceipt'
+
+loader = dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineJsonParser',
+ keys=['file_name', 'height', 'width', 'annotations']))
+
+train = dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/train.txt',
+ pipeline=train_pipeline,
+ img_prefix=data_root,
+ loader=loader,
+ dict_file=f'{data_root}/dict.txt',
+ test_mode=False)
+test = dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/test.txt',
+ pipeline=test_pipeline,
+ img_prefix=data_root,
+ loader=loader,
+ dict_file=f'{data_root}/dict.txt',
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test)
+
+evaluation = dict(
+ interval=1,
+ metric='macro_f1',
+ metric_options=dict(
+ macro_f1=dict(
+ ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
+
+model = dict(
+ type='SDMGR',
+ backbone=dict(type='UNet', base_channels=16),
+ bbox_head=dict(
+ type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
+ visual_modality=False,
+ train_cfg=None,
+ test_cfg=None,
+ class_list=f'{data_root}/class_list.txt')
+
+optimizer = dict(type='Adam', weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1,
+ warmup_ratio=1,
+ step=[40, 50])
+total_epochs = 60
+
+checkpoint_config = dict(interval=1)
+log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+
+find_unused_parameters = True
diff --git a/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py
new file mode 100644
index 00000000..afc38393
--- /dev/null
+++ b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py
@@ -0,0 +1,93 @@
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+max_scale, min_scale = 1024, 512
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='KIEFormatBundle'),
+ dict(
+ type='Collect',
+ keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations'),
+ dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='KIEFormatBundle'),
+ dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'])
+]
+
+dataset_type = 'KIEDataset'
+data_root = 'data/wildreceipt'
+
+loader = dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineJsonParser',
+ keys=['file_name', 'height', 'width', 'annotations']))
+
+train = dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/train.txt',
+ pipeline=train_pipeline,
+ img_prefix=data_root,
+ loader=loader,
+ dict_file=f'{data_root}/dict.txt',
+ test_mode=False)
+test = dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/test.txt',
+ pipeline=test_pipeline,
+ img_prefix=data_root,
+ loader=loader,
+ dict_file=f'{data_root}/dict.txt',
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test)
+
+evaluation = dict(
+ interval=1,
+ metric='macro_f1',
+ metric_options=dict(
+ macro_f1=dict(
+ ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25])))
+
+model = dict(
+ type='SDMGR',
+ backbone=dict(type='UNet', base_channels=16),
+ bbox_head=dict(
+ type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26),
+ visual_modality=True,
+ train_cfg=None,
+ test_cfg=None,
+ class_list=f'{data_root}/class_list.txt')
+
+optimizer = dict(type='Adam', weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=1,
+ warmup_ratio=1,
+ step=[40, 50])
+total_epochs = 60
+
+checkpoint_config = dict(interval=1)
+log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+
+find_unused_parameters = True
diff --git a/configs/textdet/dbnet/README.md b/configs/textdet/dbnet/README.md
new file mode 100644
index 00000000..c8b6a094
--- /dev/null
+++ b/configs/textdet/dbnet/README.md
@@ -0,0 +1,28 @@
+# Real-time Scene Text Detection with Differentiable Binarization
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@article{Liao_Wan_Yao_Chen_Bai_2020,
+ title={Real-Time Scene Text Detection with Differentiable Binarization},
+ journal={Proceedings of the AAAI Conference on Artificial Intelligence},
+ author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang},
+ year={2020},
+ pages={11474-11481}}
+```
+
+## Results and models
+
+### CTW1500
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :---------------------------------------------------------------: | :--------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [DBNet](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) |
+
+### ICDAR2015
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :--------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [DBNet](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth.log.json) |
diff --git a/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
new file mode 100644
index 00000000..790355fc
--- /dev/null
+++ b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
@@ -0,0 +1,96 @@
+_base_ = [
+ '../../_base_/schedules/schedule_1200e.py', '../../_base_/runtime_10e.py'
+]
+model = dict(
+ type='DBNet',
+ pretrained='torchvision://resnet18',
+ backbone=dict(
+ type='ResNet',
+ depth=18,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ style='caffe'),
+ neck=dict(
+ type='FPNC', in_channels=[64, 128, 256, 512], lateral_channels=256),
+ bbox_head=dict(
+ type='DBHead',
+ text_repr_type='quad',
+ in_channels=256,
+ loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2015/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# for visualizing img, pls uncomment it.
+# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ # img aug
+ dict(
+ type='ImgAug',
+ args=[['Fliplr', 0.5],
+ dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
+ # random crop
+ dict(type='EastRandomCrop', target_size=(640, 640)),
+ dict(type='DBNetTargets', shrink_ratio=0.4),
+ dict(type='Pad', size_divisor=32),
+ # for visualizing img and gts, pls set visualize = True
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'],
+ visualize=dict(flag=False, boundary_key='gt_shrink')),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 736),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(2944, 736), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=8,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ # for debugging top k imgs
+ # select_first_k=200,
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline))
+evaluation = dict(interval=100, metric='hmean-iou')
diff --git a/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
new file mode 100644
index 00000000..f1ccec51
--- /dev/null
+++ b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
@@ -0,0 +1,105 @@
+_base_ = [
+ '../../_base_/schedules/schedule_1200e.py', '../../_base_/runtime_10e.py'
+]
+load_from = 'checkpoints/textdet/dbnet/res50dcnv2_synthtext.pth'
+
+model = dict(
+ type='DBNet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ style='caffe',
+ dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
+ stage_with_dcn=(False, True, True, True)),
+ neck=dict(
+ type='FPNC', in_channels=[256, 512, 1024, 2048], lateral_channels=256),
+ bbox_head=dict(
+ type='DBHead',
+ text_repr_type='quad',
+ in_channels=256,
+ loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2015/'
+# img_norm_cfg = dict(
+# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# from official dbnet code
+img_norm_cfg = dict(
+ mean=[122.67891434, 116.66876762, 104.00698793],
+ std=[255, 255, 255],
+ to_rgb=False)
+# for visualizing img, pls uncomment it.
+# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ # img aug
+ dict(
+ type='ImgAug',
+ args=[['Fliplr', 0.5],
+ dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]),
+ # random crop
+ dict(type='EastRandomCrop', target_size=(640, 640)),
+ dict(type='DBNetTargets', shrink_ratio=0.4),
+ dict(type='Pad', size_divisor=32),
+ # for visualizing img and gts, pls set visualize = True
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'],
+ visualize=dict(flag=False, boundary_key='gt_shrink')),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(4068, 1024),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(4068, 1024), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ # for debugging top k imgs
+ # select_first_k=200,
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline))
+evaluation = dict(interval=100, metric='hmean-iou')
diff --git a/configs/textdet/maskrcnn/README.md b/configs/textdet/maskrcnn/README.md
new file mode 100644
index 00000000..520a9606
--- /dev/null
+++ b/configs/textdet/maskrcnn/README.md
@@ -0,0 +1,35 @@
+# Mask R-CNN
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@article{pmtd,
+ author={Jingchao Liu and Xuebo Liu and Jie Sheng and Ding Liang and Xin Li and Qingjie Liu},
+ title={Pyramid Mask Text Detector},
+ journal={CoRR},
+ volume={abs/1903.11800},
+ year={2019}
+}
+```
+
+## Results and models
+
+### CTW1500
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :---------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [MaskRCNN](/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 160 | 1600 | 0.753 | 0.712 | 0.732 | [model](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.log.json) |
+
+### ICDAR2015
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :-----------------------------------------------------------------------: | :--------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [MaskRCNN](/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 160 | 1920 | 0.783 | 0.872 | 0.825 | [model](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.log.json) |
+
+### ICDAR2017
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :-----------------------------------------------------------------------: | :--------------: | :-------------: | :-----------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [MaskRCNN](/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py) | ImageNet | ICDAR2017 Train | ICDAR2017 Val | 160 | 1600 | 0.754 | 0.827 | 0.789 | [model](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.log.json) |
diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py
new file mode 100644
index 00000000..d821bb7f
--- /dev/null
+++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py
@@ -0,0 +1,67 @@
+_base_ = [
+ '../../_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py',
+ '../../_base_/schedules/schedule_160e.py', '../../_base_/runtime_10e.py'
+]
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/ctw1500/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=None,
+ keep_ratio=False,
+ resize_type='indep_sample_in_range',
+ scale_range=(640, 2560)),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ mask_type='union_all',
+ instance_key='gt_masks'),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ # resize the long size to 1600
+ img_scale=(1600, 1600),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ # no flip
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ # select_first_k=1,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ # select_first_k=1,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py
new file mode 100644
index 00000000..e2d8be68
--- /dev/null
+++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py
@@ -0,0 +1,66 @@
+_base_ = [
+ '../../_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py',
+ '../../_base_/schedules/schedule_160e.py', '../../_base_/runtime_10e.py'
+]
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2015/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=None,
+ keep_ratio=False,
+ resize_type='indep_sample_in_range',
+ scale_range=(640, 2560)),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ mask_type='union_all',
+ instance_key='gt_masks'),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ # resize the long size to 1600
+ img_scale=(1920, 1920),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ # no flip
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ # select_first_k=1,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ # select_first_k=1,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py
new file mode 100644
index 00000000..8b948e09
--- /dev/null
+++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py
@@ -0,0 +1,67 @@
+_base_ = [
+ '../../_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py',
+ '../../_base_/schedules/schedule_160e.py', '../../_base_/runtime_10e.py'
+]
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2017/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=None,
+ keep_ratio=False,
+ resize_type='indep_sample_in_range',
+ scale_range=(640, 2560)),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ mask_type='union_all',
+ instance_key='gt_masks'),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ # resize the long size to 1600
+ img_scale=(1600, 1600),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ # no flip
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ # select_first_k=1,
+ ann_file=data_root + '/instances_val.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ # select_first_k=1,
+ ann_file=data_root + '/instances_val.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/panet/README.md b/configs/textdet/panet/README.md
new file mode 100644
index 00000000..b4290bf6
--- /dev/null
+++ b/configs/textdet/panet/README.md
@@ -0,0 +1,35 @@
+# Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@inproceedings{WangXSZWLYS19,
+ author={Wenhai Wang and Enze Xie and Xiaoge Song and Yuhang Zang and Wenjia Wang and Tong Lu and Gang Yu and Chunhua Shen},
+ title={Efficient and Accurate Arbitrary-Shaped Text Detection With Pixel Aggregation Network},
+ booktitle={ICCV},
+ pages={8439--8448},
+ year={2019}
+ }
+```
+
+## Results and models
+
+### CTW1500
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :----------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [PANet](/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 600 | 640 | 0.790 | 0.838 | 0.813 | [model](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.log.json) |
+
+### ICDAR2015
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :------------------------------------------------------------------: | :--------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [PANet](/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 600 | 736 | 0.734 | 0.856 | 0.791 | [model](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.log.json) |
+
+### ICDAR2017
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :------------------------------------------------------------------: | :--------------: | :-------------: | :-----------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [PANet](/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py) | ImageNet | ICDAR2017 Train | ICDAR2017 Val | 600 | 800 | 0.604 | 0.812 | 0.693 | [model](https://download.openmmlab.com/mmocr/textdet/panet/panet_r50_fpem_ffm_sbn_600e_icdar2017_20210219-b4877a4f.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/panet/panet_r50_fpem_ffm_sbn_600e_icdar2017_20210219-b4877a4f.log.json) |
diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py
new file mode 100644
index 00000000..58d1d22b
--- /dev/null
+++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py
@@ -0,0 +1,104 @@
+_base_ = [
+ '../../_base_/schedules/schedule_adam_600e.py',
+ '../../_base_/runtime_10e.py'
+]
+model = dict(
+ type='PANet',
+ pretrained='torchvision://resnet18',
+ backbone=dict(
+ type='ResNet',
+ depth=18,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ norm_eval=True,
+ style='caffe'),
+ neck=dict(type='FPEM_FFM', in_channels=[64, 128, 256, 512]),
+ bbox_head=dict(
+ type='PANHead',
+ text_repr_type='poly',
+ in_channels=[128, 128, 128, 128],
+ out_channels=6,
+ loss=dict(type='PANLoss')),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/ctw1500/'
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# for visualizing img, pls uncomment it.
+# img_norm_cfg = dict(
+# mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 640)],
+ ratio_range=(0.7, 1.3),
+ aspect_ratio_range=(0.9, 1.1),
+ multiscale_mode='value',
+ keep_ratio=False),
+ # shrink_ratio is from big to small. The 1st must be 1.0
+ dict(type='PANetTargets', shrink_ratio=(1.0, 0.7)),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='RandomRotateTextDet'),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ instance_key='gt_kernels'),
+ dict(type='Pad', size_divisor=32),
+ # for visualizing img and gts, pls set visualize = True
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels', 'gt_mask'],
+ visualize=dict(flag=False, boundary_key='gt_kernels')),
+ dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(3000, 640),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(3000, 640), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ # for debugging top k imgs
+ # select_first_k=200,
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline))
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py
new file mode 100644
index 00000000..de12c26c
--- /dev/null
+++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py
@@ -0,0 +1,102 @@
+_base_ = [
+ '../../_base_/schedules/schedule_adam_600e.py',
+ '../../_base_/runtime_10e.py'
+]
+model = dict(
+ type='PANet',
+ pretrained='torchvision://resnet18',
+ backbone=dict(
+ type='ResNet',
+ depth=18,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ norm_eval=True,
+ style='caffe'),
+ neck=dict(type='FPEM_FFM', in_channels=[64, 128, 256, 512]),
+ bbox_head=dict(
+ type='PANHead',
+ text_repr_type='quad',
+ in_channels=[128, 128, 128, 128],
+ out_channels=6,
+ loss=dict(type='PANLoss')),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2015/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# for visualizing img, pls uncomment it.
+# img_norm_cfg = dict(
+# mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 736)],
+ ratio_range=(0.7, 1.3),
+ aspect_ratio_range=(0.9, 1.1),
+ multiscale_mode='value',
+ keep_ratio=False),
+ dict(type='PANetTargets', shrink_ratio=(1.0, 0.5), max_shrink=20),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='RandomRotateTextDet'),
+ dict(
+ type='RandomCropInstances',
+ target_size=(736, 736),
+ instance_key='gt_kernels'),
+ dict(type='Pad', size_divisor=32),
+ # for visualizing img and gts, pls set visualize = True
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels', 'gt_mask'],
+ visualize=dict(flag=False, boundary_key='gt_kernels')),
+ dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 736),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(1333, 736), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ # for debugging top k imgs
+ # select_first_k=200,
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ # select_first_k=100,
+ pipeline=test_pipeline))
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py
new file mode 100644
index 00000000..933f0bae
--- /dev/null
+++ b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py
@@ -0,0 +1,93 @@
+_base_ = [
+ '../../_base_/schedules/schedule_adam_600e.py',
+ '../../_base_/runtime_10e.py'
+]
+model = dict(
+ type='PANet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='caffe'),
+ neck=dict(type='FPEM_FFM', in_channels=[256, 512, 1024, 2048]),
+ bbox_head=dict(
+ type='PANHead',
+ in_channels=[128, 128, 128, 128],
+ out_channels=6,
+ loss=dict(type='PANLoss', speedup_bbox_thr=32)),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2017/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 800)],
+ ratio_range=(0.7, 1.3),
+ aspect_ratio_range=(0.9, 1.1),
+ multiscale_mode='value',
+ keep_ratio=False),
+ dict(type='PANetTargets', shrink_ratio=(1.0, 0.5)),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='RandomRotateTextDet'),
+ dict(
+ type='RandomCropInstances',
+ target_size=(800, 800),
+ instance_key='gt_kernels'),
+ dict(type='Pad', size_divisor=32),
+ # for visualizing img and gts, pls set visualize = True
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels', 'gt_mask'],
+ visualize=dict(flag=False, boundary_key='gt_kernels')),
+ dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_val.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_val.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/psenet/README.md b/configs/textdet/psenet/README.md
new file mode 100644
index 00000000..9995b4d8
--- /dev/null
+++ b/configs/textdet/psenet/README.md
@@ -0,0 +1,29 @@
+# PSENet
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@article{li2018shape,
+ title={Shape robust text detection with progressive scale expansion network},
+ author={Li, Xiang and Wang, Wenhai and Hou, Wenbo and Liu, Ruo-Ze and Lu, Tong and Yang, Jian},
+ journal={arXiv preprint arXiv:1806.02559},
+ year={2018}
+}
+```
+
+## Results and models
+
+### CTW1500
+
+| Method | Backbone | Extra Data | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :------------------------------------------------------------------: | :------: | :--------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [PSENet-4s](/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py) | ResNet50 | - | CTW1500 Train | CTW1500 Test | 600 | 1280 | 0.728 | 0.849 | 0.784 | [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/psenet/20210401_215421.log.json) |
+
+### ICDAR2015
+
+| Method | Backbone | Extra Data | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :--------------------------------------------------------------------: | :------: | :---------------------------------------------------------------------------------------------------------------------------------------: | :----------: | :-------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| [PSENet-4s](/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | ResNet50 | - | IC15 Train | IC15 Test | 600 | 2240 | 0.784 | 0.831 | 0.807 | [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_icdar2015-c6131f0d.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/psenet/20210331_214145.log.json) |
+| [PSENet-4s](/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | ResNet50 | pretrain on IC17 MLT [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_icdar2017_as_pretrain-3bd6056c.pth) | IC15 Train | IC15 Test | 600 | 2240 | 0.834 | 0.861 | 0.847 | [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth) \| [log]() |
diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py
new file mode 100644
index 00000000..e1f5fd04
--- /dev/null
+++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py
@@ -0,0 +1,108 @@
+_base_ = ['../../_base_/default_runtime.py']
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-4)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[200, 400])
+total_epochs = 600
+
+model = dict(
+ type='PSENet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ norm_eval=True,
+ style='caffe'),
+ neck=dict(
+ type='FPNF',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ fusion_type='concat'),
+ bbox_head=dict(
+ type='PSEHead',
+ text_repr_type='poly',
+ in_channels=[256],
+ out_channels=7,
+ loss=dict(type='PSELoss')),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/ctw1500/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 736)],
+ ratio_range=(0.5, 3),
+ aspect_ratio_range=(1, 1),
+ multiscale_mode='value',
+ long_size_bound=1280,
+ short_size_bound=640,
+ resize_type='long_short_bound',
+ keep_ratio=False),
+ dict(type='PSENetTargets'),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='RandomRotateTextDet'),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ instance_key='gt_kernels'),
+ dict(type='Pad', size_divisor=32),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels', 'gt_mask'],
+ visualize=dict(flag=False, boundary_key='gt_kernels')),
+ dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1280, 1280),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(1280, 1280), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py
new file mode 100644
index 00000000..5eb7538c
--- /dev/null
+++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py
@@ -0,0 +1,108 @@
+_base_ = ['../../_base_/runtime_10e.py']
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-4)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[200, 400])
+total_epochs = 600
+
+model = dict(
+ type='PSENet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ norm_eval=True,
+ style='caffe'),
+ neck=dict(
+ type='FPNF',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ fusion_type='concat'),
+ bbox_head=dict(
+ type='PSEHead',
+ text_repr_type='quad',
+ in_channels=[256],
+ out_channels=7,
+ loss=dict(type='PSELoss')),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2015/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 736)], # unused
+ ratio_range=(0.5, 3),
+ aspect_ratio_range=(1, 1),
+ multiscale_mode='value',
+ long_size_bound=1280,
+ short_size_bound=640,
+ resize_type='long_short_bound',
+ keep_ratio=False),
+ dict(type='PSENetTargets'),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='RandomRotateTextDet'),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ instance_key='gt_kernels'),
+ dict(type='Pad', size_divisor=32),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels', 'gt_mask'],
+ visualize=dict(flag=False, boundary_key='gt_kernels')),
+ dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2240, 2200),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(2240, 2200), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py
new file mode 100644
index 00000000..3c20c4cc
--- /dev/null
+++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py
@@ -0,0 +1,103 @@
+_base_ = [
+ '../../_base_/schedules/schedule_sgd_600e.py',
+ '../../_base_/runtime_10e.py'
+]
+model = dict(
+ type='PSENet',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='caffe'),
+ neck=dict(
+ type='FPNF',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ fusion_type='concat'),
+ bbox_head=dict(
+ type='PSEHead',
+ text_repr_type='quad',
+ in_channels=[256],
+ out_channels=7,
+ loss=dict(type='PSELoss')),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/icdar2017/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 736)],
+ ratio_range=(0.5, 3),
+ aspect_ratio_range=(1, 1),
+ multiscale_mode='value',
+ long_size_bound=1280,
+ short_size_bound=640,
+ resize_type='long_short_bound',
+ keep_ratio=False),
+ dict(type='PSENetTargets'),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='RandomRotateTextDet'),
+ dict(
+ type='RandomCropInstances',
+ target_size=(640, 640),
+ instance_key='gt_kernels'),
+ dict(type='Pad', size_divisor=32),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels', 'gt_mask'],
+ visualize=dict(flag=False, boundary_key='gt_kernels')),
+ dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2240, 2200),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(2240, 2200), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_val.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_val.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textdet/textsnake/README.md b/configs/textdet/textsnake/README.md
new file mode 100644
index 00000000..8e761f4c
--- /dev/null
+++ b/configs/textdet/textsnake/README.md
@@ -0,0 +1,23 @@
+# Textsnake
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@article{long2018textsnake,
+ title={TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes},
+ author={Long, Shangbang and Ruan, Jiaqiang and Zhang, Wenjie and He, Xin and Wu, Wenhao and Yao, Cong},
+ booktitle={ECCV},
+ pages={20-36},
+ year={2018}
+}
+```
+
+## Results and models
+
+### CTW1500
+
+| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
+| :----------------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :--------------------------------------------------------------------------------------------------------------------------: |
+| [TextSnake](/configs/textdet/textsnake/textsnake_r50_fpn_unet_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 736 | 0.795 | 0.840 | 0.817 | [model](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth) \| [log]() |
diff --git a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py
new file mode 100644
index 00000000..dba03cd1
--- /dev/null
+++ b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py
@@ -0,0 +1,113 @@
+_base_ = [
+ '../../_base_/schedules/schedule_1200e.py',
+ '../../_base_/default_runtime.py'
+]
+model = dict(
+ type='TextSnake',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='caffe'),
+ neck=dict(
+ type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32),
+ bbox_head=dict(
+ type='TextSnakeHead',
+ in_channels=32,
+ text_repr_type='poly',
+ loss=dict(type='TextSnakeLoss')),
+ train_cfg=None,
+ test_cfg=None)
+
+dataset_type = 'IcdarDataset'
+data_root = 'data/ctw1500/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadTextAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ poly2mask=False),
+ dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(
+ type='RandomCropPolyInstances',
+ instance_key='gt_masks',
+ crop_ratio=0.65,
+ min_side_ratio=0.3),
+ dict(
+ type='RandomRotatePolyInstances',
+ rotate_ratio=0.5,
+ max_angle=20,
+ pad_with_fixed_color=False),
+ dict(
+ type='ScaleAspectJitter',
+ img_scale=[(3000, 736)], # unused
+ ratio_range=(0.7, 1.3),
+ aspect_ratio_range=(0.9, 1.1),
+ multiscale_mode='value',
+ long_size_bound=800,
+ short_size_bound=480,
+ resize_type='long_short_bound',
+ keep_ratio=False),
+ dict(type='SquareResizePad', target_size=800, pad_ratio=0.6),
+ dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'),
+ dict(type='TextSnakeTargets'),
+ dict(type='Pad', size_divisor=32),
+ dict(
+ type='CustomFormatBundle',
+ keys=[
+ 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
+ 'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
+ ],
+ visualize=dict(flag=False, boundary_key='gt_text_mask')),
+ dict(
+ type='Collect',
+ keys=[
+ 'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask',
+ 'gt_radius_map', 'gt_sin_map', 'gt_cos_map'
+ ])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 736),
+ flip=False,
+ transforms=[
+ dict(type='Resize', img_scale=(1333, 736), keep_ratio=True),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_training.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + '/instances_test.json',
+ img_prefix=data_root + '/imgs',
+ pipeline=test_pipeline))
+
+evaluation = dict(interval=10, metric='hmean-iou')
diff --git a/configs/textrecog/crnn/README.md b/configs/textrecog/crnn/README.md
new file mode 100644
index 00000000..489cc64a
--- /dev/null
+++ b/configs/textrecog/crnn/README.md
@@ -0,0 +1,37 @@
+# An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@article{shi2016end,
+ title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
+ author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
+ journal={IEEE transactions on pattern analysis and machine intelligence},
+ year={2016}
+}
+```
+
+## Results and Models
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | note |
+| :------: | :----------: | :--------: | :---: |
+| Syn90k | 8919273 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | note |
+| :-----: | :----------: | :-----: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+
+## Results and models
+
+| methods | | Regular Text | | | | Irregular Text | | download |
+| :-----: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------: |
+| methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
+| CRNN | 80.5 | 81.5 | 86.5 | | - | - | - | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) |
diff --git a/configs/textrecog/crnn/crnn_academic_dataset.py b/configs/textrecog/crnn/crnn_academic_dataset.py
new file mode 100644
index 00000000..75701698
--- /dev/null
+++ b/configs/textrecog/crnn/crnn_academic_dataset.py
@@ -0,0 +1,138 @@
+_base_ = []
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=1,
+ hooks=[
+ dict(type='TextLoggerHook')
+
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+
+# model
+label_convertor = dict(
+ type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True)
+
+model = dict(
+ type='CRNNNet',
+ preprocessor=None,
+ backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1),
+ encoder=None,
+ decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
+ loss=dict(type='CTCLoss'),
+ label_convertor=label_convertor,
+ pretrained=None)
+
+train_cfg = None
+test_cfg = None
+
+# optimizer
+optimizer = dict(type='Adadelta', lr=1.0)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[])
+total_epochs = 5
+
+# data
+img_norm_cfg = dict(mean=[0.5], std=[0.5])
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=100,
+ max_width=100,
+ keep_aspect_ratio=False),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=4,
+ max_width=None,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']),
+]
+
+dataset_type = 'OCRDataset'
+
+train_img_prefix = 'data/mixture/Syn90k/mnt/ramdisk/max/90kDICT32px'
+train_ann_file = 'data/mixture/Syn90k/label.lmdb'
+
+train1 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix,
+ ann_file=train_ann_file,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+test_prefix = 'data/mixture/'
+test_img_prefix1 = test_prefix + 'icdar_2013/'
+test_img_prefix2 = test_prefix + 'IIIT5K/'
+test_img_prefix3 = test_prefix + 'svt/'
+
+test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file3 = test_prefix + 'svt/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=4,
+ train=dict(type='ConcatDataset', datasets=[train1]),
+ val=dict(type='ConcatDataset', datasets=[test1, test2, test3]),
+ test=dict(type='ConcatDataset', datasets=[test1, test2, test3]))
+
+evaluation = dict(interval=1, metric='acc')
+
+cudnn_benchmark = True
diff --git a/configs/textrecog/crnn/crnn_toy_dataset.py b/configs/textrecog/crnn/crnn_toy_dataset.py
new file mode 100644
index 00000000..76854024
--- /dev/null
+++ b/configs/textrecog/crnn/crnn_toy_dataset.py
@@ -0,0 +1,6 @@
+_base_ = [
+ '../../_base_/schedules/schedule_adadelta_8e.py',
+ '../../_base_/default_runtime.py',
+ '../../_base_/recog_datasets/toy_dataset.py',
+ '../../_base_/recog_models/crnn.py'
+]
diff --git a/configs/textrecog/nrtr/README.md b/configs/textrecog/nrtr/README.md
new file mode 100644
index 00000000..7d018559
--- /dev/null
+++ b/configs/textrecog/nrtr/README.md
@@ -0,0 +1,61 @@
+# NRTR
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@inproceedings{sheng2019nrtr,
+ title={NRTR: A no-recurrence sequence-to-sequence model for scene text recognition},
+ author={Sheng, Fenfen and Chen, Zhineng and Xu, Bo},
+ booktitle={2019 International Conference on Document Analysis and Recognition (ICDAR)},
+ pages={781--786},
+ year={2019},
+ organization={IEEE}
+}
+```
+
+[BACKBONE]
+
+```bibtex
+@inproceedings{li2019show,
+ title={Show, attend and read: A simple and strong baseline for irregular text recognition},
+ author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={33},
+ number={01},
+ pages={8610--8617},
+ year={2019}
+}
+```
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :--------: | :----------: | :--------: | :----------------------: |
+| SynthText | 7266686 | 1 | synth |
+| Syn90k | 8919273 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | type |
+| :-----: | :----------: | :-------------------------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| IC15 | 2077 | irregular |
+| SVTP | 645 | irregular |
+| CT80 | 288 | irregular |
+
+## Results and Models
+
+| Methods | Backbone || Regular Text |||| Irregular Text ||download|
+| :-------: | :---------: | :----: | :----: | :--: | :-: | :--: | :------: | :--: | :-----: |
+| | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
+| [NRTR](/configs/textrecog/nrtr/nrtr_r31_academic.py) | R31-1/16-1/8 | 93.9 | 90.0| 93.5 | | 74.5 | 78.5 | 86.5 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_010150.log.json) |
+
+**Notes:**
+
+- `R31-1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width.
diff --git a/configs/textrecog/nrtr/nrtr_modality_toy.py b/configs/textrecog/nrtr/nrtr_modality_toy.py
new file mode 100644
index 00000000..e8201c6f
--- /dev/null
+++ b/configs/textrecog/nrtr/nrtr_modality_toy.py
@@ -0,0 +1,112 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/recog_models/nrtr.py',
+]
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 6
+
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=100,
+ keep_aspect_ratio=False),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=100,
+ keep_aspect_ratio=False),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+img_prefix = 'tests/data/ocr_toy_dataset/imgs'
+train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
+train1 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
+train2 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file2,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
+test = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=test_anno_file1,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train1, train2]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/nrtr/nrtr_r31_academic.py b/configs/textrecog/nrtr/nrtr_r31_academic.py
new file mode 100644
index 00000000..9299583b
--- /dev/null
+++ b/configs/textrecog/nrtr/nrtr_r31_academic.py
@@ -0,0 +1,163 @@
+_base_ = [
+ '../../_base_/default_runtime.py', '../../_base_/recog_models/nrtr.py'
+]
+
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+model = dict(
+ type='NRTR',
+ backbone=dict(
+ type='ResNet31OCR',
+ layers=[1, 2, 5, 3],
+ channels=[32, 64, 128, 256, 512, 512],
+ stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
+ last_stage_pool=True),
+ encoder=dict(type='TFEncoder'),
+ decoder=dict(type='TFDecoder'),
+ loss=dict(type='TFLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=40)
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 6
+
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+
+train_prefix = 'data/mixture/'
+
+train_img_prefix1 = train_prefix + \
+ 'SynthText/synthtext/SynthText_patch_horizontal'
+train_img_prefix2 = train_prefix + 'Syn90k/mnt/ramdisk/max/90kDICT32px'
+
+train_ann_file1 = train_prefix + 'SynthText/label.lmdb',
+train_ann_file2 = train_prefix + 'Syn90k/label.lmdb'
+
+train1 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix1,
+ ann_file=train_ann_file1,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train2 = {key: value for key, value in train1.items()}
+train2['img_prefix'] = train_img_prefix2
+train2['ann_file'] = train_ann_file2
+
+test_prefix = 'data/mixture/'
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'icdar_2015/'
+test_img_prefix5 = test_prefix + 'svtp/'
+test_img_prefix6 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
+test_ann_file5 = test_prefix + 'svtp/test_label.txt'
+test_ann_file6 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+test5 = {key: value for key, value in test1.items()}
+test5['img_prefix'] = test_img_prefix5
+test5['ann_file'] = test_ann_file5
+
+test6 = {key: value for key, value in test1.items()}
+test6['img_prefix'] = test_img_prefix6
+test6['ann_file'] = test_ann_file6
+
+data = dict(
+ samples_per_gpu=128,
+ workers_per_gpu=4,
+ train=dict(type='ConcatDataset', datasets=[train1, train2]),
+ val=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]),
+ test=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/robust_scanner/README.md b/configs/textrecog/robust_scanner/README.md
new file mode 100644
index 00000000..01e42971
--- /dev/null
+++ b/configs/textrecog/robust_scanner/README.md
@@ -0,0 +1,51 @@
+# RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@inproceedings{yue2020robustscanner,
+ title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition},
+ author={Yue, Xiaoyu and Kuang, Zhanghui and Lin, Chenhao and Sun, Hongbin and Zhang, Wayne},
+ booktitle={European Conference on Computer Vision},
+ year={2020}
+}
+```
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :--------: | :----------: | :--------: | :----------------------: |
+| icdar_2011 | 3567 | 20 | real |
+| icdar_2013 | 848 | 20 | real |
+| icdar2015 | 4468 | 20 | real |
+| coco_text | 42142 | 20 | real |
+| IIIT5K | 2000 | 20 | real |
+| SynthText | 2400000 | 1 | synth |
+| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) |
+| Syn90k | 2400000 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | type |
+| :-----: | :----------: | :-------------------------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| IC15 | 2077 | irregular |
+| SVTP | 645 | irregular, 639 in [[1]](#1) |
+| CT80 | 288 | irregular |
+
+## Results and Models
+
+| Methods | GPUs | | Regular Text | | | | Irregular Text | | download |
+| :-----------------------------------------------------------------: | :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
+| [RobustScanner](configs/textrecog/robust_scanner/robustscanner_r31_academic.py) | 16 | 95.1 | 89.2 | 93.1 | | 77.8 | 80.3 | 90.3 | [model](https://download.openmmlab.com/mmocr/textrecog/robustscanner/robustscanner_r31_academic-5f05874f.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/robustscanner/20210401_170932.log.json) |
+
+## References
+
+[1] Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019.
diff --git a/configs/textrecog/robust_scanner/robust_scanner_toy_dataset.py b/configs/textrecog/robust_scanner/robust_scanner_toy_dataset.py
new file mode 100644
index 00000000..6fc6c125
--- /dev/null
+++ b/configs/textrecog/robust_scanner/robust_scanner_toy_dataset.py
@@ -0,0 +1,12 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/recog_models/robust_scanner.py',
+ '../../_base_/recog_datasets/toy_dataset.py'
+]
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 6
diff --git a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py
new file mode 100644
index 00000000..e90dd75d
--- /dev/null
+++ b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py
@@ -0,0 +1,197 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/recog_models/robust_scanner.py'
+]
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+
+train_prefix = 'data/mixture/'
+
+train_img_prefix1 = train_prefix + 'icdar_2011'
+train_img_prefix2 = train_prefix + 'icdar_2013'
+train_img_prefix3 = train_prefix + 'icdar_2015'
+train_img_prefix4 = train_prefix + 'coco_text'
+train_img_prefix5 = train_prefix + 'III5K'
+train_img_prefix6 = train_prefix + 'SynthText_Add'
+train_img_prefix7 = train_prefix + 'SynthText'
+train_img_prefix8 = train_prefix + 'Syn90k'
+
+train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
+train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
+train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
+train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
+train_ann_file5 = train_prefix + 'III5K/train_label.txt',
+train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
+train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
+train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
+
+train1 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix1,
+ ann_file=train_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=20,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train2 = {key: value for key, value in train1.items()}
+train2['img_prefix'] = train_img_prefix2
+train2['ann_file'] = train_ann_file2
+
+train3 = {key: value for key, value in train1.items()}
+train3['img_prefix'] = train_img_prefix3
+train3['ann_file'] = train_ann_file3
+
+train4 = {key: value for key, value in train1.items()}
+train4['img_prefix'] = train_img_prefix4
+train4['ann_file'] = train_ann_file4
+
+train5 = {key: value for key, value in train1.items()}
+train5['img_prefix'] = train_img_prefix5
+train5['ann_file'] = train_ann_file5
+
+train6 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix6,
+ ann_file=train_ann_file6,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train7 = {key: value for key, value in train6.items()}
+train7['img_prefix'] = train_img_prefix7
+train7['ann_file'] = train_ann_file7
+
+train8 = {key: value for key, value in train6.items()}
+train8['img_prefix'] = train_img_prefix8
+train8['ann_file'] = train_ann_file8
+
+test_prefix = 'data/mixture/'
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'icdar_2015/'
+test_img_prefix5 = test_prefix + 'svtp/'
+test_img_prefix6 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
+test_ann_file5 = test_prefix + 'svtp/test_label.txt'
+test_ann_file6 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+test5 = {key: value for key, value in test1.items()}
+test5['img_prefix'] = test_img_prefix5
+test5['ann_file'] = test_ann_file5
+
+test6 = {key: value for key, value in test1.items()}
+test6['img_prefix'] = test_img_prefix6
+test6['ann_file'] = test_ann_file6
+
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=2,
+ train=dict(
+ type='ConcatDataset',
+ datasets=[
+ train1, train2, train3, train4, train5, train6, train7, train8
+ ]),
+ val=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]),
+ test=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/sar/README.md b/configs/textrecog/sar/README.md
new file mode 100644
index 00000000..8854ae04
--- /dev/null
+++ b/configs/textrecog/sar/README.md
@@ -0,0 +1,67 @@
+# Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@inproceedings{li2019show,
+ title={Show, attend and read: A simple and strong baseline for irregular text recognition},
+ author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={33},
+ number={01},
+ pages={8610--8617},
+ year={2019}
+}
+```
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :--------: | :----------: | :--------: | :----------------------: |
+| icdar_2011 | 3567 | 20 | real |
+| icdar_2013 | 848 | 20 | real |
+| icdar2015 | 4468 | 20 | real |
+| coco_text | 42142 | 20 | real |
+| IIIT5K | 2000 | 20 | real |
+| SynthText | 2400000 | 1 | synth |
+| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) |
+| Syn90k | 2400000 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | type |
+| :-----: | :----------: | :-------------------------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| IC15 | 2077 | irregular |
+| SVTP | 645 | irregular, 639 in [[1]](#1) |
+| CT80 | 288 | irregular |
+
+## Results and Models
+
+| Methods | Backbone | Decoder | | Regular Text | | | | Irregular Text | | download |
+| :-----------------------------------------------------------------: | :---------: | :------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| | | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
+| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 95.0 | 89.6 | 93.7 | | 79.0 | 82.2 | 88.9 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) |
+| [SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 95.2 | 88.7 | 92.4 | | 78.2 | 81.9 | 89.6 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json) |
+
+**Notes:**
+
+- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
+- We did not use beam search during decoding.
+- We implemented two kinds of decoder. Namely, `ParallelSARDecoder` and `SequentialSARDecoder`.
+ - `ParallelSARDecoder`: Parallel decoding during training with `LSTM` layer. It would be faster.
+ - `SequentialSARDecoder`: Sequential Decoding during training with `LSTMCell`. It would be easier to understand.
+- For train dataset.
+ - We did not construct distinct data groups (20 groups in [[1]](#1)) to train the model group-by-group since it would render model training too complicated.
+ - Instead, we randomly selected `2.4m` patches from `Syn90k`, `2.4m` from `SynthText` and `1.2m` from `SynthAdd`, and grouped all data together. See [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_academic.py) for details.
+- We used 48 GPUs with `total_batch_size = 64 * 48` in the experiment above to speedup training, while keeping the `initial lr = 1e-3` unchanged.
+
+## References
+
+[1] Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019.
diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
new file mode 100644
index 00000000..4e405227
--- /dev/null
+++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
@@ -0,0 +1,219 @@
+_base_ = ['../../_base_/default_runtime.py']
+
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+model = dict(
+ type='SARNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='SAREncoder',
+ enc_bi_rnn=False,
+ enc_do_rnn=0.1,
+ enc_gru=False,
+ ),
+ decoder=dict(
+ type='ParallelSARDecoder',
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0,
+ dec_gru=False,
+ pred_dropout=0.1,
+ d_k=512,
+ pred_concat=True),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+
+train_prefix = 'data/mixture/'
+
+train_img_prefix1 = train_prefix + 'icdar_2011'
+train_img_prefix2 = train_prefix + 'icdar_2013'
+train_img_prefix3 = train_prefix + 'icdar_2015'
+train_img_prefix4 = train_prefix + 'coco_text'
+train_img_prefix5 = train_prefix + 'III5K'
+train_img_prefix6 = train_prefix + 'SynthText_Add'
+train_img_prefix7 = train_prefix + 'SynthText'
+train_img_prefix8 = train_prefix + 'Syn90k'
+
+train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
+train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
+train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
+train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
+train_ann_file5 = train_prefix + 'III5K/train_label.txt',
+train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
+train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
+train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
+
+train1 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix1,
+ ann_file=train_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=20,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train2 = {key: value for key, value in train1.items()}
+train2['img_prefix'] = train_img_prefix2
+train2['ann_file'] = train_ann_file2
+
+train3 = {key: value for key, value in train1.items()}
+train3['img_prefix'] = train_img_prefix3
+train3['ann_file'] = train_ann_file3
+
+train4 = {key: value for key, value in train1.items()}
+train4['img_prefix'] = train_img_prefix4
+train4['ann_file'] = train_ann_file4
+
+train5 = {key: value for key, value in train1.items()}
+train5['img_prefix'] = train_img_prefix5
+train5['ann_file'] = train_ann_file5
+
+train6 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix6,
+ ann_file=train_ann_file6,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train7 = {key: value for key, value in train6.items()}
+train7['img_prefix'] = train_img_prefix7
+train7['ann_file'] = train_ann_file7
+
+train8 = {key: value for key, value in train6.items()}
+train8['img_prefix'] = train_img_prefix8
+train8['ann_file'] = train_ann_file8
+
+test_prefix = 'data/mixture/'
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'icdar_2015/'
+test_img_prefix5 = test_prefix + 'svtp/'
+test_img_prefix6 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
+test_ann_file5 = test_prefix + 'svtp/test_label.txt'
+test_ann_file6 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+test5 = {key: value for key, value in test1.items()}
+test5['img_prefix'] = test_img_prefix5
+test5['ann_file'] = test_ann_file5
+
+test6 = {key: value for key, value in test1.items()}
+test6['img_prefix'] = test_img_prefix6
+test6['ann_file'] = test_ann_file6
+
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=2,
+ train=dict(
+ type='ConcatDataset',
+ datasets=[
+ train1, train2, train3, train4, train5, train6, train7, train8
+ ]),
+ val=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]),
+ test=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
new file mode 100755
index 00000000..0c8b53e2
--- /dev/null
+++ b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
@@ -0,0 +1,110 @@
+_base_ = [
+ '../../_base_/default_runtime.py', '../../_base_/recog_models/sar.py'
+]
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+img_prefix = 'tests/data/ocr_toy_dataset/imgs'
+train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
+train1 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
+train2 = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file2,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=100,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
+test = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=test_anno_file1,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train1, train2]),
+ val=dict(type='ConcatDataset', datasets=[test]),
+ test=dict(type='ConcatDataset', datasets=[test]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
new file mode 100644
index 00000000..6fa00dd7
--- /dev/null
+++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
@@ -0,0 +1,219 @@
+_base_ = ['../../_base_/default_runtime.py']
+
+label_convertor = dict(
+ type='AttnConvertor', dict_type='DICT90', with_unknown=True)
+
+model = dict(
+ type='SARNet',
+ backbone=dict(type='ResNet31OCR'),
+ encoder=dict(
+ type='SAREncoder',
+ enc_bi_rnn=False,
+ enc_do_rnn=0.1,
+ enc_gru=False,
+ ),
+ decoder=dict(
+ type='SequentialSARDecoder',
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_do_rnn=0,
+ dec_gru=False,
+ pred_dropout=0.1,
+ d_k=512,
+ pred_concat=True),
+ loss=dict(type='SARLoss'),
+ label_convertor=label_convertor,
+ max_seq_len=30)
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-3)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
+ ]),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiRotateAugOCR',
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True,
+ width_downsample_ratio=0.25),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ])
+]
+
+dataset_type = 'OCRDataset'
+
+train_prefix = 'data/mixture/'
+
+train_img_prefix1 = train_prefix + 'icdar_2011'
+train_img_prefix2 = train_prefix + 'icdar_2013'
+train_img_prefix3 = train_prefix + 'icdar_2015'
+train_img_prefix4 = train_prefix + 'coco_text'
+train_img_prefix5 = train_prefix + 'III5K'
+train_img_prefix6 = train_prefix + 'SynthText_Add'
+train_img_prefix7 = train_prefix + 'SynthText'
+train_img_prefix8 = train_prefix + 'Syn90k'
+
+train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
+train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
+train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
+train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
+train_ann_file5 = train_prefix + 'III5K/train_label.txt',
+train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
+train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
+train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
+
+train1 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix1,
+ ann_file=train_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=20,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train2 = {key: value for key, value in train1.items()}
+train2['img_prefix'] = train_img_prefix2
+train2['ann_file'] = train_ann_file2
+
+train3 = {key: value for key, value in train1.items()}
+train3['img_prefix'] = train_img_prefix3
+train3['ann_file'] = train_ann_file3
+
+train4 = {key: value for key, value in train1.items()}
+train4['img_prefix'] = train_img_prefix4
+train4['ann_file'] = train_ann_file4
+
+train5 = {key: value for key, value in train1.items()}
+train5['img_prefix'] = train_img_prefix5
+train5['ann_file'] = train_ann_file5
+
+train6 = dict(
+ type=dataset_type,
+ img_prefix=train_img_prefix6,
+ ann_file=train_ann_file6,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+train7 = {key: value for key, value in train6.items()}
+train7['img_prefix'] = train_img_prefix7
+train7['ann_file'] = train_ann_file7
+
+train8 = {key: value for key, value in train6.items()}
+train8['img_prefix'] = train_img_prefix8
+train8['ann_file'] = train_ann_file8
+
+test_prefix = 'data/mixture/'
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'icdar_2015/'
+test_img_prefix5 = test_prefix + 'svtp/'
+test_img_prefix6 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
+test_ann_file5 = test_prefix + 'svtp/test_label.txt'
+test_ann_file6 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+test5 = {key: value for key, value in test1.items()}
+test5['img_prefix'] = test_img_prefix5
+test5['ann_file'] = test_ann_file5
+
+test6 = {key: value for key, value in test1.items()}
+test6['img_prefix'] = test_img_prefix6
+test6['ann_file'] = test_ann_file6
+
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=2,
+ train=dict(
+ type='ConcatDataset',
+ datasets=[
+ train1, train2, train3, train4, train5, train6, train7, train8
+ ]),
+ val=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]),
+ test=dict(
+ type='ConcatDataset',
+ datasets=[test1, test2, test3, test4, test5, test6]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/seg/README.md b/configs/textrecog/seg/README.md
new file mode 100644
index 00000000..28a96425
--- /dev/null
+++ b/configs/textrecog/seg/README.md
@@ -0,0 +1,43 @@
+# SegOCR Simple Baseline.
+
+## Introduction
+
+[ALGORITHM]
+
+```bibtex
+@unpublished{key,
+ title={SegOCR Simple Baseline.},
+ author={},
+ note={Unpublished Manuscript},
+ year={2021}
+}
+```
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :-------: | :----------: | :--------: | :----: |
+| SynthText | 7266686 | 1 | synth |
+
+### Test Dataset
+
+| testset | instance_num | type |
+| :-----: | :----------: | :-------: |
+| IIIT5K | 3000 | regular |
+| SVT | 647 | regular |
+| IC13 | 1015 | regular |
+| CT80 | 288 | irregular |
+
+## Results and Models
+
+|Backbone|Neck|Head|||Regular Text|||Irregular Text|download
+| :-------------: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: | :-----: |
+|||||IIIT5K|SVT|IC13||CT80|
+|R31-1/16|FPNOCR|1x||90.9|81.8|90.7||80.9|[model](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic-72235b11.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/seg/20210325_112835.log.json) |
+
+**Notes:**
+
+- `R31-1/16` means the size (both height and width ) of feature from backbone is 1/16 of input image.
+- `1x` means the size (both height and width) of feature from head is the same with input image.
diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py
new file mode 100644
index 00000000..8a568f8e
--- /dev/null
+++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py
@@ -0,0 +1,160 @@
+_base_ = ['../../_base_/default_runtime.py']
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-4)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+model = dict(
+ type='SegRecognizer',
+ backbone=dict(
+ type='ResNet31OCR',
+ layers=[1, 2, 5, 3],
+ channels=[32, 64, 128, 256, 512, 512],
+ out_indices=[0, 1, 2, 3],
+ stage4_pool_cfg=dict(kernel_size=2, stride=2),
+ last_stage_pool=True),
+ neck=dict(
+ type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
+ head=dict(
+ type='SegHead',
+ in_channels=256,
+ upsample_param=dict(scale_factor=2.0, mode='nearest')),
+ loss=dict(
+ type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=True),
+ label_convertor=label_convertor)
+
+find_unused_parameters = True
+
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+gt_label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='RandomPaddingOCR',
+ max_ratio=[0.15, 0.2, 0.15, 0.2],
+ box_type='char_quads'),
+ dict(type='OpencvToPil'),
+ dict(
+ type='RandomRotateImageBox',
+ min_angle=-17,
+ max_angle=17,
+ box_type='char_quads'),
+ dict(type='PilToOpencv'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=512,
+ keep_aspect_ratio=True),
+ dict(
+ type='OCRSegTargets',
+ label_convertor=gt_label_convertor,
+ box_type='char_quads'),
+ dict(type='RandomRotateTextDet', rotate_ratio=0.5, max_angle=15),
+ dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
+ dict(type='ToTensorOCR'),
+ dict(type='FancyPCA'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='CustomFormatBundle',
+ keys=['gt_kernels'],
+ visualize=dict(flag=False, boundary_key=None),
+ call_super=False),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_kernels'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='ResizeOCR',
+ height=64,
+ min_width=64,
+ max_width=None,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(type='CustomFormatBundle', call_super=False),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=['filename', 'ori_shape', 'img_shape'])
+]
+
+train_img_root = 'data/mixture/'
+
+train_img_prefix = train_img_root + 'SynthText'
+
+train_ann_file = train_img_root + 'SynthText/instances_train.txt'
+
+train = dict(
+ type='OCRSegDataset',
+ img_prefix=train_img_prefix,
+ ann_file=train_ann_file,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
+ pipeline=train_pipeline,
+ test_mode=False)
+
+dataset_type = 'OCRDataset'
+test_prefix = 'data/mixture/'
+
+test_img_prefix1 = test_prefix + 'IIIT5K/'
+test_img_prefix2 = test_prefix + 'svt/'
+test_img_prefix3 = test_prefix + 'icdar_2013/'
+test_img_prefix4 = test_prefix + 'ct80/'
+
+test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
+test_ann_file2 = test_prefix + 'svt/test_label.txt'
+test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
+test_ann_file4 = test_prefix + 'ct80/test_label.txt'
+
+test1 = dict(
+ type=dataset_type,
+ img_prefix=test_img_prefix1,
+ ann_file=test_ann_file1,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=1,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=test_pipeline,
+ test_mode=True)
+
+test2 = {key: value for key, value in test1.items()}
+test2['img_prefix'] = test_img_prefix2
+test2['ann_file'] = test_ann_file2
+
+test3 = {key: value for key, value in test1.items()}
+test3['img_prefix'] = test_img_prefix3
+test3['ann_file'] = test_ann_file3
+
+test4 = {key: value for key, value in test1.items()}
+test4['img_prefix'] = test_img_prefix4
+test4['ann_file'] = test_ann_file4
+
+data = dict(
+ samples_per_gpu=16,
+ workers_per_gpu=2,
+ train=dict(type='ConcatDataset', datasets=[train]),
+ val=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]),
+ test=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]))
+
+evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py
new file mode 100644
index 00000000..63b3d08c
--- /dev/null
+++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py
@@ -0,0 +1,35 @@
+_base_ = [
+ '../../_base_/default_runtime.py',
+ '../../_base_/recog_datasets/seg_toy_dataset.py'
+]
+
+# optimizer
+optimizer = dict(type='Adam', lr=1e-4)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(policy='step', step=[3, 4])
+total_epochs = 5
+
+label_convertor = dict(
+ type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
+
+model = dict(
+ type='SegRecognizer',
+ backbone=dict(
+ type='ResNet31OCR',
+ layers=[1, 2, 5, 3],
+ channels=[32, 64, 128, 256, 512, 512],
+ out_indices=[0, 1, 2, 3],
+ stage4_pool_cfg=dict(kernel_size=2, stride=2),
+ last_stage_pool=True),
+ neck=dict(
+ type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
+ head=dict(
+ type='SegHead',
+ in_channels=256,
+ upsample_param=dict(scale_factor=2.0, mode='nearest')),
+ loss=dict(
+ type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=False),
+ label_convertor=label_convertor)
+
+find_unused_parameters = True
diff --git a/demo/demo_text_det.jpg b/demo/demo_text_det.jpg
new file mode 100644
index 00000000..d23de3cd
Binary files /dev/null and b/demo/demo_text_det.jpg differ
diff --git a/demo/demo_text_recog.jpg b/demo/demo_text_recog.jpg
new file mode 100644
index 00000000..d9915983
Binary files /dev/null and b/demo/demo_text_recog.jpg differ
diff --git a/demo/image_demo.py b/demo/image_demo.py
new file mode 100644
index 00000000..dbccf784
--- /dev/null
+++ b/demo/image_demo.py
@@ -0,0 +1,44 @@
+from argparse import ArgumentParser
+
+import mmcv
+
+from mmdet.apis import init_detector
+from mmocr.apis.inference import model_inference
+from mmocr.datasets import build_dataset # noqa: F401
+from mmocr.models import build_detector # noqa: F401
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('img', help='Image file.')
+ parser.add_argument('config', help='Config file.')
+ parser.add_argument('checkpoint', help='Checkpoint file.')
+ parser.add_argument('save_path', help='Path to save visualized image.')
+ parser.add_argument(
+ '--device', default='cuda:0', help='Device used for inference.')
+ parser.add_argument(
+ '--imshow',
+ action='store_true',
+ help='Whether show image with OpenCV.')
+ args = parser.parse_args()
+
+ # build the model from a config file and a checkpoint file
+ model = init_detector(args.config, args.checkpoint, device=args.device)
+ if model.cfg.data.test['type'] == 'ConcatDataset':
+ model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
+ 0].pipeline
+
+ # test a single image
+ result = model_inference(model, args.img)
+ print(f'result: {result}')
+
+ # show the results
+ img = model.show_result(args.img, result, out_file=None, show=False)
+
+ mmcv.imwrite(img, args.save_path)
+ if args.imshow:
+ mmcv.imshow(img, 'predicted results')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py
new file mode 100644
index 00000000..45e23c5b
--- /dev/null
+++ b/demo/webcam_demo.py
@@ -0,0 +1,52 @@
+import argparse
+
+import cv2
+import torch
+
+from mmdet.apis import init_detector
+from mmocr.apis import model_inference
+from mmocr.datasets import build_dataset # noqa: F401
+from mmocr.models import build_detector # noqa: F401
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='MMDetection webcam demo.')
+ parser.add_argument('config', help='Test config file path.')
+ parser.add_argument('checkpoint', help='Checkpoint file.')
+ parser.add_argument(
+ '--device', type=str, default='cuda:0', help='CPU/CUDA device option.')
+ parser.add_argument(
+ '--camera-id', type=int, default=0, help='Camera device id.')
+ parser.add_argument(
+ '--score-thr', type=float, default=0.5, help='Bbox score threshold.')
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ device = torch.device(args.device)
+
+ model = init_detector(args.config, args.checkpoint, device=device)
+ if model.cfg.data.test['type'] == 'ConcatDataset':
+ model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
+ 0].pipeline
+
+ camera = cv2.VideoCapture(args.camera_id)
+
+ print('Press "Esc", "q" or "Q" to exit.')
+ while True:
+ ret_val, img = camera.read()
+ result = model_inference(model, img)
+
+ ch = cv2.waitKey(1)
+ if ch == 27 or ch == ord('q') or ch == ord('Q'):
+ break
+
+ model.show_result(
+ img, result, score_thr=args.score_thr, wait_time=1, show=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 00000000..afd8fe55
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,28 @@
+ARG PYTORCH="1.5"
+ARG CUDA="10.1"
+ARG CUDNN="7"
+
+FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
+
+ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
+ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
+ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
+
+RUN apt-get update && apt-get install -y git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN conda clean --all
+RUN pip install mmcv-full==1.2.6+torch1.5.0+cu101 -f https://download.openmmlab.com/mmcv/dist/index.html
+
+RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdet
+WORKDIR /mmdet
+RUN git checkout -b v2.9.0 v2.9.0
+RUN pip install -r requirements.txt
+RUN pip install .
+
+RUN git clone https://github.com/open-mmlab/mmocr.git /mmocr
+WORKDIR /mmocr
+ENV FORCE_CUDA="1"
+RUN pip install -r requirements.txt
+RUN pip install --no-cache-dir -e .
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000..d4bb2cbb
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/api.rst b/docs/api.rst
new file mode 100644
index 00000000..a23ab961
--- /dev/null
+++ b/docs/api.rst
@@ -0,0 +1,15 @@
+API Reference
+=============
+
+mmocr.apis
+-------------
+.. automodule:: mmocr.apis
+ :members:
+
+mmocr.core
+-------------
+
+evaluation
+^^^^^^^^^^
+.. automodule:: mmocr.core.evaluation
+ :members:
diff --git a/docs/changelog.md b/docs/changelog.md
new file mode 100644
index 00000000..8a802039
--- /dev/null
+++ b/docs/changelog.md
@@ -0,0 +1 @@
+## Changelog
diff --git a/docs/code_of_conduct.md b/docs/code_of_conduct.md
new file mode 100644
index 00000000..89039925
--- /dev/null
+++ b/docs/code_of_conduct.md
@@ -0,0 +1,93 @@
+
+# Contributor Covenant Code of Conduct
+
+
+- [Contributor Covenant Code of Conduct](#contributor-covenant-code-of-conduct)
+ - [Our Pledge](#our-pledge)
+ - [Our Standards](#our-standards)
+ - [Our Responsibilities](#our-responsibilities)
+ - [Scope](#scope)
+ - [Enforcement](#enforcement)
+ - [Attribution](#attribution)
+
+
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to making participation in our project and
+our community a harassment-free experience for everyone, regardless of age, body
+size, disability, ethnicity, sex characteristics, gender identity and expression,
+level of experience, education, socio-economic status, nationality, personal
+appearance, race, religion, or sexual identity and orientation.
+
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy towards other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+ advances
+* Trolling, insulting/derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or electronic
+ address, without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned to this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+
+## Scope
+
+This Code of Conduct applies both within project spaces and in public spaces
+when an individual is representing the project or its community. Examples of
+representing a project or community include using an official project e-mail
+address, posting via an official social media account, or acting as an appointed
+representative at an online or offline event. Representation of a project may be
+further defined and clarified by project maintainers.
+
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the project team at chenkaidev@gmail.com. All
+complaints will be reviewed and investigated and will result in a response that
+is deemed necessary and appropriate to the circumstances. The project team is
+obligated to maintain confidentiality with regard to the reporter of an incident.
+Further details of specific enforcement policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership.
+
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
+available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see
+https://www.contributor-covenant.org/faq
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 00000000..e709591e
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,83 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+
+import os
+import subprocess
+import sys
+
+sys.path.insert(0, os.path.abspath('..'))
+
+# -- Project information -----------------------------------------------------
+
+project = 'MMOCR'
+copyright = '2020-2030, OpenMMLab'
+author = 'OpenMMLab'
+
+# The full version, including alpha/beta/rc tags
+release = '0.1.0'
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.viewcode',
+ 'recommonmark',
+ 'sphinx_markdown_tables',
+]
+
+autodoc_mock_imports = ['torch', 'torchvision', 'mmcv', 'mmocr.version']
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+source_suffix = {
+ '.rst': 'restructuredtext',
+ '.md': 'markdown',
+}
+
+# The master toctree document.
+master_doc = 'index'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_rtd_theme'
+
+master_doc = 'index'
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = []
+
+
+def builder_inited_handler(app):
+ subprocess.run(['./merge_docs.sh'])
+ subprocess.run(['./stats.py'])
+
+
+def setup(app):
+ app.connect('builder-inited', builder_inited_handler)
diff --git a/docs/contributing.md b/docs/contributing.md
new file mode 100644
index 00000000..2ac726b1
--- /dev/null
+++ b/docs/contributing.md
@@ -0,0 +1,187 @@
+
+# Contributing to mmocr
+
+All kinds of contributions are welcome, including but not limited to the following.
+
+- Fixes (typo, bugs)
+- New features and components
+- Enhancement like function speedup
+
+
+- [Contributing to mmocr](#contributing-to-mmocr)
+ - [Workflow](#workflow)
+ - [Step 1: Create a Fork](#step-1-create-a-fork)
+ - [Step 2: Develop a new feature](#step-2-develop-a-new-feature)
+ - [Step 2.1: Keep your fork up to date](#step-21-keep-your-fork-up-to-date)
+ - [Step 2.2: Create a feature branch](#step-22-create-a-feature-branch)
+ - [Create an issue on github](#create-an-issue-on-github)
+ - [Create branch](#create-branch)
+ - [Step 2.3: Develop and test ](#step-23-develop-and-test-your_new_feature)
+ - [Step 2.4: Prepare to Pull Request](#step-24-prepare-to-pull-request)
+ - [Merge official repo updates to your fork](#merge-official-repo-updates-to-your-fork)
+ - [Push branch to your remote forked repo,](#push-your_new_feature-branch-to-your-remote-forked-repo)
+ - [Step 2.5: Create a Pull Request](#step-25-create-a-pull-request)
+ - [Step 2.6: Review code](#step-26-review-code)
+ - [Step 2.7: Revise (optional)](#step-27-revise-your_new_feature--optional)
+ - [Step 2.8: Delete branch if your PR is accepted.](#step-28-delete-your_new_feature-branch-if-your-pr-is-accepted)
+ - [Code style](#code-style)
+ - [Python](#python)
+ - [C++ and CUDA](#c-and-cuda)
+
+
+
+## Workflow
+
+This document describes the fork & merge request workflow that should be used when contributing to **MMOCR**.
+
+The official public [repository](https://github.com/open-mmlab/mmocr) holds two branches with an infinite lifetime only:
++ master
++ develop
+
+The *master* branch is the main branch where the source code of **HEAD** always reflects a *production-ready state*.
+
+The *develop* branch is the branch where the source code of **HEAD** always reflects a state with the latest development changes for the next release.
+
+Feature branches are used to develop new features for the upcoming or a distant future release.
+
+
+
+All new developers to **MMOCR** need to follow the following steps:
+
+
+### Step 1: Create a Fork
+
+1. Fork the repo on GitHub or GitLab to your personal account. Click the `Fork` button on the [project page](https://github.com/open-mmlab/mmocr).
+
+2. Clone your new forked repo to your computer.
+```
+git clone https://github.com//mmocr.git
+```
+3. Add the official repo as an upstream:
+```
+git remote add upstream https://github.com/open-mmlab/mmocr.git
+```
+
+
+### Step 2: Develop a new feature
+
+
+#### Step 2.1: Keep your fork up to date
+
+Whenever you want to update your fork with the latest upstream changes, you need to fetch the upstream repo's branches and latest commits to bring them into your repository:
+
+```
+# Fetch from upstream remote
+git fetch upstream
+
+# Update your master branch
+git checkout master
+git rebase upstream/master
+git push origin master
+
+# Update your develop branch
+git checkout develop
+git rebase upsteam/develop
+git push origin develop
+```
+
+
+#### Step 2.2: Create a feature branch
+
+##### Create an issue on [github](https://github.com/open-mmlab/mmocr)
+- The title of the issue should be one of the following formats: `[Feature]: xxx`, `[Fix]: xxx`, `[Enhance]: xxx`, `[Refactor]: xxx`.
+- More details can be written in comments.
+
+
+##### Create branch
+```
+git checkout -b feature/iss_ develop
+# index is the issue number above
+```
+Till now, your fork has three branches as follows:
+
+
+
+
+#### Step 2.3: Develop and test
+
+Develop your new feature and test it to make sure it works well.
+
+Pls run
+```
+pre-commit run --all-files
+pytest tests
+```
+and fix all failures before every git commit.
+```
+git commit -m "fix #: "
+```
+**Note:**
+- is the [issue](#step2.2) number.
+
+
+#### Step 2.4: Prepare to Pull Request
+- Make sure to link your pull request to the related issue. Please refer to the [instructon](https://docs.github.com/en/github/managing-your-work-on-github/linking-a-pull-request-to-an-issue)
+
+
+
+##### Merge official repo updates to your fork
+
+```
+# fetch from upstream remote. i.e., the official repo
+git fetch upstream
+
+# update the develop branch of your fork
+git checkout develop
+git rebase upsteam/develop
+git push origin develop
+
+# update the branch
+git checkout
+git rebase develop
+# solve conflicts if any and Test
+```
+
+
+##### Push branch to your remote forked repo,
+```
+git checkout
+git push origin
+```
+
+#### Step 2.5: Create a Pull Request
+
+Go to the page for your fork on GitHub, select your new feature branch, and click the pull request button to integrate your feature branch into the upstream remote’s develop branch.
+
+
+#### Step 2.6: Review code
+
+
+
+#### Step 2.7: Revise (optional)
+If PR is not accepted, pls follow Step 2.1, 2.3, 2.4 and 2.5 till your PR is accepted.
+
+
+#### Step 2.8: Delete branch if your PR is accepted.
+```
+git branch -d
+git push origin :
+```
+
+
+## Code style
+
+
+### Python
+We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
+
+We use the following tools for linting and formatting:
+- [flake8](http://flake8.pycqa.org/en/latest/): linter
+- [yapf](https://github.com/google/yapf): formatter
+- [isort](https://github.com/timothycrosley/isort): sort imports
+
+>Before you create a PR, make sure that your code lints and is formatted by yapf.
+
+
+### C++ and CUDA
+We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
diff --git a/docs/datasets.md b/docs/datasets.md
new file mode 100644
index 00000000..da9e06ba
--- /dev/null
+++ b/docs/datasets.md
@@ -0,0 +1,208 @@
+
+# Datasets Preparation
+This page lists the datasets which are commonly used in text detection, text recognition and key information extraction, and their download links.
+
+
+- [Datasets Preparation](#datasets-preparation)
+ - [Text Detection](#text-detection)
+ - [Text Recognition](#text-recognition)
+ - [Key Information Extraction](#key-information-extraction)
+
+
+
+## Text Detection
+**The structure of the text detection dataset directory is organized as follows.**
+```
+├── ctw1500
+│  ├── imgs
+│  ├── instances_test.json
+│  └── instances_training.json
+├── icdar2015
+│  ├── imgs
+│  ├── instances_test.json
+│  └── instances_training.json
+├── icdar2017
+│  ├── imgs
+│  ├── instances_training.json
+│  └── instances_val.json
+├── synthtext
+│  ├── imgs
+│  └── instances_training.lmdb
+```
+| Dataset | | Images | | | Annotation Files | | | Note | |
+|:---------:|:-:|:--------------------------:|:-:|:--------------------------------------------:|:---------------------------------------:|:----------------------------------------:|:-:|:----:|---|
+| | | | | training | validation | testing | | | |
+| CTW1500 | | [homepage](https://github.com/Yuliang-Liu/Curve-Text-Detector) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/ctw1500/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/ctw1500/instances_test.json) | | | |
+| ICDAR2015 | | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) | | | |
+| ICDAR2017 | | [homepage](https://rrc.cvc.uab.es/?ch=8&com=downloads) | [renamed_imgs](https://download.openmmlab.com/mmocr/data/icdar2017/renamed_imgs.tar) | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_training.json) | [instances_val.json](https://openmmlab) | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_test.json) | | | |
+| Synthtext | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | | [instances_training.lmdb](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.lmdb)|-| | | |
+
+- For `icdar2015`:
+ - Step1: Download `ch4_training_images.zip` and `ch4_test_images.zip` from [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads)
+ - Step2: Download [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) and [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json)
+ - Step3:
+ ```bash
+ mkdir icdar2015 && cd icdar2015
+ mv /path/to/instances_training.json .
+ mv /path/to/instances_test.json .
+
+ mkdir imgs && cd imgs
+ ln -s /path/to/ch4_training_images training
+ ln -s /path/to/ch4_test_images test
+ ```
+- For `icdar2017`:
+ - To avoid the effect of rotation when load `jpg` with opencv, We provide re-saved `png` format image in [renamed_images](https://download.openmmlab.com/mmocr/data/icdar2017/renamed_imgs.tar). You can copy these images to `imgs`.
+
+
+## Text Recognition
+**The structure of the text recognition dataset directory is organized as follows.**
+
+```
+├── mixture
+│  ├── coco_text
+│ │ ├── train_label.txt
+│ │ ├── train_words
+│  ├── icdar_2011
+│ │ ├── training_label.txt
+│ │ ├── Challenge1_Training_Task3_Images_GT
+│  ├── icdar_2013
+│ │ ├── train_label.txt
+│ │ ├── test_label_1015.txt
+│ │ ├── test_label_1095.txt
+│ │ ├── Challenge2_Training_Task3_Images_GT
+│ │ ├── Challenge2_Test_Task3_Images
+│  ├── icdar_2015
+│ │ ├── train_label.txt
+│ │ ├── test_label.txt
+│ │ ├── ch4_training_word_images_gt
+│ │ ├── ch4_test_word_images_gt
+│  ├── III5K
+│ │ ├── train_label.txt
+│ │ ├── test_label.txt
+│ │ ├── train
+│ │ ├── test
+│  ├── ct80
+│ │ ├── test_label.txt
+│ │ ├── image
+│  ├── svt
+│ │ ├── test_label.txt
+│ │ ├── image
+│  ├── svtp
+│ │ ├── test_label.txt
+│ │ ├── image
+│  ├── Synth90k
+│ │ ├── shuffle_labels.txt
+│ │ ├── label.lmdb
+│ │ ├── mnt
+│  ├── SynthText
+│ │ ├── shuffle_labels.txt
+│ │ ├── instances_train.txt
+│ │ ├── label.lmdb
+│ │ ├── synthtext
+│  ├── SynthAdd
+│ │ ├── label.txt
+│ │ ├── SynthText_Add
+
+```
+| Dataset | | images | annotation file | annotation file | Note |
+|:----------:|:-:|:---------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------:|:----:|
+|| | |training | test | |
+| coco_text ||[homepage](https://rrc.cvc.uab.es/?ch=5&com=downloads) |[train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt) |- | |
+| icdar_2011 ||[homepage](http://www.cvc.uab.es/icdar2011competition/?com=downloads) |[train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) |- | |
+| icdar_2013 | | [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt) | [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) | |
+| icdar_2015 | | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt) | |
+| IIIT5K | | [homepage](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt) | |
+| ct80 | | - |-|[test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt)||
+| svt | | [homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | |
+| svtp | | - | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | |
+| Synth90k | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/label.lmdb) | - | |
+| SynthText | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.lmdb) | - | |
+| SynthAdd | | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)|- | |
+
+- For `icdar_2013`:
+ - Step1: Download `Challenge2_Test_Task3_Images.zip` and `Challenge2_Training_Task3_Images_GT.zip` from [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads)
+ - Step2: Download [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) and [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt)
+- For `icdar_2015`:
+ - Step1: Download `ch4_training_word_images_gt.zip` and `ch4_test_word_images_gt.zip` from [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads)
+ - Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) and [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt)
+- For `IIIT5K`:
+ - Step1: Download `IIIT5K-Word_V3.0.tar.gz` from [homepage](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)
+ - Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) and [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt)
+- For `svt`:
+ - Step1: Download `svt.zip` form [homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)
+ - Step2: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt)
+- For `ct80`:
+ - Step1: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt)
+- For `svtp`:
+ - Step1: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt)
+- For `coco_text`:
+ - Step1: Download from [homepage](https://rrc.cvc.uab.es/?ch=5&com=downloads)
+ - Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt)
+
+- For `Syn90k`:
+ - Step1: Download `mjsynth.tar.gz` from [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/)
+ - Step2: Download [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt)
+ - Step3:
+ ```bash
+ mkdir Syn90k && cd Syn90k
+
+ mv /path/to/mjsynth.tar.gz .
+
+ tar -xzf mjsynth.tar.gz
+
+ mv /path/to/shuffle_labels.txt .
+
+ # create soft link
+ cd /path/to/mmocr/data/mixture
+
+ ln -s /path/to/Syn90k Syn90k
+ ```
+- For `SynthText`:
+ - Step1: Download `SynthText.zip` from [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/)
+ - Step2: Download [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt)
+ - Step3: Download [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt)
+ - Step4:
+ ```bash
+ unzip SynthText.zip
+
+ cd SynthText
+
+ mv /path/to/shuffle_labels.txt .
+
+ # create soft link
+ cd /path/to/mmocr/data/mixture
+
+ ln -s /path/to/SynthText SynthText
+ ```
+- For `SynthAdd`:
+ - Step1: Download `SynthText_Add.zip` from [SynthAdd](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x))
+ - Step2: Download [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)
+ - Step3:
+ ```bash
+ mkdir SynthAdd && cd SynthAdd
+
+ mv /path/to/SynthText_Add.zip .
+
+ unzip SynthText_Add.zip
+
+ mv /path/to/label.txt .
+
+ # create soft link
+ cd /path/to/mmocr/data/mixture
+
+ ln -s /path/to/SynthAdd SynthAdd
+ ```
+
+
+## Key Information Extraction
+**The structure of the key information extraction dataset directory is organized as follows.**
+```
+└── wildreceipt
+ ├── anno_files
+ ├── class_list.txt
+ ├── dict.txt
+ ├── image_files
+ ├── test.txt
+ └── train.txt
+```
+- Download [wildreceipt.tar](https://download.openmmlab.com/mmocr/data/wildreceipt.tar)
diff --git a/docs/getting_started.md b/docs/getting_started.md
new file mode 100644
index 00000000..20c45a92
--- /dev/null
+++ b/docs/getting_started.md
@@ -0,0 +1,369 @@
+
+# Getting Started
+
+This page provides basic tutorials on the usage of MMOCR.
+For the installation instructions, please see [install.md](install.md).
+
+
+- [Getting Started](#getting-started)
+ - [Inference with Pretrained Models](#inference-with-pretrained-models)
+ - [Test a Single Image](#test-a-single-image)
+ - [Test Multiple Images](#test-multiple-images)
+ - [Test a Dataset](#test-a-dataset)
+ - [Test with Single/Multiple GPUs](#test-with-singlemultiple-gpus)
+ - [Optional Arguments](#optional-arguments)
+ - [Test with Slurm](#test-with-slurm)
+ - [Optional Arguments](#optional-arguments-1)
+ - [Train a Model](#train-a-model)
+ - [Train with Single/Multiple GPUs](#train-with-singlemultiple-gpus)
+ - [Train with Toy Dataset.](#train-with-toy-dataset)
+ - [Train with Slurm](#train-with-slurm)
+ - [Launch Multiple Jobs on a Single Machine](#launch-multiple-jobs-on-a-single-machine)
+ - [Useful Tools](#useful-tools)
+ - [Publish a Model](#publish-a-model)
+ - [Customized Settings](#customized-settings)
+ - [Flexible Dataset](#flexible-dataset)
+ - [Encoder-Decoder-Based Text Recognition Task](#encoder-decoder-based-text-recognition-task)
+ - [Optional Arguments:](#optional-arguments-2)
+ - [Segmentation-Based Text Recognition Task](#segmentation-based-text-recognition-task)
+ - [Text Detection Task](#text-detection-task)
+ - [COCO-like Dataset](#coco-like-dataset)
+
+
+
+
+## Inference with Pretrained Models
+
+We provide testing scripts to evaluate a full dataset, as well as some task-specific image demos.
+
+
+### Test a Single Image
+
+You can use the following command to test a single image with one GPU.
+
+```shell
+python demo/image_demo.py ${TEST_IMG} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${SAVE_PATH} [--imshow] [--device ${GPU_ID}]
+```
+
+If `--imshow` is specified, the demo will also show the image with OpenCV. For example:
+
+```shell
+python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/demo_text_det_pred.jpg
+```
+
+The predicted result will be saved as `demo/demo_text_det_pred.jpg`.
+
+
+### Test Multiple Images
+
+```shell
+# for text detection
+sh tools/test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
+
+# for text recognition
+sh tools/ocr_test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR}
+```
+It will save both the prediction results and visualized images to `${RESULTS_DIR}`
+
+
+### Test a Dataset
+
+MMOCR implements **distributed** testing with `MMDistributedDataParallel`. (Please refer to [datasets.md](datasets.md) to prepare your datasets)
+
+
+#### Test with Single/Multiple GPUs
+
+You can use the following command to test a dataset with single/multiple GPUs.
+
+```shell
+./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--eval ${EVAL_METRIC}]
+```
+For example,
+
+```shell
+./tools/dist_test.sh configs/example_config.py work_dirs/example_exp/example_model_20200202.pth 1 --eval hmean-iou
+```
+
+##### Optional Arguments
+
+- `--eval`: Specify the evaluation metric. For text detection, the metric should be either 'hmean-ic13' or 'hmean-iou'. For text recognition, the metric should be 'acc'.
+
+
+#### Test with Slurm
+
+If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `slurm_test.sh`.
+
+```shell
+[GPUS=${GPUS}] ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--eval ${EVAL_METRIC}]
+```
+Here is an example of using 8 GPUs to test an example model on the 'dev' partition with job name 'test_job'.
+
+```shell
+GPUS=8 ./tools/slurm_test.sh dev test_job configs/example_config.py work_dirs/example_exp/example_model_20200202.pth --eval hmean-iou
+```
+
+You can check [slurm_test.sh](https://github.com/open-mmlab/mmocr/blob/master/tools/slurm_test.sh) for full arguments and environment variables.
+
+
+
+##### Optional Arguments
+
+- `--eval`: Specify the evaluation metric. For text detection, the metric should be either 'hmean-ic13' or 'hmean-iou'. For text recognition, the metric should be 'acc'.
+
+
+
+## Train a Model
+
+MMOCR implements **distributed** training with `MMDistributedDataParallel`. (Please refer to [datasets.md](datasets.md) to prepare your datasets)
+
+All outputs (log files and checkpoints) will be saved to a working directory specified by `work_dir` in the config file.
+
+By default, we evaluate the model on the validation set after several iterations. You can change the evaluation interval by adding the interval argument in the training config as follows:
+```python
+evaluation = dict(interval=1, by_epoch=True) # This evaluates the model per epoch.
+```
+
+
+
+### Train with Single/Multiple GPUs
+
+```shell
+./tools/dist_train.sh ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} [optional arguments]
+```
+
+Optional Arguments:
+
+- `--no-validate` (**not suggested**): By default, the codebase will perform evaluation at every k-th iteration during training. To disable this behavior, use `--no-validate`.
+
+
+#### Train with Toy Dataset.
+We provide a toy dataset under `tests/data`, and you can train a toy model directly, before the academic dataset is prepared.
+
+For example, train a text recognition task with `seg` method and toy dataset,
+```
+./tools/dist_train.sh configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py work_dirs/seg 1
+```
+
+And train a text recognition task with `sar` method and toy dataset,
+```
+./tools/dist_train.sh configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py work_dirs/sar 1
+```
+
+
+### Train with Slurm
+
+If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `slurm_train.sh`.
+
+```shell
+[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR}
+```
+
+Here is an example of using 8 GPUs to train a text detection model on the dev partition.
+
+```shell
+GPUS=8 ./tools/slurm_train.sh dev psenet-ic15 configs/textdet/psenet/psenet_r50_fpnf_sbn_1x_icdar2015.py /nfs/xxxx/psenet-ic15
+```
+
+You can check [slurm_train.sh](https://github.com/open-mmlab/mmocr/blob/master/tools/slurm_train.sh) for full arguments and environment variables.
+
+
+### Launch Multiple Jobs on a Single Machine
+
+If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
+you need to specify different ports (29500 by default) for each job to avoid communication conflicts.
+
+If you use `dist_train.sh` to launch training jobs, you can set the ports in the command shell.
+
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
+CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
+```
+
+If you launch training jobs with Slurm, you need to modify the config files to set different communication ports.
+
+In `config1.py`,
+```python
+dist_params = dict(backend='nccl', port=29500)
+```
+
+In `config2.py`,
+```python
+dist_params = dict(backend='nccl', port=29501)
+```
+
+Then you can launch two jobs with `config1.py` ang `config2.py`.
+
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR}
+CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
+```
+
+
+
+## Useful Tools
+
+We provide numerous useful tools under `mmocr/tools` directory.
+
+
+### Publish a Model
+
+Before you upload a model to AWS, you may want to
+(1) convert the model weights to CPU tensors, (2) delete the optimizer states and
+(3) compute the hash of the checkpoint file and append the hash id to the filename.
+
+```shell
+python tools/publish_model.py ${INPUT_FILENAME} ${OUTPUT_FILENAME}
+```
+
+E.g.,
+
+```shell
+python tools/publish_model.py work_dirs/psenet/latest.pth psenet_r50_fpnf_sbn_1x_20190801.pth
+```
+
+The final output filename will be `psenet_r50_fpnf_sbn_1x_20190801-{hash id}.pth`.
+
+
+## Customized Settings
+
+
+### Flexible Dataset
+To support the tasks of `text detection`, `text recognition` and `key information extraction`, we have designed a new type of dataset which consists of `loader` and `parser` to load and parse different types of annotation files.
+- **loader**: Load the annotation file. There are two types of loader, `HardDiskLoader` and `LmdbLoader`
+ - `HardDiskLoader`: Load `txt` format annotation file from hard disk to memory.
+ - `LmdbLoader`: Load `lmdb` format annotation file with lmdb backend, which is very useful for **extremely large** annotation files to avoid out-of-memory problem when ten or more GPUs are used, since each GPU will start multiple processes to load annotation file to memory.
+- **parser**: Parse the annotation file line-by-line and return with `dict` format. There are two types of parser, `LineStrParser` and `LineJsonParser`.
+ - `LineStrParser`: Parse one line in ann file while treating it as a string and separating it to several parts by a `separator`. It can be used on tasks with simple annotation files such as text recognition where each line of the annotation files contains the `filename` and `label` attribute only.
+ - `LineJsonParser`: Parse one line in ann file while treating it as a json-string and using `json.loads` to convert it to `dict`. It can be used on tasks with complex annotation files such as text detection where each line of the annotation files contains multiple attributes (e.g. `filename`, `height`, `width`, `box`, `segmentation`, `iscrowd`, `category_id`, etc.).
+
+Here we show some examples of using different combination of `loader` and `parser`.
+
+
+#### Encoder-Decoder-Based Text Recognition Task
+```python
+dataset_type = 'OCRDataset'
+img_prefix = 'tests/data/ocr_toy_dataset/imgs'
+train_anno_file = 'tests/data/ocr_toy_dataset/label.txt'
+train = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=10,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+```
+You can check the content of the annotation file in `tests/data/ocr_toy_dataset/label.txt`.
+The combination of `HardDiskLoader` and `LineStrParser` will return a dict for each file by calling `__getitem__`: `{'filename': '1223731.jpg', 'text': 'GRAND'}`.
+
+
+##### Optional Arguments:
+
+- `repeat`: The number of repeated lines in the annotation files. For example, if there are `10` lines in the annotation file, setting `repeat=10` will generate a corresponding annotation file with size `100`.
+
+If the annotation file is extreme large, you can convert it from txt format to lmdb format with the following command:
+```python
+python tools/data_converter/txt2lmdb.py -i ann_file.txt -o ann_file.lmdb
+```
+
+After that, you can use `LmdbLoader` in dataset like below.
+```python
+img_prefix = 'tests/data/ocr_toy_dataset/imgs'
+train_anno_file = 'tests/data/ocr_toy_dataset/label.lmdb'
+train = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=train_anno_file,
+ loader=dict(
+ type='LmdbLoader',
+ repeat=10,
+ parser=dict(
+ type='LineStrParser',
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' ')),
+ pipeline=train_pipeline,
+ test_mode=False)
+```
+
+
+#### Segmentation-Based Text Recognition Task
+```python
+prefix = 'tests/data/ocr_char_ann_toy_dataset/'
+train = dict(
+ type='OCRSegDataset',
+ img_prefix=prefix + 'imgs',
+ ann_file=prefix + 'instances_train.txt',
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=10,
+ parser=dict(
+ type='LineJsonParser',
+ keys=['file_name', 'annotations', 'text'])),
+ pipeline=train_pipeline,
+ test_mode=True)
+```
+You can check the content of the annotation file in `tests/data/ocr_char_ann_toy_dataset/instances_train.txt`.
+The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for each file by calling `__getitem__` each time:
+```python
+{"file_name": "resort_88_101_1.png", "annotations": [{"char_text": "F", "char_box": [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]}, {"char_text": "r", "char_box": [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]}, {"char_text": "o", "char_box": [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]}, {"char_text": "m", "char_box": [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]}, {"char_text": ":", "char_box": [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]}], "text": "From:"}
+```
+
+
+#### Text Detection Task
+```python
+dataset_type = 'TextDetDataset'
+img_prefix = 'tests/data/toy_dataset/imgs'
+test_anno_file = 'tests/data/toy_dataset/instances_test.txt'
+test = dict(
+ type=dataset_type,
+ img_prefix=img_prefix,
+ ann_file=test_anno_file,
+ loader=dict(
+ type='HardDiskLoader',
+ repeat=4,
+ parser=dict(
+ type='LineJsonParser',
+ keys=['file_name', 'height', 'width', 'annotations'])),
+ pipeline=test_pipeline,
+ test_mode=True)
+```
+The results are generated in the same way as the segmentation-based text recognition task above.
+You can check the content of the annotation file in `tests/data/toy_dataset/instances_test.txt`.
+The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for each file by calling `__getitem__`:
+```python
+{"file_name": "test/img_10.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [260.0, 138.0, 24.0, 20.0], "segmentation": [[261, 138, 284, 140, 279, 158, 260, 158]]}, {"iscrowd": 0, "category_id": 1, "bbox": [288.0, 138.0, 129.0, 23.0], "segmentation": [[288, 138, 417, 140, 416, 161, 290, 157]]}, {"iscrowd": 0, "category_id": 1, "bbox": [743.0, 145.0, 37.0, 18.0], "segmentation": [[743, 145, 779, 146, 780, 163, 746, 163]]}, {"iscrowd": 0, "category_id": 1, "bbox": [783.0, 129.0, 50.0, 26.0], "segmentation": [[783, 129, 831, 132, 833, 155, 785, 153]]}, {"iscrowd": 1, "category_id": 1, "bbox": [831.0, 133.0, 43.0, 23.0], "segmentation": [[831, 133, 870, 135, 874, 156, 835, 155]]}, {"iscrowd": 1, "category_id": 1, "bbox": [159.0, 204.0, 72.0, 15.0], "segmentation": [[159, 205, 230, 204, 231, 218, 159, 219]]}, {"iscrowd": 1, "category_id": 1, "bbox": [785.0, 158.0, 75.0, 21.0], "segmentation": [[785, 158, 856, 158, 860, 178, 787, 179]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1011.0, 157.0, 68.0, 16.0], "segmentation": [[1011, 157, 1079, 160, 1076, 173, 1011, 170]]}]}
+```
+
+
+
+### COCO-like Dataset
+For text detection, you can also use an annotation file in a COCO format that is defined in [mmdet](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py):
+```python
+dataset_type = 'IcdarDataset'
+prefix = 'tests/data/toy_dataset/'
+test=dict(
+ type=dataset_type,
+ ann_file=prefix + 'instances_test.json',
+ img_prefix=prefix + 'imgs',
+ pipeline=test_pipeline)
+```
+You can check the content of the annotation file in `tests/data/toy_dataset/instances_test.json`
+- The icdar2015/2017 annotations have to be converted into the COCO format using `tools/data_converter/icdar_converter.py`:
+
+ ```shell
+ python tools/data_converter/icdar_converter.py ${src_root_path} -o ${out_path} -d ${data_type} --split-list training validation test
+ ```
+
+- The ctw1500 annotations have to be converted into the COCO format using `tools/data_converter/ctw1500_converter.py`:
+
+ ```shell
+ python tools/data_converter/ctw1500_converter.py ${src_root_path} -o ${out_path} --split-list training test
+ ```
+```
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 00000000..8242dffc
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,38 @@
+Welcome to MMOCR's documentation!
+=======================================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Get Started
+
+ install.md
+ getting_started.md
+ technical_details.md
+ contributing.md
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Model Zoo
+
+ modelzoo.md
+ textdet_models.md
+ textrecog_models.md
+ kie_models.md
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Notes
+
+ changelog.md
+ faq.md
+
+.. toctree::
+ :caption: API Reference
+
+ api.rst
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`search`
diff --git a/docs/install.md b/docs/install.md
new file mode 100644
index 00000000..78d1b0be
--- /dev/null
+++ b/docs/install.md
@@ -0,0 +1,249 @@
+
+# Installation
+
+
+- [Installation](#installation)
+ - [Prerequisites](#prerequisites)
+ - [Step-by-Step Installation Instructions](#step-by-step-installation-instructions)
+ - [Full Set-up Script](#full-set-up-script)
+ - [Another option: Docker Image](#another-option-docker-image)
+ - [Prepare Datasets](#prepare-datasets)
+
+
+
+## Prerequisites
+
+- Linux (Windows is not officially supported)
+- Python 3.7
+- PyTorch 1.5 or higher
+- torchvision 0.6.0
+- CUDA 10.1
+- NCCL 2
+- GCC 5.4.0 or higher
+- [mmcv](https://github.com/open-mmlab/mmcv) 1.2.6
+
+We have tested the following versions of OS and softwares:
+
+- OS: Ubuntu 16.04
+- CUDA: 10.1
+- GCC(G++): 5.4.0
+- mmcv 1.2.6
+- PyTorch 1.5
+- torchvision 0.6.0
+
+MMOCR depends on Pytorch and mmdetection v2.9.0.
+
+
+## Step-by-Step Installation Instructions
+
+a. Create a conda virtual environment and activate it.
+
+```shell
+conda create -n open-mmlab python=3.7 -y
+conda activate open-mmlab
+```
+
+b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/), e.g.,
+
+```shell
+conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch
+```
+Note: Make sure that your compilation CUDA version and runtime CUDA version match.
+You can check the supported CUDA version for precompiled packages on the [PyTorch website](https://pytorch.org/).
+
+`E.g. 1` If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install
+PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1.
+
+```python
+conda install pytorch cudatoolkit=10.1 torchvision -c pytorch
+```
+
+`E.g. 2` If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install
+PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2.
+
+```python
+conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch
+```
+
+If you build PyTorch from source instead of installing the prebuilt package,
+you can use more CUDA versions such as 9.0.
+
+c. Create a folder called `code` and clone the mmcv repository into it.
+
+```shell
+mkdir code
+cd code
+git clone https://github.com/open-mmlab/mmcv.git
+cd mmcv
+git checkout -b v1.2.6 v1.2.6
+pip install -r requirements.txt
+MMCV_WITH_OPS=1 pip install -v -e .
+```
+
+d. Clone the mmdetection repository into it. The mmdetection repo is separate from the mmcv repo in `code`.
+
+```shell
+cd ..
+git clone https://github.com/open-mmlab/mmdetection.git
+cd mmdetection
+git checkout -b v2.9.0 v2.9.0
+pip install -r requirements.txt
+pip install -v -e .
+export PYTHONPATH=$(pwd):$PYTHONPATH
+```
+
+Note that we have tested mmdetection v2.9.0 only. Other versions might be incompatible.
+
+e. Clone the mmocr repository into it. The mmdetection repo is separate from the mmcv and mmdetection repo in `code`.
+
+```shell
+cd ..
+git clone https://github.com/open-mmlab/mmocr.git
+cd mmocr
+```
+
+f. Install build requirements and then install MMOCR.
+
+```shell
+pip install -r requirements.txt
+pip install -v -e . # or "python setup.py build_ext --inplace"
+export PYTHONPATH=$(pwd):$PYTHONPATH
+```
+
+
+## Full Set-up Script
+
+Here is the full script for setting up mmocr with conda.
+
+```shell
+conda create -n open-mmlab python=3.7 -y
+conda activate open-mmlab
+
+# install latest pytorch prebuilt with the default prebuilt CUDA version (usually the latest)
+conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch
+
+# install mmcv
+mkdir code
+cd code
+git clone https://github.com/open-mmlab/mmcv.git
+cd mmcv # code/mmcv
+git checkout -b v1.2.6 v1.2.6
+pip install -r requirements.txt
+MMCV_WITH_OPS=1 pip install -v -e .
+
+# install mmdetection
+cd .. # exit to code
+git clone https://github.com/open-mmlab/mmdetection.git
+cd mmdetection # code/mmdetection
+git checkout -b v2.9.0 v2.9.0
+pip install -r requirements.txt
+pip install -v -e .
+export PYTHONPATH=$(pwd):$PYTHONPATH
+
+# install mmocr
+cd ..
+git clone https://github.com/open-mmlab/mmocr.git
+cd mmocr # code/mmocr
+
+pip install -r requirements.txt
+pip install -v -e . # or "python setup.py build_ext --inplace"
+export PYTHONPATH=$(pwd):$PYTHONPATH
+```
+
+
+## Another option: Docker Image
+
+We provide a [Dockerfile](https://github.com/open-mmlab/mmocr/blob/master/docker/Dockerfile) to build an image.
+
+```shell
+# build an image with PyTorch 1.5, CUDA 10.1
+docker build -t mmocr docker/
+```
+
+Run it with
+
+```shell
+docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmocr/data mmocr
+```
+
+
+## Prepare Datasets
+
+It is recommended to symlink the dataset root to `mmocr/data`. Please refer to [datasets.md](datasets.md) to prepare your datasets.
+If your folder structure is different, you may need to change the corresponding paths in config files.
+
+The `mmocr` folder is organized as follows:
+```
+mmocr
+.
+├── configs
+│  ├── _base_
+│  ├── kie
+│  ├── textdet
+│  └── textrecog
+├── demo
+│  ├── demo_text_det.jpg
+│  ├── demo_text_recog.jpg
+│  ├── image_demo.py
+│  └── webcam_demo.py
+├── docs
+│  ├── api.rst
+│  ├── changelog.md
+│  ├── code_of_conduct.md
+│  ├── conf.py
+│  ├── contributing.md
+│  ├── datasets.md
+│  ├── getting_started.md
+│  ├── index.rst
+│  ├── install.md
+│  ├── make.bat
+│  ├── Makefile
+│  ├── merge_docs.sh
+│  ├── requirements.txt
+│  ├── res
+│  ├── stats.py
+│  └── technical_details.md
+├── LICENSE
+├── mmocr
+│  ├── apis
+│  ├── core
+│  ├── datasets
+│  ├── __init__.py
+│  ├── models
+│  ├── utils
+│  └── version.py
+├── README.md
+├── requirements
+│  ├── build.txt
+│  ├── docs.txt
+│  ├── optional.txt
+│  ├── readthedocs.txt
+│  ├── runtime.txt
+│  └── tests.txt
+├── requirements.txt
+├── resources
+│  ├── illustration.jpg
+│  └── mmocr-logo.png
+├── setup.cfg
+├── setup.py
+├── tests
+│  ├── data
+│  ├── test_dataset
+│  ├── test_metrics
+│  ├── test_models
+│  ├── test_tools
+│  └── test_utils
+└── tools
+ ├── data
+ ├── dist_test.sh
+ ├── dist_train.sh
+ ├── ocr_test_imgs.py
+ ├── ocr_test_imgs.sh
+ ├── publish_model.py
+ ├── slurm_test.sh
+ ├── slurm_train.sh
+ ├── test_imgs.py
+ ├── test_imgs.sh
+ ├── test.py
+ └── train.py
+```
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 00000000..8a3a0e25
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,36 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/merge_docs.sh b/docs/merge_docs.sh
new file mode 100755
index 00000000..2482fb74
--- /dev/null
+++ b/docs/merge_docs.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+
+sed -i '$a\\n' ../configs/kie/*/*.md
+sed -i '$a\\n' ../configs/textdet/*/*.md
+sed -i '$a\\n' ../configs/textrecog/*/*.md
+
+# gather models
+cat ../configs/kie/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Kie Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >kie_models.md
+cat ../configs/textdet/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Detection Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textdet_models.md
+cat ../configs/textrecog/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textrecog_models.md
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 00000000..89fbf86c
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,4 @@
+recommonmark
+sphinx
+sphinx_markdown_tables
+sphinx_rtd_theme
diff --git a/docs/res/git-workflow-feature.png b/docs/res/git-workflow-feature.png
new file mode 100644
index 00000000..4d9f9083
Binary files /dev/null and b/docs/res/git-workflow-feature.png differ
diff --git a/docs/res/git-workflow-master-develop.png b/docs/res/git-workflow-master-develop.png
new file mode 100644
index 00000000..624111c8
Binary files /dev/null and b/docs/res/git-workflow-master-develop.png differ
diff --git a/docs/stats.py b/docs/stats.py
new file mode 100755
index 00000000..ef337ba5
--- /dev/null
+++ b/docs/stats.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+import functools as func
+import glob
+import re
+from os.path import basename, splitext
+
+import numpy as np
+import titlecase
+
+
+def anchor(name):
+ return re.sub(r'-+', '-', re.sub(r'[^a-zA-Z0-9]', '-',
+ name.strip().lower())).strip('-')
+
+
+# Count algorithms
+
+files = sorted(glob.glob('*_models.md'))
+# files = sorted(glob.glob('docs/*_models.md'))
+
+stats = []
+
+for f in files:
+ with open(f, 'r') as content_file:
+ content = content_file.read()
+
+ # title
+ title = content.split('\n')[0].replace('#', '')
+
+ # count papers
+ papers = set((papertype, titlecase.titlecase(paper.lower().strip()))
+ for (papertype, paper) in re.findall(
+ r'\n\s*\[([A-Z]+?)\]\s*\n.*?\btitle\s*=\s*{(.*?)}',
+ content, re.DOTALL))
+ # paper links
+ revcontent = '\n'.join(list(reversed(content.splitlines())))
+ paperlinks = {}
+ for _, p in papers:
+ print(p)
+ q = p.replace('\\', '\\\\').replace('?', '\\?')
+ paperlinks[p] = ' '.join(
+ (f'[⇨]({splitext(basename(f))[0]}.html#{anchor(paperlink)})'
+ for paperlink in re.findall(
+ rf'\btitle\s*=\s*{{\s*{q}\s*}}.*?\n## (.*?)\s*[,;]?\s*\n',
+ revcontent, re.DOTALL | re.IGNORECASE)))
+ print(' ', paperlinks[p])
+ paperlist = '\n'.join(
+ sorted(f' - [{t}] {x} ({paperlinks[x]})' for t, x in papers))
+ # count configs
+ configs = set(x.lower().strip()
+ for x in re.findall(r'https.*configs/.*\.py', content))
+
+ # count ckpts
+ ckpts = set(x.lower().strip()
+ for x in re.findall(r'https://download.*\.pth', content)
+ if 'mmaction' in x)
+
+ statsmsg = f"""
+## [{title}]({f})
+
+* Number of checkpoints: {len(ckpts)}
+* Number of configs: {len(configs)}
+* Number of papers: {len(papers)}
+{paperlist}
+
+ """
+
+ stats.append((papers, configs, ckpts, statsmsg))
+
+allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _, _ in stats])
+allconfigs = func.reduce(lambda a, b: a.union(b), [c for _, c, _, _ in stats])
+allckpts = func.reduce(lambda a, b: a.union(b), [c for _, _, c, _ in stats])
+msglist = '\n'.join(x for _, _, _, x in stats)
+
+papertypes, papercounts = np.unique([t for t, _ in allpapers],
+ return_counts=True)
+countstr = '\n'.join(
+ [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)])
+
+modelzoo = f"""
+# Overview
+
+* Number of checkpoints: {len(allckpts)}
+* Number of configs: {len(allconfigs)}
+* Number of papers: {len(allpapers)}
+{countstr}
+
+For supported datasets, see [datasets overview](datasets.md).
+
+{msglist}
+"""
+
+with open('modelzoo.md', 'w') as f:
+ f.write(modelzoo)
diff --git a/mmocr/__init__.py b/mmocr/__init__.py
new file mode 100644
index 00000000..1c4f7e8f
--- /dev/null
+++ b/mmocr/__init__.py
@@ -0,0 +1,3 @@
+from .version import __version__, short_version
+
+__all__ = ['__version__', 'short_version']
diff --git a/mmocr/apis/__init__.py b/mmocr/apis/__init__.py
new file mode 100644
index 00000000..1c4b014c
--- /dev/null
+++ b/mmocr/apis/__init__.py
@@ -0,0 +1,3 @@
+from .inference import model_inference
+
+__all__ = ['model_inference']
diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py
new file mode 100644
index 00000000..615682ce
--- /dev/null
+++ b/mmocr/apis/inference.py
@@ -0,0 +1,43 @@
+import torch
+from mmcv.ops import RoIPool
+from mmcv.parallel import collate, scatter
+
+from mmdet.datasets.pipelines import Compose
+
+
+def model_inference(model, img):
+ """Inference image(s) with the detector.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ imgs (str): Image files.
+
+ Returns:
+ result (dict): Detection results.
+ """
+ assert isinstance(img, str)
+
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+ data = dict(img_info=dict(filename=img), img_prefix=None)
+ # build the data pipeline
+ test_pipeline = Compose(cfg.data.test.pipeline)
+ data = test_pipeline(data)
+ data = collate([data], samples_per_gpu=1)
+
+ # process img_metas
+ data['img_metas'] = data['img_metas'][0].data
+
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ for m in model.modules():
+ assert not isinstance(
+ m, RoIPool
+ ), 'CPU inference with RoIPool is not supported currently.'
+
+ # forward the model
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)[0]
+ return result
diff --git a/mmocr/core/__init__.py b/mmocr/core/__init__.py
new file mode 100644
index 00000000..cb717d4b
--- /dev/null
+++ b/mmocr/core/__init__.py
@@ -0,0 +1,3 @@
+from .evaluation import * # noqa: F401, F403
+from .mask import * # noqa: F401, F403
+from .visualize import * # noqa: F401, F403
diff --git a/mmocr/core/evaluation/__init__.py b/mmocr/core/evaluation/__init__.py
new file mode 100644
index 00000000..493e894e
--- /dev/null
+++ b/mmocr/core/evaluation/__init__.py
@@ -0,0 +1,10 @@
+from .hmean import eval_hmean
+from .hmean_ic13 import eval_hmean_ic13
+from .hmean_iou import eval_hmean_iou
+from .kie_metric import compute_f1_score
+from .ocr_metric import eval_ocr_metric
+
+__all__ = [
+ 'eval_hmean_ic13', 'eval_hmean_iou', 'eval_ocr_metric', 'eval_hmean',
+ 'compute_f1_score'
+]
diff --git a/mmocr/core/evaluation/hmean.py b/mmocr/core/evaluation/hmean.py
new file mode 100644
index 00000000..bbb7d679
--- /dev/null
+++ b/mmocr/core/evaluation/hmean.py
@@ -0,0 +1,149 @@
+from operator import itemgetter
+
+import mmcv
+from mmcv.utils import print_log
+
+import mmocr.utils as utils
+from mmocr.core.evaluation import hmean_ic13, hmean_iou
+from mmocr.core.evaluation.utils import (filter_2dlist_result,
+ select_top_boundary)
+from mmocr.core.mask import extract_boundary
+
+
+def output_ranklist(img_results, img_infos, out_file):
+ """Output the worst results for debugging.
+
+ Args:
+ img_results (list[dict]): Image result list.
+ img_infos (list[dict]): Image information list.
+ out_file (str): The output file path.
+
+ Returns:
+ sorted_results (list[dict]): Image results sorted by hmean.
+ """
+ assert utils.is_type_list(img_results, dict)
+ assert utils.is_type_list(img_infos, dict)
+ assert isinstance(out_file, str)
+ assert out_file.endswith('json')
+
+ sorted_results = []
+ for inx, result in enumerate(img_results):
+ name = img_infos[inx]['file_name']
+ img_result = result
+ img_result['file_name'] = name
+ sorted_results.append(img_result)
+ sorted_results = sorted(
+ sorted_results, key=itemgetter('hmean'), reverse=False)
+
+ mmcv.dump(sorted_results, file=out_file)
+
+ return sorted_results
+
+
+def get_gt_masks(ann_infos):
+ """Get ground truth masks and ignored masks.
+
+ Args:
+ ann_infos (list[dict]): Each dict contains annotation
+ infos of one image, containing following keys:
+ masks, masks_ignore.
+ Returns:
+ gt_masks (list[list[list[int]]]): Ground truth masks.
+ gt_masks_ignore (list[list[list[int]]]): Ignored masks.
+ """
+ assert utils.is_type_list(ann_infos, dict)
+
+ gt_masks = []
+ gt_masks_ignore = []
+ for ann_info in ann_infos:
+ masks = ann_info['masks']
+ mask_gt = []
+ for mask in masks:
+ assert len(mask[0]) >= 8 and len(mask[0]) % 2 == 0
+ mask_gt.append(mask[0])
+ gt_masks.append(mask_gt)
+
+ masks_ignore = ann_info['masks_ignore']
+ mask_gt_ignore = []
+ for mask_ignore in masks_ignore:
+ assert len(mask_ignore[0]) >= 8 and len(mask_ignore[0]) % 2 == 0
+ mask_gt_ignore.append(mask_ignore[0])
+ gt_masks_ignore.append(mask_gt_ignore)
+
+ return gt_masks, gt_masks_ignore
+
+
+def eval_hmean(results,
+ img_infos,
+ ann_infos,
+ metrics={'hmean-iou'},
+ score_thr=0.3,
+ rank_list=None,
+ logger=None,
+ **kwargs):
+ """Evaluation in hmean metric.
+
+ Args:
+ results (list[dict]): Each dict corresponds to one image,
+ containing the following keys: boundary_result
+ img_infos (list[dict]): Each dict corresponds to one image,
+ containing the following keys: filename, height, width
+ ann_infos (list[dict]): Each dict corresponds to one image,
+ containing the following keys: masks, masks_ignore
+ score_thr (float): Score threshold of prediction map.
+ metrics (set{str}): Hmean metric set, should be one or all of
+ {'hmean-iou', 'hmean-ic13'}
+ Returns:
+ dict[str: float]
+ """
+ assert utils.is_type_list(results, dict)
+ assert utils.is_type_list(img_infos, dict)
+ assert utils.is_type_list(ann_infos, dict)
+ assert len(results) == len(img_infos) == len(ann_infos)
+ assert isinstance(metrics, set)
+
+ gts, gts_ignore = get_gt_masks(ann_infos)
+
+ preds = []
+ pred_scores = []
+ for result in results:
+ _, texts, scores = extract_boundary(result)
+ if len(texts) > 0:
+ assert utils.valid_boundary(texts[0], False)
+ valid_texts, valid_text_scores = filter_2dlist_result(
+ texts, scores, score_thr)
+ preds.append(valid_texts)
+ pred_scores.append(valid_text_scores)
+
+ eval_results = {}
+ for metric in metrics:
+ msg = f'Evaluating {metric}...'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+ best_result = dict(hmean=-1)
+ for iter in range(3, 10):
+ thr = iter * 0.1
+ top_preds = select_top_boundary(preds, pred_scores, thr)
+ if metric == 'hmean-iou':
+ result, img_result = hmean_iou.eval_hmean_iou(
+ top_preds, gts, gts_ignore)
+ elif metric == 'hmean-ic13':
+ result, img_result = hmean_ic13.eval_hmean_ic13(
+ top_preds, gts, gts_ignore)
+ else:
+ raise NotImplementedError
+ if rank_list is not None:
+ output_ranklist(img_result, img_infos, rank_list)
+
+ print_log(
+ 'thr {0:.1f}, recall:{1[recall]:.3f}, '
+ 'precision: {1[precision]:.3f}, '
+ 'hmean:{1[hmean]:.3f}'.format(thr, result),
+ logger=logger)
+ if result['hmean'] > best_result['hmean']:
+ best_result = result
+ eval_results[metric + ':recall'] = best_result['recall']
+ eval_results[metric + ':precision'] = best_result['precision']
+ eval_results[metric + ':hmean'] = best_result['hmean']
+ return eval_results
diff --git a/mmocr/core/evaluation/hmean_ic13.py b/mmocr/core/evaluation/hmean_ic13.py
new file mode 100644
index 00000000..d3c69467
--- /dev/null
+++ b/mmocr/core/evaluation/hmean_ic13.py
@@ -0,0 +1,216 @@
+import numpy as np
+
+import mmocr.utils as utils
+from . import utils as eval_utils
+
+
+def compute_recall_precision(gt_polys, pred_polys):
+ """Compute the recall and the precision matrices between gt and predicted
+ polygons.
+
+ Args:
+ gt_polys (list[Polygon]): List of gt polygons.
+ pred_polys (list[Polygon]): List of predicted polygons.
+
+ Returns:
+ recall (ndarray): Recall matrix of size gt_num x det_num.
+ precision (ndarray): Precision matrix of size gt_num x det_num.
+ """
+ assert isinstance(gt_polys, list)
+ assert isinstance(pred_polys, list)
+
+ gt_num = len(gt_polys)
+ det_num = len(pred_polys)
+ sz = [gt_num, det_num]
+
+ recall = np.zeros(sz)
+ precision = np.zeros(sz)
+ # compute area recall and precision for each (gt, det) pair
+ # in one img
+ for gt_id in range(gt_num):
+ for pred_id in range(det_num):
+ gt = gt_polys[gt_id]
+ det = pred_polys[pred_id]
+
+ inter_area, _ = eval_utils.poly_intersection(det, gt)
+ gt_area = gt.area()
+ det_area = det.area()
+ if gt_area != 0:
+ recall[gt_id, pred_id] = inter_area / gt_area
+ if det_area != 0:
+ precision[gt_id, pred_id] = inter_area / det_area
+
+ return recall, precision
+
+
+def eval_hmean_ic13(det_boxes,
+ gt_boxes,
+ gt_ignored_boxes,
+ precision_thr=0.4,
+ recall_thr=0.8,
+ center_dist_thr=1.0,
+ one2one_score=1.,
+ one2many_score=0.8,
+ many2one_score=1.):
+ """Evalute hmean of text detection using the icdar2013 standard.
+
+ Args:
+ det_boxes (list[list[list[float]]]): List of arrays of shape (n, 2k).
+ Each element is the det_boxes for one img. k>=4.
+ gt_boxes (list[list[list[float]]]): List of arrays of shape (m, 2k).
+ Each element is the gt_boxes for one img. k>=4.
+ gt_ignored_boxes (list[list[list[float]]]): List of arrays of
+ (l, 2k). Each element is the ignored gt_boxes for one img. k>=4.
+ precision_thr (float): Precision threshold of the iou of one
+ (gt_box, det_box) pair.
+ recall_thr (float): Recall threshold of the iou of one
+ (gt_box, det_box) pair.
+ center_dist_thr (float): Distance threshold of one (gt_box, det_box)
+ center point pair.
+ one2one_score (float): Reward when one gt matches one det_box.
+ one2many_score (float): Reward when one gt matches many det_boxes.
+ many2one_score (float): Reward when many gts match one det_box.
+
+ Returns:
+ hmean (tuple[dict]): Tuple of dicts which encodes the hmean for
+ the dataset and all images.
+ """
+ assert utils.is_3dlist(det_boxes)
+ assert utils.is_3dlist(gt_boxes)
+ assert utils.is_3dlist(gt_ignored_boxes)
+
+ assert 0 <= precision_thr <= 1
+ assert 0 <= recall_thr <= 1
+ assert center_dist_thr > 0
+ assert 0 <= one2one_score <= 1
+ assert 0 <= one2many_score <= 1
+ assert 0 <= many2one_score <= 1
+
+ img_num = len(det_boxes)
+ assert img_num == len(gt_boxes)
+ assert img_num == len(gt_ignored_boxes)
+
+ dataset_gt_num = 0
+ dataset_pred_num = 0
+ dataset_hit_recall = 0.0
+ dataset_hit_prec = 0.0
+
+ img_results = []
+
+ for i in range(img_num):
+ gt = gt_boxes[i]
+ gt_ignored = gt_ignored_boxes[i]
+ pred = det_boxes[i]
+
+ gt_num = len(gt)
+ ignored_num = len(gt_ignored)
+ pred_num = len(pred)
+
+ accum_recall = 0.
+ accum_precision = 0.
+
+ gt_points = gt + gt_ignored
+ gt_polys = [eval_utils.points2polygon(p) for p in gt_points]
+ gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
+ gt_num = len(gt_polys)
+
+ pred_polys, pred_points, pred_ignored_index = eval_utils.ignore_pred(
+ pred, gt_ignored_index, gt_polys, precision_thr)
+
+ if pred_num > 0 and gt_num > 0:
+
+ gt_hit = np.zeros(gt_num, np.int8).tolist()
+ pred_hit = np.zeros(pred_num, np.int8).tolist()
+
+ # compute area recall and precision for each (gt, pred) pair
+ # in one img.
+ recall_mat, precision_mat = compute_recall_precision(
+ gt_polys, pred_polys)
+
+ # match one gt to one pred box.
+ for gt_id in range(gt_num):
+ for pred_id in range(pred_num):
+ if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0
+ or gt_id in gt_ignored_index
+ or pred_id in pred_ignored_index):
+ continue
+ match = eval_utils.one2one_match_ic13(
+ gt_id, pred_id, recall_mat, precision_mat, recall_thr,
+ precision_thr)
+
+ if match:
+ gt_point = np.array(gt_points[gt_id])
+ det_point = np.array(pred_points[pred_id])
+
+ norm_dist = eval_utils.box_center_distance(
+ det_point, gt_point)
+ norm_dist /= eval_utils.box_diag(
+ det_point) + eval_utils.box_diag(gt_point)
+ norm_dist *= 2.0
+
+ if norm_dist < center_dist_thr:
+ gt_hit[gt_id] = 1
+ pred_hit[pred_id] = 1
+ accum_recall += one2one_score
+ accum_precision += one2one_score
+
+ # match one gt to many det boxes.
+ for gt_id in range(gt_num):
+ if gt_id in gt_ignored_index:
+ continue
+ match, match_det_set = eval_utils.one2many_match_ic13(
+ gt_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_hit, pred_hit, pred_ignored_index)
+
+ if match:
+ gt_hit[gt_id] = 1
+ accum_recall += one2many_score
+ accum_precision += one2many_score * len(match_det_set)
+ for pred_id in match_det_set:
+ pred_hit[pred_id] = 1
+
+ # match many gt to one det box. One pair of (det,gt) are matched
+ # successfully if their recall, precision, normalized distance
+ # meet some thresholds.
+ for pred_id in range(pred_num):
+ if pred_id in pred_ignored_index:
+ continue
+
+ match, match_gt_set = eval_utils.many2one_match_ic13(
+ pred_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_hit, pred_hit, gt_ignored_index)
+
+ if match:
+ pred_hit[pred_id] = 1
+ accum_recall += many2one_score * len(match_gt_set)
+ accum_precision += many2one_score
+ for gt_id in match_gt_set:
+ gt_hit[gt_id] = 1
+
+ gt_care_number = gt_num - ignored_num
+ pred_care_number = pred_num - len(pred_ignored_index)
+
+ r, p, h = eval_utils.compute_hmean(accum_recall, accum_precision,
+ gt_care_number, pred_care_number)
+
+ img_results.append({'recall': r, 'precision': p, 'hmean': h})
+
+ dataset_gt_num += gt_care_number
+ dataset_pred_num += pred_care_number
+ dataset_hit_recall += accum_recall
+ dataset_hit_prec += accum_precision
+
+ total_r, total_p, total_h = eval_utils.compute_hmean(
+ dataset_hit_recall, dataset_hit_prec, dataset_gt_num, dataset_pred_num)
+
+ dataset_results = {
+ 'num_gts': dataset_gt_num,
+ 'num_dets': dataset_pred_num,
+ 'num_recall': dataset_hit_recall,
+ 'num_precision': dataset_hit_prec,
+ 'recall': total_r,
+ 'precision': total_p,
+ 'hmean': total_h
+ }
+
+ return dataset_results, img_results
diff --git a/mmocr/core/evaluation/hmean_iou.py b/mmocr/core/evaluation/hmean_iou.py
new file mode 100644
index 00000000..8ad0363f
--- /dev/null
+++ b/mmocr/core/evaluation/hmean_iou.py
@@ -0,0 +1,116 @@
+import numpy as np
+
+import mmocr.utils as utils
+from . import utils as eval_utils
+
+
+def eval_hmean_iou(pred_boxes,
+ gt_boxes,
+ gt_ignored_boxes,
+ iou_thr=0.5,
+ precision_thr=0.5):
+ """Evalute hmean of text detection using IOU standard.
+
+ Args:
+ pred_boxes (list[list[list[float]]]): Text boxes for an img list. Each
+ box has 2k (>=8) values.
+ gt_boxes (list[list[list[float]]]): Ground truth text boxes for an img
+ list. Each box has 2k (>=8) values.
+ gt_ignored_boxes (list[list[list[float]]]): Ignored ground truth text
+ boxes for an img list. Each box has 2k (>=8) values.
+ iou_thr (float): Iou threshold when one (gt_box, det_box) pair is
+ matched.
+ precision_thr (float): Precision threshold when one (gt_box, det_box)
+ pair is matched.
+
+ Returns:
+ hmean (tuple[dict]): Tuple of dicts indicates the hmean for the dataset
+ and all images.
+ """
+ assert utils.is_3dlist(pred_boxes)
+ assert utils.is_3dlist(gt_boxes)
+ assert utils.is_3dlist(gt_ignored_boxes)
+ assert 0 <= iou_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ img_num = len(pred_boxes)
+ assert img_num == len(gt_boxes)
+ assert img_num == len(gt_ignored_boxes)
+
+ dataset_gt_num = 0
+ dataset_pred_num = 0
+ dataset_hit_num = 0
+
+ img_results = []
+
+ for i in range(img_num):
+ gt = gt_boxes[i]
+ gt_ignored = gt_ignored_boxes[i]
+ pred = pred_boxes[i]
+
+ gt_num = len(gt)
+ gt_ignored_num = len(gt_ignored)
+ pred_num = len(pred)
+
+ hit_num = 0
+
+ # get gt polygons.
+ gt_all = gt + gt_ignored
+ gt_polys = [eval_utils.points2polygon(p) for p in gt_all]
+ gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
+ gt_num = len(gt_polys)
+ pred_polys, _, pred_ignored_index = eval_utils.ignore_pred(
+ pred, gt_ignored_index, gt_polys, precision_thr)
+
+ # match.
+ if gt_num > 0 and pred_num > 0:
+ sz = [gt_num, pred_num]
+ iou_mat = np.zeros(sz)
+
+ gt_hit = np.zeros(gt_num, np.int8)
+ pred_hit = np.zeros(pred_num, np.int8)
+
+ for gt_id in range(gt_num):
+ for pred_id in range(pred_num):
+ gt_pol = gt_polys[gt_id]
+ det_pol = pred_polys[pred_id]
+
+ iou_mat[gt_id,
+ pred_id] = eval_utils.poly_iou(det_pol, gt_pol)
+
+ for gt_id in range(gt_num):
+ for pred_id in range(pred_num):
+ if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0
+ or gt_id in gt_ignored_index
+ or pred_id in pred_ignored_index):
+ continue
+ if iou_mat[gt_id, pred_id] > iou_thr:
+ gt_hit[gt_id] = 1
+ pred_hit[pred_id] = 1
+ hit_num += 1
+
+ gt_care_number = gt_num - gt_ignored_num
+ pred_care_number = pred_num - len(pred_ignored_index)
+
+ r, p, h = eval_utils.compute_hmean(hit_num, hit_num, gt_care_number,
+ pred_care_number)
+
+ img_results.append({'recall': r, 'precision': p, 'hmean': h})
+
+ dataset_hit_num += hit_num
+ dataset_gt_num += gt_care_number
+ dataset_pred_num += pred_care_number
+
+ dataset_r, dataset_p, dataset_h = eval_utils.compute_hmean(
+ dataset_hit_num, dataset_hit_num, dataset_gt_num, dataset_pred_num)
+
+ dataset_results = {
+ 'num_gts': dataset_gt_num,
+ 'num_dets': dataset_pred_num,
+ 'num_match': dataset_hit_num,
+ 'recall': dataset_r,
+ 'precision': dataset_p,
+ 'hmean': dataset_h
+ }
+
+ return dataset_results, img_results
diff --git a/mmocr/core/evaluation/kie_metric.py b/mmocr/core/evaluation/kie_metric.py
new file mode 100644
index 00000000..00dc2387
--- /dev/null
+++ b/mmocr/core/evaluation/kie_metric.py
@@ -0,0 +1,27 @@
+import torch
+
+
+def compute_f1_score(preds, gts, ignores=[]):
+ """Compute the F1-score of prediction.
+
+ Args:
+ preds (Tensor): The predicted probability NxC map
+ with N and C being the sample number and class
+ number respectively.
+ gts (Tensor): The ground truth vector of size N.
+ ignores (list): The index set of classes that are ignored when
+ reporting results.
+ Note: all samples are participated in computing.
+
+ Returns:
+ The numpy list of f1-scores of valid classes.
+ """
+ C = preds.size(1)
+ classes = torch.LongTensor(sorted(set(range(C)) - set(ignores)))
+ hist = torch.bincount(
+ gts * C + preds.argmax(1), minlength=C**2).view(C, C).float()
+ diag = torch.diag(hist)
+ recalls = diag / hist.sum(1).clamp(min=1)
+ precisions = diag / hist.sum(0).clamp(min=1)
+ f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8)
+ return f1[classes].cpu().numpy()
diff --git a/mmocr/core/evaluation/ocr_metric.py b/mmocr/core/evaluation/ocr_metric.py
new file mode 100644
index 00000000..5c5124f0
--- /dev/null
+++ b/mmocr/core/evaluation/ocr_metric.py
@@ -0,0 +1,133 @@
+import re
+from difflib import SequenceMatcher
+
+import Levenshtein
+
+
+def cal_true_positive_char(pred, gt):
+ """Calculate correct character number in prediction.
+
+ Args:
+ pred (str): Prediction text.
+ gt (str): Ground truth text.
+
+ Returns:
+ true_positive_char_num (int): The true positive number.
+ """
+
+ all_opt = SequenceMatcher(None, pred, gt)
+ true_positive_char_num = 0
+ for opt, _, _, s2, e2 in all_opt.get_opcodes():
+ if opt == 'equal':
+ true_positive_char_num += (e2 - s2)
+ else:
+ pass
+ return true_positive_char_num
+
+
+def count_matches(pred_texts, gt_texts):
+ """Count the various match number for metric calculation.
+
+ Args:
+ pred_texts (list[str]): Predicted text string.
+ gt_texts (list[str]): Ground truth text string.
+
+ Returns:
+ match_res: (dict[str: int]): Match number used for
+ metric calculation.
+ """
+ match_res = {
+ 'gt_char_num': 0,
+ 'pred_char_num': 0,
+ 'true_positive_char_num': 0,
+ 'gt_word_num': 0,
+ 'match_word_num': 0,
+ 'match_word_ignore_case': 0,
+ 'match_word_ignore_case_symbol': 0
+ }
+ comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
+ norm_ed_sum = 0.0
+ for pred_text, gt_text in zip(pred_texts, gt_texts):
+ if gt_text == pred_text:
+ match_res['match_word_num'] += 1
+ gt_text_lower = gt_text.lower()
+ pred_text_lower = pred_text.lower()
+ if gt_text_lower == pred_text_lower:
+ match_res['match_word_ignore_case'] += 1
+ gt_text_lower_ignore = comp.sub('', gt_text_lower)
+ pred_text_lower_ignore = comp.sub('', pred_text_lower)
+ if gt_text_lower_ignore == pred_text_lower_ignore:
+ match_res['match_word_ignore_case_symbol'] += 1
+ match_res['gt_word_num'] += 1
+
+ # normalized edit distance
+ edit_dist = Levenshtein.distance(pred_text_lower_ignore,
+ gt_text_lower_ignore)
+ norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
+ len(pred_text_lower_ignore))
+ norm_ed_sum += norm_ed
+
+ # number to calculate char level recall & precision
+ match_res['gt_char_num'] += len(gt_text_lower_ignore)
+ match_res['pred_char_num'] += len(pred_text_lower_ignore)
+ true_positive_char_num = cal_true_positive_char(
+ pred_text_lower_ignore, gt_text_lower_ignore)
+ match_res['true_positive_char_num'] += true_positive_char_num
+
+ normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
+ match_res['ned'] = normalized_edit_distance
+
+ return match_res
+
+
+def eval_ocr_metric(pred_texts, gt_texts):
+ """Evaluate the text recognition performance with metric: word accuracy and
+ 1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
+
+ Args:
+ pred_texts (list[str]): Text strings of prediction.
+ gt_texts (list[str]): Text strings of ground truth.
+
+ Returns:
+ eval_res (dict[str: float]): Metric dict for text recognition, include:
+ - word_acc: Accuracy in word level.
+ - word_acc_ignore_case: Accuracy in word level, ignore letter case.
+ - word_acc_ignore_case_symbol: Accuracy in word level, ignore
+ letter case and symbol. (default metric for
+ academic evaluation)
+ - char_recall: Recall in character level, ignore
+ letter case and symbol.
+ - char_precision: Precision in character level, ignore
+ letter case and symbol.
+ - 1-N.E.D: 1 - normalized_edit_distance.
+ """
+ assert isinstance(pred_texts, list)
+ assert isinstance(gt_texts, list)
+ assert len(pred_texts) == len(gt_texts)
+
+ match_res = count_matches(pred_texts, gt_texts)
+ eps = 1e-8
+ char_recall = 1.0 * match_res['true_positive_char_num'] / (
+ eps + match_res['gt_char_num'])
+ char_precision = 1.0 * match_res['true_positive_char_num'] / (
+ eps + match_res['pred_char_num'])
+ word_acc = 1.0 * match_res['match_word_num'] / (
+ eps + match_res['gt_word_num'])
+ word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / (
+ eps + match_res['gt_word_num'])
+ word_acc_ignore_case_symbol = 1.0 * match_res[
+ 'match_word_ignore_case_symbol'] / (
+ eps + match_res['gt_word_num'])
+
+ eval_res = {}
+ eval_res['word_acc'] = word_acc
+ eval_res['word_acc_ignore_case'] = word_acc_ignore_case
+ eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol
+ eval_res['char_recall'] = char_recall
+ eval_res['char_precision'] = char_precision
+ eval_res['1-N.E.D'] = 1.0 - match_res['ned']
+
+ for key, value in eval_res.items():
+ eval_res[key] = float('{:.4f}'.format(value))
+
+ return eval_res
diff --git a/mmocr/core/evaluation/utils.py b/mmocr/core/evaluation/utils.py
new file mode 100644
index 00000000..8569fde8
--- /dev/null
+++ b/mmocr/core/evaluation/utils.py
@@ -0,0 +1,496 @@
+import numpy as np
+import Polygon as plg
+
+import mmocr.utils as utils
+
+
+def ignore_pred(pred_boxes, gt_ignored_index, gt_polys, precision_thr):
+ """Ignore the predicted box if it hits any ignored ground truth.
+
+ Args:
+ pred_boxes (list[ndarray or list]): The predicted boxes of one image.
+ gt_ignored_index (list[int]): The ignored ground truth index list.
+ gt_polys (list[Polygon]): The polygon list of one image.
+ precision_thr (float): The precision threshold.
+
+ Returns:
+ pred_polys (list[Polygon]): The predicted polygon list.
+ pred_points (list[list]): The predicted box list represented
+ by point sequences.
+ pred_ignored_index (list[int]): The ignored text index list.
+ """
+
+ assert isinstance(pred_boxes, list)
+ assert isinstance(gt_ignored_index, list)
+ assert isinstance(gt_polys, list)
+ assert 0 <= precision_thr <= 1
+
+ pred_polys = []
+ pred_points = []
+ pred_ignored_index = []
+
+ gt_ignored_num = len(gt_ignored_index)
+ # get detection polygons
+ for box_id, box in enumerate(pred_boxes):
+ poly = points2polygon(box)
+ pred_polys.append(poly)
+ pred_points.append(box)
+
+ if gt_ignored_num < 1:
+ continue
+
+ # ignore the current detection box
+ # if its overlap with any ignored gt > precision_thr
+ for ignored_box_id in gt_ignored_index:
+ ignored_box = gt_polys[ignored_box_id]
+ inter_area, _ = poly_intersection(poly, ignored_box)
+ area = poly.area()
+ precision = 0 if area == 0 else inter_area / area
+ if precision > precision_thr:
+ pred_ignored_index.append(box_id)
+ break
+
+ return pred_polys, pred_points, pred_ignored_index
+
+
+def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num):
+ """Compute hmean given hit number, ground truth number and prediction
+ number.
+
+ Args:
+ accum_hit_recall (int|float): Accumulated hits for computing recall.
+ accum_hit_prec (int|float): Accumulated hits for computing precision.
+ gt_num (int): Ground truth number.
+ pred_num (int): Prediction number.
+
+ Returns:
+ recall (float): The recall value.
+ precision (float): The precision value.
+ hmean (float): The hmean value.
+ """
+
+ assert isinstance(accum_hit_recall, (float, int))
+ assert isinstance(accum_hit_prec, (float, int))
+
+ assert isinstance(gt_num, int)
+ assert isinstance(pred_num, int)
+ assert accum_hit_recall >= 0.0
+ assert accum_hit_prec >= 0.0
+ assert gt_num >= 0.0
+ assert pred_num >= 0.0
+
+ if gt_num == 0:
+ recall = 1.0
+ precision = 0.0 if pred_num > 0 else 1.0
+ else:
+ recall = float(accum_hit_recall) / gt_num
+ precision = 0.0 if pred_num == 0 else float(accum_hit_prec) / pred_num
+
+ denom = recall + precision
+
+ hmean = 0.0 if denom == 0 else (2.0 * precision * recall / denom)
+
+ return recall, precision, hmean
+
+
+def box2polygon(box):
+ """Convert box to polygon.
+
+ Args:
+ box (ndarray or list): A ndarray or a list of shape (4)
+ that indicates 2 points.
+
+ Returns:
+ polygon (Polygon): A polygon object.
+ """
+ if isinstance(box, list):
+ box = np.array(box)
+
+ assert isinstance(box, np.ndarray)
+ assert box.size == 4
+ boundary = np.array(
+ [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]])
+
+ point_mat = boundary.reshape([-1, 2])
+ return plg.Polygon(point_mat)
+
+
+def points2polygon(points):
+ """Convert k points to 1 polygon.
+
+ Args:
+ points (ndarray or list): A ndarray or a list of shape (2k)
+ that indicates k points.
+
+ Returns:
+ polygon (Polygon): A polygon object.
+ """
+ if isinstance(points, list):
+ points = np.array(points)
+
+ assert isinstance(points, np.ndarray)
+ assert (points.size % 2 == 0) and (points.size >= 8)
+
+ point_mat = points.reshape([-1, 2])
+ return plg.Polygon(point_mat)
+
+
+def poly_intersection(poly_det, poly_gt):
+ """Calculate the intersection area between two polygon.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ intersection_area (float): The intersection area between two polygons.
+ """
+ assert isinstance(poly_det, plg.Polygon)
+ assert isinstance(poly_gt, plg.Polygon)
+
+ poly_inter = poly_det & poly_gt
+ if len(poly_inter) == 0:
+ return 0, poly_inter
+ return poly_inter.area(), poly_inter
+
+
+def poly_union(poly_det, poly_gt):
+ """Calculate the union area between two polygon.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ union_area (float): The union area between two polygons.
+ """
+ assert isinstance(poly_det, plg.Polygon)
+ assert isinstance(poly_gt, plg.Polygon)
+
+ area_det = poly_det.area()
+ area_gt = poly_gt.area()
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
+ return area_det + area_gt - area_inters
+
+
+def boundary_iou(src, target):
+ """Calculate the IOU between two boundaries.
+
+ Args:
+ src (list): Source boundary.
+ target (list): Target boundary.
+
+ Returns:
+ iou (float): The iou between two boundaries.
+ """
+ assert utils.valid_boundary(src, False)
+ assert utils.valid_boundary(target, False)
+ src_poly = points2polygon(src)
+ target_poly = points2polygon(target)
+
+ return poly_iou(src_poly, target_poly)
+
+
+def poly_iou(poly_det, poly_gt):
+ """Calculate the IOU between two polygons.
+
+ Args:
+ poly_det (Polygon): A polygon predicted by detector.
+ poly_gt (Polygon): A gt polygon.
+
+ Returns:
+ iou (float): The IOU between two polygons.
+ """
+ assert isinstance(poly_det, plg.Polygon)
+ assert isinstance(poly_gt, plg.Polygon)
+ area_inters, _ = poly_intersection(poly_det, poly_gt)
+
+ return area_inters / poly_union(poly_det, poly_gt)
+
+
+def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr,
+ precision_thr):
+ """One-to-One match gt and det with icdar2013 standards.
+
+ Args:
+ gt_id (int): The ground truth id index.
+ det_id (int): The detection result id index.
+ recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the recall ratio of gt i to det j.
+ precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the precision ratio of gt i to det j.
+ recall_thr (float): The recall threshold.
+ precision_thr (float): The precision threshold.
+ Returns:
+ True|False: Whether the gt and det are matched.
+ """
+ assert isinstance(gt_id, int)
+ assert isinstance(det_id, int)
+ assert isinstance(recall_mat, np.ndarray)
+ assert isinstance(precision_mat, np.ndarray)
+ assert 0 <= recall_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ cont = 0
+ for i in range(recall_mat.shape[1]):
+ if recall_mat[gt_id,
+ i] > recall_thr and precision_mat[gt_id,
+ i] > precision_thr:
+ cont += 1
+ if cont != 1:
+ return False
+
+ cont = 0
+ for i in range(recall_mat.shape[0]):
+ if recall_mat[i, det_id] > recall_thr and precision_mat[
+ i, det_id] > precision_thr:
+ cont += 1
+ if cont != 1:
+ return False
+
+ if recall_mat[gt_id, det_id] > recall_thr and precision_mat[
+ gt_id, det_id] > precision_thr:
+ return True
+
+ return False
+
+
+def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_match_flag, det_match_flag,
+ det_ignored_index):
+ """One-to-Many match gt and detections with icdar2013 standards.
+
+ Args:
+ gt_id (int): gt index.
+ recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the recall ratio of gt i to det j.
+ precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the precision ratio of gt i to det j.
+ recall_thr (float): The recall threshold.
+ precision_thr (float): The precision threshold.
+ gt_match_flag (ndarray): An array indicates each gt matched already.
+ det_match_flag (ndarray): An array indicates each box has been
+ matched already or not.
+ det_ignored_index (list): A list indicates each detection box can be
+ ignored or not.
+
+ Returns:
+ tuple (True|False, list): The first indicates the gt is matched or not;
+ the second is the matched detection ids.
+ """
+ assert isinstance(gt_id, int)
+ assert isinstance(recall_mat, np.ndarray)
+ assert isinstance(precision_mat, np.ndarray)
+ assert 0 <= recall_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ assert isinstance(gt_match_flag, list)
+ assert isinstance(det_match_flag, list)
+ assert isinstance(det_ignored_index, list)
+
+ many_sum = 0.
+ det_ids = []
+ for det_id in range(recall_mat.shape[1]):
+ if gt_match_flag[gt_id] == 0 and det_match_flag[
+ det_id] == 0 and det_id not in det_ignored_index:
+ if precision_mat[gt_id, det_id] >= precision_thr:
+ many_sum += recall_mat[gt_id, det_id]
+ det_ids.append(det_id)
+ if many_sum >= recall_thr:
+ return True, det_ids
+ return False, []
+
+
+def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr,
+ precision_thr, gt_match_flag, det_match_flag,
+ gt_ignored_index):
+ """Many-to-One match gt and detections with icdar2013 standards.
+
+ Args:
+ det_id (int): Detection index.
+ recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the recall ratio of gt i to det j.
+ precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
+ being the precision ratio of gt i to det j.
+ recall_thr (float): The recall threshold.
+ precision_thr (float): The precision threshold.
+ gt_match_flag (ndarray): An array indicates each gt has been matched
+ already.
+ det_match_flag (ndarray): An array indicates each detection box has
+ been matched already or not.
+ gt_ignored_index (list): A list indicates each gt box can be ignored
+ or not.
+
+ Returns:
+ tuple (True|False, list): The first indicates the detection is matched
+ or not; the second is the matched gt ids.
+ """
+ assert isinstance(det_id, int)
+ assert isinstance(recall_mat, np.ndarray)
+ assert isinstance(precision_mat, np.ndarray)
+ assert 0 <= recall_thr <= 1
+ assert 0 <= precision_thr <= 1
+
+ assert isinstance(gt_match_flag, list)
+ assert isinstance(det_match_flag, list)
+ assert isinstance(gt_ignored_index, list)
+ many_sum = 0.
+ gt_ids = []
+ for gt_id in range(recall_mat.shape[0]):
+ if gt_match_flag[gt_id] == 0 and det_match_flag[
+ det_id] == 0 and gt_id not in gt_ignored_index:
+ if recall_mat[gt_id, det_id] >= recall_thr:
+ many_sum += precision_mat[gt_id, det_id]
+ gt_ids.append(gt_id)
+ if many_sum >= precision_thr:
+ return True, gt_ids
+ return False, []
+
+
+def points_center(points):
+
+ assert isinstance(points, np.ndarray)
+ assert points.size % 2 == 0
+
+ points = points.reshape([-1, 2])
+ return np.mean(points, axis=0)
+
+
+def point_distance(p1, p2):
+ assert isinstance(p1, np.ndarray)
+ assert isinstance(p2, np.ndarray)
+
+ assert p1.size == 2
+ assert p2.size == 2
+
+ dist = np.square(p2 - p1)
+ dist = np.sum(dist)
+ dist = np.sqrt(dist)
+ return dist
+
+
+def box_center_distance(b1, b2):
+ assert isinstance(b1, np.ndarray)
+ assert isinstance(b2, np.ndarray)
+ return point_distance(points_center(b1), points_center(b2))
+
+
+def box_diag(box):
+ assert isinstance(box, np.ndarray)
+ assert box.size == 8
+
+ return point_distance(box[0:2], box[4:6])
+
+
+def filter_2dlist_result(results, scores, score_thr):
+ """Find out detected results whose score > score_thr.
+
+ Args:
+ results (list[list[float]]): The result list.
+ score (list): The score list.
+ score_thr (float): The score threshold.
+ Returns:
+ valid_results (list[list[float]]): The valid results.
+ valid_score (list[float]): The scores which correspond to the valid
+ results.
+ """
+ assert isinstance(results, list)
+ assert len(results) == len(scores)
+ assert isinstance(score_thr, float)
+ assert 0 <= score_thr <= 1
+
+ inds = np.array(scores) > score_thr
+ valid_results = [results[inx] for inx in np.where(inds)[0].tolist()]
+ valid_scores = [scores[inx] for inx in np.where(inds)[0].tolist()]
+ return valid_results, valid_scores
+
+
+def filter_result(results, scores, score_thr):
+ """Find out detected results whose score > score_thr.
+
+ Args:
+ results (ndarray): The results matrix of shape (n, k).
+ score (ndarray): The score vector of shape (n,).
+ score_thr (float): The score threshold.
+ Returns:
+ valid_results (ndarray): The valid results of shape (m,k) with m<=n.
+ valid_score (ndarray): The scores which correspond to the
+ valid results.
+ """
+ assert results.ndim == 2
+ assert scores.shape[0] == results.shape[0]
+ assert isinstance(score_thr, float)
+ assert 0 <= score_thr <= 1
+
+ inds = scores > score_thr
+ valid_results = results[inds, :]
+ valid_scores = scores[inds]
+ return valid_results, valid_scores
+
+
+def select_top_boundary(boundaries_list, scores_list, score_thr):
+ """Select poly boundaries with scores >= score_thr.
+
+ Args:
+ boundaries_list (list[list[list[float]]]): List of boundaries.
+ The 1st, 2rd, and 3rd indices are for image, text and
+ vertice, respectively.
+ scores_list (list(list[float])): List of lists of scores.
+ score_thr (float): The score threshold to filter out bboxes.
+
+ Returns:
+ selected_bboxes (list[list[list[float]]]): List of boundaries.
+ The 1st, 2rd, and 3rd indices are for image, text and vertice,
+ respectively.
+ """
+ assert isinstance(boundaries_list, list)
+ assert isinstance(scores_list, list)
+ assert isinstance(score_thr, float)
+ assert len(boundaries_list) == len(scores_list)
+ assert 0 <= score_thr <= 1
+
+ selected_boundaries = []
+ for boundary, scores in zip(boundaries_list, scores_list):
+ if len(scores) > 0:
+ assert len(scores) == len(boundary)
+ inds = [
+ iter for iter in range(len(scores))
+ if scores[iter] >= score_thr
+ ]
+ selected_boundaries.append([boundary[i] for i in inds])
+ else:
+ selected_boundaries.append(boundary)
+ return selected_boundaries
+
+
+def select_bboxes_via_score(bboxes_list, scores_list, score_thr):
+ """Select bboxes with scores >= score_thr.
+
+ Args:
+ bboxes_list (list[ndarray]): List of bboxes. Each element is ndarray of
+ shape (n,8)
+ scores_list (list(list[float])): List of lists of scores.
+ score_thr (float): The score threshold to filter out bboxes.
+
+ Returns:
+ selected_bboxes (list[ndarray]): List of bboxes. Each element is
+ ndarray of shape (m,8) with m<=n.
+ """
+ assert isinstance(bboxes_list, list)
+ assert isinstance(scores_list, list)
+ assert isinstance(score_thr, float)
+ assert len(bboxes_list) == len(scores_list)
+ assert 0 <= score_thr <= 1
+
+ selected_bboxes = []
+ for bboxes, scores in zip(bboxes_list, scores_list):
+ if len(scores) > 0:
+ assert len(scores) == bboxes.shape[0]
+ inds = [
+ iter for iter in range(len(scores))
+ if scores[iter] >= score_thr
+ ]
+ selected_bboxes.append(bboxes[inds, :])
+ else:
+ selected_bboxes.append(bboxes)
+ return selected_bboxes
diff --git a/mmocr/core/mask.py b/mmocr/core/mask.py
new file mode 100644
index 00000000..c9b46d19
--- /dev/null
+++ b/mmocr/core/mask.py
@@ -0,0 +1,101 @@
+import cv2
+import numpy as np
+
+import mmocr.utils as utils
+
+
+def points2boundary(points, text_repr_type, text_score=None, min_width=-1):
+ """Convert a text mask represented by point coordinates sequence into a
+ text boundary.
+
+ Args:
+ points (ndarray): Mask index of size (n, 2).
+ text_repr_type (str): Text instance encoding type
+ ('quad' for quadrangle or 'poly' for polygon).
+ text_score (float): Text score.
+
+ Returns:
+ boundary (list[float]): The text boundary point coordinates (x, y)
+ list. Return None if no text boundary found.
+ """
+ assert isinstance(points, np.ndarray)
+ assert points.shape[1] == 2
+ assert text_repr_type in ['quad', 'poly']
+ assert text_score is None or 0 <= text_score <= 1
+
+ if text_repr_type == 'quad':
+ rect = cv2.minAreaRect(points)
+ vertices = cv2.boxPoints(rect)
+ boundary = []
+ if min(rect[1]) > min_width:
+ boundary = [p for p in vertices.flatten().tolist()]
+
+ elif text_repr_type == 'poly':
+
+ height = np.max(points[:, 1]) + 10
+ width = np.max(points[:, 0]) + 10
+
+ mask = np.zeros((height, width), np.uint8)
+ mask[points[:, 1], points[:, 0]] = 255
+
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_SIMPLE)
+ boundary = list(contours[0].flatten().tolist())
+
+ if text_score is not None:
+ boundary = boundary + [text_score]
+ if len(boundary) < 8:
+ return None
+
+ return boundary
+
+
+def seg2boundary(seg, text_repr_type, text_score=None):
+ """Convert a segmentation mask to a text boundary.
+
+ Args:
+ seg (ndarray): The segmentation mask.
+ text_repr_type (str): Text instance encoding type
+ ('quad' for quadrangle or 'poly' for polygon).
+ text_score (float): The text score.
+
+ Returns:
+ boundary (list): The text boundary. Return None if no text found.
+ """
+ assert isinstance(seg, np.ndarray)
+ assert isinstance(text_repr_type, str)
+ assert text_score is None or 0 <= text_score <= 1
+
+ points = np.where(seg)
+ # x, y order
+ points = np.concatenate([points[1], points[0]]).reshape(2, -1).transpose()
+ boundary = None
+ if len(points) != 0:
+ boundary = points2boundary(points, text_repr_type, text_score)
+
+ return boundary
+
+
+def extract_boundary(result):
+ """Extract boundaries and their scores from result.
+
+ Args:
+ result (dict): The detection result with the key 'boundary_result'
+ of one image.
+
+ Returns:
+ boundaries_with_scores (list[list[float]]): The boundary and score
+ list.
+ boundaries (list[list[float]]): The boundary list.
+ scores (list[float]): The boundary score list.
+ """
+ assert isinstance(result, dict)
+ assert 'boundary_result' in result.keys()
+
+ boundaries_with_scores = result['boundary_result']
+ assert utils.is_2dlist(boundaries_with_scores)
+
+ boundaries = [b[:-1] for b in boundaries_with_scores]
+ scores = [b[-1] for b in boundaries_with_scores]
+
+ return (boundaries_with_scores, boundaries, scores)
diff --git a/mmocr/core/visualize.py b/mmocr/core/visualize.py
new file mode 100644
index 00000000..f1965970
--- /dev/null
+++ b/mmocr/core/visualize.py
@@ -0,0 +1,419 @@
+import math
+import warnings
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+
+import mmocr.utils as utils
+
+
+def overlay_mask_img(img, mask):
+ """Draw mask boundaries on image for visualization.
+
+ Args:
+ img (ndarray): The input image.
+ mask (ndarray): The instance mask.
+
+ Returns:
+ img (ndarray): The output image with instance boundaries on it.
+ """
+ assert isinstance(img, np.ndarray)
+ assert isinstance(mask, np.ndarray)
+
+ contours, _ = cv2.findContours(
+ mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ cv2.drawContours(img, contours, -1, (0, 255, 0), 1)
+
+ return img
+
+
+def show_feature(features, names, to_uint8, out_file=None):
+ """Visualize a list of feature maps.
+
+ Args:
+ features (list(ndarray)): The feature map list.
+ names (list(str)): The visualized title list.
+ to_uint8 (list(1|0)): The list indicating whether to convent
+ feature maps to uint8.
+ out_file (str): The output file name. If set to None,
+ the output image will be shown without saving.
+ """
+ assert utils.is_ndarray_list(features)
+ assert utils.is_type_list(names, str)
+ assert utils.is_type_list(to_uint8, int)
+ assert utils.is_none_or_type(out_file, str)
+ assert utils.equal_len(features, names, to_uint8)
+
+ num = len(features)
+ row = col = math.ceil(math.sqrt(num))
+
+ for i, (f, n) in enumerate(zip(features, names)):
+ plt.subplot(row, col, i + 1)
+ plt.title(n)
+ if to_uint8[i]:
+ f = f.astype(np.uint8)
+ plt.imshow(f)
+ if out_file is None:
+ plt.show()
+ else:
+ plt.savefig(out_file)
+
+
+def show_img_boundary(img, boundary):
+ """Show image and instance boundaires.
+
+ Args:
+ img (ndarray): The input image.
+ boundary (list[float or int]): The input boundary.
+ """
+ assert isinstance(img, np.ndarray)
+ assert utils.is_type_list(boundary, int) or utils.is_type_list(
+ boundary, float)
+
+ cv2.polylines(
+ img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)],
+ True,
+ color=(0, 255, 0),
+ thickness=1)
+ plt.imshow(img)
+ plt.show()
+
+
+def show_pred_gt(preds,
+ gts,
+ show=False,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Show detection and ground truth for one image.
+
+ Args:
+ preds (list[list[float]]): The detection boundary list.
+ gts (list[list[float]]): The ground truth boundary list.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): The value of waitKey param.
+ out_file (str): The filename of the output.
+ """
+ assert utils.is_2dlist(preds)
+ assert utils.is_2dlist(gts)
+ assert isinstance(show, bool)
+ assert isinstance(win_name, str)
+ assert isinstance(wait_time, int)
+ assert utils.is_none_or_type(out_file, str)
+
+ p_xy = [p for boundary in preds for p in boundary]
+ gt_xy = [g for gt in gts for g in gt]
+
+ max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0)
+
+ width = int(max_xy[0]) + 100
+ height = int(max_xy[1]) + 100
+
+ img = np.ones((height, width, 3), np.int8) * 255
+ pred_color = mmcv.color_val('red')
+ gt_color = mmcv.color_val('blue')
+ thickness = 1
+
+ for boundary in preds:
+ cv2.polylines(
+ img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)],
+ True,
+ color=pred_color,
+ thickness=thickness)
+ for gt in gts:
+ cv2.polylines(
+ img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)],
+ True,
+ color=gt_color,
+ thickness=thickness)
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ return img
+
+
+def imshow_pred_boundary(img,
+ boundaries_with_scores,
+ labels,
+ score_thr=0,
+ boundary_color='blue',
+ text_color='blue',
+ thickness=1,
+ font_scale=0.5,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None,
+ show_score=False):
+ """Draw boundaries and class labels (with scores) on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ boundaries_with_scores (list[list[float]]): Boundaries with scores.
+ labels (list[int]): Labels of boundaries.
+ score_thr (float): Minimum score of boundaries to be shown.
+ boundary_color (str or tuple or :obj:`Color`): Color of boundaries.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename of the output.
+ show_score (bool): Whether to show text instance score.
+ """
+ assert isinstance(img, (str, np.ndarray))
+ assert utils.is_2dlist(boundaries_with_scores)
+ assert utils.is_type_list(labels, int)
+ assert utils.equal_len(boundaries_with_scores, labels)
+ if len(boundaries_with_scores) == 0:
+ warnings.warn('0 text found in ' + out_file)
+ return
+
+ utils.valid_boundary(boundaries_with_scores[0])
+ img = mmcv.imread(img)
+
+ scores = np.array([b[-1] for b in boundaries_with_scores])
+ inds = scores > score_thr
+ boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]]
+ scores = [scores[i] for i in np.where(inds)[0]]
+ labels = [labels[i] for i in np.where(inds)[0]]
+
+ boundary_color = mmcv.color_val(boundary_color)
+ text_color = mmcv.color_val(text_color)
+ font_scale = 0.5
+
+ for boundary, score, label in zip(boundaries, scores, labels):
+ boundary_int = np.array(boundary).astype(np.int32)
+
+ cv2.polylines(
+ img, [boundary_int.reshape(-1, 1, 2)],
+ True,
+ color=boundary_color,
+ thickness=thickness)
+
+ if show_score:
+ label_text = f'{score:.02f}'
+ cv2.putText(img, label_text,
+ (boundary_int[0], boundary_int[1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ return img
+
+
+def imshow_text_char_boundary(img,
+ text_quads,
+ boundaries,
+ char_quads,
+ chars,
+ show=False,
+ thickness=1,
+ font_scale=0.5,
+ win_name='',
+ wait_time=-1,
+ out_file=None):
+ """Draw text boxes and char boxes on img.
+
+ Args:
+ img (str or ndarray): The img to be displayed.
+ text_quads (list[list[int|float]]): The text boxes.
+ boundaries (list[list[int|float]]): The boundary list.
+ char_quads (list[list[list[int|float]]]): A 2d list of char boxes.
+ char_quads[i] is for the ith text, and char_quads[i][j] is the jth
+ char of the ith text.
+ chars (list[list[char]]). The string for each text box.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename of the output.
+ """
+ assert isinstance(img, (np.ndarray, str))
+ assert utils.is_2dlist(text_quads)
+ assert utils.is_2dlist(boundaries)
+ assert utils.is_3dlist(char_quads)
+ assert utils.is_2dlist(chars)
+ assert utils.equal_len(text_quads, char_quads, boundaries)
+
+ img = mmcv.imread(img)
+ char_color = [mmcv.color_val('blue'), mmcv.color_val('green')]
+ text_color = mmcv.color_val('red')
+ text_inx = 0
+ for text_box, boundary, char_box, txt in zip(text_quads, boundaries,
+ char_quads, chars):
+ text_box = np.array(text_box)
+ boundary = np.array(boundary)
+
+ text_box = text_box.reshape(-1, 2).astype(np.int32)
+ cv2.polylines(
+ img, [text_box.reshape(-1, 1, 2)],
+ True,
+ color=text_color,
+ thickness=thickness)
+ if boundary.shape[0] > 0:
+ cv2.polylines(
+ img, [boundary.reshape(-1, 1, 2)],
+ True,
+ color=text_color,
+ thickness=thickness)
+
+ for b in char_box:
+ b = np.array(b)
+ c = char_color[text_inx % 2]
+ b = b.astype(np.int32)
+ cv2.polylines(
+ img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness)
+
+ label_text = ''.join(txt)
+ cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+ text_inx = text_inx + 1
+
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ return img
+
+
+def tile_image(images):
+ """Combined multiple images to one vertically.
+
+ Args:
+ images (list[np.ndarray]): Images to be combined.
+ """
+ assert isinstance(images, list)
+ assert len(images) > 0
+
+ for i, _ in enumerate(images):
+ if len(images[i].shape) == 2:
+ images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR)
+
+ widths = [img.shape[1] for img in images]
+ heights = [img.shape[0] for img in images]
+ h, w = sum(heights), max(widths)
+ vis_img = np.zeros((h, w, 3), dtype=np.uint8)
+
+ offset_y = 0
+ for image in images:
+ img_h, img_w = image.shape[:2]
+ vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image
+ offset_y += img_h
+
+ return vis_img
+
+
+def imshow_text_label(img,
+ pred_label,
+ gt_label,
+ show=False,
+ win_name='',
+ wait_time=-1,
+ out_file=None):
+ """Draw predicted texts and ground truth texts on images.
+
+ Args:
+ img (str or np.ndarray): Image filename or loaded image.
+ pred_label (str): Predicted texts.
+ gt_label (str): Ground truth texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str): The filename of the output.
+ """
+ assert isinstance(img, (np.ndarray, str))
+ assert isinstance(pred_label, str)
+ assert isinstance(gt_label, str)
+ assert isinstance(show, bool)
+ assert isinstance(win_name, str)
+ assert isinstance(wait_time, int)
+
+ img = mmcv.imread(img)
+
+ src_h, src_w = img.shape[:2]
+ resize_height = 64
+ resize_width = int(1.0 * src_w / src_h * resize_height)
+ img = cv2.resize(img, (resize_width, resize_height))
+ h, w = img.shape[:2]
+ pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
+ gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
+
+ cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
+ (0, 0, 255), 2)
+ images = [pred_img, img]
+
+ if gt_label != '':
+ cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
+ (255, 0, 0), 2)
+ images.append(gt_img)
+
+ img = tile_image(images)
+
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ return img
+
+
+def imshow_edge_node(img,
+ result,
+ boxes,
+ idx_to_cls={},
+ show=False,
+ win_name='',
+ wait_time=-1,
+ out_file=None):
+
+ img = mmcv.imread(img)
+ h, w = img.shape[:2]
+
+ pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
+ max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
+ node_pred_label = max_idx.numpy().tolist()
+ node_pred_score = max_value.numpy().tolist()
+
+ for i, box in enumerate(boxes):
+ new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
+ [box[0], box[3]]]
+ Pts = np.array([new_box], np.int32)
+ cv2.polylines(
+ img, [Pts.reshape((-1, 1, 2))],
+ True,
+ color=(255, 255, 0),
+ thickness=1)
+ x_min = int(min([point[0] for point in new_box]))
+ y_min = int(min([point[1] for point in new_box]))
+
+ pred_label = str(node_pred_label[i])
+ if pred_label in idx_to_cls:
+ pred_label = idx_to_cls[pred_label]
+ pred_score = '{:.2f}'.format(node_pred_score[i])
+ text = pred_label + '(' + pred_score + ')'
+ cv2.putText(pred_img, text, (x_min * 2, y_min),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
+
+ vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
+ vis_img[:, :w] = img
+ vis_img[:, w:] = pred_img
+
+ if show:
+ mmcv.imshow(vis_img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(vis_img, out_file)
+
+ return vis_img
diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py
new file mode 100644
index 00000000..1dbba21c
--- /dev/null
+++ b/mmocr/datasets/__init__.py
@@ -0,0 +1,15 @@
+from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset
+from .base_dataset import BaseDataset
+from .icdar_dataset import IcdarDataset
+from .kie_dataset import KIEDataset
+from .ocr_dataset import OCRDataset
+from .ocr_seg_dataset import OCRSegDataset
+from .pipelines import CustomFormatBundle, DBNetTargets
+from .text_det_dataset import TextDetDataset
+from .utils import * # noqa: F401,F403
+
+__all__ = [
+ 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
+ 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
+ 'DBNetTargets', 'OCRSegDataset', 'KIEDataset'
+]
diff --git a/mmocr/datasets/base_dataset.py b/mmocr/datasets/base_dataset.py
new file mode 100644
index 00000000..f5c8f9d8
--- /dev/null
+++ b/mmocr/datasets/base_dataset.py
@@ -0,0 +1,166 @@
+import numpy as np
+from mmcv.utils import print_log
+from torch.utils.data import Dataset
+
+from mmdet.datasets.builder import DATASETS
+from mmdet.datasets.pipelines import Compose
+from mmocr.datasets.builder import build_loader
+
+
+@DATASETS.register_module()
+class BaseDataset(Dataset):
+ """Custom dataset for text detection, text recognition, and their
+ downstream tasks.
+
+ 1. The text detection annotation format is as follows:
+ The `annotations` field is optional for testing
+ (this is one line of anno_file, with line-json-str
+ converted to dict for visualizing only).
+
+ {
+ "file_name": "sample.jpg",
+ "height": 1080,
+ "width": 960,
+ "annotations":
+ [
+ {
+ "iscrowd": 0,
+ "category_id": 1,
+ "bbox": [357.0, 667.0, 804.0, 100.0],
+ "segmentation": [[361, 667, 710, 670,
+ 72, 767, 357, 763]]
+ }
+ ]
+ }
+
+ 2. The two text recognition annotation formats are as follows:
+ The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop
+ augmentation during training.
+
+ format1: sample.jpg hello
+ format2: sample.jpg 20 20 100 20 100 40 20 40 hello
+
+ Args:
+ ann_file (str): Annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ loader (dict): Dictionary to construct loader
+ to load annotation infos.
+ img_prefix (str, optional): Image prefix to generate full
+ image path.
+ test_mode (bool, optional): If set True, try...except will
+ be turned off in __getitem__.
+ """
+
+ def __init__(self,
+ ann_file,
+ loader,
+ pipeline,
+ img_prefix='',
+ test_mode=False):
+ super().__init__()
+ self.test_mode = test_mode
+ self.img_prefix = img_prefix
+ self.ann_file = ann_file
+ # load annotations
+ loader.update(ann_file=ann_file)
+ self.data_infos = build_loader(loader)
+ # processing pipeline
+ self.pipeline = Compose(pipeline)
+ # set group flag and class, no meaning
+ # for text detect and recognize
+ self._set_group_flag()
+ self.CLASSES = 0
+
+ def __len__(self):
+ return len(self.data_infos)
+
+ def _set_group_flag(self):
+ """Set flag."""
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['img_prefix'] = self.img_prefix
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+ img_info = self.data_infos[index]
+ results = dict(img_info=img_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, img_info):
+ """Get testing data from pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by
+ pipeline.
+ """
+ return self.prepare_train_img(img_info)
+
+ def _log_error_index(self, index):
+ """Logging data info of bad index."""
+ try:
+ data_info = self.data_infos[index]
+ img_prefix = self.img_prefix
+ print_log(f'Warning: skip broken file {data_info} '
+ f'with img_prefix {img_prefix}')
+ except Exception as e:
+ print_log(f'load index {index} with error {e}')
+
+ def _get_next_index(self, index):
+ """Get next index from dataset."""
+ self._log_error_index(index)
+ index = (index + 1) % len(self)
+ return index
+
+ def __getitem__(self, index):
+ """Get training/test data from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training/test data.
+ """
+ if self.test_mode:
+ return self.prepare_test_img(index)
+
+ while True:
+ try:
+ data = self.prepare_train_img(index)
+ if data is None:
+ raise Exception('prepared train data empty')
+ break
+ except Exception as e:
+ print_log(f'prepare index {index} with error {e}')
+ index = self._get_next_index(index)
+ return data
+
+ def format_results(self, results, **kwargs):
+ """Placeholder to format result to dataset-specific output."""
+ pass
+
+ def evaluate(self, results, metric=None, logger=None, **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ Returns:
+ dict[str: float]
+ """
+ raise NotImplementedError
diff --git a/mmocr/datasets/builder.py b/mmocr/datasets/builder.py
new file mode 100644
index 00000000..e7bcf423
--- /dev/null
+++ b/mmocr/datasets/builder.py
@@ -0,0 +1,14 @@
+from mmcv.utils import Registry, build_from_cfg
+
+LOADERS = Registry('loader')
+PARSERS = Registry('parser')
+
+
+def build_loader(cfg):
+ """Build anno file loader."""
+ return build_from_cfg(cfg, LOADERS)
+
+
+def build_parser(cfg):
+ """Build anno file parser."""
+ return build_from_cfg(cfg, PARSERS)
diff --git a/mmocr/datasets/icdar_dataset.py b/mmocr/datasets/icdar_dataset.py
new file mode 100644
index 00000000..34a97584
--- /dev/null
+++ b/mmocr/datasets/icdar_dataset.py
@@ -0,0 +1,158 @@
+import numpy as np
+from pycocotools.coco import COCO
+
+import mmocr.utils as utils
+from mmdet.datasets.builder import DATASETS
+from mmdet.datasets.coco import CocoDataset
+from mmocr.core.evaluation.hmean import eval_hmean
+
+
+@DATASETS.register_module()
+class IcdarDataset(CocoDataset):
+ CLASSES = ('text')
+
+ def __init__(self,
+ ann_file,
+ pipeline,
+ classes=None,
+ data_root=None,
+ img_prefix='',
+ seg_prefix=None,
+ proposal_file=None,
+ test_mode=False,
+ filter_empty_gt=True,
+ select_first_k=-1):
+ # select first k images for fast debugging.
+ self.select_first_k = select_first_k
+
+ super().__init__(ann_file, pipeline, classes, data_root, img_prefix,
+ seg_prefix, proposal_file, test_mode, filter_empty_gt)
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+
+ self.coco = COCO(ann_file)
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+
+ count = 0
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ count = count + 1
+ if count > self.select_first_k and self.select_first_k > 0:
+ break
+ return data_infos
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,
+ labels, masks, masks_ignore, seg_map. "masks" and
+ "masks_ignore" are represented by polygon boundary
+ point sequences.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ignore = []
+ gt_masks_ann = []
+
+ for ann in ann_info:
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ gt_masks_ignore.append(ann.get(
+ 'segmentation', None)) # to float32 for latter processing
+
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ gt_masks_ann.append(ann.get('segmentation', None))
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ seg_map = img_info['filename'].replace('jpg', 'png')
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks_ignore=gt_masks_ignore,
+ masks=gt_masks_ann,
+ seg_map=seg_map)
+
+ return ann
+
+ def evaluate(self,
+ results,
+ metric='hmean-iou',
+ logger=None,
+ score_thr=0.3,
+ rank_list=None,
+ **kwargs):
+ """Evaluate the hmean metric.
+
+ Args:
+ results (list[dict]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ rank_list (str): json file used to save eval result
+ of each image after ranking.
+ Returns:
+ dict[dict[str: float]]: The evaluation results.
+ """
+ assert utils.is_type_list(results, dict)
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['hmean-iou', 'hmean-ic13']
+ metrics = set(metrics) & set(allowed_metrics)
+
+ img_infos = []
+ ann_infos = []
+ for i in range(len(self)):
+ img_info = {'filename': self.data_infos[i]['file_name']}
+ img_infos.append(img_info)
+ ann_infos.append(self.get_ann_info(i))
+
+ eval_results = eval_hmean(
+ results,
+ img_infos,
+ ann_infos,
+ metrics=metrics,
+ score_thr=score_thr,
+ logger=logger,
+ rank_list=rank_list)
+
+ return eval_results
diff --git a/mmocr/datasets/kie_dataset.py b/mmocr/datasets/kie_dataset.py
new file mode 100644
index 00000000..79868a68
--- /dev/null
+++ b/mmocr/datasets/kie_dataset.py
@@ -0,0 +1,218 @@
+import copy
+from os import path as osp
+
+import numpy as np
+import torch
+
+import mmocr.utils as utils
+from mmdet.datasets.builder import DATASETS
+from mmocr.core import compute_f1_score
+from mmocr.datasets.base_dataset import BaseDataset
+from mmocr.datasets.pipelines.crop import sort_vertex
+
+
+@DATASETS.register_module()
+class KIEDataset(BaseDataset):
+ """
+ Args:
+ ann_file (str): Annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ loader (dict): Dictionary to construct loader
+ to load annotation infos.
+ img_prefix (str, optional): Image prefix to generate full
+ image path.
+ test_mode (bool, optional): If True, try...except will
+ be turned off in __getitem__.
+ dict_file (str): Character dict file path.
+ norm (float): Norm to map value from one range to another.
+ """
+
+ def __init__(self,
+ ann_file,
+ loader,
+ dict_file,
+ img_prefix='',
+ pipeline=None,
+ norm=10.,
+ directed=False,
+ test_mode=True,
+ **kwargs):
+ super().__init__(
+ ann_file,
+ loader,
+ pipeline,
+ img_prefix=img_prefix,
+ test_mode=test_mode)
+ assert osp.exists(dict_file)
+
+ self.norm = norm
+ self.directed = directed
+
+ self.dict = dict({'': 0})
+ with open(dict_file, 'r') as fr:
+ idx = 1
+ for line in fr:
+ char = line.strip()
+ self.dict[char] = idx
+ idx += 1
+
+ def pre_pipeline(self, results):
+ results['img_prefix'] = self.img_prefix
+ results['bbox_fields'] = []
+
+ def _parse_anno_info(self, annotations):
+ """Parse annotations of boxes, texts and labels for one image.
+ Args:
+ annotations (list[dict]): Annotations of one image, where
+ each dict is for one character.
+
+ Returns:
+ dict: A dict containing the following keys:
+
+ - bboxes (np.ndarray): Bbox in one image with shape:
+ box_num * 4.
+ - relations (np.ndarray): Relations between bbox with shape:
+ box_num * box_num * D.
+ - texts (np.ndarray): Text index with shape:
+ box_num * text_max_len.
+ - labels (np.ndarray): Box Labels with shape:
+ box_num * (box_num + 1).
+ """
+
+ assert utils.is_type_list(annotations, dict)
+ assert 'box' in annotations[0]
+ assert 'text' in annotations[0]
+ assert 'label' in annotations[0]
+
+ boxes, texts, text_inds, labels, edges = [], [], [], [], []
+ for ann in annotations:
+ box = ann['box']
+ x_list, y_list = box[0:8:2], box[1:9:2]
+ sorted_x_list, sorted_y_list = sort_vertex(x_list, y_list)
+ sorted_box = []
+ for x, y in zip(sorted_x_list, sorted_y_list):
+ sorted_box.append(x)
+ sorted_box.append(y)
+ boxes.append(sorted_box)
+ text = ann['text']
+ texts.append(ann['text'])
+ text_ind = [self.dict[c] for c in text if c in self.dict]
+ text_inds.append(text_ind)
+ labels.append(ann['label'])
+ edges.append(ann.get('edge', 0))
+
+ ann_infos = dict(
+ boxes=boxes,
+ texts=texts,
+ text_inds=text_inds,
+ edges=edges,
+ labels=labels)
+
+ return self.list_to_numpy(ann_infos)
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+ img_ann_info = self.data_infos[index]
+ img_info = {
+ 'filename': img_ann_info['file_name'],
+ 'height': img_ann_info['height'],
+ 'width': img_ann_info['width']
+ }
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ results = dict(img_info=img_info, ann_info=ann_info)
+
+ self.pre_pipeline(results)
+
+ return self.pipeline(results)
+
+ def evaluate(self,
+ results,
+ metric='macro_f1',
+ metric_options=dict(macro_f1=dict(ignores=[])),
+ **kwargs):
+ # allow some kwargs to pass through
+ assert set(kwargs).issubset(['logger'])
+
+ # Protect ``metric_options`` since it uses mutable value as default
+ metric_options = copy.deepcopy(metric_options)
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['macro_f1']
+ for m in metrics:
+ if m not in allowed_metrics:
+ raise KeyError(f'metric {m} is not supported')
+
+ return self.compute_macro_f1(results, **metric_options['macro_f1'])
+
+ def compute_macro_f1(self, results, ignores=[]):
+ node_preds = []
+ node_gts = []
+ for idx, result in enumerate(results):
+ node_preds.append(result['nodes'])
+ box_ann_infos = self.data_infos[idx]['annotations']
+ node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos]
+ node_gts.append(torch.Tensor(node_gt))
+
+ node_preds = torch.cat(node_preds)
+ node_gts = torch.cat(node_gts).int().to(node_preds.device)
+
+ node_f1s = compute_f1_score(node_preds, node_gts, ignores)
+
+ return {
+ 'macro_f1': node_f1s.mean(),
+ }
+
+ def list_to_numpy(self, ann_infos):
+ """Convert bboxes, relations, texts and labels to ndarray."""
+ boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds']
+ boxes = np.array(boxes, np.int32)
+ relations, bboxes = self.compute_relation(boxes)
+
+ labels = ann_infos.get('labels', None)
+ if labels is not None:
+ labels = np.array(labels, np.int32)
+ edges = ann_infos.get('edges', None)
+ if edges is not None:
+ labels = labels[:, None]
+ edges = np.array(edges)
+ edges = (edges[:, None] == edges[None, :]).astype(np.int32)
+ if self.directed:
+ edges = (edges & labels == 1).astype(np.int32)
+ np.fill_diagonal(edges, -1)
+ labels = np.concatenate([labels, edges], -1)
+ padded_text_inds = self.pad_text_indices(text_inds)
+
+ return dict(
+ bboxes=bboxes,
+ relations=relations,
+ texts=padded_text_inds,
+ labels=labels)
+
+ def pad_text_indices(self, text_inds):
+ """Pad text index to same length."""
+ max_len = max([len(text_ind) for text_ind in text_inds])
+ padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
+ for idx, text_ind in enumerate(text_inds):
+ padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
+ return padded_text_inds
+
+ def compute_relation(self, boxes):
+ """Compute relation between every two boxes."""
+ x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
+ x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
+ ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
+ dxs = (x1s[:, 0][None] - x1s) / self.norm
+ dys = (y1s[:, 0][None] - y1s) / self.norm
+ xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
+ whs = ws / hs + np.zeros_like(xhhs)
+ relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
+ bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
+ return relations, bboxes
diff --git a/mmocr/datasets/ocr_dataset.py b/mmocr/datasets/ocr_dataset.py
new file mode 100644
index 00000000..4ec6d962
--- /dev/null
+++ b/mmocr/datasets/ocr_dataset.py
@@ -0,0 +1,34 @@
+from mmdet.datasets.builder import DATASETS
+from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
+from mmocr.datasets.base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class OCRDataset(BaseDataset):
+
+ def pre_pipeline(self, results):
+ results['img_prefix'] = self.img_prefix
+ results['text'] = results['img_info']['text']
+
+ def evaluate(self, results, metric='acc', logger=None, **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ Returns:
+ dict[str: float]
+ """
+ gt_texts = []
+ pred_texts = []
+ for i in range(len(self)):
+ item_info = self.data_infos[i]
+ text = item_info['text']
+ gt_texts.append(text)
+ pred_texts.append(results[i]['text'])
+
+ eval_results = eval_ocr_metric(pred_texts, gt_texts)
+
+ return eval_results
diff --git a/mmocr/datasets/ocr_seg_dataset.py b/mmocr/datasets/ocr_seg_dataset.py
new file mode 100644
index 00000000..638de163
--- /dev/null
+++ b/mmocr/datasets/ocr_seg_dataset.py
@@ -0,0 +1,89 @@
+import mmocr.utils as utils
+from mmdet.datasets.builder import DATASETS
+from mmocr.datasets.ocr_dataset import OCRDataset
+
+
+@DATASETS.register_module()
+class OCRSegDataset(OCRDataset):
+
+ def pre_pipeline(self, results):
+ results['img_prefix'] = self.img_prefix
+
+ def _parse_anno_info(self, annotations):
+ """Parse char boxes annotations.
+ Args:
+ annotations (list[dict]): Annotations of one image, where
+ each dict is for one character.
+
+ Returns:
+ dict: A dict containing the following keys:
+
+ - chars (list[str]): List of character strings.
+ - char_rects (list[list[float]]): List of char box, with each
+ in style of rectangle: [x_min, y_min, x_max, y_max].
+ - char_quads (list[list[float]]): List of char box, with each
+ in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4].
+ """
+
+ assert utils.is_type_list(annotations, dict)
+ assert 'char_box' in annotations[0]
+ assert 'char_text' in annotations[0]
+ assert len(annotations[0]['char_box']) in [4, 8]
+
+ chars, char_rects, char_quads = [], [], []
+ for ann in annotations:
+ char_box = ann['char_box']
+ if len(char_box) == 4:
+ char_box_type = ann.get('char_box_type', 'xyxy')
+ if char_box_type == 'xyxy':
+ char_rects.append(char_box)
+ char_quads.append([
+ char_box[0], char_box[1], char_box[2], char_box[1],
+ char_box[2], char_box[3], char_box[0], char_box[3]
+ ])
+ elif char_box_type == 'xywh':
+ x1, y1, w, h = char_box
+ x2 = x1 + w
+ y2 = y1 + h
+ char_rects.append([x1, y1, x2, y2])
+ char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2])
+ else:
+ raise ValueError(f'invalid char_box_type {char_box_type}')
+ elif len(char_box) == 8:
+ x_list, y_list = [], []
+ for i in range(4):
+ x_list.append(char_box[2 * i])
+ y_list.append(char_box[2 * i + 1])
+ x_max, x_min = max(x_list), min(x_list)
+ y_max, y_min = max(y_list), min(y_list)
+ char_rects.append([x_min, y_min, x_max, y_max])
+ char_quads.append(char_box)
+ else:
+ raise Exception(
+ f'invalid num in char box: {len(char_box)} not in (4, 8)')
+ chars.append(ann['char_text'])
+
+ ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads)
+
+ return ann
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+ img_ann_info = self.data_infos[index]
+ img_info = {
+ 'filename': img_ann_info['file_name'],
+ }
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ results = dict(img_info=img_info, ann_info=ann_info)
+
+ self.pre_pipeline(results)
+
+ return self.pipeline(results)
diff --git a/mmocr/datasets/pipelines/__init__.py b/mmocr/datasets/pipelines/__init__.py
new file mode 100644
index 00000000..4eca60a3
--- /dev/null
+++ b/mmocr/datasets/pipelines/__init__.py
@@ -0,0 +1,26 @@
+from .box_utils import sort_vertex
+from .custom_format_bundle import CustomFormatBundle
+from .dbnet_transforms import EastRandomCrop, ImgAug
+from .kie_transforms import KIEFormatBundle
+from .loading import LoadTextAnnotations
+from .ocr_seg_targets import OCRSegTargets
+from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
+ OpencvToPil, PilToOpencv, RandomPaddingOCR,
+ RandomRotateImageBox, ResizeOCR, ToTensorOCR)
+from .test_time_aug import MultiRotateAugOCR
+from .textdet_targets import DBNetTargets, PANetTargets, TextSnakeTargets
+from .transforms import (ColorJitter, RandomCropInstances,
+ RandomCropPolyInstances, RandomRotatePolyInstances,
+ RandomRotateTextDet, ScaleAspectJitter,
+ SquareResizePad)
+
+__all__ = [
+ 'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR',
+ 'ToTensorOCR', 'CustomFormatBundle', 'DBNetTargets', 'PANetTargets',
+ 'ColorJitter', 'RandomCropInstances', 'RandomRotateTextDet',
+ 'ScaleAspectJitter', 'MultiRotateAugOCR', 'OCRSegTargets', 'FancyPCA',
+ 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
+ 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
+ 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
+ 'sort_vertex'
+]
diff --git a/mmocr/datasets/pipelines/box_utils.py b/mmocr/datasets/pipelines/box_utils.py
new file mode 100644
index 00000000..23af4817
--- /dev/null
+++ b/mmocr/datasets/pipelines/box_utils.py
@@ -0,0 +1,83 @@
+import numpy as np
+from shapely.geometry import LineString, Point, Polygon
+
+import mmocr.utils as utils
+
+
+def sort_vertex(points_x, points_y):
+ """Sort box vertices in clockwise order from left-top first.
+
+ Args:
+ points_x (list[float]): x of four vertices.
+ points_y (list[float]): y of four vertices.
+ Returns:
+ sorted_points_x (list[float]): x of sorted four vertices.
+ sorted_points_y (list[float]): y of sorted four vertices.
+ """
+ assert utils.is_type_list(points_x, float) or utils.is_type_list(
+ points_x, int)
+ assert utils.is_type_list(points_y, float) or utils.is_type_list(
+ points_y, int)
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ x = np.array(points_x)
+ y = np.array(points_y)
+ center_x = np.sum(x) * 0.25
+ center_y = np.sum(y) * 0.25
+
+ x_arr = np.array(x - center_x)
+ y_arr = np.array(y - center_y)
+
+ angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
+ sort_idx = np.argsort(angle)
+
+ sorted_points_x, sorted_points_y = [], []
+ for i in range(4):
+ sorted_points_x.append(points_x[sort_idx[i]])
+ sorted_points_y.append(points_y[sort_idx[i]])
+
+ return convert_canonical(sorted_points_x, sorted_points_y)
+
+
+def convert_canonical(points_x, points_y):
+ """Make left-top be first.
+
+ Args:
+ points_x (list[float]): x of four vertices.
+ points_y (list[float]): y of four vertices.
+ Returns:
+ sorted_points_x (list[float]): x of sorted four vertices.
+ sorted_points_y (list[float]): y of sorted four vertices.
+ """
+ assert utils.is_type_list(points_x, float) or utils.is_type_list(
+ points_x, int)
+ assert utils.is_type_list(points_y, float) or utils.is_type_list(
+ points_y, int)
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+
+ polygon = Polygon([(p.x, p.y) for p in points])
+ min_x, min_y, _, _ = polygon.bounds
+ points_to_lefttop = [
+ LineString([points[i], Point(min_x, min_y)]) for i in range(4)
+ ]
+ distances = np.array([line.length for line in points_to_lefttop])
+ sort_dist_idx = np.argsort(distances)
+ lefttop_idx = sort_dist_idx[0]
+
+ if lefttop_idx == 0:
+ point_orders = [0, 1, 2, 3]
+ elif lefttop_idx == 1:
+ point_orders = [1, 2, 3, 0]
+ elif lefttop_idx == 2:
+ point_orders = [2, 3, 0, 1]
+ else:
+ point_orders = [3, 0, 1, 2]
+
+ sorted_points_x = [points_x[i] for i in point_orders]
+ sorted_points_y = [points_y[j] for j in point_orders]
+
+ return sorted_points_x, sorted_points_y
diff --git a/mmocr/datasets/pipelines/crop.py b/mmocr/datasets/pipelines/crop.py
new file mode 100644
index 00000000..eea0ffbb
--- /dev/null
+++ b/mmocr/datasets/pipelines/crop.py
@@ -0,0 +1,107 @@
+import cv2
+import numpy as np
+from shapely.geometry import LineString, Point
+
+import mmocr.utils as utils
+from .box_utils import sort_vertex
+
+
+def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
+ """Jitter on the coordinates of bounding box.
+
+ Args:
+ points_x (list[float | int]): List of y for four vertices.
+ points_y (list[float | int]): List of x for four vertices.
+ jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
+ jitter_ratio_y (float): Vertical jitter ratio relative to the height.
+ """
+ assert len(points_x) == 4
+ assert len(points_y) == 4
+ assert isinstance(jitter_ratio_x, float)
+ assert isinstance(jitter_ratio_y, float)
+ assert 0 <= jitter_ratio_x < 1
+ assert 0 <= jitter_ratio_y < 1
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+ line_list = [
+ LineString([points[i], points[i + 1 if i < 3 else 0]])
+ for i in range(4)
+ ]
+
+ tmp_h = max(line_list[1].length, line_list[3].length)
+
+ for i in range(4):
+ jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
+ jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
+ points_x[i] += jitter_pixel_x
+ points_y[i] += jitter_pixel_y
+
+
+def warp_img(src_img,
+ box,
+ jitter_flag=False,
+ jitter_ratio_x=0.5,
+ jitter_ratio_y=0.1):
+ """Crop box area from image using opencv warpPerspective w/o box jitter.
+
+ Args:
+ src_img (np.array): Image before cropping.
+ box (list[float | int]): Coordinates of quadrangle.
+ """
+ assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
+ assert len(box) == 8
+
+ h, w = src_img.shape[:2]
+ points_x = [min(max(x, 0), w) for x in box[0:8:2]]
+ points_y = [min(max(y, 0), h) for y in box[1:9:2]]
+
+ points_x, points_y = sort_vertex(points_x, points_y)
+
+ if jitter_flag:
+ box_jitter(
+ points_x,
+ points_y,
+ jitter_ratio_x=jitter_ratio_x,
+ jitter_ratio_y=jitter_ratio_y)
+
+ points = [Point(points_x[i], points_y[i]) for i in range(4)]
+ edges = [
+ LineString([points[i], points[i + 1 if i < 3 else 0]])
+ for i in range(4)
+ ]
+
+ pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)])
+ box_width = max(edges[0].length, edges[2].length)
+ box_height = max(edges[1].length, edges[3].length)
+
+ pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height],
+ [0, box_height]])
+ M = cv2.getPerspectiveTransform(pts1, pts2)
+ dst_img = cv2.warpPerspective(src_img, M,
+ (int(box_width), int(box_height)))
+
+ return dst_img
+
+
+def crop_img(src_img, box):
+ """Crop box area to rectangle.
+
+ Args:
+ src_img (np.array): Image before crop.
+ box (list[float | int]): Points of quadrangle.
+ """
+ assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
+ assert len(box) == 8
+
+ h, w = src_img.shape[:2]
+ points_x = [min(max(x, 0), w) for x in box[0:8:2]]
+ points_y = [min(max(y, 0), h) for y in box[1:9:2]]
+
+ left = int(min(points_x))
+ top = int(min(points_y))
+ right = int(max(points_x))
+ bottom = int(max(points_y))
+
+ dst_img = src_img[top:bottom, left:right]
+
+ return dst_img
diff --git a/mmocr/datasets/pipelines/custom_format_bundle.py b/mmocr/datasets/pipelines/custom_format_bundle.py
new file mode 100644
index 00000000..276cde69
--- /dev/null
+++ b/mmocr/datasets/pipelines/custom_format_bundle.py
@@ -0,0 +1,65 @@
+import numpy as np
+from mmcv.parallel import DataContainer as DC
+
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines.formating import DefaultFormatBundle
+from mmocr.core.visualize import overlay_mask_img, show_feature
+
+
+@PIPELINES.register_module()
+class CustomFormatBundle(DefaultFormatBundle):
+ """Custom formatting bundle.
+
+ It formats common fields such as 'img' and 'proposals' as done in
+ DefaultFormatBundle, while other fields such as 'gt_kernels' and
+ 'gt_effective_region_mask' will be formatted to DC as follows:
+
+ - gt_kernels: to DataContainer (cpu_only=True)
+ - gt_effective_mask: to DataContainer (cpu_only=True)
+
+ Args:
+ keys (list[str]): Fields to be formatted to DC only.
+ call_super (bool): If True, format common fields
+ by DefaultFormatBundle, else format fields in keys above only.
+ visualize (dict): If flag=True, visualize gt mask for debugging.
+ """
+
+ def __init__(self,
+ keys=[],
+ call_super=True,
+ visualize=dict(flag=False, boundary_key=None)):
+
+ super().__init__()
+ self.visualize = visualize
+ self.keys = keys
+ self.call_super = call_super
+
+ def __call__(self, results):
+
+ if self.visualize['flag']:
+ img = results['img'].astype(np.uint8)
+ boundary_key = self.visualize['boundary_key']
+ if boundary_key is not None:
+ img = overlay_mask_img(img, results[boundary_key].masks[0])
+
+ features = [img]
+ names = ['img']
+ to_uint8 = [1]
+
+ for k in results['mask_fields']:
+ for iter in range(len(results[k].masks)):
+ features.append(results[k].masks[iter])
+ names.append(k + str(iter))
+ to_uint8.append(0)
+ show_feature(features, names, to_uint8)
+
+ if self.call_super:
+ results = super().__call__(results)
+
+ for k in self.keys:
+ results[k] = DC(results[k], cpu_only=True)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
diff --git a/mmocr/datasets/pipelines/dbnet_transforms.py b/mmocr/datasets/pipelines/dbnet_transforms.py
new file mode 100644
index 00000000..9ec7936c
--- /dev/null
+++ b/mmocr/datasets/pipelines/dbnet_transforms.py
@@ -0,0 +1,272 @@
+import cv2
+import imgaug
+import imgaug.augmenters as iaa
+import numpy as np
+
+from mmdet.core.mask import PolygonMasks
+from mmdet.datasets.builder import PIPELINES
+
+
+class AugmenterBuilder:
+ """Build imgaug object according ImgAug argmentations."""
+
+ def __init__(self):
+ pass
+
+ def build(self, args, root=True):
+ if args is None:
+ return None
+ elif isinstance(args, (int, float, str)):
+ return args
+ elif isinstance(args, list):
+ if root:
+ sequence = [self.build(value, root=False) for value in args]
+ return iaa.Sequential(sequence)
+ arg_list = [self.to_tuple_if_list(a) for a in args[1:]]
+ return getattr(iaa, args[0])(*arg_list)
+ elif isinstance(args, dict):
+ if 'cls' in args:
+ cls = getattr(iaa, args['cls'])
+ return cls(
+ **{
+ k: self.to_tuple_if_list(v)
+ for k, v in args.items() if not k == 'cls'
+ })
+ else:
+ return {
+ key: self.build(value, root=False)
+ for key, value in args.items()
+ }
+ else:
+ raise RuntimeError('unknown augmenter arg: ' + str(args))
+
+ def to_tuple_if_list(self, obj):
+ if isinstance(obj, list):
+ return tuple(obj)
+ return obj
+
+
+@PIPELINES.register_module()
+class ImgAug:
+ """A wrapper to use imgaug https://github.com/aleju/imgaug.
+
+ Args:
+ args ([list[list|dict]]): The argumentation list. For details, please
+ refer to imgaug document. Take args=[['Fliplr', 0.5],
+ dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an
+ example. The args horizontally flip images with probability 0.5,
+ followed by random rotation with angles in range [-10, 10], and
+ resize with an independent scale in range [0.5, 3.0] for each
+ side of images.
+ """
+
+ def __init__(self, args=None):
+ self.augmenter_args = args
+ self.augmenter = AugmenterBuilder().build(self.augmenter_args)
+
+ def __call__(self, results):
+ # img is bgr
+ image = results['img']
+ aug = None
+ shape = image.shape
+
+ if self.augmenter:
+ aug = self.augmenter.to_deterministic()
+ results['img'] = aug.augment_image(image)
+ results['img_shape'] = results['img'].shape
+ results['flip'] = 'unknown' # it's unknown
+ results['flip_direction'] = 'unknown' # it's unknown
+ target_shape = results['img_shape']
+
+ self.may_augment_annotation(aug, shape, target_shape, results)
+
+ return results
+
+ def may_augment_annotation(self, aug, shape, target_shape, results):
+ if aug is None:
+ return results
+ for key in results['mask_fields']:
+ # augment polygon mask
+ masks = []
+ for mask in results[key]:
+ masks.append(
+ [self.may_augment_poly(aug, shape, target_shape, mask[0])])
+ if len(masks) > 0:
+ results[key] = PolygonMasks(masks, *target_shape[:2])
+
+ for key in results['bbox_fields']:
+ # augment bbox
+ bboxes = []
+ for bbox in results[key]:
+ bbox = self.may_augment_poly(aug, shape, target_shape, bbox)
+ bboxes.append(bbox)
+ results[key] = np.zeros(0)
+ if len(bboxes) > 0:
+ results[key] = np.stack(bboxes)
+
+ return results
+
+ def may_augment_poly(self, aug, img_shape, target_shape, poly):
+ # poly n x 2
+ poly = poly.reshape(-1, 2)
+ keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
+ keypoints = aug.augment_keypoints(
+ [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
+ poly = [[p.x, p.y] for p in keypoints]
+ poly = np.array(poly).flatten()
+ return poly
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class EastRandomCrop:
+
+ def __init__(self,
+ target_size=(640, 640),
+ max_tries=10,
+ min_crop_side_ratio=0.1):
+ self.target_size = target_size
+ self.max_tries = max_tries
+ self.min_crop_side_ratio = min_crop_side_ratio
+
+ def __call__(self, results):
+ # sampling crop
+ # crop image, boxes, masks
+ img = results['img']
+ crop_x, crop_y, crop_w, crop_h = self.crop_area(
+ img, results['gt_masks'])
+ scale_w = self.target_size[0] / crop_w
+ scale_h = self.target_size[1] / crop_h
+ scale = min(scale_w, scale_h)
+ h = int(crop_h * scale)
+ w = int(crop_w * scale)
+ padded_img = np.zeros(
+ (self.target_size[1], self.target_size[0], img.shape[2]),
+ img.dtype)
+ padded_img[:h, :w] = cv2.resize(
+ img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
+
+ # for bboxes
+ for key in results['bbox_fields']:
+ lines = []
+ for box in results[key]:
+ box = box.reshape(2, 2)
+ poly = ((box - (crop_x, crop_y)) * scale)
+ if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+ lines.append(poly.flatten())
+ results[key] = np.array(lines)
+ # for masks
+ for key in results['mask_fields']:
+ polys = []
+ polys_label = []
+ for poly in results[key]:
+ poly = np.array(poly).reshape(-1, 2)
+ poly = ((poly - (crop_x, crop_y)) * scale)
+ if not self.is_poly_outside_rect(poly, 0, 0, w, h):
+ polys.append([poly])
+ polys_label.append(0)
+ results[key] = PolygonMasks(polys, *self.target_size)
+ if key == 'gt_masks':
+ results['gt_labels'] = polys_label
+
+ results['img'] = padded_img
+ results['img_shape'] = padded_img.shape
+
+ return results
+
+ def is_poly_in_rect(self, poly, x, y, w, h):
+ poly = np.array(poly)
+ if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
+ return False
+ if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
+ return False
+ return True
+
+ def is_poly_outside_rect(self, poly, x, y, w, h):
+ poly = np.array(poly).reshape(-1, 2)
+ if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
+ return True
+ if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
+ return True
+ return False
+
+ def split_regions(self, axis):
+ regions = []
+ min_axis = 0
+ for i in range(1, axis.shape[0]):
+ if axis[i] != axis[i - 1] + 1:
+ region = axis[min_axis:i]
+ min_axis = i
+ regions.append(region)
+ return regions
+
+ def random_select(self, axis, max_size):
+ xx = np.random.choice(axis, size=2)
+ xmin = np.min(xx)
+ xmax = np.max(xx)
+ xmin = np.clip(xmin, 0, max_size - 1)
+ xmax = np.clip(xmax, 0, max_size - 1)
+ return xmin, xmax
+
+ def region_wise_random_select(self, regions, max_size):
+ selected_index = list(np.random.choice(len(regions), 2))
+ selected_values = []
+ for index in selected_index:
+ axis = regions[index]
+ xx = int(np.random.choice(axis, size=1))
+ selected_values.append(xx)
+ xmin = min(selected_values)
+ xmax = max(selected_values)
+ return xmin, xmax
+
+ def crop_area(self, img, polys):
+ h, w, _ = img.shape
+ h_array = np.zeros(h, dtype=np.int32)
+ w_array = np.zeros(w, dtype=np.int32)
+ for points in polys:
+ points = np.round(
+ points, decimals=0).astype(np.int32).reshape(-1, 2)
+ min_x = np.min(points[:, 0])
+ max_x = np.max(points[:, 0])
+ w_array[min_x:max_x] = 1
+ min_y = np.min(points[:, 1])
+ max_y = np.max(points[:, 1])
+ h_array[min_y:max_y] = 1
+ # ensure the cropped area not across a text
+ h_axis = np.where(h_array == 0)[0]
+ w_axis = np.where(w_array == 0)[0]
+
+ if len(h_axis) == 0 or len(w_axis) == 0:
+ return 0, 0, w, h
+
+ h_regions = self.split_regions(h_axis)
+ w_regions = self.split_regions(w_axis)
+
+ for i in range(self.max_tries):
+ if len(w_regions) > 1:
+ xmin, xmax = self.region_wise_random_select(w_regions, w)
+ else:
+ xmin, xmax = self.random_select(w_axis, w)
+ if len(h_regions) > 1:
+ ymin, ymax = self.region_wise_random_select(h_regions, h)
+ else:
+ ymin, ymax = self.random_select(h_axis, h)
+
+ if (xmax - xmin < self.min_crop_side_ratio * w
+ or ymax - ymin < self.min_crop_side_ratio * h):
+ # area too small
+ continue
+ num_poly_in_rect = 0
+ for poly in polys:
+ if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
+ ymax - ymin):
+ num_poly_in_rect += 1
+ break
+
+ if num_poly_in_rect > 0:
+ return xmin, ymin, xmax - xmin, ymax - ymin
+
+ return 0, 0, w, h
diff --git a/mmocr/datasets/pipelines/kie_transforms.py b/mmocr/datasets/pipelines/kie_transforms.py
new file mode 100644
index 00000000..0787f07f
--- /dev/null
+++ b/mmocr/datasets/pipelines/kie_transforms.py
@@ -0,0 +1,55 @@
+import numpy as np
+from mmcv.parallel import DataContainer as DC
+
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines.formating import DefaultFormatBundle, to_tensor
+
+
+@PIPELINES.register_module()
+class KIEFormatBundle(DefaultFormatBundle):
+ """Key information extraction formatting bundle.
+
+ Based on the DefaultFormatBundle, itt simplifies the pipeline of formatting
+ common fields, including "img", "proposals", "gt_bboxes", "gt_labels",
+ "gt_masks", "gt_semantic_seg", "relations" and "texts".
+ These fields are formatted as follows.
+
+ - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)
+ - proposals: (1) to tensor, (2) to DataContainer
+ - gt_bboxes: (1) to tensor, (2) to DataContainer
+ - gt_bboxes_ignore: (1) to tensor, (2) to DataContainer
+ - gt_labels: (1) to tensor, (2) to DataContainer
+ - gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True)
+ - gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor,
+ (3) to DataContainer (stack=True)
+ - relations: (1) scale, (2) to tensor, (3) to DataContainer
+ - texts: (1) to tensor, (2) to DataContainer
+ """
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with
+ default bundle.
+ """
+ super().__call__(results)
+ if 'ann_info' in results:
+ for key in ['relations', 'texts']:
+ value = results['ann_info'][key]
+ if key == 'relations' and 'scale_factor' in results:
+ scale_factor = results['scale_factor']
+ if isinstance(scale_factor, float):
+ sx = sy = scale_factor
+ else:
+ sx, sy = results['scale_factor'][:2]
+ r = sx / sy
+ value = value * np.array([sx, sy, r, 1, r])[None, None]
+ results[key] = DC(to_tensor(value))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
diff --git a/mmocr/datasets/pipelines/loading.py b/mmocr/datasets/pipelines/loading.py
new file mode 100644
index 00000000..5c3cda6e
--- /dev/null
+++ b/mmocr/datasets/pipelines/loading.py
@@ -0,0 +1,68 @@
+import numpy as np
+
+from mmdet.core import BitmapMasks, PolygonMasks
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines.loading import LoadAnnotations
+
+
+@PIPELINES.register_module()
+class LoadTextAnnotations(LoadAnnotations):
+
+ def __init__(self,
+ with_bbox=True,
+ with_label=True,
+ with_mask=False,
+ with_seg=False,
+ poly2mask=True):
+ super().__init__(
+ with_bbox=with_bbox,
+ with_label=with_label,
+ with_mask=with_mask,
+ with_seg=with_seg,
+ poly2mask=poly2mask)
+
+ def process_polygons(self, polygons):
+ """Convert polygons to list of ndarray and filter invalid polygons.
+
+ Args:
+ polygons (list[list]): Polygons of one instance.
+
+ Returns:
+ list[numpy.ndarray]: Processed polygons.
+ """
+
+ polygons = [np.array(p).astype(np.float32) for p in polygons]
+ valid_polygons = []
+ for polygon in polygons:
+ if len(polygon) % 2 == 0 and len(polygon) >= 6:
+ valid_polygons.append(polygon)
+ return valid_polygons
+
+ def _load_masks(self, results):
+ ann_info = results['ann_info']
+ h, w = results['img_info']['height'], results['img_info']['width']
+ gt_masks = ann_info['masks']
+ if self.poly2mask:
+ gt_masks = BitmapMasks(
+ [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
+ else:
+ gt_masks = PolygonMasks(
+ [self.process_polygons(polygons) for polygons in gt_masks], h,
+ w)
+ gt_masks_ignore = ann_info.get('masks_ignore', None)
+ if gt_masks_ignore is not None:
+ if self.poly2mask:
+ gt_masks_ignore = BitmapMasks(
+ [self._poly2mask(mask, h, w) for mask in gt_masks_ignore],
+ h, w)
+ else:
+ gt_masks_ignore = PolygonMasks([
+ self.process_polygons(polygons)
+ for polygons in gt_masks_ignore
+ ], h, w)
+ results['gt_masks_ignore'] = gt_masks_ignore
+ results['mask_fields'].append('gt_masks_ignore')
+
+ results['gt_masks'] = gt_masks
+ results['mask_fields'].append('gt_masks')
+ return results
diff --git a/mmocr/datasets/pipelines/ocr_seg_targets.py b/mmocr/datasets/pipelines/ocr_seg_targets.py
new file mode 100644
index 00000000..cbd9b869
--- /dev/null
+++ b/mmocr/datasets/pipelines/ocr_seg_targets.py
@@ -0,0 +1,201 @@
+import cv2
+import numpy as np
+
+import mmocr.utils.check_argument as check_argument
+from mmdet.core import BitmapMasks
+from mmdet.datasets.builder import PIPELINES
+from mmocr.models.builder import build_convertor
+
+
+@PIPELINES.register_module()
+class OCRSegTargets:
+ """Generate gt shrinked kernels for segmentation based OCR framework.
+
+ Args:
+ label_convertor (dict): Dictionary to construct label_convertor
+ to convert char to index.
+ attn_shrink_ratio (float): The area shrinked ratio
+ between attention kernels and gt text masks.
+ seg_shrink_ratio (float): The area shrinked ratio
+ between segmentation kernels and gt text masks.
+ box_type (str): Character box type, should be either
+ 'char_rects' or 'char_quads', with 'char_rects'
+ for rectangle with ``xyxy`` style and 'char_quads'
+ for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
+ """
+
+ def __init__(self,
+ label_convertor=None,
+ attn_shrink_ratio=0.5,
+ seg_shrink_ratio=0.25,
+ box_type='char_rects',
+ pad_val=255):
+
+ assert isinstance(attn_shrink_ratio, float)
+ assert isinstance(seg_shrink_ratio, float)
+ assert 0. < attn_shrink_ratio < 1.0
+ assert 0. < seg_shrink_ratio < 1.0
+ assert label_convertor is not None
+ assert box_type in ('char_rects', 'char_quads')
+
+ self.attn_shrink_ratio = attn_shrink_ratio
+ self.seg_shrink_ratio = seg_shrink_ratio
+ self.label_convertor = build_convertor(label_convertor)
+ self.box_type = box_type
+ self.pad_val = pad_val
+
+ def shrink_char_quad(self, char_quad, shrink_ratio):
+ """Shrink char box in style of quadrangle.
+
+ Args:
+ char_quad (list[float]): Char box with format
+ [x1, y1, x2, y2, x3, y3, x4, y4].
+ shrink_ratio (float): The area shrinked ratio
+ between gt kernels and gt text masks.
+ """
+ points = [[char_quad[0], char_quad[1]], [char_quad[2], char_quad[3]],
+ [char_quad[4], char_quad[5]], [char_quad[6], char_quad[7]]]
+ shrink_points = []
+ for p_idx, point in enumerate(points):
+ p1 = points[(p_idx + 3) % 4]
+ p2 = points[(p_idx + 1) % 4]
+
+ dist1 = self.l2_dist_two_points(p1, point)
+ dist2 = self.l2_dist_two_points(p2, point)
+ min_dist = min(dist1, dist2)
+
+ v1 = [p1[0] - point[0], p1[1] - point[1]]
+ v2 = [p2[0] - point[0], p2[1] - point[1]]
+
+ temp_dist1 = (shrink_ratio * min_dist /
+ dist1) if min_dist != 0 else 0.
+ temp_dist2 = (shrink_ratio * min_dist /
+ dist2) if min_dist != 0 else 0.
+
+ v1 = [temp * temp_dist1 for temp in v1]
+ v2 = [temp * temp_dist2 for temp in v2]
+
+ shrink_point = [
+ round(point[0] + v1[0] + v2[0]),
+ round(point[1] + v1[1] + v2[1])
+ ]
+ shrink_points.append(shrink_point)
+
+ poly = np.array(shrink_points)
+
+ return poly
+
+ def shrink_char_rect(self, char_rect, shrink_ratio):
+ """Shrink char box in style of rectangle.
+
+ Args:
+ char_rect (list[float]): Char box with format
+ [x_min, y_min, x_max, y_max].
+ shrink_ratio (float): The area shrinked ratio
+ between gt kernels and gt text masks.
+ """
+ x_min, y_min, x_max, y_max = char_rect
+ w = x_max - x_min
+ h = y_max - y_min
+ x_min_s = round((x_min + x_max - w * shrink_ratio) / 2)
+ y_min_s = round((y_min + y_max - h * shrink_ratio) / 2)
+ x_max_s = round((x_min + x_max + w * shrink_ratio) / 2)
+ y_max_s = round((y_min + y_max + h * shrink_ratio) / 2)
+ poly = np.array([[x_min_s, y_min_s], [x_max_s, y_min_s],
+ [x_max_s, y_max_s], [x_min_s, y_max_s]])
+
+ return poly
+
+ def generate_kernels(self,
+ resize_shape,
+ pad_shape,
+ char_boxes,
+ char_inds,
+ shrink_ratio=0.5,
+ binary=True):
+ """Generate char instance kernels for one shrink ratio.
+
+ Args:
+ resize_shape (tuple(int, int)): Image size (height, width)
+ after resizing.
+ pad_shape (tuple(int, int)): Image size (height, width)
+ after padding.
+ char_boxes (list[list[float]]): The list of char polygons.
+ char_inds (list[int]): List of char indexes.
+ shrink_ratio (float): The shrink ratio of kernel.
+ binary (bool): If True, return binary ndarray
+ containing 0 & 1 only.
+ Returns:
+ char_kernel (ndarray): The text kernel mask of (height, width).
+ """
+ assert isinstance(resize_shape, tuple)
+ assert isinstance(pad_shape, tuple)
+ assert check_argument.is_2dlist(char_boxes)
+ assert check_argument.is_type_list(char_inds, int)
+ assert isinstance(shrink_ratio, float)
+ assert isinstance(binary, bool)
+
+ char_kernel = np.zeros(pad_shape, dtype=np.int32)
+ char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val
+
+ for i, char_box in enumerate(char_boxes):
+ if self.box_type == 'char_rects':
+ poly = self.shrink_char_rect(char_box, shrink_ratio)
+ elif self.box_type == 'char_quads':
+ poly = self.shrink_char_quad(char_box, shrink_ratio)
+
+ fill_value = 1 if binary else char_inds[i]
+ cv2.fillConvexPoly(char_kernel, poly.astype(np.int32),
+ (fill_value))
+
+ return char_kernel
+
+ def l2_dist_two_points(self, p1, p2):
+ return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5
+
+ def __call__(self, results):
+ img_shape = results['img_shape']
+ resize_shape = results['resize_shape']
+
+ h_scale = 1.0 * resize_shape[0] / img_shape[0]
+ w_scale = 1.0 * resize_shape[1] / img_shape[1]
+
+ char_boxes, char_inds = [], []
+ char_num = len(results['ann_info'][self.box_type])
+ for i in range(char_num):
+ char_box = results['ann_info'][self.box_type][i]
+ num_points = 2 if self.box_type == 'char_rects' else 4
+ for j in range(num_points):
+ char_box[j * 2] = round(char_box[j * 2] * w_scale)
+ char_box[j * 2 + 1] = round(char_box[j * 2 + 1] * h_scale)
+ char_boxes.append(char_box)
+ char = results['ann_info']['chars'][i]
+ char_ind = self.label_convertor.str2idx([char])[0][0]
+ char_inds.append(char_ind)
+
+ resize_shape = tuple(results['resize_shape'][:2])
+ pad_shape = tuple(results['pad_shape'][:2])
+ binary_target = self.generate_kernels(
+ resize_shape,
+ pad_shape,
+ char_boxes,
+ char_inds,
+ shrink_ratio=self.attn_shrink_ratio,
+ binary=True)
+
+ seg_target = self.generate_kernels(
+ resize_shape,
+ pad_shape,
+ char_boxes,
+ char_inds,
+ shrink_ratio=self.seg_shrink_ratio,
+ binary=False)
+
+ mask = np.ones(pad_shape, dtype=np.int32)
+ mask[:resize_shape[0], resize_shape[1]:] = 0
+
+ results['gt_kernels'] = BitmapMasks([binary_target, seg_target, mask],
+ pad_shape[0], pad_shape[1])
+ results['mask_fields'] = ['gt_kernels']
+
+ return results
diff --git a/mmocr/datasets/pipelines/ocr_transforms.py b/mmocr/datasets/pipelines/ocr_transforms.py
new file mode 100644
index 00000000..5b7370e5
--- /dev/null
+++ b/mmocr/datasets/pipelines/ocr_transforms.py
@@ -0,0 +1,446 @@
+import math
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from mmcv.runner.dist_utils import get_dist_info
+from PIL import Image
+from shapely.geometry import Polygon
+from shapely.geometry import box as shapely_box
+
+import mmocr.utils as utils
+from mmdet.datasets.builder import PIPELINES
+from mmocr.datasets.pipelines.crop import warp_img
+
+
+@PIPELINES.register_module()
+class ResizeOCR:
+ """Image resizing and padding for OCR.
+
+ Args:
+ height (int | tuple(int)): Image height after resizing.
+ min_width (none | int | tuple(int)): Image minimum width
+ after resizing.
+ max_width (none | int | tuple(int)): Image maximum width
+ after resizing.
+ keep_aspect_ratio (bool): Keep image aspect ratio if True
+ during resizing, Otherwise resize to the size height *
+ max_width.
+ img_pad_value (int): Scalar to fill padding area.
+ width_downsample_ratio (float): Downsample ratio in horizontal
+ direction from input image to output feature.
+ """
+
+ def __init__(self,
+ height,
+ min_width=None,
+ max_width=None,
+ keep_aspect_ratio=True,
+ img_pad_value=0,
+ width_downsample_ratio=1.0 / 16):
+ assert isinstance(height, (int, tuple))
+ assert utils.is_none_or_type(min_width, (int, tuple))
+ assert utils.is_none_or_type(max_width, (int, tuple))
+ if not keep_aspect_ratio:
+ assert max_width is not None, ('"max_width" must assigned '
+ 'if "keep_aspect_ratio" is False')
+ assert isinstance(img_pad_value, int)
+ if isinstance(height, tuple):
+ assert isinstance(min_width, tuple)
+ assert isinstance(max_width, tuple)
+ assert len(height) == len(min_width) == len(max_width)
+
+ self.height = height
+ self.min_width = min_width
+ self.max_width = max_width
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.img_pad_value = img_pad_value
+ self.width_downsample_ratio = width_downsample_ratio
+
+ def __call__(self, results):
+ rank, _ = get_dist_info()
+ if isinstance(self.height, int):
+ dst_height = self.height
+ dst_min_width = self.min_width
+ dst_max_width = self.max_width
+ else:
+ """Multi-scale resize used in distributed training.
+
+ Choose one (height, width) pair for one rank id.
+ """
+ idx = rank % len(self.height)
+ dst_height = self.height[idx]
+ dst_min_width = self.min_width[idx]
+ dst_max_width = self.max_width[idx]
+
+ img_shape = results['img_shape']
+ ori_height, ori_width = img_shape[:2]
+ valid_ratio = 1.0
+ resize_shape = list(img_shape)
+ pad_shape = list(img_shape)
+
+ if self.keep_aspect_ratio:
+ new_width = math.ceil(float(dst_height) / ori_height * ori_width)
+ width_divisor = int(1 / self.width_downsample_ratio)
+ # make sure new_width is an integral multiple of width_divisor.
+ if new_width % width_divisor != 0:
+ new_width = round(new_width / width_divisor) * width_divisor
+ if dst_min_width is not None:
+ new_width = max(dst_min_width, new_width)
+ if dst_max_width is not None:
+ valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
+ resize_width = min(dst_max_width, new_width)
+ img_resize = cv2.resize(results['img'],
+ (resize_width, dst_height))
+ resize_shape = img_resize.shape
+ pad_shape = img_resize.shape
+ if new_width < dst_max_width:
+ img_resize = mmcv.impad(
+ img_resize,
+ shape=(dst_height, dst_max_width),
+ pad_val=self.img_pad_value)
+ pad_shape = img_resize.shape
+ else:
+ img_resize = cv2.resize(results['img'],
+ (new_width, dst_height))
+ resize_shape = img_resize.shape
+ pad_shape = img_resize.shape
+ else:
+ img_resize = cv2.resize(results['img'],
+ (dst_max_width, dst_height))
+ resize_shape = img_resize.shape
+ pad_shape = img_resize.shape
+
+ results['img'] = img_resize
+ results['resize_shape'] = resize_shape
+ results['pad_shape'] = pad_shape
+ results['valid_ratio'] = valid_ratio
+
+ return results
+
+
+@PIPELINES.register_module()
+class ToTensorOCR:
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
+
+ def __init__(self):
+ pass
+
+ def __call__(self, results):
+ results['img'] = TF.to_tensor(results['img'].copy())
+
+ return results
+
+
+@PIPELINES.register_module()
+class NormalizeOCR:
+ """Normalize a tensor image with mean and standard deviation."""
+
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, results):
+ results['img'] = TF.normalize(results['img'], self.mean, self.std)
+
+ return results
+
+
+@PIPELINES.register_module()
+class OnlineCropOCR:
+ """Crop text areas from whole image with bounding box jitter. If no bbox is
+ given, return directly.
+
+ Args:
+ box_keys (list[str]): Keys in results which correspond to RoI bbox.
+ jitter_prob (float): The probability of box jitter.
+ max_jitter_ratio_x (float): Maximum horizontal jitter ratio
+ relative to height.
+ max_jitter_ratio_y (float): Maximum vertical jitter ratio
+ relative to height.
+ """
+
+ def __init__(self,
+ box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
+ jitter_prob=0.5,
+ max_jitter_ratio_x=0.05,
+ max_jitter_ratio_y=0.02):
+ assert utils.is_type_list(box_keys, str)
+ assert 0 <= jitter_prob <= 1
+ assert 0 <= max_jitter_ratio_x <= 1
+ assert 0 <= max_jitter_ratio_y <= 1
+
+ self.box_keys = box_keys
+ self.jitter_prob = jitter_prob
+ self.max_jitter_ratio_x = max_jitter_ratio_x
+ self.max_jitter_ratio_y = max_jitter_ratio_y
+
+ def __call__(self, results):
+
+ if 'img_info' not in results:
+ return results
+
+ crop_flag = True
+ box = []
+ for key in self.box_keys:
+ if key not in results['img_info']:
+ crop_flag = False
+ break
+
+ box.append(float(results['img_info'][key]))
+
+ if not crop_flag:
+ return results
+
+ jitter_flag = np.random.random() > self.jitter_prob
+
+ kwargs = dict(
+ jitter_flag=jitter_flag,
+ jitter_ratio_x=self.max_jitter_ratio_x,
+ jitter_ratio_y=self.max_jitter_ratio_y)
+ crop_img = warp_img(results['img'], box, **kwargs)
+
+ results['img'] = crop_img
+ results['img_shape'] = crop_img.shape
+
+ return results
+
+
+@PIPELINES.register_module()
+class FancyPCA:
+ """Implementation of PCA based image augmentation, proposed in the paper
+ ``Imagenet Classification With Deep Convolutional Neural Networks``.
+
+ It alters the intensities of RGB values along the principal components of
+ ImageNet dataset.
+ """
+
+ def __init__(self, eig_vec=None, eig_val=None):
+ if eig_vec is None:
+ eig_vec = torch.Tensor([
+ [-0.5675, +0.7192, +0.4009],
+ [-0.5808, -0.0045, -0.8140],
+ [-0.5836, -0.6948, +0.4203],
+ ]).t()
+ if eig_val is None:
+ eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
+ self.eig_val = eig_val # 1*3
+ self.eig_vec = eig_vec # 3*3
+
+ def pca(self, tensor):
+ assert tensor.size(0) == 3
+ alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
+ reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
+ tensor = tensor + reconst.view(3, 1, 1)
+
+ return tensor
+
+ def __call__(self, results):
+ img = results['img']
+ tensor = self.pca(img)
+ results['img'] = tensor
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomPaddingOCR:
+ """Pad the given image on all sides, as well as modify the coordinates of
+ character bounding box in image.
+
+ Args:
+ max_ratio (list[int]): [left, top, right, bottom].
+ box_type (None|str): Character box type. If not none,
+ should be either 'char_rects' or 'char_quads', with
+ 'char_rects' for rectangle with ``xyxy`` style and
+ 'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
+ """
+
+ def __init__(self, max_ratio=None, box_type=None):
+ if max_ratio is None:
+ max_ratio = [0.1, 0.2, 0.1, 0.2]
+ else:
+ assert utils.is_type_list(max_ratio, float)
+ assert len(max_ratio) == 4
+ assert box_type is None or box_type in ('char_rects', 'char_quads')
+
+ self.max_ratio = max_ratio
+ self.box_type = box_type
+
+ def __call__(self, results):
+
+ img_shape = results['img_shape']
+ ori_height, ori_width = img_shape[:2]
+
+ random_padding_left = round(
+ np.random.uniform(0, self.max_ratio[0]) * ori_width)
+ random_padding_top = round(
+ np.random.uniform(0, self.max_ratio[1]) * ori_height)
+ random_padding_right = round(
+ np.random.uniform(0, self.max_ratio[2]) * ori_width)
+ random_padding_bottom = round(
+ np.random.uniform(0, self.max_ratio[3]) * ori_height)
+
+ img = np.copy(results['img'])
+ img = cv2.copyMakeBorder(img, random_padding_top,
+ random_padding_bottom, random_padding_left,
+ random_padding_right, cv2.BORDER_REPLICATE)
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ if self.box_type is not None:
+ num_points = 2 if self.box_type == 'char_rects' else 4
+ char_num = len(results['ann_info'][self.box_type])
+ for i in range(char_num):
+ for j in range(num_points):
+ results['ann_info'][self.box_type][i][
+ j * 2] += random_padding_left
+ results['ann_info'][self.box_type][i][
+ j * 2 + 1] += random_padding_top
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomRotateImageBox:
+ """Rotate augmentation for segmentation based text recognition.
+
+ Args:
+ min_angle (int): Minimum rotation angle for image and box.
+ max_angle (int): Maximum rotation angle for image and box.
+ box_type (str): Character box type, should be either
+ 'char_rects' or 'char_quads', with 'char_rects'
+ for rectangle with ``xyxy`` style and 'char_quads'
+ for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
+ """
+
+ def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
+ assert box_type in ('char_rects', 'char_quads')
+
+ self.min_angle = min_angle
+ self.max_angle = max_angle
+ self.box_type = box_type
+
+ def __call__(self, results):
+ in_img = results['img']
+ in_chars = results['ann_info']['chars']
+ in_boxes = results['ann_info'][self.box_type]
+
+ img_width, img_height = in_img.size
+ rotate_center = [img_width / 2., img_height / 2.]
+
+ tan_temp_max_angle = rotate_center[1] / rotate_center[0]
+ temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi
+
+ random_angle = np.random.uniform(
+ max(self.min_angle, -temp_max_angle),
+ min(self.max_angle, temp_max_angle))
+ random_angle_radian = random_angle * np.pi / 180.
+
+ img_box = shapely_box(0, 0, img_width, img_height)
+
+ out_img = TF.rotate(
+ in_img,
+ random_angle,
+ resample=False,
+ expand=False,
+ center=rotate_center)
+
+ out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
+ random_angle_radian,
+ rotate_center, img_box)
+
+ results['img'] = out_img
+ results['ann_info']['chars'] = out_chars
+ results['ann_info'][self.box_type] = out_boxes
+
+ return results
+
+ @staticmethod
+ def rotate_bbox(boxes, chars, angle, center, img_box):
+ out_boxes = []
+ out_chars = []
+ for idx, bbox in enumerate(boxes):
+ temp_bbox = []
+ for i in range(len(bbox) // 2):
+ point = [bbox[2 * i], bbox[2 * i + 1]]
+ temp_bbox.append(
+ RandomRotateImageBox.rotate_point(point, angle, center))
+ poly_temp_bbox = Polygon(temp_bbox).buffer(0)
+ if poly_temp_bbox.is_valid:
+ if img_box.intersects(poly_temp_bbox) and (
+ not img_box.touches(poly_temp_bbox)):
+ temp_bbox_area = poly_temp_bbox.area
+
+ intersect_area = img_box.intersection(poly_temp_bbox).area
+ intersect_ratio = intersect_area / temp_bbox_area
+
+ if intersect_ratio >= 0.7:
+ out_box = []
+ for p in temp_bbox:
+ out_box.extend(p)
+ out_boxes.append(out_box)
+ out_chars.append(chars[idx])
+
+ return out_boxes, out_chars
+
+ @staticmethod
+ def rotate_point(point, angle, center):
+ cos_theta = math.cos(-angle)
+ sin_theta = math.sin(-angle)
+ c_x = center[0]
+ c_y = center[1]
+ new_x = (point[0] - c_x) * cos_theta - (point[1] -
+ c_y) * sin_theta + c_x
+ new_y = (point[0] - c_x) * sin_theta + (point[1] -
+ c_y) * cos_theta + c_y
+
+ return [new_x, new_y]
+
+
+@PIPELINES.register_module()
+class OpencvToPil:
+ """Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, results):
+ img = results['img'][..., ::-1]
+ img = Image.fromarray(img)
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class PilToOpencv:
+ """Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, results):
+ img = np.asarray(results['img'])
+ img = img[..., ::-1]
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
diff --git a/mmocr/datasets/pipelines/test_time_aug.py b/mmocr/datasets/pipelines/test_time_aug.py
new file mode 100644
index 00000000..5c8c1a60
--- /dev/null
+++ b/mmocr/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,108 @@
+import mmcv
+import numpy as np
+
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines.compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiRotateAugOCR:
+ """Test-time augmentation with multiple rotations in the case that
+ img_height > img_width.
+
+ An example configuration is as follows:
+
+ .. code-block::
+
+ rotate_degrees=[0, 90, 270],
+ transforms=[
+ dict(
+ type='ResizeOCR',
+ height=32,
+ min_width=32,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio'
+ ]),
+ ]
+
+ After MultiRotateAugOCR with above configuration, the results are wrapped
+ into lists of the same length as follows:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transformation applied for each augmentation.
+ rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation.
+ force_rotate (bool): If True, rotate image by 'rotate_degrees'
+ while ignore image aspect ratio.
+ """
+
+ def __init__(self, transforms, rotate_degrees=None, force_rotate=False):
+ self.transforms = Compose(transforms)
+ self.force_rotate = force_rotate
+ if rotate_degrees is not None:
+ self.rotate_degrees = rotate_degrees if isinstance(
+ rotate_degrees, list) else [rotate_degrees]
+ assert mmcv.is_list_of(self.rotate_degrees, int)
+ for degree in self.rotate_degrees:
+ assert 0 <= degree < 360
+ assert degree % 90 == 0
+ if 0 not in self.rotate_degrees:
+ self.rotate_degrees.append(0)
+ else:
+ self.rotate_degrees = [0]
+
+ def __call__(self, results):
+ """Call function to apply test time augment transformation to results.
+
+ Args:
+ results (dict): Result dict contains the data to be transformed.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+ img_shape = results['img_shape']
+ ori_height, ori_width = img_shape[:2]
+ if not self.force_rotate and ori_height <= ori_width:
+ rotate_degrees = [0]
+ else:
+ rotate_degrees = self.rotate_degrees
+ aug_data = []
+ for degree in set(rotate_degrees):
+ _results = results.copy()
+ if degree == 0:
+ pass
+ elif degree == 90:
+ _results['img'] = np.rot90(_results['img'], 1)
+ elif degree == 180:
+ _results['img'] = np.rot90(_results['img'], 2)
+ elif degree == 270:
+ _results['img'] = np.rot90(_results['img'], 3)
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'rotate_degrees={self.rotate_degrees})'
+ return repr_str
diff --git a/mmocr/datasets/pipelines/textdet_targets/__init__.py b/mmocr/datasets/pipelines/textdet_targets/__init__.py
new file mode 100644
index 00000000..1565e924
--- /dev/null
+++ b/mmocr/datasets/pipelines/textdet_targets/__init__.py
@@ -0,0 +1,10 @@
+from .base_textdet_targets import BaseTextDetTargets
+from .dbnet_targets import DBNetTargets
+from .panet_targets import PANetTargets
+from .psenet_targets import PSENetTargets
+from .textsnake_targets import TextSnakeTargets
+
+__all__ = [
+ 'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets',
+ 'TextSnakeTargets'
+]
diff --git a/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py b/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
new file mode 100644
index 00000000..183743f8
--- /dev/null
+++ b/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
@@ -0,0 +1,168 @@
+import sys
+
+import cv2
+import numpy as np
+import Polygon as plg
+import pyclipper
+from mmcv.utils import print_log
+
+import mmocr.utils.check_argument as check_argument
+
+
+class BaseTextDetTargets:
+ """Generate text detector ground truths."""
+
+ def __init__(self):
+ pass
+
+ def point2line(self, xs, ys, point_1, point_2):
+ """Compute the distance from point to a line. This is adapted from
+ https://github.com/MhLiao/DB.
+
+ Args:
+ xs (ndarray): The x coordinates of size hxw.
+ ys (ndarray): The y coordinates of size hxw.
+ point_1 (ndarray): The first point with shape 1x2.
+ point_2 (ndarray): The second point with shape 1x2.
+
+ Returns:
+ result (ndarray): The distance matrix of size hxw.
+ """
+ # suppose a triangle with three edge abc with c=point_1 point_2
+ # a^2
+ a_square = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
+ # b^2
+ b_square = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
+ # c^2
+ c_square = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] -
+ point_2[1])
+ # -cosC=(c^2-a^2-b^2)/2(ab)
+ neg_cos_c = (
+ (c_square - a_square - b_square) /
+ (np.finfo(np.float32).eps + 2 * np.sqrt(a_square * b_square)))
+ # sinC^2=1-cosC^2
+ square_sin = 1 - np.square(neg_cos_c)
+ square_sin = np.nan_to_num(square_sin)
+ # distance=a*b*sinC/c=a*h/c=2*area/c
+ result = np.sqrt(a_square * b_square * square_sin /
+ (np.finfo(np.float32).eps + c_square))
+ # set result to minimum edge if C 0:
+ padded_polygon = np.array(padded_polygon[0])
+ else:
+ print(f'padding {polygon} with {distance} gets {padded_polygon}')
+ padded_polygon = polygon.copy().astype(np.int32)
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
+
+ x_min = padded_polygon[:, 0].min()
+ x_max = padded_polygon[:, 0].max()
+ y_min = padded_polygon[:, 1].min()
+ y_max = padded_polygon[:, 1].max()
+ width = x_max - x_min + 1
+ height = y_max - y_min + 1
+
+ polygon[:, 0] = polygon[:, 0] - x_min
+ polygon[:, 1] = polygon[:, 1] - y_min
+
+ xs = np.broadcast_to(
+ np.linspace(0, width - 1, num=width).reshape(1, width),
+ (height, width))
+ ys = np.broadcast_to(
+ np.linspace(0, height - 1, num=height).reshape(height, 1),
+ (height, width))
+
+ distance_map = np.zeros((polygon.shape[0], height, width),
+ dtype=np.float32)
+ for i in range(polygon.shape[0]):
+ j = (i + 1) % polygon.shape[0]
+ absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j])
+ distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
+ distance_map = distance_map.min(axis=0)
+
+ x_min_valid = min(max(0, x_min), canvas.shape[1] - 1)
+ x_max_valid = min(max(0, x_max), canvas.shape[1] - 1)
+ y_min_valid = min(max(0, y_min), canvas.shape[0] - 1)
+ y_max_valid = min(max(0, y_max), canvas.shape[0] - 1)
+ canvas[y_min_valid:y_max_valid + 1,
+ x_min_valid:x_max_valid + 1] = np.fmax(
+ 1 - distance_map[y_min_valid - y_min:y_max_valid - y_max +
+ height, x_min_valid - x_min:x_max_valid -
+ x_max + width],
+ canvas[y_min_valid:y_max_valid + 1,
+ x_min_valid:x_max_valid + 1])
+
+ def generate_targets(self, results):
+ """Generate the gt targets for DBNet.
+
+ Args:
+ results (dict): The input result dictionary.
+
+ Returns:
+ results (dict): The output result dictionary.
+ """
+ assert isinstance(results, dict)
+ polygons = results['gt_masks'].masks
+ if 'bbox_fields' in results:
+ results['bbox_fields'].clear()
+ ignore_tags = self.find_invalid(results)
+ h, w, _ = results['img_shape']
+
+ gt_shrink, ignore_tags = self.generate_kernels((h, w),
+ polygons,
+ self.shrink_ratio,
+ ignore_tags=ignore_tags)
+
+ results = self.ignore_texts(results, ignore_tags)
+
+ # polygons and polygons_ignore reassignment.
+ polygons = results['gt_masks'].masks
+ polygons_ignore = results['gt_masks_ignore'].masks
+
+ gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore)
+
+ gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons)
+
+ results['mask_fields'].clear() # rm gt_masks encoded by polygons
+ results.pop('gt_labels', None)
+ results.pop('gt_masks', None)
+ results.pop('gt_bboxes', None)
+ results.pop('gt_bboxes_ignore', None)
+
+ mapping = {
+ 'gt_shrink': gt_shrink,
+ 'gt_shrink_mask': gt_shrink_mask,
+ 'gt_thr': gt_thr,
+ 'gt_thr_mask': gt_thr_mask
+ }
+ for key, value in mapping.items():
+ value = value if isinstance(value, list) else [value]
+ results[key] = BitmapMasks(value, h, w)
+ results['mask_fields'].append(key)
+
+ return results
diff --git a/mmocr/datasets/pipelines/textdet_targets/panet_targets.py b/mmocr/datasets/pipelines/textdet_targets/panet_targets.py
new file mode 100644
index 00000000..11b85283
--- /dev/null
+++ b/mmocr/datasets/pipelines/textdet_targets/panet_targets.py
@@ -0,0 +1,63 @@
+from mmdet.core import BitmapMasks
+from mmdet.datasets.builder import PIPELINES
+from . import BaseTextDetTargets
+
+
+@PIPELINES.register_module()
+class PANetTargets(BaseTextDetTargets):
+ """Generate the ground truths for PANet: Efficient and Accurate Arbitrary-
+ Shaped Text Detection with Pixel Aggregation Network.
+
+ [https://arxiv.org/abs/1908.05900]. This code is partially adapted from
+ https://github.com/WenmuZhou/PAN.pytorch.
+
+ Args:
+ shrink_ratio (tuple[float]): The ratios for shrinking text instances.
+ max_shrink (int): The maximum shrink distance.
+ """
+
+ def __init__(self, shrink_ratio=(1.0, 0.5), max_shrink=20):
+ self.shrink_ratio = shrink_ratio
+ self.max_shrink = max_shrink
+
+ def generate_targets(self, results):
+ """Generate the gt targets for PANet.
+
+ Args:
+ results (dict): The input result dictionary.
+
+ Returns:
+ results (dict): The output result dictionary.
+ """
+
+ assert isinstance(results, dict)
+
+ polygon_masks = results['gt_masks'].masks
+ polygon_masks_ignore = results['gt_masks_ignore'].masks
+
+ h, w, _ = results['img_shape']
+ gt_kernels = []
+ for ratio in self.shrink_ratio:
+ mask, _ = self.generate_kernels((h, w),
+ polygon_masks,
+ ratio,
+ max_shrink=self.max_shrink,
+ ignore_tags=None)
+ gt_kernels.append(mask)
+ gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
+
+ results['mask_fields'].clear() # rm gt_masks encoded by polygons
+ if 'bbox_fields' in results:
+ results['bbox_fields'].clear()
+ results.pop('gt_labels', None)
+ results.pop('gt_masks', None)
+ results.pop('gt_bboxes', None)
+ results.pop('gt_bboxes_ignore', None)
+
+ mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask}
+ for key, value in mapping.items():
+ value = value if isinstance(value, list) else [value]
+ results[key] = BitmapMasks(value, h, w)
+ results['mask_fields'].append(key)
+
+ return results
diff --git a/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py b/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py
new file mode 100644
index 00000000..d78b7fd7
--- /dev/null
+++ b/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py
@@ -0,0 +1,21 @@
+from mmdet.datasets.builder import PIPELINES
+from . import PANetTargets
+
+
+@PIPELINES.register_module()
+class PSENetTargets(PANetTargets):
+ """Generate the ground truth targets of PSENet: Shape robust text detection
+ with progressive scale expansion network.
+
+ [https://arxiv.org/abs/1903.12473]. This code is partially adapted from
+ https://github.com/whai362/PSENet.
+
+ Args:
+ shrink_ratio(tuple(float)): The ratios for shrinking text instances.
+ max_shrink(int): The maximum shrinking distance.
+ """
+
+ def __init__(self,
+ shrink_ratio=(1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4),
+ max_shrink=20):
+ super().__init__(shrink_ratio=shrink_ratio, max_shrink=max_shrink)
diff --git a/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py
new file mode 100644
index 00000000..eeebf5bd
--- /dev/null
+++ b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py
@@ -0,0 +1,454 @@
+import cv2
+import numpy as np
+from numpy.linalg import norm
+
+import mmocr.utils.check_argument as check_argument
+from mmdet.core import BitmapMasks
+from mmdet.datasets.builder import PIPELINES
+from . import BaseTextDetTargets
+
+
+@PIPELINES.register_module()
+class TextSnakeTargets(BaseTextDetTargets):
+ """Generate the ground truth targets of TextSnake: TextSnake: A Flexible
+ Representation for Detecting Text of Arbitrary Shapes.
+
+ [https://arxiv.org/abs/1807.01544]. This was partially adapted from
+ https://github.com/princewang1994/TextSnake.pytorch.
+
+ Args:
+ orientation_thr (float): The threshold for distinguishing between
+ head edge and tail edge among the horizontal and vertical edges
+ of a quadrangle.
+ """
+
+ def __init__(self,
+ orientation_thr=2.0,
+ resample_step=4.0,
+ center_region_shrink_ratio=0.3):
+
+ super().__init__()
+ self.orientation_thr = orientation_thr
+ self.resample_step = resample_step
+ self.center_region_shrink_ratio = center_region_shrink_ratio
+
+ def vector_angle(self, vec1, vec2):
+ if vec1.ndim > 1:
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+ else:
+ unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+ if vec2.ndim > 1:
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+ else:
+ unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+ return np.arccos(
+ np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+ def vector_slope(self, vec):
+ assert len(vec) == 2
+ return abs(vec[1] / (vec[0] + 1e-8))
+
+ def vector_sin(self, vec):
+ assert len(vec) == 2
+ return vec[1] / (norm(vec) + 1e-8)
+
+ def vector_cos(self, vec):
+ assert len(vec) == 2
+ return vec[0] / (norm(vec) + 1e-8)
+
+ def find_head_tail(self, points, orientation_thr):
+ """Find the head edge and tail edge of a text polygon.
+
+ Args:
+ points (ndarray): The points composing a text polygon.
+ orientation_thr (float): The threshold for distinguishing between
+ head edge and tail edge among the horizontal and vertical edges
+ of a quadrangle.
+
+ Returns:
+ head_inds (list): The indexes of two points composing head edge.
+ tail_inds (list): The indexes of two points composing tail edge.
+ """
+
+ assert points.ndim == 2
+ assert points.shape[0] >= 4
+ assert points.shape[1] == 2
+ assert isinstance(orientation_thr, float)
+
+ if len(points) > 4:
+ pad_points = np.vstack([points, points[0]])
+ edge_vec = pad_points[1:] - pad_points[:-1]
+
+ theta_sum = []
+
+ for i, edge_vec1 in enumerate(edge_vec):
+ adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+ adjacent_edge_vec = edge_vec[adjacent_ind]
+ temp_theta_sum = np.sum(
+ self.vector_angle(edge_vec1, adjacent_edge_vec))
+ theta_sum.append(temp_theta_sum)
+ theta_sum = np.array(theta_sum)
+ head_start, tail_start = np.argsort(theta_sum)[::-1][0:2]
+
+ if (abs(head_start - tail_start) < 2
+ or abs(head_start - tail_start) > 12):
+ tail_start = (head_start + len(points) // 2) % len(points)
+ head_end = (head_start + 1) % len(points)
+ tail_end = (tail_start + 1) % len(points)
+
+ if head_end > tail_end:
+ head_start, tail_start = tail_start, head_start
+ head_end, tail_end = tail_end, head_end
+ head_inds = [head_start, head_end]
+ tail_inds = [tail_start, tail_end]
+ else:
+ if self.vector_slope(points[1] - points[0]) + self.vector_slope(
+ points[3] - points[2]) < self.vector_slope(
+ points[2] - points[1]) + self.vector_slope(points[0] -
+ points[3]):
+ horizontal_edge_inds = [[0, 1], [2, 3]]
+ vertical_edge_inds = [[3, 0], [1, 2]]
+ else:
+ horizontal_edge_inds = [[3, 0], [1, 2]]
+ vertical_edge_inds = [[0, 1], [2, 3]]
+
+ vertical_len_sum = norm(points[vertical_edge_inds[0][0]] -
+ points[vertical_edge_inds[0][1]]) + norm(
+ points[vertical_edge_inds[1][0]] -
+ points[vertical_edge_inds[1][1]])
+ horizontal_len_sum = norm(
+ points[horizontal_edge_inds[0][0]] -
+ points[horizontal_edge_inds[0][1]]) + norm(
+ points[horizontal_edge_inds[1][0]] -
+ points[horizontal_edge_inds[1][1]])
+
+ if vertical_len_sum > horizontal_len_sum * orientation_thr:
+ head_inds = horizontal_edge_inds[0]
+ tail_inds = horizontal_edge_inds[1]
+ else:
+ head_inds = vertical_edge_inds[0]
+ tail_inds = vertical_edge_inds[1]
+
+ return head_inds, tail_inds
+
+ def reorder_poly_edge(self, points):
+ """Get the respective points composing head edge, tail edge, top
+ sideline and bottom sideline.
+
+ Args:
+ points (ndarray): The points composing a text polygon.
+
+ Returns:
+ head_edge (ndarray): The two points composing the head edge of text
+ polygon.
+ tail_edge (ndarray): The two points composing the tail edge of text
+ polygon.
+ top_sideline (ndarray): The points composing top curved sideline of
+ text polygon.
+ bot_sideline (ndarray): The points composing bottom curved sideline
+ of text polygon.
+ """
+
+ assert points.ndim == 2
+ assert points.shape[0] >= 4
+ assert points.shape[1] == 2
+
+ head_inds, tail_inds = self.find_head_tail(points,
+ self.orientation_thr)
+ head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+ pad_points = np.vstack([points, points])
+ if tail_inds[1] < 1:
+ tail_inds[1] = len(points)
+ sideline1 = pad_points[head_inds[1]:tail_inds[1]]
+ sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
+ sideline_mean_shift = np.mean(
+ sideline1, axis=0) - np.mean(
+ sideline2, axis=0)
+
+ if sideline_mean_shift[1] > 0:
+ top_sideline, bot_sideline = sideline2, sideline1
+ else:
+ top_sideline, bot_sideline = sideline1, sideline2
+
+ return head_edge, tail_edge, top_sideline, bot_sideline
+
+ def resample_line(self, line, n):
+ """Resample n points on a line.
+
+ Args:
+ line (ndarray): The points composing a line.
+ n (int): The resampled points number.
+
+ Returns:
+ resampled_line (ndarray): The points composing the resampled line.
+ """
+
+ assert line.ndim == 2
+ assert line.shape[0] >= 2
+ assert line.shape[1] == 2
+ assert isinstance(n, int)
+
+ length_list = [
+ norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+ ]
+ total_length = sum(length_list)
+ length_cumsum = np.cumsum([0.0] + length_list)
+ delta_length = total_length / (float(n) + 1e-8)
+
+ current_edge_ind = 0
+ resampled_line = [line[0]]
+
+ for i in range(1, n):
+ current_line_len = i * delta_length
+
+ while current_line_len >= length_cumsum[current_edge_ind + 1]:
+ current_edge_ind += 1
+ current_edge_end_shift = current_line_len - length_cumsum[
+ current_edge_ind]
+ end_shift_ratio = current_edge_end_shift / length_list[
+ current_edge_ind]
+ current_point = line[current_edge_ind] + (
+ line[current_edge_ind + 1] -
+ line[current_edge_ind]) * end_shift_ratio
+ resampled_line.append(current_point)
+
+ resampled_line.append(line[-1])
+ resampled_line = np.array(resampled_line)
+
+ return resampled_line
+
+ def resample_sidelines(self, sideline1, sideline2, resample_step):
+ """Resample two sidelines to be of the same points number according to
+ step size.
+
+ Args:
+ sideline1 (ndarray): The points composing a sideline of a text
+ polygon.
+ sideline2 (ndarray): The points composing another sideline of a
+ text polygon.
+ resample_step (float): The resampled step size.
+
+ Returns:
+ resampled_line1 (ndarray): The resampled line 1.
+ resampled_line2 (ndarray): The resampled line 2.
+ """
+
+ assert sideline1.ndim == sideline1.ndim == 2
+ assert sideline1.shape[1] == sideline1.shape[1] == 2
+ assert sideline1.shape[0] >= 2
+ assert sideline2.shape[0] >= 2
+ assert isinstance(resample_step, float)
+
+ length1 = sum([
+ norm(sideline1[i + 1] - sideline1[i])
+ for i in range(len(sideline1) - 1)
+ ])
+ length2 = sum([
+ norm(sideline2[i + 1] - sideline2[i])
+ for i in range(len(sideline2) - 1)
+ ])
+
+ total_length = (length1 + length2) / 2
+ resample_point_num = int(float(total_length) / resample_step)
+
+ resampled_line1 = self.resample_line(sideline1, resample_point_num)
+ resampled_line2 = self.resample_line(sideline2, resample_point_num)
+
+ return resampled_line1, resampled_line2
+
+ def draw_center_region_maps(self, top_line, bot_line, center_line,
+ center_region_mask, radius_map, sin_map,
+ cos_map, region_shrink_ratio):
+ """Draw attributes on text center region.
+
+ Args:
+ top_line (ndarray): The points composing top curved sideline of
+ text polygon.
+ bot_line (ndarray): The points composing bottom curved sideline
+ of text polygon.
+ center_line (ndarray): The points composing the center line of text
+ instance.
+ center_region_mask (ndarray): The text center region mask.
+ radius_map (ndarray): The map where the distance from point to
+ sidelines will be drawn on for each pixel in text center
+ region.
+ sin_map (ndarray): The map where vector_sin(theta) will be drawn
+ on text center regions. Theta is the angle between tangent
+ line and vector (1, 0).
+ cos_map (ndarray): The map where vector_cos(theta) will be drawn on
+ text center regions. Theta is the angle between tangent line
+ and vector (1, 0).
+ region_shrink_ratio (float): The shrink ratio of text center.
+ """
+
+ assert top_line.shape == bot_line.shape == center_line.shape
+ assert (center_region_mask.shape == radius_map.shape == sin_map.shape
+ == cos_map.shape)
+ assert isinstance(region_shrink_ratio, float)
+ for i in range(0, len(center_line) - 1):
+
+ top_mid_point = (top_line[i] + top_line[i + 1]) / 2
+ bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
+ radius = norm(top_mid_point - bot_mid_point) / 2
+
+ text_direction = center_line[i + 1] - center_line[i]
+ sin_theta = self.vector_sin(text_direction)
+ cos_theta = self.vector_cos(text_direction)
+
+ pnt_tl = center_line[i] + (top_line[i] -
+ center_line[i]) * region_shrink_ratio
+ pnt_tr = center_line[i + 1] + (
+ top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
+ pnt_br = center_line[i + 1] + (
+ bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
+ pnt_bl = center_line[i] + (bot_line[i] -
+ center_line[i]) * region_shrink_ratio
+ current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br,
+ pnt_bl]).astype(np.int32)
+
+ cv2.fillPoly(center_region_mask, [current_center_box], color=1)
+ cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
+ cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
+ cv2.fillPoly(radius_map, [current_center_box], color=radius)
+
+ def generate_center_mask_attrib_maps(self, img_size, text_polys):
+ """Generate text center region mask and geometric attribute maps.
+
+ Args:
+ img_size (tuple): The image size of (height, width).
+ text_polys (list[list[ndarray]]): The list of text polygons.
+
+ Returns:
+ center_region_mask (ndarray): The text center region mask.
+ radius_map (ndarray): The distance map from each pixel in text
+ center region to top sideline.
+ sin_map (ndarray): The sin(theta) map where theta is the angle
+ between vector (top point - bottom point) and vector (1, 0).
+ cos_map (ndarray): The cos(theta) map where theta is the angle
+ between vector (top point - bottom point) and vector (1, 0).
+ """
+
+ assert isinstance(img_size, tuple)
+ assert check_argument.is_2dlist(text_polys)
+
+ h, w = img_size
+
+ center_region_mask = np.zeros((h, w), np.uint8)
+ radius_map = np.zeros((h, w), dtype=np.float32)
+ sin_map = np.zeros((h, w), dtype=np.float32)
+ cos_map = np.zeros((h, w), dtype=np.float32)
+
+ for poly in text_polys:
+ assert len(poly) == 1
+ text_instance = [[poly[0][i], poly[0][i + 1]]
+ for i in range(0, len(poly[0]), 2)]
+ polygon_points = np.array(
+ text_instance, dtype=np.int32).reshape(-1, 2)
+
+ _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
+ resampled_top_line, resampled_bot_line = self.resample_sidelines(
+ top_line, bot_line, self.resample_step)
+ resampled_bot_line = resampled_bot_line[::-1]
+ center_line = (resampled_top_line + resampled_bot_line) / 2
+
+ if self.vector_slope(center_line[-1] - center_line[0]) > 0.9:
+ if (center_line[-1] - center_line[0])[1] < 0:
+ center_line = center_line[::-1]
+ resampled_top_line = resampled_top_line[::-1]
+ resampled_bot_line = resampled_bot_line[::-1]
+ else:
+ if (center_line[-1] - center_line[0])[0] < 0:
+ center_line = center_line[::-1]
+ resampled_top_line = resampled_top_line[::-1]
+ resampled_bot_line = resampled_bot_line[::-1]
+
+ line_head_shrink_len = norm(resampled_top_line[0] -
+ resampled_bot_line[0]) / 4.0
+ line_tail_shrink_len = norm(resampled_top_line[-1] -
+ resampled_bot_line[-1]) / 4.0
+ head_shrink_num = int(line_head_shrink_len // self.resample_step)
+ tail_shrink_num = int(line_tail_shrink_len // self.resample_step)
+
+ if len(center_line) > head_shrink_num + tail_shrink_num + 2:
+ center_line = center_line[head_shrink_num:len(center_line) -
+ tail_shrink_num]
+ resampled_top_line = resampled_top_line[
+ head_shrink_num:len(resampled_top_line) - tail_shrink_num]
+ resampled_bot_line = resampled_bot_line[
+ head_shrink_num:len(resampled_bot_line) - tail_shrink_num]
+
+ self.draw_center_region_maps(resampled_top_line,
+ resampled_bot_line, center_line,
+ center_region_mask, radius_map,
+ sin_map, cos_map,
+ self.center_region_shrink_ratio)
+
+ return center_region_mask, radius_map, sin_map, cos_map
+
+ def generate_text_region_mask(self, img_size, text_polys):
+ """Generate text center region mask and geometry attribute maps.
+
+ Args:
+ img_size (tuple): The image size (height, width).
+ text_polys (list[list[ndarray]]): The list of text polygons.
+
+ Returns:
+ text_region_mask (ndarray): The text region mask.
+ """
+
+ assert isinstance(img_size, tuple)
+ assert check_argument.is_2dlist(text_polys)
+
+ h, w = img_size
+ text_region_mask = np.zeros((h, w), dtype=np.uint8)
+
+ for poly in text_polys:
+ assert len(poly) == 1
+ text_instance = [[poly[0][i], poly[0][i + 1]]
+ for i in range(0, len(poly[0]), 2)]
+ polygon = np.array(
+ text_instance, dtype=np.int32).reshape((1, -1, 2))
+ cv2.fillPoly(text_region_mask, polygon, 1)
+
+ return text_region_mask
+
+ def generate_targets(self, results):
+ """Generate the gt targets for TextSnake.
+
+ Args:
+ results (dict): The input result dictionary.
+
+ Returns:
+ results (dict): The output result dictionary.
+ """
+
+ assert isinstance(results, dict)
+
+ polygon_masks = results['gt_masks'].masks
+ polygon_masks_ignore = results['gt_masks_ignore'].masks
+
+ h, w, _ = results['img_shape']
+
+ gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
+ gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
+
+ (gt_center_region_mask, gt_radius_map, gt_sin_map,
+ gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
+ polygon_masks)
+
+ results['mask_fields'].clear() # rm gt_masks encoded by polygons
+ mapping = {
+ 'gt_text_mask': gt_text_mask,
+ 'gt_center_region_mask': gt_center_region_mask,
+ 'gt_mask': gt_mask,
+ 'gt_radius_map': gt_radius_map,
+ 'gt_sin_map': gt_sin_map,
+ 'gt_cos_map': gt_cos_map
+ }
+ for key, value in mapping.items():
+ value = value if isinstance(value, list) else [value]
+ results[key] = BitmapMasks(value, h, w)
+ results['mask_fields'].append(key)
+
+ return results
diff --git a/mmocr/datasets/pipelines/transforms.py b/mmocr/datasets/pipelines/transforms.py
new file mode 100644
index 00000000..537c107e
--- /dev/null
+++ b/mmocr/datasets/pipelines/transforms.py
@@ -0,0 +1,727 @@
+import math
+
+import cv2
+import numpy as np
+import torchvision.transforms as transforms
+from PIL import Image
+
+import mmocr.core.evaluation.utils as eval_utils
+from mmdet.core import BitmapMasks, PolygonMasks
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines.transforms import Resize
+from mmocr.utils import check_argument
+
+
+@PIPELINES.register_module()
+class RandomCropInstances:
+ """Randomly crop images and make sure to contain text instances.
+
+ Args:
+ target_size (tuple or int): (height, width)
+ positive_sample_ratio (float): The probability of sampling regions
+ that go through positive regions.
+ """
+
+ def __init__(
+ self,
+ target_size,
+ instance_key,
+ mask_type='inx0', # 'inx0' or 'union_all'
+ positive_sample_ratio=5.0 / 8.0):
+
+ assert mask_type in ['inx0', 'union_all']
+
+ self.mask_type = mask_type
+ self.instance_key = instance_key
+ self.positive_sample_ratio = positive_sample_ratio
+ self.target_size = target_size if (target_size is None or isinstance(
+ target_size, tuple)) else (target_size, target_size)
+
+ def sample_offset(self, img_gt, img_size):
+ h, w = img_size
+ t_h, t_w = self.target_size
+
+ # target size is bigger than origin size
+ t_h = t_h if t_h < h else h
+ t_w = t_w if t_w < w else w
+ if (img_gt is not None
+ and np.random.random_sample() < self.positive_sample_ratio
+ and np.max(img_gt) > 0):
+
+ # make sure to crop the positive region
+
+ # the minimum top left to crop positive region (h,w)
+ tl = np.min(np.where(img_gt > 0), axis=1) - (t_h, t_w)
+ tl[tl < 0] = 0
+ # the maximum top left to crop positive region
+ br = np.max(np.where(img_gt > 0), axis=1) - (t_h, t_w)
+ br[br < 0] = 0
+ # if br is too big so that crop the outside region of img
+ br[0] = min(br[0], h - t_h)
+ br[1] = min(br[1], w - t_w)
+ #
+ h = np.random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
+ w = np.random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
+ else:
+ # make sure not to crop outside of img
+
+ h = np.random.randint(0, h - t_h) if h - t_h > 0 else 0
+ w = np.random.randint(0, w - t_w) if w - t_w > 0 else 0
+
+ return (h, w)
+
+ @staticmethod
+ def crop_img(img, offset, target_size):
+ h, w = img.shape[:2]
+ br = np.min(
+ np.stack((np.array(offset) + np.array(target_size), np.array(
+ (h, w)))),
+ axis=0)
+ return img[offset[0]:br[0], offset[1]:br[1]], np.array(
+ [offset[1], offset[0], br[1], br[0]])
+
+ def crop_bboxes(self, bboxes, canvas_bbox):
+ kept_bboxes = []
+ kept_inx = []
+ canvas_poly = eval_utils.box2polygon(canvas_bbox)
+ tl = canvas_bbox[0:2]
+
+ for inx, bbox in enumerate(bboxes):
+ poly = eval_utils.box2polygon(bbox)
+ area, inters = eval_utils.poly_intersection(poly, canvas_poly)
+ if area == 0:
+ continue
+ xmin, xmax, ymin, ymax = inters.boundingBox()
+ kept_bboxes += [
+ np.array(
+ [xmin - tl[0], ymin - tl[1], xmax - tl[0], ymax - tl[1]],
+ dtype=np.float32)
+ ]
+ kept_inx += [inx]
+
+ if len(kept_inx) == 0:
+ return np.array([]).astype(np.float32).reshape(0, 4), kept_inx
+
+ return np.stack(kept_bboxes), kept_inx
+
+ @staticmethod
+ def generate_mask(gt_mask, type):
+
+ if type == 'inx0':
+ return gt_mask.masks[0]
+ if type == 'union_all':
+ mask = gt_mask.masks[0].copy()
+ for inx in range(1, len(gt_mask.masks)):
+ mask = np.logical_or(mask, gt_mask.masks[inx])
+ return mask
+
+ raise NotImplementedError
+
+ def __call__(self, results):
+
+ gt_mask = results[self.instance_key]
+ mask = None
+ if len(gt_mask.masks) > 0:
+ mask = self.generate_mask(gt_mask, self.mask_type)
+ results['crop_offset'] = self.sample_offset(mask,
+ results['img'].shape[:2])
+
+ # crop img. bbox = [x1,y1,x2,y2]
+ img, bbox = self.crop_img(results['img'], results['crop_offset'],
+ self.target_size)
+ results['img'] = img
+ img_shape = img.shape
+ results['img_shape'] = img_shape
+
+ # crop masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].crop(bbox)
+
+ # for mask rcnn
+ for key in results.get('bbox_fields', []):
+ results[key], kept_inx = self.crop_bboxes(results[key], bbox)
+ if key == 'gt_bboxes':
+ # ignore gt_labels accordingly
+ if 'gt_labels' in results:
+ ori_labels = results['gt_labels']
+ ori_inst_num = len(ori_labels)
+ results['gt_labels'] = [
+ ori_labels[inx] for inx in range(ori_inst_num)
+ if inx in kept_inx
+ ]
+ # ignore g_masks accordingly
+ if 'gt_masks' in results:
+ ori_mask = results['gt_masks'].masks
+ kept_mask = [
+ ori_mask[inx] for inx in range(ori_inst_num)
+ if inx in kept_inx
+ ]
+ target_h, target_w = bbox[3] - bbox[1], bbox[2] - bbox[0]
+ if len(kept_inx) > 0:
+ kept_mask = np.stack(kept_mask)
+ else:
+ kept_mask = np.empty((0, target_h, target_w),
+ dtype=np.float32)
+ results['gt_masks'] = BitmapMasks(kept_mask, target_h,
+ target_w)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomRotateTextDet:
+ """Randomly rotate images."""
+
+ def __init__(self, rotate_ratio=1.0, max_angle=10):
+ self.rotate_ratio = rotate_ratio
+ self.max_angle = max_angle
+
+ @staticmethod
+ def sample_angle(max_angle):
+ angle = np.random.random_sample() * 2 * max_angle - max_angle
+ return angle
+
+ @staticmethod
+ def rotate_img(img, angle):
+ h, w = img.shape[:2]
+ rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
+ img_target = cv2.warpAffine(
+ img, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST)
+ assert img_target.shape == img.shape
+ return img_target
+
+ def __call__(self, results):
+ if np.random.random_sample() < self.rotate_ratio:
+ # rotate imgs
+ results['rotated_angle'] = self.sample_angle(self.max_angle)
+ img = self.rotate_img(results['img'], results['rotated_angle'])
+ results['img'] = img
+ img_shape = img.shape
+ results['img_shape'] = img_shape
+
+ # rotate masks
+ for key in results.get('mask_fields', []):
+ masks = results[key].masks
+ mask_list = []
+ for m in masks:
+ rotated_m = self.rotate_img(m, results['rotated_angle'])
+ mask_list.append(rotated_m)
+ results[key] = BitmapMasks(mask_list, *(img_shape[:2]))
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ColorJitter:
+ """An interface for torch color jitter so that it can be invoked in
+ mmdetection pipeline."""
+
+ def __init__(self, **kwargs):
+ self.transform = transforms.ColorJitter(**kwargs)
+
+ def __call__(self, results):
+ # img is bgr
+ img = results['img'][..., ::-1]
+ img = Image.fromarray(img)
+ img = self.transform(img)
+ img = np.asarray(img)
+ img = img[..., ::-1]
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ScaleAspectJitter(Resize):
+ """Resize image and segmentation mask encoded by coordinates.
+
+ Allowed resize types are `around_min_img_scale`, `long_short_bound`, and
+ `indep_sample_in_range`.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=False,
+ resize_type='around_min_img_scale',
+ aspect_ratio_range=None,
+ long_size_bound=None,
+ short_size_bound=None,
+ scale_range=None):
+ super().__init__(
+ img_scale=img_scale,
+ multiscale_mode=multiscale_mode,
+ ratio_range=ratio_range,
+ keep_ratio=keep_ratio)
+ assert not keep_ratio
+ assert resize_type in [
+ 'around_min_img_scale', 'long_short_bound', 'indep_sample_in_range'
+ ]
+ self.resize_type = resize_type
+
+ if resize_type == 'indep_sample_in_range':
+ assert ratio_range is None
+ assert aspect_ratio_range is None
+ assert short_size_bound is None
+ assert long_size_bound is None
+ assert scale_range is not None
+ else:
+ assert scale_range is None
+ assert isinstance(ratio_range, tuple)
+ assert isinstance(aspect_ratio_range, tuple)
+ assert check_argument.equal_len(ratio_range, aspect_ratio_range)
+
+ if resize_type in ['long_short_bound']:
+ assert short_size_bound is not None
+ assert long_size_bound is not None
+
+ self.aspect_ratio_range = aspect_ratio_range
+ self.long_size_bound = long_size_bound
+ self.short_size_bound = short_size_bound
+ self.scale_range = scale_range
+
+ @staticmethod
+ def sample_from_range(range):
+ assert len(range) == 2
+ min_value, max_value = min(range), max(range)
+ value = np.random.random_sample() * (max_value - min_value) + min_value
+
+ return value
+
+ def _random_scale(self, results):
+
+ if self.resize_type == 'indep_sample_in_range':
+ w = self.sample_from_range(self.scale_range)
+ h = self.sample_from_range(self.scale_range)
+ results['scale'] = (int(w), int(h)) # (w,h)
+ results['scale_idx'] = None
+ return
+ h, w = results['img'].shape[0:2]
+ if self.resize_type == 'long_short_bound':
+ scale1 = 1
+ if max(h, w) > self.long_size_bound:
+ scale1 = self.long_size_bound / max(h, w)
+ scale2 = self.sample_from_range(self.ratio_range)
+ scale = scale1 * scale2
+ if min(h, w) * scale <= self.short_size_bound:
+ scale = (self.short_size_bound + 10) * 1.0 / min(h, w)
+ elif self.resize_type == 'around_min_img_scale':
+ short_size = min(self.img_scale[0])
+ ratio = self.sample_from_range(self.ratio_range)
+ scale = (ratio * short_size) / min(h, w)
+ else:
+ raise NotImplementedError
+
+ aspect = self.sample_from_range(self.aspect_ratio_range)
+ h_scale = scale * math.sqrt(aspect)
+ w_scale = scale / math.sqrt(aspect)
+ results['scale'] = (int(w * w_scale), int(h * h_scale)) # (w,h)
+ results['scale_idx'] = None
+
+
+@PIPELINES.register_module()
+class AffineJitter:
+ """An interface for torchvision random affine so that it can be invoked in
+ mmdet pipeline."""
+
+ def __init__(self,
+ degrees=4,
+ translate=(0.02, 0.04),
+ scale=(0.9, 1.1),
+ shear=None,
+ resample=False,
+ fillcolor=0):
+ self.transform = transforms.RandomAffine(
+ degrees=degrees,
+ translate=translate,
+ scale=scale,
+ shear=shear,
+ resample=resample,
+ fillcolor=fillcolor)
+
+ def __call__(self, results):
+ # img is bgr
+ img = results['img'][..., ::-1]
+ img = Image.fromarray(img)
+ img = self.transform(img)
+ img = np.asarray(img)
+ img = img[..., ::-1]
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCropPolyInstances:
+ """Randomly crop images and make sure to contain at least one intact
+ instance."""
+
+ def __init__(self,
+ instance_key='gt_masks',
+ crop_ratio=5.0 / 8.0,
+ min_side_ratio=0.4):
+ super().__init__()
+ self.instance_key = instance_key
+ self.crop_ratio = crop_ratio
+ self.min_side_ratio = min_side_ratio
+
+ def sample_valid_start_end(self, valid_array, min_len, max_start, min_end):
+
+ assert isinstance(min_len, int)
+ assert len(valid_array) > min_len
+
+ start_array = valid_array.copy()
+ max_start = min(len(start_array) - min_len, max_start)
+ start_array[max_start:] = 0
+ start_array[0] = 1
+ diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
+ region_starts = np.where(diff_array < 0)[0]
+ region_ends = np.where(diff_array > 0)[0]
+ region_ind = np.random.randint(0, len(region_starts))
+ start = np.random.randint(region_starts[region_ind],
+ region_ends[region_ind])
+
+ end_array = valid_array.copy()
+ min_end = max(start + min_len, min_end)
+ end_array[:min_end] = 0
+ end_array[-1] = 1
+ diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
+ region_starts = np.where(diff_array < 0)[0]
+ region_ends = np.where(diff_array > 0)[0]
+ region_ind = np.random.randint(0, len(region_starts))
+ end = np.random.randint(region_starts[region_ind],
+ region_ends[region_ind])
+ return start, end
+
+ def sample_crop_box(self, img_size, masks):
+ """Generate crop box and make sure not to crop the polygon instances.
+
+ Args:
+ img_size (tuple(int)): The image size.
+ masks (list[list[ndarray]]): The polygon masks.
+ """
+
+ assert isinstance(img_size, tuple)
+ h, w = img_size[:2]
+
+ x_valid_array = np.ones(w, dtype=np.int32)
+ y_valid_array = np.ones(h, dtype=np.int32)
+
+ selected_mask = masks[np.random.randint(0, len(masks))]
+ selected_mask = selected_mask[0].reshape((-1, 2)).astype(np.int32)
+ max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
+ min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
+ max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
+ min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
+
+ for mask in masks:
+ assert len(mask) == 1
+ mask = mask[0].reshape((-1, 2)).astype(np.int32)
+ clip_x = np.clip(mask[:, 0], 0, w - 1)
+ clip_y = np.clip(mask[:, 1], 0, h - 1)
+ min_x, max_x = np.min(clip_x), np.max(clip_x)
+ min_y, max_y = np.min(clip_y), np.max(clip_y)
+
+ x_valid_array[min_x - 2:max_x + 3] = 0
+ y_valid_array[min_y - 2:max_y + 3] = 0
+
+ min_w = int(w * self.min_side_ratio)
+ min_h = int(h * self.min_side_ratio)
+
+ x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start,
+ min_x_end)
+ y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start,
+ min_y_end)
+
+ return np.array([x1, y1, x2, y2])
+
+ def crop_img(self, img, bbox):
+ assert img.ndim == 3
+ h, w, _ = img.shape
+ assert 0 <= bbox[1] < bbox[3] <= h
+ assert 0 <= bbox[0] < bbox[2] <= w
+ return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
+
+ def __call__(self, results):
+ if np.random.random_sample() < self.crop_ratio:
+ crop_box = self.sample_crop_box(results['img'].shape,
+ results[self.instance_key].masks)
+ results['crop_region'] = crop_box
+ img = self.crop_img(results['img'], crop_box)
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ # crop and filter masks
+ x1, y1, x2, y2 = crop_box
+ w = max(x2 - x1, 1)
+ h = max(y2 - y1, 1)
+ labels = results['gt_labels']
+ valid_labels = []
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ results[key] = results[key].crop(crop_box)
+ # filter out polygons beyond crop box.
+ masks = results[key].masks
+ valid_masks_list = []
+
+ for ind, mask in enumerate(masks):
+ assert len(mask) == 1
+ polygon = mask[0].reshape((-1, 2))
+ if (polygon[:, 0] >
+ -4).all() and (polygon[:, 0] < w + 4).all() and (
+ polygon[:, 1] > -4).all() and (polygon[:, 1] <
+ h + 4).all():
+ mask[0][::2] = np.clip(mask[0][::2], 0, w)
+ mask[0][1::2] = np.clip(mask[0][1::2], 0, h)
+ if key == self.instance_key:
+ valid_labels.append(labels[ind])
+ valid_masks_list.append(mask)
+
+ results[key] = PolygonMasks(valid_masks_list, h, w)
+ results['gt_labels'] = np.array(valid_labels)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomRotatePolyInstances:
+
+ def __init__(self,
+ rotate_ratio=0.5,
+ max_angle=10,
+ pad_with_fixed_color=False,
+ pad_value=(0, 0, 0)):
+ """Randomly rotate images and polygon masks.
+
+ Args:
+ rotate_ratio (float): The ratio of samples to operate rotation.
+ max_angle (int): The maximum rotation angle.
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
+ image with fixed value. If set to False, the rotated image will
+ be padded onto cropped image.
+ pad_value (tuple(int)): The color value for padding rotated image.
+ """
+ self.rotate_ratio = rotate_ratio
+ self.max_angle = max_angle
+ self.pad_with_fixed_color = pad_with_fixed_color
+ self.pad_value = pad_value
+
+ def rotate(self, center, points, theta, center_shift=(0, 0)):
+ # rotate points.
+ (center_x, center_y) = center
+ center_y = -center_y
+ x, y = points[::2], points[1::2]
+ y = -y
+
+ theta = theta / 180 * math.pi
+ cos = math.cos(theta)
+ sin = math.sin(theta)
+
+ x = (x - center_x)
+ y = (y - center_y)
+
+ _x = center_x + x * cos - y * sin + center_shift[0]
+ _y = -(center_y + x * sin + y * cos) + center_shift[1]
+
+ points[::2], points[1::2] = _x, _y
+ return points
+
+ def cal_canvas_size(self, ori_size, degree):
+ assert isinstance(ori_size, tuple)
+ angle = degree * math.pi / 180.0
+ h, w = ori_size[:2]
+
+ cos = math.cos(angle)
+ sin = math.sin(angle)
+ canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos))
+ canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin))
+
+ canvas_size = (canvas_h, canvas_w)
+ return canvas_size
+
+ def sample_angle(self, max_angle):
+ angle = np.random.random_sample() * 2 * max_angle - max_angle
+ return angle
+
+ def rotate_img(self, img, angle, canvas_size):
+ h, w = img.shape[:2]
+ rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
+ rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2)
+ rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2)
+
+ if self.pad_with_fixed_color:
+ target_img = cv2.warpAffine(
+ img,
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
+ flags=cv2.INTER_NEAREST,
+ borderValue=self.pad_value)
+ else:
+ mask = np.zeros_like(img)
+ (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
+ np.random.randint(0, w * 7 // 8))
+ img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+ img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0]))
+ mask = cv2.warpAffine(
+ mask,
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
+ borderValue=[1, 1, 1])
+ target_img = cv2.warpAffine(
+ img,
+ rotation_matrix, (canvas_size[1], canvas_size[0]),
+ borderValue=[0, 0, 0])
+ target_img = target_img + img_cut * mask
+
+ return target_img
+
+ def __call__(self, results):
+ if np.random.random_sample() < self.rotate_ratio:
+ img = results['img']
+ h, w = img.shape[:2]
+ angle = self.sample_angle(self.max_angle)
+ canvas_size = self.cal_canvas_size((h, w), angle)
+ center_shift = (int(
+ (canvas_size[1] - w) / 2), int((canvas_size[0] - h) / 2))
+
+ # rotate image
+ results['rotated_poly_angle'] = angle
+ img = self.rotate_img(img, angle, canvas_size)
+ results['img'] = img
+ img_shape = img.shape
+ results['img_shape'] = img_shape
+
+ # rotate polygons
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ masks = results[key].masks
+ rotated_masks = []
+ for mask in masks:
+ rotated_mask = self.rotate((w / 2, h / 2), mask[0], angle,
+ center_shift)
+ rotated_masks.append([rotated_mask])
+
+ results[key] = PolygonMasks(rotated_masks, *(img_shape[:2]))
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+@PIPELINES.register_module()
+class SquareResizePad:
+
+ def __init__(self,
+ target_size,
+ pad_ratio=0.6,
+ pad_with_fixed_color=False,
+ pad_value=(0, 0, 0)):
+ """Resize or pad images to be square shape.
+
+ Args:
+ target_size (int): The target size of square shaped image.
+ pad_with_fixed_color (bool): The flag for whether to pad rotated
+ image with fixed value. If set to False, the rescales image will
+ be padded onto cropped image.
+ pad_value (tuple(int)): The color value for padding rotated image.
+ """
+ assert isinstance(target_size, int)
+ assert isinstance(pad_ratio, float)
+ assert isinstance(pad_with_fixed_color, bool)
+ assert isinstance(pad_value, tuple)
+
+ self.target_size = target_size
+ self.pad_ratio = pad_ratio
+ self.pad_with_fixed_color = pad_with_fixed_color
+ self.pad_value = pad_value
+
+ def resize_img(self, img, keep_ratio=True):
+ h, w, _ = img.shape
+ if keep_ratio:
+ t_h = self.target_size if h >= w else int(h * self.target_size / w)
+ t_w = self.target_size if h <= w else int(w * self.target_size / h)
+ else:
+ t_h = t_w = self.target_size
+ img = cv2.resize(img, (t_w, t_h))
+ return img, (t_h, t_w)
+
+ def square_pad(self, img):
+ h, w = img.shape[:2]
+ if h == w:
+ return img, (0, 0)
+ pad_size = max(h, w)
+ if self.pad_with_fixed_color:
+ expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8)
+ expand_img[:] = self.pad_value
+ else:
+ (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8),
+ np.random.randint(0, w * 7 // 8))
+ img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)]
+ expand_img = cv2.resize(img_cut, (pad_size, pad_size))
+ if h > w:
+ y0, x0 = 0, (h - w) // 2
+ else:
+ y0, x0 = (w - h) // 2, 0
+ expand_img[y0:y0 + h, x0:x0 + w] = img
+ offset = (x0, y0)
+
+ return expand_img, offset
+
+ def square_pad_mask(self, points, offset):
+ x0, y0 = offset
+ pad_points = points.copy()
+ pad_points[::2] = pad_points[::2] + x0
+ pad_points[1::2] = pad_points[1::2] + y0
+ return pad_points
+
+ def __call__(self, results):
+ img = results['img']
+
+ if np.random.random_sample() < self.pad_ratio:
+ img, out_size = self.resize_img(img, keep_ratio=True)
+ img, offset = self.square_pad(img)
+ else:
+ img, out_size = self.resize_img(img, keep_ratio=False)
+ offset = (0, 0)
+
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ for key in results.get('mask_fields', []):
+ if len(results[key].masks) == 0:
+ continue
+ results[key] = results[key].resize(out_size)
+ masks = results[key].masks
+ processed_masks = []
+ for mask in masks:
+ square_pad_mask = self.square_pad_mask(mask[0], offset)
+ processed_masks.append([square_pad_mask])
+
+ results[key] = PolygonMasks(processed_masks, *(img.shape[:2]))
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
diff --git a/mmocr/datasets/text_det_dataset.py b/mmocr/datasets/text_det_dataset.py
new file mode 100644
index 00000000..65a2d9d9
--- /dev/null
+++ b/mmocr/datasets/text_det_dataset.py
@@ -0,0 +1,121 @@
+import numpy as np
+
+from mmdet.datasets.builder import DATASETS
+from mmocr.core.evaluation.hmean import eval_hmean
+from mmocr.datasets.base_dataset import BaseDataset
+
+
+@DATASETS.register_module()
+class TextDetDataset(BaseDataset):
+
+ def _parse_anno_info(self, annotations):
+ """Parse bbox and mask annotation.
+ Args:
+ annotations (dict): Annotations of one image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,
+ labels, masks, masks_ignore. "masks" and
+ "masks_ignore" are represented by polygon boundary
+ point sequences.
+ """
+ gt_bboxes, gt_bboxes_ignore = [], []
+ gt_masks, gt_masks_ignore = [], []
+ gt_labels = []
+ for ann in annotations:
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(ann['bbox'])
+ gt_masks_ignore.append(ann.get('segmentation', None))
+ else:
+ gt_bboxes.append(ann['bbox'])
+ gt_labels.append(ann['category_id'])
+ gt_masks.append(ann.get('segmentation', None))
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks_ignore=gt_masks_ignore,
+ masks=gt_masks)
+
+ return ann
+
+ def prepare_train_img(self, index):
+ """Get training data and annotations from pipeline.
+
+ Args:
+ index (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+ img_ann_info = self.data_infos[index]
+ img_info = {
+ 'filename': img_ann_info['file_name'],
+ 'height': img_ann_info['height'],
+ 'width': img_ann_info['width']
+ }
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ results = dict(img_info=img_info, ann_info=ann_info)
+ results['bbox_fields'] = []
+ results['mask_fields'] = []
+ results['seg_fields'] = []
+ self.pre_pipeline(results)
+
+ return self.pipeline(results)
+
+ def evaluate(self,
+ results,
+ metric='hmean-iou',
+ score_thr=0.3,
+ rank_list=None,
+ logger=None,
+ **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ score_thr (float): Score threshold for prediction map.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ rank_list (str): json file used to save eval result
+ of each image after ranking.
+ Returns:
+ dict[str: float]
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['hmean-iou', 'hmean-ic13']
+ metrics = set(metrics) & set(allowed_metrics)
+
+ img_infos = []
+ ann_infos = []
+ for i in range(len(self)):
+ img_ann_info = self.data_infos[i]
+ img_info = {'filename': img_ann_info['file_name']}
+ ann_info = self._parse_anno_info(img_ann_info['annotations'])
+ img_infos.append(img_info)
+ ann_infos.append(ann_info)
+
+ eval_results = eval_hmean(
+ results,
+ img_infos,
+ ann_infos,
+ metrics=metrics,
+ score_thr=score_thr,
+ logger=logger,
+ rank_list=rank_list)
+
+ return eval_results
diff --git a/mmocr/datasets/utils/__init__.py b/mmocr/datasets/utils/__init__.py
new file mode 100644
index 00000000..f014de7e
--- /dev/null
+++ b/mmocr/datasets/utils/__init__.py
@@ -0,0 +1,4 @@
+from .loader import HardDiskLoader, LmdbLoader
+from .parser import LineJsonParser, LineStrParser
+
+__all__ = ['HardDiskLoader', 'LmdbLoader', 'LineStrParser', 'LineJsonParser']
diff --git a/mmocr/datasets/utils/loader.py b/mmocr/datasets/utils/loader.py
new file mode 100644
index 00000000..55a4d075
--- /dev/null
+++ b/mmocr/datasets/utils/loader.py
@@ -0,0 +1,108 @@
+import os.path as osp
+
+from mmocr.datasets.builder import LOADERS, build_parser
+
+
+@LOADERS.register_module()
+class Loader:
+ """Load annotation from annotation file, and parse instance information to
+ dict format with parser.
+
+ Args:
+ ann_file (str): Annotation file path.
+ parser (dict): Dictionary to construct parser
+ to parse original annotation infos.
+ repeat (int): Repeated times of annotations.
+ """
+
+ def __init__(self, ann_file, parser, repeat=1):
+ assert isinstance(ann_file, str)
+ assert isinstance(repeat, int)
+ assert isinstance(parser, dict)
+ assert repeat > 0
+ assert osp.exists(ann_file), f'{ann_file} is not exist'
+
+ self.ori_data_infos = self._load(ann_file)
+ self.parser = build_parser(parser)
+ self.repeat = repeat
+
+ def __len__(self):
+ return len(self.ori_data_infos) * self.repeat
+
+ def _load(self, ann_file):
+ """Load annotation file."""
+ raise NotImplementedError
+
+ def __getitem__(self, index):
+ """Retrieve anno info of one instance with dict format."""
+ return self.parser.get_item(self.ori_data_infos, index)
+
+
+@LOADERS.register_module()
+class HardDiskLoader(Loader):
+ """Load annotation file from hard disk to RAM.
+
+ Args:
+ ann_file (str): Annotation file path.
+ """
+
+ def _load(self, ann_file):
+ data_ret = []
+ with open(ann_file, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ data_ret.append(line)
+
+ return data_ret
+
+
+@LOADERS.register_module()
+class LmdbLoader(Loader):
+ """Load annotation file with lmdb storage backend."""
+
+ def _load(self, ann_file):
+ lmdb_anno_obj = LmdbAnnFileBackend(ann_file)
+
+ return lmdb_anno_obj
+
+
+class LmdbAnnFileBackend:
+ """Lmdb storage backend for annotation file.
+
+ Args:
+ lmdb_path (str): Lmdb file path.
+ """
+
+ def __init__(self, lmdb_path, coding='utf8'):
+ self.lmdb_path = lmdb_path
+ self.coding = coding
+ env = self._get_env()
+ with env.begin(write=False) as txn:
+ self.total_number = int(
+ txn.get('total_number'.encode(self.coding)).decode(
+ self.coding))
+
+ def __getitem__(self, index):
+ """Retrieval one line from lmdb file by index."""
+ # only attach env to self when __getitem__ is called
+ # because env object cannot be pickle
+ if not hasattr(self, 'env'):
+ self.env = self._get_env()
+
+ with self.env.begin(write=False) as txn:
+ line = txn.get(str(index).encode(self.coding)).decode(self.coding)
+ return line
+
+ def __len__(self):
+ return self.total_number
+
+ def _get_env(self):
+ import lmdb
+ return lmdb.open(
+ self.lmdb_path,
+ max_readers=1,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
diff --git a/mmocr/datasets/utils/parser.py b/mmocr/datasets/utils/parser.py
new file mode 100644
index 00000000..a895e217
--- /dev/null
+++ b/mmocr/datasets/utils/parser.py
@@ -0,0 +1,69 @@
+import json
+
+from mmocr.datasets.builder import PARSERS
+
+
+@PARSERS.register_module()
+class LineStrParser:
+ """Parse string of one line in annotation file to dict format.
+
+ Args:
+ keys (list[str]): Keys in result dict.
+ keys_idx (list[int]): Value index in sub-string list
+ for each key above.
+ separator (str): Separator to separate string to list of sub-string.
+ """
+
+ def __init__(self,
+ keys=['filename', 'text'],
+ keys_idx=[0, 1],
+ separator=' '):
+ assert isinstance(keys, list)
+ assert isinstance(keys_idx, list)
+ assert isinstance(separator, str)
+ assert len(keys) > 0
+ assert len(keys) == len(keys_idx)
+ self.keys = keys
+ self.keys_idx = keys_idx
+ self.separator = separator
+
+ def get_item(self, data_ret, index):
+ map_index = index % len(data_ret)
+ line_str = data_ret[map_index]
+ for split_key in self.separator:
+ if split_key != ' ':
+ line_str = line_str.replace(split_key, ' ')
+ line_str = line_str.split()
+ if len(line_str) <= max(self.keys_idx):
+ raise Exception(
+ f'key index: {max(self.keys_idx)} out of range: {line_str}')
+
+ line_info = {}
+ for i, key in enumerate(self.keys):
+ line_info[key] = line_str[self.keys_idx[i]]
+ return line_info
+
+
+@PARSERS.register_module()
+class LineJsonParser:
+ """Parse json-string of one line in annotation file to dict format.
+
+ Args:
+ keys (list[str]): Keys in both json-string and result dict.
+ """
+
+ def __init__(self, keys=[], **kwargs):
+ assert isinstance(keys, list)
+ assert len(keys) > 0
+ self.keys = keys
+
+ def get_item(self, data_ret, index):
+ map_index = index % len(data_ret)
+ line_json_obj = json.loads(data_ret[map_index])
+ line_info = {}
+ for key in self.keys:
+ if key not in line_json_obj:
+ raise Exception(f'key {key} not in line json {line_json_obj}')
+ line_info[key] = line_json_obj[key]
+
+ return line_info
diff --git a/mmocr/models/__init__.py b/mmocr/models/__init__.py
new file mode 100644
index 00000000..529a6208
--- /dev/null
+++ b/mmocr/models/__init__.py
@@ -0,0 +1,16 @@
+from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
+ build_backbone, build_detector, build_loss)
+from .builder import (CONVERTORS, DECODERS, ENCODERS, PREPROCESSOR,
+ build_convertor, build_decoder, build_encoder,
+ build_preprocessor)
+from .common import * # noqa: F401,F403
+from .kie import * # noqa: F401,F403
+from .textdet import * # noqa: F401,F403
+from .textrecog import * # noqa: F401,F403
+
+__all__ = [
+ 'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone',
+ 'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS',
+ 'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder',
+ 'build_preprocessor'
+]
diff --git a/mmocr/models/builder.py b/mmocr/models/builder.py
new file mode 100644
index 00000000..cd606a8a
--- /dev/null
+++ b/mmocr/models/builder.py
@@ -0,0 +1,33 @@
+from mmcv.utils import Registry, build_from_cfg
+
+RECOGNIZERS = Registry('recognizer')
+CONVERTORS = Registry('convertor')
+ENCODERS = Registry('encoder')
+DECODERS = Registry('decoder')
+PREPROCESSOR = Registry('preprocessor')
+
+
+def build_recognizer(cfg, train_cfg=None, test_cfg=None):
+ """Build recognizer."""
+ return build_from_cfg(cfg, RECOGNIZERS,
+ dict(train_cfg=train_cfg, test_cfg=test_cfg))
+
+
+def build_convertor(cfg):
+ """Build label convertor for scene text recognizer."""
+ return build_from_cfg(cfg, CONVERTORS)
+
+
+def build_encoder(cfg):
+ """Build encoder for scene text recognizer."""
+ return build_from_cfg(cfg, ENCODERS)
+
+
+def build_decoder(cfg):
+ """Build decoder for scene text recognizer."""
+ return build_from_cfg(cfg, DECODERS)
+
+
+def build_preprocessor(cfg):
+ """Build preprocessor for scene text recognizer."""
+ return build_from_cfg(cfg, PREPROCESSOR)
diff --git a/mmocr/models/common/__init__.py b/mmocr/models/common/__init__.py
new file mode 100644
index 00000000..8d662575
--- /dev/null
+++ b/mmocr/models/common/__init__.py
@@ -0,0 +1,2 @@
+from .backbones import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
diff --git a/mmocr/models/common/backbones/__init__.py b/mmocr/models/common/backbones/__init__.py
new file mode 100644
index 00000000..de67ca96
--- /dev/null
+++ b/mmocr/models/common/backbones/__init__.py
@@ -0,0 +1,3 @@
+from .unet import UNet
+
+__all__ = ['UNet']
diff --git a/mmocr/models/common/backbones/unet.py b/mmocr/models/common/backbones/unet.py
new file mode 100644
index 00000000..cc11a8f8
--- /dev/null
+++ b/mmocr/models/common/backbones/unet.py
@@ -0,0 +1,529 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
+ build_norm_layer, build_upsample_layer, constant_init,
+ kaiming_init)
+from mmcv.runner import load_checkpoint
+from mmcv.utils.parrots_wrapper import _BatchNorm
+
+from mmdet.models.builder import BACKBONES
+from mmdet.utils import get_root_logger
+
+
+class UpConvBlock(nn.Module):
+ """Upsample convolution block in decoder for UNet.
+
+ This upsample convolution block consists of one upsample module
+ followed by one convolution block. The upsample module expands the
+ high-level low-resolution feature map and the convolution block fuses
+ the upsampled high-level low-resolution feature map and the low-level
+ high-resolution feature map from encoder.
+
+ Args:
+ conv_block (nn.Sequential): Sequential of convolutional layers.
+ in_channels (int): Number of input channels of the high-level
+ skip_channels (int): Number of input channels of the low-level
+ high-resolution feature map from encoder.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers in the conv_block.
+ Default: 2.
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
+ dilation (int): Dilation rate of convolutional layer in conv_block.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv'). If the size of
+ high-level feature map is the same as that of skip feature map
+ (low-level feature map from encoder), it does not need upsample the
+ high-level feature map and the upsample_cfg is None.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ conv_block,
+ in_channels,
+ skip_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ dcn=None,
+ plugins=None):
+ super().__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.conv_block = conv_block(
+ in_channels=2 * skip_channels,
+ out_channels=out_channels,
+ num_convs=num_convs,
+ stride=stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None)
+ if upsample_cfg is not None:
+ self.upsample = build_upsample_layer(
+ cfg=upsample_cfg,
+ in_channels=in_channels,
+ out_channels=skip_channels,
+ with_cp=with_cp,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.upsample = ConvModule(
+ in_channels,
+ skip_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, skip, x):
+ """Forward function."""
+
+ x = self.upsample(x)
+ out = torch.cat([skip, x], dim=1)
+ out = self.conv_block(out)
+
+ return out
+
+
+class BasicConvBlock(nn.Module):
+ """Basic convolutional block for UNet.
+
+ This module consists of several plain convolutional layers.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers. Default: 2.
+ stride (int): Whether use stride convolution to downsample
+ the input feature map. If stride=2, it only uses stride convolution
+ in the first convolutional layer to downsample the input feature
+ map. Options are 1 or 2. Default: 1.
+ dilation (int): Whether use dilated convolution to expand the
+ receptive field. Set dilation rate of each convolutional layer and
+ the dilation rate of the first convolutional layer is always 1.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ dcn=None,
+ plugins=None):
+ super().__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.with_cp = with_cp
+ convs = []
+ for i in range(num_convs):
+ convs.append(
+ ConvModule(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ dilation=1 if i == 0 else dilation,
+ padding=1 if i == 0 else dilation,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.convs, x)
+ else:
+ out = self.convs(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class DeconvModule(nn.Module):
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
+
+ This module uses deconvolution to upsample feature map in the decoder
+ of UNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ kernel_size=4,
+ scale_factor=2):
+ super().__init__()
+
+ assert (
+ kernel_size - scale_factor >= 0
+ and (kernel_size - scale_factor) % 2 == 0), (
+ f'kernel_size should be greater than or equal to scale_factor '
+ f'and (kernel_size - scale_factor) should be even numbers, '
+ f'while the kernel size is {kernel_size} and scale_factor is '
+ f'{scale_factor}.')
+
+ stride = scale_factor
+ padding = (kernel_size - scale_factor) // 2
+ self.with_cp = with_cp
+ deconv = nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+
+ _, norm = build_norm_layer(norm_cfg, out_channels)
+ activate = build_activation_layer(act_cfg)
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.deconv_upsamping, x)
+ else:
+ out = self.deconv_upsamping(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class InterpConv(nn.Module):
+ """Interpolation upsample module in decoder for UNet.
+
+ This module uses interpolation to upsample feature map in the decoder
+ of UNet. It consists of one interpolation upsample layer and one
+ convolutional layer. It can be one interpolation upsample layer followed
+ by one convolutional layer (conv_first=False) or one convolutional layer
+ followed by one interpolation upsample layer (conv_first=True).
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ conv_first (bool): Whether convolutional layer or interpolation
+ upsample layer first. Default: False. It means interpolation
+ upsample layer followed by one convolutional layer.
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+ stride (int): Stride of the convolutional layer. Default: 1.
+ padding (int): Padding of the convolutional layer. Default: 1.
+ upsample_cfg (dict): Interpolation config of the upsample layer.
+ Default: dict(
+ scale_factor=2, mode='bilinear', align_corners=False).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ conv_cfg=None,
+ conv_first=False,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ upsample_cfg=dict(
+ scale_factor=2, mode='bilinear', align_corners=False)):
+ super().__init__()
+
+ self.with_cp = with_cp
+ conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ upsample = nn.Upsample(**upsample_cfg)
+ if conv_first:
+ self.interp_upsample = nn.Sequential(conv, upsample)
+ else:
+ self.interp_upsample = nn.Sequential(upsample, conv)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.interp_upsample, x)
+ else:
+ out = self.interp_upsample(x)
+ return out
+
+
+@BACKBONES.register_module()
+class UNet(nn.Module):
+ """UNet backbone.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondence encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondence encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+
+ Notice:
+ The input image size should be divisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_divisible.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None):
+ super().__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, (
+ 'The length of strides should be equal to num_stages, '
+ f'while the strides is {strides}, the length of '
+ f'strides is {len(strides)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(enc_num_convs) == num_stages, (
+ 'The length of enc_num_convs should be equal to num_stages, '
+ f'while the enc_num_convs is {enc_num_convs}, the length of '
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(dec_num_convs) == (num_stages - 1), (
+ 'The length of dec_num_convs should be equal to (num_stages-1), '
+ f'while the dec_num_convs is {dec_num_convs}, the length of '
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(downsamples) == (num_stages - 1), (
+ 'The length of downsamples should be equal to (num_stages-1), '
+ f'while the downsamples is {downsamples}, the length of '
+ f'downsamples is {len(downsamples)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(enc_dilations) == num_stages, (
+ 'The length of enc_dilations should be equal to num_stages, '
+ f'while the enc_dilations is {enc_dilations}, the length of '
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '
+ f'{num_stages}.')
+ assert len(dec_dilations) == (num_stages - 1), (
+ 'The length of dec_dilations should be equal to (num_stages-1), '
+ f'while the dec_dilations is {dec_dilations}, the length of '
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '
+ f'{num_stages}.')
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+ self.base_channels = base_channels
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+
+ def forward(self, x):
+ self._check_input_divisible(x)
+ enc_outs = []
+ for enc in self.encoder:
+ x = enc(x)
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+
+ return dec_outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super().train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def _check_input_divisible(self, x):
+ h, w = x.shape[-2:]
+ whole_downsample_rate = 1
+ for i in range(1, self.num_stages):
+ if self.strides[i] == 2 or self.downsamples[i - 1]:
+ whole_downsample_rate *= 2
+ assert (
+ h % whole_downsample_rate == 0 and w % whole_downsample_rate == 0
+ ), (f'The input image size {(h, w)} should be divisible by the whole '
+ f'downsample rate {whole_downsample_rate}, when num_stages is '
+ f'{self.num_stages}, strides is {self.strides}, and downsamples '
+ f'is {self.downsamples}.')
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
diff --git a/mmocr/models/common/losses/__init__.py b/mmocr/models/common/losses/__init__.py
new file mode 100644
index 00000000..7daa7345
--- /dev/null
+++ b/mmocr/models/common/losses/__init__.py
@@ -0,0 +1,3 @@
+from .dice_loss import DiceLoss
+
+__all__ = ['DiceLoss']
diff --git a/mmocr/models/common/losses/dice_loss.py b/mmocr/models/common/losses/dice_loss.py
new file mode 100644
index 00000000..0dfdf009
--- /dev/null
+++ b/mmocr/models/common/losses/dice_loss.py
@@ -0,0 +1,30 @@
+import torch
+import torch.nn as nn
+
+from mmdet.models.builder import LOSSES
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ assert isinstance(eps, float)
+ self.eps = eps
+
+ def forward(self, pred, target, mask=None):
+
+ pred = pred.contiguous().view(pred.size()[0], -1)
+ target = target.contiguous().view(target.size()[0], -1)
+
+ if mask is not None:
+ mask = mask.contiguous().view(mask.size()[0], -1)
+ pred = pred * mask
+ target = target * mask
+
+ a = torch.sum(pred * target)
+ b = torch.sum(pred)
+ c = torch.sum(target)
+ d = (2 * a) / (b + c + self.eps)
+
+ return 1 - d
diff --git a/mmocr/models/kie/__init__.py b/mmocr/models/kie/__init__.py
new file mode 100644
index 00000000..46d98163
--- /dev/null
+++ b/mmocr/models/kie/__init__.py
@@ -0,0 +1,3 @@
+from .extractors import * # noqa: F401,F403
+from .heads import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
diff --git a/mmocr/models/kie/extractors/__init__.py b/mmocr/models/kie/extractors/__init__.py
new file mode 100644
index 00000000..f58541e6
--- /dev/null
+++ b/mmocr/models/kie/extractors/__init__.py
@@ -0,0 +1,3 @@
+from .sdmgr import SDMGR
+
+__all__ = ['SDMGR']
diff --git a/mmocr/models/kie/extractors/sdmgr.py b/mmocr/models/kie/extractors/sdmgr.py
new file mode 100644
index 00000000..c8275a28
--- /dev/null
+++ b/mmocr/models/kie/extractors/sdmgr.py
@@ -0,0 +1,154 @@
+import warnings
+
+import mmcv
+from torch import nn
+from torch.nn import functional as F
+
+from mmdet.core import bbox2roi
+from mmdet.models.builder import DETECTORS, build_roi_extractor
+from mmdet.models.detectors import SingleStageDetector
+from mmocr.core import imshow_edge_node
+
+
+@DETECTORS.register_module()
+class SDMGR(SingleStageDetector):
+ """The implementation of the paper: Spatial Dual-Modality Graph Reasoning
+ for Key Information Extraction. https://arxiv.org/abs/2103.14470.
+
+ Args:
+ visual_modality (bool): Whether use the visual modality.
+ class_list (None | str): Mapping file of class index to
+ class name. If None, class index will be shown in
+ `show_results`, else class name.
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7),
+ featmap_strides=[1]),
+ visual_modality=False,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ class_list=None):
+ super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
+ pretrained)
+ self.visual_modality = visual_modality
+ if visual_modality:
+ self.extractor = build_roi_extractor({
+ **extractor, 'out_channels':
+ self.backbone.base_channels
+ })
+ self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size'])
+ else:
+ self.extractor = None
+ self.class_list = class_list
+
+ def forward_train(self, img, img_metas, relations, texts, gt_bboxes,
+ gt_labels):
+ """
+ Args:
+ img (tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A list of image info dict where each dict
+ contains: 'img_shape', 'scale_factor', 'flip', and may also
+ contain 'filename', 'ori_shape', 'pad_shape', and
+ 'img_norm_cfg'. For details of the values of these keys,
+ please see :class:`mmdet.datasets.pipelines.Collect`.
+ relations (list[tensor]): Relations between bboxes.
+ texts (list[tensor]): Texts in bboxes.
+ gt_bboxes (list[tensor]): Each item is the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[tensor]): Class indices corresponding to each box.
+
+ Returns:
+ dict[str, tensor]: A dictionary of loss components.
+ """
+ x = self.extract_feat(img, gt_bboxes)
+ node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
+ return self.bbox_head.loss(node_preds, edge_preds, gt_labels)
+
+ def forward_test(self,
+ img,
+ img_metas,
+ relations,
+ texts,
+ gt_bboxes,
+ rescale=False):
+ x = self.extract_feat(img, gt_bboxes)
+ node_preds, edge_preds = self.bbox_head.forward(relations, texts, x)
+ return [
+ dict(
+ img_metas=img_metas,
+ nodes=F.softmax(node_preds, -1),
+ edges=F.softmax(edge_preds, -1))
+ ]
+
+ def extract_feat(self, img, gt_bboxes):
+ if self.visual_modality:
+ x = super().extract_feat(img)[-1]
+ feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes)))
+ return feats.view(feats.size(0), -1)
+ return None
+
+ def show_result(self,
+ img,
+ result,
+ boxes,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None,
+ **kwargs):
+ """Draw `result` on `img`.
+
+ Args:
+ img (str or tensor): The image to be displayed.
+ result (dict): The results to draw on `img`.
+ boxes (list): Bbox of img.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The output filename.
+ Default: None.
+
+ Returns:
+ img (tensor): Only if not `show` or `out_file`.
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+
+ idx_to_cls = {}
+ if self.class_list is not None:
+ with open(self.class_list, 'r') as fr:
+ for line in fr:
+ line = line.strip().split()
+ class_idx, class_label = line
+ idx_to_cls[class_idx] = class_label
+
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+
+ img = imshow_edge_node(
+ img,
+ result,
+ boxes,
+ idx_to_cls=idx_to_cls,
+ show=show,
+ win_name=win_name,
+ wait_time=wait_time,
+ out_file=out_file)
+
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, only '
+ 'result image will be returned')
+ return img
+
+ return img
diff --git a/mmocr/models/kie/heads/__init__.py b/mmocr/models/kie/heads/__init__.py
new file mode 100644
index 00000000..00a11469
--- /dev/null
+++ b/mmocr/models/kie/heads/__init__.py
@@ -0,0 +1,3 @@
+from .sdmgr_head import SDMGRHead
+
+__all__ = ['SDMGRHead']
diff --git a/mmocr/models/kie/heads/sdmgr_head.py b/mmocr/models/kie/heads/sdmgr_head.py
new file mode 100644
index 00000000..3fc0ba36
--- /dev/null
+++ b/mmocr/models/kie/heads/sdmgr_head.py
@@ -0,0 +1,193 @@
+import torch
+from mmcv.cnn import normal_init
+from torch import nn
+from torch.nn import functional as F
+
+from mmdet.models.builder import HEADS, build_loss
+
+
+@HEADS.register_module()
+class SDMGRHead(nn.Module):
+
+ def __init__(self,
+ num_chars=92,
+ visual_dim=64,
+ fusion_dim=1024,
+ node_input=32,
+ node_embed=256,
+ edge_input=5,
+ edge_embed=256,
+ num_gnn=2,
+ num_classes=26,
+ loss=dict(type='SDMGRLoss'),
+ bidirectional=False,
+ train_cfg=None,
+ test_cfg=None):
+ super().__init__()
+
+ self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
+ self.node_embed = nn.Embedding(num_chars, node_input, 0)
+ hidden = node_embed // 2 if bidirectional else node_embed
+ self.rnn = nn.LSTM(
+ input_size=node_input,
+ hidden_size=hidden,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=bidirectional)
+ self.edge_embed = nn.Linear(edge_input, edge_embed)
+ self.gnn_layers = nn.ModuleList(
+ [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
+ self.node_cls = nn.Linear(node_embed, num_classes)
+ self.edge_cls = nn.Linear(edge_embed, 2)
+ self.loss = build_loss(loss)
+
+ def init_weights(self, pretrained=False):
+ normal_init(self.edge_embed, mean=0, std=0.01)
+
+ def forward(self, relations, texts, x=None):
+ node_nums, char_nums = [], []
+ for text in texts:
+ node_nums.append(text.size(0))
+ char_nums.append((text > 0).sum(-1))
+
+ max_num = max([char_num.max() for char_num in char_nums])
+ all_nodes = torch.cat([
+ torch.cat(
+ [text,
+ text.new_zeros(text.size(0), max_num - text.size(1))], -1)
+ for text in texts
+ ])
+ embed_nodes = self.node_embed(all_nodes.clamp(min=0).long())
+ rnn_nodes, _ = self.rnn(embed_nodes)
+
+ nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2])
+ all_nums = torch.cat(char_nums)
+ valid = all_nums > 0
+ nodes[valid] = rnn_nodes[valid].gather(
+ 1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand(
+ -1, -1, rnn_nodes.size(-1))).squeeze(1)
+
+ if x is not None:
+ nodes = self.fusion([x, nodes])
+
+ all_edges = torch.cat(
+ [rel.view(-1, rel.size(-1)) for rel in relations])
+ embed_edges = self.edge_embed(all_edges.float())
+ embed_edges = F.normalize(embed_edges)
+
+ for gnn_layer in self.gnn_layers:
+ nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
+
+ node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
+ return node_cls, edge_cls
+
+
+class GNNLayer(nn.Module):
+
+ def __init__(self, node_dim=256, edge_dim=256):
+ super().__init__()
+ self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
+ self.coef_fc = nn.Linear(node_dim, 1)
+ self.out_fc = nn.Linear(node_dim, node_dim)
+ self.relu = nn.ReLU()
+
+ def forward(self, nodes, edges, nums):
+ start, cat_nodes = 0, []
+ for num in nums:
+ sample_nodes = nodes[start:start + num]
+ cat_nodes.append(
+ torch.cat([
+ sample_nodes.unsqueeze(1).expand(-1, num, -1),
+ sample_nodes.unsqueeze(0).expand(num, -1, -1)
+ ], -1).view(num**2, -1))
+ start += num
+ cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1)
+ cat_nodes = self.relu(self.in_fc(cat_nodes))
+ coefs = self.coef_fc(cat_nodes)
+
+ start, residuals = 0, []
+ for num in nums:
+ residual = F.softmax(
+ -torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 +
+ coefs[start:start + num**2].view(num, num, -1), 1)
+ residuals.append(
+ (residual *
+ cat_nodes[start:start + num**2].view(num, num, -1)).sum(1))
+ start += num**2
+
+ nodes += self.relu(self.out_fc(torch.cat(residuals)))
+ return nodes, cat_nodes
+
+
+class Block(nn.Module):
+
+ def __init__(self,
+ input_dims,
+ output_dim,
+ mm_dim=1600,
+ chunks=20,
+ rank=15,
+ shared=False,
+ dropout_input=0.,
+ dropout_pre_lin=0.,
+ dropout_output=0.,
+ pos_norm='before_cat'):
+ super().__init__()
+ self.rank = rank
+ self.dropout_input = dropout_input
+ self.dropout_pre_lin = dropout_pre_lin
+ self.dropout_output = dropout_output
+ assert (pos_norm in ['before_cat', 'after_cat'])
+ self.pos_norm = pos_norm
+ # Modules
+ self.linear0 = nn.Linear(input_dims[0], mm_dim)
+ self.linear1 = (
+ self.linear0 if shared else nn.Linear(input_dims[1], mm_dim))
+ self.merge_linears0 = nn.ModuleList()
+ self.merge_linears1 = nn.ModuleList()
+ self.chunks = self.chunk_sizes(mm_dim, chunks)
+ for size in self.chunks:
+ ml0 = nn.Linear(size, size * rank)
+ self.merge_linears0.append(ml0)
+ ml1 = ml0 if shared else nn.Linear(size, size * rank)
+ self.merge_linears1.append(ml1)
+ self.linear_out = nn.Linear(mm_dim, output_dim)
+
+ def forward(self, x):
+ x0 = self.linear0(x[0])
+ x1 = self.linear1(x[1])
+ bs = x1.size(0)
+ if self.dropout_input > 0:
+ x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
+ x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
+ x0_chunks = torch.split(x0, self.chunks, -1)
+ x1_chunks = torch.split(x1, self.chunks, -1)
+ zs = []
+ for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks,
+ self.merge_linears0,
+ self.merge_linears1):
+ m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
+ m = m.view(bs, self.rank, -1)
+ z = torch.sum(m, 1)
+ if self.pos_norm == 'before_cat':
+ z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
+ z = F.normalize(z)
+ zs.append(z)
+ z = torch.cat(zs, 1)
+ if self.pos_norm == 'after_cat':
+ z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z))
+ z = F.normalize(z)
+
+ if self.dropout_pre_lin > 0:
+ z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
+ z = self.linear_out(z)
+ if self.dropout_output > 0:
+ z = F.dropout(z, p=self.dropout_output, training=self.training)
+ return z
+
+ @staticmethod
+ def chunk_sizes(dim, chunks):
+ split_size = (dim + chunks - 1) // chunks
+ sizes_list = [split_size] * chunks
+ sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
+ return sizes_list
diff --git a/mmocr/models/kie/losses/__init__.py b/mmocr/models/kie/losses/__init__.py
new file mode 100644
index 00000000..96b4afde
--- /dev/null
+++ b/mmocr/models/kie/losses/__init__.py
@@ -0,0 +1,3 @@
+from .sdmgr_loss import SDMGRLoss
+
+__all__ = ['SDMGRLoss']
diff --git a/mmocr/models/kie/losses/sdmgr_loss.py b/mmocr/models/kie/losses/sdmgr_loss.py
new file mode 100644
index 00000000..9e1d2312
--- /dev/null
+++ b/mmocr/models/kie/losses/sdmgr_loss.py
@@ -0,0 +1,39 @@
+import torch
+from torch import nn
+
+from mmdet.models.builder import LOSSES
+from mmdet.models.losses import accuracy
+
+
+@LOSSES.register_module()
+class SDMGRLoss(nn.Module):
+ """The implementation the loss of key information extraction proposed in
+ the paper: Spatial Dual-Modality Graph Reasoning for Key Information
+ Extraction.
+
+ https://arxiv.org/abs/2103.14470.
+ """
+
+ def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
+ super().__init__()
+ self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
+ self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
+ self.node_weight = node_weight
+ self.edge_weight = edge_weight
+ self.ignore = ignore
+
+ def forward(self, node_preds, edge_preds, gts):
+ node_gts, edge_gts = [], []
+ for gt in gts:
+ node_gts.append(gt[:, 0])
+ edge_gts.append(gt[:, 1:].contiguous().view(-1))
+ node_gts = torch.cat(node_gts).long()
+ edge_gts = torch.cat(edge_gts).long()
+
+ node_valids = torch.nonzero(node_gts != self.ignore).view(-1)
+ edge_valids = torch.nonzero(edge_gts != -1).view(-1)
+ return dict(
+ loss_node=self.node_weight * self.loss_node(node_preds, node_gts),
+ loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts),
+ acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]),
+ acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids]))
diff --git a/mmocr/models/textdet/__init__.py b/mmocr/models/textdet/__init__.py
new file mode 100644
index 00000000..bf95e0f7
--- /dev/null
+++ b/mmocr/models/textdet/__init__.py
@@ -0,0 +1,5 @@
+from .dense_heads import * # noqa: F401,F403
+from .detectors import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .postprocess import * # noqa: F401,F403
diff --git a/mmocr/models/textdet/dense_heads/__init__.py b/mmocr/models/textdet/dense_heads/__init__.py
new file mode 100644
index 00000000..8f227b80
--- /dev/null
+++ b/mmocr/models/textdet/dense_heads/__init__.py
@@ -0,0 +1,7 @@
+from .db_head import DBHead
+from .head_mixin import HeadMixin
+from .pan_head import PANHead
+from .pse_head import PSEHead
+from .textsnake_head import TextSnakeHead
+
+__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead']
diff --git a/mmocr/models/textdet/dense_heads/db_head.py b/mmocr/models/textdet/dense_heads/db_head.py
new file mode 100644
index 00000000..f32b296c
--- /dev/null
+++ b/mmocr/models/textdet/dense_heads/db_head.py
@@ -0,0 +1,86 @@
+import torch
+import torch.nn as nn
+
+from mmdet.models.builder import HEADS, build_loss
+from .head_mixin import HeadMixin
+
+
+@HEADS.register_module()
+class DBHead(HeadMixin, nn.Module):
+ """The class for DBNet head.
+
+ This was partially adapted from https://github.com/MhLiao/DB
+ """
+
+ def __init__(self,
+ in_channels,
+ with_bias=False,
+ decoding_type='db',
+ text_repr_type='poly',
+ downsample_ratio=1.0,
+ loss=dict(type='DBLoss'),
+ train_cfg=None,
+ test_cfg=None):
+ """Initialization.
+
+ Args:
+ in_channels (int): The number of input channels of the db head.
+ decoding_type (str): The type of decoder for dbnet.
+ text_repr_type (str): Boundary encoding type 'poly' or 'quad'.
+ downsample_ratio (float): The downsample ratio of ground truths.
+ loss (dict): The type of loss for dbnet.
+ """
+ super().__init__()
+
+ assert isinstance(in_channels, int)
+
+ self.in_channels = in_channels
+ self.text_repr_type = text_repr_type
+ self.loss_module = build_loss(loss)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.downsample_ratio = downsample_ratio
+ self.decoding_type = decoding_type
+
+ self.binarize = nn.Sequential(
+ nn.Conv2d(
+ in_channels, in_channels // 4, 3, bias=with_bias, padding=1),
+ nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),
+ nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid())
+
+ self.threshold = self._init_thr(in_channels)
+
+ def init_weights(self):
+ self.binarize.apply(self.init_class_parameters)
+ self.threshold.apply(self.init_class_parameters)
+
+ def init_class_parameters(self, m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.kaiming_normal_(m.weight.data)
+ elif classname.find('BatchNorm') != -1:
+ m.weight.data.fill_(1.)
+ m.bias.data.fill_(1e-4)
+
+ def diff_binarize(self, prob_map, thr_map, k):
+ return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map)))
+
+ def forward(self, inputs):
+ prob_map = self.binarize(inputs)
+ thr_map = self.threshold(inputs)
+ binary_map = self.diff_binarize(prob_map, thr_map, k=50)
+ outputs = torch.cat((prob_map, thr_map, binary_map), dim=1)
+ return (outputs, )
+
+ def _init_thr(self, inner_channels, bias=False):
+ in_channels = inner_channels
+ seq = nn.Sequential(
+ nn.Conv2d(
+ in_channels, inner_channels // 4, 3, padding=1, bias=bias),
+ nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
+ nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
+ return seq
diff --git a/mmocr/models/textdet/dense_heads/head_mixin.py b/mmocr/models/textdet/dense_heads/head_mixin.py
new file mode 100644
index 00000000..567d334d
--- /dev/null
+++ b/mmocr/models/textdet/dense_heads/head_mixin.py
@@ -0,0 +1,74 @@
+import numpy as np
+
+from mmdet.models.builder import HEADS
+from mmocr.models.textdet.postprocess import decode
+from mmocr.utils import check_argument
+
+
+@HEADS.register_module()
+class HeadMixin:
+ """The head minxin for dbnet and pannet heads."""
+
+ def resize_boundary(self, boundaries, scale_factor):
+ """Rescale boundaries via scale_factor.
+
+ Args:
+ boundaries (list[list[float]]): The boundary list. Each boundary
+ with size 2k+1 with k>=4.
+ scale_factor(ndarray): The scale factor of size (4,).
+
+ Returns:
+ boundaries (list[list[float]]): The scaled boundaries.
+ """
+ assert check_argument.is_2dlist(boundaries)
+ assert isinstance(scale_factor, np.ndarray)
+ assert scale_factor.shape[0] == 4
+
+ for b in boundaries:
+ sz = len(b)
+ check_argument.valid_boundary(b, True)
+ b[:sz -
+ 1] = (np.array(b[:sz - 1]) *
+ (np.tile(scale_factor[:2], int(
+ (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
+ return boundaries
+
+ def get_boundary(self, score_maps, img_metas, rescale):
+ """Compute text boundaries via post processing.
+
+ Args:
+ score_maps (Tensor): The text score map.
+ img_metas (dict): The image meta info.
+ rescale (bool): Rescale boundaries to the original image resolution
+ if true, and keep the score_maps resolution if false.
+
+ Returns:
+ results (dict): The result dict.
+ """
+
+ assert check_argument.is_type_list(img_metas, dict)
+ assert isinstance(rescale, bool)
+
+ score_maps = score_maps.squeeze()
+ boundaries = decode(
+ decoding_type=self.decoding_type,
+ preds=score_maps,
+ text_repr_type=self.text_repr_type)
+ if rescale:
+ boundaries = self.resize_boundary(
+ boundaries,
+ 1.0 / self.downsample_ratio / img_metas[0]['scale_factor'])
+ results = dict(boundary_result=boundaries)
+ return results
+
+ def loss(self, pred_maps, **kwargs):
+ """Compute the loss for text detection.
+
+ Args:
+ pred_maps (tensor): The input score maps of NxCxHxW.
+
+ Returns:
+ losses (dict): The dict for losses.
+ """
+ losses = self.loss_module(pred_maps, self.downsample_ratio, **kwargs)
+ return losses
diff --git a/mmocr/models/textdet/dense_heads/pan_head.py b/mmocr/models/textdet/dense_heads/pan_head.py
new file mode 100644
index 00000000..08cc7cff
--- /dev/null
+++ b/mmocr/models/textdet/dense_heads/pan_head.py
@@ -0,0 +1,61 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import normal_init
+
+from mmdet.models.builder import HEADS, build_loss
+from mmocr.utils import check_argument
+from . import HeadMixin
+
+
+@HEADS.register_module()
+class PANHead(HeadMixin, nn.Module):
+ """The class for PANet head."""
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ text_repr_type='poly', # 'poly' or 'quad'
+ downsample_ratio=0.25,
+ loss=dict(type='PANLoss'),
+ train_cfg=None,
+ test_cfg=None):
+ super().__init__()
+
+ assert check_argument.is_type_list(in_channels, int)
+ assert isinstance(out_channels, int)
+ assert text_repr_type in ['poly', 'quad']
+ assert 0 <= downsample_ratio <= 1
+
+ self.loss_module = build_loss(loss)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.text_repr_type = text_repr_type
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.downsample_ratio = downsample_ratio
+ if loss['type'] == 'PANLoss':
+ self.decoding_type = 'pan'
+ elif loss['type'] == 'PSELoss':
+ self.decoding_type = 'pse'
+ else:
+ type = loss['type']
+ raise NotImplementedError(f'unsupported loss type {type}.')
+
+ self.out_conv = nn.Conv2d(
+ in_channels=np.sum(np.array(in_channels)),
+ out_channels=out_channels,
+ kernel_size=1)
+ self.init_weights()
+
+ def init_weights(self):
+ normal_init(self.out_conv, mean=0, std=0.01)
+
+ def forward(self, inputs):
+ if isinstance(inputs, tuple):
+ outputs = torch.cat(inputs, dim=1)
+ else:
+ outputs = inputs
+ outputs = self.out_conv(outputs)
+ return outputs
diff --git a/mmocr/models/textdet/dense_heads/pse_head.py b/mmocr/models/textdet/dense_heads/pse_head.py
new file mode 100644
index 00000000..db2a9925
--- /dev/null
+++ b/mmocr/models/textdet/dense_heads/pse_head.py
@@ -0,0 +1,25 @@
+from mmdet.models.builder import HEADS
+from . import PANHead
+
+
+@HEADS.register_module()
+class PSEHead(PANHead):
+ """The class for PANet head."""
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ text_repr_type='poly', # 'poly' or 'quad'
+ downsample_ratio=0.25,
+ loss=dict(type='PSELoss'),
+ train_cfg=None,
+ test_cfg=None):
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ text_repr_type=text_repr_type,
+ downsample_ratio=downsample_ratio,
+ loss=loss,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg)
diff --git a/mmocr/models/textdet/dense_heads/textsnake_head.py b/mmocr/models/textdet/dense_heads/textsnake_head.py
new file mode 100644
index 00000000..1645bba7
--- /dev/null
+++ b/mmocr/models/textdet/dense_heads/textsnake_head.py
@@ -0,0 +1,48 @@
+import torch.nn as nn
+from mmcv.cnn import normal_init
+
+from mmdet.models.builder import HEADS, build_loss
+from . import HeadMixin
+
+
+@HEADS.register_module()
+class TextSnakeHead(HeadMixin, nn.Module):
+ """The class for TextSnake head: TextSnake: A Flexible Representation for
+ Detecting Text of Arbitrary Shapes.
+
+ [https://arxiv.org/abs/1807.01544]
+ """
+
+ def __init__(self,
+ in_channels,
+ decoding_type='textsnake',
+ text_repr_type='poly',
+ loss=dict(type='TextSnakeLoss'),
+ train_cfg=None,
+ test_cfg=None):
+ super().__init__()
+
+ assert isinstance(in_channels, int)
+ self.in_channels = in_channels
+ self.out_channels = 5
+ self.downsample_ratio = 1.0
+ self.decoding_type = decoding_type
+ self.text_repr_type = text_repr_type
+ self.loss_module = build_loss(loss)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.out_conv = nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.init_weights()
+
+ def init_weights(self):
+ normal_init(self.out_conv, mean=0, std=0.01)
+
+ def forward(self, inputs):
+ outputs = self.out_conv(inputs)
+ return outputs
diff --git a/mmocr/models/textdet/detectors/__init__.py b/mmocr/models/textdet/detectors/__init__.py
new file mode 100644
index 00000000..6aab9c73
--- /dev/null
+++ b/mmocr/models/textdet/detectors/__init__.py
@@ -0,0 +1,12 @@
+from .dbnet import DBNet
+from .ocr_mask_rcnn import OCRMaskRCNN
+from .panet import PANet
+from .psenet import PSENet
+from .single_stage_text_detector import SingleStageTextDetector
+from .text_detector_mixin import TextDetectorMixin
+from .textsnake import TextSnake
+
+__all__ = [
+ 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',
+ 'PANet', 'PSENet', 'TextSnake'
+]
diff --git a/mmocr/models/textdet/detectors/dbnet.py b/mmocr/models/textdet/detectors/dbnet.py
new file mode 100644
index 00000000..4b3c33d4
--- /dev/null
+++ b/mmocr/models/textdet/detectors/dbnet.py
@@ -0,0 +1,26 @@
+from mmdet.models.builder import DETECTORS
+from mmocr.models.textdet.detectors.single_stage_text_detector import \
+ SingleStageTextDetector
+from mmocr.models.textdet.detectors.text_detector_mixin import \
+ TextDetectorMixin
+
+
+@DETECTORS.register_module()
+class DBNet(TextDetectorMixin, SingleStageTextDetector):
+ """The class for implementing DBNet text detector: Real-time Scene Text
+ Detection with Differentiable Binarization.
+
+ [https://arxiv.org/abs/1911.08947].
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ show_score=False):
+ SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
+ train_cfg, test_cfg, pretrained)
+ TextDetectorMixin.__init__(self, show_score)
diff --git a/mmocr/models/textdet/detectors/ocr_mask_rcnn.py b/mmocr/models/textdet/detectors/ocr_mask_rcnn.py
new file mode 100644
index 00000000..a9ce9cc8
--- /dev/null
+++ b/mmocr/models/textdet/detectors/ocr_mask_rcnn.py
@@ -0,0 +1,41 @@
+from mmdet.models.builder import DETECTORS
+from mmdet.models.detectors import MaskRCNN
+from mmocr.models.textdet.detectors.text_detector_mixin import \
+ TextDetectorMixin
+
+
+@DETECTORS.register_module()
+class OCRMaskRCNN(TextDetectorMixin, MaskRCNN):
+ """Mask RCNN tailored for OCR."""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ text_repr_type='quad',
+ show_score=False):
+ TextDetectorMixin.__init__(self, show_score)
+ MaskRCNN.__init__(
+ self,
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+ assert text_repr_type in ['quad', 'poly']
+ self.text_repr_type = text_repr_type
+
+ def simple_test(self, img, img_metas, proposals=None, rescale=False):
+
+ results = super().simple_test(img, img_metas, proposals, rescale)
+
+ boundaries = self.get_boundary(results[0])
+ boundaries = boundaries if isinstance(boundaries,
+ list) else [boundaries]
+ return boundaries
diff --git a/mmocr/models/textdet/detectors/panet.py b/mmocr/models/textdet/detectors/panet.py
new file mode 100644
index 00000000..befa52c8
--- /dev/null
+++ b/mmocr/models/textdet/detectors/panet.py
@@ -0,0 +1,26 @@
+from mmdet.models.builder import DETECTORS
+from mmocr.models.textdet.detectors.single_stage_text_detector import \
+ SingleStageTextDetector
+from mmocr.models.textdet.detectors.text_detector_mixin import \
+ TextDetectorMixin
+
+
+@DETECTORS.register_module()
+class PANet(TextDetectorMixin, SingleStageTextDetector):
+ """The class for implementing PANet text detector:
+
+ Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel
+ Aggregation Network [https://arxiv.org/abs/1908.05900].
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ show_score=False):
+ SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
+ train_cfg, test_cfg, pretrained)
+ TextDetectorMixin.__init__(self, show_score)
diff --git a/mmocr/models/textdet/detectors/psenet.py b/mmocr/models/textdet/detectors/psenet.py
new file mode 100644
index 00000000..7dccad4e
--- /dev/null
+++ b/mmocr/models/textdet/detectors/psenet.py
@@ -0,0 +1,26 @@
+from mmdet.models.builder import DETECTORS
+from mmocr.models.textdet.detectors.single_stage_text_detector import \
+ SingleStageTextDetector
+from mmocr.models.textdet.detectors.text_detector_mixin import \
+ TextDetectorMixin
+
+
+@DETECTORS.register_module()
+class PSENet(TextDetectorMixin, SingleStageTextDetector):
+ """The class for implementing PSENet text detector: Shape Robust Text
+ Detection with Progressive Scale Expansion Network.
+
+ [https://arxiv.org/abs/1806.02559].
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ show_score=False):
+ SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
+ train_cfg, test_cfg, pretrained)
+ TextDetectorMixin.__init__(self, show_score)
diff --git a/mmocr/models/textdet/detectors/single_stage_text_detector.py b/mmocr/models/textdet/detectors/single_stage_text_detector.py
new file mode 100644
index 00000000..3456479c
--- /dev/null
+++ b/mmocr/models/textdet/detectors/single_stage_text_detector.py
@@ -0,0 +1,45 @@
+from mmdet.models.builder import DETECTORS
+from mmdet.models.detectors import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class SingleStageTextDetector(SingleStageDetector):
+ """The class for implementing single stage text detector.
+
+ It is the parent class of PANet, PSENet, and DBNet.
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ SingleStageDetector.__init__(self, backbone, neck, bbox_head,
+ train_cfg, test_cfg, pretrained)
+
+ def forward_train(self, img, img_metas, **kwargs):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys, see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ x = self.extract_feat(img)
+ preds = self.bbox_head(x)
+ losses = self.bbox_head.loss(preds, **kwargs)
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x)
+ boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale)
+
+ return [boundaries]
diff --git a/mmocr/models/textdet/detectors/text_detector_mixin.py b/mmocr/models/textdet/detectors/text_detector_mixin.py
new file mode 100644
index 00000000..30ba6560
--- /dev/null
+++ b/mmocr/models/textdet/detectors/text_detector_mixin.py
@@ -0,0 +1,101 @@
+import warnings
+
+import mmcv
+
+from mmocr.core import imshow_pred_boundary, seg2boundary
+
+
+class TextDetectorMixin:
+ """The class for implementing text detector auxiliary methods."""
+
+ def __init__(self, show_score):
+ self.show_score = show_score
+
+ def get_boundary(self, results):
+ """Convert segmentation into text boundaries.
+
+ Args:
+ results (tuple): The result tuple. The first element is
+ segmentation while the second is its scores.
+
+ Returns:
+ results (dict): A result dict containing 'boundary_result'.
+ """
+
+ assert isinstance(results, tuple)
+
+ instance_num = len(results[1][0])
+ boundaries = []
+ for i in range(instance_num):
+ seg = results[1][0][i]
+ score = results[0][0][i][-1]
+ boundary = seg2boundary(seg, self.text_repr_type, score)
+ if boundary is not None:
+ boundaries.append(boundary)
+
+ results = dict(boundary_result=boundaries)
+ return results
+
+ def show_result(self,
+ img,
+ result,
+ score_thr=0.5,
+ bbox_color='green',
+ text_color='green',
+ thickness=1,
+ font_scale=0.5,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (dict): The results to draw over `img`.
+ score_thr (float, optional): Minimum score of bboxes to be shown.
+ Default: 0.3.
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.imshow_pred_boundary`
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ boundaries = None
+ labels = None
+ if 'boundary_result' in result.keys():
+ boundaries = result['boundary_result']
+ labels = [0] * len(boundaries)
+
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw bounding boxes
+ if boundaries is not None:
+ imshow_pred_boundary(
+ img,
+ boundaries,
+ labels,
+ score_thr=score_thr,
+ boundary_color=bbox_color,
+ text_color=text_color,
+ thickness=thickness,
+ font_scale=font_scale,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file,
+ show_score=self.show_score)
+
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, '
+ 'result image will be returned')
+ return img
diff --git a/mmocr/models/textdet/detectors/textsnake.py b/mmocr/models/textdet/detectors/textsnake.py
new file mode 100644
index 00000000..25f65abf
--- /dev/null
+++ b/mmocr/models/textdet/detectors/textsnake.py
@@ -0,0 +1,23 @@
+from mmdet.models.builder import DETECTORS
+from . import SingleStageTextDetector, TextDetectorMixin
+
+
+@DETECTORS.register_module()
+class TextSnake(TextDetectorMixin, SingleStageTextDetector):
+ """The class for implementing TextSnake text detector: TextSnake: A
+ Flexible Representation for Detecting Text of Arbitrary Shapes.
+
+ [https://arxiv.org/abs/1807.01544]
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ show_score=False):
+ SingleStageTextDetector.__init__(self, backbone, neck, bbox_head,
+ train_cfg, test_cfg, pretrained)
+ TextDetectorMixin.__init__(self, show_score)
diff --git a/mmocr/models/textdet/losses/__init__.py b/mmocr/models/textdet/losses/__init__.py
new file mode 100644
index 00000000..eaa4d9cd
--- /dev/null
+++ b/mmocr/models/textdet/losses/__init__.py
@@ -0,0 +1,6 @@
+from .db_loss import DBLoss
+from .pan_loss import PANLoss
+from .pse_loss import PSELoss
+from .textsnake_loss import TextSnakeLoss
+
+__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss']
diff --git a/mmocr/models/textdet/losses/db_loss.py b/mmocr/models/textdet/losses/db_loss.py
new file mode 100644
index 00000000..0081ecda
--- /dev/null
+++ b/mmocr/models/textdet/losses/db_loss.py
@@ -0,0 +1,169 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from mmdet.models.builder import LOSSES
+from mmocr.core.visualize import show_feature # noqa F401
+from mmocr.models.common.losses.dice_loss import DiceLoss
+
+
+@LOSSES.register_module()
+class DBLoss(nn.Module):
+ """The class for implementing DBNet loss.
+
+ This is partially adapted from https://github.com/MhLiao/DB.
+ """
+
+ def __init__(self,
+ alpha=1,
+ beta=1,
+ reduction='mean',
+ negative_ratio=3.0,
+ eps=1e-6,
+ bbce_loss=False):
+ """Initialization.
+
+ Args:
+ alpha (float): The binary loss coef.
+ beta (float): The threshold loss coef.
+ reduction (str): The way to reduce the loss.
+ negative_ratio (float): The ratio of positives to negatives.
+ eps (float): Epsilon in the threshold loss function.
+ bbce_loss (bool): Whether to use balanced bce for probability loss.
+ If False, dice loss will be used instead.
+ """
+ super().__init__()
+ assert reduction in ['mean',
+ 'sum'], " reduction must in ['mean','sum']"
+ self.alpha = alpha
+ self.beta = beta
+ self.reduction = reduction
+ self.negative_ratio = negative_ratio
+ self.eps = eps
+ self.bbce_loss = bbce_loss
+ self.dice_loss = DiceLoss(eps=eps)
+
+ def bitmasks2tensor(self, bitmasks, target_sz):
+ """Convert Bitmasks to tensor.
+
+ Args:
+ bitmasks (list[BitMasks]): The BitMasks list. Each item is for
+ one img.
+ target_sz (tuple(int, int)): The target tensor size of KxHxW
+ with K being the number of kernels.
+
+ Returns
+ result_tensors (list[tensor]): The list of kernel tensors. Each
+ element is for one kernel level.
+ """
+ assert isinstance(bitmasks, list)
+ assert isinstance(target_sz, tuple)
+
+ batch_size = len(bitmasks)
+ num_levels = len(bitmasks[0])
+
+ result_tensors = []
+
+ for level_inx in range(num_levels):
+ kernel = []
+ for batch_inx in range(batch_size):
+ mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
+ mask_sz = mask.shape
+ pad = [
+ 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
+ ]
+ mask = F.pad(mask, pad, mode='constant', value=0)
+ kernel.append(mask)
+ kernel = torch.stack(kernel)
+ result_tensors.append(kernel)
+
+ return result_tensors
+
+ def balance_bce_loss(self, pred, gt, mask):
+
+ positive = (gt * mask)
+ negative = ((1 - gt) * mask)
+ positive_count = int(positive.float().sum())
+ negative_count = min(
+ int(negative.float().sum()),
+ int(positive_count * self.negative_ratio))
+
+ assert gt.max() <= 1 and gt.min() >= 0
+ assert pred.max() <= 1 and pred.min() >= 0
+ loss = F.binary_cross_entropy(pred, gt, reduction='none')
+ positive_loss = loss * positive.float()
+ negative_loss = loss * negative.float()
+
+ negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
+
+ balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
+ positive_count + negative_count + self.eps)
+
+ return balance_loss
+
+ def l1_thr_loss(self, pred, gt, mask):
+ thr_loss = torch.abs((pred - gt) * mask).sum() / (
+ mask.sum() + self.eps)
+ return thr_loss
+
+ def forward(self, preds, downsample_ratio, gt_shrink, gt_shrink_mask,
+ gt_thr, gt_thr_mask):
+ """Compute DBNet loss.
+
+ Args:
+ preds (tensor): The output tensor with size of Nx3xHxW.
+ downsample_ratio (float): The downsample ratio for the
+ ground truths.
+ gt_shrink (list[BitmapMasks]): The mask list with each element
+ being the shrinked text mask for one img.
+ gt_shrink_mask (list[BitmapMasks]): The effective mask list with
+ each element being the shrinked effective mask for one img.
+ gt_thr (list[BitmapMasks]): The mask list with each element
+ being the threshold text mask for one img.
+ gt_thr_mask (list[BitmapMasks]): The effective mask list with
+ each element being the threshold effective mask for one img.
+
+ Returns:
+ results(dict): The dict for dbnet losses with loss_prob,
+ loss_db and loss_thresh.
+ """
+ assert isinstance(downsample_ratio, float)
+
+ assert isinstance(gt_shrink, list)
+ assert isinstance(gt_shrink_mask, list)
+ assert isinstance(gt_thr, list)
+ assert isinstance(gt_thr_mask, list)
+
+ preds = preds[0]
+
+ pred_prob = preds[:, 0, :, :]
+ pred_thr = preds[:, 1, :, :]
+ pred_db = preds[:, 2, :, :]
+ feature_sz = preds.size()
+
+ keys = ['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask']
+ gt = {}
+ for k in keys:
+ gt[k] = eval(k)
+ gt[k] = [item.rescale(downsample_ratio) for item in gt[k]]
+ gt[k] = self.bitmasks2tensor(gt[k], feature_sz[2:])
+ gt[k] = [item.to(preds.device) for item in gt[k]]
+ gt['gt_shrink'][0] = (gt['gt_shrink'][0] > 0).float()
+ if self.bbce_loss:
+ loss_prob = self.balance_bce_loss(pred_prob, gt['gt_shrink'][0],
+ gt['gt_shrink_mask'][0])
+ else:
+ loss_prob = self.dice_loss(pred_prob, gt['gt_shrink'][0],
+ gt['gt_shrink_mask'][0])
+
+ loss_db = self.dice_loss(pred_db, gt['gt_shrink'][0],
+ gt['gt_shrink_mask'][0])
+ loss_thr = self.l1_thr_loss(pred_thr, gt['gt_thr'][0],
+ gt['gt_thr_mask'][0])
+
+ results = dict(
+ loss_prob=self.alpha * loss_prob,
+ loss_db=loss_db,
+ loss_thr=self.beta * loss_thr)
+
+ return results
diff --git a/mmocr/models/textdet/losses/pan_loss.py b/mmocr/models/textdet/losses/pan_loss.py
new file mode 100644
index 00000000..c5ccd464
--- /dev/null
+++ b/mmocr/models/textdet/losses/pan_loss.py
@@ -0,0 +1,329 @@
+import itertools
+import warnings
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from mmdet.core import BitmapMasks
+from mmdet.models.builder import LOSSES
+from mmocr.utils import check_argument
+
+
+@LOSSES.register_module()
+class PANLoss(nn.Module):
+ """The class for implementing PANet loss: Efficient and Accurate Arbitrary-
+ Shaped Text Detection with Pixel Aggregation Network.
+
+ [https://arxiv.org/abs/1908.05900]. This was partially adapted from
+ https://github.com/WenmuZhou/PAN.pytorch
+ """
+
+ def __init__(self,
+ alpha=0.5,
+ beta=0.25,
+ delta_aggregation=0.5,
+ delta_discrimination=3,
+ ohem_ratio=3,
+ reduction='mean',
+ speedup_bbox_thr=-1):
+ """Initialization.
+
+ Args:
+ alpha (float): The kernel loss coef.
+ beta (float): The aggregation and discriminative loss coef.
+ delta_aggregation (float): The constant for aggregation loss.
+ delta_discrimination (float): The constant for discriminative loss.
+ ohem_ratio (float): The negative/positive ratio in ohem.
+ reduction (str): The way to reduce the loss.
+ speedup_bbox_thr (int): Speed up if speedup_bbox_thr >0
+ and 0.5).float() * (
+ gt['gt_mask'][0].float())
+ loss_kernels = self.dice_loss_with_logits(pred_kernels,
+ gt['gt_kernels'][1],
+ sampled_masks_kernel)
+ losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs]
+ if self.reduction == 'mean':
+ losses = [item.mean() for item in losses]
+ elif self.reduction == 'sum':
+ losses = [item.sum() for item in losses]
+ else:
+ raise NotImplementedError
+
+ coefs = [1, self.alpha, self.beta, self.beta]
+ losses = [item * scale for item, scale in zip(losses, coefs)]
+
+ results = dict()
+ results.update(
+ loss_text=losses[0],
+ loss_kernel=losses[1],
+ loss_aggregation=losses[2],
+ loss_discrimination=losses[3])
+ return results
+
+ def aggregation_discrimination_loss(self, gt_texts, gt_kernels,
+ inst_embeds):
+ """Compute the aggregation and discrimnative losses.
+
+ Args:
+ gt_texts (tensor): The ground truth text mask of size Nx1xHxW.
+ gt_kernels (tensor): The ground truth text kernel mask of
+ size Nx1xHxW.
+ inst_embeds(tensor): The text instance embedding tensor
+ of size Nx4xHxW.
+
+ Returns:
+ loss_aggrs (tensor): The aggregation loss before reduction.
+ loss_discrs (tensor): The discriminative loss before reduction.
+ """
+
+ batch_size = gt_texts.size()[0]
+ gt_texts = gt_texts.contiguous().reshape(batch_size, -1)
+ gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1)
+
+ assert inst_embeds.shape[1] == 4
+ inst_embeds = inst_embeds.contiguous().reshape(batch_size, 4, -1)
+
+ loss_aggrs = []
+ loss_discrs = []
+
+ for text, kernel, embed in zip(gt_texts, gt_kernels, inst_embeds):
+
+ # for each image
+ text_num = int(text.max().item())
+ loss_aggr_img = []
+ kernel_avgs = []
+ select_num = self.speedup_bbox_thr
+ if 0 < select_num < text_num:
+ inds = np.random.choice(
+ text_num, select_num, replace=False) + 1
+ else:
+ inds = range(1, text_num + 1)
+
+ for i in inds:
+ # for each text instance
+ kernel_i = (kernel == i) # 0.2ms
+ if kernel_i.sum() == 0 or (text == i).sum() == 0: # 0.2ms
+ continue
+
+ # compute G_Ki in Eq (2)
+ avg = embed[:, kernel_i].mean(1) # 0.5ms
+ kernel_avgs.append(avg)
+
+ embed_i = embed[:, text == i] # 0.6ms
+ # ||F(p) - G(K_i)|| - delta_aggregation, shape: nums
+ distance = (embed_i - avg.reshape(4, 1)).norm( # 0.5ms
+ 2, dim=0) - self.delta_aggregation
+ # compute D(p,K_i) in Eq (2)
+ hinge = torch.max(
+ distance,
+ torch.tensor(0, device=distance.device,
+ dtype=torch.float)).pow(2)
+
+ aggr = torch.log(hinge + 1).mean()
+ loss_aggr_img.append(aggr)
+
+ num_inst = len(loss_aggr_img)
+ if num_inst > 0:
+ loss_aggr_img = torch.stack(loss_aggr_img).mean()
+ else:
+ loss_aggr_img = torch.tensor(
+ 0, device=gt_texts.device, dtype=torch.float)
+ loss_aggrs.append(loss_aggr_img)
+
+ loss_discr_img = 0
+ for avg_i, avg_j in itertools.combinations(kernel_avgs, 2):
+ # delta_discrimination - ||G(K_i) - G(K_j)||
+ distance_ij = self.delta_discrimination - (avg_i -
+ avg_j).norm(2)
+ # D(K_i,K_j)
+ D_ij = torch.max(
+ distance_ij,
+ torch.tensor(
+ 0, device=distance_ij.device,
+ dtype=torch.float)).pow(2)
+ loss_discr_img += torch.log(D_ij + 1)
+
+ if num_inst > 1:
+ loss_discr_img /= (num_inst * (num_inst - 1))
+ else:
+ loss_discr_img = torch.tensor(
+ 0, device=gt_texts.device, dtype=torch.float)
+ if num_inst == 0:
+ warnings.warn('num of instance is 0')
+ loss_discrs.append(loss_discr_img)
+ return torch.stack(loss_aggrs), torch.stack(loss_discrs)
+
+ def dice_loss_with_logits(self, pred, target, mask):
+
+ smooth = 0.001
+
+ pred = torch.sigmoid(pred)
+ target[target <= 0.5] = 0
+ target[target > 0.5] = 1
+ pred = pred.contiguous().view(pred.size()[0], -1)
+ target = target.contiguous().view(target.size()[0], -1)
+ mask = mask.contiguous().view(mask.size()[0], -1)
+
+ pred = pred * mask
+ target = target * mask
+
+ a = torch.sum(pred * target, 1)
+ b = torch.sum(pred * pred, 1) + smooth
+ c = torch.sum(target * target, 1) + smooth
+ d = (2 * a) / (b + c)
+ return 1 - d
+
+ def ohem_img(self, text_score, gt_text, gt_mask):
+ """Sample the top-k maximal negative samples and all positive samples.
+
+ Args:
+ text_score (Tensor): The text score with size of HxW.
+ gt_text (Tensor): The ground truth text mask of HxW.
+ gt_mask (Tensor): The effective region mask of HxW.
+
+ Returns:
+ sampled_mask (Tensor): The sampled pixel mask of size HxW.
+ """
+ assert isinstance(text_score, torch.Tensor)
+ assert isinstance(gt_text, torch.Tensor)
+ assert isinstance(gt_mask, torch.Tensor)
+ assert len(text_score.shape) == 2
+ assert text_score.shape == gt_text.shape
+ assert gt_text.shape == gt_mask.shape
+
+ pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)(
+ torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item())
+ neg_num = (int)(torch.sum(gt_text <= 0.5).item())
+ neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))
+
+ if pos_num == 0 or neg_num == 0:
+ warnings.warn('pos_num = 0 or neg_num = 0')
+ return gt_mask.bool()
+
+ neg_score = text_score[gt_text <= 0.5]
+ neg_score_sorted, _ = torch.sort(neg_score, descending=True)
+ threshold = neg_score_sorted[neg_num - 1]
+ sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * (
+ gt_mask > 0.5)
+ return sampled_mask
+
+ def ohem_batch(self, text_scores, gt_texts, gt_mask):
+ """OHEM sampling for a batch of imgs.
+
+ Args:
+ text_scores (Tensor): The text scores of size NxHxW.
+ gt_texts (Tensor): The gt text masks of size NxHxW.
+ gt_mask (Tensor): The gt effective mask of size NxHxW.
+
+ Returns:
+ sampled_masks (Tensor): The sampled mask of size NxHxW.
+ """
+ assert isinstance(text_scores, torch.Tensor)
+ assert isinstance(gt_texts, torch.Tensor)
+ assert isinstance(gt_mask, torch.Tensor)
+ assert len(text_scores.shape) == 3
+ assert text_scores.shape == gt_texts.shape
+ assert gt_texts.shape == gt_mask.shape
+
+ sampled_masks = []
+ for i in range(text_scores.shape[0]):
+ sampled_masks.append(
+ self.ohem_img(text_scores[i], gt_texts[i], gt_mask[i]))
+
+ sampled_masks = torch.stack(sampled_masks)
+
+ return sampled_masks
diff --git a/mmocr/models/textdet/losses/pse_loss.py b/mmocr/models/textdet/losses/pse_loss.py
new file mode 100644
index 00000000..4580bfc4
--- /dev/null
+++ b/mmocr/models/textdet/losses/pse_loss.py
@@ -0,0 +1,104 @@
+from mmdet.core import BitmapMasks
+from mmdet.models.builder import LOSSES
+from mmocr.utils import check_argument
+from . import PANLoss
+
+
+@LOSSES.register_module()
+class PSELoss(PANLoss):
+ """The class for implementing PSENet loss: Shape Robust Text Detection with
+ Progressive Scale Expansion Network [https://arxiv.org/abs/1806.02559].
+
+ This is partially adapted from https://github.com/whai362/PSENet.
+ """
+
+ def __init__(self,
+ alpha=0.7,
+ ohem_ratio=3,
+ reduction='mean',
+ kernel_sample_type='adaptive'):
+ """Initialization.
+
+ Args:
+ alpha (float): alpha: The text loss coef;
+ (1-alpha): the kernel loss coef.
+ ohem_ratio (float): The negative/positive ratio in ohem.
+ reduction (str): The way to reduce the loss.
+ """
+ super().__init__()
+ assert reduction in ['mean',
+ 'sum'], " reduction must in ['mean','sum']"
+ self.alpha = alpha
+ self.ohem_ratio = ohem_ratio
+ self.reduction = reduction
+ self.kernel_sample_type = kernel_sample_type
+
+ def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask):
+ """Compute PSENet loss.
+
+ Args:
+ score_maps (tensor): The output tensor with size of Nx6xHxW.
+ gt_kernels (list[BitmapMasks]): The kernel list with each element
+ being the text kernel mask for one img.
+ gt_mask (list[BitmapMasks]): The effective mask list
+ with each element being the effective mask fo one img.
+ downsample_ratio (float): The downsample ratio between score_maps
+ and the input img.
+
+ Returns:
+ results (dict): The loss.
+ """
+
+ assert check_argument.is_type_list(gt_kernels, BitmapMasks)
+ assert check_argument.is_type_list(gt_mask, BitmapMasks)
+ assert isinstance(downsample_ratio, float)
+ losses = []
+
+ pred_texts = score_maps[:, 0, :, :]
+ pred_kernels = score_maps[:, 1:, :, :]
+ feature_sz = score_maps.size()
+
+ gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels]
+ gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:])
+ gt_kernels = [item.to(score_maps.device) for item in gt_kernels]
+
+ gt_mask = [item.rescale(downsample_ratio) for item in gt_mask]
+ gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:])
+ gt_mask = [item.to(score_maps.device) for item in gt_mask]
+
+ # compute text loss
+ sampled_masks_text = self.ohem_batch(pred_texts.detach(),
+ gt_kernels[0], gt_mask[0])
+ loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0],
+ sampled_masks_text)
+ losses.append(self.alpha * loss_texts)
+
+ # compute kernel loss
+ if self.kernel_sample_type == 'hard':
+ sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * (
+ gt_mask[0].float())
+ elif self.kernel_sample_type == 'adaptive':
+ sampled_masks_kernel = (pred_texts > 0).float() * (
+ gt_mask[0].float())
+ else:
+ raise NotImplementedError
+
+ num_kernel = pred_kernels.shape[1]
+ assert num_kernel == len(gt_kernels) - 1
+ loss_list = []
+ for inx in range(num_kernel):
+ loss_kernels = self.dice_loss_with_logits(
+ pred_kernels[:, inx, :, :], gt_kernels[1 + inx],
+ sampled_masks_kernel)
+ loss_list.append(loss_kernels)
+
+ losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list))
+
+ if self.reduction == 'mean':
+ losses = [item.mean() for item in losses]
+ elif self.reduction == 'sum':
+ losses = [item.sum() for item in losses]
+ else:
+ raise NotImplementedError
+ results = dict(loss_text=losses[0], loss_kernel=losses[1])
+ return results
diff --git a/mmocr/models/textdet/losses/textsnake_loss.py b/mmocr/models/textdet/losses/textsnake_loss.py
new file mode 100644
index 00000000..b12de444
--- /dev/null
+++ b/mmocr/models/textdet/losses/textsnake_loss.py
@@ -0,0 +1,181 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from mmdet.core import BitmapMasks
+from mmdet.models.builder import LOSSES
+from mmocr.utils import check_argument
+
+
+@LOSSES.register_module()
+class TextSnakeLoss(nn.Module):
+ """The class for implementing TextSnake loss:
+ TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes
+ [https://arxiv.org/abs/1807.01544].
+ This is partially adapted from
+ https://github.com/princewang1994/TextSnake.pytorch.
+ """
+
+ def __init__(self, ohem_ratio=3.0):
+ """Initialization.
+
+ Args:
+ ohem_ratio (float): The negative/positive ratio in ohem.
+ """
+ super().__init__()
+ self.ohem_ratio = ohem_ratio
+
+ def balanced_bce_loss(self, pred, gt, mask):
+
+ assert pred.shape == gt.shape == mask.shape
+ positive = gt * mask
+ negative = (1 - gt) * mask
+ positive_count = int(positive.float().sum())
+ gt = gt.float()
+ if positive_count > 0:
+ loss = F.binary_cross_entropy(pred, gt, reduction='none')
+ positive_loss = torch.sum(loss * positive.float())
+ negative_loss = loss * negative.float()
+ negative_count = min(
+ int(negative.float().sum()),
+ int(positive_count * self.ohem_ratio))
+ else:
+ positive_loss = torch.tensor(0.0, device=pred.device)
+ loss = F.binary_cross_entropy(pred, gt, reduction='none')
+ negative_loss = loss * negative.float()
+ negative_count = 100
+ negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
+
+ balance_loss = (positive_loss + torch.sum(negative_loss)) / (
+ float(positive_count + negative_count) + 1e-5)
+
+ return balance_loss
+
+ def bitmasks2tensor(self, bitmasks, target_sz):
+ """Convert Bitmasks to tensor.
+
+ Args:
+ bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
+ for one img.
+ target_sz (tuple(int, int)): The target tensor size HxW.
+
+ Returns
+ results (list[tensor]): The list of kernel tensors. Each
+ element is for one kernel level.
+ """
+ assert check_argument.is_type_list(bitmasks, BitmapMasks)
+ assert isinstance(target_sz, tuple)
+
+ batch_size = len(bitmasks)
+ num_masks = len(bitmasks[0])
+
+ results = []
+
+ for level_inx in range(num_masks):
+ kernel = []
+ for batch_inx in range(batch_size):
+ mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
+ # hxw
+ mask_sz = mask.shape
+ # left, right, top, bottom
+ pad = [
+ 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
+ ]
+ mask = F.pad(mask, pad, mode='constant', value=0)
+ kernel.append(mask)
+ kernel = torch.stack(kernel)
+ results.append(kernel)
+
+ return results
+
+ def forward(self, pred_maps, downsample_ratio, gt_text_mask,
+ gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map,
+ gt_cos_map):
+
+ assert isinstance(downsample_ratio, float)
+ assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
+ assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
+ assert check_argument.is_type_list(gt_mask, BitmapMasks)
+ assert check_argument.is_type_list(gt_radius_map, BitmapMasks)
+ assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
+ assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
+
+ pred_text_region = pred_maps[:, 0, :, :]
+ pred_center_region = pred_maps[:, 1, :, :]
+ pred_sin_map = pred_maps[:, 2, :, :]
+ pred_cos_map = pred_maps[:, 3, :, :]
+ pred_radius_map = pred_maps[:, 4, :, :]
+ feature_sz = pred_maps.size()
+ device = pred_maps.device
+
+ # bitmask 2 tensor
+ mapping = {
+ 'gt_text_mask': gt_text_mask,
+ 'gt_center_region_mask': gt_center_region_mask,
+ 'gt_mask': gt_mask,
+ 'gt_radius_map': gt_radius_map,
+ 'gt_sin_map': gt_sin_map,
+ 'gt_cos_map': gt_cos_map
+ }
+ gt = {}
+ for key, value in mapping.items():
+ gt[key] = value
+ if abs(downsample_ratio - 1.0) < 1e-2:
+ gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
+ else:
+ gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
+ gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
+ if key == 'gt_radius_map':
+ gt[key] = [item * downsample_ratio for item in gt[key]]
+ gt[key] = [item.to(device) for item in gt[key]]
+
+ scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
+ pred_sin_map = pred_sin_map * scale
+ pred_cos_map = pred_cos_map * scale
+
+ loss_text = self.balanced_bce_loss(
+ torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
+ gt['gt_mask'][0])
+
+ text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
+ loss_center_map = F.binary_cross_entropy(
+ torch.sigmoid(pred_center_region),
+ gt['gt_center_region_mask'][0].float(),
+ reduction='none')
+ if int(text_mask.sum()) > 0:
+ loss_center = torch.sum(
+ loss_center_map * text_mask) / torch.sum(text_mask)
+ else:
+ loss_center = torch.tensor(0.0, device=device)
+
+ center_mask = (gt['gt_center_region_mask'][0] *
+ gt['gt_mask'][0]).float()
+ if int(center_mask.sum()) > 0:
+ map_sz = pred_radius_map.size()
+ ones = torch.ones(map_sz, dtype=torch.float, device=device)
+ loss_radius = torch.sum(
+ F.smooth_l1_loss(
+ pred_radius_map / (gt['gt_radius_map'][0] + 1e-2),
+ ones,
+ reduction='none') * center_mask) / torch.sum(center_mask)
+ loss_sin = torch.sum(
+ F.smooth_l1_loss(
+ pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
+ center_mask) / torch.sum(center_mask)
+ loss_cos = torch.sum(
+ F.smooth_l1_loss(
+ pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
+ center_mask) / torch.sum(center_mask)
+ else:
+ loss_radius = torch.tensor(0.0, device=device)
+ loss_sin = torch.tensor(0.0, device=device)
+ loss_cos = torch.tensor(0.0, device=device)
+
+ results = dict(
+ loss_text=loss_text,
+ loss_center=loss_center,
+ loss_radius=loss_radius,
+ loss_sin=loss_sin,
+ loss_cos=loss_cos)
+
+ return results
diff --git a/mmocr/models/textdet/necks/__init__.py b/mmocr/models/textdet/necks/__init__.py
new file mode 100644
index 00000000..101bbbdb
--- /dev/null
+++ b/mmocr/models/textdet/necks/__init__.py
@@ -0,0 +1,6 @@
+from .fpem_ffm import FPEM_FFM
+from .fpn_cat import FPNC
+from .fpn_unet import FPN_UNET
+from .fpnf import FPNF
+
+__all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNET']
diff --git a/mmocr/models/textdet/necks/fpem_ffm.py b/mmocr/models/textdet/necks/fpem_ffm.py
new file mode 100644
index 00000000..3722649b
--- /dev/null
+++ b/mmocr/models/textdet/necks/fpem_ffm.py
@@ -0,0 +1,143 @@
+import torch.nn.functional as F
+from mmcv.cnn import xavier_init
+from torch import nn
+
+from mmdet.models.builder import NECKS
+
+
+class FPEM(nn.Module):
+ """FPN-like feature fusion module in PANet."""
+
+ def __init__(self, in_channels=128):
+ super().__init__()
+ self.up_add1 = SeparableConv2d(in_channels, in_channels, 1)
+ self.up_add2 = SeparableConv2d(in_channels, in_channels, 1)
+ self.up_add3 = SeparableConv2d(in_channels, in_channels, 1)
+ self.down_add1 = SeparableConv2d(in_channels, in_channels, 2)
+ self.down_add2 = SeparableConv2d(in_channels, in_channels, 2)
+ self.down_add3 = SeparableConv2d(in_channels, in_channels, 2)
+
+ def forward(self, c2, c3, c4, c5):
+ # upsample
+ c4 = self.up_add1(self._upsample_add(c5, c4))
+ c3 = self.up_add2(self._upsample_add(c4, c3))
+ c2 = self.up_add3(self._upsample_add(c3, c2))
+
+ # downsample
+ c3 = self.down_add1(self._upsample_add(c3, c2))
+ c4 = self.down_add2(self._upsample_add(c4, c3))
+ c5 = self.down_add3(self._upsample_add(c5, c4))
+ return c2, c3, c4, c5
+
+ def _upsample_add(self, x, y):
+ return F.interpolate(x, size=y.size()[2:]) + y
+
+
+class SeparableConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, stride=1):
+ super().__init__()
+
+ self.depthwise_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ groups=in_channels)
+ self.pointwise_conv = nn.Conv2d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1)
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+@NECKS.register_module()
+class FPEM_FFM(nn.Module):
+ """This code is from https://github.com/WenmuZhou/PAN.pytorch."""
+
+ def __init__(self,
+ in_channels,
+ conv_out=128,
+ fpem_repeat=2,
+ align_corners=False):
+ super().__init__()
+ # reduce layers
+ self.reduce_conv_c2 = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels[0],
+ out_channels=conv_out,
+ kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU())
+ self.reduce_conv_c3 = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels[1],
+ out_channels=conv_out,
+ kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU())
+ self.reduce_conv_c4 = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels[2],
+ out_channels=conv_out,
+ kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU())
+ self.reduce_conv_c5 = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels[3],
+ out_channels=conv_out,
+ kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU())
+ self.align_corners = align_corners
+ self.fpems = nn.ModuleList()
+ for _ in range(fpem_repeat):
+ self.fpems.append(FPEM(conv_out))
+
+ def init_weights(self):
+ """Initialize the weights of FPN module."""
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+ # reduce channel
+ c2 = self.reduce_conv_c2(c2)
+ c3 = self.reduce_conv_c3(c3)
+ c4 = self.reduce_conv_c4(c4)
+ c5 = self.reduce_conv_c5(c5)
+
+ # FPEM
+ for i, fpem in enumerate(self.fpems):
+ c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
+ if i == 0:
+ c2_ffm = c2
+ c3_ffm = c3
+ c4_ffm = c4
+ c5_ffm = c5
+ else:
+ c2_ffm += c2
+ c3_ffm += c3
+ c4_ffm += c4
+ c5_ffm += c5
+
+ # FFM
+ c5 = F.interpolate(
+ c5_ffm,
+ c2_ffm.size()[-2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ c4 = F.interpolate(
+ c4_ffm,
+ c2_ffm.size()[-2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ c3 = F.interpolate(
+ c3_ffm,
+ c2_ffm.size()[-2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ outs = [c2_ffm, c3, c4, c5]
+ return tuple(outs)
diff --git a/mmocr/models/textdet/necks/fpn_cat.py b/mmocr/models/textdet/necks/fpn_cat.py
new file mode 100644
index 00000000..0a3241c1
--- /dev/null
+++ b/mmocr/models/textdet/necks/fpn_cat.py
@@ -0,0 +1,128 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import auto_fp16
+
+from mmdet.models.builder import NECKS
+
+
+@NECKS.register_module()
+class FPNC(nn.Module):
+ """FPN-like fusion module in Real-time Scene Text Detection with
+ Differentiable Binarization.
+
+ This was partially adapted from https://github.com/MhLiao/DB and
+ https://github.com/WenmuZhou/DBNet.pytorch
+ """
+
+ def __init__(self,
+ in_channels,
+ lateral_channels=256,
+ out_channels=64,
+ bias_on_lateral=False,
+ bn_re_on_lateral=False,
+ bias_on_smooth=False,
+ bn_re_on_smooth=False,
+ conv_after_concat=False):
+ super(FPNC, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.lateral_channels = lateral_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.bn_re_on_lateral = bn_re_on_lateral
+ self.bn_re_on_smooth = bn_re_on_smooth
+ self.conv_after_concat = conv_after_concat
+ self.lateral_convs = nn.ModuleList()
+ self.smooth_convs = nn.ModuleList()
+ self.num_outs = self.num_ins
+
+ for i in range(self.num_ins):
+ norm_cfg = None
+ act_cfg = None
+ if self.bn_re_on_lateral:
+ norm_cfg = dict(type='BN')
+ act_cfg = dict(type='ReLU')
+ l_conv = ConvModule(
+ in_channels[i],
+ lateral_channels,
+ 1,
+ bias=bias_on_lateral,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ norm_cfg = None
+ act_cfg = None
+ if self.bn_re_on_smooth:
+ norm_cfg = dict(type='BN')
+ act_cfg = dict(type='ReLU')
+
+ smooth_conv = ConvModule(
+ lateral_channels,
+ out_channels,
+ 3,
+ bias=bias_on_smooth,
+ padding=1,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.smooth_convs.append(smooth_conv)
+ if self.conv_after_concat:
+ norm_cfg = dict(type='BN')
+ act_cfg = dict(type='ReLU')
+ self.out_conv = ConvModule(
+ out_channels * self.num_outs,
+ out_channels * self.num_outs,
+ 3,
+ padding=1,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ """Initialize the weights of FPN module."""
+ for m in self.lateral_convs:
+ m.init_weights()
+ for m in self.smooth_convs:
+ m.init_weights()
+ if self.conv_after_concat:
+ self.out_conv.init_weights()
+
+ @auto_fp16()
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ used_backbone_levels = len(laterals)
+ # build top-down path
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, mode='nearest')
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.smooth_convs[i](laterals[i])
+ for i in range(used_backbone_levels)
+ ]
+
+ for i in range(len(outs)):
+ scale = 2**i
+ outs[i] = F.interpolate(
+ outs[i], scale_factor=scale, mode='nearest')
+ out = torch.cat(outs, dim=1)
+
+ if self.conv_after_concat:
+ out = self.out_conv(out)
+
+ return out
diff --git a/mmocr/models/textdet/necks/fpn_unet.py b/mmocr/models/textdet/necks/fpn_unet.py
new file mode 100644
index 00000000..85c09085
--- /dev/null
+++ b/mmocr/models/textdet/necks/fpn_unet.py
@@ -0,0 +1,88 @@
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import xavier_init
+from torch import nn
+
+from mmdet.models.builder import NECKS
+
+
+class UpBlock(nn.Module):
+ """Upsample block for DRRG and TextSnake."""
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ assert isinstance(in_channels, int)
+ assert isinstance(out_channels, int)
+
+ self.conv1x1 = nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.conv3x3 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.deconv = nn.ConvTranspose2d(
+ out_channels, out_channels, kernel_size=4, stride=2, padding=1)
+
+ def forward(self, x):
+ x = F.relu(self.conv1x1(x))
+ x = F.relu(self.conv3x3(x))
+ x = self.deconv(x)
+ return x
+
+
+@NECKS.register_module()
+class FPN_UNET(nn.Module):
+ """The class for implementing DRRG and TextSnake U-Net-like FPN.
+
+ DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape
+ Text Detection [https://arxiv.org/abs/2003.07493].
+ TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes
+ [https://arxiv.org/abs/1807.01544].
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ assert len(in_channels) == 4
+ assert isinstance(out_channels, int)
+
+ blocks_out_channels = [out_channels] + [
+ min(out_channels * 2**i, 256) for i in range(4)
+ ]
+ blocks_in_channels = [blocks_out_channels[1]] + [
+ in_channels[i] + blocks_out_channels[i + 2] for i in range(3)
+ ] + [in_channels[3]]
+
+ self.up4 = nn.ConvTranspose2d(
+ blocks_in_channels[4],
+ blocks_out_channels[4],
+ kernel_size=4,
+ stride=2,
+ padding=1)
+ self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3])
+ self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2])
+ self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1])
+ self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0])
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, x):
+ c2, c3, c4, c5 = x
+
+ x = F.relu(self.up4(c5))
+
+ x = torch.cat([x, c4], dim=1)
+ x = F.relu(self.up_block3(x))
+
+ x = torch.cat([x, c3], dim=1)
+ x = F.relu(self.up_block2(x))
+
+ x = torch.cat([x, c2], dim=1)
+ x = F.relu(self.up_block1(x))
+
+ x = self.up_block0(x)
+ # the output should be of the same height and width as backbone input
+ return x
diff --git a/mmocr/models/textdet/necks/fpnf.py b/mmocr/models/textdet/necks/fpnf.py
new file mode 100644
index 00000000..bdb69fea
--- /dev/null
+++ b/mmocr/models/textdet/necks/fpnf.py
@@ -0,0 +1,117 @@
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, xavier_init
+from mmcv.runner import auto_fp16
+from torch import nn
+
+from mmdet.models.builder import NECKS
+
+
+@NECKS.register_module()
+class FPNF(nn.Module):
+ """FPN-like fusion module in Shape Robust Text Detection with Progressive
+ Scale Expansion Network."""
+
+ def __init__(
+ self,
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ fusion_type='concat', # 'concat' or 'add'
+ upsample_ratio=1):
+ super().__init__()
+ conv_cfg = None
+ norm_cfg = dict(type='BN')
+ act_cfg = dict(type='ReLU')
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ self.backbone_end_level = len(in_channels)
+ for i in range(self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.lateral_convs.append(l_conv)
+
+ if i < self.backbone_end_level - 1:
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fusion_type = fusion_type
+
+ if self.fusion_type == 'concat':
+ feature_channels = 1024
+ elif self.fusion_type == 'add':
+ feature_channels = 256
+ else:
+ raise NotImplementedError
+
+ self.output_convs = ConvModule(
+ feature_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=None,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.upsample_ratio = upsample_ratio
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ @auto_fp16()
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # step 1: upsample to level i-1 size and add level i-1
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, mode='nearest')
+ # step 2: smooth level i-1
+ laterals[i - 1] = self.fpn_convs[i - 1](laterals[i - 1])
+
+ # upsample and cont
+ bottom_shape = laterals[0].shape[2:]
+ for i in range(1, used_backbone_levels):
+ laterals[i] = F.interpolate(
+ laterals[i], size=bottom_shape, mode='nearest')
+
+ if self.fusion_type == 'concat':
+ out = torch.cat(laterals, 1)
+ elif self.fusion_type == 'add':
+ out = laterals[0]
+ for i in range(1, used_backbone_levels):
+ out += laterals[i]
+ else:
+ raise NotImplementedError
+ out = self.output_convs(out)
+
+ return out
diff --git a/mmocr/models/textdet/postprocess/__init__.py b/mmocr/models/textdet/postprocess/__init__.py
new file mode 100644
index 00000000..acc72530
--- /dev/null
+++ b/mmocr/models/textdet/postprocess/__init__.py
@@ -0,0 +1,3 @@
+from .wrapper import decode
+
+__all__ = ['decode']
diff --git a/mmocr/models/textdet/postprocess/include/clipper/clipper.cpp b/mmocr/models/textdet/postprocess/include/clipper/clipper.cpp
new file mode 100644
index 00000000..521c613c
--- /dev/null
+++ b/mmocr/models/textdet/postprocess/include/clipper/clipper.cpp
@@ -0,0 +1,4622 @@
+/*******************************************************************************
+* *
+* Author : Angus Johnson *
+* Version : 6.4.0 *
+* Date : 2 July 2015 *
+* Website : http://www.angusj.com *
+* Copyright : Angus Johnson 2010-2015 *
+* *
+* License: *
+* Use, modification & distribution is subject to Boost Software License Ver 1. *
+* http://www.boost.org/LICENSE_1_0.txt *
+* *
+* Attributions: *
+* The code in this library is an extension of Bala Vatti's clipping algorithm: *
+* "A generic solution to polygon clipping" *
+* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
+* http://portal.acm.org/citation.cfm?id=129906 *
+* *
+* Computer graphics and geometric modeling: implementation and algorithms *
+* By Max K. Agoston *
+* Springer; 1 edition (January 4, 2005) *
+* http://books.google.com/books?q=vatti+clipping+agoston *
+* *
+* See also: *
+* "Polygon Offsetting by Computing Winding Numbers" *
+* Paper no. DETC2005-85513 pp. 565-575 *
+* ASME 2005 International Design Engineering Technical Conferences *
+* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
+* September 24-28, 2005 , Long Beach, California, USA *
+* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
+* *
+*******************************************************************************/
+
+/*******************************************************************************
+* *
+* This is a translation of the Delphi Clipper library and the naming style *
+* used has retained a Delphi flavour. *
+* *
+*******************************************************************************/
+
+#include "clipper.hpp"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace ClipperLib {
+
+static double const pi = 3.141592653589793238;
+static double const two_pi = pi *2;
+static double const def_arc_tolerance = 0.25;
+
+enum Direction { dRightToLeft, dLeftToRight };
+
+static int const Unassigned = -1; //edge not currently 'owning' a solution
+static int const Skip = -2; //edge that would otherwise close a path
+
+#define HORIZONTAL (-1.0E+40)
+#define TOLERANCE (1.0e-20)
+#define NEAR_ZERO(val) (((val) > -TOLERANCE) && ((val) < TOLERANCE))
+
+struct TEdge {
+ IntPoint Bot;
+ IntPoint Curr; //current (updated for every new scanbeam)
+ IntPoint Top;
+ double Dx;
+ PolyType PolyTyp;
+ EdgeSide Side; //side only refers to current side of solution poly
+ int WindDelta; //1 or -1 depending on winding direction
+ int WindCnt;
+ int WindCnt2; //winding count of the opposite polytype
+ int OutIdx;
+ TEdge *Next;
+ TEdge *Prev;
+ TEdge *NextInLML;
+ TEdge *NextInAEL;
+ TEdge *PrevInAEL;
+ TEdge *NextInSEL;
+ TEdge *PrevInSEL;
+};
+
+struct IntersectNode {
+ TEdge *Edge1;
+ TEdge *Edge2;
+ IntPoint Pt;
+};
+
+struct LocalMinimum {
+ cInt Y;
+ TEdge *LeftBound;
+ TEdge *RightBound;
+};
+
+struct OutPt;
+
+//OutRec: contains a path in the clipping solution. Edges in the AEL will
+//carry a pointer to an OutRec when they are part of the clipping solution.
+struct OutRec {
+ int Idx;
+ bool IsHole;
+ bool IsOpen;
+ OutRec *FirstLeft; //see comments in clipper.pas
+ PolyNode *PolyNd;
+ OutPt *Pts;
+ OutPt *BottomPt;
+};
+
+struct OutPt {
+ int Idx;
+ IntPoint Pt;
+ OutPt *Next;
+ OutPt *Prev;
+};
+
+struct Join {
+ OutPt *OutPt1;
+ OutPt *OutPt2;
+ IntPoint OffPt;
+};
+
+struct LocMinSorter
+{
+ inline bool operator()(const LocalMinimum& locMin1, const LocalMinimum& locMin2)
+ {
+ return locMin2.Y < locMin1.Y;
+ }
+};
+
+//------------------------------------------------------------------------------
+//------------------------------------------------------------------------------
+
+inline cInt Round(double val)
+{
+ if ((val < 0)) return static_cast(val - 0.5);
+ else return static_cast(val + 0.5);
+}
+//------------------------------------------------------------------------------
+
+inline cInt Abs(cInt val)
+{
+ return val < 0 ? -val : val;
+}
+
+//------------------------------------------------------------------------------
+// PolyTree methods ...
+//------------------------------------------------------------------------------
+
+void PolyTree::Clear()
+{
+ for (PolyNodes::size_type i = 0; i < AllNodes.size(); ++i)
+ delete AllNodes[i];
+ AllNodes.resize(0);
+ Childs.resize(0);
+}
+//------------------------------------------------------------------------------
+
+PolyNode* PolyTree::GetFirst() const
+{
+ if (!Childs.empty())
+ return Childs[0];
+ else
+ return 0;
+}
+//------------------------------------------------------------------------------
+
+int PolyTree::Total() const
+{
+ int result = (int)AllNodes.size();
+ //with negative offsets, ignore the hidden outer polygon ...
+ if (result > 0 && Childs[0] != AllNodes[0]) result--;
+ return result;
+}
+
+//------------------------------------------------------------------------------
+// PolyNode methods ...
+//------------------------------------------------------------------------------
+
+PolyNode::PolyNode(): Childs(), Parent(0), Index(0), m_IsOpen(false)
+{
+}
+//------------------------------------------------------------------------------
+
+int PolyNode::ChildCount() const
+{
+ return (int)Childs.size();
+}
+//------------------------------------------------------------------------------
+
+void PolyNode::AddChild(PolyNode& child)
+{
+ unsigned cnt = (unsigned)Childs.size();
+ Childs.push_back(&child);
+ child.Parent = this;
+ child.Index = cnt;
+}
+//------------------------------------------------------------------------------
+
+PolyNode* PolyNode::GetNext() const
+{
+ if (!Childs.empty())
+ return Childs[0];
+ else
+ return GetNextSiblingUp();
+}
+//------------------------------------------------------------------------------
+
+PolyNode* PolyNode::GetNextSiblingUp() const
+{
+ if (!Parent) //protects against PolyTree.GetNextSiblingUp()
+ return 0;
+ else if (Index == Parent->Childs.size() - 1)
+ return Parent->GetNextSiblingUp();
+ else
+ return Parent->Childs[Index + 1];
+}
+//------------------------------------------------------------------------------
+
+bool PolyNode::IsHole() const
+{
+ bool result = true;
+ PolyNode* node = Parent;
+ while (node)
+ {
+ result = !result;
+ node = node->Parent;
+ }
+ return result;
+}
+//------------------------------------------------------------------------------
+
+bool PolyNode::IsOpen() const
+{
+ return m_IsOpen;
+}
+//------------------------------------------------------------------------------
+
+#ifndef use_int32
+
+//------------------------------------------------------------------------------
+// Int128 class (enables safe math on signed 64bit integers)
+// eg Int128 val1((long64)9223372036854775807); //ie 2^63 -1
+// Int128 val2((long64)9223372036854775807);
+// Int128 val3 = val1 * val2;
+// val3.AsString => "85070591730234615847396907784232501249" (8.5e+37)
+//------------------------------------------------------------------------------
+
+class Int128
+{
+ public:
+ ulong64 lo;
+ long64 hi;
+
+ Int128(long64 _lo = 0)
+ {
+ lo = (ulong64)_lo;
+ if (_lo < 0) hi = -1; else hi = 0;
+ }
+
+
+ Int128(const Int128 &val): lo(val.lo), hi(val.hi){}
+
+ Int128(const long64& _hi, const ulong64& _lo): lo(_lo), hi(_hi){}
+
+ Int128& operator = (const long64 &val)
+ {
+ lo = (ulong64)val;
+ if (val < 0) hi = -1; else hi = 0;
+ return *this;
+ }
+
+ bool operator == (const Int128 &val) const
+ {return (hi == val.hi && lo == val.lo);}
+
+ bool operator != (const Int128 &val) const
+ { return !(*this == val);}
+
+ bool operator > (const Int128 &val) const
+ {
+ if (hi != val.hi)
+ return hi > val.hi;
+ else
+ return lo > val.lo;
+ }
+
+ bool operator < (const Int128 &val) const
+ {
+ if (hi != val.hi)
+ return hi < val.hi;
+ else
+ return lo < val.lo;
+ }
+
+ bool operator >= (const Int128 &val) const
+ { return !(*this < val);}
+
+ bool operator <= (const Int128 &val) const
+ { return !(*this > val);}
+
+ Int128& operator += (const Int128 &rhs)
+ {
+ hi += rhs.hi;
+ lo += rhs.lo;
+ if (lo < rhs.lo) hi++;
+ return *this;
+ }
+
+ Int128 operator + (const Int128 &rhs) const
+ {
+ Int128 result(*this);
+ result+= rhs;
+ return result;
+ }
+
+ Int128& operator -= (const Int128 &rhs)
+ {
+ *this += -rhs;
+ return *this;
+ }
+
+ Int128 operator - (const Int128 &rhs) const
+ {
+ Int128 result(*this);
+ result -= rhs;
+ return result;
+ }
+
+ Int128 operator-() const //unary negation
+ {
+ if (lo == 0)
+ return Int128(-hi, 0);
+ else
+ return Int128(~hi, ~lo + 1);
+ }
+
+ operator double() const
+ {
+ const double shift64 = 18446744073709551616.0; //2^64
+ if (hi < 0)
+ {
+ if (lo == 0) return (double)hi * shift64;
+ else return -(double)(~lo + ~hi * shift64);
+ }
+ else
+ return (double)(lo + hi * shift64);
+ }
+
+};
+//------------------------------------------------------------------------------
+
+Int128 Int128Mul (long64 lhs, long64 rhs)
+{
+ bool negate = (lhs < 0) != (rhs < 0);
+
+ if (lhs < 0) lhs = -lhs;
+ ulong64 int1Hi = ulong64(lhs) >> 32;
+ ulong64 int1Lo = ulong64(lhs & 0xFFFFFFFF);
+
+ if (rhs < 0) rhs = -rhs;
+ ulong64 int2Hi = ulong64(rhs) >> 32;
+ ulong64 int2Lo = ulong64(rhs & 0xFFFFFFFF);
+
+ //nb: see comments in clipper.pas
+ ulong64 a = int1Hi * int2Hi;
+ ulong64 b = int1Lo * int2Lo;
+ ulong64 c = int1Hi * int2Lo + int1Lo * int2Hi;
+
+ Int128 tmp;
+ tmp.hi = long64(a + (c >> 32));
+ tmp.lo = long64(c << 32);
+ tmp.lo += long64(b);
+ if (tmp.lo < b) tmp.hi++;
+ if (negate) tmp = -tmp;
+ return tmp;
+};
+#endif
+
+//------------------------------------------------------------------------------
+// Miscellaneous global functions
+//------------------------------------------------------------------------------
+
+bool Orientation(const Path &poly)
+{
+ return Area(poly) >= 0;
+}
+//------------------------------------------------------------------------------
+
+double Area(const Path &poly)
+{
+ int size = (int)poly.size();
+ if (size < 3) return 0;
+
+ double a = 0;
+ for (int i = 0, j = size -1; i < size; ++i)
+ {
+ a += ((double)poly[j].X + poly[i].X) * ((double)poly[j].Y - poly[i].Y);
+ j = i;
+ }
+ return -a * 0.5;
+}
+//------------------------------------------------------------------------------
+
+double Area(const OutPt *op)
+{
+ const OutPt *startOp = op;
+ if (!op) return 0;
+ double a = 0;
+ do {
+ a += (double)(op->Prev->Pt.X + op->Pt.X) * (double)(op->Prev->Pt.Y - op->Pt.Y);
+ op = op->Next;
+ } while (op != startOp);
+ return a * 0.5;
+}
+//------------------------------------------------------------------------------
+
+double Area(const OutRec &outRec)
+{
+ return Area(outRec.Pts);
+}
+//------------------------------------------------------------------------------
+
+bool PointIsVertex(const IntPoint &Pt, OutPt *pp)
+{
+ OutPt *pp2 = pp;
+ do
+ {
+ if (pp2->Pt == Pt) return true;
+ pp2 = pp2->Next;
+ }
+ while (pp2 != pp);
+ return false;
+}
+//------------------------------------------------------------------------------
+
+//See "The Point in Polygon Problem for Arbitrary Polygons" by Hormann & Agathos
+//http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.88.5498&rep=rep1&type=pdf
+int PointInPolygon(const IntPoint &pt, const Path &path)
+{
+ //returns 0 if false, +1 if true, -1 if pt ON polygon boundary
+ int result = 0;
+ size_t cnt = path.size();
+ if (cnt < 3) return 0;
+ IntPoint ip = path[0];
+ for(size_t i = 1; i <= cnt; ++i)
+ {
+ IntPoint ipNext = (i == cnt ? path[0] : path[i]);
+ if (ipNext.Y == pt.Y)
+ {
+ if ((ipNext.X == pt.X) || (ip.Y == pt.Y &&
+ ((ipNext.X > pt.X) == (ip.X < pt.X)))) return -1;
+ }
+ if ((ip.Y < pt.Y) != (ipNext.Y < pt.Y))
+ {
+ if (ip.X >= pt.X)
+ {
+ if (ipNext.X > pt.X) result = 1 - result;
+ else
+ {
+ double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) -
+ (double)(ipNext.X - pt.X) * (ip.Y - pt.Y);
+ if (!d) return -1;
+ if ((d > 0) == (ipNext.Y > ip.Y)) result = 1 - result;
+ }
+ } else
+ {
+ if (ipNext.X > pt.X)
+ {
+ double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) -
+ (double)(ipNext.X - pt.X) * (ip.Y - pt.Y);
+ if (!d) return -1;
+ if ((d > 0) == (ipNext.Y > ip.Y)) result = 1 - result;
+ }
+ }
+ }
+ ip = ipNext;
+ }
+ return result;
+}
+//------------------------------------------------------------------------------
+
+int PointInPolygon (const IntPoint &pt, OutPt *op)
+{
+ //returns 0 if false, +1 if true, -1 if pt ON polygon boundary
+ int result = 0;
+ OutPt* startOp = op;
+ for(;;)
+ {
+ if (op->Next->Pt.Y == pt.Y)
+ {
+ if ((op->Next->Pt.X == pt.X) || (op->Pt.Y == pt.Y &&
+ ((op->Next->Pt.X > pt.X) == (op->Pt.X < pt.X)))) return -1;
+ }
+ if ((op->Pt.Y < pt.Y) != (op->Next->Pt.Y < pt.Y))
+ {
+ if (op->Pt.X >= pt.X)
+ {
+ if (op->Next->Pt.X > pt.X) result = 1 - result;
+ else
+ {
+ double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) -
+ (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y);
+ if (!d) return -1;
+ if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) result = 1 - result;
+ }
+ } else
+ {
+ if (op->Next->Pt.X > pt.X)
+ {
+ double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) -
+ (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y);
+ if (!d) return -1;
+ if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) result = 1 - result;
+ }
+ }
+ }
+ op = op->Next;
+ if (startOp == op) break;
+ }
+ return result;
+}
+//------------------------------------------------------------------------------
+
+bool Poly2ContainsPoly1(OutPt *OutPt1, OutPt *OutPt2)
+{
+ OutPt* op = OutPt1;
+ do
+ {
+ //nb: PointInPolygon returns 0 if false, +1 if true, -1 if pt on polygon
+ int res = PointInPolygon(op->Pt, OutPt2);
+ if (res >= 0) return res > 0;
+ op = op->Next;
+ }
+ while (op != OutPt1);
+ return true;
+}
+//----------------------------------------------------------------------
+
+bool SlopesEqual(const TEdge &e1, const TEdge &e2, bool UseFullInt64Range)
+{
+#ifndef use_int32
+ if (UseFullInt64Range)
+ return Int128Mul(e1.Top.Y - e1.Bot.Y, e2.Top.X - e2.Bot.X) ==
+ Int128Mul(e1.Top.X - e1.Bot.X, e2.Top.Y - e2.Bot.Y);
+ else
+#endif
+ return (e1.Top.Y - e1.Bot.Y) * (e2.Top.X - e2.Bot.X) ==
+ (e1.Top.X - e1.Bot.X) * (e2.Top.Y - e2.Bot.Y);
+}
+//------------------------------------------------------------------------------
+
+bool SlopesEqual(const IntPoint pt1, const IntPoint pt2,
+ const IntPoint pt3, bool UseFullInt64Range)
+{
+#ifndef use_int32
+ if (UseFullInt64Range)
+ return Int128Mul(pt1.Y-pt2.Y, pt2.X-pt3.X) == Int128Mul(pt1.X-pt2.X, pt2.Y-pt3.Y);
+ else
+#endif
+ return (pt1.Y-pt2.Y)*(pt2.X-pt3.X) == (pt1.X-pt2.X)*(pt2.Y-pt3.Y);
+}
+//------------------------------------------------------------------------------
+
+bool SlopesEqual(const IntPoint pt1, const IntPoint pt2,
+ const IntPoint pt3, const IntPoint pt4, bool UseFullInt64Range)
+{
+#ifndef use_int32
+ if (UseFullInt64Range)
+ return Int128Mul(pt1.Y-pt2.Y, pt3.X-pt4.X) == Int128Mul(pt1.X-pt2.X, pt3.Y-pt4.Y);
+ else
+#endif
+ return (pt1.Y-pt2.Y)*(pt3.X-pt4.X) == (pt1.X-pt2.X)*(pt3.Y-pt4.Y);
+}
+//------------------------------------------------------------------------------
+
+inline bool IsHorizontal(TEdge &e)
+{
+ return e.Dx == HORIZONTAL;
+}
+//------------------------------------------------------------------------------
+
+inline double GetDx(const IntPoint pt1, const IntPoint pt2)
+{
+ return (pt1.Y == pt2.Y) ?
+ HORIZONTAL : (double)(pt2.X - pt1.X) / (pt2.Y - pt1.Y);
+}
+//---------------------------------------------------------------------------
+
+inline void SetDx(TEdge &e)
+{
+ cInt dy = (e.Top.Y - e.Bot.Y);
+ if (dy == 0) e.Dx = HORIZONTAL;
+ else e.Dx = (double)(e.Top.X - e.Bot.X) / dy;
+}
+//---------------------------------------------------------------------------
+
+inline void SwapSides(TEdge &Edge1, TEdge &Edge2)
+{
+ EdgeSide Side = Edge1.Side;
+ Edge1.Side = Edge2.Side;
+ Edge2.Side = Side;
+}
+//------------------------------------------------------------------------------
+
+inline void SwapPolyIndexes(TEdge &Edge1, TEdge &Edge2)
+{
+ int OutIdx = Edge1.OutIdx;
+ Edge1.OutIdx = Edge2.OutIdx;
+ Edge2.OutIdx = OutIdx;
+}
+//------------------------------------------------------------------------------
+
+inline cInt TopX(TEdge &edge, const cInt currentY)
+{
+ return ( currentY == edge.Top.Y ) ?
+ edge.Top.X : edge.Bot.X + Round(edge.Dx *(currentY - edge.Bot.Y));
+}
+//------------------------------------------------------------------------------
+
+void IntersectPoint(TEdge &Edge1, TEdge &Edge2, IntPoint &ip)
+{
+#ifdef use_xyz
+ ip.Z = 0;
+#endif
+
+ double b1, b2;
+ if (Edge1.Dx == Edge2.Dx)
+ {
+ ip.Y = Edge1.Curr.Y;
+ ip.X = TopX(Edge1, ip.Y);
+ return;
+ }
+ else if (Edge1.Dx == 0)
+ {
+ ip.X = Edge1.Bot.X;
+ if (IsHorizontal(Edge2))
+ ip.Y = Edge2.Bot.Y;
+ else
+ {
+ b2 = Edge2.Bot.Y - (Edge2.Bot.X / Edge2.Dx);
+ ip.Y = Round(ip.X / Edge2.Dx + b2);
+ }
+ }
+ else if (Edge2.Dx == 0)
+ {
+ ip.X = Edge2.Bot.X;
+ if (IsHorizontal(Edge1))
+ ip.Y = Edge1.Bot.Y;
+ else
+ {
+ b1 = Edge1.Bot.Y - (Edge1.Bot.X / Edge1.Dx);
+ ip.Y = Round(ip.X / Edge1.Dx + b1);
+ }
+ }
+ else
+ {
+ b1 = Edge1.Bot.X - Edge1.Bot.Y * Edge1.Dx;
+ b2 = Edge2.Bot.X - Edge2.Bot.Y * Edge2.Dx;
+ double q = (b2-b1) / (Edge1.Dx - Edge2.Dx);
+ ip.Y = Round(q);
+ if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx))
+ ip.X = Round(Edge1.Dx * q + b1);
+ else
+ ip.X = Round(Edge2.Dx * q + b2);
+ }
+
+ if (ip.Y < Edge1.Top.Y || ip.Y < Edge2.Top.Y)
+ {
+ if (Edge1.Top.Y > Edge2.Top.Y)
+ ip.Y = Edge1.Top.Y;
+ else
+ ip.Y = Edge2.Top.Y;
+ if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx))
+ ip.X = TopX(Edge1, ip.Y);
+ else
+ ip.X = TopX(Edge2, ip.Y);
+ }
+ //finally, don't allow 'ip' to be BELOW curr.Y (ie bottom of scanbeam) ...
+ if (ip.Y > Edge1.Curr.Y)
+ {
+ ip.Y = Edge1.Curr.Y;
+ //use the more vertical edge to derive X ...
+ if (std::fabs(Edge1.Dx) > std::fabs(Edge2.Dx))
+ ip.X = TopX(Edge2, ip.Y); else
+ ip.X = TopX(Edge1, ip.Y);
+ }
+}
+//------------------------------------------------------------------------------
+
+void ReversePolyPtLinks(OutPt *pp)
+{
+ if (!pp) return;
+ OutPt *pp1, *pp2;
+ pp1 = pp;
+ do {
+ pp2 = pp1->Next;
+ pp1->Next = pp1->Prev;
+ pp1->Prev = pp2;
+ pp1 = pp2;
+ } while( pp1 != pp );
+}
+//------------------------------------------------------------------------------
+
+void DisposeOutPts(OutPt*& pp)
+{
+ if (pp == 0) return;
+ pp->Prev->Next = 0;
+ while( pp )
+ {
+ OutPt *tmpPp = pp;
+ pp = pp->Next;
+ delete tmpPp;
+ }
+}
+//------------------------------------------------------------------------------
+
+inline void InitEdge(TEdge* e, TEdge* eNext, TEdge* ePrev, const IntPoint& Pt)
+{
+ std::memset(e, 0, sizeof(TEdge));
+ e->Next = eNext;
+ e->Prev = ePrev;
+ e->Curr = Pt;
+ e->OutIdx = Unassigned;
+}
+//------------------------------------------------------------------------------
+
+void InitEdge2(TEdge& e, PolyType Pt)
+{
+ if (e.Curr.Y >= e.Next->Curr.Y)
+ {
+ e.Bot = e.Curr;
+ e.Top = e.Next->Curr;
+ } else
+ {
+ e.Top = e.Curr;
+ e.Bot = e.Next->Curr;
+ }
+ SetDx(e);
+ e.PolyTyp = Pt;
+}
+//------------------------------------------------------------------------------
+
+TEdge* RemoveEdge(TEdge* e)
+{
+ //removes e from double_linked_list (but without removing from memory)
+ e->Prev->Next = e->Next;
+ e->Next->Prev = e->Prev;
+ TEdge* result = e->Next;
+ e->Prev = 0; //flag as removed (see ClipperBase.Clear)
+ return result;
+}
+//------------------------------------------------------------------------------
+
+inline void ReverseHorizontal(TEdge &e)
+{
+ //swap horizontal edges' Top and Bottom x's so they follow the natural
+ //progression of the bounds - ie so their xbots will align with the
+ //adjoining lower edge. [Helpful in the ProcessHorizontal() method.]
+ std::swap(e.Top.X, e.Bot.X);
+#ifdef use_xyz
+ std::swap(e.Top.Z, e.Bot.Z);
+#endif
+}
+//------------------------------------------------------------------------------
+
+void SwapPoints(IntPoint &pt1, IntPoint &pt2)
+{
+ IntPoint tmp = pt1;
+ pt1 = pt2;
+ pt2 = tmp;
+}
+//------------------------------------------------------------------------------
+
+bool GetOverlapSegment(IntPoint pt1a, IntPoint pt1b, IntPoint pt2a,
+ IntPoint pt2b, IntPoint &pt1, IntPoint &pt2)
+{
+ //precondition: segments are Collinear.
+ if (Abs(pt1a.X - pt1b.X) > Abs(pt1a.Y - pt1b.Y))
+ {
+ if (pt1a.X > pt1b.X) SwapPoints(pt1a, pt1b);
+ if (pt2a.X > pt2b.X) SwapPoints(pt2a, pt2b);
+ if (pt1a.X > pt2a.X) pt1 = pt1a; else pt1 = pt2a;
+ if (pt1b.X < pt2b.X) pt2 = pt1b; else pt2 = pt2b;
+ return pt1.X < pt2.X;
+ } else
+ {
+ if (pt1a.Y < pt1b.Y) SwapPoints(pt1a, pt1b);
+ if (pt2a.Y < pt2b.Y) SwapPoints(pt2a, pt2b);
+ if (pt1a.Y < pt2a.Y) pt1 = pt1a; else pt1 = pt2a;
+ if (pt1b.Y > pt2b.Y) pt2 = pt1b; else pt2 = pt2b;
+ return pt1.Y > pt2.Y;
+ }
+}
+//------------------------------------------------------------------------------
+
+bool FirstIsBottomPt(const OutPt* btmPt1, const OutPt* btmPt2)
+{
+ OutPt *p = btmPt1->Prev;
+ while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) p = p->Prev;
+ double dx1p = std::fabs(GetDx(btmPt1->Pt, p->Pt));
+ p = btmPt1->Next;
+ while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) p = p->Next;
+ double dx1n = std::fabs(GetDx(btmPt1->Pt, p->Pt));
+
+ p = btmPt2->Prev;
+ while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) p = p->Prev;
+ double dx2p = std::fabs(GetDx(btmPt2->Pt, p->Pt));
+ p = btmPt2->Next;
+ while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) p = p->Next;
+ double dx2n = std::fabs(GetDx(btmPt2->Pt, p->Pt));
+
+ if (std::max(dx1p, dx1n) == std::max(dx2p, dx2n) &&
+ std::min(dx1p, dx1n) == std::min(dx2p, dx2n))
+ return Area(btmPt1) > 0; //if otherwise identical use orientation
+ else
+ return (dx1p >= dx2p && dx1p >= dx2n) || (dx1n >= dx2p && dx1n >= dx2n);
+}
+//------------------------------------------------------------------------------
+
+OutPt* GetBottomPt(OutPt *pp)
+{
+ OutPt* dups = 0;
+ OutPt* p = pp->Next;
+ while (p != pp)
+ {
+ if (p->Pt.Y > pp->Pt.Y)
+ {
+ pp = p;
+ dups = 0;
+ }
+ else if (p->Pt.Y == pp->Pt.Y && p->Pt.X <= pp->Pt.X)
+ {
+ if (p->Pt.X < pp->Pt.X)
+ {
+ dups = 0;
+ pp = p;
+ } else
+ {
+ if (p->Next != pp && p->Prev != pp) dups = p;
+ }
+ }
+ p = p->Next;
+ }
+ if (dups)
+ {
+ //there appears to be at least 2 vertices at BottomPt so ...
+ while (dups != p)
+ {
+ if (!FirstIsBottomPt(p, dups)) pp = dups;
+ dups = dups->Next;
+ while (dups->Pt != pp->Pt) dups = dups->Next;
+ }
+ }
+ return pp;
+}
+//------------------------------------------------------------------------------
+
+bool Pt2IsBetweenPt1AndPt3(const IntPoint pt1,
+ const IntPoint pt2, const IntPoint pt3)
+{
+ if ((pt1 == pt3) || (pt1 == pt2) || (pt3 == pt2))
+ return false;
+ else if (pt1.X != pt3.X)
+ return (pt2.X > pt1.X) == (pt2.X < pt3.X);
+ else
+ return (pt2.Y > pt1.Y) == (pt2.Y < pt3.Y);
+}
+//------------------------------------------------------------------------------
+
+bool HorzSegmentsOverlap(cInt seg1a, cInt seg1b, cInt seg2a, cInt seg2b)
+{
+ if (seg1a > seg1b) std::swap(seg1a, seg1b);
+ if (seg2a > seg2b) std::swap(seg2a, seg2b);
+ return (seg1a < seg2b) && (seg2a < seg1b);
+}
+
+//------------------------------------------------------------------------------
+// ClipperBase class methods ...
+//------------------------------------------------------------------------------
+
+ClipperBase::ClipperBase() //constructor
+{
+ m_CurrentLM = m_MinimaList.begin(); //begin() == end() here
+ m_UseFullRange = false;
+}
+//------------------------------------------------------------------------------
+
+ClipperBase::~ClipperBase() //destructor
+{
+ Clear();
+}
+//------------------------------------------------------------------------------
+
+void RangeTest(const IntPoint& Pt, bool& useFullRange)
+{
+ if (useFullRange)
+ {
+ if (Pt.X > hiRange || Pt.Y > hiRange || -Pt.X > hiRange || -Pt.Y > hiRange)
+ throw clipperException("Coordinate outside allowed range");
+ }
+ else if (Pt.X > loRange|| Pt.Y > loRange || -Pt.X > loRange || -Pt.Y > loRange)
+ {
+ useFullRange = true;
+ RangeTest(Pt, useFullRange);
+ }
+}
+//------------------------------------------------------------------------------
+
+TEdge* FindNextLocMin(TEdge* E)
+{
+ for (;;)
+ {
+ while (E->Bot != E->Prev->Bot || E->Curr == E->Top) E = E->Next;
+ if (!IsHorizontal(*E) && !IsHorizontal(*E->Prev)) break;
+ while (IsHorizontal(*E->Prev)) E = E->Prev;
+ TEdge* E2 = E;
+ while (IsHorizontal(*E)) E = E->Next;
+ if (E->Top.Y == E->Prev->Bot.Y) continue; //ie just an intermediate horz.
+ if (E2->Prev->Bot.X < E->Bot.X) E = E2;
+ break;
+ }
+ return E;
+}
+//------------------------------------------------------------------------------
+
+TEdge* ClipperBase::ProcessBound(TEdge* E, bool NextIsForward)
+{
+ TEdge *Result = E;
+ TEdge *Horz = 0;
+
+ if (E->OutIdx == Skip)
+ {
+ //if edges still remain in the current bound beyond the skip edge then
+ //create another LocMin and call ProcessBound once more
+ if (NextIsForward)
+ {
+ while (E->Top.Y == E->Next->Bot.Y) E = E->Next;
+ //don't include top horizontals when parsing a bound a second time,
+ //they will be contained in the opposite bound ...
+ while (E != Result && IsHorizontal(*E)) E = E->Prev;
+ }
+ else
+ {
+ while (E->Top.Y == E->Prev->Bot.Y) E = E->Prev;
+ while (E != Result && IsHorizontal(*E)) E = E->Next;
+ }
+
+ if (E == Result)
+ {
+ if (NextIsForward) Result = E->Next;
+ else Result = E->Prev;
+ }
+ else
+ {
+ //there are more edges in the bound beyond result starting with E
+ if (NextIsForward)
+ E = Result->Next;
+ else
+ E = Result->Prev;
+ MinimaList::value_type locMin;
+ locMin.Y = E->Bot.Y;
+ locMin.LeftBound = 0;
+ locMin.RightBound = E;
+ E->WindDelta = 0;
+ Result = ProcessBound(E, NextIsForward);
+ m_MinimaList.push_back(locMin);
+ }
+ return Result;
+ }
+
+ TEdge *EStart;
+
+ if (IsHorizontal(*E))
+ {
+ //We need to be careful with open paths because this may not be a
+ //true local minima (ie E may be following a skip edge).
+ //Also, consecutive horz. edges may start heading left before going right.
+ if (NextIsForward)
+ EStart = E->Prev;
+ else
+ EStart = E->Next;
+ if (IsHorizontal(*EStart)) //ie an adjoining horizontal skip edge
+ {
+ if (EStart->Bot.X != E->Bot.X && EStart->Top.X != E->Bot.X)
+ ReverseHorizontal(*E);
+ }
+ else if (EStart->Bot.X != E->Bot.X)
+ ReverseHorizontal(*E);
+ }
+
+ EStart = E;
+ if (NextIsForward)
+ {
+ while (Result->Top.Y == Result->Next->Bot.Y && Result->Next->OutIdx != Skip)
+ Result = Result->Next;
+ if (IsHorizontal(*Result) && Result->Next->OutIdx != Skip)
+ {
+ //nb: at the top of a bound, horizontals are added to the bound
+ //only when the preceding edge attaches to the horizontal's left vertex
+ //unless a Skip edge is encountered when that becomes the top divide
+ Horz = Result;
+ while (IsHorizontal(*Horz->Prev)) Horz = Horz->Prev;
+ if (Horz->Prev->Top.X > Result->Next->Top.X) Result = Horz->Prev;
+ }
+ while (E != Result)
+ {
+ E->NextInLML = E->Next;
+ if (IsHorizontal(*E) && E != EStart &&
+ E->Bot.X != E->Prev->Top.X) ReverseHorizontal(*E);
+ E = E->Next;
+ }
+ if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Prev->Top.X)
+ ReverseHorizontal(*E);
+ Result = Result->Next; //move to the edge just beyond current bound
+ } else
+ {
+ while (Result->Top.Y == Result->Prev->Bot.Y && Result->Prev->OutIdx != Skip)
+ Result = Result->Prev;
+ if (IsHorizontal(*Result) && Result->Prev->OutIdx != Skip)
+ {
+ Horz = Result;
+ while (IsHorizontal(*Horz->Next)) Horz = Horz->Next;
+ if (Horz->Next->Top.X == Result->Prev->Top.X ||
+ Horz->Next->Top.X > Result->Prev->Top.X) Result = Horz->Next;
+ }
+
+ while (E != Result)
+ {
+ E->NextInLML = E->Prev;
+ if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X)
+ ReverseHorizontal(*E);
+ E = E->Prev;
+ }
+ if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X)
+ ReverseHorizontal(*E);
+ Result = Result->Prev; //move to the edge just beyond current bound
+ }
+
+ return Result;
+}
+//------------------------------------------------------------------------------
+
+bool ClipperBase::AddPath(const Path &pg, PolyType PolyTyp, bool Closed)
+{
+#ifdef use_lines
+ if (!Closed && PolyTyp == ptClip)
+ throw clipperException("AddPath: Open paths must be subject.");
+#else
+ if (!Closed)
+ throw clipperException("AddPath: Open paths have been disabled.");
+#endif
+
+ int highI = (int)pg.size() -1;
+ if (Closed) while (highI > 0 && (pg[highI] == pg[0])) --highI;
+ while (highI > 0 && (pg[highI] == pg[highI -1])) --highI;
+ if ((Closed && highI < 2) || (!Closed && highI < 1)) return false;
+
+ //create a new edge array ...
+ TEdge *edges = new TEdge [highI +1];
+
+ bool IsFlat = true;
+ //1. Basic (first) edge initialization ...
+ try
+ {
+ edges[1].Curr = pg[1];
+ RangeTest(pg[0], m_UseFullRange);
+ RangeTest(pg[highI], m_UseFullRange);
+ InitEdge(&edges[0], &edges[1], &edges[highI], pg[0]);
+ InitEdge(&edges[highI], &edges[0], &edges[highI-1], pg[highI]);
+ for (int i = highI - 1; i >= 1; --i)
+ {
+ RangeTest(pg[i], m_UseFullRange);
+ InitEdge(&edges[i], &edges[i+1], &edges[i-1], pg[i]);
+ }
+ }
+ catch(...)
+ {
+ delete [] edges;
+ throw; //range test fails
+ }
+ TEdge *eStart = &edges[0];
+
+ //2. Remove duplicate vertices, and (when closed) collinear edges ...
+ TEdge *E = eStart, *eLoopStop = eStart;
+ for (;;)
+ {
+ //nb: allows matching start and end points when not Closed ...
+ if (E->Curr == E->Next->Curr && (Closed || E->Next != eStart))
+ {
+ if (E == E->Next) break;
+ if (E == eStart) eStart = E->Next;
+ E = RemoveEdge(E);
+ eLoopStop = E;
+ continue;
+ }
+ if (E->Prev == E->Next)
+ break; //only two vertices
+ else if (Closed &&
+ SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr, m_UseFullRange) &&
+ (!m_PreserveCollinear ||
+ !Pt2IsBetweenPt1AndPt3(E->Prev->Curr, E->Curr, E->Next->Curr)))
+ {
+ //Collinear edges are allowed for open paths but in closed paths
+ //the default is to merge adjacent collinear edges into a single edge.
+ //However, if the PreserveCollinear property is enabled, only overlapping
+ //collinear edges (ie spikes) will be removed from closed paths.
+ if (E == eStart) eStart = E->Next;
+ E = RemoveEdge(E);
+ E = E->Prev;
+ eLoopStop = E;
+ continue;
+ }
+ E = E->Next;
+ if ((E == eLoopStop) || (!Closed && E->Next == eStart)) break;
+ }
+
+ if ((!Closed && (E == E->Next)) || (Closed && (E->Prev == E->Next)))
+ {
+ delete [] edges;
+ return false;
+ }
+
+ if (!Closed)
+ {
+ m_HasOpenPaths = true;
+ eStart->Prev->OutIdx = Skip;
+ }
+
+ //3. Do second stage of edge initialization ...
+ E = eStart;
+ do
+ {
+ InitEdge2(*E, PolyTyp);
+ E = E->Next;
+ if (IsFlat && E->Curr.Y != eStart->Curr.Y) IsFlat = false;
+ }
+ while (E != eStart);
+
+ //4. Finally, add edge bounds to LocalMinima list ...
+
+ //Totally flat paths must be handled differently when adding them
+ //to LocalMinima list to avoid endless loops etc ...
+ if (IsFlat)
+ {
+ if (Closed)
+ {
+ delete [] edges;
+ return false;
+ }
+ E->Prev->OutIdx = Skip;
+ MinimaList::value_type locMin;
+ locMin.Y = E->Bot.Y;
+ locMin.LeftBound = 0;
+ locMin.RightBound = E;
+ locMin.RightBound->Side = esRight;
+ locMin.RightBound->WindDelta = 0;
+ for (;;)
+ {
+ if (E->Bot.X != E->Prev->Top.X) ReverseHorizontal(*E);
+ if (E->Next->OutIdx == Skip) break;
+ E->NextInLML = E->Next;
+ E = E->Next;
+ }
+ m_MinimaList.push_back(locMin);
+ m_edges.push_back(edges);
+ return true;
+ }
+
+ m_edges.push_back(edges);
+ bool leftBoundIsForward;
+ TEdge* EMin = 0;
+
+ //workaround to avoid an endless loop in the while loop below when
+ //open paths have matching start and end points ...
+ if (E->Prev->Bot == E->Prev->Top) E = E->Next;
+
+ for (;;)
+ {
+ E = FindNextLocMin(E);
+ if (E == EMin) break;
+ else if (!EMin) EMin = E;
+
+ //E and E.Prev now share a local minima (left aligned if horizontal).
+ //Compare their slopes to find which starts which bound ...
+ MinimaList::value_type locMin;
+ locMin.Y = E->Bot.Y;
+ if (E->Dx < E->Prev->Dx)
+ {
+ locMin.LeftBound = E->Prev;
+ locMin.RightBound = E;
+ leftBoundIsForward = false; //Q.nextInLML = Q.prev
+ } else
+ {
+ locMin.LeftBound = E;
+ locMin.RightBound = E->Prev;
+ leftBoundIsForward = true; //Q.nextInLML = Q.next
+ }
+
+ if (!Closed) locMin.LeftBound->WindDelta = 0;
+ else if (locMin.LeftBound->Next == locMin.RightBound)
+ locMin.LeftBound->WindDelta = -1;
+ else locMin.LeftBound->WindDelta = 1;
+ locMin.RightBound->WindDelta = -locMin.LeftBound->WindDelta;
+
+ E = ProcessBound(locMin.LeftBound, leftBoundIsForward);
+ if (E->OutIdx == Skip) E = ProcessBound(E, leftBoundIsForward);
+
+ TEdge* E2 = ProcessBound(locMin.RightBound, !leftBoundIsForward);
+ if (E2->OutIdx == Skip) E2 = ProcessBound(E2, !leftBoundIsForward);
+
+ if (locMin.LeftBound->OutIdx == Skip)
+ locMin.LeftBound = 0;
+ else if (locMin.RightBound->OutIdx == Skip)
+ locMin.RightBound = 0;
+ m_MinimaList.push_back(locMin);
+ if (!leftBoundIsForward) E = E2;
+ }
+ return true;
+}
+//------------------------------------------------------------------------------
+
+bool ClipperBase::AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed)
+{
+ bool result = false;
+ for (Paths::size_type i = 0; i < ppg.size(); ++i)
+ if (AddPath(ppg[i], PolyTyp, Closed)) result = true;
+ return result;
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::Clear()
+{
+ DisposeLocalMinimaList();
+ for (EdgeList::size_type i = 0; i < m_edges.size(); ++i)
+ {
+ TEdge* edges = m_edges[i];
+ delete [] edges;
+ }
+ m_edges.clear();
+ m_UseFullRange = false;
+ m_HasOpenPaths = false;
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::Reset()
+{
+ m_CurrentLM = m_MinimaList.begin();
+ if (m_CurrentLM == m_MinimaList.end()) return; //ie nothing to process
+ std::sort(m_MinimaList.begin(), m_MinimaList.end(), LocMinSorter());
+
+ m_Scanbeam = ScanbeamList(); //clears/resets priority_queue
+ //reset all edges ...
+ for (MinimaList::iterator lm = m_MinimaList.begin(); lm != m_MinimaList.end(); ++lm)
+ {
+ InsertScanbeam(lm->Y);
+ TEdge* e = lm->LeftBound;
+ if (e)
+ {
+ e->Curr = e->Bot;
+ e->Side = esLeft;
+ e->OutIdx = Unassigned;
+ }
+
+ e = lm->RightBound;
+ if (e)
+ {
+ e->Curr = e->Bot;
+ e->Side = esRight;
+ e->OutIdx = Unassigned;
+ }
+ }
+ m_ActiveEdges = 0;
+ m_CurrentLM = m_MinimaList.begin();
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::DisposeLocalMinimaList()
+{
+ m_MinimaList.clear();
+ m_CurrentLM = m_MinimaList.begin();
+}
+//------------------------------------------------------------------------------
+
+bool ClipperBase::PopLocalMinima(cInt Y, const LocalMinimum *&locMin)
+{
+ if (m_CurrentLM == m_MinimaList.end() || (*m_CurrentLM).Y != Y) return false;
+ locMin = &(*m_CurrentLM);
+ ++m_CurrentLM;
+ return true;
+}
+//------------------------------------------------------------------------------
+
+IntRect ClipperBase::GetBounds()
+{
+ IntRect result;
+ MinimaList::iterator lm = m_MinimaList.begin();
+ if (lm == m_MinimaList.end())
+ {
+ result.left = result.top = result.right = result.bottom = 0;
+ return result;
+ }
+ result.left = lm->LeftBound->Bot.X;
+ result.top = lm->LeftBound->Bot.Y;
+ result.right = lm->LeftBound->Bot.X;
+ result.bottom = lm->LeftBound->Bot.Y;
+ while (lm != m_MinimaList.end())
+ {
+ //todo - needs fixing for open paths
+ result.bottom = std::max(result.bottom, lm->LeftBound->Bot.Y);
+ TEdge* e = lm->LeftBound;
+ for (;;) {
+ TEdge* bottomE = e;
+ while (e->NextInLML)
+ {
+ if (e->Bot.X < result.left) result.left = e->Bot.X;
+ if (e->Bot.X > result.right) result.right = e->Bot.X;
+ e = e->NextInLML;
+ }
+ result.left = std::min(result.left, e->Bot.X);
+ result.right = std::max(result.right, e->Bot.X);
+ result.left = std::min(result.left, e->Top.X);
+ result.right = std::max(result.right, e->Top.X);
+ result.top = std::min(result.top, e->Top.Y);
+ if (bottomE == lm->LeftBound) e = lm->RightBound;
+ else break;
+ }
+ ++lm;
+ }
+ return result;
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::InsertScanbeam(const cInt Y)
+{
+ m_Scanbeam.push(Y);
+}
+//------------------------------------------------------------------------------
+
+bool ClipperBase::PopScanbeam(cInt &Y)
+{
+ if (m_Scanbeam.empty()) return false;
+ Y = m_Scanbeam.top();
+ m_Scanbeam.pop();
+ while (!m_Scanbeam.empty() && Y == m_Scanbeam.top()) { m_Scanbeam.pop(); } // Pop duplicates.
+ return true;
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::DisposeAllOutRecs(){
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
+ DisposeOutRec(i);
+ m_PolyOuts.clear();
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::DisposeOutRec(PolyOutList::size_type index)
+{
+ OutRec *outRec = m_PolyOuts[index];
+ if (outRec->Pts) DisposeOutPts(outRec->Pts);
+ delete outRec;
+ m_PolyOuts[index] = 0;
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::DeleteFromAEL(TEdge *e)
+{
+ TEdge* AelPrev = e->PrevInAEL;
+ TEdge* AelNext = e->NextInAEL;
+ if (!AelPrev && !AelNext && (e != m_ActiveEdges)) return; //already deleted
+ if (AelPrev) AelPrev->NextInAEL = AelNext;
+ else m_ActiveEdges = AelNext;
+ if (AelNext) AelNext->PrevInAEL = AelPrev;
+ e->NextInAEL = 0;
+ e->PrevInAEL = 0;
+}
+//------------------------------------------------------------------------------
+
+OutRec* ClipperBase::CreateOutRec()
+{
+ OutRec* result = new OutRec;
+ result->IsHole = false;
+ result->IsOpen = false;
+ result->FirstLeft = 0;
+ result->Pts = 0;
+ result->BottomPt = 0;
+ result->PolyNd = 0;
+ m_PolyOuts.push_back(result);
+ result->Idx = (int)m_PolyOuts.size() - 1;
+ return result;
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::SwapPositionsInAEL(TEdge *Edge1, TEdge *Edge2)
+{
+ //check that one or other edge hasn't already been removed from AEL ...
+ if (Edge1->NextInAEL == Edge1->PrevInAEL ||
+ Edge2->NextInAEL == Edge2->PrevInAEL) return;
+
+ if (Edge1->NextInAEL == Edge2)
+ {
+ TEdge* Next = Edge2->NextInAEL;
+ if (Next) Next->PrevInAEL = Edge1;
+ TEdge* Prev = Edge1->PrevInAEL;
+ if (Prev) Prev->NextInAEL = Edge2;
+ Edge2->PrevInAEL = Prev;
+ Edge2->NextInAEL = Edge1;
+ Edge1->PrevInAEL = Edge2;
+ Edge1->NextInAEL = Next;
+ }
+ else if (Edge2->NextInAEL == Edge1)
+ {
+ TEdge* Next = Edge1->NextInAEL;
+ if (Next) Next->PrevInAEL = Edge2;
+ TEdge* Prev = Edge2->PrevInAEL;
+ if (Prev) Prev->NextInAEL = Edge1;
+ Edge1->PrevInAEL = Prev;
+ Edge1->NextInAEL = Edge2;
+ Edge2->PrevInAEL = Edge1;
+ Edge2->NextInAEL = Next;
+ }
+ else
+ {
+ TEdge* Next = Edge1->NextInAEL;
+ TEdge* Prev = Edge1->PrevInAEL;
+ Edge1->NextInAEL = Edge2->NextInAEL;
+ if (Edge1->NextInAEL) Edge1->NextInAEL->PrevInAEL = Edge1;
+ Edge1->PrevInAEL = Edge2->PrevInAEL;
+ if (Edge1->PrevInAEL) Edge1->PrevInAEL->NextInAEL = Edge1;
+ Edge2->NextInAEL = Next;
+ if (Edge2->NextInAEL) Edge2->NextInAEL->PrevInAEL = Edge2;
+ Edge2->PrevInAEL = Prev;
+ if (Edge2->PrevInAEL) Edge2->PrevInAEL->NextInAEL = Edge2;
+ }
+
+ if (!Edge1->PrevInAEL) m_ActiveEdges = Edge1;
+ else if (!Edge2->PrevInAEL) m_ActiveEdges = Edge2;
+}
+//------------------------------------------------------------------------------
+
+void ClipperBase::UpdateEdgeIntoAEL(TEdge *&e)
+{
+ if (!e->NextInLML)
+ throw clipperException("UpdateEdgeIntoAEL: invalid call");
+
+ e->NextInLML->OutIdx = e->OutIdx;
+ TEdge* AelPrev = e->PrevInAEL;
+ TEdge* AelNext = e->NextInAEL;
+ if (AelPrev) AelPrev->NextInAEL = e->NextInLML;
+ else m_ActiveEdges = e->NextInLML;
+ if (AelNext) AelNext->PrevInAEL = e->NextInLML;
+ e->NextInLML->Side = e->Side;
+ e->NextInLML->WindDelta = e->WindDelta;
+ e->NextInLML->WindCnt = e->WindCnt;
+ e->NextInLML->WindCnt2 = e->WindCnt2;
+ e = e->NextInLML;
+ e->Curr = e->Bot;
+ e->PrevInAEL = AelPrev;
+ e->NextInAEL = AelNext;
+ if (!IsHorizontal(*e)) InsertScanbeam(e->Top.Y);
+}
+//------------------------------------------------------------------------------
+
+bool ClipperBase::LocalMinimaPending()
+{
+ return (m_CurrentLM != m_MinimaList.end());
+}
+
+//------------------------------------------------------------------------------
+// TClipper methods ...
+//------------------------------------------------------------------------------
+
+Clipper::Clipper(int initOptions) : ClipperBase() //constructor
+{
+ m_ExecuteLocked = false;
+ m_UseFullRange = false;
+ m_ReverseOutput = ((initOptions & ioReverseSolution) != 0);
+ m_StrictSimple = ((initOptions & ioStrictlySimple) != 0);
+ m_PreserveCollinear = ((initOptions & ioPreserveCollinear) != 0);
+ m_HasOpenPaths = false;
+#ifdef use_xyz
+ m_ZFill = 0;
+#endif
+}
+//------------------------------------------------------------------------------
+
+#ifdef use_xyz
+void Clipper::ZFillFunction(ZFillCallback zFillFunc)
+{
+ m_ZFill = zFillFunc;
+}
+//------------------------------------------------------------------------------
+#endif
+
+bool Clipper::Execute(ClipType clipType, Paths &solution, PolyFillType fillType)
+{
+ return Execute(clipType, solution, fillType, fillType);
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::Execute(ClipType clipType, PolyTree &polytree, PolyFillType fillType)
+{
+ return Execute(clipType, polytree, fillType, fillType);
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::Execute(ClipType clipType, Paths &solution,
+ PolyFillType subjFillType, PolyFillType clipFillType)
+{
+ if( m_ExecuteLocked ) return false;
+ if (m_HasOpenPaths)
+ throw clipperException("Error: PolyTree struct is needed for open path clipping.");
+ m_ExecuteLocked = true;
+ solution.resize(0);
+ m_SubjFillType = subjFillType;
+ m_ClipFillType = clipFillType;
+ m_ClipType = clipType;
+ m_UsingPolyTree = false;
+ bool succeeded = ExecuteInternal();
+ if (succeeded) BuildResult(solution);
+ DisposeAllOutRecs();
+ m_ExecuteLocked = false;
+ return succeeded;
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::Execute(ClipType clipType, PolyTree& polytree,
+ PolyFillType subjFillType, PolyFillType clipFillType)
+{
+ if( m_ExecuteLocked ) return false;
+ m_ExecuteLocked = true;
+ m_SubjFillType = subjFillType;
+ m_ClipFillType = clipFillType;
+ m_ClipType = clipType;
+ m_UsingPolyTree = true;
+ bool succeeded = ExecuteInternal();
+ if (succeeded) BuildResult2(polytree);
+ DisposeAllOutRecs();
+ m_ExecuteLocked = false;
+ return succeeded;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::FixHoleLinkage(OutRec &outrec)
+{
+ //skip OutRecs that (a) contain outermost polygons or
+ //(b) already have the correct owner/child linkage ...
+ if (!outrec.FirstLeft ||
+ (outrec.IsHole != outrec.FirstLeft->IsHole &&
+ outrec.FirstLeft->Pts)) return;
+
+ OutRec* orfl = outrec.FirstLeft;
+ while (orfl && ((orfl->IsHole == outrec.IsHole) || !orfl->Pts))
+ orfl = orfl->FirstLeft;
+ outrec.FirstLeft = orfl;
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::ExecuteInternal()
+{
+ bool succeeded = true;
+ try {
+ Reset();
+ m_Maxima = MaximaList();
+ m_SortedEdges = 0;
+
+ succeeded = true;
+ cInt botY, topY;
+ if (!PopScanbeam(botY)) return false;
+ InsertLocalMinimaIntoAEL(botY);
+ while (PopScanbeam(topY) || LocalMinimaPending())
+ {
+ ProcessHorizontals();
+ ClearGhostJoins();
+ if (!ProcessIntersections(topY))
+ {
+ succeeded = false;
+ break;
+ }
+ ProcessEdgesAtTopOfScanbeam(topY);
+ botY = topY;
+ InsertLocalMinimaIntoAEL(botY);
+ }
+ }
+ catch(...)
+ {
+ succeeded = false;
+ }
+
+ if (succeeded)
+ {
+ //fix orientations ...
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
+ {
+ OutRec *outRec = m_PolyOuts[i];
+ if (!outRec->Pts || outRec->IsOpen) continue;
+ if ((outRec->IsHole ^ m_ReverseOutput) == (Area(*outRec) > 0))
+ ReversePolyPtLinks(outRec->Pts);
+ }
+
+ if (!m_Joins.empty()) JoinCommonEdges();
+
+ //unfortunately FixupOutPolygon() must be done after JoinCommonEdges()
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
+ {
+ OutRec *outRec = m_PolyOuts[i];
+ if (!outRec->Pts) continue;
+ if (outRec->IsOpen)
+ FixupOutPolyline(*outRec);
+ else
+ FixupOutPolygon(*outRec);
+ }
+
+ if (m_StrictSimple) DoSimplePolygons();
+ }
+
+ ClearJoins();
+ ClearGhostJoins();
+ return succeeded;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::SetWindingCount(TEdge &edge)
+{
+ TEdge *e = edge.PrevInAEL;
+ //find the edge of the same polytype that immediately preceeds 'edge' in AEL
+ while (e && ((e->PolyTyp != edge.PolyTyp) || (e->WindDelta == 0))) e = e->PrevInAEL;
+ if (!e)
+ {
+ if (edge.WindDelta == 0)
+ {
+ PolyFillType pft = (edge.PolyTyp == ptSubject ? m_SubjFillType : m_ClipFillType);
+ edge.WindCnt = (pft == pftNegative ? -1 : 1);
+ }
+ else
+ edge.WindCnt = edge.WindDelta;
+ edge.WindCnt2 = 0;
+ e = m_ActiveEdges; //ie get ready to calc WindCnt2
+ }
+ else if (edge.WindDelta == 0 && m_ClipType != ctUnion)
+ {
+ edge.WindCnt = 1;
+ edge.WindCnt2 = e->WindCnt2;
+ e = e->NextInAEL; //ie get ready to calc WindCnt2
+ }
+ else if (IsEvenOddFillType(edge))
+ {
+ //EvenOdd filling ...
+ if (edge.WindDelta == 0)
+ {
+ //are we inside a subj polygon ...
+ bool Inside = true;
+ TEdge *e2 = e->PrevInAEL;
+ while (e2)
+ {
+ if (e2->PolyTyp == e->PolyTyp && e2->WindDelta != 0)
+ Inside = !Inside;
+ e2 = e2->PrevInAEL;
+ }
+ edge.WindCnt = (Inside ? 0 : 1);
+ }
+ else
+ {
+ edge.WindCnt = edge.WindDelta;
+ }
+ edge.WindCnt2 = e->WindCnt2;
+ e = e->NextInAEL; //ie get ready to calc WindCnt2
+ }
+ else
+ {
+ //nonZero, Positive or Negative filling ...
+ if (e->WindCnt * e->WindDelta < 0)
+ {
+ //prev edge is 'decreasing' WindCount (WC) toward zero
+ //so we're outside the previous polygon ...
+ if (Abs(e->WindCnt) > 1)
+ {
+ //outside prev poly but still inside another.
+ //when reversing direction of prev poly use the same WC
+ if (e->WindDelta * edge.WindDelta < 0) edge.WindCnt = e->WindCnt;
+ //otherwise continue to 'decrease' WC ...
+ else edge.WindCnt = e->WindCnt + edge.WindDelta;
+ }
+ else
+ //now outside all polys of same polytype so set own WC ...
+ edge.WindCnt = (edge.WindDelta == 0 ? 1 : edge.WindDelta);
+ } else
+ {
+ //prev edge is 'increasing' WindCount (WC) away from zero
+ //so we're inside the previous polygon ...
+ if (edge.WindDelta == 0)
+ edge.WindCnt = (e->WindCnt < 0 ? e->WindCnt - 1 : e->WindCnt + 1);
+ //if wind direction is reversing prev then use same WC
+ else if (e->WindDelta * edge.WindDelta < 0) edge.WindCnt = e->WindCnt;
+ //otherwise add to WC ...
+ else edge.WindCnt = e->WindCnt + edge.WindDelta;
+ }
+ edge.WindCnt2 = e->WindCnt2;
+ e = e->NextInAEL; //ie get ready to calc WindCnt2
+ }
+
+ //update WindCnt2 ...
+ if (IsEvenOddAltFillType(edge))
+ {
+ //EvenOdd filling ...
+ while (e != &edge)
+ {
+ if (e->WindDelta != 0)
+ edge.WindCnt2 = (edge.WindCnt2 == 0 ? 1 : 0);
+ e = e->NextInAEL;
+ }
+ } else
+ {
+ //nonZero, Positive or Negative filling ...
+ while ( e != &edge )
+ {
+ edge.WindCnt2 += e->WindDelta;
+ e = e->NextInAEL;
+ }
+ }
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::IsEvenOddFillType(const TEdge& edge) const
+{
+ if (edge.PolyTyp == ptSubject)
+ return m_SubjFillType == pftEvenOdd; else
+ return m_ClipFillType == pftEvenOdd;
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::IsEvenOddAltFillType(const TEdge& edge) const
+{
+ if (edge.PolyTyp == ptSubject)
+ return m_ClipFillType == pftEvenOdd; else
+ return m_SubjFillType == pftEvenOdd;
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::IsContributing(const TEdge& edge) const
+{
+ PolyFillType pft, pft2;
+ if (edge.PolyTyp == ptSubject)
+ {
+ pft = m_SubjFillType;
+ pft2 = m_ClipFillType;
+ } else
+ {
+ pft = m_ClipFillType;
+ pft2 = m_SubjFillType;
+ }
+
+ switch(pft)
+ {
+ case pftEvenOdd:
+ //return false if a subj line has been flagged as inside a subj polygon
+ if (edge.WindDelta == 0 && edge.WindCnt != 1) return false;
+ break;
+ case pftNonZero:
+ if (Abs(edge.WindCnt) != 1) return false;
+ break;
+ case pftPositive:
+ if (edge.WindCnt != 1) return false;
+ break;
+ default: //pftNegative
+ if (edge.WindCnt != -1) return false;
+ }
+
+ switch(m_ClipType)
+ {
+ case ctIntersection:
+ switch(pft2)
+ {
+ case pftEvenOdd:
+ case pftNonZero:
+ return (edge.WindCnt2 != 0);
+ case pftPositive:
+ return (edge.WindCnt2 > 0);
+ default:
+ return (edge.WindCnt2 < 0);
+ }
+ break;
+ case ctUnion:
+ switch(pft2)
+ {
+ case pftEvenOdd:
+ case pftNonZero:
+ return (edge.WindCnt2 == 0);
+ case pftPositive:
+ return (edge.WindCnt2 <= 0);
+ default:
+ return (edge.WindCnt2 >= 0);
+ }
+ break;
+ case ctDifference:
+ if (edge.PolyTyp == ptSubject)
+ switch(pft2)
+ {
+ case pftEvenOdd:
+ case pftNonZero:
+ return (edge.WindCnt2 == 0);
+ case pftPositive:
+ return (edge.WindCnt2 <= 0);
+ default:
+ return (edge.WindCnt2 >= 0);
+ }
+ else
+ switch(pft2)
+ {
+ case pftEvenOdd:
+ case pftNonZero:
+ return (edge.WindCnt2 != 0);
+ case pftPositive:
+ return (edge.WindCnt2 > 0);
+ default:
+ return (edge.WindCnt2 < 0);
+ }
+ break;
+ case ctXor:
+ if (edge.WindDelta == 0) //XOr always contributing unless open
+ switch(pft2)
+ {
+ case pftEvenOdd:
+ case pftNonZero:
+ return (edge.WindCnt2 == 0);
+ case pftPositive:
+ return (edge.WindCnt2 <= 0);
+ default:
+ return (edge.WindCnt2 >= 0);
+ }
+ else
+ return true;
+ break;
+ default:
+ return true;
+ }
+}
+//------------------------------------------------------------------------------
+
+OutPt* Clipper::AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt)
+{
+ OutPt* result;
+ TEdge *e, *prevE;
+ if (IsHorizontal(*e2) || ( e1->Dx > e2->Dx ))
+ {
+ result = AddOutPt(e1, Pt);
+ e2->OutIdx = e1->OutIdx;
+ e1->Side = esLeft;
+ e2->Side = esRight;
+ e = e1;
+ if (e->PrevInAEL == e2)
+ prevE = e2->PrevInAEL;
+ else
+ prevE = e->PrevInAEL;
+ } else
+ {
+ result = AddOutPt(e2, Pt);
+ e1->OutIdx = e2->OutIdx;
+ e1->Side = esRight;
+ e2->Side = esLeft;
+ e = e2;
+ if (e->PrevInAEL == e1)
+ prevE = e1->PrevInAEL;
+ else
+ prevE = e->PrevInAEL;
+ }
+
+ if (prevE && prevE->OutIdx >= 0)
+ {
+ cInt xPrev = TopX(*prevE, Pt.Y);
+ cInt xE = TopX(*e, Pt.Y);
+ if (xPrev == xE && (e->WindDelta != 0) && (prevE->WindDelta != 0) &&
+ SlopesEqual(IntPoint(xPrev, Pt.Y), prevE->Top, IntPoint(xE, Pt.Y), e->Top, m_UseFullRange))
+ {
+ OutPt* outPt = AddOutPt(prevE, Pt);
+ AddJoin(result, outPt, e->Top);
+ }
+ }
+ return result;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt)
+{
+ AddOutPt( e1, Pt );
+ if (e2->WindDelta == 0) AddOutPt(e2, Pt);
+ if( e1->OutIdx == e2->OutIdx )
+ {
+ e1->OutIdx = Unassigned;
+ e2->OutIdx = Unassigned;
+ }
+ else if (e1->OutIdx < e2->OutIdx)
+ AppendPolygon(e1, e2);
+ else
+ AppendPolygon(e2, e1);
+}
+//------------------------------------------------------------------------------
+
+void Clipper::AddEdgeToSEL(TEdge *edge)
+{
+ //SEL pointers in PEdge are reused to build a list of horizontal edges.
+ //However, we don't need to worry about order with horizontal edge processing.
+ if( !m_SortedEdges )
+ {
+ m_SortedEdges = edge;
+ edge->PrevInSEL = 0;
+ edge->NextInSEL = 0;
+ }
+ else
+ {
+ edge->NextInSEL = m_SortedEdges;
+ edge->PrevInSEL = 0;
+ m_SortedEdges->PrevInSEL = edge;
+ m_SortedEdges = edge;
+ }
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::PopEdgeFromSEL(TEdge *&edge)
+{
+ if (!m_SortedEdges) return false;
+ edge = m_SortedEdges;
+ DeleteFromSEL(m_SortedEdges);
+ return true;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::CopyAELToSEL()
+{
+ TEdge* e = m_ActiveEdges;
+ m_SortedEdges = e;
+ while ( e )
+ {
+ e->PrevInSEL = e->PrevInAEL;
+ e->NextInSEL = e->NextInAEL;
+ e = e->NextInAEL;
+ }
+}
+//------------------------------------------------------------------------------
+
+void Clipper::AddJoin(OutPt *op1, OutPt *op2, const IntPoint OffPt)
+{
+ Join* j = new Join;
+ j->OutPt1 = op1;
+ j->OutPt2 = op2;
+ j->OffPt = OffPt;
+ m_Joins.push_back(j);
+}
+//------------------------------------------------------------------------------
+
+void Clipper::ClearJoins()
+{
+ for (JoinList::size_type i = 0; i < m_Joins.size(); i++)
+ delete m_Joins[i];
+ m_Joins.resize(0);
+}
+//------------------------------------------------------------------------------
+
+void Clipper::ClearGhostJoins()
+{
+ for (JoinList::size_type i = 0; i < m_GhostJoins.size(); i++)
+ delete m_GhostJoins[i];
+ m_GhostJoins.resize(0);
+}
+//------------------------------------------------------------------------------
+
+void Clipper::AddGhostJoin(OutPt *op, const IntPoint OffPt)
+{
+ Join* j = new Join;
+ j->OutPt1 = op;
+ j->OutPt2 = 0;
+ j->OffPt = OffPt;
+ m_GhostJoins.push_back(j);
+}
+//------------------------------------------------------------------------------
+
+void Clipper::InsertLocalMinimaIntoAEL(const cInt botY)
+{
+ const LocalMinimum *lm;
+ while (PopLocalMinima(botY, lm))
+ {
+ TEdge* lb = lm->LeftBound;
+ TEdge* rb = lm->RightBound;
+
+ OutPt *Op1 = 0;
+ if (!lb)
+ {
+ //nb: don't insert LB into either AEL or SEL
+ InsertEdgeIntoAEL(rb, 0);
+ SetWindingCount(*rb);
+ if (IsContributing(*rb))
+ Op1 = AddOutPt(rb, rb->Bot);
+ }
+ else if (!rb)
+ {
+ InsertEdgeIntoAEL(lb, 0);
+ SetWindingCount(*lb);
+ if (IsContributing(*lb))
+ Op1 = AddOutPt(lb, lb->Bot);
+ InsertScanbeam(lb->Top.Y);
+ }
+ else
+ {
+ InsertEdgeIntoAEL(lb, 0);
+ InsertEdgeIntoAEL(rb, lb);
+ SetWindingCount( *lb );
+ rb->WindCnt = lb->WindCnt;
+ rb->WindCnt2 = lb->WindCnt2;
+ if (IsContributing(*lb))
+ Op1 = AddLocalMinPoly(lb, rb, lb->Bot);
+ InsertScanbeam(lb->Top.Y);
+ }
+
+ if (rb)
+ {
+ if (IsHorizontal(*rb))
+ {
+ AddEdgeToSEL(rb);
+ if (rb->NextInLML)
+ InsertScanbeam(rb->NextInLML->Top.Y);
+ }
+ else InsertScanbeam( rb->Top.Y );
+ }
+
+ if (!lb || !rb) continue;
+
+ //if any output polygons share an edge, they'll need joining later ...
+ if (Op1 && IsHorizontal(*rb) &&
+ m_GhostJoins.size() > 0 && (rb->WindDelta != 0))
+ {
+ for (JoinList::size_type i = 0; i < m_GhostJoins.size(); ++i)
+ {
+ Join* jr = m_GhostJoins[i];
+ //if the horizontal Rb and a 'ghost' horizontal overlap, then convert
+ //the 'ghost' join to a real join ready for later ...
+ if (HorzSegmentsOverlap(jr->OutPt1->Pt.X, jr->OffPt.X, rb->Bot.X, rb->Top.X))
+ AddJoin(jr->OutPt1, Op1, jr->OffPt);
+ }
+ }
+
+ if (lb->OutIdx >= 0 && lb->PrevInAEL &&
+ lb->PrevInAEL->Curr.X == lb->Bot.X &&
+ lb->PrevInAEL->OutIdx >= 0 &&
+ SlopesEqual(lb->PrevInAEL->Bot, lb->PrevInAEL->Top, lb->Curr, lb->Top, m_UseFullRange) &&
+ (lb->WindDelta != 0) && (lb->PrevInAEL->WindDelta != 0))
+ {
+ OutPt *Op2 = AddOutPt(lb->PrevInAEL, lb->Bot);
+ AddJoin(Op1, Op2, lb->Top);
+ }
+
+ if(lb->NextInAEL != rb)
+ {
+
+ if (rb->OutIdx >= 0 && rb->PrevInAEL->OutIdx >= 0 &&
+ SlopesEqual(rb->PrevInAEL->Curr, rb->PrevInAEL->Top, rb->Curr, rb->Top, m_UseFullRange) &&
+ (rb->WindDelta != 0) && (rb->PrevInAEL->WindDelta != 0))
+ {
+ OutPt *Op2 = AddOutPt(rb->PrevInAEL, rb->Bot);
+ AddJoin(Op1, Op2, rb->Top);
+ }
+
+ TEdge* e = lb->NextInAEL;
+ if (e)
+ {
+ while( e != rb )
+ {
+ //nb: For calculating winding counts etc, IntersectEdges() assumes
+ //that param1 will be to the Right of param2 ABOVE the intersection ...
+ IntersectEdges(rb , e , lb->Curr); //order important here
+ e = e->NextInAEL;
+ }
+ }
+ }
+
+ }
+}
+//------------------------------------------------------------------------------
+
+void Clipper::DeleteFromSEL(TEdge *e)
+{
+ TEdge* SelPrev = e->PrevInSEL;
+ TEdge* SelNext = e->NextInSEL;
+ if( !SelPrev && !SelNext && (e != m_SortedEdges) ) return; //already deleted
+ if( SelPrev ) SelPrev->NextInSEL = SelNext;
+ else m_SortedEdges = SelNext;
+ if( SelNext ) SelNext->PrevInSEL = SelPrev;
+ e->NextInSEL = 0;
+ e->PrevInSEL = 0;
+}
+//------------------------------------------------------------------------------
+
+#ifdef use_xyz
+void Clipper::SetZ(IntPoint& pt, TEdge& e1, TEdge& e2)
+{
+ if (pt.Z != 0 || !m_ZFill) return;
+ else if (pt == e1.Bot) pt.Z = e1.Bot.Z;
+ else if (pt == e1.Top) pt.Z = e1.Top.Z;
+ else if (pt == e2.Bot) pt.Z = e2.Bot.Z;
+ else if (pt == e2.Top) pt.Z = e2.Top.Z;
+ else (*m_ZFill)(e1.Bot, e1.Top, e2.Bot, e2.Top, pt);
+}
+//------------------------------------------------------------------------------
+#endif
+
+void Clipper::IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &Pt)
+{
+ bool e1Contributing = ( e1->OutIdx >= 0 );
+ bool e2Contributing = ( e2->OutIdx >= 0 );
+
+#ifdef use_xyz
+ SetZ(Pt, *e1, *e2);
+#endif
+
+#ifdef use_lines
+ //if either edge is on an OPEN path ...
+ if (e1->WindDelta == 0 || e2->WindDelta == 0)
+ {
+ //ignore subject-subject open path intersections UNLESS they
+ //are both open paths, AND they are both 'contributing maximas' ...
+ if (e1->WindDelta == 0 && e2->WindDelta == 0) return;
+
+ //if intersecting a subj line with a subj poly ...
+ else if (e1->PolyTyp == e2->PolyTyp &&
+ e1->WindDelta != e2->WindDelta && m_ClipType == ctUnion)
+ {
+ if (e1->WindDelta == 0)
+ {
+ if (e2Contributing)
+ {
+ AddOutPt(e1, Pt);
+ if (e1Contributing) e1->OutIdx = Unassigned;
+ }
+ }
+ else
+ {
+ if (e1Contributing)
+ {
+ AddOutPt(e2, Pt);
+ if (e2Contributing) e2->OutIdx = Unassigned;
+ }
+ }
+ }
+ else if (e1->PolyTyp != e2->PolyTyp)
+ {
+ //toggle subj open path OutIdx on/off when Abs(clip.WndCnt) == 1 ...
+ if ((e1->WindDelta == 0) && abs(e2->WindCnt) == 1 &&
+ (m_ClipType != ctUnion || e2->WindCnt2 == 0))
+ {
+ AddOutPt(e1, Pt);
+ if (e1Contributing) e1->OutIdx = Unassigned;
+ }
+ else if ((e2->WindDelta == 0) && (abs(e1->WindCnt) == 1) &&
+ (m_ClipType != ctUnion || e1->WindCnt2 == 0))
+ {
+ AddOutPt(e2, Pt);
+ if (e2Contributing) e2->OutIdx = Unassigned;
+ }
+ }
+ return;
+ }
+#endif
+
+ //update winding counts...
+ //assumes that e1 will be to the Right of e2 ABOVE the intersection
+ if ( e1->PolyTyp == e2->PolyTyp )
+ {
+ if ( IsEvenOddFillType( *e1) )
+ {
+ int oldE1WindCnt = e1->WindCnt;
+ e1->WindCnt = e2->WindCnt;
+ e2->WindCnt = oldE1WindCnt;
+ } else
+ {
+ if (e1->WindCnt + e2->WindDelta == 0 ) e1->WindCnt = -e1->WindCnt;
+ else e1->WindCnt += e2->WindDelta;
+ if ( e2->WindCnt - e1->WindDelta == 0 ) e2->WindCnt = -e2->WindCnt;
+ else e2->WindCnt -= e1->WindDelta;
+ }
+ } else
+ {
+ if (!IsEvenOddFillType(*e2)) e1->WindCnt2 += e2->WindDelta;
+ else e1->WindCnt2 = ( e1->WindCnt2 == 0 ) ? 1 : 0;
+ if (!IsEvenOddFillType(*e1)) e2->WindCnt2 -= e1->WindDelta;
+ else e2->WindCnt2 = ( e2->WindCnt2 == 0 ) ? 1 : 0;
+ }
+
+ PolyFillType e1FillType, e2FillType, e1FillType2, e2FillType2;
+ if (e1->PolyTyp == ptSubject)
+ {
+ e1FillType = m_SubjFillType;
+ e1FillType2 = m_ClipFillType;
+ } else
+ {
+ e1FillType = m_ClipFillType;
+ e1FillType2 = m_SubjFillType;
+ }
+ if (e2->PolyTyp == ptSubject)
+ {
+ e2FillType = m_SubjFillType;
+ e2FillType2 = m_ClipFillType;
+ } else
+ {
+ e2FillType = m_ClipFillType;
+ e2FillType2 = m_SubjFillType;
+ }
+
+ cInt e1Wc, e2Wc;
+ switch (e1FillType)
+ {
+ case pftPositive: e1Wc = e1->WindCnt; break;
+ case pftNegative: e1Wc = -e1->WindCnt; break;
+ default: e1Wc = Abs(e1->WindCnt);
+ }
+ switch(e2FillType)
+ {
+ case pftPositive: e2Wc = e2->WindCnt; break;
+ case pftNegative: e2Wc = -e2->WindCnt; break;
+ default: e2Wc = Abs(e2->WindCnt);
+ }
+
+ if ( e1Contributing && e2Contributing )
+ {
+ if ((e1Wc != 0 && e1Wc != 1) || (e2Wc != 0 && e2Wc != 1) ||
+ (e1->PolyTyp != e2->PolyTyp && m_ClipType != ctXor) )
+ {
+ AddLocalMaxPoly(e1, e2, Pt);
+ }
+ else
+ {
+ AddOutPt(e1, Pt);
+ AddOutPt(e2, Pt);
+ SwapSides( *e1 , *e2 );
+ SwapPolyIndexes( *e1 , *e2 );
+ }
+ }
+ else if ( e1Contributing )
+ {
+ if (e2Wc == 0 || e2Wc == 1)
+ {
+ AddOutPt(e1, Pt);
+ SwapSides(*e1, *e2);
+ SwapPolyIndexes(*e1, *e2);
+ }
+ }
+ else if ( e2Contributing )
+ {
+ if (e1Wc == 0 || e1Wc == 1)
+ {
+ AddOutPt(e2, Pt);
+ SwapSides(*e1, *e2);
+ SwapPolyIndexes(*e1, *e2);
+ }
+ }
+ else if ( (e1Wc == 0 || e1Wc == 1) && (e2Wc == 0 || e2Wc == 1))
+ {
+ //neither edge is currently contributing ...
+
+ cInt e1Wc2, e2Wc2;
+ switch (e1FillType2)
+ {
+ case pftPositive: e1Wc2 = e1->WindCnt2; break;
+ case pftNegative : e1Wc2 = -e1->WindCnt2; break;
+ default: e1Wc2 = Abs(e1->WindCnt2);
+ }
+ switch (e2FillType2)
+ {
+ case pftPositive: e2Wc2 = e2->WindCnt2; break;
+ case pftNegative: e2Wc2 = -e2->WindCnt2; break;
+ default: e2Wc2 = Abs(e2->WindCnt2);
+ }
+
+ if (e1->PolyTyp != e2->PolyTyp)
+ {
+ AddLocalMinPoly(e1, e2, Pt);
+ }
+ else if (e1Wc == 1 && e2Wc == 1)
+ switch( m_ClipType ) {
+ case ctIntersection:
+ if (e1Wc2 > 0 && e2Wc2 > 0)
+ AddLocalMinPoly(e1, e2, Pt);
+ break;
+ case ctUnion:
+ if ( e1Wc2 <= 0 && e2Wc2 <= 0 )
+ AddLocalMinPoly(e1, e2, Pt);
+ break;
+ case ctDifference:
+ if (((e1->PolyTyp == ptClip) && (e1Wc2 > 0) && (e2Wc2 > 0)) ||
+ ((e1->PolyTyp == ptSubject) && (e1Wc2 <= 0) && (e2Wc2 <= 0)))
+ AddLocalMinPoly(e1, e2, Pt);
+ break;
+ case ctXor:
+ AddLocalMinPoly(e1, e2, Pt);
+ }
+ else
+ SwapSides( *e1, *e2 );
+ }
+}
+//------------------------------------------------------------------------------
+
+void Clipper::SetHoleState(TEdge *e, OutRec *outrec)
+{
+ TEdge *e2 = e->PrevInAEL;
+ TEdge *eTmp = 0;
+ while (e2)
+ {
+ if (e2->OutIdx >= 0 && e2->WindDelta != 0)
+ {
+ if (!eTmp) eTmp = e2;
+ else if (eTmp->OutIdx == e2->OutIdx) eTmp = 0;
+ }
+ e2 = e2->PrevInAEL;
+ }
+ if (!eTmp)
+ {
+ outrec->FirstLeft = 0;
+ outrec->IsHole = false;
+ }
+ else
+ {
+ outrec->FirstLeft = m_PolyOuts[eTmp->OutIdx];
+ outrec->IsHole = !outrec->FirstLeft->IsHole;
+ }
+}
+//------------------------------------------------------------------------------
+
+OutRec* GetLowermostRec(OutRec *outRec1, OutRec *outRec2)
+{
+ //work out which polygon fragment has the correct hole state ...
+ if (!outRec1->BottomPt)
+ outRec1->BottomPt = GetBottomPt(outRec1->Pts);
+ if (!outRec2->BottomPt)
+ outRec2->BottomPt = GetBottomPt(outRec2->Pts);
+ OutPt *OutPt1 = outRec1->BottomPt;
+ OutPt *OutPt2 = outRec2->BottomPt;
+ if (OutPt1->Pt.Y > OutPt2->Pt.Y) return outRec1;
+ else if (OutPt1->Pt.Y < OutPt2->Pt.Y) return outRec2;
+ else if (OutPt1->Pt.X < OutPt2->Pt.X) return outRec1;
+ else if (OutPt1->Pt.X > OutPt2->Pt.X) return outRec2;
+ else if (OutPt1->Next == OutPt1) return outRec2;
+ else if (OutPt2->Next == OutPt2) return outRec1;
+ else if (FirstIsBottomPt(OutPt1, OutPt2)) return outRec1;
+ else return outRec2;
+}
+//------------------------------------------------------------------------------
+
+bool OutRec1RightOfOutRec2(OutRec* outRec1, OutRec* outRec2)
+{
+ do
+ {
+ outRec1 = outRec1->FirstLeft;
+ if (outRec1 == outRec2) return true;
+ } while (outRec1);
+ return false;
+}
+//------------------------------------------------------------------------------
+
+OutRec* Clipper::GetOutRec(int Idx)
+{
+ OutRec* outrec = m_PolyOuts[Idx];
+ while (outrec != m_PolyOuts[outrec->Idx])
+ outrec = m_PolyOuts[outrec->Idx];
+ return outrec;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::AppendPolygon(TEdge *e1, TEdge *e2)
+{
+ //get the start and ends of both output polygons ...
+ OutRec *outRec1 = m_PolyOuts[e1->OutIdx];
+ OutRec *outRec2 = m_PolyOuts[e2->OutIdx];
+
+ OutRec *holeStateRec;
+ if (OutRec1RightOfOutRec2(outRec1, outRec2))
+ holeStateRec = outRec2;
+ else if (OutRec1RightOfOutRec2(outRec2, outRec1))
+ holeStateRec = outRec1;
+ else
+ holeStateRec = GetLowermostRec(outRec1, outRec2);
+
+ //get the start and ends of both output polygons and
+ //join e2 poly onto e1 poly and delete pointers to e2 ...
+
+ OutPt* p1_lft = outRec1->Pts;
+ OutPt* p1_rt = p1_lft->Prev;
+ OutPt* p2_lft = outRec2->Pts;
+ OutPt* p2_rt = p2_lft->Prev;
+
+ //join e2 poly onto e1 poly and delete pointers to e2 ...
+ if( e1->Side == esLeft )
+ {
+ if( e2->Side == esLeft )
+ {
+ //z y x a b c
+ ReversePolyPtLinks(p2_lft);
+ p2_lft->Next = p1_lft;
+ p1_lft->Prev = p2_lft;
+ p1_rt->Next = p2_rt;
+ p2_rt->Prev = p1_rt;
+ outRec1->Pts = p2_rt;
+ } else
+ {
+ //x y z a b c
+ p2_rt->Next = p1_lft;
+ p1_lft->Prev = p2_rt;
+ p2_lft->Prev = p1_rt;
+ p1_rt->Next = p2_lft;
+ outRec1->Pts = p2_lft;
+ }
+ } else
+ {
+ if( e2->Side == esRight )
+ {
+ //a b c z y x
+ ReversePolyPtLinks(p2_lft);
+ p1_rt->Next = p2_rt;
+ p2_rt->Prev = p1_rt;
+ p2_lft->Next = p1_lft;
+ p1_lft->Prev = p2_lft;
+ } else
+ {
+ //a b c x y z
+ p1_rt->Next = p2_lft;
+ p2_lft->Prev = p1_rt;
+ p1_lft->Prev = p2_rt;
+ p2_rt->Next = p1_lft;
+ }
+ }
+
+ outRec1->BottomPt = 0;
+ if (holeStateRec == outRec2)
+ {
+ if (outRec2->FirstLeft != outRec1)
+ outRec1->FirstLeft = outRec2->FirstLeft;
+ outRec1->IsHole = outRec2->IsHole;
+ }
+ outRec2->Pts = 0;
+ outRec2->BottomPt = 0;
+ outRec2->FirstLeft = outRec1;
+
+ int OKIdx = e1->OutIdx;
+ int ObsoleteIdx = e2->OutIdx;
+
+ e1->OutIdx = Unassigned; //nb: safe because we only get here via AddLocalMaxPoly
+ e2->OutIdx = Unassigned;
+
+ TEdge* e = m_ActiveEdges;
+ while( e )
+ {
+ if( e->OutIdx == ObsoleteIdx )
+ {
+ e->OutIdx = OKIdx;
+ e->Side = e1->Side;
+ break;
+ }
+ e = e->NextInAEL;
+ }
+
+ outRec2->Idx = outRec1->Idx;
+}
+//------------------------------------------------------------------------------
+
+OutPt* Clipper::AddOutPt(TEdge *e, const IntPoint &pt)
+{
+ if( e->OutIdx < 0 )
+ {
+ OutRec *outRec = CreateOutRec();
+ outRec->IsOpen = (e->WindDelta == 0);
+ OutPt* newOp = new OutPt;
+ outRec->Pts = newOp;
+ newOp->Idx = outRec->Idx;
+ newOp->Pt = pt;
+ newOp->Next = newOp;
+ newOp->Prev = newOp;
+ if (!outRec->IsOpen)
+ SetHoleState(e, outRec);
+ e->OutIdx = outRec->Idx;
+ return newOp;
+ } else
+ {
+ OutRec *outRec = m_PolyOuts[e->OutIdx];
+ //OutRec.Pts is the 'Left-most' point & OutRec.Pts.Prev is the 'Right-most'
+ OutPt* op = outRec->Pts;
+
+ bool ToFront = (e->Side == esLeft);
+ if (ToFront && (pt == op->Pt)) return op;
+ else if (!ToFront && (pt == op->Prev->Pt)) return op->Prev;
+
+ OutPt* newOp = new OutPt;
+ newOp->Idx = outRec->Idx;
+ newOp->Pt = pt;
+ newOp->Next = op;
+ newOp->Prev = op->Prev;
+ newOp->Prev->Next = newOp;
+ op->Prev = newOp;
+ if (ToFront) outRec->Pts = newOp;
+ return newOp;
+ }
+}
+//------------------------------------------------------------------------------
+
+OutPt* Clipper::GetLastOutPt(TEdge *e)
+{
+ OutRec *outRec = m_PolyOuts[e->OutIdx];
+ if (e->Side == esLeft)
+ return outRec->Pts;
+ else
+ return outRec->Pts->Prev;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::ProcessHorizontals()
+{
+ TEdge* horzEdge;
+ while (PopEdgeFromSEL(horzEdge))
+ ProcessHorizontal(horzEdge);
+}
+//------------------------------------------------------------------------------
+
+inline bool IsMinima(TEdge *e)
+{
+ return e && (e->Prev->NextInLML != e) && (e->Next->NextInLML != e);
+}
+//------------------------------------------------------------------------------
+
+inline bool IsMaxima(TEdge *e, const cInt Y)
+{
+ return e && e->Top.Y == Y && !e->NextInLML;
+}
+//------------------------------------------------------------------------------
+
+inline bool IsIntermediate(TEdge *e, const cInt Y)
+{
+ return e->Top.Y == Y && e->NextInLML;
+}
+//------------------------------------------------------------------------------
+
+TEdge *GetMaximaPair(TEdge *e)
+{
+ if ((e->Next->Top == e->Top) && !e->Next->NextInLML)
+ return e->Next;
+ else if ((e->Prev->Top == e->Top) && !e->Prev->NextInLML)
+ return e->Prev;
+ else return 0;
+}
+//------------------------------------------------------------------------------
+
+TEdge *GetMaximaPairEx(TEdge *e)
+{
+ //as GetMaximaPair() but returns 0 if MaxPair isn't in AEL (unless it's horizontal)
+ TEdge* result = GetMaximaPair(e);
+ if (result && (result->OutIdx == Skip ||
+ (result->NextInAEL == result->PrevInAEL && !IsHorizontal(*result)))) return 0;
+ return result;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::SwapPositionsInSEL(TEdge *Edge1, TEdge *Edge2)
+{
+ if( !( Edge1->NextInSEL ) && !( Edge1->PrevInSEL ) ) return;
+ if( !( Edge2->NextInSEL ) && !( Edge2->PrevInSEL ) ) return;
+
+ if( Edge1->NextInSEL == Edge2 )
+ {
+ TEdge* Next = Edge2->NextInSEL;
+ if( Next ) Next->PrevInSEL = Edge1;
+ TEdge* Prev = Edge1->PrevInSEL;
+ if( Prev ) Prev->NextInSEL = Edge2;
+ Edge2->PrevInSEL = Prev;
+ Edge2->NextInSEL = Edge1;
+ Edge1->PrevInSEL = Edge2;
+ Edge1->NextInSEL = Next;
+ }
+ else if( Edge2->NextInSEL == Edge1 )
+ {
+ TEdge* Next = Edge1->NextInSEL;
+ if( Next ) Next->PrevInSEL = Edge2;
+ TEdge* Prev = Edge2->PrevInSEL;
+ if( Prev ) Prev->NextInSEL = Edge1;
+ Edge1->PrevInSEL = Prev;
+ Edge1->NextInSEL = Edge2;
+ Edge2->PrevInSEL = Edge1;
+ Edge2->NextInSEL = Next;
+ }
+ else
+ {
+ TEdge* Next = Edge1->NextInSEL;
+ TEdge* Prev = Edge1->PrevInSEL;
+ Edge1->NextInSEL = Edge2->NextInSEL;
+ if( Edge1->NextInSEL ) Edge1->NextInSEL->PrevInSEL = Edge1;
+ Edge1->PrevInSEL = Edge2->PrevInSEL;
+ if( Edge1->PrevInSEL ) Edge1->PrevInSEL->NextInSEL = Edge1;
+ Edge2->NextInSEL = Next;
+ if( Edge2->NextInSEL ) Edge2->NextInSEL->PrevInSEL = Edge2;
+ Edge2->PrevInSEL = Prev;
+ if( Edge2->PrevInSEL ) Edge2->PrevInSEL->NextInSEL = Edge2;
+ }
+
+ if( !Edge1->PrevInSEL ) m_SortedEdges = Edge1;
+ else if( !Edge2->PrevInSEL ) m_SortedEdges = Edge2;
+}
+//------------------------------------------------------------------------------
+
+TEdge* GetNextInAEL(TEdge *e, Direction dir)
+{
+ return dir == dLeftToRight ? e->NextInAEL : e->PrevInAEL;
+}
+//------------------------------------------------------------------------------
+
+void GetHorzDirection(TEdge& HorzEdge, Direction& Dir, cInt& Left, cInt& Right)
+{
+ if (HorzEdge.Bot.X < HorzEdge.Top.X)
+ {
+ Left = HorzEdge.Bot.X;
+ Right = HorzEdge.Top.X;
+ Dir = dLeftToRight;
+ } else
+ {
+ Left = HorzEdge.Top.X;
+ Right = HorzEdge.Bot.X;
+ Dir = dRightToLeft;
+ }
+}
+//------------------------------------------------------------------------
+
+/*******************************************************************************
+* Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or *
+* Bottom of a scanbeam) are processed as if layered. The order in which HEs *
+* are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] *
+* (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), *
+* and with other non-horizontal edges [*]. Once these intersections are *
+* processed, intermediate HEs then 'promote' the Edge above (NextInLML) into *
+* the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. *
+*******************************************************************************/
+
+void Clipper::ProcessHorizontal(TEdge *horzEdge)
+{
+ Direction dir;
+ cInt horzLeft, horzRight;
+ bool IsOpen = (horzEdge->WindDelta == 0);
+
+ GetHorzDirection(*horzEdge, dir, horzLeft, horzRight);
+
+ TEdge* eLastHorz = horzEdge, *eMaxPair = 0;
+ while (eLastHorz->NextInLML && IsHorizontal(*eLastHorz->NextInLML))
+ eLastHorz = eLastHorz->NextInLML;
+ if (!eLastHorz->NextInLML)
+ eMaxPair = GetMaximaPair(eLastHorz);
+
+ MaximaList::const_iterator maxIt;
+ MaximaList::const_reverse_iterator maxRit;
+ if (m_Maxima.size() > 0)
+ {
+ //get the first maxima in range (X) ...
+ if (dir == dLeftToRight)
+ {
+ maxIt = m_Maxima.begin();
+ while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X) maxIt++;
+ if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
+ maxIt = m_Maxima.end();
+ }
+ else
+ {
+ maxRit = m_Maxima.rbegin();
+ while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X) maxRit++;
+ if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
+ maxRit = m_Maxima.rend();
+ }
+ }
+
+ OutPt* op1 = 0;
+
+ for (;;) //loop through consec. horizontal edges
+ {
+
+ bool IsLastHorz = (horzEdge == eLastHorz);
+ TEdge* e = GetNextInAEL(horzEdge, dir);
+ while(e)
+ {
+
+ //this code block inserts extra coords into horizontal edges (in output
+ //polygons) whereever maxima touch these horizontal edges. This helps
+ //'simplifying' polygons (ie if the Simplify property is set).
+ if (m_Maxima.size() > 0)
+ {
+ if (dir == dLeftToRight)
+ {
+ while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X)
+ {
+ if (horzEdge->OutIdx >= 0 && !IsOpen)
+ AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
+ maxIt++;
+ }
+ }
+ else
+ {
+ while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X)
+ {
+ if (horzEdge->OutIdx >= 0 && !IsOpen)
+ AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
+ maxRit++;
+ }
+ }
+ };
+
+ if ((dir == dLeftToRight && e->Curr.X > horzRight) ||
+ (dir == dRightToLeft && e->Curr.X < horzLeft)) break;
+
+ //Also break if we've got to the end of an intermediate horizontal edge ...
+ //nb: Smaller Dx's are to the right of larger Dx's ABOVE the horizontal.
+ if (e->Curr.X == horzEdge->Top.X && horzEdge->NextInLML &&
+ e->Dx < horzEdge->NextInLML->Dx) break;
+
+ if (horzEdge->OutIdx >= 0 && !IsOpen) //note: may be done multiple times
+ {
+ op1 = AddOutPt(horzEdge, e->Curr);
+ TEdge* eNextHorz = m_SortedEdges;
+ while (eNextHorz)
+ {
+ if (eNextHorz->OutIdx >= 0 &&
+ HorzSegmentsOverlap(horzEdge->Bot.X,
+ horzEdge->Top.X, eNextHorz->Bot.X, eNextHorz->Top.X))
+ {
+ OutPt* op2 = GetLastOutPt(eNextHorz);
+ AddJoin(op2, op1, eNextHorz->Top);
+ }
+ eNextHorz = eNextHorz->NextInSEL;
+ }
+ AddGhostJoin(op1, horzEdge->Bot);
+ }
+
+ //OK, so far we're still in range of the horizontal Edge but make sure
+ //we're at the last of consec. horizontals when matching with eMaxPair
+ if(e == eMaxPair && IsLastHorz)
+ {
+ if (horzEdge->OutIdx >= 0)
+ AddLocalMaxPoly(horzEdge, eMaxPair, horzEdge->Top);
+ DeleteFromAEL(horzEdge);
+ DeleteFromAEL(eMaxPair);
+ return;
+ }
+
+ if(dir == dLeftToRight)
+ {
+ IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y);
+ IntersectEdges(horzEdge, e, Pt);
+ }
+ else
+ {
+ IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y);
+ IntersectEdges( e, horzEdge, Pt);
+ }
+ TEdge* eNext = GetNextInAEL(e, dir);
+ SwapPositionsInAEL( horzEdge, e );
+ e = eNext;
+ } //end while(e)
+
+ //Break out of loop if HorzEdge.NextInLML is not also horizontal ...
+ if (!horzEdge->NextInLML || !IsHorizontal(*horzEdge->NextInLML)) break;
+
+ UpdateEdgeIntoAEL(horzEdge);
+ if (horzEdge->OutIdx >= 0) AddOutPt(horzEdge, horzEdge->Bot);
+ GetHorzDirection(*horzEdge, dir, horzLeft, horzRight);
+
+ } //end for (;;)
+
+ if (horzEdge->OutIdx >= 0 && !op1)
+ {
+ op1 = GetLastOutPt(horzEdge);
+ TEdge* eNextHorz = m_SortedEdges;
+ while (eNextHorz)
+ {
+ if (eNextHorz->OutIdx >= 0 &&
+ HorzSegmentsOverlap(horzEdge->Bot.X,
+ horzEdge->Top.X, eNextHorz->Bot.X, eNextHorz->Top.X))
+ {
+ OutPt* op2 = GetLastOutPt(eNextHorz);
+ AddJoin(op2, op1, eNextHorz->Top);
+ }
+ eNextHorz = eNextHorz->NextInSEL;
+ }
+ AddGhostJoin(op1, horzEdge->Top);
+ }
+
+ if (horzEdge->NextInLML)
+ {
+ if(horzEdge->OutIdx >= 0)
+ {
+ op1 = AddOutPt( horzEdge, horzEdge->Top);
+ UpdateEdgeIntoAEL(horzEdge);
+ if (horzEdge->WindDelta == 0) return;
+ //nb: HorzEdge is no longer horizontal here
+ TEdge* ePrev = horzEdge->PrevInAEL;
+ TEdge* eNext = horzEdge->NextInAEL;
+ if (ePrev && ePrev->Curr.X == horzEdge->Bot.X &&
+ ePrev->Curr.Y == horzEdge->Bot.Y && ePrev->WindDelta != 0 &&
+ (ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y &&
+ SlopesEqual(*horzEdge, *ePrev, m_UseFullRange)))
+ {
+ OutPt* op2 = AddOutPt(ePrev, horzEdge->Bot);
+ AddJoin(op1, op2, horzEdge->Top);
+ }
+ else if (eNext && eNext->Curr.X == horzEdge->Bot.X &&
+ eNext->Curr.Y == horzEdge->Bot.Y && eNext->WindDelta != 0 &&
+ eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y &&
+ SlopesEqual(*horzEdge, *eNext, m_UseFullRange))
+ {
+ OutPt* op2 = AddOutPt(eNext, horzEdge->Bot);
+ AddJoin(op1, op2, horzEdge->Top);
+ }
+ }
+ else
+ UpdateEdgeIntoAEL(horzEdge);
+ }
+ else
+ {
+ if (horzEdge->OutIdx >= 0) AddOutPt(horzEdge, horzEdge->Top);
+ DeleteFromAEL(horzEdge);
+ }
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::ProcessIntersections(const cInt topY)
+{
+ if( !m_ActiveEdges ) return true;
+ try {
+ BuildIntersectList(topY);
+ size_t IlSize = m_IntersectList.size();
+ if (IlSize == 0) return true;
+ if (IlSize == 1 || FixupIntersectionOrder()) ProcessIntersectList();
+ else return false;
+ }
+ catch(...)
+ {
+ m_SortedEdges = 0;
+ DisposeIntersectNodes();
+ throw clipperException("ProcessIntersections error");
+ }
+ m_SortedEdges = 0;
+ return true;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::DisposeIntersectNodes()
+{
+ for (size_t i = 0; i < m_IntersectList.size(); ++i )
+ delete m_IntersectList[i];
+ m_IntersectList.clear();
+}
+//------------------------------------------------------------------------------
+
+void Clipper::BuildIntersectList(const cInt topY)
+{
+ if ( !m_ActiveEdges ) return;
+
+ //prepare for sorting ...
+ TEdge* e = m_ActiveEdges;
+ m_SortedEdges = e;
+ while( e )
+ {
+ e->PrevInSEL = e->PrevInAEL;
+ e->NextInSEL = e->NextInAEL;
+ e->Curr.X = TopX( *e, topY );
+ e = e->NextInAEL;
+ }
+
+ //bubblesort ...
+ bool isModified;
+ do
+ {
+ isModified = false;
+ e = m_SortedEdges;
+ while( e->NextInSEL )
+ {
+ TEdge *eNext = e->NextInSEL;
+ IntPoint Pt;
+ if(e->Curr.X > eNext->Curr.X)
+ {
+ IntersectPoint(*e, *eNext, Pt);
+ if (Pt.Y < topY) Pt = IntPoint(TopX(*e, topY), topY);
+ IntersectNode * newNode = new IntersectNode;
+ newNode->Edge1 = e;
+ newNode->Edge2 = eNext;
+ newNode->Pt = Pt;
+ m_IntersectList.push_back(newNode);
+
+ SwapPositionsInSEL(e, eNext);
+ isModified = true;
+ }
+ else
+ e = eNext;
+ }
+ if( e->PrevInSEL ) e->PrevInSEL->NextInSEL = 0;
+ else break;
+ }
+ while ( isModified );
+ m_SortedEdges = 0; //important
+}
+//------------------------------------------------------------------------------
+
+
+void Clipper::ProcessIntersectList()
+{
+ for (size_t i = 0; i < m_IntersectList.size(); ++i)
+ {
+ IntersectNode* iNode = m_IntersectList[i];
+ {
+ IntersectEdges( iNode->Edge1, iNode->Edge2, iNode->Pt);
+ SwapPositionsInAEL( iNode->Edge1 , iNode->Edge2 );
+ }
+ delete iNode;
+ }
+ m_IntersectList.clear();
+}
+//------------------------------------------------------------------------------
+
+bool IntersectListSort(IntersectNode* node1, IntersectNode* node2)
+{
+ return node2->Pt.Y < node1->Pt.Y;
+}
+//------------------------------------------------------------------------------
+
+inline bool EdgesAdjacent(const IntersectNode &inode)
+{
+ return (inode.Edge1->NextInSEL == inode.Edge2) ||
+ (inode.Edge1->PrevInSEL == inode.Edge2);
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::FixupIntersectionOrder()
+{
+ //pre-condition: intersections are sorted Bottom-most first.
+ //Now it's crucial that intersections are made only between adjacent edges,
+ //so to ensure this the order of intersections may need adjusting ...
+ CopyAELToSEL();
+ std::sort(m_IntersectList.begin(), m_IntersectList.end(), IntersectListSort);
+ size_t cnt = m_IntersectList.size();
+ for (size_t i = 0; i < cnt; ++i)
+ {
+ if (!EdgesAdjacent(*m_IntersectList[i]))
+ {
+ size_t j = i + 1;
+ while (j < cnt && !EdgesAdjacent(*m_IntersectList[j])) j++;
+ if (j == cnt) return false;
+ std::swap(m_IntersectList[i], m_IntersectList[j]);
+ }
+ SwapPositionsInSEL(m_IntersectList[i]->Edge1, m_IntersectList[i]->Edge2);
+ }
+ return true;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::DoMaxima(TEdge *e)
+{
+ TEdge* eMaxPair = GetMaximaPairEx(e);
+ if (!eMaxPair)
+ {
+ if (e->OutIdx >= 0)
+ AddOutPt(e, e->Top);
+ DeleteFromAEL(e);
+ return;
+ }
+
+ TEdge* eNext = e->NextInAEL;
+ while(eNext && eNext != eMaxPair)
+ {
+ IntersectEdges(e, eNext, e->Top);
+ SwapPositionsInAEL(e, eNext);
+ eNext = e->NextInAEL;
+ }
+
+ if(e->OutIdx == Unassigned && eMaxPair->OutIdx == Unassigned)
+ {
+ DeleteFromAEL(e);
+ DeleteFromAEL(eMaxPair);
+ }
+ else if( e->OutIdx >= 0 && eMaxPair->OutIdx >= 0 )
+ {
+ if (e->OutIdx >= 0) AddLocalMaxPoly(e, eMaxPair, e->Top);
+ DeleteFromAEL(e);
+ DeleteFromAEL(eMaxPair);
+ }
+#ifdef use_lines
+ else if (e->WindDelta == 0)
+ {
+ if (e->OutIdx >= 0)
+ {
+ AddOutPt(e, e->Top);
+ e->OutIdx = Unassigned;
+ }
+ DeleteFromAEL(e);
+
+ if (eMaxPair->OutIdx >= 0)
+ {
+ AddOutPt(eMaxPair, e->Top);
+ eMaxPair->OutIdx = Unassigned;
+ }
+ DeleteFromAEL(eMaxPair);
+ }
+#endif
+ else throw clipperException("DoMaxima error");
+}
+//------------------------------------------------------------------------------
+
+void Clipper::ProcessEdgesAtTopOfScanbeam(const cInt topY)
+{
+ TEdge* e = m_ActiveEdges;
+ while( e )
+ {
+ //1. process maxima, treating them as if they're 'bent' horizontal edges,
+ // but exclude maxima with horizontal edges. nb: e can't be a horizontal.
+ bool IsMaximaEdge = IsMaxima(e, topY);
+
+ if(IsMaximaEdge)
+ {
+ TEdge* eMaxPair = GetMaximaPairEx(e);
+ IsMaximaEdge = (!eMaxPair || !IsHorizontal(*eMaxPair));
+ }
+
+ if(IsMaximaEdge)
+ {
+ if (m_StrictSimple) m_Maxima.push_back(e->Top.X);
+ TEdge* ePrev = e->PrevInAEL;
+ DoMaxima(e);
+ if( !ePrev ) e = m_ActiveEdges;
+ else e = ePrev->NextInAEL;
+ }
+ else
+ {
+ //2. promote horizontal edges, otherwise update Curr.X and Curr.Y ...
+ if (IsIntermediate(e, topY) && IsHorizontal(*e->NextInLML))
+ {
+ UpdateEdgeIntoAEL(e);
+ if (e->OutIdx >= 0)
+ AddOutPt(e, e->Bot);
+ AddEdgeToSEL(e);
+ }
+ else
+ {
+ e->Curr.X = TopX( *e, topY );
+ e->Curr.Y = topY;
+ }
+
+ //When StrictlySimple and 'e' is being touched by another edge, then
+ //make sure both edges have a vertex here ...
+ if (m_StrictSimple)
+ {
+ TEdge* ePrev = e->PrevInAEL;
+ if ((e->OutIdx >= 0) && (e->WindDelta != 0) && ePrev && (ePrev->OutIdx >= 0) &&
+ (ePrev->Curr.X == e->Curr.X) && (ePrev->WindDelta != 0))
+ {
+ IntPoint pt = e->Curr;
+#ifdef use_xyz
+ SetZ(pt, *ePrev, *e);
+#endif
+ OutPt* op = AddOutPt(ePrev, pt);
+ OutPt* op2 = AddOutPt(e, pt);
+ AddJoin(op, op2, pt); //StrictlySimple (type-3) join
+ }
+ }
+
+ e = e->NextInAEL;
+ }
+ }
+
+ //3. Process horizontals at the Top of the scanbeam ...
+ m_Maxima.sort();
+ ProcessHorizontals();
+ m_Maxima.clear();
+
+ //4. Promote intermediate vertices ...
+ e = m_ActiveEdges;
+ while(e)
+ {
+ if(IsIntermediate(e, topY))
+ {
+ OutPt* op = 0;
+ if( e->OutIdx >= 0 )
+ op = AddOutPt(e, e->Top);
+ UpdateEdgeIntoAEL(e);
+
+ //if output polygons share an edge, they'll need joining later ...
+ TEdge* ePrev = e->PrevInAEL;
+ TEdge* eNext = e->NextInAEL;
+ if (ePrev && ePrev->Curr.X == e->Bot.X &&
+ ePrev->Curr.Y == e->Bot.Y && op &&
+ ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y &&
+ SlopesEqual(e->Curr, e->Top, ePrev->Curr, ePrev->Top, m_UseFullRange) &&
+ (e->WindDelta != 0) && (ePrev->WindDelta != 0))
+ {
+ OutPt* op2 = AddOutPt(ePrev, e->Bot);
+ AddJoin(op, op2, e->Top);
+ }
+ else if (eNext && eNext->Curr.X == e->Bot.X &&
+ eNext->Curr.Y == e->Bot.Y && op &&
+ eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y &&
+ SlopesEqual(e->Curr, e->Top, eNext->Curr, eNext->Top, m_UseFullRange) &&
+ (e->WindDelta != 0) && (eNext->WindDelta != 0))
+ {
+ OutPt* op2 = AddOutPt(eNext, e->Bot);
+ AddJoin(op, op2, e->Top);
+ }
+ }
+ e = e->NextInAEL;
+ }
+}
+//------------------------------------------------------------------------------
+
+void Clipper::FixupOutPolyline(OutRec &outrec)
+{
+ OutPt *pp = outrec.Pts;
+ OutPt *lastPP = pp->Prev;
+ while (pp != lastPP)
+ {
+ pp = pp->Next;
+ if (pp->Pt == pp->Prev->Pt)
+ {
+ if (pp == lastPP) lastPP = pp->Prev;
+ OutPt *tmpPP = pp->Prev;
+ tmpPP->Next = pp->Next;
+ pp->Next->Prev = tmpPP;
+ delete pp;
+ pp = tmpPP;
+ }
+ }
+
+ if (pp == pp->Prev)
+ {
+ DisposeOutPts(pp);
+ outrec.Pts = 0;
+ return;
+ }
+}
+//------------------------------------------------------------------------------
+
+void Clipper::FixupOutPolygon(OutRec &outrec)
+{
+ //FixupOutPolygon() - removes duplicate points and simplifies consecutive
+ //parallel edges by removing the middle vertex.
+ OutPt *lastOK = 0;
+ outrec.BottomPt = 0;
+ OutPt *pp = outrec.Pts;
+ bool preserveCol = m_PreserveCollinear || m_StrictSimple;
+
+ for (;;)
+ {
+ if (pp->Prev == pp || pp->Prev == pp->Next)
+ {
+ DisposeOutPts(pp);
+ outrec.Pts = 0;
+ return;
+ }
+
+ //test for duplicate points and collinear edges ...
+ if ((pp->Pt == pp->Next->Pt) || (pp->Pt == pp->Prev->Pt) ||
+ (SlopesEqual(pp->Prev->Pt, pp->Pt, pp->Next->Pt, m_UseFullRange) &&
+ (!preserveCol || !Pt2IsBetweenPt1AndPt3(pp->Prev->Pt, pp->Pt, pp->Next->Pt))))
+ {
+ lastOK = 0;
+ OutPt *tmp = pp;
+ pp->Prev->Next = pp->Next;
+ pp->Next->Prev = pp->Prev;
+ pp = pp->Prev;
+ delete tmp;
+ }
+ else if (pp == lastOK) break;
+ else
+ {
+ if (!lastOK) lastOK = pp;
+ pp = pp->Next;
+ }
+ }
+ outrec.Pts = pp;
+}
+//------------------------------------------------------------------------------
+
+int PointCount(OutPt *Pts)
+{
+ if (!Pts) return 0;
+ int result = 0;
+ OutPt* p = Pts;
+ do
+ {
+ result++;
+ p = p->Next;
+ }
+ while (p != Pts);
+ return result;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::BuildResult(Paths &polys)
+{
+ polys.reserve(m_PolyOuts.size());
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
+ {
+ if (!m_PolyOuts[i]->Pts) continue;
+ Path pg;
+ OutPt* p = m_PolyOuts[i]->Pts->Prev;
+ int cnt = PointCount(p);
+ if (cnt < 2) continue;
+ pg.reserve(cnt);
+ for (int i = 0; i < cnt; ++i)
+ {
+ pg.push_back(p->Pt);
+ p = p->Prev;
+ }
+ polys.push_back(pg);
+ }
+}
+//------------------------------------------------------------------------------
+
+void Clipper::BuildResult2(PolyTree& polytree)
+{
+ polytree.Clear();
+ polytree.AllNodes.reserve(m_PolyOuts.size());
+ //add each output polygon/contour to polytree ...
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++)
+ {
+ OutRec* outRec = m_PolyOuts[i];
+ int cnt = PointCount(outRec->Pts);
+ if ((outRec->IsOpen && cnt < 2) || (!outRec->IsOpen && cnt < 3)) continue;
+ FixHoleLinkage(*outRec);
+ PolyNode* pn = new PolyNode();
+ //nb: polytree takes ownership of all the PolyNodes
+ polytree.AllNodes.push_back(pn);
+ outRec->PolyNd = pn;
+ pn->Parent = 0;
+ pn->Index = 0;
+ pn->Contour.reserve(cnt);
+ OutPt *op = outRec->Pts->Prev;
+ for (int j = 0; j < cnt; j++)
+ {
+ pn->Contour.push_back(op->Pt);
+ op = op->Prev;
+ }
+ }
+
+ //fixup PolyNode links etc ...
+ polytree.Childs.reserve(m_PolyOuts.size());
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++)
+ {
+ OutRec* outRec = m_PolyOuts[i];
+ if (!outRec->PolyNd) continue;
+ if (outRec->IsOpen)
+ {
+ outRec->PolyNd->m_IsOpen = true;
+ polytree.AddChild(*outRec->PolyNd);
+ }
+ else if (outRec->FirstLeft && outRec->FirstLeft->PolyNd)
+ outRec->FirstLeft->PolyNd->AddChild(*outRec->PolyNd);
+ else
+ polytree.AddChild(*outRec->PolyNd);
+ }
+}
+//------------------------------------------------------------------------------
+
+void SwapIntersectNodes(IntersectNode &int1, IntersectNode &int2)
+{
+ //just swap the contents (because fIntersectNodes is a single-linked-list)
+ IntersectNode inode = int1; //gets a copy of Int1
+ int1.Edge1 = int2.Edge1;
+ int1.Edge2 = int2.Edge2;
+ int1.Pt = int2.Pt;
+ int2.Edge1 = inode.Edge1;
+ int2.Edge2 = inode.Edge2;
+ int2.Pt = inode.Pt;
+}
+//------------------------------------------------------------------------------
+
+inline bool E2InsertsBeforeE1(TEdge &e1, TEdge &e2)
+{
+ if (e2.Curr.X == e1.Curr.X)
+ {
+ if (e2.Top.Y > e1.Top.Y)
+ return e2.Top.X < TopX(e1, e2.Top.Y);
+ else return e1.Top.X > TopX(e2, e1.Top.Y);
+ }
+ else return e2.Curr.X < e1.Curr.X;
+}
+//------------------------------------------------------------------------------
+
+bool GetOverlap(const cInt a1, const cInt a2, const cInt b1, const cInt b2,
+ cInt& Left, cInt& Right)
+{
+ if (a1 < a2)
+ {
+ if (b1 < b2) {Left = std::max(a1,b1); Right = std::min(a2,b2);}
+ else {Left = std::max(a1,b2); Right = std::min(a2,b1);}
+ }
+ else
+ {
+ if (b1 < b2) {Left = std::max(a2,b1); Right = std::min(a1,b2);}
+ else {Left = std::max(a2,b2); Right = std::min(a1,b1);}
+ }
+ return Left < Right;
+}
+//------------------------------------------------------------------------------
+
+inline void UpdateOutPtIdxs(OutRec& outrec)
+{
+ OutPt* op = outrec.Pts;
+ do
+ {
+ op->Idx = outrec.Idx;
+ op = op->Prev;
+ }
+ while(op != outrec.Pts);
+}
+//------------------------------------------------------------------------------
+
+void Clipper::InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge)
+{
+ if(!m_ActiveEdges)
+ {
+ edge->PrevInAEL = 0;
+ edge->NextInAEL = 0;
+ m_ActiveEdges = edge;
+ }
+ else if(!startEdge && E2InsertsBeforeE1(*m_ActiveEdges, *edge))
+ {
+ edge->PrevInAEL = 0;
+ edge->NextInAEL = m_ActiveEdges;
+ m_ActiveEdges->PrevInAEL = edge;
+ m_ActiveEdges = edge;
+ }
+ else
+ {
+ if(!startEdge) startEdge = m_ActiveEdges;
+ while(startEdge->NextInAEL &&
+ !E2InsertsBeforeE1(*startEdge->NextInAEL , *edge))
+ startEdge = startEdge->NextInAEL;
+ edge->NextInAEL = startEdge->NextInAEL;
+ if(startEdge->NextInAEL) startEdge->NextInAEL->PrevInAEL = edge;
+ edge->PrevInAEL = startEdge;
+ startEdge->NextInAEL = edge;
+ }
+}
+//----------------------------------------------------------------------
+
+OutPt* DupOutPt(OutPt* outPt, bool InsertAfter)
+{
+ OutPt* result = new OutPt;
+ result->Pt = outPt->Pt;
+ result->Idx = outPt->Idx;
+ if (InsertAfter)
+ {
+ result->Next = outPt->Next;
+ result->Prev = outPt;
+ outPt->Next->Prev = result;
+ outPt->Next = result;
+ }
+ else
+ {
+ result->Prev = outPt->Prev;
+ result->Next = outPt;
+ outPt->Prev->Next = result;
+ outPt->Prev = result;
+ }
+ return result;
+}
+//------------------------------------------------------------------------------
+
+bool JoinHorz(OutPt* op1, OutPt* op1b, OutPt* op2, OutPt* op2b,
+ const IntPoint Pt, bool DiscardLeft)
+{
+ Direction Dir1 = (op1->Pt.X > op1b->Pt.X ? dRightToLeft : dLeftToRight);
+ Direction Dir2 = (op2->Pt.X > op2b->Pt.X ? dRightToLeft : dLeftToRight);
+ if (Dir1 == Dir2) return false;
+
+ //When DiscardLeft, we want Op1b to be on the Left of Op1, otherwise we
+ //want Op1b to be on the Right. (And likewise with Op2 and Op2b.)
+ //So, to facilitate this while inserting Op1b and Op2b ...
+ //when DiscardLeft, make sure we're AT or RIGHT of Pt before adding Op1b,
+ //otherwise make sure we're AT or LEFT of Pt. (Likewise with Op2b.)
+ if (Dir1 == dLeftToRight)
+ {
+ while (op1->Next->Pt.X <= Pt.X &&
+ op1->Next->Pt.X >= op1->Pt.X && op1->Next->Pt.Y == Pt.Y)
+ op1 = op1->Next;
+ if (DiscardLeft && (op1->Pt.X != Pt.X)) op1 = op1->Next;
+ op1b = DupOutPt(op1, !DiscardLeft);
+ if (op1b->Pt != Pt)
+ {
+ op1 = op1b;
+ op1->Pt = Pt;
+ op1b = DupOutPt(op1, !DiscardLeft);
+ }
+ }
+ else
+ {
+ while (op1->Next->Pt.X >= Pt.X &&
+ op1->Next->Pt.X <= op1->Pt.X && op1->Next->Pt.Y == Pt.Y)
+ op1 = op1->Next;
+ if (!DiscardLeft && (op1->Pt.X != Pt.X)) op1 = op1->Next;
+ op1b = DupOutPt(op1, DiscardLeft);
+ if (op1b->Pt != Pt)
+ {
+ op1 = op1b;
+ op1->Pt = Pt;
+ op1b = DupOutPt(op1, DiscardLeft);
+ }
+ }
+
+ if (Dir2 == dLeftToRight)
+ {
+ while (op2->Next->Pt.X <= Pt.X &&
+ op2->Next->Pt.X >= op2->Pt.X && op2->Next->Pt.Y == Pt.Y)
+ op2 = op2->Next;
+ if (DiscardLeft && (op2->Pt.X != Pt.X)) op2 = op2->Next;
+ op2b = DupOutPt(op2, !DiscardLeft);
+ if (op2b->Pt != Pt)
+ {
+ op2 = op2b;
+ op2->Pt = Pt;
+ op2b = DupOutPt(op2, !DiscardLeft);
+ };
+ } else
+ {
+ while (op2->Next->Pt.X >= Pt.X &&
+ op2->Next->Pt.X <= op2->Pt.X && op2->Next->Pt.Y == Pt.Y)
+ op2 = op2->Next;
+ if (!DiscardLeft && (op2->Pt.X != Pt.X)) op2 = op2->Next;
+ op2b = DupOutPt(op2, DiscardLeft);
+ if (op2b->Pt != Pt)
+ {
+ op2 = op2b;
+ op2->Pt = Pt;
+ op2b = DupOutPt(op2, DiscardLeft);
+ };
+ };
+
+ if ((Dir1 == dLeftToRight) == DiscardLeft)
+ {
+ op1->Prev = op2;
+ op2->Next = op1;
+ op1b->Next = op2b;
+ op2b->Prev = op1b;
+ }
+ else
+ {
+ op1->Next = op2;
+ op2->Prev = op1;
+ op1b->Prev = op2b;
+ op2b->Next = op1b;
+ }
+ return true;
+}
+//------------------------------------------------------------------------------
+
+bool Clipper::JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2)
+{
+ OutPt *op1 = j->OutPt1, *op1b;
+ OutPt *op2 = j->OutPt2, *op2b;
+
+ //There are 3 kinds of joins for output polygons ...
+ //1. Horizontal joins where Join.OutPt1 & Join.OutPt2 are vertices anywhere
+ //along (horizontal) collinear edges (& Join.OffPt is on the same horizontal).
+ //2. Non-horizontal joins where Join.OutPt1 & Join.OutPt2 are at the same
+ //location at the Bottom of the overlapping segment (& Join.OffPt is above).
+ //3. StrictSimple joins where edges touch but are not collinear and where
+ //Join.OutPt1, Join.OutPt2 & Join.OffPt all share the same point.
+ bool isHorizontal = (j->OutPt1->Pt.Y == j->OffPt.Y);
+
+ if (isHorizontal && (j->OffPt == j->OutPt1->Pt) &&
+ (j->OffPt == j->OutPt2->Pt))
+ {
+ //Strictly Simple join ...
+ if (outRec1 != outRec2) return false;
+ op1b = j->OutPt1->Next;
+ while (op1b != op1 && (op1b->Pt == j->OffPt))
+ op1b = op1b->Next;
+ bool reverse1 = (op1b->Pt.Y > j->OffPt.Y);
+ op2b = j->OutPt2->Next;
+ while (op2b != op2 && (op2b->Pt == j->OffPt))
+ op2b = op2b->Next;
+ bool reverse2 = (op2b->Pt.Y > j->OffPt.Y);
+ if (reverse1 == reverse2) return false;
+ if (reverse1)
+ {
+ op1b = DupOutPt(op1, false);
+ op2b = DupOutPt(op2, true);
+ op1->Prev = op2;
+ op2->Next = op1;
+ op1b->Next = op2b;
+ op2b->Prev = op1b;
+ j->OutPt1 = op1;
+ j->OutPt2 = op1b;
+ return true;
+ } else
+ {
+ op1b = DupOutPt(op1, true);
+ op2b = DupOutPt(op2, false);
+ op1->Next = op2;
+ op2->Prev = op1;
+ op1b->Prev = op2b;
+ op2b->Next = op1b;
+ j->OutPt1 = op1;
+ j->OutPt2 = op1b;
+ return true;
+ }
+ }
+ else if (isHorizontal)
+ {
+ //treat horizontal joins differently to non-horizontal joins since with
+ //them we're not yet sure where the overlapping is. OutPt1.Pt & OutPt2.Pt
+ //may be anywhere along the horizontal edge.
+ op1b = op1;
+ while (op1->Prev->Pt.Y == op1->Pt.Y && op1->Prev != op1b && op1->Prev != op2)
+ op1 = op1->Prev;
+ while (op1b->Next->Pt.Y == op1b->Pt.Y && op1b->Next != op1 && op1b->Next != op2)
+ op1b = op1b->Next;
+ if (op1b->Next == op1 || op1b->Next == op2) return false; //a flat 'polygon'
+
+ op2b = op2;
+ while (op2->Prev->Pt.Y == op2->Pt.Y && op2->Prev != op2b && op2->Prev != op1b)
+ op2 = op2->Prev;
+ while (op2b->Next->Pt.Y == op2b->Pt.Y && op2b->Next != op2 && op2b->Next != op1)
+ op2b = op2b->Next;
+ if (op2b->Next == op2 || op2b->Next == op1) return false; //a flat 'polygon'
+
+ cInt Left, Right;
+ //Op1 --> Op1b & Op2 --> Op2b are the extremites of the horizontal edges
+ if (!GetOverlap(op1->Pt.X, op1b->Pt.X, op2->Pt.X, op2b->Pt.X, Left, Right))
+ return false;
+
+ //DiscardLeftSide: when overlapping edges are joined, a spike will created
+ //which needs to be cleaned up. However, we don't want Op1 or Op2 caught up
+ //on the discard Side as either may still be needed for other joins ...
+ IntPoint Pt;
+ bool DiscardLeftSide;
+ if (op1->Pt.X >= Left && op1->Pt.X <= Right)
+ {
+ Pt = op1->Pt; DiscardLeftSide = (op1->Pt.X > op1b->Pt.X);
+ }
+ else if (op2->Pt.X >= Left&& op2->Pt.X <= Right)
+ {
+ Pt = op2->Pt; DiscardLeftSide = (op2->Pt.X > op2b->Pt.X);
+ }
+ else if (op1b->Pt.X >= Left && op1b->Pt.X <= Right)
+ {
+ Pt = op1b->Pt; DiscardLeftSide = op1b->Pt.X > op1->Pt.X;
+ }
+ else
+ {
+ Pt = op2b->Pt; DiscardLeftSide = (op2b->Pt.X > op2->Pt.X);
+ }
+ j->OutPt1 = op1; j->OutPt2 = op2;
+ return JoinHorz(op1, op1b, op2, op2b, Pt, DiscardLeftSide);
+ } else
+ {
+ //nb: For non-horizontal joins ...
+ // 1. Jr.OutPt1.Pt.Y == Jr.OutPt2.Pt.Y
+ // 2. Jr.OutPt1.Pt > Jr.OffPt.Y
+
+ //make sure the polygons are correctly oriented ...
+ op1b = op1->Next;
+ while ((op1b->Pt == op1->Pt) && (op1b != op1)) op1b = op1b->Next;
+ bool Reverse1 = ((op1b->Pt.Y > op1->Pt.Y) ||
+ !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange));
+ if (Reverse1)
+ {
+ op1b = op1->Prev;
+ while ((op1b->Pt == op1->Pt) && (op1b != op1)) op1b = op1b->Prev;
+ if ((op1b->Pt.Y > op1->Pt.Y) ||
+ !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange)) return false;
+ };
+ op2b = op2->Next;
+ while ((op2b->Pt == op2->Pt) && (op2b != op2))op2b = op2b->Next;
+ bool Reverse2 = ((op2b->Pt.Y > op2->Pt.Y) ||
+ !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange));
+ if (Reverse2)
+ {
+ op2b = op2->Prev;
+ while ((op2b->Pt == op2->Pt) && (op2b != op2)) op2b = op2b->Prev;
+ if ((op2b->Pt.Y > op2->Pt.Y) ||
+ !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange)) return false;
+ }
+
+ if ((op1b == op1) || (op2b == op2) || (op1b == op2b) ||
+ ((outRec1 == outRec2) && (Reverse1 == Reverse2))) return false;
+
+ if (Reverse1)
+ {
+ op1b = DupOutPt(op1, false);
+ op2b = DupOutPt(op2, true);
+ op1->Prev = op2;
+ op2->Next = op1;
+ op1b->Next = op2b;
+ op2b->Prev = op1b;
+ j->OutPt1 = op1;
+ j->OutPt2 = op1b;
+ return true;
+ } else
+ {
+ op1b = DupOutPt(op1, true);
+ op2b = DupOutPt(op2, false);
+ op1->Next = op2;
+ op2->Prev = op1;
+ op1b->Prev = op2b;
+ op2b->Next = op1b;
+ j->OutPt1 = op1;
+ j->OutPt2 = op1b;
+ return true;
+ }
+ }
+}
+//----------------------------------------------------------------------
+
+static OutRec* ParseFirstLeft(OutRec* FirstLeft)
+{
+ while (FirstLeft && !FirstLeft->Pts)
+ FirstLeft = FirstLeft->FirstLeft;
+ return FirstLeft;
+}
+//------------------------------------------------------------------------------
+
+void Clipper::FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec)
+{
+ //tests if NewOutRec contains the polygon before reassigning FirstLeft
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
+ {
+ OutRec* outRec = m_PolyOuts[i];
+ OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft);
+ if (outRec->Pts && firstLeft == OldOutRec)
+ {
+ if (Poly2ContainsPoly1(outRec->Pts, NewOutRec->Pts))
+ outRec->FirstLeft = NewOutRec;
+ }
+ }
+}
+//----------------------------------------------------------------------
+
+void Clipper::FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec)
+{
+ //A polygon has split into two such that one is now the inner of the other.
+ //It's possible that these polygons now wrap around other polygons, so check
+ //every polygon that's also contained by OuterOutRec's FirstLeft container
+ //(including 0) to see if they've become inner to the new inner polygon ...
+ OutRec* orfl = OuterOutRec->FirstLeft;
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
+ {
+ OutRec* outRec = m_PolyOuts[i];
+
+ if (!outRec->Pts || outRec == OuterOutRec || outRec == InnerOutRec)
+ continue;
+ OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft);
+ if (firstLeft != orfl && firstLeft != InnerOutRec && firstLeft != OuterOutRec)
+ continue;
+ if (Poly2ContainsPoly1(outRec->Pts, InnerOutRec->Pts))
+ outRec->FirstLeft = InnerOutRec;
+ else if (Poly2ContainsPoly1(outRec->Pts, OuterOutRec->Pts))
+ outRec->FirstLeft = OuterOutRec;
+ else if (outRec->FirstLeft == InnerOutRec || outRec->FirstLeft == OuterOutRec)
+ outRec->FirstLeft = orfl;
+ }
+}
+//----------------------------------------------------------------------
+void Clipper::FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec)
+{
+ //reassigns FirstLeft WITHOUT testing if NewOutRec contains the polygon
+ for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
+ {
+ OutRec* outRec = m_PolyOuts[i];
+ OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft);
+ if (outRec->Pts && outRec->FirstLeft == OldOutRec)
+ outRec->FirstLeft = NewOutRec;
+ }
+}
+//----------------------------------------------------------------------
+
+void Clipper::JoinCommonEdges()
+{
+ for (JoinList::size_type i = 0; i < m_Joins.size(); i++)
+ {
+ Join* join = m_Joins[i];
+
+ OutRec *outRec1 = GetOutRec(join->OutPt1->Idx);
+ OutRec *outRec2 = GetOutRec(join->OutPt2->Idx);
+
+ if (!outRec1->Pts || !outRec2->Pts) continue;
+ if (outRec1->IsOpen || outRec2->IsOpen) continue;
+
+ //get the polygon fragment with the correct hole state (FirstLeft)
+ //before calling JoinPoints() ...
+ OutRec *holeStateRec;
+ if (outRec1 == outRec2) holeStateRec = outRec1;
+ else if (OutRec1RightOfOutRec2(outRec1, outRec2)) holeStateRec = outRec2;
+ else if (OutRec1RightOfOutRec2(outRec2, outRec1)) holeStateRec = outRec1;
+ else holeStateRec = GetLowermostRec(outRec1, outRec2);
+
+ if (!JoinPoints(join, outRec1, outRec2)) continue;
+
+ if (outRec1 == outRec2)
+ {
+ //instead of joining two polygons, we've just created a new one by
+ //splitting one polygon into two.
+ outRec1->Pts = join->OutPt1;
+ outRec1->BottomPt = 0;
+ outRec2 = CreateOutRec();
+ outRec2->Pts = join->OutPt2;
+
+ //update all OutRec2.Pts Idx's ...
+ UpdateOutPtIdxs(*outRec2);
+
+ if (Poly2ContainsPoly1(outRec2->Pts, outRec1->Pts))
+ {
+ //outRec1 contains outRec2 ...
+ outRec2->IsHole = !outRec1->IsHole;
+ outRec2->FirstLeft = outRec1;
+
+ if (m_UsingPolyTree) FixupFirstLefts2(outRec2, outRec1);
+
+ if ((outRec2->IsHole ^ m_ReverseOutput) == (Area(*outRec2) > 0))
+ ReversePolyPtLinks(outRec2->Pts);
+
+ } else if (Poly2ContainsPoly1(outRec1->Pts, outRec2->Pts))
+ {
+ //outRec2 contains outRec1 ...
+ outRec2->IsHole = outRec1->IsHole;
+ outRec1->IsHole = !outRec2->IsHole;
+ outRec2->FirstLeft = outRec1->FirstLeft;
+ outRec1->FirstLeft = outRec2;
+
+ if (m_UsingPolyTree) FixupFirstLefts2(outRec1, outRec2);
+
+ if ((outRec1->IsHole ^ m_ReverseOutput) == (Area(*outRec1) > 0))
+ ReversePolyPtLinks(outRec1->Pts);
+ }
+ else
+ {
+ //the 2 polygons are completely separate ...
+ outRec2->IsHole = outRec1->IsHole;
+ outRec2->FirstLeft = outRec1->FirstLeft;
+
+ //fixup FirstLeft pointers that may need reassigning to OutRec2
+ if (m_UsingPolyTree) FixupFirstLefts1(outRec1, outRec2);
+ }
+
+ } else
+ {
+ //joined 2 polygons together ...
+
+ outRec2->Pts = 0;
+ outRec2->BottomPt = 0;
+ outRec2->Idx = outRec1->Idx;
+
+ outRec1->IsHole = holeStateRec->IsHole;
+ if (holeStateRec == outRec2)
+ outRec1->FirstLeft = outRec2->FirstLeft;
+ outRec2->FirstLeft = outRec1;
+
+ if (m_UsingPolyTree) FixupFirstLefts3(outRec2, outRec1);
+ }
+ }
+}
+
+//------------------------------------------------------------------------------
+// ClipperOffset support functions ...
+//------------------------------------------------------------------------------
+
+DoublePoint GetUnitNormal(const IntPoint &pt1, const IntPoint &pt2)
+{
+ if(pt2.X == pt1.X && pt2.Y == pt1.Y)
+ return DoublePoint(0, 0);
+
+ double Dx = (double)(pt2.X - pt1.X);
+ double dy = (double)(pt2.Y - pt1.Y);
+ double f = 1 *1.0/ std::sqrt( Dx*Dx + dy*dy );
+ Dx *= f;
+ dy *= f;
+ return DoublePoint(dy, -Dx);
+}
+
+//------------------------------------------------------------------------------
+// ClipperOffset class
+//------------------------------------------------------------------------------
+
+ClipperOffset::ClipperOffset(double miterLimit, double arcTolerance)
+{
+ this->MiterLimit = miterLimit;
+ this->ArcTolerance = arcTolerance;
+ m_lowest.X = -1;
+}
+//------------------------------------------------------------------------------
+
+ClipperOffset::~ClipperOffset()
+{
+ Clear();
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::Clear()
+{
+ for (int i = 0; i < m_polyNodes.ChildCount(); ++i)
+ delete m_polyNodes.Childs[i];
+ m_polyNodes.Childs.clear();
+ m_lowest.X = -1;
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::AddPath(const Path& path, JoinType joinType, EndType endType)
+{
+ int highI = (int)path.size() - 1;
+ if (highI < 0) return;
+ PolyNode* newNode = new PolyNode();
+ newNode->m_jointype = joinType;
+ newNode->m_endtype = endType;
+
+ //strip duplicate points from path and also get index to the lowest point ...
+ if (endType == etClosedLine || endType == etClosedPolygon)
+ while (highI > 0 && path[0] == path[highI]) highI--;
+ newNode->Contour.reserve(highI + 1);
+ newNode->Contour.push_back(path[0]);
+ int j = 0, k = 0;
+ for (int i = 1; i <= highI; i++)
+ if (newNode->Contour[j] != path[i])
+ {
+ j++;
+ newNode->Contour.push_back(path[i]);
+ if (path[i].Y > newNode->Contour[k].Y ||
+ (path[i].Y == newNode->Contour[k].Y &&
+ path[i].X < newNode->Contour[k].X)) k = j;
+ }
+ if (endType == etClosedPolygon && j < 2)
+ {
+ delete newNode;
+ return;
+ }
+ m_polyNodes.AddChild(*newNode);
+
+ //if this path's lowest pt is lower than all the others then update m_lowest
+ if (endType != etClosedPolygon) return;
+ if (m_lowest.X < 0)
+ m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k);
+ else
+ {
+ IntPoint ip = m_polyNodes.Childs[(int)m_lowest.X]->Contour[(int)m_lowest.Y];
+ if (newNode->Contour[k].Y > ip.Y ||
+ (newNode->Contour[k].Y == ip.Y &&
+ newNode->Contour[k].X < ip.X))
+ m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k);
+ }
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::AddPaths(const Paths& paths, JoinType joinType, EndType endType)
+{
+ for (Paths::size_type i = 0; i < paths.size(); ++i)
+ AddPath(paths[i], joinType, endType);
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::FixOrientations()
+{
+ //fixup orientations of all closed paths if the orientation of the
+ //closed path with the lowermost vertex is wrong ...
+ if (m_lowest.X >= 0 &&
+ !Orientation(m_polyNodes.Childs[(int)m_lowest.X]->Contour))
+ {
+ for (int i = 0; i < m_polyNodes.ChildCount(); ++i)
+ {
+ PolyNode& node = *m_polyNodes.Childs[i];
+ if (node.m_endtype == etClosedPolygon ||
+ (node.m_endtype == etClosedLine && Orientation(node.Contour)))
+ ReversePath(node.Contour);
+ }
+ } else
+ {
+ for (int i = 0; i < m_polyNodes.ChildCount(); ++i)
+ {
+ PolyNode& node = *m_polyNodes.Childs[i];
+ if (node.m_endtype == etClosedLine && !Orientation(node.Contour))
+ ReversePath(node.Contour);
+ }
+ }
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::Execute(Paths& solution, double delta)
+{
+ solution.clear();
+ FixOrientations();
+ DoOffset(delta);
+
+ //now clean up 'corners' ...
+ Clipper clpr;
+ clpr.AddPaths(m_destPolys, ptSubject, true);
+ if (delta > 0)
+ {
+ clpr.Execute(ctUnion, solution, pftPositive, pftPositive);
+ }
+ else
+ {
+ IntRect r = clpr.GetBounds();
+ Path outer(4);
+ outer[0] = IntPoint(r.left - 10, r.bottom + 10);
+ outer[1] = IntPoint(r.right + 10, r.bottom + 10);
+ outer[2] = IntPoint(r.right + 10, r.top - 10);
+ outer[3] = IntPoint(r.left - 10, r.top - 10);
+
+ clpr.AddPath(outer, ptSubject, true);
+ clpr.ReverseSolution(true);
+ clpr.Execute(ctUnion, solution, pftNegative, pftNegative);
+ if (solution.size() > 0) solution.erase(solution.begin());
+ }
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::Execute(PolyTree& solution, double delta)
+{
+ solution.Clear();
+ FixOrientations();
+ DoOffset(delta);
+
+ //now clean up 'corners' ...
+ Clipper clpr;
+ clpr.AddPaths(m_destPolys, ptSubject, true);
+ if (delta > 0)
+ {
+ clpr.Execute(ctUnion, solution, pftPositive, pftPositive);
+ }
+ else
+ {
+ IntRect r = clpr.GetBounds();
+ Path outer(4);
+ outer[0] = IntPoint(r.left - 10, r.bottom + 10);
+ outer[1] = IntPoint(r.right + 10, r.bottom + 10);
+ outer[2] = IntPoint(r.right + 10, r.top - 10);
+ outer[3] = IntPoint(r.left - 10, r.top - 10);
+
+ clpr.AddPath(outer, ptSubject, true);
+ clpr.ReverseSolution(true);
+ clpr.Execute(ctUnion, solution, pftNegative, pftNegative);
+ //remove the outer PolyNode rectangle ...
+ if (solution.ChildCount() == 1 && solution.Childs[0]->ChildCount() > 0)
+ {
+ PolyNode* outerNode = solution.Childs[0];
+ solution.Childs.reserve(outerNode->ChildCount());
+ solution.Childs[0] = outerNode->Childs[0];
+ solution.Childs[0]->Parent = outerNode->Parent;
+ for (int i = 1; i < outerNode->ChildCount(); ++i)
+ solution.AddChild(*outerNode->Childs[i]);
+ }
+ else
+ solution.Clear();
+ }
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::DoOffset(double delta)
+{
+ m_destPolys.clear();
+ m_delta = delta;
+
+ //if Zero offset, just copy any CLOSED polygons to m_p and return ...
+ if (NEAR_ZERO(delta))
+ {
+ m_destPolys.reserve(m_polyNodes.ChildCount());
+ for (int i = 0; i < m_polyNodes.ChildCount(); i++)
+ {
+ PolyNode& node = *m_polyNodes.Childs[i];
+ if (node.m_endtype == etClosedPolygon)
+ m_destPolys.push_back(node.Contour);
+ }
+ return;
+ }
+
+ //see offset_triginometry3.svg in the documentation folder ...
+ if (MiterLimit > 2) m_miterLim = 2/(MiterLimit * MiterLimit);
+ else m_miterLim = 0.5;
+
+ double y;
+ if (ArcTolerance <= 0.0) y = def_arc_tolerance;
+ else if (ArcTolerance > std::fabs(delta) * def_arc_tolerance)
+ y = std::fabs(delta) * def_arc_tolerance;
+ else y = ArcTolerance;
+ //see offset_triginometry2.svg in the documentation folder ...
+ double steps = pi / std::acos(1 - y / std::fabs(delta));
+ if (steps > std::fabs(delta) * pi)
+ steps = std::fabs(delta) * pi; //ie excessive precision check
+ m_sin = std::sin(two_pi / steps);
+ m_cos = std::cos(two_pi / steps);
+ m_StepsPerRad = steps / two_pi;
+ if (delta < 0.0) m_sin = -m_sin;
+
+ m_destPolys.reserve(m_polyNodes.ChildCount() * 2);
+ for (int i = 0; i < m_polyNodes.ChildCount(); i++)
+ {
+ PolyNode& node = *m_polyNodes.Childs[i];
+ m_srcPoly = node.Contour;
+
+ int len = (int)m_srcPoly.size();
+ if (len == 0 || (delta <= 0 && (len < 3 || node.m_endtype != etClosedPolygon)))
+ continue;
+
+ m_destPoly.clear();
+ if (len == 1)
+ {
+ if (node.m_jointype == jtRound)
+ {
+ double X = 1.0, Y = 0.0;
+ for (cInt j = 1; j <= steps; j++)
+ {
+ m_destPoly.push_back(IntPoint(
+ Round(m_srcPoly[0].X + X * delta),
+ Round(m_srcPoly[0].Y + Y * delta)));
+ double X2 = X;
+ X = X * m_cos - m_sin * Y;
+ Y = X2 * m_sin + Y * m_cos;
+ }
+ }
+ else
+ {
+ double X = -1.0, Y = -1.0;
+ for (int j = 0; j < 4; ++j)
+ {
+ m_destPoly.push_back(IntPoint(
+ Round(m_srcPoly[0].X + X * delta),
+ Round(m_srcPoly[0].Y + Y * delta)));
+ if (X < 0) X = 1;
+ else if (Y < 0) Y = 1;
+ else X = -1;
+ }
+ }
+ m_destPolys.push_back(m_destPoly);
+ continue;
+ }
+ //build m_normals ...
+ m_normals.clear();
+ m_normals.reserve(len);
+ for (int j = 0; j < len - 1; ++j)
+ m_normals.push_back(GetUnitNormal(m_srcPoly[j], m_srcPoly[j + 1]));
+ if (node.m_endtype == etClosedLine || node.m_endtype == etClosedPolygon)
+ m_normals.push_back(GetUnitNormal(m_srcPoly[len - 1], m_srcPoly[0]));
+ else
+ m_normals.push_back(DoublePoint(m_normals[len - 2]));
+
+ if (node.m_endtype == etClosedPolygon)
+ {
+ int k = len - 1;
+ for (int j = 0; j < len; ++j)
+ OffsetPoint(j, k, node.m_jointype);
+ m_destPolys.push_back(m_destPoly);
+ }
+ else if (node.m_endtype == etClosedLine)
+ {
+ int k = len - 1;
+ for (int j = 0; j < len; ++j)
+ OffsetPoint(j, k, node.m_jointype);
+ m_destPolys.push_back(m_destPoly);
+ m_destPoly.clear();
+ //re-build m_normals ...
+ DoublePoint n = m_normals[len -1];
+ for (int j = len - 1; j > 0; j--)
+ m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y);
+ m_normals[0] = DoublePoint(-n.X, -n.Y);
+ k = 0;
+ for (int j = len - 1; j >= 0; j--)
+ OffsetPoint(j, k, node.m_jointype);
+ m_destPolys.push_back(m_destPoly);
+ }
+ else
+ {
+ int k = 0;
+ for (int j = 1; j < len - 1; ++j)
+ OffsetPoint(j, k, node.m_jointype);
+
+ IntPoint pt1;
+ if (node.m_endtype == etOpenButt)
+ {
+ int j = len - 1;
+ pt1 = IntPoint((cInt)Round(m_srcPoly[j].X + m_normals[j].X *
+ delta), (cInt)Round(m_srcPoly[j].Y + m_normals[j].Y * delta));
+ m_destPoly.push_back(pt1);
+ pt1 = IntPoint((cInt)Round(m_srcPoly[j].X - m_normals[j].X *
+ delta), (cInt)Round(m_srcPoly[j].Y - m_normals[j].Y * delta));
+ m_destPoly.push_back(pt1);
+ }
+ else
+ {
+ int j = len - 1;
+ k = len - 2;
+ m_sinA = 0;
+ m_normals[j] = DoublePoint(-m_normals[j].X, -m_normals[j].Y);
+ if (node.m_endtype == etOpenSquare)
+ DoSquare(j, k);
+ else
+ DoRound(j, k);
+ }
+
+ //re-build m_normals ...
+ for (int j = len - 1; j > 0; j--)
+ m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y);
+ m_normals[0] = DoublePoint(-m_normals[1].X, -m_normals[1].Y);
+
+ k = len - 1;
+ for (int j = k - 1; j > 0; --j) OffsetPoint(j, k, node.m_jointype);
+
+ if (node.m_endtype == etOpenButt)
+ {
+ pt1 = IntPoint((cInt)Round(m_srcPoly[0].X - m_normals[0].X * delta),
+ (cInt)Round(m_srcPoly[0].Y - m_normals[0].Y * delta));
+ m_destPoly.push_back(pt1);
+ pt1 = IntPoint((cInt)Round(m_srcPoly[0].X + m_normals[0].X * delta),
+ (cInt)Round(m_srcPoly[0].Y + m_normals[0].Y * delta));
+ m_destPoly.push_back(pt1);
+ }
+ else
+ {
+ k = 1;
+ m_sinA = 0;
+ if (node.m_endtype == etOpenSquare)
+ DoSquare(0, 1);
+ else
+ DoRound(0, 1);
+ }
+ m_destPolys.push_back(m_destPoly);
+ }
+ }
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::OffsetPoint(int j, int& k, JoinType jointype)
+{
+ //cross product ...
+ m_sinA = (m_normals[k].X * m_normals[j].Y - m_normals[j].X * m_normals[k].Y);
+ if (std::fabs(m_sinA * m_delta) < 1.0)
+ {
+ //dot product ...
+ double cosA = (m_normals[k].X * m_normals[j].X + m_normals[j].Y * m_normals[k].Y );
+ if (cosA > 0) // angle => 0 degrees
+ {
+ m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta),
+ Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta)));
+ return;
+ }
+ //else angle => 180 degrees
+ }
+ else if (m_sinA > 1.0) m_sinA = 1.0;
+ else if (m_sinA < -1.0) m_sinA = -1.0;
+
+ if (m_sinA * m_delta < 0)
+ {
+ m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta),
+ Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta)));
+ m_destPoly.push_back(m_srcPoly[j]);
+ m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[j].X * m_delta),
+ Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta)));
+ }
+ else
+ switch (jointype)
+ {
+ case jtMiter:
+ {
+ double r = 1 + (m_normals[j].X * m_normals[k].X +
+ m_normals[j].Y * m_normals[k].Y);
+ if (r >= m_miterLim) DoMiter(j, k, r); else DoSquare(j, k);
+ break;
+ }
+ case jtSquare: DoSquare(j, k); break;
+ case jtRound: DoRound(j, k); break;
+ }
+ k = j;
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::DoSquare(int j, int k)
+{
+ double dx = std::tan(std::atan2(m_sinA,
+ m_normals[k].X * m_normals[j].X + m_normals[k].Y * m_normals[j].Y) / 4);
+ m_destPoly.push_back(IntPoint(
+ Round(m_srcPoly[j].X + m_delta * (m_normals[k].X - m_normals[k].Y * dx)),
+ Round(m_srcPoly[j].Y + m_delta * (m_normals[k].Y + m_normals[k].X * dx))));
+ m_destPoly.push_back(IntPoint(
+ Round(m_srcPoly[j].X + m_delta * (m_normals[j].X + m_normals[j].Y * dx)),
+ Round(m_srcPoly[j].Y + m_delta * (m_normals[j].Y - m_normals[j].X * dx))));
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::DoMiter(int j, int k, double r)
+{
+ double q = m_delta / r;
+ m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + (m_normals[k].X + m_normals[j].X) * q),
+ Round(m_srcPoly[j].Y + (m_normals[k].Y + m_normals[j].Y) * q)));
+}
+//------------------------------------------------------------------------------
+
+void ClipperOffset::DoRound(int j, int k)
+{
+ double a = std::atan2(m_sinA,
+ m_normals[k].X * m_normals[j].X + m_normals[k].Y * m_normals[j].Y);
+ int steps = std::max((int)Round(m_StepsPerRad * std::fabs(a)), 1);
+
+ double X = m_normals[k].X, Y = m_normals[k].Y, X2;
+ for (int i = 0; i < steps; ++i)
+ {
+ m_destPoly.push_back(IntPoint(
+ Round(m_srcPoly[j].X + X * m_delta),
+ Round(m_srcPoly[j].Y + Y * m_delta)));
+ X2 = X;
+ X = X * m_cos - m_sin * Y;
+ Y = X2 * m_sin + Y * m_cos;
+ }
+ m_destPoly.push_back(IntPoint(
+ Round(m_srcPoly[j].X + m_normals[j].X * m_delta),
+ Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta)));
+}
+
+//------------------------------------------------------------------------------
+// Miscellaneous public functions
+//------------------------------------------------------------------------------
+
+void Clipper::DoSimplePolygons()
+{
+ PolyOutList::size_type i = 0;
+ while (i < m_PolyOuts.size())
+ {
+ OutRec* outrec = m_PolyOuts[i++];
+ OutPt* op = outrec->Pts;
+ if (!op || outrec->IsOpen) continue;
+ do //for each Pt in Polygon until duplicate found do ...
+ {
+ OutPt* op2 = op->Next;
+ while (op2 != outrec->Pts)
+ {
+ if ((op->Pt == op2->Pt) && op2->Next != op && op2->Prev != op)
+ {
+ //split the polygon into two ...
+ OutPt* op3 = op->Prev;
+ OutPt* op4 = op2->Prev;
+ op->Prev = op4;
+ op4->Next = op;
+ op2->Prev = op3;
+ op3->Next = op2;
+
+ outrec->Pts = op;
+ OutRec* outrec2 = CreateOutRec();
+ outrec2->Pts = op2;
+ UpdateOutPtIdxs(*outrec2);
+ if (Poly2ContainsPoly1(outrec2->Pts, outrec->Pts))
+ {
+ //OutRec2 is contained by OutRec1 ...
+ outrec2->IsHole = !outrec->IsHole;
+ outrec2->FirstLeft = outrec;
+ if (m_UsingPolyTree) FixupFirstLefts2(outrec2, outrec);
+ }
+ else
+ if (Poly2ContainsPoly1(outrec->Pts, outrec2->Pts))
+ {
+ //OutRec1 is contained by OutRec2 ...
+ outrec2->IsHole = outrec->IsHole;
+ outrec->IsHole = !outrec2->IsHole;
+ outrec2->FirstLeft = outrec->FirstLeft;
+ outrec->FirstLeft = outrec2;
+ if (m_UsingPolyTree) FixupFirstLefts2(outrec, outrec2);
+ }
+ else
+ {
+ //the 2 polygons are separate ...
+ outrec2->IsHole = outrec->IsHole;
+ outrec2->FirstLeft = outrec->FirstLeft;
+ if (m_UsingPolyTree) FixupFirstLefts1(outrec, outrec2);
+ }
+ op2 = op; //ie get ready for the Next iteration
+ }
+ op2 = op2->Next;
+ }
+ op = op->Next;
+ }
+ while (op != outrec->Pts);
+ }
+}
+//------------------------------------------------------------------------------
+
+void ReversePath(Path& p)
+{
+ std::reverse(p.begin(), p.end());
+}
+//------------------------------------------------------------------------------
+
+void ReversePaths(Paths& p)
+{
+ for (Paths::size_type i = 0; i < p.size(); ++i)
+ ReversePath(p[i]);
+}
+//------------------------------------------------------------------------------
+
+void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType)
+{
+ Clipper c;
+ c.StrictlySimple(true);
+ c.AddPath(in_poly, ptSubject, true);
+ c.Execute(ctUnion, out_polys, fillType, fillType);
+}
+//------------------------------------------------------------------------------
+
+void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType)
+{
+ Clipper c;
+ c.StrictlySimple(true);
+ c.AddPaths(in_polys, ptSubject, true);
+ c.Execute(ctUnion, out_polys, fillType, fillType);
+}
+//------------------------------------------------------------------------------
+
+void SimplifyPolygons(Paths &polys, PolyFillType fillType)
+{
+ SimplifyPolygons(polys, polys, fillType);
+}
+//------------------------------------------------------------------------------
+
+inline double DistanceSqrd(const IntPoint& pt1, const IntPoint& pt2)
+{
+ double Dx = ((double)pt1.X - pt2.X);
+ double dy = ((double)pt1.Y - pt2.Y);
+ return (Dx*Dx + dy*dy);
+}
+//------------------------------------------------------------------------------
+
+double DistanceFromLineSqrd(
+ const IntPoint& pt, const IntPoint& ln1, const IntPoint& ln2)
+{
+ //The equation of a line in general form (Ax + By + C = 0)
+ //given 2 points (x¹,y¹) & (x²,y²) is ...
+ //(y¹ - y²)x + (x² - x¹)y + (y² - y¹)x¹ - (x² - x¹)y¹ = 0
+ //A = (y¹ - y²); B = (x² - x¹); C = (y² - y¹)x¹ - (x² - x¹)y¹
+ //perpendicular distance of point (x³,y³) = (Ax³ + By³ + C)/Sqrt(A² + B²)
+ //see http://en.wikipedia.org/wiki/Perpendicular_distance
+ double A = double(ln1.Y - ln2.Y);
+ double B = double(ln2.X - ln1.X);
+ double C = A * ln1.X + B * ln1.Y;
+ C = A * pt.X + B * pt.Y - C;
+ return (C * C) / (A * A + B * B);
+}
+//---------------------------------------------------------------------------
+
+bool SlopesNearCollinear(const IntPoint& pt1,
+ const IntPoint& pt2, const IntPoint& pt3, double distSqrd)
+{
+ //this function is more accurate when the point that's geometrically
+ //between the other 2 points is the one that's tested for distance.
+ //ie makes it more likely to pick up 'spikes' ...
+ if (Abs(pt1.X - pt2.X) > Abs(pt1.Y - pt2.Y))
+ {
+ if ((pt1.X > pt2.X) == (pt1.X < pt3.X))
+ return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd;
+ else if ((pt2.X > pt1.X) == (pt2.X < pt3.X))
+ return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd;
+ else
+ return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd;
+ }
+ else
+ {
+ if ((pt1.Y > pt2.Y) == (pt1.Y < pt3.Y))
+ return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd;
+ else if ((pt2.Y > pt1.Y) == (pt2.Y < pt3.Y))
+ return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd;
+ else
+ return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd;
+ }
+}
+//------------------------------------------------------------------------------
+
+bool PointsAreClose(IntPoint pt1, IntPoint pt2, double distSqrd)
+{
+ double Dx = (double)pt1.X - pt2.X;
+ double dy = (double)pt1.Y - pt2.Y;
+ return ((Dx * Dx) + (dy * dy) <= distSqrd);
+}
+//------------------------------------------------------------------------------
+
+OutPt* ExcludeOp(OutPt* op)
+{
+ OutPt* result = op->Prev;
+ result->Next = op->Next;
+ op->Next->Prev = result;
+ result->Idx = 0;
+ return result;
+}
+//------------------------------------------------------------------------------
+
+void CleanPolygon(const Path& in_poly, Path& out_poly, double distance)
+{
+ //distance = proximity in units/pixels below which vertices
+ //will be stripped. Default ~= sqrt(2).
+
+ size_t size = in_poly.size();
+
+ if (size == 0)
+ {
+ out_poly.clear();
+ return;
+ }
+
+ OutPt* outPts = new OutPt[size];
+ for (size_t i = 0; i < size; ++i)
+ {
+ outPts[i].Pt = in_poly[i];
+ outPts[i].Next = &outPts[(i + 1) % size];
+ outPts[i].Next->Prev = &outPts[i];
+ outPts[i].Idx = 0;
+ }
+
+ double distSqrd = distance * distance;
+ OutPt* op = &outPts[0];
+ while (op->Idx == 0 && op->Next != op->Prev)
+ {
+ if (PointsAreClose(op->Pt, op->Prev->Pt, distSqrd))
+ {
+ op = ExcludeOp(op);
+ size--;
+ }
+ else if (PointsAreClose(op->Prev->Pt, op->Next->Pt, distSqrd))
+ {
+ ExcludeOp(op->Next);
+ op = ExcludeOp(op);
+ size -= 2;
+ }
+ else if (SlopesNearCollinear(op->Prev->Pt, op->Pt, op->Next->Pt, distSqrd))
+ {
+ op = ExcludeOp(op);
+ size--;
+ }
+ else
+ {
+ op->Idx = 1;
+ op = op->Next;
+ }
+ }
+
+ if (size < 3) size = 0;
+ out_poly.resize(size);
+ for (size_t i = 0; i < size; ++i)
+ {
+ out_poly[i] = op->Pt;
+ op = op->Next;
+ }
+ delete [] outPts;
+}
+//------------------------------------------------------------------------------
+
+void CleanPolygon(Path& poly, double distance)
+{
+ CleanPolygon(poly, poly, distance);
+}
+//------------------------------------------------------------------------------
+
+void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance)
+{
+ out_polys.resize(in_polys.size());
+ for (Paths::size_type i = 0; i < in_polys.size(); ++i)
+ CleanPolygon(in_polys[i], out_polys[i], distance);
+}
+//------------------------------------------------------------------------------
+
+void CleanPolygons(Paths& polys, double distance)
+{
+ CleanPolygons(polys, polys, distance);
+}
+//------------------------------------------------------------------------------
+
+void Minkowski(const Path& poly, const Path& path,
+ Paths& solution, bool isSum, bool isClosed)
+{
+ int delta = (isClosed ? 1 : 0);
+ size_t polyCnt = poly.size();
+ size_t pathCnt = path.size();
+ Paths pp;
+ pp.reserve(pathCnt);
+ if (isSum)
+ for (size_t i = 0; i < pathCnt; ++i)
+ {
+ Path p;
+ p.reserve(polyCnt);
+ for (size_t j = 0; j < poly.size(); ++j)
+ p.push_back(IntPoint(path[i].X + poly[j].X, path[i].Y + poly[j].Y));
+ pp.push_back(p);
+ }
+ else
+ for (size_t i = 0; i < pathCnt; ++i)
+ {
+ Path p;
+ p.reserve(polyCnt);
+ for (size_t j = 0; j < poly.size(); ++j)
+ p.push_back(IntPoint(path[i].X - poly[j].X, path[i].Y - poly[j].Y));
+ pp.push_back(p);
+ }
+
+ solution.clear();
+ solution.reserve((pathCnt + delta) * (polyCnt + 1));
+ for (size_t i = 0; i < pathCnt - 1 + delta; ++i)
+ for (size_t j = 0; j < polyCnt; ++j)
+ {
+ Path quad;
+ quad.reserve(4);
+ quad.push_back(pp[i % pathCnt][j % polyCnt]);
+ quad.push_back(pp[(i + 1) % pathCnt][j % polyCnt]);
+ quad.push_back(pp[(i + 1) % pathCnt][(j + 1) % polyCnt]);
+ quad.push_back(pp[i % pathCnt][(j + 1) % polyCnt]);
+ if (!Orientation(quad)) ReversePath(quad);
+ solution.push_back(quad);
+ }
+}
+//------------------------------------------------------------------------------
+
+void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed)
+{
+ Minkowski(pattern, path, solution, true, pathIsClosed);
+ Clipper c;
+ c.AddPaths(solution, ptSubject, true);
+ c.Execute(ctUnion, solution, pftNonZero, pftNonZero);
+}
+//------------------------------------------------------------------------------
+
+void TranslatePath(const Path& input, Path& output, const IntPoint delta)
+{
+ //precondition: input != output
+ output.resize(input.size());
+ for (size_t i = 0; i < input.size(); ++i)
+ output[i] = IntPoint(input[i].X + delta.X, input[i].Y + delta.Y);
+}
+//------------------------------------------------------------------------------
+
+void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed)
+{
+ Clipper c;
+ for (size_t i = 0; i < paths.size(); ++i)
+ {
+ Paths tmp;
+ Minkowski(pattern, paths[i], tmp, true, pathIsClosed);
+ c.AddPaths(tmp, ptSubject, true);
+ if (pathIsClosed)
+ {
+ Path tmp2;
+ TranslatePath(paths[i], tmp2, pattern[0]);
+ c.AddPath(tmp2, ptClip, true);
+ }
+ }
+ c.Execute(ctUnion, solution, pftNonZero, pftNonZero);
+}
+//------------------------------------------------------------------------------
+
+void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution)
+{
+ Minkowski(poly1, poly2, solution, false, true);
+ Clipper c;
+ c.AddPaths(solution, ptSubject, true);
+ c.Execute(ctUnion, solution, pftNonZero, pftNonZero);
+}
+//------------------------------------------------------------------------------
+
+enum NodeType {ntAny, ntOpen, ntClosed};
+
+void AddPolyNodeToPaths(const PolyNode& polynode, NodeType nodetype, Paths& paths)
+{
+ bool match = true;
+ if (nodetype == ntClosed) match = !polynode.IsOpen();
+ else if (nodetype == ntOpen) return;
+
+ if (!polynode.Contour.empty() && match)
+ paths.push_back(polynode.Contour);
+ for (int i = 0; i < polynode.ChildCount(); ++i)
+ AddPolyNodeToPaths(*polynode.Childs[i], nodetype, paths);
+}
+//------------------------------------------------------------------------------
+
+void PolyTreeToPaths(const PolyTree& polytree, Paths& paths)
+{
+ paths.resize(0);
+ paths.reserve(polytree.Total());
+ AddPolyNodeToPaths(polytree, ntAny, paths);
+}
+//------------------------------------------------------------------------------
+
+void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths)
+{
+ paths.resize(0);
+ paths.reserve(polytree.Total());
+ AddPolyNodeToPaths(polytree, ntClosed, paths);
+}
+//------------------------------------------------------------------------------
+
+void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths)
+{
+ paths.resize(0);
+ paths.reserve(polytree.Total());
+ //Open paths are top level only, so ...
+ for (int i = 0; i < polytree.ChildCount(); ++i)
+ if (polytree.Childs[i]->IsOpen())
+ paths.push_back(polytree.Childs[i]->Contour);
+}
+//------------------------------------------------------------------------------
+
+std::ostream& operator <<(std::ostream &s, const IntPoint &p)
+{
+ s << "(" << p.X << "," << p.Y << ")";
+ return s;
+}
+//------------------------------------------------------------------------------
+
+std::ostream& operator <<(std::ostream &s, const Path &p)
+{
+ if (p.empty()) return s;
+ Path::size_type last = p.size() -1;
+ for (Path::size_type i = 0; i < last; i++)
+ s << "(" << p[i].X << "," << p[i].Y << "), ";
+ s << "(" << p[last].X << "," << p[last].Y << ")\n";
+ return s;
+}
+//------------------------------------------------------------------------------
+
+std::ostream& operator <<(std::ostream &s, const Paths &p)
+{
+ for (Paths::size_type i = 0; i < p.size(); i++)
+ s << p[i];
+ s << "\n";
+ return s;
+}
+//------------------------------------------------------------------------------
+
+} //ClipperLib namespace
diff --git a/mmocr/models/textdet/postprocess/include/clipper/clipper.hpp b/mmocr/models/textdet/postprocess/include/clipper/clipper.hpp
new file mode 100644
index 00000000..2ac51182
--- /dev/null
+++ b/mmocr/models/textdet/postprocess/include/clipper/clipper.hpp
@@ -0,0 +1,402 @@
+/*******************************************************************************
+* *
+* Author : Angus Johnson *
+* Version : 6.4.0 *
+* Date : 2 July 2015 *
+* Website : http://www.angusj.com *
+* Copyright : Angus Johnson 2010-2015 *
+* *
+* License: *
+* Use, modification & distribution is subject to Boost Software License Ver 1. *
+* http://www.boost.org/LICENSE_1_0.txt *
+* *
+* Attributions: *
+* The code in this library is an extension of Bala Vatti's clipping algorithm: *
+* "A generic solution to polygon clipping" *
+* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
+* http://portal.acm.org/citation.cfm?id=129906 *
+* *
+* Computer graphics and geometric modeling: implementation and algorithms *
+* By Max K. Agoston *
+* Springer; 1 edition (January 4, 2005) *
+* http://books.google.com/books?q=vatti+clipping+agoston *
+* *
+* See also: *
+* "Polygon Offsetting by Computing Winding Numbers" *
+* Paper no. DETC2005-85513 pp. 565-575 *
+* ASME 2005 International Design Engineering Technical Conferences *
+* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
+* September 24-28, 2005 , Long Beach, California, USA *
+* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
+* *
+*******************************************************************************/
+
+#ifndef clipper_hpp
+#define clipper_hpp
+
+#define CLIPPER_VERSION "6.2.6"
+
+//use_int32: When enabled 32bit ints are used instead of 64bit ints. This
+//improve performance but coordinate values are limited to the range +/- 46340
+//#define use_int32
+
+//use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
+//#define use_xyz
+
+//use_lines: Enables line clipping. Adds a very minor cost to performance.
+#define use_lines
+
+//use_deprecated: Enables temporary support for the obsolete functions
+//#define use_deprecated
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace ClipperLib {
+
+enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor };
+enum PolyType { ptSubject, ptClip };
+//By far the most widely used winding rules for polygon filling are
+//EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32)
+//Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL)
+//see http://glprogramming.com/red/chapter11.html
+enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative };
+
+#ifdef use_int32
+ typedef int cInt;
+ static cInt const loRange = 0x7FFF;
+ static cInt const hiRange = 0x7FFF;
+#else
+ typedef signed long long cInt;
+ static cInt const loRange = 0x3FFFFFFF;
+ static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL;
+ typedef signed long long long64; //used by Int128 class
+ typedef unsigned long long ulong64;
+
+#endif
+
+struct IntPoint {
+ cInt X;
+ cInt Y;
+#ifdef use_xyz
+ cInt Z;
+ IntPoint(cInt x = 0, cInt y = 0, cInt z = 0): X(x), Y(y), Z(z) {};
+#else
+ IntPoint(cInt x = 0, cInt y = 0): X(x), Y(y) {};
+#endif
+
+ friend inline bool operator== (const IntPoint& a, const IntPoint& b)
+ {
+ return a.X == b.X && a.Y == b.Y;
+ }
+ friend inline bool operator!= (const IntPoint& a, const IntPoint& b)
+ {
+ return a.X != b.X || a.Y != b.Y;
+ }
+};
+//------------------------------------------------------------------------------
+
+typedef std::vector< IntPoint > Path;
+typedef std::vector< Path > Paths;
+
+inline Path& operator <<(Path& poly, const IntPoint& p) {poly.push_back(p); return poly;}
+inline Paths& operator <<(Paths& polys, const Path& p) {polys.push_back(p); return polys;}
+
+std::ostream& operator <<(std::ostream &s, const IntPoint &p);
+std::ostream& operator <<(std::ostream &s, const Path &p);
+std::ostream& operator <<(std::ostream &s, const Paths &p);
+
+struct DoublePoint
+{
+ double X;
+ double Y;
+ DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {}
+ DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {}
+};
+//------------------------------------------------------------------------------
+
+#ifdef use_xyz
+typedef void (*ZFillCallback)(IntPoint& e1bot, IntPoint& e1top, IntPoint& e2bot, IntPoint& e2top, IntPoint& pt);
+#endif
+
+enum InitOptions {ioReverseSolution = 1, ioStrictlySimple = 2, ioPreserveCollinear = 4};
+enum JoinType {jtSquare, jtRound, jtMiter};
+enum EndType {etClosedPolygon, etClosedLine, etOpenButt, etOpenSquare, etOpenRound};
+
+class PolyNode;
+typedef std::vector< PolyNode* > PolyNodes;
+
+class PolyNode
+{
+public:
+ PolyNode();
+ virtual ~PolyNode(){};
+ Path Contour;
+ PolyNodes Childs;
+ PolyNode* Parent;
+ PolyNode* GetNext() const;
+ bool IsHole() const;
+ bool IsOpen() const;
+ int ChildCount() const;
+private:
+ unsigned Index; //node index in Parent.Childs
+ bool m_IsOpen;
+ JoinType m_jointype;
+ EndType m_endtype;
+ PolyNode* GetNextSiblingUp() const;
+ void AddChild(PolyNode& child);
+ friend class Clipper; //to access Index
+ friend class ClipperOffset;
+};
+
+class PolyTree: public PolyNode
+{
+public:
+ ~PolyTree(){Clear();};
+ PolyNode* GetFirst() const;
+ void Clear();
+ int Total() const;
+private:
+ PolyNodes AllNodes;
+ friend class Clipper; //to access AllNodes
+};
+
+bool Orientation(const Path &poly);
+double Area(const Path &poly);
+int PointInPolygon(const IntPoint &pt, const Path &path);
+
+void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType = pftEvenOdd);
+void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType = pftEvenOdd);
+void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd);
+
+void CleanPolygon(const Path& in_poly, Path& out_poly, double distance = 1.415);
+void CleanPolygon(Path& poly, double distance = 1.415);
+void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance = 1.415);
+void CleanPolygons(Paths& polys, double distance = 1.415);
+
+void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed);
+void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed);
+void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution);
+
+void PolyTreeToPaths(const PolyTree& polytree, Paths& paths);
+void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths);
+void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths);
+
+void ReversePath(Path& p);
+void ReversePaths(Paths& p);
+
+struct IntRect { cInt left; cInt top; cInt right; cInt bottom; };
+
+//enums that are used internally ...
+enum EdgeSide { esLeft = 1, esRight = 2};
+
+//forward declarations (for stuff used internally) ...
+struct TEdge;
+struct IntersectNode;
+struct LocalMinimum;
+struct OutPt;
+struct OutRec;
+struct Join;
+
+typedef std::vector < OutRec* > PolyOutList;
+typedef std::vector < TEdge* > EdgeList;
+typedef std::vector < Join* > JoinList;
+typedef std::vector < IntersectNode* > IntersectList;
+
+//------------------------------------------------------------------------------
+
+//ClipperBase is the ancestor to the Clipper class. It should not be
+//instantiated directly. This class simply abstracts the conversion of sets of
+//polygon coordinates into edge objects that are stored in a LocalMinima list.
+class ClipperBase
+{
+public:
+ ClipperBase();
+ virtual ~ClipperBase();
+ virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed);
+ bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed);
+ virtual void Clear();
+ IntRect GetBounds();
+ bool PreserveCollinear() {return m_PreserveCollinear;};
+ void PreserveCollinear(bool value) {m_PreserveCollinear = value;};
+protected:
+ void DisposeLocalMinimaList();
+ TEdge* AddBoundsToLML(TEdge *e, bool IsClosed);
+ virtual void Reset();
+ TEdge* ProcessBound(TEdge* E, bool IsClockwise);
+ void InsertScanbeam(const cInt Y);
+ bool PopScanbeam(cInt &Y);
+ bool LocalMinimaPending();
+ bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin);
+ OutRec* CreateOutRec();
+ void DisposeAllOutRecs();
+ void DisposeOutRec(PolyOutList::size_type index);
+ void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2);
+ void DeleteFromAEL(TEdge *e);
+ void UpdateEdgeIntoAEL(TEdge *&e);
+
+ typedef std::vector MinimaList;
+ MinimaList::iterator m_CurrentLM;
+ MinimaList m_MinimaList;
+
+ bool m_UseFullRange;
+ EdgeList m_edges;
+ bool m_PreserveCollinear;
+ bool m_HasOpenPaths;
+ PolyOutList m_PolyOuts;
+ TEdge *m_ActiveEdges;
+
+ typedef std::priority_queue ScanbeamList;
+ ScanbeamList m_Scanbeam;
+};
+//------------------------------------------------------------------------------
+
+class Clipper : public virtual ClipperBase
+{
+public:
+ Clipper(int initOptions = 0);
+ bool Execute(ClipType clipType,
+ Paths &solution,
+ PolyFillType fillType = pftEvenOdd);
+ bool Execute(ClipType clipType,
+ Paths &solution,
+ PolyFillType subjFillType,
+ PolyFillType clipFillType);
+ bool Execute(ClipType clipType,
+ PolyTree &polytree,
+ PolyFillType fillType = pftEvenOdd);
+ bool Execute(ClipType clipType,
+ PolyTree &polytree,
+ PolyFillType subjFillType,
+ PolyFillType clipFillType);
+ bool ReverseSolution() { return m_ReverseOutput; };
+ void ReverseSolution(bool value) {m_ReverseOutput = value;};
+ bool StrictlySimple() {return m_StrictSimple;};
+ void StrictlySimple(bool value) {m_StrictSimple = value;};
+ //set the callback function for z value filling on intersections (otherwise Z is 0)
+#ifdef use_xyz
+ void ZFillFunction(ZFillCallback zFillFunc);
+#endif
+protected:
+ virtual bool ExecuteInternal();
+private:
+ JoinList m_Joins;
+ JoinList m_GhostJoins;
+ IntersectList m_IntersectList;
+ ClipType m_ClipType;
+ typedef std::list MaximaList;
+ MaximaList m_Maxima;
+ TEdge *m_SortedEdges;
+ bool m_ExecuteLocked;
+ PolyFillType m_ClipFillType;
+ PolyFillType m_SubjFillType;
+ bool m_ReverseOutput;
+ bool m_UsingPolyTree;
+ bool m_StrictSimple;
+#ifdef use_xyz
+ ZFillCallback m_ZFill; //custom callback
+#endif
+ void SetWindingCount(TEdge& edge);
+ bool IsEvenOddFillType(const TEdge& edge) const;
+ bool IsEvenOddAltFillType(const TEdge& edge) const;
+ void InsertLocalMinimaIntoAEL(const cInt botY);
+ void InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge);
+ void AddEdgeToSEL(TEdge *edge);
+ bool PopEdgeFromSEL(TEdge *&edge);
+ void CopyAELToSEL();
+ void DeleteFromSEL(TEdge *e);
+ void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2);
+ bool IsContributing(const TEdge& edge) const;
+ bool IsTopHorz(const cInt XPos);
+ void DoMaxima(TEdge *e);
+ void ProcessHorizontals();
+ void ProcessHorizontal(TEdge *horzEdge);
+ void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
+ OutPt* AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
+ OutRec* GetOutRec(int idx);
+ void AppendPolygon(TEdge *e1, TEdge *e2);
+ void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt);
+ OutPt* AddOutPt(TEdge *e, const IntPoint &pt);
+ OutPt* GetLastOutPt(TEdge *e);
+ bool ProcessIntersections(const cInt topY);
+ void BuildIntersectList(const cInt topY);
+ void ProcessIntersectList();
+ void ProcessEdgesAtTopOfScanbeam(const cInt topY);
+ void BuildResult(Paths& polys);
+ void BuildResult2(PolyTree& polytree);
+ void SetHoleState(TEdge *e, OutRec *outrec);
+ void DisposeIntersectNodes();
+ bool FixupIntersectionOrder();
+ void FixupOutPolygon(OutRec &outrec);
+ void FixupOutPolyline(OutRec &outrec);
+ bool IsHole(TEdge *e);
+ bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl);
+ void FixHoleLinkage(OutRec &outrec);
+ void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt);
+ void ClearJoins();
+ void ClearGhostJoins();
+ void AddGhostJoin(OutPt *op, const IntPoint offPt);
+ bool JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2);
+ void JoinCommonEdges();
+ void DoSimplePolygons();
+ void FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec);
+ void FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec);
+ void FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec);
+#ifdef use_xyz
+ void SetZ(IntPoint& pt, TEdge& e1, TEdge& e2);
+#endif
+};
+//------------------------------------------------------------------------------
+
+class ClipperOffset
+{
+public:
+ ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25);
+ ~ClipperOffset();
+ void AddPath(const Path& path, JoinType joinType, EndType endType);
+ void AddPaths(const Paths& paths, JoinType joinType, EndType endType);
+ void Execute(Paths& solution, double delta);
+ void Execute(PolyTree& solution, double delta);
+ void Clear();
+ double MiterLimit;
+ double ArcTolerance;
+private:
+ Paths m_destPolys;
+ Path m_srcPoly;
+ Path m_destPoly;
+ std::vector m_normals;
+ double m_delta, m_sinA, m_sin, m_cos;
+ double m_miterLim, m_StepsPerRad;
+ IntPoint m_lowest;
+ PolyNode m_polyNodes;
+
+ void FixOrientations();
+ void DoOffset(double delta);
+ void OffsetPoint(int j, int& k, JoinType jointype);
+ void DoSquare(int j, int k);
+ void DoMiter(int j, int k, double r);
+ void DoRound(int j, int k);
+};
+//------------------------------------------------------------------------------
+
+class clipperException : public std::exception
+{
+ public:
+ clipperException(const char* description): m_descr(description) {}
+ virtual ~clipperException() throw() {}
+ virtual const char* what() const throw() {return m_descr.c_str();}
+ private:
+ std::string m_descr;
+};
+//------------------------------------------------------------------------------
+
+} //ClipperLib namespace
+
+#endif //clipper_hpp
diff --git a/mmocr/models/textdet/postprocess/include/pybind11/attr.h b/mmocr/models/textdet/postprocess/include/pybind11/attr.h
new file mode 100644
index 00000000..8732cfe1
--- /dev/null
+++ b/mmocr/models/textdet/postprocess/include/pybind11/attr.h
@@ -0,0 +1,492 @@
+/*
+ pybind11/attr.h: Infrastructure for processing custom
+ type and function attributes
+
+ Copyright (c) 2016 Wenzel Jakob
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "cast.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+/// \addtogroup annotations
+/// @{
+
+/// Annotation for methods
+struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
+
+/// Annotation for operators
+struct is_operator { };
+
+/// Annotation for parent scope
+struct scope { handle value; scope(const handle &s) : value(s) { } };
+
+/// Annotation for documentation
+struct doc { const char *value; doc(const char *value) : value(value) { } };
+
+/// Annotation for function names
+struct name { const char *value; name(const char *value) : value(value) { } };
+
+/// Annotation indicating that a function is an overload associated with a given "sibling"
+struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
+
+/// Annotation indicating that a class derives from another given type
+template struct base {
+ PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_")
+ base() { }
+};
+
+/// Keep patient alive while nurse lives
+template struct keep_alive { };
+
+/// Annotation indicating that a class is involved in a multiple inheritance relationship
+struct multiple_inheritance { };
+
+/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
+struct dynamic_attr { };
+
+/// Annotation which enables the buffer protocol for a type
+struct buffer_protocol { };
+
+/// Annotation which requests that a special metaclass is created for a type
+struct metaclass {
+ handle value;
+
+ PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
+ metaclass() {}
+
+ /// Override pybind11's default metaclass
+ explicit metaclass(handle value) : value(value) { }
+};
+
+/// Annotation that marks a class as local to the module:
+struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
+
+/// Annotation to mark enums as an arithmetic type
+struct arithmetic { };
+
+/** \rst
+ A call policy which places one or more guard variables (``Ts...``) around the function call.
+
+ For example, this definition:
+
+ .. code-block:: cpp
+
+ m.def("foo", foo, py::call_guard());
+
+ is equivalent to the following pseudocode:
+
+ .. code-block:: cpp
+
+ m.def("foo", [](args...) {
+ T scope_guard;
+ return foo(args...); // forwarded arguments
+ });
+ \endrst */
+template struct call_guard;
+
+template <> struct call_guard<> { using type = detail::void_type; };
+
+template
+struct call_guard {
+ static_assert(std::is_default_constructible::value,
+ "The guard type must be default constructible");
+
+ using type = T;
+};
+
+template
+struct call_guard {
+ struct type {
+ T guard{}; // Compose multiple guard types with left-to-right default-constructor order
+ typename call_guard::type next{};
+ };
+};
+
+/// @} annotations
+
+NAMESPACE_BEGIN(detail)
+/* Forward declarations */
+enum op_id : int;
+enum op_type : int;
+struct undefined_t;
+template struct op_;
+inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
+
+/// Internal data structure which holds metadata about a keyword argument
+struct argument_record {
+ const char *name; ///< Argument name
+ const char *descr; ///< Human-readable version of the argument value
+ handle value; ///< Associated Python object
+ bool convert : 1; ///< True if the argument is allowed to convert when loading
+ bool none : 1; ///< True if None is allowed when loading
+
+ argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
+ : name(name), descr(descr), value(value), convert(convert), none(none) { }
+};
+
+/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
+struct function_record {
+ function_record()
+ : is_constructor(false), is_new_style_constructor(false), is_stateless(false),
+ is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
+
+ /// Function name
+ char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
+
+ // User-specified documentation string
+ char *doc = nullptr;
+
+ /// Human-readable version of the function signature
+ char *signature = nullptr;
+
+ /// List of registered keyword arguments
+ std::vector args;
+
+ /// Pointer to lambda function which converts arguments and performs the actual call
+ handle (*impl) (function_call &) = nullptr;
+
+ /// Storage for the wrapped function pointer and captured data, if any
+ void *data[3] = { };
+
+ /// Pointer to custom destructor for 'data' (if needed)
+ void (*free_data) (function_record *ptr) = nullptr;
+
+ /// Return value policy associated with this function
+ return_value_policy policy = return_value_policy::automatic;
+
+ /// True if name == '__init__'
+ bool is_constructor : 1;
+
+ /// True if this is a new-style `__init__` defined in `detail/init.h`
+ bool is_new_style_constructor : 1;
+
+ /// True if this is a stateless function pointer
+ bool is_stateless : 1;
+
+ /// True if this is an operator (__add__), etc.
+ bool is_operator : 1;
+
+ /// True if the function has a '*args' argument
+ bool has_args : 1;
+
+ /// True if the function has a '**kwargs' argument
+ bool has_kwargs : 1;
+
+ /// True if this is a method
+ bool is_method : 1;
+
+ /// Number of arguments (including py::args and/or py::kwargs, if present)
+ std::uint16_t nargs;
+
+ /// Python method object
+ PyMethodDef *def = nullptr;
+
+ /// Python handle to the parent scope (a class or a module)
+ handle scope;
+
+ /// Python handle to the sibling function representing an overload chain
+ handle sibling;
+
+ /// Pointer to next overload
+ function_record *next = nullptr;
+};
+
+/// Special data structure which (temporarily) holds metadata about a bound class
+struct type_record {
+ PYBIND11_NOINLINE type_record()
+ : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { }
+
+ /// Handle to the parent scope
+ handle scope;
+
+ /// Name of the class
+ const char *name = nullptr;
+
+ // Pointer to RTTI type_info data structure
+ const std::type_info *type = nullptr;
+
+ /// How large is the underlying C++ type?
+ size_t type_size = 0;
+
+ /// What is the alignment of the underlying C++ type?
+ size_t type_align = 0;
+
+ /// How large is the type's holder?
+ size_t holder_size = 0;
+
+ /// The global operator new can be overridden with a class-specific variant
+ void *(*operator_new)(size_t) = nullptr;
+
+ /// Function pointer to class_<..>::init_instance
+ void (*init_instance)(instance *, const void *) = nullptr;
+
+ /// Function pointer to class_<..>::dealloc
+ void (*dealloc)(detail::value_and_holder &) = nullptr;
+
+ /// List of base classes of the newly created type
+ list bases;
+
+ /// Optional docstring
+ const char *doc = nullptr;
+
+ /// Custom metaclass (optional)
+ handle metaclass;
+
+ /// Multiple inheritance marker
+ bool multiple_inheritance : 1;
+
+ /// Does the class manage a __dict__?
+ bool dynamic_attr : 1;
+
+ /// Does the class implement the buffer protocol?
+ bool buffer_protocol : 1;
+
+ /// Is the default (unique_ptr) holder type used?
+ bool default_holder : 1;
+
+ /// Is the class definition local to the module shared object?
+ bool module_local : 1;
+
+ PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
+ auto base_info = detail::get_type_info(base, false);
+ if (!base_info) {
+ std::string tname(base.name());
+ detail::clean_type_id(tname);
+ pybind11_fail("generic_type: type \"" + std::string(name) +
+ "\" referenced unknown base type \"" + tname + "\"");
+ }
+
+ if (default_holder != base_info->default_holder) {
+ std::string tname(base.name());
+ detail::clean_type_id(tname);
+ pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
+ (default_holder ? "does not have" : "has") +
+ " a non-default holder type while its base \"" + tname + "\" " +
+ (base_info->default_holder ? "does not" : "does"));
+ }
+
+ bases.append((PyObject *) base_info->type);
+
+ if (base_info->type->tp_dictoffset != 0)
+ dynamic_attr = true;
+
+ if (caster)
+ base_info->implicit_casts.emplace_back(type, caster);
+ }
+};
+
+inline function_call::function_call(const function_record &f, handle p) :
+ func(f), parent(p) {
+ args.reserve(f.nargs);
+ args_convert.reserve(f.nargs);
+}
+
+/// Tag for a new-style `__init__` defined in `detail/init.h`
+struct is_new_style_constructor { };
+
+/**
+ * Partial template specializations to process custom attributes provided to
+ * cpp_function_ and class_. These are either used to initialize the respective
+ * fields in the type_record and function_record data structures or executed at
+ * runtime to deal with custom call policies (e.g. keep_alive).
+ */
+template struct process_attribute;
+
+template struct process_attribute_default {
+ /// Default implementation: do nothing
+ static void init(const T &, function_record *) { }
+ static void init(const T &, type_record *) { }
+ static void precall(function_call &) { }
+ static void postcall(function_call &, handle) { }
+};
+
+/// Process an attribute specifying the function's name
+template <> struct process_attribute : process_attribute_default {
+ static void init(const name &n, function_record *r) { r->name = const_cast(n.value); }
+};
+
+/// Process an attribute specifying the function's docstring
+template <> struct process_attribute : process_attribute_default {
+ static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); }
+};
+
+/// Process an attribute specifying the function's docstring (provided as a C-style string)
+template <> struct process_attribute : process_attribute_default {
+ static void init(const char *d, function_record *r) { r->doc = const_cast(d); }
+ static void init(const char *d, type_record *r) { r->doc = const_cast(d); }
+};
+template <> struct process_attribute : process_attribute { };
+
+/// Process an attribute indicating the function's return value policy
+template <> struct process_attribute : process_attribute_default {
+ static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
+};
+
+/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
+template <> struct process_attribute : process_attribute_default {
+ static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
+};
+
+/// Process an attribute which indicates that this function is a method
+template <> struct process_attribute : process_attribute_default {
+ static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
+};
+
+/// Process an attribute which indicates the parent scope of a method
+template <> struct process_attribute : process_attribute_default {
+ static void init(const scope &s, function_record *r) { r->scope = s.value; }
+};
+
+/// Process an attribute which indicates that this function is an operator
+template <> struct process_attribute : process_attribute_default {
+ static void init(const is_operator &, function_record *r) { r->is_operator = true; }
+};
+
+template <> struct process_attribute : process_attribute_default {
+ static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
+};
+
+/// Process a keyword argument attribute (*without* a default value)
+template <> struct process_attribute : process_attribute_default {
+ static void init(const arg &a, function_record *r) {
+ if (r->is_method && r->args.empty())
+ r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
+ r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
+ }
+};
+
+/// Process a keyword argument attribute (*with* a default value)
+template <> struct process_attribute : process_attribute_default {
+ static void init(const arg_v &a, function_record *r) {
+ if (r->is_method && r->args.empty())
+ r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
+
+ if (!a.value) {
+#if !defined(NDEBUG)
+ std::string descr("'");
+ if (a.name) descr += std::string(a.name) + ": ";
+ descr += a.type + "'";
+ if (r->is_method) {
+ if (r->name)
+ descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
+ else
+ descr += " in method of '" + (std::string) str(r->scope) + "'";
+ } else if (r->name) {
+ descr += " in function '" + (std::string) r->name + "'";
+ }
+ pybind11_fail("arg(): could not convert default argument "
+ + descr + " into a Python object (type not registered yet?)");
+#else
+ pybind11_fail("arg(): could not convert default argument "
+ "into a Python object (type not registered yet?). "
+ "Compile in debug mode for more information.");
+#endif
+ }
+ r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
+ }
+};
+
+/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
+template
+struct process_attribute::value>> : process_attribute_default {
+ static void init(const handle &h, type_record *r) { r->bases.append(h); }
+};
+
+/// Process a parent class attribute (deprecated, does not support multiple inheritance)
+template
+struct process_attribute> : process_attribute_default> {
+ static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); }
+};
+
+/// Process a multiple inheritance attribute
+template <>
+struct process_attribute : process_attribute_default {
+ static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
+};
+
+template <>
+struct process_attribute : process_attribute_default {
+ static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
+};
+
+template <>
+struct process_attribute : process_attribute_default {
+ static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
+};
+
+template <>
+struct process_attribute : process_attribute_default {
+ static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
+};
+
+template <>
+struct process_attribute : process_attribute_default {
+ static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
+};
+
+/// Process an 'arithmetic' attribute for enums (does nothing here)
+template <>
+struct process_attribute : process_attribute_default {};
+
+template
+struct process_attribute> : process_attribute_default> { };
+
+/**
+ * Process a keep_alive call policy -- invokes keep_alive_impl during the
+ * pre-call handler if both Nurse, Patient != 0 and use the post-call handler
+ * otherwise
+ */
+template struct process_attribute> : public process_attribute_default> {
+ template = 0>
+ static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
+ template = 0>
+ static void postcall(function_call &, handle) { }
+ template = 0>
+ static void precall(function_call &) { }
+ template = 0>
+ static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
+};
+
+/// Recursively iterate over variadic template arguments
+template struct process_attributes {
+ static void init(const Args&... args, function_record *r) {
+ int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... };
+ ignore_unused(unused);
+ }
+ static void init(const Args&... args, type_record *r) {
+ int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... };
+ ignore_unused(unused);
+ }
+ static void precall(function_call &call) {
+ int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... };
+ ignore_unused(unused);
+ }
+ static void postcall(function_call &call, handle fn_ret) {
+ int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... };
+ ignore_unused(unused);
+ }
+};
+
+template
+using is_call_guard = is_instantiation;
+
+/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
+template
+using extract_guard_t = typename exactly_one_t, Extra...>::type;
+
+/// Check the number of named arguments at compile time
+template ::value...),
+ size_t self = constexpr_sum(std::is_same::value...)>
+constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
+ return named == 0 || (self + named + has_args + has_kwargs) == nargs;
+}
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/mmocr/models/textdet/postprocess/include/pybind11/buffer_info.h b/mmocr/models/textdet/postprocess/include/pybind11/buffer_info.h
new file mode 100644
index 00000000..9f072fa7
--- /dev/null
+++ b/mmocr/models/textdet/postprocess/include/pybind11/buffer_info.h
@@ -0,0 +1,108 @@
+/*
+ pybind11/buffer_info.h: Python buffer object interface
+
+ Copyright (c) 2016 Wenzel Jakob
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "detail/common.h"
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+
+/// Information record describing a Python buffer object
+struct buffer_info {
+ void *ptr = nullptr; // Pointer to the underlying storage
+ ssize_t itemsize = 0; // Size of individual items in bytes
+ ssize_t size = 0; // Total number of entries
+ std::string format; // For homogeneous buffers, this should be set to format_descriptor::format()
+ ssize_t ndim = 0; // Number of dimensions
+ std::vector shape; // Shape of the tensor (1 entry per dimension)
+ std::vector strides; // Number of entries between adjacent entries (for each per dimension)
+
+ buffer_info() { }
+
+ buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
+ detail::any_container shape_in, detail::any_container strides_in)
+ : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
+ shape(std::move(shape_in)), strides(std::move(strides_in)) {
+ if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
+ pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
+ for (size_t i = 0; i < (size_t) ndim; ++i)
+ size *= shape[i];
+ }
+
+ template
+ buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in)
+ : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
+
+ buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
+ : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
+
+ template
+ buffer_info(T *ptr, ssize_t size)
+ : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { }
+
+ explicit buffer_info(Py_buffer *view, bool ownview = true)
+ : buffer_info(view->buf, view->itemsize, view->format, view->ndim,
+ {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
+ this->view = view;
+ this->ownview = ownview;
+ }
+
+ buffer_info(const buffer_info &) = delete;
+ buffer_info& operator=(const buffer_info &) = delete;
+
+ buffer_info(buffer_info &&other) {
+ (*this) = std::move(other);
+ }
+
+ buffer_info& operator=(buffer_info &&rhs) {
+ ptr = rhs.ptr;
+ itemsize = rhs.itemsize;
+ size = rhs.size;
+ format = std::move(rhs.format);
+ ndim = rhs.ndim;
+ shape = std::move(rhs.shape);
+ strides = std::move(rhs.strides);
+ std::swap(view, rhs.view);
+ std::swap(ownview, rhs.ownview);
+ return *this;
+ }
+
+ ~buffer_info() {
+ if (view && ownview) { PyBuffer_Release(view); delete view; }
+ }
+
+private:
+ struct private_ctr_tag { };
+
+ buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
+ detail::any_container &&shape_in, detail::any_container &&strides_in)
+ : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
+
+ Py_buffer *view = nullptr;
+ bool ownview = false;
+};
+
+NAMESPACE_BEGIN(detail)
+
+template struct compare_buffer_info {
+ static bool compare(const buffer_info& b) {
+ return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T);
+ }
+};
+
+template struct compare_buffer_info::value>> {
+ static bool compare(const buffer_info& b) {
+ return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value ||
+ ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) ||
+ ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n")));
+ }
+};
+
+NAMESPACE_END(detail)
+NAMESPACE_END(PYBIND11_NAMESPACE)
diff --git a/mmocr/models/textdet/postprocess/include/pybind11/cast.h b/mmocr/models/textdet/postprocess/include/pybind11/cast.h
new file mode 100644
index 00000000..80abb2b9
--- /dev/null
+++ b/mmocr/models/textdet/postprocess/include/pybind11/cast.h
@@ -0,0 +1,2128 @@
+/*
+ pybind11/cast.h: Partial template specializations to cast between
+ C++ and Python types
+
+ Copyright (c) 2016 Wenzel Jakob
+
+ All rights reserved. Use of this source code is governed by a
+ BSD-style license that can be found in the LICENSE file.
+*/
+
+#pragma once
+
+#include "pytypes.h"
+#include "detail/typeid.h"
+#include "detail/descr.h"
+#include "detail/internals.h"
+#include
+#include
+#include
+#include
+
+#if defined(PYBIND11_CPP17)
+# if defined(__has_include)
+# if __has_include()
+# define PYBIND11_HAS_STRING_VIEW
+# endif
+# elif defined(_MSC_VER)
+# define PYBIND11_HAS_STRING_VIEW
+# endif
+#endif
+#ifdef PYBIND11_HAS_STRING_VIEW
+#include
+#endif
+
+NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
+NAMESPACE_BEGIN(detail)
+
+/// A life support system for temporary objects created by `type_caster::load()`.
+/// Adding a patient will keep it alive up until the enclosing function returns.
+class loader_life_support {
+public:
+ /// A new patient frame is created when a function is entered
+ loader_life_support() {
+ get_internals().loader_patient_stack.push_back(nullptr);
+ }
+
+ /// ... and destroyed after it returns
+ ~loader_life_support() {
+ auto &stack = get_internals().loader_patient_stack;
+ if (stack.empty())
+ pybind11_fail("loader_life_support: internal error");
+
+ auto ptr = stack.back();
+ stack.pop_back();
+ Py_CLEAR(ptr);
+
+ // A heuristic to reduce the stack's capacity (e.g. after long recursive calls)
+ if (stack.capacity() > 16 && stack.size() != 0 && stack.capacity() / stack.size() > 2)
+ stack.shrink_to_fit();
+ }
+
+ /// This can only be used inside a pybind11-bound function, either by `argument_loader`
+ /// at argument preparation time or by `py::cast()` at execution time.
+ PYBIND11_NOINLINE static void add_patient(handle h) {
+ auto &stack = get_internals().loader_patient_stack;
+ if (stack.empty())
+ throw cast_error("When called outside a bound function, py::cast() cannot "
+ "do Python -> C++ conversions which require the creation "
+ "of temporary values");
+
+ auto &list_ptr = stack.back();
+ if (list_ptr == nullptr) {
+ list_ptr = PyList_New(1);
+ if (!list_ptr)
+ pybind11_fail("loader_life_support: error allocating list");
+ PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr());
+ } else {
+ auto result = PyList_Append(list_ptr, h.ptr());
+ if (result == -1)
+ pybind11_fail("loader_life_support: error adding patient");
+ }
+ }
+};
+
+// Gets the cache entry for the given type, creating it if necessary. The return value is the pair
+// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was
+// just created.
+inline std::pair all_type_info_get_cache(PyTypeObject *type);
+
+// Populates a just-created cache entry.
+PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector &bases) {
+ std::vector check;
+ for (handle parent : reinterpret_borrow(t->tp_bases))
+ check.push_back((PyTypeObject *) parent.ptr());
+
+ auto const &type_dict = get_internals().registered_types_py;
+ for (size_t i = 0; i < check.size(); i++) {
+ auto type = check[i];
+ // Ignore Python2 old-style class super type:
+ if (!PyType_Check((PyObject *) type)) continue;
+
+ // Check `type` in the current set of registered python types:
+ auto it = type_dict.find(type);
+ if (it != type_dict.end()) {
+ // We found a cache entry for it, so it's either pybind-registered or has pre-computed
+ // pybind bases, but we have to make sure we haven't already seen the type(s) before: we
+ // want to follow Python/virtual C++ rules that there should only be one instance of a
+ // common base.
+ for (auto *tinfo : it->second) {
+ // NB: Could use a second set here, rather than doing a linear search, but since
+ // having a large number of immediate pybind11-registered types seems fairly
+ // unlikely, that probably isn't worthwhile.
+ bool found = false;
+ for (auto *known : bases) {
+ if (known == tinfo) { found = true; break; }
+ }
+ if (!found) bases.push_back(tinfo);
+ }
+ }
+ else if (type->tp_bases) {
+ // It's some python type, so keep follow its bases classes to look for one or more
+ // registered types
+ if (i + 1 == check.size()) {
+ // When we're at the end, we can pop off the current element to avoid growing
+ // `check` when adding just one base (which is typical--i.e. when there is no
+ // multiple inheritance)
+ check.pop_back();
+ i--;
+ }
+ for (handle parent : reinterpret_borrow(type->tp_bases))
+ check.push_back((PyTypeObject *) parent.ptr());
+ }
+ }
+}
+
+/**
+ * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will
+ * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side
+ * derived class that uses single inheritance. Will contain as many types as required for a Python
+ * class that uses multiple inheritance to inherit (directly or indirectly) from multiple
+ * pybind-registered classes. Will be empty if neither the type nor any base classes are
+ * pybind-registered.
+ *
+ * The value is cached for the lifetime of the Python type.
+ */
+inline const std::vector &all_type_info(PyTypeObject *type) {
+ auto ins = all_type_info_get_cache(type);
+ if (ins.second)
+ // New cache entry: populate it
+ all_type_info_populate(type, ins.first->second);
+
+ return ins.first->second;
+}
+
+/**
+ * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any
+ * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use
+ * `all_type_info` instead if you want to support multiple bases.
+ */
+PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) {
+ auto &bases = all_type_info(type);
+ if (bases.size() == 0)
+ return nullptr;
+ if (bases.size() > 1)
+ pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases");
+ return bases.front();
+}
+
+inline detail::type_info *get_local_type_info(const std::type_index &tp) {
+ auto &locals = registered_local_types_cpp();
+ auto it = locals.find(tp);
+ if (it != locals.end())
+ return it->second;
+ return nullptr;
+}
+
+inline detail::type_info *get_global_type_info(const std::type_index &tp) {
+ auto &types = get_internals().registered_types_cpp;
+ auto it = types.find(tp);
+ if (it != types.end())
+ return it->second;
+ return nullptr;
+}
+
+/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr.
+PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp,
+ bool throw_if_missing = false) {
+ if (auto ltype = get_local_type_info(tp))
+ return ltype;
+ if (auto gtype = get_global_type_info(tp))
+ return gtype;
+
+ if (throw_if_missing) {
+ std::string tname = tp.name();
+ detail::clean_type_id(tname);
+ pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\"");
+ }
+ return nullptr;
+}
+
+PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) {
+ detail::type_info *type_info = get_type_info(tp, throw_if_missing);
+ return handle(type_info ? ((PyObject *) type_info->type) : nullptr);
+}
+
+struct value_and_holder {
+ instance *inst;
+ size_t index;
+ const detail::type_info *type;
+ void **vh;
+
+ // Main constructor for a found value/holder:
+ value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) :
+ inst{i}, index{index}, type{type},
+ vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]}
+ {}
+
+ // Default constructor (used to signal a value-and-holder not found by get_value_and_holder())
+ value_and_holder() : inst{nullptr} {}
+
+ // Used for past-the-end iterator
+ value_and_holder(size_t index) : index{index} {}
+
+ template V *&value_ptr() const {
+ return reinterpret_cast(vh[0]);
+ }
+ // True if this `value_and_holder` has a non-null value pointer
+ explicit operator bool() const { return value_ptr(); }
+
+ template H &holder() const {
+ return reinterpret_cast(vh[1]);
+ }
+ bool holder_constructed() const {
+ return inst->simple_layout
+ ? inst->simple_holder_constructed
+ : inst->nonsimple.status[index] & instance::status_holder_constructed;
+ }
+ void set_holder_constructed(bool v = true) {
+ if (inst->simple_layout)
+ inst->simple_holder_constructed = v;
+ else if (v)
+ inst->nonsimple.status[index] |= instance::status_holder_constructed;
+ else
+ inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed;
+ }
+ bool instance_registered() const {
+ return inst->simple_layout
+ ? inst->simple_instance_registered
+ : inst->nonsimple.status[index] & instance::status_instance_registered;
+ }
+ void set_instance_registered(bool v = true) {
+ if (inst->simple_layout)
+ inst->simple_instance_registered = v;
+ else if (v)
+ inst->nonsimple.status[index] |= instance::status_instance_registered;
+ else
+ inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered;
+ }
+};
+
+// Container for accessing and iterating over an instance's values/holders
+struct values_and_holders {
+private:
+ instance *inst;
+ using type_vec = std::vector;
+ const type_vec &tinfo;
+
+public:
+ values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {}
+
+ struct iterator {
+ private:
+ instance *inst;
+ const type_vec *types;
+ value_and_holder curr;
+ friend struct values_and_holders;
+ iterator(instance *inst, const type_vec *tinfo)
+ : inst{inst}, types{tinfo},
+ curr(inst /* instance */,
+ types->empty() ? nullptr : (*types)[0] /* type info */,
+ 0, /* vpos: (non-simple types only): the first vptr comes first */
+ 0 /* index */)
+ {}
+ // Past-the-end iterator:
+ iterator(size_t end) : curr(end) {}
+ public:
+ bool operator==(const iterator &other) { return curr.index == other.curr.index; }
+ bool operator!=(const iterator &other) { return curr.index != other.curr.index; }
+ iterator &operator++() {
+ if (!inst->simple_layout)
+ curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs;
+ ++curr.index;
+ curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr;
+ return *this;
+ }
+ value_and_holder &operator*() { return curr; }
+ value_and_holder *operator->() { return &curr; }
+ };
+
+ iterator begin() { return iterator(inst, &tinfo); }
+ iterator end() { return iterator(tinfo.size()); }
+
+ iterator find(const type_info *find_type) {
+ auto it = begin(), endit = end();
+ while (it != endit && it->type != find_type) ++it;
+ return it;
+ }
+
+ size_t size() { return tinfo.size(); }
+};
+
+/**
+ * Extracts C++ value and holder pointer references from an instance (which may contain multiple
+ * values/holders for python-side multiple inheritance) that match the given type. Throws an error
+ * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If
+ * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned,
+ * regardless of type (and the resulting .type will be nullptr).
+ *
+ * The returned object should be short-lived: in particular, it must not outlive the called-upon
+ * instance.
+ */
+PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) {
+ // Optimize common case:
+ if (!find_type || Py_TYPE(this) == find_type->type)
+ return value_and_holder(this, find_type, 0, 0);
+
+ detail::values_and_holders vhs(this);
+ auto it = vhs.find(find_type);
+ if (it != vhs.end())
+ return *it;
+
+ if (!throw_if_missing)
+ return value_and_holder();
+
+#if defined(NDEBUG)
+ pybind11_fail("pybind11::detail::instance::get_value_and_holder: "
+ "type is not a pybind11 base of the given instance "
+ "(compile in debug mode for type details)");
+#else
+ pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" +
+ std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" +
+ std::string(Py_TYPE(this)->tp_name) + "' instance");
+#endif
+}
+
+PYBIND11_NOINLINE inline void instance::allocate_layout() {
+ auto &tinfo = all_type_info(Py_TYPE(this));
+
+ const size_t n_types = tinfo.size();
+
+ if (n_types == 0)
+ pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types");
+
+ simple_layout =
+ n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs();
+
+ // Simple path: no python-side multiple inheritance, and a small-enough holder
+ if (simple_layout) {
+ simple_value_holder[0] = nullptr;
+ simple_holder_constructed = false;
+ simple_instance_registered = false;
+ }
+ else { // multiple base types or a too-large holder
+ // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer,
+ // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool
+ // values that tracks whether each associated holder has been initialized. Each [block] is
+ // padded, if necessary, to an integer multiple of sizeof(void *).
+ size_t space = 0;
+ for (auto t : tinfo) {
+ space += 1; // value pointer
+ space += t->holder_size_in_ptrs; // holder instance
+ }
+ size_t flags_at = space;
+ space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered)
+
+ // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values,
+ // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6
+ // they default to using pymalloc, which is designed to be efficient for small allocations
+ // like the one we're doing here; in earlier versions (and for larger allocations) they are
+ // just wrappers around malloc.
+#if PY_VERSION_HEX >= 0x03050000
+ nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *));
+ if (!nonsimple.values_and_holders) throw std::bad_alloc();
+#else
+ nonsimple.values_and_holders = (void **) PyMem_New(void *, space);
+ if (!nonsimple.values_and_holders) throw std::bad_alloc();
+ std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *));
+#endif
+ nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]);
+ }
+ owned = true;
+}
+
+PYBIND11_NOINLINE inline void instance::deallocate_layout() {
+ if (!simple_layout)
+ PyMem_Free(nonsimple.values_and_holders);
+}
+
+PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) {
+ handle type = detail::get_type_handle(tp, false);
+ if (!type)
+ return false;
+ return isinstance(obj, type);
+}
+
+PYBIND11_NOINLINE inline std::string error_string() {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred");
+ return "Unknown internal error occurred";
+ }
+
+ error_scope scope; // Preserve error state
+
+ std::string errorString;
+ if (scope.type) {
+ errorString += handle(scope.type).attr("__name__").cast();
+ errorString += ": ";
+ }
+ if (scope.value)
+ errorString += (std::string) str(scope.value);
+
+ PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace);
+
+#if PY_MAJOR_VERSION >= 3
+ if (scope.trace != nullptr)
+ PyException_SetTraceback(scope.value, scope.trace);
+#endif
+
+#if !defined(PYPY_VERSION)
+ if (scope.trace) {
+ PyTracebackObject *trace = (PyTracebackObject *) scope.trace;
+
+ /* Get the deepest trace possible */
+ while (trace->tb_next)
+ trace = trace->tb_next;
+
+ PyFrameObject *frame = trace->tb_frame;
+ errorString += "\n\nAt:\n";
+ while (frame) {
+ int lineno = PyFrame_GetLineNumber(frame);
+ errorString +=
+ " " + handle(frame->f_code->co_filename).cast() +
+ "(" + std::to_string(lineno) + "): " +
+ handle(frame->f_code->co_name).cast() + "\n";
+ frame = frame->f_back;
+ }
+ }
+#endif
+
+ return errorString;
+}
+
+PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) {
+ auto &instances = get_internals().registered_instances;
+ auto range = instances.equal_range(ptr);
+ for (auto it = range.first; it != range.second; ++it) {
+ for (auto vh : values_and_holders(it->second)) {
+ if (vh.type == type)
+ return handle((PyObject *) it->second);
+ }
+ }
+ return handle();
+}
+
+inline PyThreadState *get_thread_state_unchecked() {
+#if defined(PYPY_VERSION)
+ return PyThreadState_GET();
+#elif PY_VERSION_HEX < 0x03000000
+ return _PyThreadState_Current;
+#elif PY_VERSION_HEX < 0x03050000
+ return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current);
+#elif PY_VERSION_HEX < 0x03050200
+ return (PyThreadState*) _PyThreadState_Current.value;
+#else
+ return _PyThreadState_UncheckedGet();
+#endif
+}
+
+// Forward declarations
+inline void keep_alive_impl(handle nurse, handle patient);
+inline PyObject *make_new_instance(PyTypeObject *type);
+
+class type_caster_generic {
+public:
+ PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info)
+ : typeinfo(get_type_info(type_info)), cpptype(&type_info) { }
+
+ type_caster_generic(const type_info *typeinfo)
+ : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { }
+
+ bool load(handle src, bool convert) {
+ return load_impl(src, convert);
+ }
+
+ PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent,
+ const detail::type_info *tinfo,
+ void *(*copy_constructor)(const void *),
+ void *(*move_constructor)(const void *),
+ const void *existing_holder = nullptr) {
+ if (!tinfo) // no type info: error will be set already
+ return handle();
+
+ void *src = const_cast(_src);
+ if (src == nullptr)
+ return none().release();
+
+ auto it_instances = get_internals().registered_instances.equal_range(src);
+ for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) {
+ for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) {
+ if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype))
+ return handle((PyObject *) it_i->second).inc_ref();
+ }
+ }
+
+ auto inst = reinterpret_steal