diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..a7ee6382 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +omit = + */__init__.py diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..efd43057 --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,76 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at chenkaidev@gmail.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 00000000..4a6fe184 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1 @@ +We appreciate all contributions to improve MMOCR. Please refer to [CONTRIBUTING.md](/docs/contributing.md) in MMCV for more details about the contributing guideline. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000..f9b07021 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,68 @@ +name: build + +on: + push: + branches: + - master + + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + - name: Install pre-commit hook + run: | + pip install pre-commit + pre-commit install + - name: Linting + run: pre-commit run --all-files + - name: Check docstring coverage + run: | + pip install interrogate + interrogate -v --ignore-init-method --ignore-module --ignore-nested-functions --ignore-regex "__repr__" --fail-under 50 mmocr + + build_cpu: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7] + torch: [1.5.0, 1.6.0, 1.7.0] + include: + - torch: 1.5.0 + torchvision: 0.6.0 + - torch: 1.6.0 + torchvision: 0.7.0 + - torch: 1.7.0 + torchvision: 0.8.1 + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip + run: pip install pip --upgrade + - name: Install Pillow + run: pip install Pillow==6.2.2 + if: ${{matrix.torchvision == '0.4.1'}} + - name: Install PyTorch + run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install MMCV + run: pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html + - name: Install MMDet + run: pip install git+https://github.com/open-mmlab/mmdetection/ + - name: Install other dependencies + run: pip install -r requirements.txt + - name: Build and install + run: rm -rf .eggs && pip install -e . + - name: Run unittests and generate coverage report + run: | + coverage run --branch --source mmocr -m pytest tests/ + coverage xml + coverage report -m diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml new file mode 100644 index 00000000..af71b058 --- /dev/null +++ b/.github/workflows/publish-to-pypi.yml @@ -0,0 +1,20 @@ +name: deploy + +on: push + +jobs: + build-n-publish: + runs-on: ubuntu-latest + if: startsWith(github.event.ref, 'refs/tags') + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Build MMOCR + run: python setup.py sdist + - name: Publish distribution to PyPI + run: | + pip install twine + twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..52720295 --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +*.ipynb + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# cython generated cpp +!data/dict +data/* +.vscode +.idea + +# custom +*.pkl +*.pkl.json +*.log.json +work_dirs/ +exps/ +*~ +show_dir/ + +# Pytorch +*.pth + +# demo +!tests/data +tests/results + +#temp files +.DS_Store + +checkpoints + +htmlcov +*.swp +log.txt +workspace.code-workspace +results diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..38324d51 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,36 @@ +exclude: ^tests/data/ +repos: + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.1 + hooks: + - id: flake8 + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + - repo: https://github.com/timothycrosley/isort + rev: 4.3.21 + hooks: + - id: isort + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.30.0 + hooks: + - id: yapf + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.1.0 + hooks: + - id: trailing-whitespace + - id: check-yaml + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: double-quote-string-fixer + - id: check-merge-conflict + - id: fix-encoding-pragma + args: ["--remove"] + - id: mixed-line-ending + args: ["--fix=lf"] + - repo: https://github.com/myint/docformatter + rev: v1.3.1 + hooks: + - id: docformatter + args: ["--in-place", "--wrap-descriptions", "79"] diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..d7a39be8 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,621 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=10.0 + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS,configs + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + no-member, + invalid-name, + too-many-branches, + wrong-import-order, + too-many-arguments, + missing-function-docstring, + missing-module-docstring, + too-many-locals, + too-few-public-methods, + abstract-method, + broad-except, + too-many-nested-blocks, + too-many-instance-attributes, + missing-class-docstring, + duplicate-code, + not-callable, + protected-access, + dangerous-default-value, + no-name-in-module, + logging-fstring-interpolation, + super-init-not-called, + redefined-builtin, + attribute-defined-outside-init, + arguments-differ, + cyclic-import, + bad-super-call, + too-many-statements + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _, + x, + y, + w, + h, + a, + b + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 00000000..73ea4cb7 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,7 @@ +version: 2 + +python: + version: 3.7 + install: + - requirements: requirements/docs.txt + - requirements: requirements/readthedocs.txt diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..3076a437 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ +Copyright (c) MMOCR Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021 MMOCR Authors. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index c553ca6a..96f13097 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,17 @@
- +
## Introduction -[![build](https://github.com/open-mmlab/mmediting/workflows/build/badge.svg)](https://github.com/open-mmlab/mmediting/actions) -[![docs](https://readthedocs.org/projects/mmediting/badge/?version=latest)](https://mmediting.readthedocs.io/en/latest/?badge=latest) -[![codecov](https://codecov.io/gh/open-mmlab/mmediting/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmediting) -[![license](https://img.shields.io/github/license/open-mmlab/mmediting.svg)](https://github.com/open-mmlab/mmediting/blob/master/LICENSE) +[![build](https://github.com/open-mmlab/mmocr/workflows/build/badge.svg)](https://github.com/open-mmlab/mmocr/actions) +[![docs](https://readthedocs.org/projects/mmocr/badge/?version=latest)](https://mmocr.readthedocs.io/en/latest/?badge=latest) +[![codecov](https://codecov.io/gh/open-mmlab/mmocr/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmocr) +[![license](https://img.shields.io/github/license/open-mmlab/mmocr.svg)](https://github.com/open-mmlab/mmocr/blob/master/LICENSE) MMOCR is an open-source toolbox based on PyTorch and mmdetection for text detection, text recognition, and the corresponding downstream tasks including key information extraction. It is part of the open-mmlab project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/). -The master branch works with **PyTorch 1.5**. +The master branch works with **PyTorch 1.5+**. Documentation: https://mmocr.readthedocs.io/en/latest/. @@ -31,7 +31,7 @@ Documentation: https://mmocr.readthedocs.io/en/latest/. - **Modular Design** - The modular design of MMOCR enables users to define their own optimizers, data preprocessors, and model components such as backbones, necks and heads as well as losses. Please refer to [GETTING_STARTED.md](docs/GETTING_STARTED.md) for how to construct a customized model. + The modular design of MMOCR enables users to define their own optimizers, data preprocessors, and model components such as backbones, necks and heads as well as losses. Please refer to [getting_started.md](docs/getting_started.md) for how to construct a customized model. - **Numerous Utilities** @@ -43,24 +43,24 @@ This project is released under the [Apache 2.0 license](LICENSE). ## Changelog -v1.0 was released on 31/03/2021. +v1.0 was released on 07/04/2021. ## Benchmark and Model Zoo -Please refer to [MODEL_ZOO.md](MODEL_ZOO.md) for more details. +Please refer to [modelzoo.md](modelzoo.md) for more details. ## Installation -Please refer to [INSTALL.md](docs/INSTALL.md) for installation. +Please refer to [install.md](docs/install.md) for installation. ## Get Started -Please see [GETTING_STARTED.md](docs/GETTING_STARTED.md) for the basic usage of MMOCR. +Please see [getting_started.md](docs/getting_started.md) for the basic usage of MMOCR. ## Contributing -We appreciate all contributions to improve MMOCR. Please refer to [CONTRIBUTING.md](docs/CONTRIBUTING.md) for the contributing guidelines. +We appreciate all contributions to improve MMOCR. Please refer to [contributing.md](docs/contributing.md) for the contributing guidelines. ## Acknowledgement diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py new file mode 100644 index 00000000..949e800b --- /dev/null +++ b/configs/_base_/default_runtime.py @@ -0,0 +1,14 @@ +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=5, + hooks=[ + dict(type='TextLoggerHook') + + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/_base_/det_dataset/toy_dataset.py b/configs/_base_/det_dataset/toy_dataset.py new file mode 100644 index 00000000..0b87d55f --- /dev/null +++ b/configs/_base_/det_dataset/toy_dataset.py @@ -0,0 +1,97 @@ +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_cfg = None +test_cfg = None + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 640)], + ratio_range=(0.7, 1.3), + aspect_ratio_range=(0.9, 1.1), + multiscale_mode='value', + keep_ratio=False), + # shrink_ratio is from big to small. The 1st must be 1.0 + dict(type='PANetTargets', shrink_ratio=(1.0, 0.7)), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='RandomRotateTextDet'), + dict( + type='RandomCropInstances', + target_size=(640, 640), + instance_key='gt_kernels'), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=False, boundary_key='gt_kernels')), + dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(3000, 640), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(3000, 640), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +dataset_type = 'TextDetDataset' +img_prefix = 'tests/data/toy_dataset/imgs' +train_anno_file = 'tests/data/toy_dataset/instances_test.txt' +train1 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file, + loader=dict( + type='HardDiskLoader', + repeat=4, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])), + pipeline=train_pipeline, + test_mode=False) + +data_root = 'tests/data/toy_dataset' +train2 = dict( + type='IcdarDataset', + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline) + +test_anno_file = 'tests/data/toy_dataset/instances_test.txt' +test = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=test_anno_file, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train1, train2]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='hmean-iou') diff --git a/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py new file mode 100644 index 00000000..348778d8 --- /dev/null +++ b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py @@ -0,0 +1,126 @@ +# model settings +model = dict( + type='OCRMaskRCNN', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[4], + ratios=[0.17, 0.44, 1.13, 2.90, 7.46], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=1, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=1, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + gpu_assign_thr=50), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='OHEMSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py new file mode 100644 index 00000000..3ef9a55e --- /dev/null +++ b/configs/_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py @@ -0,0 +1,126 @@ +# model settings +model = dict( + type='OCRMaskRCNN', + text_repr_type='poly', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[4], + ratios=[0.17, 0.44, 1.13, 2.90, 7.46], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sample_num=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sample_num=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1, + gpu_assign_thr=50), + sampler=dict( + type='OHEMSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) diff --git a/configs/_base_/recog_datasets/seg_toy_dataset.py b/configs/_base_/recog_datasets/seg_toy_dataset.py new file mode 100644 index 00000000..d6c49dbf --- /dev/null +++ b/configs/_base_/recog_datasets/seg_toy_dataset.py @@ -0,0 +1,96 @@ +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +gt_label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomPaddingOCR', + max_ratio=[0.15, 0.2, 0.15, 0.2], + box_type='char_quads'), + dict(type='OpencvToPil'), + dict( + type='RandomRotateImageBox', + min_angle=-17, + max_angle=17, + box_type='char_quads'), + dict(type='PilToOpencv'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=512, + keep_aspect_ratio=True), + dict( + type='OCRSegTargets', + label_convertor=gt_label_convertor, + box_type='char_quads'), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict(type='ToTensorOCR'), + dict(type='FancyPCA'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='CustomFormatBundle', + keys=['gt_kernels'], + visualize=dict(flag=False, boundary_key=None), + call_super=False), + dict( + type='Collect', + keys=['img', 'gt_kernels'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=None, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict(type='CustomFormatBundle', call_super=False), + dict( + type='Collect', + keys=['img'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +prefix = 'tests/data/ocr_char_ann_toy_dataset/' +train = dict( + type='OCRSegDataset', + img_prefix=prefix + 'imgs', + ann_file=prefix + 'instances_train.txt', + loader=dict( + type='HardDiskLoader', + repeat=100, + parser=dict( + type='LineJsonParser', keys=['file_name', 'annotations', 'text'])), + pipeline=train_pipeline, + test_mode=True) + +test = dict( + type='OCRDataset', + img_prefix=prefix + 'imgs', + ann_file=prefix + 'instances_test.txt', + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=8, + workers_per_gpu=1, + train=dict(type='ConcatDataset', datasets=[train]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/_base_/recog_datasets/toy_dataset.py b/configs/_base_/recog_datasets/toy_dataset.py new file mode 100755 index 00000000..83848863 --- /dev/null +++ b/configs/_base_/recog_datasets/toy_dataset.py @@ -0,0 +1,99 @@ +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' +img_prefix = 'tests/data/ocr_toy_dataset/imgs' +train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt' +train1 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file1, + loader=dict( + type='HardDiskLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb' +train2 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file2, + loader=dict( + type='LmdbLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb' +test = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=test_anno_file1, + loader=dict( + type='LmdbLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train1, train2]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/_base_/recog_models/crnn.py b/configs/_base_/recog_models/crnn.py new file mode 100644 index 00000000..6b98c3d9 --- /dev/null +++ b/configs/_base_/recog_models/crnn.py @@ -0,0 +1,11 @@ +label_convertor = dict( + type='CTCConvertor', dict_type='DICT90', with_unknown=False) + +model = dict( + type='CRNNNet', + preprocessor=None, + backbone=dict(type='VeryDeepVgg', leakyRelu=False), + encoder=None, + decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), + loss=dict(type='CTCLoss', flatten=False), + label_convertor=label_convertor) diff --git a/configs/_base_/recog_models/nrtr.py b/configs/_base_/recog_models/nrtr.py new file mode 100644 index 00000000..40657578 --- /dev/null +++ b/configs/_base_/recog_models/nrtr.py @@ -0,0 +1,11 @@ +label_convertor = dict( + type='AttnConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +model = dict( + type='NRTR', + backbone=dict(type='NRTRModalityTransform'), + encoder=dict(type='TFEncoder'), + decoder=dict(type='TFDecoder'), + loss=dict(type='TFLoss'), + label_convertor=label_convertor, + max_seq_len=40) diff --git a/configs/_base_/recog_models/robust_scanner.py b/configs/_base_/recog_models/robust_scanner.py new file mode 100644 index 00000000..4cc2fa10 --- /dev/null +++ b/configs/_base_/recog_models/robust_scanner.py @@ -0,0 +1,24 @@ +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +hybrid_decoder = dict(type='SequenceAttentionDecoder') + +position_decoder = dict(type='PositionAttentionDecoder') + +model = dict( + type='RobustScanner', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='ChannelReductionEncoder', + in_channels=512, + out_channels=128, + ), + decoder=dict( + type='RobustScannerDecoder', + dim_input=512, + dim_model=128, + hybrid_decoder=hybrid_decoder, + position_decoder=position_decoder), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) diff --git a/configs/_base_/recog_models/sar.py b/configs/_base_/recog_models/sar.py new file mode 100755 index 00000000..8438d9b9 --- /dev/null +++ b/configs/_base_/recog_models/sar.py @@ -0,0 +1,24 @@ +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +model = dict( + type='SARNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='SAREncoder', + enc_bi_rnn=False, + enc_do_rnn=0.1, + enc_gru=False, + ), + decoder=dict( + type='ParallelSARDecoder', + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + pred_dropout=0.1, + d_k=512, + pred_concat=True), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) diff --git a/configs/_base_/recog_models/transformer.py b/configs/_base_/recog_models/transformer.py new file mode 100644 index 00000000..476643fa --- /dev/null +++ b/configs/_base_/recog_models/transformer.py @@ -0,0 +1,11 @@ +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=False) + +model = dict( + type='TransformerNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict(type='TFEncoder'), + decoder=dict(type='TFDecoder'), + loss=dict(type='TFLoss'), + label_convertor=label_convertor, + max_seq_len=40) diff --git a/configs/_base_/runtime_10e.py b/configs/_base_/runtime_10e.py new file mode 100644 index 00000000..393c5f14 --- /dev/null +++ b/configs/_base_/runtime_10e.py @@ -0,0 +1,14 @@ +checkpoint_config = dict(interval=10) +# yapf:disable +log_config = dict( + interval=5, + hooks=[ + dict(type='TextLoggerHook') + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/configs/_base_/schedules/schedule_1200e.py b/configs/_base_/schedules/schedule_1200e.py new file mode 100644 index 00000000..31e00920 --- /dev/null +++ b/configs/_base_/schedules/schedule_1200e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.007, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-7, by_epoch=True) +total_epochs = 1200 diff --git a/configs/_base_/schedules/schedule_160e.py b/configs/_base_/schedules/schedule_160e.py new file mode 100644 index 00000000..0958701a --- /dev/null +++ b/configs/_base_/schedules/schedule_160e.py @@ -0,0 +1,11 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.08, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[80, 128]) +total_epochs = 160 diff --git a/configs/_base_/schedules/schedule_1x.py b/configs/_base_/schedules/schedule_1x.py new file mode 100644 index 00000000..12694c87 --- /dev/null +++ b/configs/_base_/schedules/schedule_1x.py @@ -0,0 +1,11 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[8, 11]) +total_epochs = 12 diff --git a/configs/_base_/schedules/schedule_20e.py b/configs/_base_/schedules/schedule_20e.py new file mode 100644 index 00000000..e6ca2b24 --- /dev/null +++ b/configs/_base_/schedules/schedule_20e.py @@ -0,0 +1,4 @@ +_base_ = './schedule_1x.py' +# learning policy +lr_config = dict(step=[16, 19]) +total_epochs = 20 diff --git a/configs/_base_/schedules/schedule_2e.py b/configs/_base_/schedules/schedule_2e.py new file mode 100644 index 00000000..110ef8c5 --- /dev/null +++ b/configs/_base_/schedules/schedule_2e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.007, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=True) +total_epochs = 2 diff --git a/configs/_base_/schedules/schedule_2x.py b/configs/_base_/schedules/schedule_2x.py new file mode 100644 index 00000000..72b4135c --- /dev/null +++ b/configs/_base_/schedules/schedule_2x.py @@ -0,0 +1,4 @@ +_base_ = './schedule_1x.py' +# learning policy +lr_config = dict(step=[16, 22]) +total_epochs = 24 diff --git a/configs/_base_/schedules/schedule_adadelta_16e.py b/configs/_base_/schedules/schedule_adadelta_16e.py new file mode 100644 index 00000000..f8cc8b9b --- /dev/null +++ b/configs/_base_/schedules/schedule_adadelta_16e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='Adadelta', lr=1.0) +optimizer_config = dict(grad_clip=dict(max_norm=0.5)) +# learning policy +lr_config = dict(policy='step', step=[8, 10, 12]) +total_epochs = 16 diff --git a/configs/_base_/schedules/schedule_adadelta_8e.py b/configs/_base_/schedules/schedule_adadelta_8e.py new file mode 100644 index 00000000..2b4a8444 --- /dev/null +++ b/configs/_base_/schedules/schedule_adadelta_8e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='Adadelta', lr=1.0) +optimizer_config = dict(grad_clip=dict(max_norm=0.5)) +# learning policy +lr_config = dict(policy='step', step=[4, 6, 7]) +total_epochs = 8 diff --git a/configs/_base_/schedules/schedule_adam_1e.py b/configs/_base_/schedules/schedule_adam_1e.py new file mode 100644 index 00000000..b4b13379 --- /dev/null +++ b/configs/_base_/schedules/schedule_adam_1e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='poly', power=0.9) +total_epochs = 1 diff --git a/configs/_base_/schedules/schedule_adam_600e.py b/configs/_base_/schedules/schedule_adam_600e.py new file mode 100644 index 00000000..e946603e --- /dev/null +++ b/configs/_base_/schedules/schedule_adam_600e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='poly', power=0.9) +total_epochs = 600 diff --git a/configs/_base_/schedules/schedule_sgd_600e.py b/configs/_base_/schedules/schedule_sgd_600e.py new file mode 100644 index 00000000..9a605291 --- /dev/null +++ b/configs/_base_/schedules/schedule_sgd_600e.py @@ -0,0 +1,6 @@ +# optimizer +optimizer = dict(type='SGD', lr=1e-3, momentum=0.99, weight_decay=5e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[200, 400]) +total_epochs = 600 diff --git a/configs/kie/sdmgr/README.md b/configs/kie/sdmgr/README.md new file mode 100644 index 00000000..d0de5641 --- /dev/null +++ b/configs/kie/sdmgr/README.md @@ -0,0 +1,25 @@ +# Spatial Dual-Modality Graph Reasoning for Key Information Extraction + +## Introduction + +[ALGORITHM] + +```bibtex +@misc{sun2021spatial, + title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction}, + author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang}, + year={2021}, + eprint={2103.14470}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +## Results and models + +### WildReceipt + +| Method | Modality | Macro F1-Score | Download | +| :--------------------------------------------------------------------: | :--------------: | :------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [sdmgr_unet16](/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py) | Visual + Textual | 0.876 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_unet16_60e_wildreceipt_20210405-16a47642.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210405_104508.log.json) | +| [sdmgr_novisual](/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py) | Textual | 0.864 | [model](https://download.openmmlab.com/mmocr/kie/sdmgr/sdmgr_novisual_60e_wildreceipt_20210405-07bc26ad.pth) \| [log](https://download.openmmlab.com/mmocr/kie/sdmgr/20210405_141138.log.json) | diff --git a/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py b/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py new file mode 100644 index 00000000..6568909e --- /dev/null +++ b/configs/kie/sdmgr/sdmgr_novisual_60e_wildreceipt.py @@ -0,0 +1,93 @@ +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +max_scale, min_scale = 1024, 512 + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict( + type='Collect', + keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes']) +] + +dataset_type = 'KIEDataset' +data_root = 'data/wildreceipt' + +loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])) + +train = dict( + type=dataset_type, + ann_file=f'{data_root}/train.txt', + pipeline=train_pipeline, + img_prefix=data_root, + loader=loader, + dict_file=f'{data_root}/dict.txt', + test_mode=False) +test = dict( + type=dataset_type, + ann_file=f'{data_root}/test.txt', + pipeline=test_pipeline, + img_prefix=data_root, + loader=loader, + dict_file=f'{data_root}/dict.txt', + test_mode=True) + +data = dict( + samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test) + +evaluation = dict( + interval=1, + metric='macro_f1', + metric_options=dict( + macro_f1=dict( + ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]))) + +model = dict( + type='SDMGR', + backbone=dict(type='UNet', base_channels=16), + bbox_head=dict( + type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26), + visual_modality=False, + train_cfg=None, + test_cfg=None, + class_list=f'{data_root}/class_list.txt') + +optimizer = dict(type='Adam', weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1, + warmup_ratio=1, + step=[40, 50]) +total_epochs = 60 + +checkpoint_config = dict(interval=1) +log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +find_unused_parameters = True diff --git a/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py new file mode 100644 index 00000000..afc38393 --- /dev/null +++ b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py @@ -0,0 +1,93 @@ +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +max_scale, min_scale = 1024, 512 + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict( + type='Collect', + keys=['img', 'relations', 'texts', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(max_scale, min_scale), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='KIEFormatBundle'), + dict(type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes']) +] + +dataset_type = 'KIEDataset' +data_root = 'data/wildreceipt' + +loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])) + +train = dict( + type=dataset_type, + ann_file=f'{data_root}/train.txt', + pipeline=train_pipeline, + img_prefix=data_root, + loader=loader, + dict_file=f'{data_root}/dict.txt', + test_mode=False) +test = dict( + type=dataset_type, + ann_file=f'{data_root}/test.txt', + pipeline=test_pipeline, + img_prefix=data_root, + loader=loader, + dict_file=f'{data_root}/dict.txt', + test_mode=True) + +data = dict( + samples_per_gpu=4, workers_per_gpu=0, train=train, val=test, test=test) + +evaluation = dict( + interval=1, + metric='macro_f1', + metric_options=dict( + macro_f1=dict( + ignores=[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]))) + +model = dict( + type='SDMGR', + backbone=dict(type='UNet', base_channels=16), + bbox_head=dict( + type='SDMGRHead', visual_dim=16, num_chars=92, num_classes=26), + visual_modality=True, + train_cfg=None, + test_cfg=None, + class_list=f'{data_root}/class_list.txt') + +optimizer = dict(type='Adam', weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=1, + warmup_ratio=1, + step=[40, 50]) +total_epochs = 60 + +checkpoint_config = dict(interval=1) +log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +find_unused_parameters = True diff --git a/configs/textdet/dbnet/README.md b/configs/textdet/dbnet/README.md new file mode 100644 index 00000000..c8b6a094 --- /dev/null +++ b/configs/textdet/dbnet/README.md @@ -0,0 +1,28 @@ +# Real-time Scene Text Detection with Differentiable Binarization + +## Introduction + +[ALGORITHM] + +```bibtex +@article{Liao_Wan_Yao_Chen_Bai_2020, + title={Real-Time Scene Text Detection with Differentiable Binarization}, + journal={Proceedings of the AAAI Conference on Artificial Intelligence}, + author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang}, + year={2020}, + pages={11474-11481}} +``` + +## Results and models + +### CTW1500 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :---------------------------------------------------------------: | :--------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [DBNet](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) | + +### ICDAR2015 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :--------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [DBNet](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth.log.json) | diff --git a/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py new file mode 100644 index 00000000..790355fc --- /dev/null +++ b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py @@ -0,0 +1,96 @@ +_base_ = [ + '../../_base_/schedules/schedule_1200e.py', '../../_base_/runtime_10e.py' +] +model = dict( + type='DBNet', + pretrained='torchvision://resnet18', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + style='caffe'), + neck=dict( + type='FPNC', in_channels=[64, 128, 256, 512], lateral_channels=256), + bbox_head=dict( + type='DBHead', + text_repr_type='quad', + in_channels=256, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# for visualizing img, pls uncomment it. +# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + # img aug + dict( + type='ImgAug', + args=[['Fliplr', 0.5], + dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]), + # random crop + dict(type='EastRandomCrop', target_size=(640, 640)), + dict(type='DBNetTargets', shrink_ratio=0.4), + dict(type='Pad', size_divisor=32), + # for visualizing img and gts, pls set visualize = True + dict( + type='CustomFormatBundle', + keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'], + visualize=dict(flag=False, boundary_key='gt_shrink')), + dict( + type='Collect', + keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 736), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(2944, 736), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=16, + workers_per_gpu=8, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + # for debugging top k imgs + # select_first_k=200, + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline)) +evaluation = dict(interval=100, metric='hmean-iou') diff --git a/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py new file mode 100644 index 00000000..f1ccec51 --- /dev/null +++ b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py @@ -0,0 +1,105 @@ +_base_ = [ + '../../_base_/schedules/schedule_1200e.py', '../../_base_/runtime_10e.py' +] +load_from = 'checkpoints/textdet/dbnet/res50dcnv2_synthtext.pth' + +model = dict( + type='DBNet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + style='caffe', + dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), + stage_with_dcn=(False, True, True, True)), + neck=dict( + type='FPNC', in_channels=[256, 512, 1024, 2048], lateral_channels=256), + bbox_head=dict( + type='DBHead', + text_repr_type='quad', + in_channels=256, + loss=dict(type='DBLoss', alpha=5.0, beta=10.0, bbce_loss=True)), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015/' +# img_norm_cfg = dict( +# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# from official dbnet code +img_norm_cfg = dict( + mean=[122.67891434, 116.66876762, 104.00698793], + std=[255, 255, 255], + to_rgb=False) +# for visualizing img, pls uncomment it. +# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + # img aug + dict( + type='ImgAug', + args=[['Fliplr', 0.5], + dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]]), + # random crop + dict(type='EastRandomCrop', target_size=(640, 640)), + dict(type='DBNetTargets', shrink_ratio=0.4), + dict(type='Pad', size_divisor=32), + # for visualizing img and gts, pls set visualize = True + dict( + type='CustomFormatBundle', + keys=['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'], + visualize=dict(flag=False, boundary_key='gt_shrink')), + dict( + type='Collect', + keys=['img', 'gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(4068, 1024), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(4068, 1024), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + # for debugging top k imgs + # select_first_k=200, + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline)) +evaluation = dict(interval=100, metric='hmean-iou') diff --git a/configs/textdet/maskrcnn/README.md b/configs/textdet/maskrcnn/README.md new file mode 100644 index 00000000..520a9606 --- /dev/null +++ b/configs/textdet/maskrcnn/README.md @@ -0,0 +1,35 @@ +# Mask R-CNN + +## Introduction + +[ALGORITHM] + +```bibtex +@article{pmtd, + author={Jingchao Liu and Xuebo Liu and Jie Sheng and Ding Liang and Xin Li and Qingjie Liu}, + title={Pyramid Mask Text Detector}, + journal={CoRR}, + volume={abs/1903.11800}, + year={2019} +} +``` + +## Results and models + +### CTW1500 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :---------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [MaskRCNN](/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 160 | 1600 | 0.753 | 0.712 | 0.732 | [model](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.log.json) | + +### ICDAR2015 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :-----------------------------------------------------------------------: | :--------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [MaskRCNN](/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 160 | 1920 | 0.783 | 0.872 | 0.825 | [model](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.log.json) | + +### ICDAR2017 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :-----------------------------------------------------------------------: | :--------------: | :-------------: | :-----------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [MaskRCNN](/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py) | ImageNet | ICDAR2017 Train | ICDAR2017 Val | 160 | 1600 | 0.754 | 0.827 | 0.789 | [model](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.log.json) | diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py new file mode 100644 index 00000000..d821bb7f --- /dev/null +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py @@ -0,0 +1,67 @@ +_base_ = [ + '../../_base_/models/ocr_mask_rcnn_r50_fpn_ohem_poly.py', + '../../_base_/schedules/schedule_160e.py', '../../_base_/runtime_10e.py' +] + +dataset_type = 'IcdarDataset' +data_root = 'data/ctw1500/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='ScaleAspectJitter', + img_scale=None, + keep_ratio=False, + resize_type='indep_sample_in_range', + scale_range=(640, 2560)), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='RandomCropInstances', + target_size=(640, 640), + mask_type='union_all', + instance_key='gt_masks'), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + # resize the long size to 1600 + img_scale=(1600, 1600), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + # no flip + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + # select_first_k=1, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + # select_first_k=1, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py new file mode 100644 index 00000000..e2d8be68 --- /dev/null +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py @@ -0,0 +1,66 @@ +_base_ = [ + '../../_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py', + '../../_base_/schedules/schedule_160e.py', '../../_base_/runtime_10e.py' +] +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='ScaleAspectJitter', + img_scale=None, + keep_ratio=False, + resize_type='indep_sample_in_range', + scale_range=(640, 2560)), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='RandomCropInstances', + target_size=(640, 640), + mask_type='union_all', + instance_key='gt_masks'), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + # resize the long size to 1600 + img_scale=(1920, 1920), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + # no flip + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + # select_first_k=1, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + # select_first_k=1, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py new file mode 100644 index 00000000..8b948e09 --- /dev/null +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py @@ -0,0 +1,67 @@ +_base_ = [ + '../../_base_/models/ocr_mask_rcnn_r50_fpn_ohem.py', + '../../_base_/schedules/schedule_160e.py', '../../_base_/runtime_10e.py' +] + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2017/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# img_norm_cfg = dict(mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='ScaleAspectJitter', + img_scale=None, + keep_ratio=False, + resize_type='indep_sample_in_range', + scale_range=(640, 2560)), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='RandomCropInstances', + target_size=(640, 640), + mask_type='union_all', + instance_key='gt_masks'), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + # resize the long size to 1600 + img_scale=(1600, 1600), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + # no flip + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + # select_first_k=1, + ann_file=data_root + '/instances_val.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + # select_first_k=1, + ann_file=data_root + '/instances_val.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/panet/README.md b/configs/textdet/panet/README.md new file mode 100644 index 00000000..b4290bf6 --- /dev/null +++ b/configs/textdet/panet/README.md @@ -0,0 +1,35 @@ +# Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network + +## Introduction + +[ALGORITHM] + +```bibtex +@inproceedings{WangXSZWLYS19, + author={Wenhai Wang and Enze Xie and Xiaoge Song and Yuhang Zang and Wenjia Wang and Tong Lu and Gang Yu and Chunhua Shen}, + title={Efficient and Accurate Arbitrary-Shaped Text Detection With Pixel Aggregation Network}, + booktitle={ICCV}, + pages={8439--8448}, + year={2019} + } +``` + +## Results and models + +### CTW1500 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :----------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [PANet](/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 600 | 640 | 0.790 | 0.838 | 0.813 | [model](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.log.json) | + +### ICDAR2015 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :------------------------------------------------------------------: | :--------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [PANet](/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 600 | 736 | 0.734 | 0.856 | 0.791 | [model](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.log.json) | + +### ICDAR2017 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :------------------------------------------------------------------: | :--------------: | :-------------: | :-----------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [PANet](/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py) | ImageNet | ICDAR2017 Train | ICDAR2017 Val | 600 | 800 | 0.604 | 0.812 | 0.693 | [model](https://download.openmmlab.com/mmocr/textdet/panet/panet_r50_fpem_ffm_sbn_600e_icdar2017_20210219-b4877a4f.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/panet/panet_r50_fpem_ffm_sbn_600e_icdar2017_20210219-b4877a4f.log.json) | diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py new file mode 100644 index 00000000..58d1d22b --- /dev/null +++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py @@ -0,0 +1,104 @@ +_base_ = [ + '../../_base_/schedules/schedule_adam_600e.py', + '../../_base_/runtime_10e.py' +] +model = dict( + type='PANet', + pretrained='torchvision://resnet18', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict(type='FPEM_FFM', in_channels=[64, 128, 256, 512]), + bbox_head=dict( + type='PANHead', + text_repr_type='poly', + in_channels=[128, 128, 128, 128], + out_channels=6, + loss=dict(type='PANLoss')), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/ctw1500/' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# for visualizing img, pls uncomment it. +# img_norm_cfg = dict( +# mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 640)], + ratio_range=(0.7, 1.3), + aspect_ratio_range=(0.9, 1.1), + multiscale_mode='value', + keep_ratio=False), + # shrink_ratio is from big to small. The 1st must be 1.0 + dict(type='PANetTargets', shrink_ratio=(1.0, 0.7)), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='RandomRotateTextDet'), + dict( + type='RandomCropInstances', + target_size=(640, 640), + instance_key='gt_kernels'), + dict(type='Pad', size_divisor=32), + # for visualizing img and gts, pls set visualize = True + dict( + type='CustomFormatBundle', + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=False, boundary_key='gt_kernels')), + dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(3000, 640), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(3000, 640), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + # for debugging top k imgs + # select_first_k=200, + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline)) +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py new file mode 100644 index 00000000..de12c26c --- /dev/null +++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py @@ -0,0 +1,102 @@ +_base_ = [ + '../../_base_/schedules/schedule_adam_600e.py', + '../../_base_/runtime_10e.py' +] +model = dict( + type='PANet', + pretrained='torchvision://resnet18', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict(type='FPEM_FFM', in_channels=[64, 128, 256, 512]), + bbox_head=dict( + type='PANHead', + text_repr_type='quad', + in_channels=[128, 128, 128, 128], + out_channels=6, + loss=dict(type='PANLoss')), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# for visualizing img, pls uncomment it. +# img_norm_cfg = dict( +# mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 736)], + ratio_range=(0.7, 1.3), + aspect_ratio_range=(0.9, 1.1), + multiscale_mode='value', + keep_ratio=False), + dict(type='PANetTargets', shrink_ratio=(1.0, 0.5), max_shrink=20), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='RandomRotateTextDet'), + dict( + type='RandomCropInstances', + target_size=(736, 736), + instance_key='gt_kernels'), + dict(type='Pad', size_divisor=32), + # for visualizing img and gts, pls set visualize = True + dict( + type='CustomFormatBundle', + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=False, boundary_key='gt_kernels')), + dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 736), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(1333, 736), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + # for debugging top k imgs + # select_first_k=200, + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + # select_first_k=100, + pipeline=test_pipeline)) +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py new file mode 100644 index 00000000..933f0bae --- /dev/null +++ b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py @@ -0,0 +1,93 @@ +_base_ = [ + '../../_base_/schedules/schedule_adam_600e.py', + '../../_base_/runtime_10e.py' +] +model = dict( + type='PANet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict(type='FPEM_FFM', in_channels=[256, 512, 1024, 2048]), + bbox_head=dict( + type='PANHead', + in_channels=[128, 128, 128, 128], + out_channels=6, + loss=dict(type='PANLoss', speedup_bbox_thr=32)), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2017/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 800)], + ratio_range=(0.7, 1.3), + aspect_ratio_range=(0.9, 1.1), + multiscale_mode='value', + keep_ratio=False), + dict(type='PANetTargets', shrink_ratio=(1.0, 0.5)), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='RandomRotateTextDet'), + dict( + type='RandomCropInstances', + target_size=(800, 800), + instance_key='gt_kernels'), + dict(type='Pad', size_divisor=32), + # for visualizing img and gts, pls set visualize = True + dict( + type='CustomFormatBundle', + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=False, boundary_key='gt_kernels')), + dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_val.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_val.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/psenet/README.md b/configs/textdet/psenet/README.md new file mode 100644 index 00000000..9995b4d8 --- /dev/null +++ b/configs/textdet/psenet/README.md @@ -0,0 +1,29 @@ +# PSENet + +## Introduction + +[ALGORITHM] + +```bibtex +@article{li2018shape, + title={Shape robust text detection with progressive scale expansion network}, + author={Li, Xiang and Wang, Wenhai and Hou, Wenbo and Liu, Ruo-Ze and Lu, Tong and Yang, Jian}, + journal={arXiv preprint arXiv:1806.02559}, + year={2018} +} +``` + +## Results and models + +### CTW1500 + +| Method | Backbone | Extra Data | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :------------------------------------------------------------------: | :------: | :--------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [PSENet-4s](/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py) | ResNet50 | - | CTW1500 Train | CTW1500 Test | 600 | 1280 | 0.728 | 0.849 | 0.784 | [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/psenet/20210401_215421.log.json) | + +### ICDAR2015 + +| Method | Backbone | Extra Data | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :--------------------------------------------------------------------: | :------: | :---------------------------------------------------------------------------------------------------------------------------------------: | :----------: | :-------: | :-----: | :-------: | :----: | :-------: | :---: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [PSENet-4s](/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | ResNet50 | - | IC15 Train | IC15 Test | 600 | 2240 | 0.784 | 0.831 | 0.807 | [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_icdar2015-c6131f0d.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/psenet/20210331_214145.log.json) | +| [PSENet-4s](/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py) | ResNet50 | pretrain on IC17 MLT [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_icdar2017_as_pretrain-3bd6056c.pth) | IC15 Train | IC15 Test | 600 | 2240 | 0.834 | 0.861 | 0.847 | [model](https://download.openmmlab.com/mmocr/textdet/psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth) \| [log]() | diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py new file mode 100644 index 00000000..e1f5fd04 --- /dev/null +++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py @@ -0,0 +1,108 @@ +_base_ = ['../../_base_/default_runtime.py'] + +# optimizer +optimizer = dict(type='Adam', lr=1e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[200, 400]) +total_epochs = 600 + +model = dict( + type='PSENet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict( + type='FPNF', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + fusion_type='concat'), + bbox_head=dict( + type='PSEHead', + text_repr_type='poly', + in_channels=[256], + out_channels=7, + loss=dict(type='PSELoss')), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/ctw1500/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 736)], + ratio_range=(0.5, 3), + aspect_ratio_range=(1, 1), + multiscale_mode='value', + long_size_bound=1280, + short_size_bound=640, + resize_type='long_short_bound', + keep_ratio=False), + dict(type='PSENetTargets'), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='RandomRotateTextDet'), + dict( + type='RandomCropInstances', + target_size=(640, 640), + instance_key='gt_kernels'), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=False, boundary_key='gt_kernels')), + dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1280, 1280), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(1280, 1280), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) + +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py new file mode 100644 index 00000000..5eb7538c --- /dev/null +++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py @@ -0,0 +1,108 @@ +_base_ = ['../../_base_/runtime_10e.py'] + +# optimizer +optimizer = dict(type='Adam', lr=1e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[200, 400]) +total_epochs = 600 + +model = dict( + type='PSENet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='SyncBN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict( + type='FPNF', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + fusion_type='concat'), + bbox_head=dict( + type='PSEHead', + text_repr_type='quad', + in_channels=[256], + out_channels=7, + loss=dict(type='PSELoss')), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2015/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 736)], # unused + ratio_range=(0.5, 3), + aspect_ratio_range=(1, 1), + multiscale_mode='value', + long_size_bound=1280, + short_size_bound=640, + resize_type='long_short_bound', + keep_ratio=False), + dict(type='PSENetTargets'), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='RandomRotateTextDet'), + dict( + type='RandomCropInstances', + target_size=(640, 640), + instance_key='gt_kernels'), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=False, boundary_key='gt_kernels')), + dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2240, 2200), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(2240, 2200), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=8, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) + +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py new file mode 100644 index 00000000..3c20c4cc --- /dev/null +++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py @@ -0,0 +1,103 @@ +_base_ = [ + '../../_base_/schedules/schedule_sgd_600e.py', + '../../_base_/runtime_10e.py' +] +model = dict( + type='PSENet', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict( + type='FPNF', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + fusion_type='concat'), + bbox_head=dict( + type='PSEHead', + text_repr_type='quad', + in_channels=[256], + out_channels=7, + loss=dict(type='PSELoss')), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/icdar2017/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 736)], + ratio_range=(0.5, 3), + aspect_ratio_range=(1, 1), + multiscale_mode='value', + long_size_bound=1280, + short_size_bound=640, + resize_type='long_short_bound', + keep_ratio=False), + dict(type='PSENetTargets'), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='RandomRotateTextDet'), + dict( + type='RandomCropInstances', + target_size=(640, 640), + instance_key='gt_kernels'), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=False, boundary_key='gt_kernels')), + dict(type='Collect', keys=['img', 'gt_kernels', 'gt_mask']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2240, 2200), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(2240, 2200), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_val.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_val.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) + +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textdet/textsnake/README.md b/configs/textdet/textsnake/README.md new file mode 100644 index 00000000..8e761f4c --- /dev/null +++ b/configs/textdet/textsnake/README.md @@ -0,0 +1,23 @@ +# Textsnake + +## Introduction + +[ALGORITHM] + +```bibtex +@article{long2018textsnake, + title={TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes}, + author={Long, Shangbang and Ruan, Jiaqiang and Zhang, Wenjie and He, Xin and Wu, Wenhao and Yao, Cong}, + booktitle={ECCV}, + pages={20-36}, + year={2018} +} +``` + +## Results and models + +### CTW1500 + +| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | +| :----------------------------------------------------------------------------: | :--------------: | :-----------: | :----------: | :-----: | :-------: | :----: | :-------: | :---: | :--------------------------------------------------------------------------------------------------------------------------: | +| [TextSnake](/configs/textdet/textsnake/textsnake_r50_fpn_unet_600e_ctw1500.py) | ImageNet | CTW1500 Train | CTW1500 Test | 1200 | 736 | 0.795 | 0.840 | 0.817 | [model](https://download.openmmlab.com/mmocr/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth) \| [log]() | diff --git a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py new file mode 100644 index 00000000..dba03cd1 --- /dev/null +++ b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py @@ -0,0 +1,113 @@ +_base_ = [ + '../../_base_/schedules/schedule_1200e.py', + '../../_base_/default_runtime.py' +] +model = dict( + type='TextSnake', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='caffe'), + neck=dict( + type='FPN_UNET', in_channels=[256, 512, 1024, 2048], out_channels=32), + bbox_head=dict( + type='TextSnakeHead', + in_channels=32, + text_repr_type='poly', + loss=dict(type='TextSnakeLoss')), + train_cfg=None, + test_cfg=None) + +dataset_type = 'IcdarDataset' +data_root = 'data/ctw1500/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadTextAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ColorJitter', brightness=32.0 / 255, saturation=0.5), + dict(type='Normalize', **img_norm_cfg), + dict( + type='RandomCropPolyInstances', + instance_key='gt_masks', + crop_ratio=0.65, + min_side_ratio=0.3), + dict( + type='RandomRotatePolyInstances', + rotate_ratio=0.5, + max_angle=20, + pad_with_fixed_color=False), + dict( + type='ScaleAspectJitter', + img_scale=[(3000, 736)], # unused + ratio_range=(0.7, 1.3), + aspect_ratio_range=(0.9, 1.1), + multiscale_mode='value', + long_size_bound=800, + short_size_bound=480, + resize_type='long_short_bound', + keep_ratio=False), + dict(type='SquareResizePad', target_size=800, pad_ratio=0.6), + dict(type='RandomFlip', flip_ratio=0.5, direction='horizontal'), + dict(type='TextSnakeTargets'), + dict(type='Pad', size_divisor=32), + dict( + type='CustomFormatBundle', + keys=[ + 'gt_text_mask', 'gt_center_region_mask', 'gt_mask', + 'gt_radius_map', 'gt_sin_map', 'gt_cos_map' + ], + visualize=dict(flag=False, boundary_key='gt_text_mask')), + dict( + type='Collect', + keys=[ + 'img', 'gt_text_mask', 'gt_center_region_mask', 'gt_mask', + 'gt_radius_map', 'gt_sin_map', 'gt_cos_map' + ]) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 736), + flip=False, + transforms=[ + dict(type='Resize', img_scale=(1333, 736), keep_ratio=True), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + '/instances_training.json', + img_prefix=data_root + '/imgs', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + '/instances_test.json', + img_prefix=data_root + '/imgs', + pipeline=test_pipeline)) + +evaluation = dict(interval=10, metric='hmean-iou') diff --git a/configs/textrecog/crnn/README.md b/configs/textrecog/crnn/README.md new file mode 100644 index 00000000..489cc64a --- /dev/null +++ b/configs/textrecog/crnn/README.md @@ -0,0 +1,37 @@ +# An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition + +## Introduction + +[ALGORITHM] + +```bibtex +@article{shi2016end, + title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition}, + author={Shi, Baoguang and Bai, Xiang and Yao, Cong}, + journal={IEEE transactions on pattern analysis and machine intelligence}, + year={2016} +} +``` + +## Results and Models + +### Train Dataset + +| trainset | instance_num | repeat_num | note | +| :------: | :----------: | :--------: | :---: | +| Syn90k | 8919273 | 1 | synth | + +### Test Dataset + +| testset | instance_num | note | +| :-----: | :----------: | :-----: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | + +## Results and models + +| methods | | Regular Text | | | | Irregular Text | | download | +| :-----: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------: | +| methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 | +| CRNN | 80.5 | 81.5 | 86.5 | | - | - | - | [model](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic-a723a1c5.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/crnn/20210326_111035.log.json) | diff --git a/configs/textrecog/crnn/crnn_academic_dataset.py b/configs/textrecog/crnn/crnn_academic_dataset.py new file mode 100644 index 00000000..75701698 --- /dev/null +++ b/configs/textrecog/crnn/crnn_academic_dataset.py @@ -0,0 +1,138 @@ +_base_ = [] +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=1, + hooks=[ + dict(type='TextLoggerHook') + + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# model +label_convertor = dict( + type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True) + +model = dict( + type='CRNNNet', + preprocessor=None, + backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1), + encoder=None, + decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), + loss=dict(type='CTCLoss'), + label_convertor=label_convertor, + pretrained=None) + +train_cfg = None +test_cfg = None + +# optimizer +optimizer = dict(type='Adadelta', lr=1.0) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[]) +total_epochs = 5 + +# data +img_norm_cfg = dict(mean=[0.5], std=[0.5]) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=32, + min_width=100, + max_width=100, + keep_aspect_ratio=False), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=32, + min_width=4, + max_width=None, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']), +] + +dataset_type = 'OCRDataset' + +train_img_prefix = 'data/mixture/Syn90k/mnt/ramdisk/max/90kDICT32px' +train_ann_file = 'data/mixture/Syn90k/label.lmdb' + +train1 = dict( + type=dataset_type, + img_prefix=train_img_prefix, + ann_file=train_ann_file, + loader=dict( + type='LmdbLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +test_prefix = 'data/mixture/' +test_img_prefix1 = test_prefix + 'icdar_2013/' +test_img_prefix2 = test_prefix + 'IIIT5K/' +test_img_prefix3 = test_prefix + 'svt/' + +test_ann_file1 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file2 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file3 = test_prefix + 'svt/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + train=dict(type='ConcatDataset', datasets=[train1]), + val=dict(type='ConcatDataset', datasets=[test1, test2, test3]), + test=dict(type='ConcatDataset', datasets=[test1, test2, test3])) + +evaluation = dict(interval=1, metric='acc') + +cudnn_benchmark = True diff --git a/configs/textrecog/crnn/crnn_toy_dataset.py b/configs/textrecog/crnn/crnn_toy_dataset.py new file mode 100644 index 00000000..76854024 --- /dev/null +++ b/configs/textrecog/crnn/crnn_toy_dataset.py @@ -0,0 +1,6 @@ +_base_ = [ + '../../_base_/schedules/schedule_adadelta_8e.py', + '../../_base_/default_runtime.py', + '../../_base_/recog_datasets/toy_dataset.py', + '../../_base_/recog_models/crnn.py' +] diff --git a/configs/textrecog/nrtr/README.md b/configs/textrecog/nrtr/README.md new file mode 100644 index 00000000..7d018559 --- /dev/null +++ b/configs/textrecog/nrtr/README.md @@ -0,0 +1,61 @@ +# NRTR + +## Introduction + +[ALGORITHM] + +```bibtex +@inproceedings{sheng2019nrtr, + title={NRTR: A no-recurrence sequence-to-sequence model for scene text recognition}, + author={Sheng, Fenfen and Chen, Zhineng and Xu, Bo}, + booktitle={2019 International Conference on Document Analysis and Recognition (ICDAR)}, + pages={781--786}, + year={2019}, + organization={IEEE} +} +``` + +[BACKBONE] + +```bibtex +@inproceedings{li2019show, + title={Show, attend and read: A simple and strong baseline for irregular text recognition}, + author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={33}, + number={01}, + pages={8610--8617}, + year={2019} +} +``` + +## Dataset + +### Train Dataset + +| trainset | instance_num | repeat_num | source | +| :--------: | :----------: | :--------: | :----------------------: | +| SynthText | 7266686 | 1 | synth | +| Syn90k | 8919273 | 1 | synth | + +### Test Dataset + +| testset | instance_num | type | +| :-----: | :----------: | :-------------------------: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | +| IC15 | 2077 | irregular | +| SVTP | 645 | irregular | +| CT80 | 288 | irregular | + +## Results and Models + +| Methods | Backbone || Regular Text |||| Irregular Text ||download| +| :-------: | :---------: | :----: | :----: | :--: | :-: | :--: | :------: | :--: | :-----: | +| | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 | +| [NRTR](/configs/textrecog/nrtr/nrtr_r31_academic.py) | R31-1/16-1/8 | 93.9 | 90.0| 93.5 | | 74.5 | 78.5 | 86.5 | [model](https://download.openmmlab.com/mmocr/textrecog/nrtr/nrtr_r31_academic_20210406-954db95e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/nrtr/20210406_010150.log.json) | + +**Notes:** + +- `R31-1/16-1/8` means the height of feature from backbone is 1/16 of input image, where 1/8 for width. diff --git a/configs/textrecog/nrtr/nrtr_modality_toy.py b/configs/textrecog/nrtr/nrtr_modality_toy.py new file mode 100644 index 00000000..e8201c6f --- /dev/null +++ b/configs/textrecog/nrtr/nrtr_modality_toy.py @@ -0,0 +1,112 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_models/nrtr.py', +] + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 6 + +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=100, + keep_aspect_ratio=False), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=100, + keep_aspect_ratio=False), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' +img_prefix = 'tests/data/ocr_toy_dataset/imgs' +train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt' +train1 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file1, + loader=dict( + type='HardDiskLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb' +train2 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file2, + loader=dict( + type='LmdbLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb' +test = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=test_anno_file1, + loader=dict( + type='LmdbLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train1, train2]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/nrtr/nrtr_r31_academic.py b/configs/textrecog/nrtr/nrtr_r31_academic.py new file mode 100644 index 00000000..9299583b --- /dev/null +++ b/configs/textrecog/nrtr/nrtr_r31_academic.py @@ -0,0 +1,163 @@ +_base_ = [ + '../../_base_/default_runtime.py', '../../_base_/recog_models/nrtr.py' +] + +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +model = dict( + type='NRTR', + backbone=dict( + type='ResNet31OCR', + layers=[1, 2, 5, 3], + channels=[32, 64, 128, 256, 512, 512], + stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)), + last_stage_pool=True), + encoder=dict(type='TFEncoder'), + decoder=dict(type='TFDecoder'), + loss=dict(type='TFLoss'), + label_convertor=label_convertor, + max_seq_len=40) + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 6 + +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' + +train_prefix = 'data/mixture/' + +train_img_prefix1 = train_prefix + \ + 'SynthText/synthtext/SynthText_patch_horizontal' +train_img_prefix2 = train_prefix + 'Syn90k/mnt/ramdisk/max/90kDICT32px' + +train_ann_file1 = train_prefix + 'SynthText/label.lmdb', +train_ann_file2 = train_prefix + 'Syn90k/label.lmdb' + +train1 = dict( + type=dataset_type, + img_prefix=train_img_prefix1, + ann_file=train_ann_file1, + loader=dict( + type='LmdbLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train2 = {key: value for key, value in train1.items()} +train2['img_prefix'] = train_img_prefix2 +train2['ann_file'] = train_ann_file2 + +test_prefix = 'data/mixture/' +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'icdar_2015/' +test_img_prefix5 = test_prefix + 'svtp/' +test_img_prefix6 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt' +test_ann_file5 = test_prefix + 'svtp/test_label.txt' +test_ann_file6 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +test5 = {key: value for key, value in test1.items()} +test5['img_prefix'] = test_img_prefix5 +test5['ann_file'] = test_ann_file5 + +test6 = {key: value for key, value in test1.items()} +test6['img_prefix'] = test_img_prefix6 +test6['ann_file'] = test_ann_file6 + +data = dict( + samples_per_gpu=128, + workers_per_gpu=4, + train=dict(type='ConcatDataset', datasets=[train1, train2]), + val=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6]), + test=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/robust_scanner/README.md b/configs/textrecog/robust_scanner/README.md new file mode 100644 index 00000000..01e42971 --- /dev/null +++ b/configs/textrecog/robust_scanner/README.md @@ -0,0 +1,51 @@ +# RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition + +## Introduction + +[ALGORITHM] + +```bibtex +@inproceedings{yue2020robustscanner, + title={RobustScanner: Dynamically Enhancing Positional Clues for Robust Text Recognition}, + author={Yue, Xiaoyu and Kuang, Zhanghui and Lin, Chenhao and Sun, Hongbin and Zhang, Wayne}, + booktitle={European Conference on Computer Vision}, + year={2020} +} +``` + +## Dataset + +### Train Dataset + +| trainset | instance_num | repeat_num | source | +| :--------: | :----------: | :--------: | :----------------------: | +| icdar_2011 | 3567 | 20 | real | +| icdar_2013 | 848 | 20 | real | +| icdar2015 | 4468 | 20 | real | +| coco_text | 42142 | 20 | real | +| IIIT5K | 2000 | 20 | real | +| SynthText | 2400000 | 1 | synth | +| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) | +| Syn90k | 2400000 | 1 | synth | + +### Test Dataset + +| testset | instance_num | type | +| :-----: | :----------: | :-------------------------: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | +| IC15 | 2077 | irregular | +| SVTP | 645 | irregular, 639 in [[1]](#1) | +| CT80 | 288 | irregular | + +## Results and Models + +| Methods | GPUs | | Regular Text | | | | Irregular Text | | download | +| :-----------------------------------------------------------------: | :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 | +| [RobustScanner](configs/textrecog/robust_scanner/robustscanner_r31_academic.py) | 16 | 95.1 | 89.2 | 93.1 | | 77.8 | 80.3 | 90.3 | [model](https://download.openmmlab.com/mmocr/textrecog/robustscanner/robustscanner_r31_academic-5f05874f.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/robustscanner/20210401_170932.log.json) | + +## References + +[1] Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019. diff --git a/configs/textrecog/robust_scanner/robust_scanner_toy_dataset.py b/configs/textrecog/robust_scanner/robust_scanner_toy_dataset.py new file mode 100644 index 00000000..6fc6c125 --- /dev/null +++ b/configs/textrecog/robust_scanner/robust_scanner_toy_dataset.py @@ -0,0 +1,12 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_models/robust_scanner.py', + '../../_base_/recog_datasets/toy_dataset.py' +] + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 6 diff --git a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py new file mode 100644 index 00000000..e90dd75d --- /dev/null +++ b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py @@ -0,0 +1,197 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_models/robust_scanner.py' +] + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' + +train_prefix = 'data/mixture/' + +train_img_prefix1 = train_prefix + 'icdar_2011' +train_img_prefix2 = train_prefix + 'icdar_2013' +train_img_prefix3 = train_prefix + 'icdar_2015' +train_img_prefix4 = train_prefix + 'coco_text' +train_img_prefix5 = train_prefix + 'III5K' +train_img_prefix6 = train_prefix + 'SynthText_Add' +train_img_prefix7 = train_prefix + 'SynthText' +train_img_prefix8 = train_prefix + 'Syn90k' + +train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt', +train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt', +train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt', +train_ann_file4 = train_prefix + 'coco_text/train_label.txt', +train_ann_file5 = train_prefix + 'III5K/train_label.txt', +train_ann_file6 = train_prefix + 'SynthText_Add/label.txt', +train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt', +train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt' + +train1 = dict( + type=dataset_type, + img_prefix=train_img_prefix1, + ann_file=train_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=20, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train2 = {key: value for key, value in train1.items()} +train2['img_prefix'] = train_img_prefix2 +train2['ann_file'] = train_ann_file2 + +train3 = {key: value for key, value in train1.items()} +train3['img_prefix'] = train_img_prefix3 +train3['ann_file'] = train_ann_file3 + +train4 = {key: value for key, value in train1.items()} +train4['img_prefix'] = train_img_prefix4 +train4['ann_file'] = train_ann_file4 + +train5 = {key: value for key, value in train1.items()} +train5['img_prefix'] = train_img_prefix5 +train5['ann_file'] = train_ann_file5 + +train6 = dict( + type=dataset_type, + img_prefix=train_img_prefix6, + ann_file=train_ann_file6, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train7 = {key: value for key, value in train6.items()} +train7['img_prefix'] = train_img_prefix7 +train7['ann_file'] = train_ann_file7 + +train8 = {key: value for key, value in train6.items()} +train8['img_prefix'] = train_img_prefix8 +train8['ann_file'] = train_ann_file8 + +test_prefix = 'data/mixture/' +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'icdar_2015/' +test_img_prefix5 = test_prefix + 'svtp/' +test_img_prefix6 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt' +test_ann_file5 = test_prefix + 'svtp/test_label.txt' +test_ann_file6 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +test5 = {key: value for key, value in test1.items()} +test5['img_prefix'] = test_img_prefix5 +test5['ann_file'] = test_ann_file5 + +test6 = {key: value for key, value in test1.items()} +test6['img_prefix'] = test_img_prefix6 +test6['ann_file'] = test_ann_file6 + +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='ConcatDataset', + datasets=[ + train1, train2, train3, train4, train5, train6, train7, train8 + ]), + val=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6]), + test=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/sar/README.md b/configs/textrecog/sar/README.md new file mode 100644 index 00000000..8854ae04 --- /dev/null +++ b/configs/textrecog/sar/README.md @@ -0,0 +1,67 @@ +# Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition + +## Introduction + +[ALGORITHM] + +```bibtex +@inproceedings{li2019show, + title={Show, attend and read: A simple and strong baseline for irregular text recognition}, + author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={33}, + number={01}, + pages={8610--8617}, + year={2019} +} +``` + +## Dataset + +### Train Dataset + +| trainset | instance_num | repeat_num | source | +| :--------: | :----------: | :--------: | :----------------------: | +| icdar_2011 | 3567 | 20 | real | +| icdar_2013 | 848 | 20 | real | +| icdar2015 | 4468 | 20 | real | +| coco_text | 42142 | 20 | real | +| IIIT5K | 2000 | 20 | real | +| SynthText | 2400000 | 1 | synth | +| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) | +| Syn90k | 2400000 | 1 | synth | + +### Test Dataset + +| testset | instance_num | type | +| :-----: | :----------: | :-------------------------: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | +| IC15 | 2077 | irregular | +| SVTP | 645 | irregular, 639 in [[1]](#1) | +| CT80 | 288 | irregular | + +## Results and Models + +| Methods | Backbone | Decoder | | Regular Text | | | | Irregular Text | | download | +| :-----------------------------------------------------------------: | :---------: | :------------------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| | | | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 | +| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 95.0 | 89.6 | 93.7 | | 79.0 | 82.2 | 88.9 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) | +| [SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 95.2 | 88.7 | 92.4 | | 78.2 | 81.9 | 89.6 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json) | + +**Notes:** + +- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width. +- We did not use beam search during decoding. +- We implemented two kinds of decoder. Namely, `ParallelSARDecoder` and `SequentialSARDecoder`. + - `ParallelSARDecoder`: Parallel decoding during training with `LSTM` layer. It would be faster. + - `SequentialSARDecoder`: Sequential Decoding during training with `LSTMCell`. It would be easier to understand. +- For train dataset. + - We did not construct distinct data groups (20 groups in [[1]](#1)) to train the model group-by-group since it would render model training too complicated. + - Instead, we randomly selected `2.4m` patches from `Syn90k`, `2.4m` from `SynthText` and `1.2m` from `SynthAdd`, and grouped all data together. See [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_academic.py) for details. +- We used 48 GPUs with `total_batch_size = 64 * 48` in the experiment above to speedup training, while keeping the `initial lr = 1e-3` unchanged. + +## References + +[1] Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019. diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py new file mode 100644 index 00000000..4e405227 --- /dev/null +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py @@ -0,0 +1,219 @@ +_base_ = ['../../_base_/default_runtime.py'] + +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +model = dict( + type='SARNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='SAREncoder', + enc_bi_rnn=False, + enc_do_rnn=0.1, + enc_gru=False, + ), + decoder=dict( + type='ParallelSARDecoder', + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + pred_dropout=0.1, + d_k=512, + pred_concat=True), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' + +train_prefix = 'data/mixture/' + +train_img_prefix1 = train_prefix + 'icdar_2011' +train_img_prefix2 = train_prefix + 'icdar_2013' +train_img_prefix3 = train_prefix + 'icdar_2015' +train_img_prefix4 = train_prefix + 'coco_text' +train_img_prefix5 = train_prefix + 'III5K' +train_img_prefix6 = train_prefix + 'SynthText_Add' +train_img_prefix7 = train_prefix + 'SynthText' +train_img_prefix8 = train_prefix + 'Syn90k' + +train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt', +train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt', +train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt', +train_ann_file4 = train_prefix + 'coco_text/train_label.txt', +train_ann_file5 = train_prefix + 'III5K/train_label.txt', +train_ann_file6 = train_prefix + 'SynthText_Add/label.txt', +train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt', +train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt' + +train1 = dict( + type=dataset_type, + img_prefix=train_img_prefix1, + ann_file=train_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=20, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train2 = {key: value for key, value in train1.items()} +train2['img_prefix'] = train_img_prefix2 +train2['ann_file'] = train_ann_file2 + +train3 = {key: value for key, value in train1.items()} +train3['img_prefix'] = train_img_prefix3 +train3['ann_file'] = train_ann_file3 + +train4 = {key: value for key, value in train1.items()} +train4['img_prefix'] = train_img_prefix4 +train4['ann_file'] = train_ann_file4 + +train5 = {key: value for key, value in train1.items()} +train5['img_prefix'] = train_img_prefix5 +train5['ann_file'] = train_ann_file5 + +train6 = dict( + type=dataset_type, + img_prefix=train_img_prefix6, + ann_file=train_ann_file6, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train7 = {key: value for key, value in train6.items()} +train7['img_prefix'] = train_img_prefix7 +train7['ann_file'] = train_ann_file7 + +train8 = {key: value for key, value in train6.items()} +train8['img_prefix'] = train_img_prefix8 +train8['ann_file'] = train_ann_file8 + +test_prefix = 'data/mixture/' +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'icdar_2015/' +test_img_prefix5 = test_prefix + 'svtp/' +test_img_prefix6 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt' +test_ann_file5 = test_prefix + 'svtp/test_label.txt' +test_ann_file6 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +test5 = {key: value for key, value in test1.items()} +test5['img_prefix'] = test_img_prefix5 +test5['ann_file'] = test_ann_file5 + +test6 = {key: value for key, value in test1.items()} +test6['img_prefix'] = test_img_prefix6 +test6['ann_file'] = test_ann_file6 + +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='ConcatDataset', + datasets=[ + train1, train2, train3, train4, train5, train6, train7, train8 + ]), + val=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6]), + test=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py new file mode 100755 index 00000000..0c8b53e2 --- /dev/null +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py @@ -0,0 +1,110 @@ +_base_ = [ + '../../_base_/default_runtime.py', '../../_base_/recog_models/sar.py' +] + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' +img_prefix = 'tests/data/ocr_toy_dataset/imgs' +train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt' +train1 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file1, + loader=dict( + type='HardDiskLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb' +train2 = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file2, + loader=dict( + type='LmdbLoader', + repeat=100, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb' +test = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=test_anno_file1, + loader=dict( + type='LmdbLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train1, train2]), + val=dict(type='ConcatDataset', datasets=[test]), + test=dict(type='ConcatDataset', datasets=[test])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py new file mode 100644 index 00000000..6fa00dd7 --- /dev/null +++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py @@ -0,0 +1,219 @@ +_base_ = ['../../_base_/default_runtime.py'] + +label_convertor = dict( + type='AttnConvertor', dict_type='DICT90', with_unknown=True) + +model = dict( + type='SARNet', + backbone=dict(type='ResNet31OCR'), + encoder=dict( + type='SAREncoder', + enc_bi_rnn=False, + enc_do_rnn=0.1, + enc_gru=False, + ), + decoder=dict( + type='SequentialSARDecoder', + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + pred_dropout=0.1, + d_k=512, + pred_concat=True), + loss=dict(type='SARLoss'), + label_convertor=label_convertor, + max_seq_len=30) + +# optimizer +optimizer = dict(type='Adam', lr=1e-3) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio' + ]), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiRotateAugOCR', + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True, + width_downsample_ratio=0.25), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ]) +] + +dataset_type = 'OCRDataset' + +train_prefix = 'data/mixture/' + +train_img_prefix1 = train_prefix + 'icdar_2011' +train_img_prefix2 = train_prefix + 'icdar_2013' +train_img_prefix3 = train_prefix + 'icdar_2015' +train_img_prefix4 = train_prefix + 'coco_text' +train_img_prefix5 = train_prefix + 'III5K' +train_img_prefix6 = train_prefix + 'SynthText_Add' +train_img_prefix7 = train_prefix + 'SynthText' +train_img_prefix8 = train_prefix + 'Syn90k' + +train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt', +train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt', +train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt', +train_ann_file4 = train_prefix + 'coco_text/train_label.txt', +train_ann_file5 = train_prefix + 'III5K/train_label.txt', +train_ann_file6 = train_prefix + 'SynthText_Add/label.txt', +train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt', +train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt' + +train1 = dict( + type=dataset_type, + img_prefix=train_img_prefix1, + ann_file=train_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=20, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train2 = {key: value for key, value in train1.items()} +train2['img_prefix'] = train_img_prefix2 +train2['ann_file'] = train_ann_file2 + +train3 = {key: value for key, value in train1.items()} +train3['img_prefix'] = train_img_prefix3 +train3['ann_file'] = train_ann_file3 + +train4 = {key: value for key, value in train1.items()} +train4['img_prefix'] = train_img_prefix4 +train4['ann_file'] = train_ann_file4 + +train5 = {key: value for key, value in train1.items()} +train5['img_prefix'] = train_img_prefix5 +train5['ann_file'] = train_ann_file5 + +train6 = dict( + type=dataset_type, + img_prefix=train_img_prefix6, + ann_file=train_ann_file6, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) + +train7 = {key: value for key, value in train6.items()} +train7['img_prefix'] = train_img_prefix7 +train7['ann_file'] = train_ann_file7 + +train8 = {key: value for key, value in train6.items()} +train8['img_prefix'] = train_img_prefix8 +train8['ann_file'] = train_ann_file8 + +test_prefix = 'data/mixture/' +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'icdar_2015/' +test_img_prefix5 = test_prefix + 'svtp/' +test_img_prefix6 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt' +test_ann_file5 = test_prefix + 'svtp/test_label.txt' +test_ann_file6 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +test5 = {key: value for key, value in test1.items()} +test5['img_prefix'] = test_img_prefix5 +test5['ann_file'] = test_ann_file5 + +test6 = {key: value for key, value in test1.items()} +test6['img_prefix'] = test_img_prefix6 +test6['ann_file'] = test_ann_file6 + +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + train=dict( + type='ConcatDataset', + datasets=[ + train1, train2, train3, train4, train5, train6, train7, train8 + ]), + val=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6]), + test=dict( + type='ConcatDataset', + datasets=[test1, test2, test3, test4, test5, test6])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/seg/README.md b/configs/textrecog/seg/README.md new file mode 100644 index 00000000..28a96425 --- /dev/null +++ b/configs/textrecog/seg/README.md @@ -0,0 +1,43 @@ +# SegOCR Simple Baseline. + +## Introduction + +[ALGORITHM] + +```bibtex +@unpublished{key, + title={SegOCR Simple Baseline.}, + author={}, + note={Unpublished Manuscript}, + year={2021} +} +``` + +## Dataset + +### Train Dataset + +| trainset | instance_num | repeat_num | source | +| :-------: | :----------: | :--------: | :----: | +| SynthText | 7266686 | 1 | synth | + +### Test Dataset + +| testset | instance_num | type | +| :-----: | :----------: | :-------: | +| IIIT5K | 3000 | regular | +| SVT | 647 | regular | +| IC13 | 1015 | regular | +| CT80 | 288 | irregular | + +## Results and Models + +|Backbone|Neck|Head|||Regular Text|||Irregular Text|download +| :-------------: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: | :-----: | +|||||IIIT5K|SVT|IC13||CT80| +|R31-1/16|FPNOCR|1x||90.9|81.8|90.7||80.9|[model](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic-72235b11.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/seg/20210325_112835.log.json) | + +**Notes:** + +- `R31-1/16` means the size (both height and width ) of feature from backbone is 1/16 of input image. +- `1x` means the size (both height and width) of feature from head is the same with input image. diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py new file mode 100644 index 00000000..8a568f8e --- /dev/null +++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py @@ -0,0 +1,160 @@ +_base_ = ['../../_base_/default_runtime.py'] + +# optimizer +optimizer = dict(type='Adam', lr=1e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +model = dict( + type='SegRecognizer', + backbone=dict( + type='ResNet31OCR', + layers=[1, 2, 5, 3], + channels=[32, 64, 128, 256, 512, 512], + out_indices=[0, 1, 2, 3], + stage4_pool_cfg=dict(kernel_size=2, stride=2), + last_stage_pool=True), + neck=dict( + type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256), + head=dict( + type='SegHead', + in_channels=256, + upsample_param=dict(scale_factor=2.0, mode='nearest')), + loss=dict( + type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=True), + label_convertor=label_convertor) + +find_unused_parameters = True + +img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + +gt_label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomPaddingOCR', + max_ratio=[0.15, 0.2, 0.15, 0.2], + box_type='char_quads'), + dict(type='OpencvToPil'), + dict( + type='RandomRotateImageBox', + min_angle=-17, + max_angle=17, + box_type='char_quads'), + dict(type='PilToOpencv'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=512, + keep_aspect_ratio=True), + dict( + type='OCRSegTargets', + label_convertor=gt_label_convertor, + box_type='char_quads'), + dict(type='RandomRotateTextDet', rotate_ratio=0.5, max_angle=15), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict(type='ToTensorOCR'), + dict(type='FancyPCA'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='CustomFormatBundle', + keys=['gt_kernels'], + visualize=dict(flag=False, boundary_key=None), + call_super=False), + dict( + type='Collect', + keys=['img', 'gt_kernels'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeOCR', + height=64, + min_width=64, + max_width=None, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict(type='CustomFormatBundle', call_super=False), + dict( + type='Collect', + keys=['img'], + meta_keys=['filename', 'ori_shape', 'img_shape']) +] + +train_img_root = 'data/mixture/' + +train_img_prefix = train_img_root + 'SynthText' + +train_ann_file = train_img_root + 'SynthText/instances_train.txt' + +train = dict( + type='OCRSegDataset', + img_prefix=train_img_prefix, + ann_file=train_ann_file, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', keys=['file_name', 'annotations', 'text'])), + pipeline=train_pipeline, + test_mode=False) + +dataset_type = 'OCRDataset' +test_prefix = 'data/mixture/' + +test_img_prefix1 = test_prefix + 'IIIT5K/' +test_img_prefix2 = test_prefix + 'svt/' +test_img_prefix3 = test_prefix + 'icdar_2013/' +test_img_prefix4 = test_prefix + 'ct80/' + +test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt' +test_ann_file2 = test_prefix + 'svt/test_label.txt' +test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt' +test_ann_file4 = test_prefix + 'ct80/test_label.txt' + +test1 = dict( + type=dataset_type, + img_prefix=test_img_prefix1, + ann_file=test_ann_file1, + loader=dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=test_pipeline, + test_mode=True) + +test2 = {key: value for key, value in test1.items()} +test2['img_prefix'] = test_img_prefix2 +test2['ann_file'] = test_ann_file2 + +test3 = {key: value for key, value in test1.items()} +test3['img_prefix'] = test_img_prefix3 +test3['ann_file'] = test_ann_file3 + +test4 = {key: value for key, value in test1.items()} +test4['img_prefix'] = test_img_prefix4 +test4['ann_file'] = test_ann_file4 + +data = dict( + samples_per_gpu=16, + workers_per_gpu=2, + train=dict(type='ConcatDataset', datasets=[train]), + val=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]), + test=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4])) + +evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py new file mode 100644 index 00000000..63b3d08c --- /dev/null +++ b/configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py @@ -0,0 +1,35 @@ +_base_ = [ + '../../_base_/default_runtime.py', + '../../_base_/recog_datasets/seg_toy_dataset.py' +] + +# optimizer +optimizer = dict(type='Adam', lr=1e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[3, 4]) +total_epochs = 5 + +label_convertor = dict( + type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True) + +model = dict( + type='SegRecognizer', + backbone=dict( + type='ResNet31OCR', + layers=[1, 2, 5, 3], + channels=[32, 64, 128, 256, 512, 512], + out_indices=[0, 1, 2, 3], + stage4_pool_cfg=dict(kernel_size=2, stride=2), + last_stage_pool=True), + neck=dict( + type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256), + head=dict( + type='SegHead', + in_channels=256, + upsample_param=dict(scale_factor=2.0, mode='nearest')), + loss=dict( + type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=False), + label_convertor=label_convertor) + +find_unused_parameters = True diff --git a/demo/demo_text_det.jpg b/demo/demo_text_det.jpg new file mode 100644 index 00000000..d23de3cd Binary files /dev/null and b/demo/demo_text_det.jpg differ diff --git a/demo/demo_text_recog.jpg b/demo/demo_text_recog.jpg new file mode 100644 index 00000000..d9915983 Binary files /dev/null and b/demo/demo_text_recog.jpg differ diff --git a/demo/image_demo.py b/demo/image_demo.py new file mode 100644 index 00000000..dbccf784 --- /dev/null +++ b/demo/image_demo.py @@ -0,0 +1,44 @@ +from argparse import ArgumentParser + +import mmcv + +from mmdet.apis import init_detector +from mmocr.apis.inference import model_inference +from mmocr.datasets import build_dataset # noqa: F401 +from mmocr.models import build_detector # noqa: F401 + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file.') + parser.add_argument('config', help='Config file.') + parser.add_argument('checkpoint', help='Checkpoint file.') + parser.add_argument('save_path', help='Path to save visualized image.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference.') + parser.add_argument( + '--imshow', + action='store_true', + help='Whether show image with OpenCV.') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + model = init_detector(args.config, args.checkpoint, device=args.device) + if model.cfg.data.test['type'] == 'ConcatDataset': + model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ + 0].pipeline + + # test a single image + result = model_inference(model, args.img) + print(f'result: {result}') + + # show the results + img = model.show_result(args.img, result, out_file=None, show=False) + + mmcv.imwrite(img, args.save_path) + if args.imshow: + mmcv.imshow(img, 'predicted results') + + +if __name__ == '__main__': + main() diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py new file mode 100644 index 00000000..45e23c5b --- /dev/null +++ b/demo/webcam_demo.py @@ -0,0 +1,52 @@ +import argparse + +import cv2 +import torch + +from mmdet.apis import init_detector +from mmocr.apis import model_inference +from mmocr.datasets import build_dataset # noqa: F401 +from mmocr.models import build_detector # noqa: F401 + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMDetection webcam demo.') + parser.add_argument('config', help='Test config file path.') + parser.add_argument('checkpoint', help='Checkpoint file.') + parser.add_argument( + '--device', type=str, default='cuda:0', help='CPU/CUDA device option.') + parser.add_argument( + '--camera-id', type=int, default=0, help='Camera device id.') + parser.add_argument( + '--score-thr', type=float, default=0.5, help='Bbox score threshold.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + device = torch.device(args.device) + + model = init_detector(args.config, args.checkpoint, device=device) + if model.cfg.data.test['type'] == 'ConcatDataset': + model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ + 0].pipeline + + camera = cv2.VideoCapture(args.camera_id) + + print('Press "Esc", "q" or "Q" to exit.') + while True: + ret_val, img = camera.read() + result = model_inference(model, img) + + ch = cv2.waitKey(1) + if ch == 27 or ch == ord('q') or ch == ord('Q'): + break + + model.show_result( + img, result, score_thr=args.score_thr, wait_time=1, show=True) + + +if __name__ == '__main__': + main() diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 00000000..afd8fe55 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,28 @@ +ARG PYTORCH="1.5" +ARG CUDA="10.1" +ARG CUDNN="7" + +FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel + +ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX" +ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" +ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" + +RUN apt-get update && apt-get install -y git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN conda clean --all +RUN pip install mmcv-full==1.2.6+torch1.5.0+cu101 -f https://download.openmmlab.com/mmcv/dist/index.html + +RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdet +WORKDIR /mmdet +RUN git checkout -b v2.9.0 v2.9.0 +RUN pip install -r requirements.txt +RUN pip install . + +RUN git clone https://github.com/open-mmlab/mmocr.git /mmocr +WORKDIR /mmocr +ENV FORCE_CUDA="1" +RUN pip install -r requirements.txt +RUN pip install --no-cache-dir -e . diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d4bb2cbb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 00000000..a23ab961 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,15 @@ +API Reference +============= + +mmocr.apis +------------- +.. automodule:: mmocr.apis + :members: + +mmocr.core +------------- + +evaluation +^^^^^^^^^^ +.. automodule:: mmocr.core.evaluation + :members: diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 00000000..8a802039 --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1 @@ +## Changelog diff --git a/docs/code_of_conduct.md b/docs/code_of_conduct.md new file mode 100644 index 00000000..89039925 --- /dev/null +++ b/docs/code_of_conduct.md @@ -0,0 +1,93 @@ + +# Contributor Covenant Code of Conduct + + +- [Contributor Covenant Code of Conduct](#contributor-covenant-code-of-conduct) + - [Our Pledge](#our-pledge) + - [Our Standards](#our-standards) + - [Our Responsibilities](#our-responsibilities) + - [Scope](#scope) + - [Enforcement](#enforcement) + - [Attribution](#attribution) + + + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at chenkaidev@gmail.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 00000000..e709591e --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,83 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import os +import subprocess +import sys + +sys.path.insert(0, os.path.abspath('..')) + +# -- Project information ----------------------------------------------------- + +project = 'MMOCR' +copyright = '2020-2030, OpenMMLab' +author = 'OpenMMLab' + +# The full version, including alpha/beta/rc tags +release = '0.1.0' + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'recommonmark', + 'sphinx_markdown_tables', +] + +autodoc_mock_imports = ['torch', 'torchvision', 'mmcv', 'mmocr.version'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} + +# The master toctree document. +master_doc = 'index' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +master_doc = 'index' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = [] + + +def builder_inited_handler(app): + subprocess.run(['./merge_docs.sh']) + subprocess.run(['./stats.py']) + + +def setup(app): + app.connect('builder-inited', builder_inited_handler) diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 00000000..2ac726b1 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,187 @@ + +# Contributing to mmocr + +All kinds of contributions are welcome, including but not limited to the following. + +- Fixes (typo, bugs) +- New features and components +- Enhancement like function speedup + + +- [Contributing to mmocr](#contributing-to-mmocr) + - [Workflow](#workflow) + - [Step 1: Create a Fork](#step-1-create-a-fork) + - [Step 2: Develop a new feature](#step-2-develop-a-new-feature) + - [Step 2.1: Keep your fork up to date](#step-21-keep-your-fork-up-to-date) + - [Step 2.2: Create a feature branch](#step-22-create-a-feature-branch) + - [Create an issue on github](#create-an-issue-on-github) + - [Create branch](#create-branch) + - [Step 2.3: Develop and test ](#step-23-develop-and-test-your_new_feature) + - [Step 2.4: Prepare to Pull Request](#step-24-prepare-to-pull-request) + - [Merge official repo updates to your fork](#merge-official-repo-updates-to-your-fork) + - [Push branch to your remote forked repo,](#push-your_new_feature-branch-to-your-remote-forked-repo) + - [Step 2.5: Create a Pull Request](#step-25-create-a-pull-request) + - [Step 2.6: Review code](#step-26-review-code) + - [Step 2.7: Revise (optional)](#step-27-revise-your_new_feature--optional) + - [Step 2.8: Delete branch if your PR is accepted.](#step-28-delete-your_new_feature-branch-if-your-pr-is-accepted) + - [Code style](#code-style) + - [Python](#python) + - [C++ and CUDA](#c-and-cuda) + + + +## Workflow + +This document describes the fork & merge request workflow that should be used when contributing to **MMOCR**. + +The official public [repository](https://github.com/open-mmlab/mmocr) holds two branches with an infinite lifetime only: ++ master ++ develop + +The *master* branch is the main branch where the source code of **HEAD** always reflects a *production-ready state*. + +The *develop* branch is the branch where the source code of **HEAD** always reflects a state with the latest development changes for the next release. + +Feature branches are used to develop new features for the upcoming or a distant future release. + +![](res/git-workflow-master-develop.png) + +All new developers to **MMOCR** need to follow the following steps: + + +### Step 1: Create a Fork + +1. Fork the repo on GitHub or GitLab to your personal account. Click the `Fork` button on the [project page](https://github.com/open-mmlab/mmocr). + +2. Clone your new forked repo to your computer. +``` +git clone https://github.com//mmocr.git +``` +3. Add the official repo as an upstream: +``` +git remote add upstream https://github.com/open-mmlab/mmocr.git +``` + + +### Step 2: Develop a new feature + + +#### Step 2.1: Keep your fork up to date + +Whenever you want to update your fork with the latest upstream changes, you need to fetch the upstream repo's branches and latest commits to bring them into your repository: + +``` +# Fetch from upstream remote +git fetch upstream + +# Update your master branch +git checkout master +git rebase upstream/master +git push origin master + +# Update your develop branch +git checkout develop +git rebase upsteam/develop +git push origin develop +``` + + +#### Step 2.2: Create a feature branch + +##### Create an issue on [github](https://github.com/open-mmlab/mmocr) +- The title of the issue should be one of the following formats: `[Feature]: xxx`, `[Fix]: xxx`, `[Enhance]: xxx`, `[Refactor]: xxx`. +- More details can be written in comments. + + +##### Create branch +``` +git checkout -b feature/iss_ develop +# index is the issue number above +``` +Till now, your fork has three branches as follows: + +![](res/git-workflow-feature.png) + + +#### Step 2.3: Develop and test + +Develop your new feature and test it to make sure it works well. + +Pls run +``` +pre-commit run --all-files +pytest tests +``` +and fix all failures before every git commit. +``` +git commit -m "fix #: " +``` +**Note:** +- is the [issue](#step2.2) number. + + +#### Step 2.4: Prepare to Pull Request +- Make sure to link your pull request to the related issue. Please refer to the [instructon](https://docs.github.com/en/github/managing-your-work-on-github/linking-a-pull-request-to-an-issue) + + + +##### Merge official repo updates to your fork + +``` +# fetch from upstream remote. i.e., the official repo +git fetch upstream + +# update the develop branch of your fork +git checkout develop +git rebase upsteam/develop +git push origin develop + +# update the branch +git checkout +git rebase develop +# solve conflicts if any and Test +``` + + +##### Push branch to your remote forked repo, +``` +git checkout +git push origin +``` + +#### Step 2.5: Create a Pull Request + +Go to the page for your fork on GitHub, select your new feature branch, and click the pull request button to integrate your feature branch into the upstream remote’s develop branch. + + +#### Step 2.6: Review code + + + +#### Step 2.7: Revise (optional) +If PR is not accepted, pls follow Step 2.1, 2.3, 2.4 and 2.5 till your PR is accepted. + + +#### Step 2.8: Delete branch if your PR is accepted. +``` +git branch -d +git push origin : +``` + + +## Code style + + +### Python +We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style. + +We use the following tools for linting and formatting: +- [flake8](http://flake8.pycqa.org/en/latest/): linter +- [yapf](https://github.com/google/yapf): formatter +- [isort](https://github.com/timothycrosley/isort): sort imports + +>Before you create a PR, make sure that your code lints and is formatted by yapf. + + +### C++ and CUDA +We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). diff --git a/docs/datasets.md b/docs/datasets.md new file mode 100644 index 00000000..da9e06ba --- /dev/null +++ b/docs/datasets.md @@ -0,0 +1,208 @@ + +# Datasets Preparation +This page lists the datasets which are commonly used in text detection, text recognition and key information extraction, and their download links. + + +- [Datasets Preparation](#datasets-preparation) + - [Text Detection](#text-detection) + - [Text Recognition](#text-recognition) + - [Key Information Extraction](#key-information-extraction) + + + +## Text Detection +**The structure of the text detection dataset directory is organized as follows.** +``` +├── ctw1500 +│   ├── imgs +│   ├── instances_test.json +│   └── instances_training.json +├── icdar2015 +│   ├── imgs +│   ├── instances_test.json +│   └── instances_training.json +├── icdar2017 +│   ├── imgs +│   ├── instances_training.json +│   └── instances_val.json +├── synthtext +│   ├── imgs +│   └── instances_training.lmdb +``` +| Dataset | | Images | | | Annotation Files | | | Note | | +|:---------:|:-:|:--------------------------:|:-:|:--------------------------------------------:|:---------------------------------------:|:----------------------------------------:|:-:|:----:|---| +| | | | | training | validation | testing | | | | +| CTW1500 | | [homepage](https://github.com/Yuliang-Liu/Curve-Text-Detector) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/ctw1500/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/ctw1500/instances_test.json) | | | | +| ICDAR2015 | | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) | | | | +| ICDAR2017 | | [homepage](https://rrc.cvc.uab.es/?ch=8&com=downloads) | [renamed_imgs](https://download.openmmlab.com/mmocr/data/icdar2017/renamed_imgs.tar) | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_training.json) | [instances_val.json](https://openmmlab) | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_test.json) | | | | +| Synthtext | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | | [instances_training.lmdb](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.lmdb)|-| | | | + +- For `icdar2015`: + - Step1: Download `ch4_training_images.zip` and `ch4_test_images.zip` from [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) + - Step2: Download [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) and [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) + - Step3: + ```bash + mkdir icdar2015 && cd icdar2015 + mv /path/to/instances_training.json . + mv /path/to/instances_test.json . + + mkdir imgs && cd imgs + ln -s /path/to/ch4_training_images training + ln -s /path/to/ch4_test_images test + ``` +- For `icdar2017`: + - To avoid the effect of rotation when load `jpg` with opencv, We provide re-saved `png` format image in [renamed_images](https://download.openmmlab.com/mmocr/data/icdar2017/renamed_imgs.tar). You can copy these images to `imgs`. + + +## Text Recognition +**The structure of the text recognition dataset directory is organized as follows.** + +``` +├── mixture +│   ├── coco_text +│ │ ├── train_label.txt +│ │ ├── train_words +│   ├── icdar_2011 +│ │ ├── training_label.txt +│ │ ├── Challenge1_Training_Task3_Images_GT +│   ├── icdar_2013 +│ │ ├── train_label.txt +│ │ ├── test_label_1015.txt +│ │ ├── test_label_1095.txt +│ │ ├── Challenge2_Training_Task3_Images_GT +│ │ ├── Challenge2_Test_Task3_Images +│   ├── icdar_2015 +│ │ ├── train_label.txt +│ │ ├── test_label.txt +│ │ ├── ch4_training_word_images_gt +│ │ ├── ch4_test_word_images_gt +│   ├── III5K +│ │ ├── train_label.txt +│ │ ├── test_label.txt +│ │ ├── train +│ │ ├── test +│   ├── ct80 +│ │ ├── test_label.txt +│ │ ├── image +│   ├── svt +│ │ ├── test_label.txt +│ │ ├── image +│   ├── svtp +│ │ ├── test_label.txt +│ │ ├── image +│   ├── Synth90k +│ │ ├── shuffle_labels.txt +│ │ ├── label.lmdb +│ │ ├── mnt +│   ├── SynthText +│ │ ├── shuffle_labels.txt +│ │ ├── instances_train.txt +│ │ ├── label.lmdb +│ │ ├── synthtext +│   ├── SynthAdd +│ │ ├── label.txt +│ │ ├── SynthText_Add + +``` +| Dataset | | images | annotation file | annotation file | Note | +|:----------:|:-:|:---------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------------------:|:----:| +|| | |training | test | | +| coco_text ||[homepage](https://rrc.cvc.uab.es/?ch=5&com=downloads) |[train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt) |- | | +| icdar_2011 ||[homepage](http://www.cvc.uab.es/icdar2011competition/?com=downloads) |[train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) |- | | +| icdar_2013 | | [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt) | [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) | | +| icdar_2015 | | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt) | | +| IIIT5K | | [homepage](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt) | | +| ct80 | | - |-|[test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt)|| +| svt | | [homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | | +| svtp | | - | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | | +| Synth90k | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/label.lmdb) | - | | +| SynthText | | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.lmdb](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.lmdb) | - | | +| SynthAdd | | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt)|- | | + +- For `icdar_2013`: + - Step1: Download `Challenge2_Test_Task3_Images.zip` and `Challenge2_Training_Task3_Images_GT.zip` from [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads) + - Step2: Download [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) and [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt) +- For `icdar_2015`: + - Step1: Download `ch4_training_word_images_gt.zip` and `ch4_test_word_images_gt.zip` from [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) + - Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) and [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt) +- For `IIIT5K`: + - Step1: Download `IIIT5K-Word_V3.0.tar.gz` from [homepage](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html) + - Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) and [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt) +- For `svt`: + - Step1: Download `svt.zip` form [homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) + - Step2: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) +- For `ct80`: + - Step1: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt) +- For `svtp`: + - Step1: Download [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) +- For `coco_text`: + - Step1: Download from [homepage](https://rrc.cvc.uab.es/?ch=5&com=downloads) + - Step2: Download [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt) + +- For `Syn90k`: + - Step1: Download `mjsynth.tar.gz` from [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) + - Step2: Download [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Synth90k/shuffle_labels.txt) + - Step3: + ```bash + mkdir Syn90k && cd Syn90k + + mv /path/to/mjsynth.tar.gz . + + tar -xzf mjsynth.tar.gz + + mv /path/to/shuffle_labels.txt . + + # create soft link + cd /path/to/mmocr/data/mixture + + ln -s /path/to/Syn90k Syn90k + ``` +- For `SynthText`: + - Step1: Download `SynthText.zip` from [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) + - Step2: Download [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) + - Step3: Download [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) + - Step4: + ```bash + unzip SynthText.zip + + cd SynthText + + mv /path/to/shuffle_labels.txt . + + # create soft link + cd /path/to/mmocr/data/mixture + + ln -s /path/to/SynthText SynthText + ``` +- For `SynthAdd`: + - Step1: Download `SynthText_Add.zip` from [SynthAdd](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x)) + - Step2: Download [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt) + - Step3: + ```bash + mkdir SynthAdd && cd SynthAdd + + mv /path/to/SynthText_Add.zip . + + unzip SynthText_Add.zip + + mv /path/to/label.txt . + + # create soft link + cd /path/to/mmocr/data/mixture + + ln -s /path/to/SynthAdd SynthAdd + ``` + + +## Key Information Extraction +**The structure of the key information extraction dataset directory is organized as follows.** +``` +└── wildreceipt + ├── anno_files + ├── class_list.txt + ├── dict.txt + ├── image_files + ├── test.txt + └── train.txt +``` +- Download [wildreceipt.tar](https://download.openmmlab.com/mmocr/data/wildreceipt.tar) diff --git a/docs/getting_started.md b/docs/getting_started.md new file mode 100644 index 00000000..20c45a92 --- /dev/null +++ b/docs/getting_started.md @@ -0,0 +1,369 @@ + +# Getting Started + +This page provides basic tutorials on the usage of MMOCR. +For the installation instructions, please see [install.md](install.md). + + +- [Getting Started](#getting-started) + - [Inference with Pretrained Models](#inference-with-pretrained-models) + - [Test a Single Image](#test-a-single-image) + - [Test Multiple Images](#test-multiple-images) + - [Test a Dataset](#test-a-dataset) + - [Test with Single/Multiple GPUs](#test-with-singlemultiple-gpus) + - [Optional Arguments](#optional-arguments) + - [Test with Slurm](#test-with-slurm) + - [Optional Arguments](#optional-arguments-1) + - [Train a Model](#train-a-model) + - [Train with Single/Multiple GPUs](#train-with-singlemultiple-gpus) + - [Train with Toy Dataset.](#train-with-toy-dataset) + - [Train with Slurm](#train-with-slurm) + - [Launch Multiple Jobs on a Single Machine](#launch-multiple-jobs-on-a-single-machine) + - [Useful Tools](#useful-tools) + - [Publish a Model](#publish-a-model) + - [Customized Settings](#customized-settings) + - [Flexible Dataset](#flexible-dataset) + - [Encoder-Decoder-Based Text Recognition Task](#encoder-decoder-based-text-recognition-task) + - [Optional Arguments:](#optional-arguments-2) + - [Segmentation-Based Text Recognition Task](#segmentation-based-text-recognition-task) + - [Text Detection Task](#text-detection-task) + - [COCO-like Dataset](#coco-like-dataset) + + + + +## Inference with Pretrained Models + +We provide testing scripts to evaluate a full dataset, as well as some task-specific image demos. + + +### Test a Single Image + +You can use the following command to test a single image with one GPU. + +```shell +python demo/image_demo.py ${TEST_IMG} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${SAVE_PATH} [--imshow] [--device ${GPU_ID}] +``` + +If `--imshow` is specified, the demo will also show the image with OpenCV. For example: + +```shell +python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/demo_text_det_pred.jpg +``` + +The predicted result will be saved as `demo/demo_text_det_pred.jpg`. + + +### Test Multiple Images + +```shell +# for text detection +sh tools/test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR} + +# for text recognition +sh tools/ocr_test_imgs.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${IMG_ROOT_PATH} ${IMG_LIST} ${RESULTS_DIR} +``` +It will save both the prediction results and visualized images to `${RESULTS_DIR}` + + +### Test a Dataset + +MMOCR implements **distributed** testing with `MMDistributedDataParallel`. (Please refer to [datasets.md](datasets.md) to prepare your datasets) + + +#### Test with Single/Multiple GPUs + +You can use the following command to test a dataset with single/multiple GPUs. + +```shell +./tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--eval ${EVAL_METRIC}] +``` +For example, + +```shell +./tools/dist_test.sh configs/example_config.py work_dirs/example_exp/example_model_20200202.pth 1 --eval hmean-iou +``` + +##### Optional Arguments + +- `--eval`: Specify the evaluation metric. For text detection, the metric should be either 'hmean-ic13' or 'hmean-iou'. For text recognition, the metric should be 'acc'. + + +#### Test with Slurm + +If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `slurm_test.sh`. + +```shell +[GPUS=${GPUS}] ./tools/slurm_test.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--eval ${EVAL_METRIC}] +``` +Here is an example of using 8 GPUs to test an example model on the 'dev' partition with job name 'test_job'. + +```shell +GPUS=8 ./tools/slurm_test.sh dev test_job configs/example_config.py work_dirs/example_exp/example_model_20200202.pth --eval hmean-iou +``` + +You can check [slurm_test.sh](https://github.com/open-mmlab/mmocr/blob/master/tools/slurm_test.sh) for full arguments and environment variables. + + + +##### Optional Arguments + +- `--eval`: Specify the evaluation metric. For text detection, the metric should be either 'hmean-ic13' or 'hmean-iou'. For text recognition, the metric should be 'acc'. + + + +## Train a Model + +MMOCR implements **distributed** training with `MMDistributedDataParallel`. (Please refer to [datasets.md](datasets.md) to prepare your datasets) + +All outputs (log files and checkpoints) will be saved to a working directory specified by `work_dir` in the config file. + +By default, we evaluate the model on the validation set after several iterations. You can change the evaluation interval by adding the interval argument in the training config as follows: +```python +evaluation = dict(interval=1, by_epoch=True) # This evaluates the model per epoch. +``` + + + +### Train with Single/Multiple GPUs + +```shell +./tools/dist_train.sh ${CONFIG_FILE} ${WORK_DIR} ${GPU_NUM} [optional arguments] +``` + +Optional Arguments: + +- `--no-validate` (**not suggested**): By default, the codebase will perform evaluation at every k-th iteration during training. To disable this behavior, use `--no-validate`. + + +#### Train with Toy Dataset. +We provide a toy dataset under `tests/data`, and you can train a toy model directly, before the academic dataset is prepared. + +For example, train a text recognition task with `seg` method and toy dataset, +``` +./tools/dist_train.sh configs/textrecog/seg/seg_r31_1by16_fpnocr_toy_dataset.py work_dirs/seg 1 +``` + +And train a text recognition task with `sar` method and toy dataset, +``` +./tools/dist_train.sh configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py work_dirs/sar 1 +``` + + +### Train with Slurm + +If you run MMOCR on a cluster managed with [Slurm](https://slurm.schedmd.com/), you can use the script `slurm_train.sh`. + +```shell +[GPUS=${GPUS}] ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE} ${WORK_DIR} +``` + +Here is an example of using 8 GPUs to train a text detection model on the dev partition. + +```shell +GPUS=8 ./tools/slurm_train.sh dev psenet-ic15 configs/textdet/psenet/psenet_r50_fpnf_sbn_1x_icdar2015.py /nfs/xxxx/psenet-ic15 +``` + +You can check [slurm_train.sh](https://github.com/open-mmlab/mmocr/blob/master/tools/slurm_train.sh) for full arguments and environment variables. + + +### Launch Multiple Jobs on a Single Machine + +If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs, +you need to specify different ports (29500 by default) for each job to avoid communication conflicts. + +If you use `dist_train.sh` to launch training jobs, you can set the ports in the command shell. + +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4 +CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4 +``` + +If you launch training jobs with Slurm, you need to modify the config files to set different communication ports. + +In `config1.py`, +```python +dist_params = dict(backend='nccl', port=29500) +``` + +In `config2.py`, +```python +dist_params = dict(backend='nccl', port=29501) +``` + +Then you can launch two jobs with `config1.py` ang `config2.py`. + +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR} +CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR} +``` + + + +## Useful Tools + +We provide numerous useful tools under `mmocr/tools` directory. + + +### Publish a Model + +Before you upload a model to AWS, you may want to +(1) convert the model weights to CPU tensors, (2) delete the optimizer states and +(3) compute the hash of the checkpoint file and append the hash id to the filename. + +```shell +python tools/publish_model.py ${INPUT_FILENAME} ${OUTPUT_FILENAME} +``` + +E.g., + +```shell +python tools/publish_model.py work_dirs/psenet/latest.pth psenet_r50_fpnf_sbn_1x_20190801.pth +``` + +The final output filename will be `psenet_r50_fpnf_sbn_1x_20190801-{hash id}.pth`. + + +## Customized Settings + + +### Flexible Dataset +To support the tasks of `text detection`, `text recognition` and `key information extraction`, we have designed a new type of dataset which consists of `loader` and `parser` to load and parse different types of annotation files. +- **loader**: Load the annotation file. There are two types of loader, `HardDiskLoader` and `LmdbLoader` + - `HardDiskLoader`: Load `txt` format annotation file from hard disk to memory. + - `LmdbLoader`: Load `lmdb` format annotation file with lmdb backend, which is very useful for **extremely large** annotation files to avoid out-of-memory problem when ten or more GPUs are used, since each GPU will start multiple processes to load annotation file to memory. +- **parser**: Parse the annotation file line-by-line and return with `dict` format. There are two types of parser, `LineStrParser` and `LineJsonParser`. + - `LineStrParser`: Parse one line in ann file while treating it as a string and separating it to several parts by a `separator`. It can be used on tasks with simple annotation files such as text recognition where each line of the annotation files contains the `filename` and `label` attribute only. + - `LineJsonParser`: Parse one line in ann file while treating it as a json-string and using `json.loads` to convert it to `dict`. It can be used on tasks with complex annotation files such as text detection where each line of the annotation files contains multiple attributes (e.g. `filename`, `height`, `width`, `box`, `segmentation`, `iscrowd`, `category_id`, etc.). + +Here we show some examples of using different combination of `loader` and `parser`. + + +#### Encoder-Decoder-Based Text Recognition Task +```python +dataset_type = 'OCRDataset' +img_prefix = 'tests/data/ocr_toy_dataset/imgs' +train_anno_file = 'tests/data/ocr_toy_dataset/label.txt' +train = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file, + loader=dict( + type='HardDiskLoader', + repeat=10, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) +``` +You can check the content of the annotation file in `tests/data/ocr_toy_dataset/label.txt`. +The combination of `HardDiskLoader` and `LineStrParser` will return a dict for each file by calling `__getitem__`: `{'filename': '1223731.jpg', 'text': 'GRAND'}`. + + +##### Optional Arguments: + +- `repeat`: The number of repeated lines in the annotation files. For example, if there are `10` lines in the annotation file, setting `repeat=10` will generate a corresponding annotation file with size `100`. + +If the annotation file is extreme large, you can convert it from txt format to lmdb format with the following command: +```python +python tools/data_converter/txt2lmdb.py -i ann_file.txt -o ann_file.lmdb +``` + +After that, you can use `LmdbLoader` in dataset like below. +```python +img_prefix = 'tests/data/ocr_toy_dataset/imgs' +train_anno_file = 'tests/data/ocr_toy_dataset/label.lmdb' +train = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=train_anno_file, + loader=dict( + type='LmdbLoader', + repeat=10, + parser=dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ')), + pipeline=train_pipeline, + test_mode=False) +``` + + +#### Segmentation-Based Text Recognition Task +```python +prefix = 'tests/data/ocr_char_ann_toy_dataset/' +train = dict( + type='OCRSegDataset', + img_prefix=prefix + 'imgs', + ann_file=prefix + 'instances_train.txt', + loader=dict( + type='HardDiskLoader', + repeat=10, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'annotations', 'text'])), + pipeline=train_pipeline, + test_mode=True) +``` +You can check the content of the annotation file in `tests/data/ocr_char_ann_toy_dataset/instances_train.txt`. +The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for each file by calling `__getitem__` each time: +```python +{"file_name": "resort_88_101_1.png", "annotations": [{"char_text": "F", "char_box": [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]}, {"char_text": "r", "char_box": [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]}, {"char_text": "o", "char_box": [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]}, {"char_text": "m", "char_box": [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]}, {"char_text": ":", "char_box": [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]}], "text": "From:"} +``` + + +#### Text Detection Task +```python +dataset_type = 'TextDetDataset' +img_prefix = 'tests/data/toy_dataset/imgs' +test_anno_file = 'tests/data/toy_dataset/instances_test.txt' +test = dict( + type=dataset_type, + img_prefix=img_prefix, + ann_file=test_anno_file, + loader=dict( + type='HardDiskLoader', + repeat=4, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])), + pipeline=test_pipeline, + test_mode=True) +``` +The results are generated in the same way as the segmentation-based text recognition task above. +You can check the content of the annotation file in `tests/data/toy_dataset/instances_test.txt`. +The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for each file by calling `__getitem__`: +```python +{"file_name": "test/img_10.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [260.0, 138.0, 24.0, 20.0], "segmentation": [[261, 138, 284, 140, 279, 158, 260, 158]]}, {"iscrowd": 0, "category_id": 1, "bbox": [288.0, 138.0, 129.0, 23.0], "segmentation": [[288, 138, 417, 140, 416, 161, 290, 157]]}, {"iscrowd": 0, "category_id": 1, "bbox": [743.0, 145.0, 37.0, 18.0], "segmentation": [[743, 145, 779, 146, 780, 163, 746, 163]]}, {"iscrowd": 0, "category_id": 1, "bbox": [783.0, 129.0, 50.0, 26.0], "segmentation": [[783, 129, 831, 132, 833, 155, 785, 153]]}, {"iscrowd": 1, "category_id": 1, "bbox": [831.0, 133.0, 43.0, 23.0], "segmentation": [[831, 133, 870, 135, 874, 156, 835, 155]]}, {"iscrowd": 1, "category_id": 1, "bbox": [159.0, 204.0, 72.0, 15.0], "segmentation": [[159, 205, 230, 204, 231, 218, 159, 219]]}, {"iscrowd": 1, "category_id": 1, "bbox": [785.0, 158.0, 75.0, 21.0], "segmentation": [[785, 158, 856, 158, 860, 178, 787, 179]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1011.0, 157.0, 68.0, 16.0], "segmentation": [[1011, 157, 1079, 160, 1076, 173, 1011, 170]]}]} +``` + + + +### COCO-like Dataset +For text detection, you can also use an annotation file in a COCO format that is defined in [mmdet](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py): +```python +dataset_type = 'IcdarDataset' +prefix = 'tests/data/toy_dataset/' +test=dict( + type=dataset_type, + ann_file=prefix + 'instances_test.json', + img_prefix=prefix + 'imgs', + pipeline=test_pipeline) +``` +You can check the content of the annotation file in `tests/data/toy_dataset/instances_test.json` +- The icdar2015/2017 annotations have to be converted into the COCO format using `tools/data_converter/icdar_converter.py`: + + ```shell + python tools/data_converter/icdar_converter.py ${src_root_path} -o ${out_path} -d ${data_type} --split-list training validation test + ``` + +- The ctw1500 annotations have to be converted into the COCO format using `tools/data_converter/ctw1500_converter.py`: + + ```shell + python tools/data_converter/ctw1500_converter.py ${src_root_path} -o ${out_path} --split-list training test + ``` +``` diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 00000000..8242dffc --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,38 @@ +Welcome to MMOCR's documentation! +======================================= + +.. toctree:: + :maxdepth: 2 + :caption: Get Started + + install.md + getting_started.md + technical_details.md + contributing.md + +.. toctree:: + :maxdepth: 2 + :caption: Model Zoo + + modelzoo.md + textdet_models.md + textrecog_models.md + kie_models.md + +.. toctree:: + :maxdepth: 2 + :caption: Notes + + changelog.md + faq.md + +.. toctree:: + :caption: API Reference + + api.rst + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`search` diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 00000000..78d1b0be --- /dev/null +++ b/docs/install.md @@ -0,0 +1,249 @@ + +# Installation + + +- [Installation](#installation) + - [Prerequisites](#prerequisites) + - [Step-by-Step Installation Instructions](#step-by-step-installation-instructions) + - [Full Set-up Script](#full-set-up-script) + - [Another option: Docker Image](#another-option-docker-image) + - [Prepare Datasets](#prepare-datasets) + + + +## Prerequisites + +- Linux (Windows is not officially supported) +- Python 3.7 +- PyTorch 1.5 or higher +- torchvision 0.6.0 +- CUDA 10.1 +- NCCL 2 +- GCC 5.4.0 or higher +- [mmcv](https://github.com/open-mmlab/mmcv) 1.2.6 + +We have tested the following versions of OS and softwares: + +- OS: Ubuntu 16.04 +- CUDA: 10.1 +- GCC(G++): 5.4.0 +- mmcv 1.2.6 +- PyTorch 1.5 +- torchvision 0.6.0 + +MMOCR depends on Pytorch and mmdetection v2.9.0. + + +## Step-by-Step Installation Instructions + +a. Create a conda virtual environment and activate it. + +```shell +conda create -n open-mmlab python=3.7 -y +conda activate open-mmlab +``` + +b. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/), e.g., + +```shell +conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch +``` +Note: Make sure that your compilation CUDA version and runtime CUDA version match. +You can check the supported CUDA version for precompiled packages on the [PyTorch website](https://pytorch.org/). + +`E.g. 1` If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install +PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1. + +```python +conda install pytorch cudatoolkit=10.1 torchvision -c pytorch +``` + +`E.g. 2` If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install +PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2. + +```python +conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch +``` + +If you build PyTorch from source instead of installing the prebuilt package, +you can use more CUDA versions such as 9.0. + +c. Create a folder called `code` and clone the mmcv repository into it. + +```shell +mkdir code +cd code +git clone https://github.com/open-mmlab/mmcv.git +cd mmcv +git checkout -b v1.2.6 v1.2.6 +pip install -r requirements.txt +MMCV_WITH_OPS=1 pip install -v -e . +``` + +d. Clone the mmdetection repository into it. The mmdetection repo is separate from the mmcv repo in `code`. + +```shell +cd .. +git clone https://github.com/open-mmlab/mmdetection.git +cd mmdetection +git checkout -b v2.9.0 v2.9.0 +pip install -r requirements.txt +pip install -v -e . +export PYTHONPATH=$(pwd):$PYTHONPATH +``` + +Note that we have tested mmdetection v2.9.0 only. Other versions might be incompatible. + +e. Clone the mmocr repository into it. The mmdetection repo is separate from the mmcv and mmdetection repo in `code`. + +```shell +cd .. +git clone https://github.com/open-mmlab/mmocr.git +cd mmocr +``` + +f. Install build requirements and then install MMOCR. + +```shell +pip install -r requirements.txt +pip install -v -e . # or "python setup.py build_ext --inplace" +export PYTHONPATH=$(pwd):$PYTHONPATH +``` + + +## Full Set-up Script + +Here is the full script for setting up mmocr with conda. + +```shell +conda create -n open-mmlab python=3.7 -y +conda activate open-mmlab + +# install latest pytorch prebuilt with the default prebuilt CUDA version (usually the latest) +conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch + +# install mmcv +mkdir code +cd code +git clone https://github.com/open-mmlab/mmcv.git +cd mmcv # code/mmcv +git checkout -b v1.2.6 v1.2.6 +pip install -r requirements.txt +MMCV_WITH_OPS=1 pip install -v -e . + +# install mmdetection +cd .. # exit to code +git clone https://github.com/open-mmlab/mmdetection.git +cd mmdetection # code/mmdetection +git checkout -b v2.9.0 v2.9.0 +pip install -r requirements.txt +pip install -v -e . +export PYTHONPATH=$(pwd):$PYTHONPATH + +# install mmocr +cd .. +git clone https://github.com/open-mmlab/mmocr.git +cd mmocr # code/mmocr + +pip install -r requirements.txt +pip install -v -e . # or "python setup.py build_ext --inplace" +export PYTHONPATH=$(pwd):$PYTHONPATH +``` + + +## Another option: Docker Image + +We provide a [Dockerfile](https://github.com/open-mmlab/mmocr/blob/master/docker/Dockerfile) to build an image. + +```shell +# build an image with PyTorch 1.5, CUDA 10.1 +docker build -t mmocr docker/ +``` + +Run it with + +```shell +docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmocr/data mmocr +``` + + +## Prepare Datasets + +It is recommended to symlink the dataset root to `mmocr/data`. Please refer to [datasets.md](datasets.md) to prepare your datasets. +If your folder structure is different, you may need to change the corresponding paths in config files. + +The `mmocr` folder is organized as follows: +``` +mmocr +. +├── configs +│   ├── _base_ +│   ├── kie +│   ├── textdet +│   └── textrecog +├── demo +│   ├── demo_text_det.jpg +│   ├── demo_text_recog.jpg +│   ├── image_demo.py +│   └── webcam_demo.py +├── docs +│   ├── api.rst +│   ├── changelog.md +│   ├── code_of_conduct.md +│   ├── conf.py +│   ├── contributing.md +│   ├── datasets.md +│   ├── getting_started.md +│   ├── index.rst +│   ├── install.md +│   ├── make.bat +│   ├── Makefile +│   ├── merge_docs.sh +│   ├── requirements.txt +│   ├── res +│   ├── stats.py +│   └── technical_details.md +├── LICENSE +├── mmocr +│   ├── apis +│   ├── core +│   ├── datasets +│   ├── __init__.py +│   ├── models +│   ├── utils +│   └── version.py +├── README.md +├── requirements +│   ├── build.txt +│   ├── docs.txt +│   ├── optional.txt +│   ├── readthedocs.txt +│   ├── runtime.txt +│   └── tests.txt +├── requirements.txt +├── resources +│   ├── illustration.jpg +│   └── mmocr-logo.png +├── setup.cfg +├── setup.py +├── tests +│   ├── data +│   ├── test_dataset +│   ├── test_metrics +│   ├── test_models +│   ├── test_tools +│   └── test_utils +└── tools + ├── data + ├── dist_test.sh + ├── dist_train.sh + ├── ocr_test_imgs.py + ├── ocr_test_imgs.sh + ├── publish_model.py + ├── slurm_test.sh + ├── slurm_train.sh + ├── test_imgs.py + ├── test_imgs.sh + ├── test.py + └── train.py +``` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..8a3a0e25 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/merge_docs.sh b/docs/merge_docs.sh new file mode 100755 index 00000000..2482fb74 --- /dev/null +++ b/docs/merge_docs.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +sed -i '$a\\n' ../configs/kie/*/*.md +sed -i '$a\\n' ../configs/textdet/*/*.md +sed -i '$a\\n' ../configs/textrecog/*/*.md + +# gather models +cat ../configs/kie/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Kie Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >kie_models.md +cat ../configs/textdet/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Detection Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textdet_models.md +cat ../configs/textrecog/*/*.md | sed "s/md###t/html#t/g" | sed "s/#/#&/" | sed '1i\# Text Recognition Models' | sed 's/](\/docs\//](/g' | sed 's=](/=](https://github.com/open-mmlab/mmocr/tree/master/=g' >textrecog_models.md diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..89fbf86c --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +recommonmark +sphinx +sphinx_markdown_tables +sphinx_rtd_theme diff --git a/docs/res/git-workflow-feature.png b/docs/res/git-workflow-feature.png new file mode 100644 index 00000000..4d9f9083 Binary files /dev/null and b/docs/res/git-workflow-feature.png differ diff --git a/docs/res/git-workflow-master-develop.png b/docs/res/git-workflow-master-develop.png new file mode 100644 index 00000000..624111c8 Binary files /dev/null and b/docs/res/git-workflow-master-develop.png differ diff --git a/docs/stats.py b/docs/stats.py new file mode 100755 index 00000000..ef337ba5 --- /dev/null +++ b/docs/stats.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +import functools as func +import glob +import re +from os.path import basename, splitext + +import numpy as np +import titlecase + + +def anchor(name): + return re.sub(r'-+', '-', re.sub(r'[^a-zA-Z0-9]', '-', + name.strip().lower())).strip('-') + + +# Count algorithms + +files = sorted(glob.glob('*_models.md')) +# files = sorted(glob.glob('docs/*_models.md')) + +stats = [] + +for f in files: + with open(f, 'r') as content_file: + content = content_file.read() + + # title + title = content.split('\n')[0].replace('#', '') + + # count papers + papers = set((papertype, titlecase.titlecase(paper.lower().strip())) + for (papertype, paper) in re.findall( + r'\n\s*\[([A-Z]+?)\]\s*\n.*?\btitle\s*=\s*{(.*?)}', + content, re.DOTALL)) + # paper links + revcontent = '\n'.join(list(reversed(content.splitlines()))) + paperlinks = {} + for _, p in papers: + print(p) + q = p.replace('\\', '\\\\').replace('?', '\\?') + paperlinks[p] = ' '.join( + (f'[⇨]({splitext(basename(f))[0]}.html#{anchor(paperlink)})' + for paperlink in re.findall( + rf'\btitle\s*=\s*{{\s*{q}\s*}}.*?\n## (.*?)\s*[,;]?\s*\n', + revcontent, re.DOTALL | re.IGNORECASE))) + print(' ', paperlinks[p]) + paperlist = '\n'.join( + sorted(f' - [{t}] {x} ({paperlinks[x]})' for t, x in papers)) + # count configs + configs = set(x.lower().strip() + for x in re.findall(r'https.*configs/.*\.py', content)) + + # count ckpts + ckpts = set(x.lower().strip() + for x in re.findall(r'https://download.*\.pth', content) + if 'mmaction' in x) + + statsmsg = f""" +## [{title}]({f}) + +* Number of checkpoints: {len(ckpts)} +* Number of configs: {len(configs)} +* Number of papers: {len(papers)} +{paperlist} + + """ + + stats.append((papers, configs, ckpts, statsmsg)) + +allpapers = func.reduce(lambda a, b: a.union(b), [p for p, _, _, _ in stats]) +allconfigs = func.reduce(lambda a, b: a.union(b), [c for _, c, _, _ in stats]) +allckpts = func.reduce(lambda a, b: a.union(b), [c for _, _, c, _ in stats]) +msglist = '\n'.join(x for _, _, _, x in stats) + +papertypes, papercounts = np.unique([t for t, _ in allpapers], + return_counts=True) +countstr = '\n'.join( + [f' - {t}: {c}' for t, c in zip(papertypes, papercounts)]) + +modelzoo = f""" +# Overview + +* Number of checkpoints: {len(allckpts)} +* Number of configs: {len(allconfigs)} +* Number of papers: {len(allpapers)} +{countstr} + +For supported datasets, see [datasets overview](datasets.md). + +{msglist} +""" + +with open('modelzoo.md', 'w') as f: + f.write(modelzoo) diff --git a/mmocr/__init__.py b/mmocr/__init__.py new file mode 100644 index 00000000..1c4f7e8f --- /dev/null +++ b/mmocr/__init__.py @@ -0,0 +1,3 @@ +from .version import __version__, short_version + +__all__ = ['__version__', 'short_version'] diff --git a/mmocr/apis/__init__.py b/mmocr/apis/__init__.py new file mode 100644 index 00000000..1c4b014c --- /dev/null +++ b/mmocr/apis/__init__.py @@ -0,0 +1,3 @@ +from .inference import model_inference + +__all__ = ['model_inference'] diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py new file mode 100644 index 00000000..615682ce --- /dev/null +++ b/mmocr/apis/inference.py @@ -0,0 +1,43 @@ +import torch +from mmcv.ops import RoIPool +from mmcv.parallel import collate, scatter + +from mmdet.datasets.pipelines import Compose + + +def model_inference(model, img): + """Inference image(s) with the detector. + + Args: + model (nn.Module): The loaded detector. + imgs (str): Image files. + + Returns: + result (dict): Detection results. + """ + assert isinstance(img, str) + + cfg = model.cfg + device = next(model.parameters()).device # model device + data = dict(img_info=dict(filename=img), img_prefix=None) + # build the data pipeline + test_pipeline = Compose(cfg.data.test.pipeline) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + + # process img_metas + data['img_metas'] = data['img_metas'][0].data + + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + for m in model.modules(): + assert not isinstance( + m, RoIPool + ), 'CPU inference with RoIPool is not supported currently.' + + # forward the model + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data)[0] + return result diff --git a/mmocr/core/__init__.py b/mmocr/core/__init__.py new file mode 100644 index 00000000..cb717d4b --- /dev/null +++ b/mmocr/core/__init__.py @@ -0,0 +1,3 @@ +from .evaluation import * # noqa: F401, F403 +from .mask import * # noqa: F401, F403 +from .visualize import * # noqa: F401, F403 diff --git a/mmocr/core/evaluation/__init__.py b/mmocr/core/evaluation/__init__.py new file mode 100644 index 00000000..493e894e --- /dev/null +++ b/mmocr/core/evaluation/__init__.py @@ -0,0 +1,10 @@ +from .hmean import eval_hmean +from .hmean_ic13 import eval_hmean_ic13 +from .hmean_iou import eval_hmean_iou +from .kie_metric import compute_f1_score +from .ocr_metric import eval_ocr_metric + +__all__ = [ + 'eval_hmean_ic13', 'eval_hmean_iou', 'eval_ocr_metric', 'eval_hmean', + 'compute_f1_score' +] diff --git a/mmocr/core/evaluation/hmean.py b/mmocr/core/evaluation/hmean.py new file mode 100644 index 00000000..bbb7d679 --- /dev/null +++ b/mmocr/core/evaluation/hmean.py @@ -0,0 +1,149 @@ +from operator import itemgetter + +import mmcv +from mmcv.utils import print_log + +import mmocr.utils as utils +from mmocr.core.evaluation import hmean_ic13, hmean_iou +from mmocr.core.evaluation.utils import (filter_2dlist_result, + select_top_boundary) +from mmocr.core.mask import extract_boundary + + +def output_ranklist(img_results, img_infos, out_file): + """Output the worst results for debugging. + + Args: + img_results (list[dict]): Image result list. + img_infos (list[dict]): Image information list. + out_file (str): The output file path. + + Returns: + sorted_results (list[dict]): Image results sorted by hmean. + """ + assert utils.is_type_list(img_results, dict) + assert utils.is_type_list(img_infos, dict) + assert isinstance(out_file, str) + assert out_file.endswith('json') + + sorted_results = [] + for inx, result in enumerate(img_results): + name = img_infos[inx]['file_name'] + img_result = result + img_result['file_name'] = name + sorted_results.append(img_result) + sorted_results = sorted( + sorted_results, key=itemgetter('hmean'), reverse=False) + + mmcv.dump(sorted_results, file=out_file) + + return sorted_results + + +def get_gt_masks(ann_infos): + """Get ground truth masks and ignored masks. + + Args: + ann_infos (list[dict]): Each dict contains annotation + infos of one image, containing following keys: + masks, masks_ignore. + Returns: + gt_masks (list[list[list[int]]]): Ground truth masks. + gt_masks_ignore (list[list[list[int]]]): Ignored masks. + """ + assert utils.is_type_list(ann_infos, dict) + + gt_masks = [] + gt_masks_ignore = [] + for ann_info in ann_infos: + masks = ann_info['masks'] + mask_gt = [] + for mask in masks: + assert len(mask[0]) >= 8 and len(mask[0]) % 2 == 0 + mask_gt.append(mask[0]) + gt_masks.append(mask_gt) + + masks_ignore = ann_info['masks_ignore'] + mask_gt_ignore = [] + for mask_ignore in masks_ignore: + assert len(mask_ignore[0]) >= 8 and len(mask_ignore[0]) % 2 == 0 + mask_gt_ignore.append(mask_ignore[0]) + gt_masks_ignore.append(mask_gt_ignore) + + return gt_masks, gt_masks_ignore + + +def eval_hmean(results, + img_infos, + ann_infos, + metrics={'hmean-iou'}, + score_thr=0.3, + rank_list=None, + logger=None, + **kwargs): + """Evaluation in hmean metric. + + Args: + results (list[dict]): Each dict corresponds to one image, + containing the following keys: boundary_result + img_infos (list[dict]): Each dict corresponds to one image, + containing the following keys: filename, height, width + ann_infos (list[dict]): Each dict corresponds to one image, + containing the following keys: masks, masks_ignore + score_thr (float): Score threshold of prediction map. + metrics (set{str}): Hmean metric set, should be one or all of + {'hmean-iou', 'hmean-ic13'} + Returns: + dict[str: float] + """ + assert utils.is_type_list(results, dict) + assert utils.is_type_list(img_infos, dict) + assert utils.is_type_list(ann_infos, dict) + assert len(results) == len(img_infos) == len(ann_infos) + assert isinstance(metrics, set) + + gts, gts_ignore = get_gt_masks(ann_infos) + + preds = [] + pred_scores = [] + for result in results: + _, texts, scores = extract_boundary(result) + if len(texts) > 0: + assert utils.valid_boundary(texts[0], False) + valid_texts, valid_text_scores = filter_2dlist_result( + texts, scores, score_thr) + preds.append(valid_texts) + pred_scores.append(valid_text_scores) + + eval_results = {} + for metric in metrics: + msg = f'Evaluating {metric}...' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + best_result = dict(hmean=-1) + for iter in range(3, 10): + thr = iter * 0.1 + top_preds = select_top_boundary(preds, pred_scores, thr) + if metric == 'hmean-iou': + result, img_result = hmean_iou.eval_hmean_iou( + top_preds, gts, gts_ignore) + elif metric == 'hmean-ic13': + result, img_result = hmean_ic13.eval_hmean_ic13( + top_preds, gts, gts_ignore) + else: + raise NotImplementedError + if rank_list is not None: + output_ranklist(img_result, img_infos, rank_list) + + print_log( + 'thr {0:.1f}, recall:{1[recall]:.3f}, ' + 'precision: {1[precision]:.3f}, ' + 'hmean:{1[hmean]:.3f}'.format(thr, result), + logger=logger) + if result['hmean'] > best_result['hmean']: + best_result = result + eval_results[metric + ':recall'] = best_result['recall'] + eval_results[metric + ':precision'] = best_result['precision'] + eval_results[metric + ':hmean'] = best_result['hmean'] + return eval_results diff --git a/mmocr/core/evaluation/hmean_ic13.py b/mmocr/core/evaluation/hmean_ic13.py new file mode 100644 index 00000000..d3c69467 --- /dev/null +++ b/mmocr/core/evaluation/hmean_ic13.py @@ -0,0 +1,216 @@ +import numpy as np + +import mmocr.utils as utils +from . import utils as eval_utils + + +def compute_recall_precision(gt_polys, pred_polys): + """Compute the recall and the precision matrices between gt and predicted + polygons. + + Args: + gt_polys (list[Polygon]): List of gt polygons. + pred_polys (list[Polygon]): List of predicted polygons. + + Returns: + recall (ndarray): Recall matrix of size gt_num x det_num. + precision (ndarray): Precision matrix of size gt_num x det_num. + """ + assert isinstance(gt_polys, list) + assert isinstance(pred_polys, list) + + gt_num = len(gt_polys) + det_num = len(pred_polys) + sz = [gt_num, det_num] + + recall = np.zeros(sz) + precision = np.zeros(sz) + # compute area recall and precision for each (gt, det) pair + # in one img + for gt_id in range(gt_num): + for pred_id in range(det_num): + gt = gt_polys[gt_id] + det = pred_polys[pred_id] + + inter_area, _ = eval_utils.poly_intersection(det, gt) + gt_area = gt.area() + det_area = det.area() + if gt_area != 0: + recall[gt_id, pred_id] = inter_area / gt_area + if det_area != 0: + precision[gt_id, pred_id] = inter_area / det_area + + return recall, precision + + +def eval_hmean_ic13(det_boxes, + gt_boxes, + gt_ignored_boxes, + precision_thr=0.4, + recall_thr=0.8, + center_dist_thr=1.0, + one2one_score=1., + one2many_score=0.8, + many2one_score=1.): + """Evalute hmean of text detection using the icdar2013 standard. + + Args: + det_boxes (list[list[list[float]]]): List of arrays of shape (n, 2k). + Each element is the det_boxes for one img. k>=4. + gt_boxes (list[list[list[float]]]): List of arrays of shape (m, 2k). + Each element is the gt_boxes for one img. k>=4. + gt_ignored_boxes (list[list[list[float]]]): List of arrays of + (l, 2k). Each element is the ignored gt_boxes for one img. k>=4. + precision_thr (float): Precision threshold of the iou of one + (gt_box, det_box) pair. + recall_thr (float): Recall threshold of the iou of one + (gt_box, det_box) pair. + center_dist_thr (float): Distance threshold of one (gt_box, det_box) + center point pair. + one2one_score (float): Reward when one gt matches one det_box. + one2many_score (float): Reward when one gt matches many det_boxes. + many2one_score (float): Reward when many gts match one det_box. + + Returns: + hmean (tuple[dict]): Tuple of dicts which encodes the hmean for + the dataset and all images. + """ + assert utils.is_3dlist(det_boxes) + assert utils.is_3dlist(gt_boxes) + assert utils.is_3dlist(gt_ignored_boxes) + + assert 0 <= precision_thr <= 1 + assert 0 <= recall_thr <= 1 + assert center_dist_thr > 0 + assert 0 <= one2one_score <= 1 + assert 0 <= one2many_score <= 1 + assert 0 <= many2one_score <= 1 + + img_num = len(det_boxes) + assert img_num == len(gt_boxes) + assert img_num == len(gt_ignored_boxes) + + dataset_gt_num = 0 + dataset_pred_num = 0 + dataset_hit_recall = 0.0 + dataset_hit_prec = 0.0 + + img_results = [] + + for i in range(img_num): + gt = gt_boxes[i] + gt_ignored = gt_ignored_boxes[i] + pred = det_boxes[i] + + gt_num = len(gt) + ignored_num = len(gt_ignored) + pred_num = len(pred) + + accum_recall = 0. + accum_precision = 0. + + gt_points = gt + gt_ignored + gt_polys = [eval_utils.points2polygon(p) for p in gt_points] + gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))] + gt_num = len(gt_polys) + + pred_polys, pred_points, pred_ignored_index = eval_utils.ignore_pred( + pred, gt_ignored_index, gt_polys, precision_thr) + + if pred_num > 0 and gt_num > 0: + + gt_hit = np.zeros(gt_num, np.int8).tolist() + pred_hit = np.zeros(pred_num, np.int8).tolist() + + # compute area recall and precision for each (gt, pred) pair + # in one img. + recall_mat, precision_mat = compute_recall_precision( + gt_polys, pred_polys) + + # match one gt to one pred box. + for gt_id in range(gt_num): + for pred_id in range(pred_num): + if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0 + or gt_id in gt_ignored_index + or pred_id in pred_ignored_index): + continue + match = eval_utils.one2one_match_ic13( + gt_id, pred_id, recall_mat, precision_mat, recall_thr, + precision_thr) + + if match: + gt_point = np.array(gt_points[gt_id]) + det_point = np.array(pred_points[pred_id]) + + norm_dist = eval_utils.box_center_distance( + det_point, gt_point) + norm_dist /= eval_utils.box_diag( + det_point) + eval_utils.box_diag(gt_point) + norm_dist *= 2.0 + + if norm_dist < center_dist_thr: + gt_hit[gt_id] = 1 + pred_hit[pred_id] = 1 + accum_recall += one2one_score + accum_precision += one2one_score + + # match one gt to many det boxes. + for gt_id in range(gt_num): + if gt_id in gt_ignored_index: + continue + match, match_det_set = eval_utils.one2many_match_ic13( + gt_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_hit, pred_hit, pred_ignored_index) + + if match: + gt_hit[gt_id] = 1 + accum_recall += one2many_score + accum_precision += one2many_score * len(match_det_set) + for pred_id in match_det_set: + pred_hit[pred_id] = 1 + + # match many gt to one det box. One pair of (det,gt) are matched + # successfully if their recall, precision, normalized distance + # meet some thresholds. + for pred_id in range(pred_num): + if pred_id in pred_ignored_index: + continue + + match, match_gt_set = eval_utils.many2one_match_ic13( + pred_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_hit, pred_hit, gt_ignored_index) + + if match: + pred_hit[pred_id] = 1 + accum_recall += many2one_score * len(match_gt_set) + accum_precision += many2one_score + for gt_id in match_gt_set: + gt_hit[gt_id] = 1 + + gt_care_number = gt_num - ignored_num + pred_care_number = pred_num - len(pred_ignored_index) + + r, p, h = eval_utils.compute_hmean(accum_recall, accum_precision, + gt_care_number, pred_care_number) + + img_results.append({'recall': r, 'precision': p, 'hmean': h}) + + dataset_gt_num += gt_care_number + dataset_pred_num += pred_care_number + dataset_hit_recall += accum_recall + dataset_hit_prec += accum_precision + + total_r, total_p, total_h = eval_utils.compute_hmean( + dataset_hit_recall, dataset_hit_prec, dataset_gt_num, dataset_pred_num) + + dataset_results = { + 'num_gts': dataset_gt_num, + 'num_dets': dataset_pred_num, + 'num_recall': dataset_hit_recall, + 'num_precision': dataset_hit_prec, + 'recall': total_r, + 'precision': total_p, + 'hmean': total_h + } + + return dataset_results, img_results diff --git a/mmocr/core/evaluation/hmean_iou.py b/mmocr/core/evaluation/hmean_iou.py new file mode 100644 index 00000000..8ad0363f --- /dev/null +++ b/mmocr/core/evaluation/hmean_iou.py @@ -0,0 +1,116 @@ +import numpy as np + +import mmocr.utils as utils +from . import utils as eval_utils + + +def eval_hmean_iou(pred_boxes, + gt_boxes, + gt_ignored_boxes, + iou_thr=0.5, + precision_thr=0.5): + """Evalute hmean of text detection using IOU standard. + + Args: + pred_boxes (list[list[list[float]]]): Text boxes for an img list. Each + box has 2k (>=8) values. + gt_boxes (list[list[list[float]]]): Ground truth text boxes for an img + list. Each box has 2k (>=8) values. + gt_ignored_boxes (list[list[list[float]]]): Ignored ground truth text + boxes for an img list. Each box has 2k (>=8) values. + iou_thr (float): Iou threshold when one (gt_box, det_box) pair is + matched. + precision_thr (float): Precision threshold when one (gt_box, det_box) + pair is matched. + + Returns: + hmean (tuple[dict]): Tuple of dicts indicates the hmean for the dataset + and all images. + """ + assert utils.is_3dlist(pred_boxes) + assert utils.is_3dlist(gt_boxes) + assert utils.is_3dlist(gt_ignored_boxes) + assert 0 <= iou_thr <= 1 + assert 0 <= precision_thr <= 1 + + img_num = len(pred_boxes) + assert img_num == len(gt_boxes) + assert img_num == len(gt_ignored_boxes) + + dataset_gt_num = 0 + dataset_pred_num = 0 + dataset_hit_num = 0 + + img_results = [] + + for i in range(img_num): + gt = gt_boxes[i] + gt_ignored = gt_ignored_boxes[i] + pred = pred_boxes[i] + + gt_num = len(gt) + gt_ignored_num = len(gt_ignored) + pred_num = len(pred) + + hit_num = 0 + + # get gt polygons. + gt_all = gt + gt_ignored + gt_polys = [eval_utils.points2polygon(p) for p in gt_all] + gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))] + gt_num = len(gt_polys) + pred_polys, _, pred_ignored_index = eval_utils.ignore_pred( + pred, gt_ignored_index, gt_polys, precision_thr) + + # match. + if gt_num > 0 and pred_num > 0: + sz = [gt_num, pred_num] + iou_mat = np.zeros(sz) + + gt_hit = np.zeros(gt_num, np.int8) + pred_hit = np.zeros(pred_num, np.int8) + + for gt_id in range(gt_num): + for pred_id in range(pred_num): + gt_pol = gt_polys[gt_id] + det_pol = pred_polys[pred_id] + + iou_mat[gt_id, + pred_id] = eval_utils.poly_iou(det_pol, gt_pol) + + for gt_id in range(gt_num): + for pred_id in range(pred_num): + if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0 + or gt_id in gt_ignored_index + or pred_id in pred_ignored_index): + continue + if iou_mat[gt_id, pred_id] > iou_thr: + gt_hit[gt_id] = 1 + pred_hit[pred_id] = 1 + hit_num += 1 + + gt_care_number = gt_num - gt_ignored_num + pred_care_number = pred_num - len(pred_ignored_index) + + r, p, h = eval_utils.compute_hmean(hit_num, hit_num, gt_care_number, + pred_care_number) + + img_results.append({'recall': r, 'precision': p, 'hmean': h}) + + dataset_hit_num += hit_num + dataset_gt_num += gt_care_number + dataset_pred_num += pred_care_number + + dataset_r, dataset_p, dataset_h = eval_utils.compute_hmean( + dataset_hit_num, dataset_hit_num, dataset_gt_num, dataset_pred_num) + + dataset_results = { + 'num_gts': dataset_gt_num, + 'num_dets': dataset_pred_num, + 'num_match': dataset_hit_num, + 'recall': dataset_r, + 'precision': dataset_p, + 'hmean': dataset_h + } + + return dataset_results, img_results diff --git a/mmocr/core/evaluation/kie_metric.py b/mmocr/core/evaluation/kie_metric.py new file mode 100644 index 00000000..00dc2387 --- /dev/null +++ b/mmocr/core/evaluation/kie_metric.py @@ -0,0 +1,27 @@ +import torch + + +def compute_f1_score(preds, gts, ignores=[]): + """Compute the F1-score of prediction. + + Args: + preds (Tensor): The predicted probability NxC map + with N and C being the sample number and class + number respectively. + gts (Tensor): The ground truth vector of size N. + ignores (list): The index set of classes that are ignored when + reporting results. + Note: all samples are participated in computing. + + Returns: + The numpy list of f1-scores of valid classes. + """ + C = preds.size(1) + classes = torch.LongTensor(sorted(set(range(C)) - set(ignores))) + hist = torch.bincount( + gts * C + preds.argmax(1), minlength=C**2).view(C, C).float() + diag = torch.diag(hist) + recalls = diag / hist.sum(1).clamp(min=1) + precisions = diag / hist.sum(0).clamp(min=1) + f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8) + return f1[classes].cpu().numpy() diff --git a/mmocr/core/evaluation/ocr_metric.py b/mmocr/core/evaluation/ocr_metric.py new file mode 100644 index 00000000..5c5124f0 --- /dev/null +++ b/mmocr/core/evaluation/ocr_metric.py @@ -0,0 +1,133 @@ +import re +from difflib import SequenceMatcher + +import Levenshtein + + +def cal_true_positive_char(pred, gt): + """Calculate correct character number in prediction. + + Args: + pred (str): Prediction text. + gt (str): Ground truth text. + + Returns: + true_positive_char_num (int): The true positive number. + """ + + all_opt = SequenceMatcher(None, pred, gt) + true_positive_char_num = 0 + for opt, _, _, s2, e2 in all_opt.get_opcodes(): + if opt == 'equal': + true_positive_char_num += (e2 - s2) + else: + pass + return true_positive_char_num + + +def count_matches(pred_texts, gt_texts): + """Count the various match number for metric calculation. + + Args: + pred_texts (list[str]): Predicted text string. + gt_texts (list[str]): Ground truth text string. + + Returns: + match_res: (dict[str: int]): Match number used for + metric calculation. + """ + match_res = { + 'gt_char_num': 0, + 'pred_char_num': 0, + 'true_positive_char_num': 0, + 'gt_word_num': 0, + 'match_word_num': 0, + 'match_word_ignore_case': 0, + 'match_word_ignore_case_symbol': 0 + } + comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') + norm_ed_sum = 0.0 + for pred_text, gt_text in zip(pred_texts, gt_texts): + if gt_text == pred_text: + match_res['match_word_num'] += 1 + gt_text_lower = gt_text.lower() + pred_text_lower = pred_text.lower() + if gt_text_lower == pred_text_lower: + match_res['match_word_ignore_case'] += 1 + gt_text_lower_ignore = comp.sub('', gt_text_lower) + pred_text_lower_ignore = comp.sub('', pred_text_lower) + if gt_text_lower_ignore == pred_text_lower_ignore: + match_res['match_word_ignore_case_symbol'] += 1 + match_res['gt_word_num'] += 1 + + # normalized edit distance + edit_dist = Levenshtein.distance(pred_text_lower_ignore, + gt_text_lower_ignore) + norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore), + len(pred_text_lower_ignore)) + norm_ed_sum += norm_ed + + # number to calculate char level recall & precision + match_res['gt_char_num'] += len(gt_text_lower_ignore) + match_res['pred_char_num'] += len(pred_text_lower_ignore) + true_positive_char_num = cal_true_positive_char( + pred_text_lower_ignore, gt_text_lower_ignore) + match_res['true_positive_char_num'] += true_positive_char_num + + normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts)) + match_res['ned'] = normalized_edit_distance + + return match_res + + +def eval_ocr_metric(pred_texts, gt_texts): + """Evaluate the text recognition performance with metric: word accuracy and + 1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details. + + Args: + pred_texts (list[str]): Text strings of prediction. + gt_texts (list[str]): Text strings of ground truth. + + Returns: + eval_res (dict[str: float]): Metric dict for text recognition, include: + - word_acc: Accuracy in word level. + - word_acc_ignore_case: Accuracy in word level, ignore letter case. + - word_acc_ignore_case_symbol: Accuracy in word level, ignore + letter case and symbol. (default metric for + academic evaluation) + - char_recall: Recall in character level, ignore + letter case and symbol. + - char_precision: Precision in character level, ignore + letter case and symbol. + - 1-N.E.D: 1 - normalized_edit_distance. + """ + assert isinstance(pred_texts, list) + assert isinstance(gt_texts, list) + assert len(pred_texts) == len(gt_texts) + + match_res = count_matches(pred_texts, gt_texts) + eps = 1e-8 + char_recall = 1.0 * match_res['true_positive_char_num'] / ( + eps + match_res['gt_char_num']) + char_precision = 1.0 * match_res['true_positive_char_num'] / ( + eps + match_res['pred_char_num']) + word_acc = 1.0 * match_res['match_word_num'] / ( + eps + match_res['gt_word_num']) + word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / ( + eps + match_res['gt_word_num']) + word_acc_ignore_case_symbol = 1.0 * match_res[ + 'match_word_ignore_case_symbol'] / ( + eps + match_res['gt_word_num']) + + eval_res = {} + eval_res['word_acc'] = word_acc + eval_res['word_acc_ignore_case'] = word_acc_ignore_case + eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol + eval_res['char_recall'] = char_recall + eval_res['char_precision'] = char_precision + eval_res['1-N.E.D'] = 1.0 - match_res['ned'] + + for key, value in eval_res.items(): + eval_res[key] = float('{:.4f}'.format(value)) + + return eval_res diff --git a/mmocr/core/evaluation/utils.py b/mmocr/core/evaluation/utils.py new file mode 100644 index 00000000..8569fde8 --- /dev/null +++ b/mmocr/core/evaluation/utils.py @@ -0,0 +1,496 @@ +import numpy as np +import Polygon as plg + +import mmocr.utils as utils + + +def ignore_pred(pred_boxes, gt_ignored_index, gt_polys, precision_thr): + """Ignore the predicted box if it hits any ignored ground truth. + + Args: + pred_boxes (list[ndarray or list]): The predicted boxes of one image. + gt_ignored_index (list[int]): The ignored ground truth index list. + gt_polys (list[Polygon]): The polygon list of one image. + precision_thr (float): The precision threshold. + + Returns: + pred_polys (list[Polygon]): The predicted polygon list. + pred_points (list[list]): The predicted box list represented + by point sequences. + pred_ignored_index (list[int]): The ignored text index list. + """ + + assert isinstance(pred_boxes, list) + assert isinstance(gt_ignored_index, list) + assert isinstance(gt_polys, list) + assert 0 <= precision_thr <= 1 + + pred_polys = [] + pred_points = [] + pred_ignored_index = [] + + gt_ignored_num = len(gt_ignored_index) + # get detection polygons + for box_id, box in enumerate(pred_boxes): + poly = points2polygon(box) + pred_polys.append(poly) + pred_points.append(box) + + if gt_ignored_num < 1: + continue + + # ignore the current detection box + # if its overlap with any ignored gt > precision_thr + for ignored_box_id in gt_ignored_index: + ignored_box = gt_polys[ignored_box_id] + inter_area, _ = poly_intersection(poly, ignored_box) + area = poly.area() + precision = 0 if area == 0 else inter_area / area + if precision > precision_thr: + pred_ignored_index.append(box_id) + break + + return pred_polys, pred_points, pred_ignored_index + + +def compute_hmean(accum_hit_recall, accum_hit_prec, gt_num, pred_num): + """Compute hmean given hit number, ground truth number and prediction + number. + + Args: + accum_hit_recall (int|float): Accumulated hits for computing recall. + accum_hit_prec (int|float): Accumulated hits for computing precision. + gt_num (int): Ground truth number. + pred_num (int): Prediction number. + + Returns: + recall (float): The recall value. + precision (float): The precision value. + hmean (float): The hmean value. + """ + + assert isinstance(accum_hit_recall, (float, int)) + assert isinstance(accum_hit_prec, (float, int)) + + assert isinstance(gt_num, int) + assert isinstance(pred_num, int) + assert accum_hit_recall >= 0.0 + assert accum_hit_prec >= 0.0 + assert gt_num >= 0.0 + assert pred_num >= 0.0 + + if gt_num == 0: + recall = 1.0 + precision = 0.0 if pred_num > 0 else 1.0 + else: + recall = float(accum_hit_recall) / gt_num + precision = 0.0 if pred_num == 0 else float(accum_hit_prec) / pred_num + + denom = recall + precision + + hmean = 0.0 if denom == 0 else (2.0 * precision * recall / denom) + + return recall, precision, hmean + + +def box2polygon(box): + """Convert box to polygon. + + Args: + box (ndarray or list): A ndarray or a list of shape (4) + that indicates 2 points. + + Returns: + polygon (Polygon): A polygon object. + """ + if isinstance(box, list): + box = np.array(box) + + assert isinstance(box, np.ndarray) + assert box.size == 4 + boundary = np.array( + [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]]) + + point_mat = boundary.reshape([-1, 2]) + return plg.Polygon(point_mat) + + +def points2polygon(points): + """Convert k points to 1 polygon. + + Args: + points (ndarray or list): A ndarray or a list of shape (2k) + that indicates k points. + + Returns: + polygon (Polygon): A polygon object. + """ + if isinstance(points, list): + points = np.array(points) + + assert isinstance(points, np.ndarray) + assert (points.size % 2 == 0) and (points.size >= 8) + + point_mat = points.reshape([-1, 2]) + return plg.Polygon(point_mat) + + +def poly_intersection(poly_det, poly_gt): + """Calculate the intersection area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + intersection_area (float): The intersection area between two polygons. + """ + assert isinstance(poly_det, plg.Polygon) + assert isinstance(poly_gt, plg.Polygon) + + poly_inter = poly_det & poly_gt + if len(poly_inter) == 0: + return 0, poly_inter + return poly_inter.area(), poly_inter + + +def poly_union(poly_det, poly_gt): + """Calculate the union area between two polygon. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + union_area (float): The union area between two polygons. + """ + assert isinstance(poly_det, plg.Polygon) + assert isinstance(poly_gt, plg.Polygon) + + area_det = poly_det.area() + area_gt = poly_gt.area() + area_inters, _ = poly_intersection(poly_det, poly_gt) + return area_det + area_gt - area_inters + + +def boundary_iou(src, target): + """Calculate the IOU between two boundaries. + + Args: + src (list): Source boundary. + target (list): Target boundary. + + Returns: + iou (float): The iou between two boundaries. + """ + assert utils.valid_boundary(src, False) + assert utils.valid_boundary(target, False) + src_poly = points2polygon(src) + target_poly = points2polygon(target) + + return poly_iou(src_poly, target_poly) + + +def poly_iou(poly_det, poly_gt): + """Calculate the IOU between two polygons. + + Args: + poly_det (Polygon): A polygon predicted by detector. + poly_gt (Polygon): A gt polygon. + + Returns: + iou (float): The IOU between two polygons. + """ + assert isinstance(poly_det, plg.Polygon) + assert isinstance(poly_gt, plg.Polygon) + area_inters, _ = poly_intersection(poly_det, poly_gt) + + return area_inters / poly_union(poly_det, poly_gt) + + +def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr, + precision_thr): + """One-to-One match gt and det with icdar2013 standards. + + Args: + gt_id (int): The ground truth id index. + det_id (int): The detection result id index. + recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the recall ratio of gt i to det j. + precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the precision ratio of gt i to det j. + recall_thr (float): The recall threshold. + precision_thr (float): The precision threshold. + Returns: + True|False: Whether the gt and det are matched. + """ + assert isinstance(gt_id, int) + assert isinstance(det_id, int) + assert isinstance(recall_mat, np.ndarray) + assert isinstance(precision_mat, np.ndarray) + assert 0 <= recall_thr <= 1 + assert 0 <= precision_thr <= 1 + + cont = 0 + for i in range(recall_mat.shape[1]): + if recall_mat[gt_id, + i] > recall_thr and precision_mat[gt_id, + i] > precision_thr: + cont += 1 + if cont != 1: + return False + + cont = 0 + for i in range(recall_mat.shape[0]): + if recall_mat[i, det_id] > recall_thr and precision_mat[ + i, det_id] > precision_thr: + cont += 1 + if cont != 1: + return False + + if recall_mat[gt_id, det_id] > recall_thr and precision_mat[ + gt_id, det_id] > precision_thr: + return True + + return False + + +def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_match_flag, det_match_flag, + det_ignored_index): + """One-to-Many match gt and detections with icdar2013 standards. + + Args: + gt_id (int): gt index. + recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the recall ratio of gt i to det j. + precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the precision ratio of gt i to det j. + recall_thr (float): The recall threshold. + precision_thr (float): The precision threshold. + gt_match_flag (ndarray): An array indicates each gt matched already. + det_match_flag (ndarray): An array indicates each box has been + matched already or not. + det_ignored_index (list): A list indicates each detection box can be + ignored or not. + + Returns: + tuple (True|False, list): The first indicates the gt is matched or not; + the second is the matched detection ids. + """ + assert isinstance(gt_id, int) + assert isinstance(recall_mat, np.ndarray) + assert isinstance(precision_mat, np.ndarray) + assert 0 <= recall_thr <= 1 + assert 0 <= precision_thr <= 1 + + assert isinstance(gt_match_flag, list) + assert isinstance(det_match_flag, list) + assert isinstance(det_ignored_index, list) + + many_sum = 0. + det_ids = [] + for det_id in range(recall_mat.shape[1]): + if gt_match_flag[gt_id] == 0 and det_match_flag[ + det_id] == 0 and det_id not in det_ignored_index: + if precision_mat[gt_id, det_id] >= precision_thr: + many_sum += recall_mat[gt_id, det_id] + det_ids.append(det_id) + if many_sum >= recall_thr: + return True, det_ids + return False, [] + + +def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_match_flag, det_match_flag, + gt_ignored_index): + """Many-to-One match gt and detections with icdar2013 standards. + + Args: + det_id (int): Detection index. + recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the recall ratio of gt i to det j. + precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j) + being the precision ratio of gt i to det j. + recall_thr (float): The recall threshold. + precision_thr (float): The precision threshold. + gt_match_flag (ndarray): An array indicates each gt has been matched + already. + det_match_flag (ndarray): An array indicates each detection box has + been matched already or not. + gt_ignored_index (list): A list indicates each gt box can be ignored + or not. + + Returns: + tuple (True|False, list): The first indicates the detection is matched + or not; the second is the matched gt ids. + """ + assert isinstance(det_id, int) + assert isinstance(recall_mat, np.ndarray) + assert isinstance(precision_mat, np.ndarray) + assert 0 <= recall_thr <= 1 + assert 0 <= precision_thr <= 1 + + assert isinstance(gt_match_flag, list) + assert isinstance(det_match_flag, list) + assert isinstance(gt_ignored_index, list) + many_sum = 0. + gt_ids = [] + for gt_id in range(recall_mat.shape[0]): + if gt_match_flag[gt_id] == 0 and det_match_flag[ + det_id] == 0 and gt_id not in gt_ignored_index: + if recall_mat[gt_id, det_id] >= recall_thr: + many_sum += precision_mat[gt_id, det_id] + gt_ids.append(gt_id) + if many_sum >= precision_thr: + return True, gt_ids + return False, [] + + +def points_center(points): + + assert isinstance(points, np.ndarray) + assert points.size % 2 == 0 + + points = points.reshape([-1, 2]) + return np.mean(points, axis=0) + + +def point_distance(p1, p2): + assert isinstance(p1, np.ndarray) + assert isinstance(p2, np.ndarray) + + assert p1.size == 2 + assert p2.size == 2 + + dist = np.square(p2 - p1) + dist = np.sum(dist) + dist = np.sqrt(dist) + return dist + + +def box_center_distance(b1, b2): + assert isinstance(b1, np.ndarray) + assert isinstance(b2, np.ndarray) + return point_distance(points_center(b1), points_center(b2)) + + +def box_diag(box): + assert isinstance(box, np.ndarray) + assert box.size == 8 + + return point_distance(box[0:2], box[4:6]) + + +def filter_2dlist_result(results, scores, score_thr): + """Find out detected results whose score > score_thr. + + Args: + results (list[list[float]]): The result list. + score (list): The score list. + score_thr (float): The score threshold. + Returns: + valid_results (list[list[float]]): The valid results. + valid_score (list[float]): The scores which correspond to the valid + results. + """ + assert isinstance(results, list) + assert len(results) == len(scores) + assert isinstance(score_thr, float) + assert 0 <= score_thr <= 1 + + inds = np.array(scores) > score_thr + valid_results = [results[inx] for inx in np.where(inds)[0].tolist()] + valid_scores = [scores[inx] for inx in np.where(inds)[0].tolist()] + return valid_results, valid_scores + + +def filter_result(results, scores, score_thr): + """Find out detected results whose score > score_thr. + + Args: + results (ndarray): The results matrix of shape (n, k). + score (ndarray): The score vector of shape (n,). + score_thr (float): The score threshold. + Returns: + valid_results (ndarray): The valid results of shape (m,k) with m<=n. + valid_score (ndarray): The scores which correspond to the + valid results. + """ + assert results.ndim == 2 + assert scores.shape[0] == results.shape[0] + assert isinstance(score_thr, float) + assert 0 <= score_thr <= 1 + + inds = scores > score_thr + valid_results = results[inds, :] + valid_scores = scores[inds] + return valid_results, valid_scores + + +def select_top_boundary(boundaries_list, scores_list, score_thr): + """Select poly boundaries with scores >= score_thr. + + Args: + boundaries_list (list[list[list[float]]]): List of boundaries. + The 1st, 2rd, and 3rd indices are for image, text and + vertice, respectively. + scores_list (list(list[float])): List of lists of scores. + score_thr (float): The score threshold to filter out bboxes. + + Returns: + selected_bboxes (list[list[list[float]]]): List of boundaries. + The 1st, 2rd, and 3rd indices are for image, text and vertice, + respectively. + """ + assert isinstance(boundaries_list, list) + assert isinstance(scores_list, list) + assert isinstance(score_thr, float) + assert len(boundaries_list) == len(scores_list) + assert 0 <= score_thr <= 1 + + selected_boundaries = [] + for boundary, scores in zip(boundaries_list, scores_list): + if len(scores) > 0: + assert len(scores) == len(boundary) + inds = [ + iter for iter in range(len(scores)) + if scores[iter] >= score_thr + ] + selected_boundaries.append([boundary[i] for i in inds]) + else: + selected_boundaries.append(boundary) + return selected_boundaries + + +def select_bboxes_via_score(bboxes_list, scores_list, score_thr): + """Select bboxes with scores >= score_thr. + + Args: + bboxes_list (list[ndarray]): List of bboxes. Each element is ndarray of + shape (n,8) + scores_list (list(list[float])): List of lists of scores. + score_thr (float): The score threshold to filter out bboxes. + + Returns: + selected_bboxes (list[ndarray]): List of bboxes. Each element is + ndarray of shape (m,8) with m<=n. + """ + assert isinstance(bboxes_list, list) + assert isinstance(scores_list, list) + assert isinstance(score_thr, float) + assert len(bboxes_list) == len(scores_list) + assert 0 <= score_thr <= 1 + + selected_bboxes = [] + for bboxes, scores in zip(bboxes_list, scores_list): + if len(scores) > 0: + assert len(scores) == bboxes.shape[0] + inds = [ + iter for iter in range(len(scores)) + if scores[iter] >= score_thr + ] + selected_bboxes.append(bboxes[inds, :]) + else: + selected_bboxes.append(bboxes) + return selected_bboxes diff --git a/mmocr/core/mask.py b/mmocr/core/mask.py new file mode 100644 index 00000000..c9b46d19 --- /dev/null +++ b/mmocr/core/mask.py @@ -0,0 +1,101 @@ +import cv2 +import numpy as np + +import mmocr.utils as utils + + +def points2boundary(points, text_repr_type, text_score=None, min_width=-1): + """Convert a text mask represented by point coordinates sequence into a + text boundary. + + Args: + points (ndarray): Mask index of size (n, 2). + text_repr_type (str): Text instance encoding type + ('quad' for quadrangle or 'poly' for polygon). + text_score (float): Text score. + + Returns: + boundary (list[float]): The text boundary point coordinates (x, y) + list. Return None if no text boundary found. + """ + assert isinstance(points, np.ndarray) + assert points.shape[1] == 2 + assert text_repr_type in ['quad', 'poly'] + assert text_score is None or 0 <= text_score <= 1 + + if text_repr_type == 'quad': + rect = cv2.minAreaRect(points) + vertices = cv2.boxPoints(rect) + boundary = [] + if min(rect[1]) > min_width: + boundary = [p for p in vertices.flatten().tolist()] + + elif text_repr_type == 'poly': + + height = np.max(points[:, 1]) + 10 + width = np.max(points[:, 0]) + 10 + + mask = np.zeros((height, width), np.uint8) + mask[points[:, 1], points[:, 0]] = 255 + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE) + boundary = list(contours[0].flatten().tolist()) + + if text_score is not None: + boundary = boundary + [text_score] + if len(boundary) < 8: + return None + + return boundary + + +def seg2boundary(seg, text_repr_type, text_score=None): + """Convert a segmentation mask to a text boundary. + + Args: + seg (ndarray): The segmentation mask. + text_repr_type (str): Text instance encoding type + ('quad' for quadrangle or 'poly' for polygon). + text_score (float): The text score. + + Returns: + boundary (list): The text boundary. Return None if no text found. + """ + assert isinstance(seg, np.ndarray) + assert isinstance(text_repr_type, str) + assert text_score is None or 0 <= text_score <= 1 + + points = np.where(seg) + # x, y order + points = np.concatenate([points[1], points[0]]).reshape(2, -1).transpose() + boundary = None + if len(points) != 0: + boundary = points2boundary(points, text_repr_type, text_score) + + return boundary + + +def extract_boundary(result): + """Extract boundaries and their scores from result. + + Args: + result (dict): The detection result with the key 'boundary_result' + of one image. + + Returns: + boundaries_with_scores (list[list[float]]): The boundary and score + list. + boundaries (list[list[float]]): The boundary list. + scores (list[float]): The boundary score list. + """ + assert isinstance(result, dict) + assert 'boundary_result' in result.keys() + + boundaries_with_scores = result['boundary_result'] + assert utils.is_2dlist(boundaries_with_scores) + + boundaries = [b[:-1] for b in boundaries_with_scores] + scores = [b[-1] for b in boundaries_with_scores] + + return (boundaries_with_scores, boundaries, scores) diff --git a/mmocr/core/visualize.py b/mmocr/core/visualize.py new file mode 100644 index 00000000..f1965970 --- /dev/null +++ b/mmocr/core/visualize.py @@ -0,0 +1,419 @@ +import math +import warnings + +import cv2 +import mmcv +import numpy as np +import torch +from matplotlib import pyplot as plt + +import mmocr.utils as utils + + +def overlay_mask_img(img, mask): + """Draw mask boundaries on image for visualization. + + Args: + img (ndarray): The input image. + mask (ndarray): The instance mask. + + Returns: + img (ndarray): The output image with instance boundaries on it. + """ + assert isinstance(img, np.ndarray) + assert isinstance(mask, np.ndarray) + + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + cv2.drawContours(img, contours, -1, (0, 255, 0), 1) + + return img + + +def show_feature(features, names, to_uint8, out_file=None): + """Visualize a list of feature maps. + + Args: + features (list(ndarray)): The feature map list. + names (list(str)): The visualized title list. + to_uint8 (list(1|0)): The list indicating whether to convent + feature maps to uint8. + out_file (str): The output file name. If set to None, + the output image will be shown without saving. + """ + assert utils.is_ndarray_list(features) + assert utils.is_type_list(names, str) + assert utils.is_type_list(to_uint8, int) + assert utils.is_none_or_type(out_file, str) + assert utils.equal_len(features, names, to_uint8) + + num = len(features) + row = col = math.ceil(math.sqrt(num)) + + for i, (f, n) in enumerate(zip(features, names)): + plt.subplot(row, col, i + 1) + plt.title(n) + if to_uint8[i]: + f = f.astype(np.uint8) + plt.imshow(f) + if out_file is None: + plt.show() + else: + plt.savefig(out_file) + + +def show_img_boundary(img, boundary): + """Show image and instance boundaires. + + Args: + img (ndarray): The input image. + boundary (list[float or int]): The input boundary. + """ + assert isinstance(img, np.ndarray) + assert utils.is_type_list(boundary, int) or utils.is_type_list( + boundary, float) + + cv2.polylines( + img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], + True, + color=(0, 255, 0), + thickness=1) + plt.imshow(img) + plt.show() + + +def show_pred_gt(preds, + gts, + show=False, + win_name='', + wait_time=0, + out_file=None): + """Show detection and ground truth for one image. + + Args: + preds (list[list[float]]): The detection boundary list. + gts (list[list[float]]): The ground truth boundary list. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): The value of waitKey param. + out_file (str): The filename of the output. + """ + assert utils.is_2dlist(preds) + assert utils.is_2dlist(gts) + assert isinstance(show, bool) + assert isinstance(win_name, str) + assert isinstance(wait_time, int) + assert utils.is_none_or_type(out_file, str) + + p_xy = [p for boundary in preds for p in boundary] + gt_xy = [g for gt in gts for g in gt] + + max_xy = np.max(np.array(p_xy + gt_xy).reshape(-1, 2), axis=0) + + width = int(max_xy[0]) + 100 + height = int(max_xy[1]) + 100 + + img = np.ones((height, width, 3), np.int8) * 255 + pred_color = mmcv.color_val('red') + gt_color = mmcv.color_val('blue') + thickness = 1 + + for boundary in preds: + cv2.polylines( + img, [np.array(boundary).astype(np.int32).reshape(-1, 1, 2)], + True, + color=pred_color, + thickness=thickness) + for gt in gts: + cv2.polylines( + img, [np.array(gt).astype(np.int32).reshape(-1, 1, 2)], + True, + color=gt_color, + thickness=thickness) + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + return img + + +def imshow_pred_boundary(img, + boundaries_with_scores, + labels, + score_thr=0, + boundary_color='blue', + text_color='blue', + thickness=1, + font_scale=0.5, + show=True, + win_name='', + wait_time=0, + out_file=None, + show_score=False): + """Draw boundaries and class labels (with scores) on an image. + + Args: + img (str or ndarray): The image to be displayed. + boundaries_with_scores (list[list[float]]): Boundaries with scores. + labels (list[int]): Labels of boundaries. + score_thr (float): Minimum score of boundaries to be shown. + boundary_color (str or tuple or :obj:`Color`): Color of boundaries. + text_color (str or tuple or :obj:`Color`): Color of texts. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str or None): The filename of the output. + show_score (bool): Whether to show text instance score. + """ + assert isinstance(img, (str, np.ndarray)) + assert utils.is_2dlist(boundaries_with_scores) + assert utils.is_type_list(labels, int) + assert utils.equal_len(boundaries_with_scores, labels) + if len(boundaries_with_scores) == 0: + warnings.warn('0 text found in ' + out_file) + return + + utils.valid_boundary(boundaries_with_scores[0]) + img = mmcv.imread(img) + + scores = np.array([b[-1] for b in boundaries_with_scores]) + inds = scores > score_thr + boundaries = [boundaries_with_scores[i][:-1] for i in np.where(inds)[0]] + scores = [scores[i] for i in np.where(inds)[0]] + labels = [labels[i] for i in np.where(inds)[0]] + + boundary_color = mmcv.color_val(boundary_color) + text_color = mmcv.color_val(text_color) + font_scale = 0.5 + + for boundary, score, label in zip(boundaries, scores, labels): + boundary_int = np.array(boundary).astype(np.int32) + + cv2.polylines( + img, [boundary_int.reshape(-1, 1, 2)], + True, + color=boundary_color, + thickness=thickness) + + if show_score: + label_text = f'{score:.02f}' + cv2.putText(img, label_text, + (boundary_int[0], boundary_int[1] - 2), + cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + return img + + +def imshow_text_char_boundary(img, + text_quads, + boundaries, + char_quads, + chars, + show=False, + thickness=1, + font_scale=0.5, + win_name='', + wait_time=-1, + out_file=None): + """Draw text boxes and char boxes on img. + + Args: + img (str or ndarray): The img to be displayed. + text_quads (list[list[int|float]]): The text boxes. + boundaries (list[list[int|float]]): The boundary list. + char_quads (list[list[list[int|float]]]): A 2d list of char boxes. + char_quads[i] is for the ith text, and char_quads[i][j] is the jth + char of the ith text. + chars (list[list[char]]). The string for each text box. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str or None): The filename of the output. + """ + assert isinstance(img, (np.ndarray, str)) + assert utils.is_2dlist(text_quads) + assert utils.is_2dlist(boundaries) + assert utils.is_3dlist(char_quads) + assert utils.is_2dlist(chars) + assert utils.equal_len(text_quads, char_quads, boundaries) + + img = mmcv.imread(img) + char_color = [mmcv.color_val('blue'), mmcv.color_val('green')] + text_color = mmcv.color_val('red') + text_inx = 0 + for text_box, boundary, char_box, txt in zip(text_quads, boundaries, + char_quads, chars): + text_box = np.array(text_box) + boundary = np.array(boundary) + + text_box = text_box.reshape(-1, 2).astype(np.int32) + cv2.polylines( + img, [text_box.reshape(-1, 1, 2)], + True, + color=text_color, + thickness=thickness) + if boundary.shape[0] > 0: + cv2.polylines( + img, [boundary.reshape(-1, 1, 2)], + True, + color=text_color, + thickness=thickness) + + for b in char_box: + b = np.array(b) + c = char_color[text_inx % 2] + b = b.astype(np.int32) + cv2.polylines( + img, [b.reshape(-1, 1, 2)], True, color=c, thickness=thickness) + + label_text = ''.join(txt) + cv2.putText(img, label_text, (text_box[0, 0], text_box[0, 1] - 2), + cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color) + text_inx = text_inx + 1 + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + return img + + +def tile_image(images): + """Combined multiple images to one vertically. + + Args: + images (list[np.ndarray]): Images to be combined. + """ + assert isinstance(images, list) + assert len(images) > 0 + + for i, _ in enumerate(images): + if len(images[i].shape) == 2: + images[i] = cv2.cvtColor(images[i], cv2.COLOR_GRAY2BGR) + + widths = [img.shape[1] for img in images] + heights = [img.shape[0] for img in images] + h, w = sum(heights), max(widths) + vis_img = np.zeros((h, w, 3), dtype=np.uint8) + + offset_y = 0 + for image in images: + img_h, img_w = image.shape[:2] + vis_img[offset_y:(offset_y + img_h), 0:img_w, :] = image + offset_y += img_h + + return vis_img + + +def imshow_text_label(img, + pred_label, + gt_label, + show=False, + win_name='', + wait_time=-1, + out_file=None): + """Draw predicted texts and ground truth texts on images. + + Args: + img (str or np.ndarray): Image filename or loaded image. + pred_label (str): Predicted texts. + gt_label (str): Ground truth texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str): The filename of the output. + """ + assert isinstance(img, (np.ndarray, str)) + assert isinstance(pred_label, str) + assert isinstance(gt_label, str) + assert isinstance(show, bool) + assert isinstance(win_name, str) + assert isinstance(wait_time, int) + + img = mmcv.imread(img) + + src_h, src_w = img.shape[:2] + resize_height = 64 + resize_width = int(1.0 * src_w / src_h * resize_height) + img = cv2.resize(img, (resize_width, resize_height)) + h, w = img.shape[:2] + pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255 + gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255 + + cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, + (0, 0, 255), 2) + images = [pred_img, img] + + if gt_label != '': + cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, + (255, 0, 0), 2) + images.append(gt_img) + + img = tile_image(images) + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + + return img + + +def imshow_edge_node(img, + result, + boxes, + idx_to_cls={}, + show=False, + win_name='', + wait_time=-1, + out_file=None): + + img = mmcv.imread(img) + h, w = img.shape[:2] + + pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 + max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + + for i, box in enumerate(boxes): + new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], + [box[0], box[3]]] + Pts = np.array([new_box], np.int32) + cv2.polylines( + img, [Pts.reshape((-1, 1, 2))], + True, + color=(255, 255, 0), + thickness=1) + x_min = int(min([point[0] for point in new_box])) + y_min = int(min([point[1] for point in new_box])) + + pred_label = str(node_pred_label[i]) + if pred_label in idx_to_cls: + pred_label = idx_to_cls[pred_label] + pred_score = '{:.2f}'.format(node_pred_score[i]) + text = pred_label + '(' + pred_score + ')' + cv2.putText(pred_img, text, (x_min * 2, y_min), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) + + vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 + vis_img[:, :w] = img + vis_img[:, w:] = pred_img + + if show: + mmcv.imshow(vis_img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(vis_img, out_file) + + return vis_img diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py new file mode 100644 index 00000000..1dbba21c --- /dev/null +++ b/mmocr/datasets/__init__.py @@ -0,0 +1,15 @@ +from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset +from .base_dataset import BaseDataset +from .icdar_dataset import IcdarDataset +from .kie_dataset import KIEDataset +from .ocr_dataset import OCRDataset +from .ocr_seg_dataset import OCRSegDataset +from .pipelines import CustomFormatBundle, DBNetTargets +from .text_det_dataset import TextDetDataset +from .utils import * # noqa: F401,F403 + +__all__ = [ + 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', + 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', + 'DBNetTargets', 'OCRSegDataset', 'KIEDataset' +] diff --git a/mmocr/datasets/base_dataset.py b/mmocr/datasets/base_dataset.py new file mode 100644 index 00000000..f5c8f9d8 --- /dev/null +++ b/mmocr/datasets/base_dataset.py @@ -0,0 +1,166 @@ +import numpy as np +from mmcv.utils import print_log +from torch.utils.data import Dataset + +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.pipelines import Compose +from mmocr.datasets.builder import build_loader + + +@DATASETS.register_module() +class BaseDataset(Dataset): + """Custom dataset for text detection, text recognition, and their + downstream tasks. + + 1. The text detection annotation format is as follows: + The `annotations` field is optional for testing + (this is one line of anno_file, with line-json-str + converted to dict for visualizing only). + + { + "file_name": "sample.jpg", + "height": 1080, + "width": 960, + "annotations": + [ + { + "iscrowd": 0, + "category_id": 1, + "bbox": [357.0, 667.0, 804.0, 100.0], + "segmentation": [[361, 667, 710, 670, + 72, 767, 357, 763]] + } + ] + } + + 2. The two text recognition annotation formats are as follows: + The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop + augmentation during training. + + format1: sample.jpg hello + format2: sample.jpg 20 20 100 20 100 40 20 40 hello + + Args: + ann_file (str): Annotation file path. + pipeline (list[dict]): Processing pipeline. + loader (dict): Dictionary to construct loader + to load annotation infos. + img_prefix (str, optional): Image prefix to generate full + image path. + test_mode (bool, optional): If set True, try...except will + be turned off in __getitem__. + """ + + def __init__(self, + ann_file, + loader, + pipeline, + img_prefix='', + test_mode=False): + super().__init__() + self.test_mode = test_mode + self.img_prefix = img_prefix + self.ann_file = ann_file + # load annotations + loader.update(ann_file=ann_file) + self.data_infos = build_loader(loader) + # processing pipeline + self.pipeline = Compose(pipeline) + # set group flag and class, no meaning + # for text detect and recognize + self._set_group_flag() + self.CLASSES = 0 + + def __len__(self): + return len(self.data_infos) + + def _set_group_flag(self): + """Set flag.""" + self.flag = np.zeros(len(self), dtype=np.uint8) + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['img_prefix'] = self.img_prefix + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + img_info = self.data_infos[index] + results = dict(img_info=img_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, img_info): + """Get testing data from pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys introduced by + pipeline. + """ + return self.prepare_train_img(img_info) + + def _log_error_index(self, index): + """Logging data info of bad index.""" + try: + data_info = self.data_infos[index] + img_prefix = self.img_prefix + print_log(f'Warning: skip broken file {data_info} ' + f'with img_prefix {img_prefix}') + except Exception as e: + print_log(f'load index {index} with error {e}') + + def _get_next_index(self, index): + """Get next index from dataset.""" + self._log_error_index(index) + index = (index + 1) % len(self) + return index + + def __getitem__(self, index): + """Get training/test data from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training/test data. + """ + if self.test_mode: + return self.prepare_test_img(index) + + while True: + try: + data = self.prepare_train_img(index) + if data is None: + raise Exception('prepared train data empty') + break + except Exception as e: + print_log(f'prepare index {index} with error {e}') + index = self._get_next_index(index) + return data + + def format_results(self, results, **kwargs): + """Placeholder to format result to dataset-specific output.""" + pass + + def evaluate(self, results, metric=None, logger=None, **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + Returns: + dict[str: float] + """ + raise NotImplementedError diff --git a/mmocr/datasets/builder.py b/mmocr/datasets/builder.py new file mode 100644 index 00000000..e7bcf423 --- /dev/null +++ b/mmocr/datasets/builder.py @@ -0,0 +1,14 @@ +from mmcv.utils import Registry, build_from_cfg + +LOADERS = Registry('loader') +PARSERS = Registry('parser') + + +def build_loader(cfg): + """Build anno file loader.""" + return build_from_cfg(cfg, LOADERS) + + +def build_parser(cfg): + """Build anno file parser.""" + return build_from_cfg(cfg, PARSERS) diff --git a/mmocr/datasets/icdar_dataset.py b/mmocr/datasets/icdar_dataset.py new file mode 100644 index 00000000..34a97584 --- /dev/null +++ b/mmocr/datasets/icdar_dataset.py @@ -0,0 +1,158 @@ +import numpy as np +from pycocotools.coco import COCO + +import mmocr.utils as utils +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.coco import CocoDataset +from mmocr.core.evaluation.hmean import eval_hmean + + +@DATASETS.register_module() +class IcdarDataset(CocoDataset): + CLASSES = ('text') + + def __init__(self, + ann_file, + pipeline, + classes=None, + data_root=None, + img_prefix='', + seg_prefix=None, + proposal_file=None, + test_mode=False, + filter_empty_gt=True, + select_first_k=-1): + # select first k images for fast debugging. + self.select_first_k = select_first_k + + super().__init__(ann_file, pipeline, classes, data_root, img_prefix, + seg_prefix, proposal_file, test_mode, filter_empty_gt) + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + + Returns: + list[dict]: Annotation info from COCO api. + """ + + self.coco = COCO(ann_file) + self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.img_ids = self.coco.get_img_ids() + data_infos = [] + + count = 0 + for i in self.img_ids: + info = self.coco.load_imgs([i])[0] + info['filename'] = info['file_name'] + data_infos.append(info) + count = count + 1 + if count > self.select_first_k and self.select_first_k > 0: + break + return data_infos + + def _parse_ann_info(self, img_info, ann_info): + """Parse bbox and mask annotation. + + Args: + ann_info (list[dict]): Annotation info of an image. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore, + labels, masks, masks_ignore, seg_map. "masks" and + "masks_ignore" are represented by polygon boundary + point sequences. + """ + gt_bboxes = [] + gt_labels = [] + gt_bboxes_ignore = [] + gt_masks_ignore = [] + gt_masks_ann = [] + + for ann in ann_info: + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(bbox) + gt_masks_ignore.append(ann.get( + 'segmentation', None)) # to float32 for latter processing + + else: + gt_bboxes.append(bbox) + gt_labels.append(self.cat2label[ann['category_id']]) + gt_masks_ann.append(ann.get('segmentation', None)) + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + seg_map = img_info['filename'].replace('jpg', 'png') + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks_ignore=gt_masks_ignore, + masks=gt_masks_ann, + seg_map=seg_map) + + return ann + + def evaluate(self, + results, + metric='hmean-iou', + logger=None, + score_thr=0.3, + rank_list=None, + **kwargs): + """Evaluate the hmean metric. + + Args: + results (list[dict]): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + rank_list (str): json file used to save eval result + of each image after ranking. + Returns: + dict[dict[str: float]]: The evaluation results. + """ + assert utils.is_type_list(results, dict) + + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['hmean-iou', 'hmean-ic13'] + metrics = set(metrics) & set(allowed_metrics) + + img_infos = [] + ann_infos = [] + for i in range(len(self)): + img_info = {'filename': self.data_infos[i]['file_name']} + img_infos.append(img_info) + ann_infos.append(self.get_ann_info(i)) + + eval_results = eval_hmean( + results, + img_infos, + ann_infos, + metrics=metrics, + score_thr=score_thr, + logger=logger, + rank_list=rank_list) + + return eval_results diff --git a/mmocr/datasets/kie_dataset.py b/mmocr/datasets/kie_dataset.py new file mode 100644 index 00000000..79868a68 --- /dev/null +++ b/mmocr/datasets/kie_dataset.py @@ -0,0 +1,218 @@ +import copy +from os import path as osp + +import numpy as np +import torch + +import mmocr.utils as utils +from mmdet.datasets.builder import DATASETS +from mmocr.core import compute_f1_score +from mmocr.datasets.base_dataset import BaseDataset +from mmocr.datasets.pipelines.crop import sort_vertex + + +@DATASETS.register_module() +class KIEDataset(BaseDataset): + """ + Args: + ann_file (str): Annotation file path. + pipeline (list[dict]): Processing pipeline. + loader (dict): Dictionary to construct loader + to load annotation infos. + img_prefix (str, optional): Image prefix to generate full + image path. + test_mode (bool, optional): If True, try...except will + be turned off in __getitem__. + dict_file (str): Character dict file path. + norm (float): Norm to map value from one range to another. + """ + + def __init__(self, + ann_file, + loader, + dict_file, + img_prefix='', + pipeline=None, + norm=10., + directed=False, + test_mode=True, + **kwargs): + super().__init__( + ann_file, + loader, + pipeline, + img_prefix=img_prefix, + test_mode=test_mode) + assert osp.exists(dict_file) + + self.norm = norm + self.directed = directed + + self.dict = dict({'': 0}) + with open(dict_file, 'r') as fr: + idx = 1 + for line in fr: + char = line.strip() + self.dict[char] = idx + idx += 1 + + def pre_pipeline(self, results): + results['img_prefix'] = self.img_prefix + results['bbox_fields'] = [] + + def _parse_anno_info(self, annotations): + """Parse annotations of boxes, texts and labels for one image. + Args: + annotations (list[dict]): Annotations of one image, where + each dict is for one character. + + Returns: + dict: A dict containing the following keys: + + - bboxes (np.ndarray): Bbox in one image with shape: + box_num * 4. + - relations (np.ndarray): Relations between bbox with shape: + box_num * box_num * D. + - texts (np.ndarray): Text index with shape: + box_num * text_max_len. + - labels (np.ndarray): Box Labels with shape: + box_num * (box_num + 1). + """ + + assert utils.is_type_list(annotations, dict) + assert 'box' in annotations[0] + assert 'text' in annotations[0] + assert 'label' in annotations[0] + + boxes, texts, text_inds, labels, edges = [], [], [], [], [] + for ann in annotations: + box = ann['box'] + x_list, y_list = box[0:8:2], box[1:9:2] + sorted_x_list, sorted_y_list = sort_vertex(x_list, y_list) + sorted_box = [] + for x, y in zip(sorted_x_list, sorted_y_list): + sorted_box.append(x) + sorted_box.append(y) + boxes.append(sorted_box) + text = ann['text'] + texts.append(ann['text']) + text_ind = [self.dict[c] for c in text if c in self.dict] + text_inds.append(text_ind) + labels.append(ann['label']) + edges.append(ann.get('edge', 0)) + + ann_infos = dict( + boxes=boxes, + texts=texts, + text_inds=text_inds, + edges=edges, + labels=labels) + + return self.list_to_numpy(ann_infos) + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + img_ann_info = self.data_infos[index] + img_info = { + 'filename': img_ann_info['file_name'], + 'height': img_ann_info['height'], + 'width': img_ann_info['width'] + } + ann_info = self._parse_anno_info(img_ann_info['annotations']) + results = dict(img_info=img_info, ann_info=ann_info) + + self.pre_pipeline(results) + + return self.pipeline(results) + + def evaluate(self, + results, + metric='macro_f1', + metric_options=dict(macro_f1=dict(ignores=[])), + **kwargs): + # allow some kwargs to pass through + assert set(kwargs).issubset(['logger']) + + # Protect ``metric_options`` since it uses mutable value as default + metric_options = copy.deepcopy(metric_options) + + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['macro_f1'] + for m in metrics: + if m not in allowed_metrics: + raise KeyError(f'metric {m} is not supported') + + return self.compute_macro_f1(results, **metric_options['macro_f1']) + + def compute_macro_f1(self, results, ignores=[]): + node_preds = [] + node_gts = [] + for idx, result in enumerate(results): + node_preds.append(result['nodes']) + box_ann_infos = self.data_infos[idx]['annotations'] + node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos] + node_gts.append(torch.Tensor(node_gt)) + + node_preds = torch.cat(node_preds) + node_gts = torch.cat(node_gts).int().to(node_preds.device) + + node_f1s = compute_f1_score(node_preds, node_gts, ignores) + + return { + 'macro_f1': node_f1s.mean(), + } + + def list_to_numpy(self, ann_infos): + """Convert bboxes, relations, texts and labels to ndarray.""" + boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds'] + boxes = np.array(boxes, np.int32) + relations, bboxes = self.compute_relation(boxes) + + labels = ann_infos.get('labels', None) + if labels is not None: + labels = np.array(labels, np.int32) + edges = ann_infos.get('edges', None) + if edges is not None: + labels = labels[:, None] + edges = np.array(edges) + edges = (edges[:, None] == edges[None, :]).astype(np.int32) + if self.directed: + edges = (edges & labels == 1).astype(np.int32) + np.fill_diagonal(edges, -1) + labels = np.concatenate([labels, edges], -1) + padded_text_inds = self.pad_text_indices(text_inds) + + return dict( + bboxes=bboxes, + relations=relations, + texts=padded_text_inds, + labels=labels) + + def pad_text_indices(self, text_inds): + """Pad text index to same length.""" + max_len = max([len(text_ind) for text_ind in text_inds]) + padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) + for idx, text_ind in enumerate(text_inds): + padded_text_inds[idx, :len(text_ind)] = np.array(text_ind) + return padded_text_inds + + def compute_relation(self, boxes): + """Compute relation between every two boxes.""" + x1s, y1s = boxes[:, 0:1], boxes[:, 1:2] + x2s, y2s = boxes[:, 4:5], boxes[:, 5:6] + ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1) + dxs = (x1s[:, 0][None] - x1s) / self.norm + dys = (y1s[:, 0][None] - y1s) / self.norm + xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs + whs = ws / hs + np.zeros_like(xhhs) + relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1) + bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32) + return relations, bboxes diff --git a/mmocr/datasets/ocr_dataset.py b/mmocr/datasets/ocr_dataset.py new file mode 100644 index 00000000..4ec6d962 --- /dev/null +++ b/mmocr/datasets/ocr_dataset.py @@ -0,0 +1,34 @@ +from mmdet.datasets.builder import DATASETS +from mmocr.core.evaluation.ocr_metric import eval_ocr_metric +from mmocr.datasets.base_dataset import BaseDataset + + +@DATASETS.register_module() +class OCRDataset(BaseDataset): + + def pre_pipeline(self, results): + results['img_prefix'] = self.img_prefix + results['text'] = results['img_info']['text'] + + def evaluate(self, results, metric='acc', logger=None, **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + Returns: + dict[str: float] + """ + gt_texts = [] + pred_texts = [] + for i in range(len(self)): + item_info = self.data_infos[i] + text = item_info['text'] + gt_texts.append(text) + pred_texts.append(results[i]['text']) + + eval_results = eval_ocr_metric(pred_texts, gt_texts) + + return eval_results diff --git a/mmocr/datasets/ocr_seg_dataset.py b/mmocr/datasets/ocr_seg_dataset.py new file mode 100644 index 00000000..638de163 --- /dev/null +++ b/mmocr/datasets/ocr_seg_dataset.py @@ -0,0 +1,89 @@ +import mmocr.utils as utils +from mmdet.datasets.builder import DATASETS +from mmocr.datasets.ocr_dataset import OCRDataset + + +@DATASETS.register_module() +class OCRSegDataset(OCRDataset): + + def pre_pipeline(self, results): + results['img_prefix'] = self.img_prefix + + def _parse_anno_info(self, annotations): + """Parse char boxes annotations. + Args: + annotations (list[dict]): Annotations of one image, where + each dict is for one character. + + Returns: + dict: A dict containing the following keys: + + - chars (list[str]): List of character strings. + - char_rects (list[list[float]]): List of char box, with each + in style of rectangle: [x_min, y_min, x_max, y_max]. + - char_quads (list[list[float]]): List of char box, with each + in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4]. + """ + + assert utils.is_type_list(annotations, dict) + assert 'char_box' in annotations[0] + assert 'char_text' in annotations[0] + assert len(annotations[0]['char_box']) in [4, 8] + + chars, char_rects, char_quads = [], [], [] + for ann in annotations: + char_box = ann['char_box'] + if len(char_box) == 4: + char_box_type = ann.get('char_box_type', 'xyxy') + if char_box_type == 'xyxy': + char_rects.append(char_box) + char_quads.append([ + char_box[0], char_box[1], char_box[2], char_box[1], + char_box[2], char_box[3], char_box[0], char_box[3] + ]) + elif char_box_type == 'xywh': + x1, y1, w, h = char_box + x2 = x1 + w + y2 = y1 + h + char_rects.append([x1, y1, x2, y2]) + char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2]) + else: + raise ValueError(f'invalid char_box_type {char_box_type}') + elif len(char_box) == 8: + x_list, y_list = [], [] + for i in range(4): + x_list.append(char_box[2 * i]) + y_list.append(char_box[2 * i + 1]) + x_max, x_min = max(x_list), min(x_list) + y_max, y_min = max(y_list), min(y_list) + char_rects.append([x_min, y_min, x_max, y_max]) + char_quads.append(char_box) + else: + raise Exception( + f'invalid num in char box: {len(char_box)} not in (4, 8)') + chars.append(ann['char_text']) + + ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads) + + return ann + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + img_ann_info = self.data_infos[index] + img_info = { + 'filename': img_ann_info['file_name'], + } + ann_info = self._parse_anno_info(img_ann_info['annotations']) + results = dict(img_info=img_info, ann_info=ann_info) + + self.pre_pipeline(results) + + return self.pipeline(results) diff --git a/mmocr/datasets/pipelines/__init__.py b/mmocr/datasets/pipelines/__init__.py new file mode 100644 index 00000000..4eca60a3 --- /dev/null +++ b/mmocr/datasets/pipelines/__init__.py @@ -0,0 +1,26 @@ +from .box_utils import sort_vertex +from .custom_format_bundle import CustomFormatBundle +from .dbnet_transforms import EastRandomCrop, ImgAug +from .kie_transforms import KIEFormatBundle +from .loading import LoadTextAnnotations +from .ocr_seg_targets import OCRSegTargets +from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, + OpencvToPil, PilToOpencv, RandomPaddingOCR, + RandomRotateImageBox, ResizeOCR, ToTensorOCR) +from .test_time_aug import MultiRotateAugOCR +from .textdet_targets import DBNetTargets, PANetTargets, TextSnakeTargets +from .transforms import (ColorJitter, RandomCropInstances, + RandomCropPolyInstances, RandomRotatePolyInstances, + RandomRotateTextDet, ScaleAspectJitter, + SquareResizePad) + +__all__ = [ + 'LoadTextAnnotations', 'NormalizeOCR', 'OnlineCropOCR', 'ResizeOCR', + 'ToTensorOCR', 'CustomFormatBundle', 'DBNetTargets', 'PANetTargets', + 'ColorJitter', 'RandomCropInstances', 'RandomRotateTextDet', + 'ScaleAspectJitter', 'MultiRotateAugOCR', 'OCRSegTargets', 'FancyPCA', + 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR', + 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', + 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets', + 'sort_vertex' +] diff --git a/mmocr/datasets/pipelines/box_utils.py b/mmocr/datasets/pipelines/box_utils.py new file mode 100644 index 00000000..23af4817 --- /dev/null +++ b/mmocr/datasets/pipelines/box_utils.py @@ -0,0 +1,83 @@ +import numpy as np +from shapely.geometry import LineString, Point, Polygon + +import mmocr.utils as utils + + +def sort_vertex(points_x, points_y): + """Sort box vertices in clockwise order from left-top first. + + Args: + points_x (list[float]): x of four vertices. + points_y (list[float]): y of four vertices. + Returns: + sorted_points_x (list[float]): x of sorted four vertices. + sorted_points_y (list[float]): y of sorted four vertices. + """ + assert utils.is_type_list(points_x, float) or utils.is_type_list( + points_x, int) + assert utils.is_type_list(points_y, float) or utils.is_type_list( + points_y, int) + assert len(points_x) == 4 + assert len(points_y) == 4 + + x = np.array(points_x) + y = np.array(points_y) + center_x = np.sum(x) * 0.25 + center_y = np.sum(y) * 0.25 + + x_arr = np.array(x - center_x) + y_arr = np.array(y - center_y) + + angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi + sort_idx = np.argsort(angle) + + sorted_points_x, sorted_points_y = [], [] + for i in range(4): + sorted_points_x.append(points_x[sort_idx[i]]) + sorted_points_y.append(points_y[sort_idx[i]]) + + return convert_canonical(sorted_points_x, sorted_points_y) + + +def convert_canonical(points_x, points_y): + """Make left-top be first. + + Args: + points_x (list[float]): x of four vertices. + points_y (list[float]): y of four vertices. + Returns: + sorted_points_x (list[float]): x of sorted four vertices. + sorted_points_y (list[float]): y of sorted four vertices. + """ + assert utils.is_type_list(points_x, float) or utils.is_type_list( + points_x, int) + assert utils.is_type_list(points_y, float) or utils.is_type_list( + points_y, int) + assert len(points_x) == 4 + assert len(points_y) == 4 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + + polygon = Polygon([(p.x, p.y) for p in points]) + min_x, min_y, _, _ = polygon.bounds + points_to_lefttop = [ + LineString([points[i], Point(min_x, min_y)]) for i in range(4) + ] + distances = np.array([line.length for line in points_to_lefttop]) + sort_dist_idx = np.argsort(distances) + lefttop_idx = sort_dist_idx[0] + + if lefttop_idx == 0: + point_orders = [0, 1, 2, 3] + elif lefttop_idx == 1: + point_orders = [1, 2, 3, 0] + elif lefttop_idx == 2: + point_orders = [2, 3, 0, 1] + else: + point_orders = [3, 0, 1, 2] + + sorted_points_x = [points_x[i] for i in point_orders] + sorted_points_y = [points_y[j] for j in point_orders] + + return sorted_points_x, sorted_points_y diff --git a/mmocr/datasets/pipelines/crop.py b/mmocr/datasets/pipelines/crop.py new file mode 100644 index 00000000..eea0ffbb --- /dev/null +++ b/mmocr/datasets/pipelines/crop.py @@ -0,0 +1,107 @@ +import cv2 +import numpy as np +from shapely.geometry import LineString, Point + +import mmocr.utils as utils +from .box_utils import sort_vertex + + +def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1): + """Jitter on the coordinates of bounding box. + + Args: + points_x (list[float | int]): List of y for four vertices. + points_y (list[float | int]): List of x for four vertices. + jitter_ratio_x (float): Horizontal jitter ratio relative to the height. + jitter_ratio_y (float): Vertical jitter ratio relative to the height. + """ + assert len(points_x) == 4 + assert len(points_y) == 4 + assert isinstance(jitter_ratio_x, float) + assert isinstance(jitter_ratio_y, float) + assert 0 <= jitter_ratio_x < 1 + assert 0 <= jitter_ratio_y < 1 + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + line_list = [ + LineString([points[i], points[i + 1 if i < 3 else 0]]) + for i in range(4) + ] + + tmp_h = max(line_list[1].length, line_list[3].length) + + for i in range(4): + jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h + jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h + points_x[i] += jitter_pixel_x + points_y[i] += jitter_pixel_y + + +def warp_img(src_img, + box, + jitter_flag=False, + jitter_ratio_x=0.5, + jitter_ratio_y=0.1): + """Crop box area from image using opencv warpPerspective w/o box jitter. + + Args: + src_img (np.array): Image before cropping. + box (list[float | int]): Coordinates of quadrangle. + """ + assert utils.is_type_list(box, float) or utils.is_type_list(box, int) + assert len(box) == 8 + + h, w = src_img.shape[:2] + points_x = [min(max(x, 0), w) for x in box[0:8:2]] + points_y = [min(max(y, 0), h) for y in box[1:9:2]] + + points_x, points_y = sort_vertex(points_x, points_y) + + if jitter_flag: + box_jitter( + points_x, + points_y, + jitter_ratio_x=jitter_ratio_x, + jitter_ratio_y=jitter_ratio_y) + + points = [Point(points_x[i], points_y[i]) for i in range(4)] + edges = [ + LineString([points[i], points[i + 1 if i < 3 else 0]]) + for i in range(4) + ] + + pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)]) + box_width = max(edges[0].length, edges[2].length) + box_height = max(edges[1].length, edges[3].length) + + pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], + [0, box_height]]) + M = cv2.getPerspectiveTransform(pts1, pts2) + dst_img = cv2.warpPerspective(src_img, M, + (int(box_width), int(box_height))) + + return dst_img + + +def crop_img(src_img, box): + """Crop box area to rectangle. + + Args: + src_img (np.array): Image before crop. + box (list[float | int]): Points of quadrangle. + """ + assert utils.is_type_list(box, float) or utils.is_type_list(box, int) + assert len(box) == 8 + + h, w = src_img.shape[:2] + points_x = [min(max(x, 0), w) for x in box[0:8:2]] + points_y = [min(max(y, 0), h) for y in box[1:9:2]] + + left = int(min(points_x)) + top = int(min(points_y)) + right = int(max(points_x)) + bottom = int(max(points_y)) + + dst_img = src_img[top:bottom, left:right] + + return dst_img diff --git a/mmocr/datasets/pipelines/custom_format_bundle.py b/mmocr/datasets/pipelines/custom_format_bundle.py new file mode 100644 index 00000000..276cde69 --- /dev/null +++ b/mmocr/datasets/pipelines/custom_format_bundle.py @@ -0,0 +1,65 @@ +import numpy as np +from mmcv.parallel import DataContainer as DC + +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines.formating import DefaultFormatBundle +from mmocr.core.visualize import overlay_mask_img, show_feature + + +@PIPELINES.register_module() +class CustomFormatBundle(DefaultFormatBundle): + """Custom formatting bundle. + + It formats common fields such as 'img' and 'proposals' as done in + DefaultFormatBundle, while other fields such as 'gt_kernels' and + 'gt_effective_region_mask' will be formatted to DC as follows: + + - gt_kernels: to DataContainer (cpu_only=True) + - gt_effective_mask: to DataContainer (cpu_only=True) + + Args: + keys (list[str]): Fields to be formatted to DC only. + call_super (bool): If True, format common fields + by DefaultFormatBundle, else format fields in keys above only. + visualize (dict): If flag=True, visualize gt mask for debugging. + """ + + def __init__(self, + keys=[], + call_super=True, + visualize=dict(flag=False, boundary_key=None)): + + super().__init__() + self.visualize = visualize + self.keys = keys + self.call_super = call_super + + def __call__(self, results): + + if self.visualize['flag']: + img = results['img'].astype(np.uint8) + boundary_key = self.visualize['boundary_key'] + if boundary_key is not None: + img = overlay_mask_img(img, results[boundary_key].masks[0]) + + features = [img] + names = ['img'] + to_uint8 = [1] + + for k in results['mask_fields']: + for iter in range(len(results[k].masks)): + features.append(results[k].masks[iter]) + names.append(k + str(iter)) + to_uint8.append(0) + show_feature(features, names, to_uint8) + + if self.call_super: + results = super().__call__(results) + + for k in self.keys: + results[k] = DC(results[k], cpu_only=True) + + return results + + def __repr__(self): + return self.__class__.__name__ diff --git a/mmocr/datasets/pipelines/dbnet_transforms.py b/mmocr/datasets/pipelines/dbnet_transforms.py new file mode 100644 index 00000000..9ec7936c --- /dev/null +++ b/mmocr/datasets/pipelines/dbnet_transforms.py @@ -0,0 +1,272 @@ +import cv2 +import imgaug +import imgaug.augmenters as iaa +import numpy as np + +from mmdet.core.mask import PolygonMasks +from mmdet.datasets.builder import PIPELINES + + +class AugmenterBuilder: + """Build imgaug object according ImgAug argmentations.""" + + def __init__(self): + pass + + def build(self, args, root=True): + if args is None: + return None + elif isinstance(args, (int, float, str)): + return args + elif isinstance(args, list): + if root: + sequence = [self.build(value, root=False) for value in args] + return iaa.Sequential(sequence) + arg_list = [self.to_tuple_if_list(a) for a in args[1:]] + return getattr(iaa, args[0])(*arg_list) + elif isinstance(args, dict): + if 'cls' in args: + cls = getattr(iaa, args['cls']) + return cls( + **{ + k: self.to_tuple_if_list(v) + for k, v in args.items() if not k == 'cls' + }) + else: + return { + key: self.build(value, root=False) + for key, value in args.items() + } + else: + raise RuntimeError('unknown augmenter arg: ' + str(args)) + + def to_tuple_if_list(self, obj): + if isinstance(obj, list): + return tuple(obj) + return obj + + +@PIPELINES.register_module() +class ImgAug: + """A wrapper to use imgaug https://github.com/aleju/imgaug. + + Args: + args ([list[list|dict]]): The argumentation list. For details, please + refer to imgaug document. Take args=[['Fliplr', 0.5], + dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an + example. The args horizontally flip images with probability 0.5, + followed by random rotation with angles in range [-10, 10], and + resize with an independent scale in range [0.5, 3.0] for each + side of images. + """ + + def __init__(self, args=None): + self.augmenter_args = args + self.augmenter = AugmenterBuilder().build(self.augmenter_args) + + def __call__(self, results): + # img is bgr + image = results['img'] + aug = None + shape = image.shape + + if self.augmenter: + aug = self.augmenter.to_deterministic() + results['img'] = aug.augment_image(image) + results['img_shape'] = results['img'].shape + results['flip'] = 'unknown' # it's unknown + results['flip_direction'] = 'unknown' # it's unknown + target_shape = results['img_shape'] + + self.may_augment_annotation(aug, shape, target_shape, results) + + return results + + def may_augment_annotation(self, aug, shape, target_shape, results): + if aug is None: + return results + for key in results['mask_fields']: + # augment polygon mask + masks = [] + for mask in results[key]: + masks.append( + [self.may_augment_poly(aug, shape, target_shape, mask[0])]) + if len(masks) > 0: + results[key] = PolygonMasks(masks, *target_shape[:2]) + + for key in results['bbox_fields']: + # augment bbox + bboxes = [] + for bbox in results[key]: + bbox = self.may_augment_poly(aug, shape, target_shape, bbox) + bboxes.append(bbox) + results[key] = np.zeros(0) + if len(bboxes) > 0: + results[key] = np.stack(bboxes) + + return results + + def may_augment_poly(self, aug, img_shape, target_shape, poly): + # poly n x 2 + poly = poly.reshape(-1, 2) + keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] + keypoints = aug.augment_keypoints( + [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints + poly = [[p.x, p.y] for p in keypoints] + poly = np.array(poly).flatten() + return poly + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class EastRandomCrop: + + def __init__(self, + target_size=(640, 640), + max_tries=10, + min_crop_side_ratio=0.1): + self.target_size = target_size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + + def __call__(self, results): + # sampling crop + # crop image, boxes, masks + img = results['img'] + crop_x, crop_y, crop_w, crop_h = self.crop_area( + img, results['gt_masks']) + scale_w = self.target_size[0] / crop_w + scale_h = self.target_size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + padded_img = np.zeros( + (self.target_size[1], self.target_size[0], img.shape[2]), + img.dtype) + padded_img[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + + # for bboxes + for key in results['bbox_fields']: + lines = [] + for box in results[key]: + box = box.reshape(2, 2) + poly = ((box - (crop_x, crop_y)) * scale) + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + lines.append(poly.flatten()) + results[key] = np.array(lines) + # for masks + for key in results['mask_fields']: + polys = [] + polys_label = [] + for poly in results[key]: + poly = np.array(poly).reshape(-1, 2) + poly = ((poly - (crop_x, crop_y)) * scale) + if not self.is_poly_outside_rect(poly, 0, 0, w, h): + polys.append([poly]) + polys_label.append(0) + results[key] = PolygonMasks(polys, *self.target_size) + if key == 'gt_masks': + results['gt_labels'] = polys_label + + results['img'] = padded_img + results['img_shape'] = padded_img.shape + + return results + + def is_poly_in_rect(self, poly, x, y, w, h): + poly = np.array(poly) + if poly[:, 0].min() < x or poly[:, 0].max() > x + w: + return False + if poly[:, 1].min() < y or poly[:, 1].max() > y + h: + return False + return True + + def is_poly_outside_rect(self, poly, x, y, w, h): + poly = np.array(poly).reshape(-1, 2) + if poly[:, 0].max() < x or poly[:, 0].min() > x + w: + return True + if poly[:, 1].max() < y or poly[:, 1].min() > y + h: + return True + return False + + def split_regions(self, axis): + regions = [] + min_axis = 0 + for i in range(1, axis.shape[0]): + if axis[i] != axis[i - 1] + 1: + region = axis[min_axis:i] + min_axis = i + regions.append(region) + return regions + + def random_select(self, axis, max_size): + xx = np.random.choice(axis, size=2) + xmin = np.min(xx) + xmax = np.max(xx) + xmin = np.clip(xmin, 0, max_size - 1) + xmax = np.clip(xmax, 0, max_size - 1) + return xmin, xmax + + def region_wise_random_select(self, regions, max_size): + selected_index = list(np.random.choice(len(regions), 2)) + selected_values = [] + for index in selected_index: + axis = regions[index] + xx = int(np.random.choice(axis, size=1)) + selected_values.append(xx) + xmin = min(selected_values) + xmax = max(selected_values) + return xmin, xmax + + def crop_area(self, img, polys): + h, w, _ = img.shape + h_array = np.zeros(h, dtype=np.int32) + w_array = np.zeros(w, dtype=np.int32) + for points in polys: + points = np.round( + points, decimals=0).astype(np.int32).reshape(-1, 2) + min_x = np.min(points[:, 0]) + max_x = np.max(points[:, 0]) + w_array[min_x:max_x] = 1 + min_y = np.min(points[:, 1]) + max_y = np.max(points[:, 1]) + h_array[min_y:max_y] = 1 + # ensure the cropped area not across a text + h_axis = np.where(h_array == 0)[0] + w_axis = np.where(w_array == 0)[0] + + if len(h_axis) == 0 or len(w_axis) == 0: + return 0, 0, w, h + + h_regions = self.split_regions(h_axis) + w_regions = self.split_regions(w_axis) + + for i in range(self.max_tries): + if len(w_regions) > 1: + xmin, xmax = self.region_wise_random_select(w_regions, w) + else: + xmin, xmax = self.random_select(w_axis, w) + if len(h_regions) > 1: + ymin, ymax = self.region_wise_random_select(h_regions, h) + else: + ymin, ymax = self.random_select(h_axis, h) + + if (xmax - xmin < self.min_crop_side_ratio * w + or ymax - ymin < self.min_crop_side_ratio * h): + # area too small + continue + num_poly_in_rect = 0 + for poly in polys: + if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, + ymax - ymin): + num_poly_in_rect += 1 + break + + if num_poly_in_rect > 0: + return xmin, ymin, xmax - xmin, ymax - ymin + + return 0, 0, w, h diff --git a/mmocr/datasets/pipelines/kie_transforms.py b/mmocr/datasets/pipelines/kie_transforms.py new file mode 100644 index 00000000..0787f07f --- /dev/null +++ b/mmocr/datasets/pipelines/kie_transforms.py @@ -0,0 +1,55 @@ +import numpy as np +from mmcv.parallel import DataContainer as DC + +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines.formating import DefaultFormatBundle, to_tensor + + +@PIPELINES.register_module() +class KIEFormatBundle(DefaultFormatBundle): + """Key information extraction formatting bundle. + + Based on the DefaultFormatBundle, itt simplifies the pipeline of formatting + common fields, including "img", "proposals", "gt_bboxes", "gt_labels", + "gt_masks", "gt_semantic_seg", "relations" and "texts". + These fields are formatted as follows. + + - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True) + - proposals: (1) to tensor, (2) to DataContainer + - gt_bboxes: (1) to tensor, (2) to DataContainer + - gt_bboxes_ignore: (1) to tensor, (2) to DataContainer + - gt_labels: (1) to tensor, (2) to DataContainer + - gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True) + - gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, + (3) to DataContainer (stack=True) + - relations: (1) scale, (2) to tensor, (3) to DataContainer + - texts: (1) to tensor, (2) to DataContainer + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with + default bundle. + """ + super().__call__(results) + if 'ann_info' in results: + for key in ['relations', 'texts']: + value = results['ann_info'][key] + if key == 'relations' and 'scale_factor' in results: + scale_factor = results['scale_factor'] + if isinstance(scale_factor, float): + sx = sy = scale_factor + else: + sx, sy = results['scale_factor'][:2] + r = sx / sy + value = value * np.array([sx, sy, r, 1, r])[None, None] + results[key] = DC(to_tensor(value)) + return results + + def __repr__(self): + return self.__class__.__name__ diff --git a/mmocr/datasets/pipelines/loading.py b/mmocr/datasets/pipelines/loading.py new file mode 100644 index 00000000..5c3cda6e --- /dev/null +++ b/mmocr/datasets/pipelines/loading.py @@ -0,0 +1,68 @@ +import numpy as np + +from mmdet.core import BitmapMasks, PolygonMasks +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines.loading import LoadAnnotations + + +@PIPELINES.register_module() +class LoadTextAnnotations(LoadAnnotations): + + def __init__(self, + with_bbox=True, + with_label=True, + with_mask=False, + with_seg=False, + poly2mask=True): + super().__init__( + with_bbox=with_bbox, + with_label=with_label, + with_mask=with_mask, + with_seg=with_seg, + poly2mask=poly2mask) + + def process_polygons(self, polygons): + """Convert polygons to list of ndarray and filter invalid polygons. + + Args: + polygons (list[list]): Polygons of one instance. + + Returns: + list[numpy.ndarray]: Processed polygons. + """ + + polygons = [np.array(p).astype(np.float32) for p in polygons] + valid_polygons = [] + for polygon in polygons: + if len(polygon) % 2 == 0 and len(polygon) >= 6: + valid_polygons.append(polygon) + return valid_polygons + + def _load_masks(self, results): + ann_info = results['ann_info'] + h, w = results['img_info']['height'], results['img_info']['width'] + gt_masks = ann_info['masks'] + if self.poly2mask: + gt_masks = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks], h, w) + else: + gt_masks = PolygonMasks( + [self.process_polygons(polygons) for polygons in gt_masks], h, + w) + gt_masks_ignore = ann_info.get('masks_ignore', None) + if gt_masks_ignore is not None: + if self.poly2mask: + gt_masks_ignore = BitmapMasks( + [self._poly2mask(mask, h, w) for mask in gt_masks_ignore], + h, w) + else: + gt_masks_ignore = PolygonMasks([ + self.process_polygons(polygons) + for polygons in gt_masks_ignore + ], h, w) + results['gt_masks_ignore'] = gt_masks_ignore + results['mask_fields'].append('gt_masks_ignore') + + results['gt_masks'] = gt_masks + results['mask_fields'].append('gt_masks') + return results diff --git a/mmocr/datasets/pipelines/ocr_seg_targets.py b/mmocr/datasets/pipelines/ocr_seg_targets.py new file mode 100644 index 00000000..cbd9b869 --- /dev/null +++ b/mmocr/datasets/pipelines/ocr_seg_targets.py @@ -0,0 +1,201 @@ +import cv2 +import numpy as np + +import mmocr.utils.check_argument as check_argument +from mmdet.core import BitmapMasks +from mmdet.datasets.builder import PIPELINES +from mmocr.models.builder import build_convertor + + +@PIPELINES.register_module() +class OCRSegTargets: + """Generate gt shrinked kernels for segmentation based OCR framework. + + Args: + label_convertor (dict): Dictionary to construct label_convertor + to convert char to index. + attn_shrink_ratio (float): The area shrinked ratio + between attention kernels and gt text masks. + seg_shrink_ratio (float): The area shrinked ratio + between segmentation kernels and gt text masks. + box_type (str): Character box type, should be either + 'char_rects' or 'char_quads', with 'char_rects' + for rectangle with ``xyxy`` style and 'char_quads' + for quadrangle with ``x1y1x2y2x3y3x4y4`` style. + """ + + def __init__(self, + label_convertor=None, + attn_shrink_ratio=0.5, + seg_shrink_ratio=0.25, + box_type='char_rects', + pad_val=255): + + assert isinstance(attn_shrink_ratio, float) + assert isinstance(seg_shrink_ratio, float) + assert 0. < attn_shrink_ratio < 1.0 + assert 0. < seg_shrink_ratio < 1.0 + assert label_convertor is not None + assert box_type in ('char_rects', 'char_quads') + + self.attn_shrink_ratio = attn_shrink_ratio + self.seg_shrink_ratio = seg_shrink_ratio + self.label_convertor = build_convertor(label_convertor) + self.box_type = box_type + self.pad_val = pad_val + + def shrink_char_quad(self, char_quad, shrink_ratio): + """Shrink char box in style of quadrangle. + + Args: + char_quad (list[float]): Char box with format + [x1, y1, x2, y2, x3, y3, x4, y4]. + shrink_ratio (float): The area shrinked ratio + between gt kernels and gt text masks. + """ + points = [[char_quad[0], char_quad[1]], [char_quad[2], char_quad[3]], + [char_quad[4], char_quad[5]], [char_quad[6], char_quad[7]]] + shrink_points = [] + for p_idx, point in enumerate(points): + p1 = points[(p_idx + 3) % 4] + p2 = points[(p_idx + 1) % 4] + + dist1 = self.l2_dist_two_points(p1, point) + dist2 = self.l2_dist_two_points(p2, point) + min_dist = min(dist1, dist2) + + v1 = [p1[0] - point[0], p1[1] - point[1]] + v2 = [p2[0] - point[0], p2[1] - point[1]] + + temp_dist1 = (shrink_ratio * min_dist / + dist1) if min_dist != 0 else 0. + temp_dist2 = (shrink_ratio * min_dist / + dist2) if min_dist != 0 else 0. + + v1 = [temp * temp_dist1 for temp in v1] + v2 = [temp * temp_dist2 for temp in v2] + + shrink_point = [ + round(point[0] + v1[0] + v2[0]), + round(point[1] + v1[1] + v2[1]) + ] + shrink_points.append(shrink_point) + + poly = np.array(shrink_points) + + return poly + + def shrink_char_rect(self, char_rect, shrink_ratio): + """Shrink char box in style of rectangle. + + Args: + char_rect (list[float]): Char box with format + [x_min, y_min, x_max, y_max]. + shrink_ratio (float): The area shrinked ratio + between gt kernels and gt text masks. + """ + x_min, y_min, x_max, y_max = char_rect + w = x_max - x_min + h = y_max - y_min + x_min_s = round((x_min + x_max - w * shrink_ratio) / 2) + y_min_s = round((y_min + y_max - h * shrink_ratio) / 2) + x_max_s = round((x_min + x_max + w * shrink_ratio) / 2) + y_max_s = round((y_min + y_max + h * shrink_ratio) / 2) + poly = np.array([[x_min_s, y_min_s], [x_max_s, y_min_s], + [x_max_s, y_max_s], [x_min_s, y_max_s]]) + + return poly + + def generate_kernels(self, + resize_shape, + pad_shape, + char_boxes, + char_inds, + shrink_ratio=0.5, + binary=True): + """Generate char instance kernels for one shrink ratio. + + Args: + resize_shape (tuple(int, int)): Image size (height, width) + after resizing. + pad_shape (tuple(int, int)): Image size (height, width) + after padding. + char_boxes (list[list[float]]): The list of char polygons. + char_inds (list[int]): List of char indexes. + shrink_ratio (float): The shrink ratio of kernel. + binary (bool): If True, return binary ndarray + containing 0 & 1 only. + Returns: + char_kernel (ndarray): The text kernel mask of (height, width). + """ + assert isinstance(resize_shape, tuple) + assert isinstance(pad_shape, tuple) + assert check_argument.is_2dlist(char_boxes) + assert check_argument.is_type_list(char_inds, int) + assert isinstance(shrink_ratio, float) + assert isinstance(binary, bool) + + char_kernel = np.zeros(pad_shape, dtype=np.int32) + char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val + + for i, char_box in enumerate(char_boxes): + if self.box_type == 'char_rects': + poly = self.shrink_char_rect(char_box, shrink_ratio) + elif self.box_type == 'char_quads': + poly = self.shrink_char_quad(char_box, shrink_ratio) + + fill_value = 1 if binary else char_inds[i] + cv2.fillConvexPoly(char_kernel, poly.astype(np.int32), + (fill_value)) + + return char_kernel + + def l2_dist_two_points(self, p1, p2): + return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5 + + def __call__(self, results): + img_shape = results['img_shape'] + resize_shape = results['resize_shape'] + + h_scale = 1.0 * resize_shape[0] / img_shape[0] + w_scale = 1.0 * resize_shape[1] / img_shape[1] + + char_boxes, char_inds = [], [] + char_num = len(results['ann_info'][self.box_type]) + for i in range(char_num): + char_box = results['ann_info'][self.box_type][i] + num_points = 2 if self.box_type == 'char_rects' else 4 + for j in range(num_points): + char_box[j * 2] = round(char_box[j * 2] * w_scale) + char_box[j * 2 + 1] = round(char_box[j * 2 + 1] * h_scale) + char_boxes.append(char_box) + char = results['ann_info']['chars'][i] + char_ind = self.label_convertor.str2idx([char])[0][0] + char_inds.append(char_ind) + + resize_shape = tuple(results['resize_shape'][:2]) + pad_shape = tuple(results['pad_shape'][:2]) + binary_target = self.generate_kernels( + resize_shape, + pad_shape, + char_boxes, + char_inds, + shrink_ratio=self.attn_shrink_ratio, + binary=True) + + seg_target = self.generate_kernels( + resize_shape, + pad_shape, + char_boxes, + char_inds, + shrink_ratio=self.seg_shrink_ratio, + binary=False) + + mask = np.ones(pad_shape, dtype=np.int32) + mask[:resize_shape[0], resize_shape[1]:] = 0 + + results['gt_kernels'] = BitmapMasks([binary_target, seg_target, mask], + pad_shape[0], pad_shape[1]) + results['mask_fields'] = ['gt_kernels'] + + return results diff --git a/mmocr/datasets/pipelines/ocr_transforms.py b/mmocr/datasets/pipelines/ocr_transforms.py new file mode 100644 index 00000000..5b7370e5 --- /dev/null +++ b/mmocr/datasets/pipelines/ocr_transforms.py @@ -0,0 +1,446 @@ +import math + +import cv2 +import mmcv +import numpy as np +import torch +import torchvision.transforms.functional as TF +from mmcv.runner.dist_utils import get_dist_info +from PIL import Image +from shapely.geometry import Polygon +from shapely.geometry import box as shapely_box + +import mmocr.utils as utils +from mmdet.datasets.builder import PIPELINES +from mmocr.datasets.pipelines.crop import warp_img + + +@PIPELINES.register_module() +class ResizeOCR: + """Image resizing and padding for OCR. + + Args: + height (int | tuple(int)): Image height after resizing. + min_width (none | int | tuple(int)): Image minimum width + after resizing. + max_width (none | int | tuple(int)): Image maximum width + after resizing. + keep_aspect_ratio (bool): Keep image aspect ratio if True + during resizing, Otherwise resize to the size height * + max_width. + img_pad_value (int): Scalar to fill padding area. + width_downsample_ratio (float): Downsample ratio in horizontal + direction from input image to output feature. + """ + + def __init__(self, + height, + min_width=None, + max_width=None, + keep_aspect_ratio=True, + img_pad_value=0, + width_downsample_ratio=1.0 / 16): + assert isinstance(height, (int, tuple)) + assert utils.is_none_or_type(min_width, (int, tuple)) + assert utils.is_none_or_type(max_width, (int, tuple)) + if not keep_aspect_ratio: + assert max_width is not None, ('"max_width" must assigned ' + 'if "keep_aspect_ratio" is False') + assert isinstance(img_pad_value, int) + if isinstance(height, tuple): + assert isinstance(min_width, tuple) + assert isinstance(max_width, tuple) + assert len(height) == len(min_width) == len(max_width) + + self.height = height + self.min_width = min_width + self.max_width = max_width + self.keep_aspect_ratio = keep_aspect_ratio + self.img_pad_value = img_pad_value + self.width_downsample_ratio = width_downsample_ratio + + def __call__(self, results): + rank, _ = get_dist_info() + if isinstance(self.height, int): + dst_height = self.height + dst_min_width = self.min_width + dst_max_width = self.max_width + else: + """Multi-scale resize used in distributed training. + + Choose one (height, width) pair for one rank id. + """ + idx = rank % len(self.height) + dst_height = self.height[idx] + dst_min_width = self.min_width[idx] + dst_max_width = self.max_width[idx] + + img_shape = results['img_shape'] + ori_height, ori_width = img_shape[:2] + valid_ratio = 1.0 + resize_shape = list(img_shape) + pad_shape = list(img_shape) + + if self.keep_aspect_ratio: + new_width = math.ceil(float(dst_height) / ori_height * ori_width) + width_divisor = int(1 / self.width_downsample_ratio) + # make sure new_width is an integral multiple of width_divisor. + if new_width % width_divisor != 0: + new_width = round(new_width / width_divisor) * width_divisor + if dst_min_width is not None: + new_width = max(dst_min_width, new_width) + if dst_max_width is not None: + valid_ratio = min(1.0, 1.0 * new_width / dst_max_width) + resize_width = min(dst_max_width, new_width) + img_resize = cv2.resize(results['img'], + (resize_width, dst_height)) + resize_shape = img_resize.shape + pad_shape = img_resize.shape + if new_width < dst_max_width: + img_resize = mmcv.impad( + img_resize, + shape=(dst_height, dst_max_width), + pad_val=self.img_pad_value) + pad_shape = img_resize.shape + else: + img_resize = cv2.resize(results['img'], + (new_width, dst_height)) + resize_shape = img_resize.shape + pad_shape = img_resize.shape + else: + img_resize = cv2.resize(results['img'], + (dst_max_width, dst_height)) + resize_shape = img_resize.shape + pad_shape = img_resize.shape + + results['img'] = img_resize + results['resize_shape'] = resize_shape + results['pad_shape'] = pad_shape + results['valid_ratio'] = valid_ratio + + return results + + +@PIPELINES.register_module() +class ToTensorOCR: + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.""" + + def __init__(self): + pass + + def __call__(self, results): + results['img'] = TF.to_tensor(results['img'].copy()) + + return results + + +@PIPELINES.register_module() +class NormalizeOCR: + """Normalize a tensor image with mean and standard deviation.""" + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, results): + results['img'] = TF.normalize(results['img'], self.mean, self.std) + + return results + + +@PIPELINES.register_module() +class OnlineCropOCR: + """Crop text areas from whole image with bounding box jitter. If no bbox is + given, return directly. + + Args: + box_keys (list[str]): Keys in results which correspond to RoI bbox. + jitter_prob (float): The probability of box jitter. + max_jitter_ratio_x (float): Maximum horizontal jitter ratio + relative to height. + max_jitter_ratio_y (float): Maximum vertical jitter ratio + relative to height. + """ + + def __init__(self, + box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'], + jitter_prob=0.5, + max_jitter_ratio_x=0.05, + max_jitter_ratio_y=0.02): + assert utils.is_type_list(box_keys, str) + assert 0 <= jitter_prob <= 1 + assert 0 <= max_jitter_ratio_x <= 1 + assert 0 <= max_jitter_ratio_y <= 1 + + self.box_keys = box_keys + self.jitter_prob = jitter_prob + self.max_jitter_ratio_x = max_jitter_ratio_x + self.max_jitter_ratio_y = max_jitter_ratio_y + + def __call__(self, results): + + if 'img_info' not in results: + return results + + crop_flag = True + box = [] + for key in self.box_keys: + if key not in results['img_info']: + crop_flag = False + break + + box.append(float(results['img_info'][key])) + + if not crop_flag: + return results + + jitter_flag = np.random.random() > self.jitter_prob + + kwargs = dict( + jitter_flag=jitter_flag, + jitter_ratio_x=self.max_jitter_ratio_x, + jitter_ratio_y=self.max_jitter_ratio_y) + crop_img = warp_img(results['img'], box, **kwargs) + + results['img'] = crop_img + results['img_shape'] = crop_img.shape + + return results + + +@PIPELINES.register_module() +class FancyPCA: + """Implementation of PCA based image augmentation, proposed in the paper + ``Imagenet Classification With Deep Convolutional Neural Networks``. + + It alters the intensities of RGB values along the principal components of + ImageNet dataset. + """ + + def __init__(self, eig_vec=None, eig_val=None): + if eig_vec is None: + eig_vec = torch.Tensor([ + [-0.5675, +0.7192, +0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, +0.4203], + ]).t() + if eig_val is None: + eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]]) + self.eig_val = eig_val # 1*3 + self.eig_vec = eig_vec # 3*3 + + def pca(self, tensor): + assert tensor.size(0) == 3 + alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1 + reconst = torch.mm(self.eig_val * alpha, self.eig_vec) + tensor = tensor + reconst.view(3, 1, 1) + + return tensor + + def __call__(self, results): + img = results['img'] + tensor = self.pca(img) + results['img'] = tensor + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomPaddingOCR: + """Pad the given image on all sides, as well as modify the coordinates of + character bounding box in image. + + Args: + max_ratio (list[int]): [left, top, right, bottom]. + box_type (None|str): Character box type. If not none, + should be either 'char_rects' or 'char_quads', with + 'char_rects' for rectangle with ``xyxy`` style and + 'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style. + """ + + def __init__(self, max_ratio=None, box_type=None): + if max_ratio is None: + max_ratio = [0.1, 0.2, 0.1, 0.2] + else: + assert utils.is_type_list(max_ratio, float) + assert len(max_ratio) == 4 + assert box_type is None or box_type in ('char_rects', 'char_quads') + + self.max_ratio = max_ratio + self.box_type = box_type + + def __call__(self, results): + + img_shape = results['img_shape'] + ori_height, ori_width = img_shape[:2] + + random_padding_left = round( + np.random.uniform(0, self.max_ratio[0]) * ori_width) + random_padding_top = round( + np.random.uniform(0, self.max_ratio[1]) * ori_height) + random_padding_right = round( + np.random.uniform(0, self.max_ratio[2]) * ori_width) + random_padding_bottom = round( + np.random.uniform(0, self.max_ratio[3]) * ori_height) + + img = np.copy(results['img']) + img = cv2.copyMakeBorder(img, random_padding_top, + random_padding_bottom, random_padding_left, + random_padding_right, cv2.BORDER_REPLICATE) + results['img'] = img + results['img_shape'] = img.shape + + if self.box_type is not None: + num_points = 2 if self.box_type == 'char_rects' else 4 + char_num = len(results['ann_info'][self.box_type]) + for i in range(char_num): + for j in range(num_points): + results['ann_info'][self.box_type][i][ + j * 2] += random_padding_left + results['ann_info'][self.box_type][i][ + j * 2 + 1] += random_padding_top + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomRotateImageBox: + """Rotate augmentation for segmentation based text recognition. + + Args: + min_angle (int): Minimum rotation angle for image and box. + max_angle (int): Maximum rotation angle for image and box. + box_type (str): Character box type, should be either + 'char_rects' or 'char_quads', with 'char_rects' + for rectangle with ``xyxy`` style and 'char_quads' + for quadrangle with ``x1y1x2y2x3y3x4y4`` style. + """ + + def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'): + assert box_type in ('char_rects', 'char_quads') + + self.min_angle = min_angle + self.max_angle = max_angle + self.box_type = box_type + + def __call__(self, results): + in_img = results['img'] + in_chars = results['ann_info']['chars'] + in_boxes = results['ann_info'][self.box_type] + + img_width, img_height = in_img.size + rotate_center = [img_width / 2., img_height / 2.] + + tan_temp_max_angle = rotate_center[1] / rotate_center[0] + temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi + + random_angle = np.random.uniform( + max(self.min_angle, -temp_max_angle), + min(self.max_angle, temp_max_angle)) + random_angle_radian = random_angle * np.pi / 180. + + img_box = shapely_box(0, 0, img_width, img_height) + + out_img = TF.rotate( + in_img, + random_angle, + resample=False, + expand=False, + center=rotate_center) + + out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars, + random_angle_radian, + rotate_center, img_box) + + results['img'] = out_img + results['ann_info']['chars'] = out_chars + results['ann_info'][self.box_type] = out_boxes + + return results + + @staticmethod + def rotate_bbox(boxes, chars, angle, center, img_box): + out_boxes = [] + out_chars = [] + for idx, bbox in enumerate(boxes): + temp_bbox = [] + for i in range(len(bbox) // 2): + point = [bbox[2 * i], bbox[2 * i + 1]] + temp_bbox.append( + RandomRotateImageBox.rotate_point(point, angle, center)) + poly_temp_bbox = Polygon(temp_bbox).buffer(0) + if poly_temp_bbox.is_valid: + if img_box.intersects(poly_temp_bbox) and ( + not img_box.touches(poly_temp_bbox)): + temp_bbox_area = poly_temp_bbox.area + + intersect_area = img_box.intersection(poly_temp_bbox).area + intersect_ratio = intersect_area / temp_bbox_area + + if intersect_ratio >= 0.7: + out_box = [] + for p in temp_bbox: + out_box.extend(p) + out_boxes.append(out_box) + out_chars.append(chars[idx]) + + return out_boxes, out_chars + + @staticmethod + def rotate_point(point, angle, center): + cos_theta = math.cos(-angle) + sin_theta = math.sin(-angle) + c_x = center[0] + c_y = center[1] + new_x = (point[0] - c_x) * cos_theta - (point[1] - + c_y) * sin_theta + c_x + new_y = (point[0] - c_x) * sin_theta + (point[1] - + c_y) * cos_theta + c_y + + return [new_x, new_y] + + +@PIPELINES.register_module() +class OpencvToPil: + """Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb).""" + + def __init__(self, **kwargs): + pass + + def __call__(self, results): + img = results['img'][..., ::-1] + img = Image.fromarray(img) + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class PilToOpencv: + """Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr).""" + + def __init__(self, **kwargs): + pass + + def __call__(self, results): + img = np.asarray(results['img']) + img = img[..., ::-1] + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str diff --git a/mmocr/datasets/pipelines/test_time_aug.py b/mmocr/datasets/pipelines/test_time_aug.py new file mode 100644 index 00000000..5c8c1a60 --- /dev/null +++ b/mmocr/datasets/pipelines/test_time_aug.py @@ -0,0 +1,108 @@ +import mmcv +import numpy as np + +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines.compose import Compose + + +@PIPELINES.register_module() +class MultiRotateAugOCR: + """Test-time augmentation with multiple rotations in the case that + img_height > img_width. + + An example configuration is as follows: + + .. code-block:: + + rotate_degrees=[0, 90, 270], + transforms=[ + dict( + type='ResizeOCR', + height=32, + min_width=32, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio' + ]), + ] + + After MultiRotateAugOCR with above configuration, the results are wrapped + into lists of the same length as follows: + + .. code-block:: + + dict( + img=[...], + img_shape=[...] + ... + ) + + Args: + transforms (list[dict]): Transformation applied for each augmentation. + rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation. + force_rotate (bool): If True, rotate image by 'rotate_degrees' + while ignore image aspect ratio. + """ + + def __init__(self, transforms, rotate_degrees=None, force_rotate=False): + self.transforms = Compose(transforms) + self.force_rotate = force_rotate + if rotate_degrees is not None: + self.rotate_degrees = rotate_degrees if isinstance( + rotate_degrees, list) else [rotate_degrees] + assert mmcv.is_list_of(self.rotate_degrees, int) + for degree in self.rotate_degrees: + assert 0 <= degree < 360 + assert degree % 90 == 0 + if 0 not in self.rotate_degrees: + self.rotate_degrees.append(0) + else: + self.rotate_degrees = [0] + + def __call__(self, results): + """Call function to apply test time augment transformation to results. + + Args: + results (dict): Result dict contains the data to be transformed. + + Returns: + dict[str: list]: The augmented data, where each value is wrapped + into a list. + """ + img_shape = results['img_shape'] + ori_height, ori_width = img_shape[:2] + if not self.force_rotate and ori_height <= ori_width: + rotate_degrees = [0] + else: + rotate_degrees = self.rotate_degrees + aug_data = [] + for degree in set(rotate_degrees): + _results = results.copy() + if degree == 0: + pass + elif degree == 90: + _results['img'] = np.rot90(_results['img'], 1) + elif degree == 180: + _results['img'] = np.rot90(_results['img'], 2) + elif degree == 270: + _results['img'] = np.rot90(_results['img'], 3) + data = self.transforms(_results) + aug_data.append(data) + # list of dict to dict of list + aug_data_dict = {key: [] for key in aug_data[0]} + for data in aug_data: + for key, val in data.items(): + aug_data_dict[key].append(val) + return aug_data_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'rotate_degrees={self.rotate_degrees})' + return repr_str diff --git a/mmocr/datasets/pipelines/textdet_targets/__init__.py b/mmocr/datasets/pipelines/textdet_targets/__init__.py new file mode 100644 index 00000000..1565e924 --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/__init__.py @@ -0,0 +1,10 @@ +from .base_textdet_targets import BaseTextDetTargets +from .dbnet_targets import DBNetTargets +from .panet_targets import PANetTargets +from .psenet_targets import PSENetTargets +from .textsnake_targets import TextSnakeTargets + +__all__ = [ + 'BaseTextDetTargets', 'PANetTargets', 'PSENetTargets', 'DBNetTargets', + 'TextSnakeTargets' +] diff --git a/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py b/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py new file mode 100644 index 00000000..183743f8 --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py @@ -0,0 +1,168 @@ +import sys + +import cv2 +import numpy as np +import Polygon as plg +import pyclipper +from mmcv.utils import print_log + +import mmocr.utils.check_argument as check_argument + + +class BaseTextDetTargets: + """Generate text detector ground truths.""" + + def __init__(self): + pass + + def point2line(self, xs, ys, point_1, point_2): + """Compute the distance from point to a line. This is adapted from + https://github.com/MhLiao/DB. + + Args: + xs (ndarray): The x coordinates of size hxw. + ys (ndarray): The y coordinates of size hxw. + point_1 (ndarray): The first point with shape 1x2. + point_2 (ndarray): The second point with shape 1x2. + + Returns: + result (ndarray): The distance matrix of size hxw. + """ + # suppose a triangle with three edge abc with c=point_1 point_2 + # a^2 + a_square = np.square(xs - point_1[0]) + np.square(ys - point_1[1]) + # b^2 + b_square = np.square(xs - point_2[0]) + np.square(ys - point_2[1]) + # c^2 + c_square = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] - + point_2[1]) + # -cosC=(c^2-a^2-b^2)/2(ab) + neg_cos_c = ( + (c_square - a_square - b_square) / + (np.finfo(np.float32).eps + 2 * np.sqrt(a_square * b_square))) + # sinC^2=1-cosC^2 + square_sin = 1 - np.square(neg_cos_c) + square_sin = np.nan_to_num(square_sin) + # distance=a*b*sinC/c=a*h/c=2*area/c + result = np.sqrt(a_square * b_square * square_sin / + (np.finfo(np.float32).eps + c_square)) + # set result to minimum edge if C 0: + padded_polygon = np.array(padded_polygon[0]) + else: + print(f'padding {polygon} with {distance} gets {padded_polygon}') + padded_polygon = polygon.copy().astype(np.int32) + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + x_min = padded_polygon[:, 0].min() + x_max = padded_polygon[:, 0].max() + y_min = padded_polygon[:, 1].min() + y_max = padded_polygon[:, 1].max() + width = x_max - x_min + 1 + height = y_max - y_min + 1 + + polygon[:, 0] = polygon[:, 0] - x_min + polygon[:, 1] = polygon[:, 1] - y_min + + xs = np.broadcast_to( + np.linspace(0, width - 1, num=width).reshape(1, width), + (height, width)) + ys = np.broadcast_to( + np.linspace(0, height - 1, num=height).reshape(height, 1), + (height, width)) + + distance_map = np.zeros((polygon.shape[0], height, width), + dtype=np.float32) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = distance_map.min(axis=0) + + x_min_valid = min(max(0, x_min), canvas.shape[1] - 1) + x_max_valid = min(max(0, x_max), canvas.shape[1] - 1) + y_min_valid = min(max(0, y_min), canvas.shape[0] - 1) + y_max_valid = min(max(0, y_max), canvas.shape[0] - 1) + canvas[y_min_valid:y_max_valid + 1, + x_min_valid:x_max_valid + 1] = np.fmax( + 1 - distance_map[y_min_valid - y_min:y_max_valid - y_max + + height, x_min_valid - x_min:x_max_valid - + x_max + width], + canvas[y_min_valid:y_max_valid + 1, + x_min_valid:x_max_valid + 1]) + + def generate_targets(self, results): + """Generate the gt targets for DBNet. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + assert isinstance(results, dict) + polygons = results['gt_masks'].masks + if 'bbox_fields' in results: + results['bbox_fields'].clear() + ignore_tags = self.find_invalid(results) + h, w, _ = results['img_shape'] + + gt_shrink, ignore_tags = self.generate_kernels((h, w), + polygons, + self.shrink_ratio, + ignore_tags=ignore_tags) + + results = self.ignore_texts(results, ignore_tags) + + # polygons and polygons_ignore reassignment. + polygons = results['gt_masks'].masks + polygons_ignore = results['gt_masks_ignore'].masks + + gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore) + + gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons) + + results['mask_fields'].clear() # rm gt_masks encoded by polygons + results.pop('gt_labels', None) + results.pop('gt_masks', None) + results.pop('gt_bboxes', None) + results.pop('gt_bboxes_ignore', None) + + mapping = { + 'gt_shrink': gt_shrink, + 'gt_shrink_mask': gt_shrink_mask, + 'gt_thr': gt_thr, + 'gt_thr_mask': gt_thr_mask + } + for key, value in mapping.items(): + value = value if isinstance(value, list) else [value] + results[key] = BitmapMasks(value, h, w) + results['mask_fields'].append(key) + + return results diff --git a/mmocr/datasets/pipelines/textdet_targets/panet_targets.py b/mmocr/datasets/pipelines/textdet_targets/panet_targets.py new file mode 100644 index 00000000..11b85283 --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/panet_targets.py @@ -0,0 +1,63 @@ +from mmdet.core import BitmapMasks +from mmdet.datasets.builder import PIPELINES +from . import BaseTextDetTargets + + +@PIPELINES.register_module() +class PANetTargets(BaseTextDetTargets): + """Generate the ground truths for PANet: Efficient and Accurate Arbitrary- + Shaped Text Detection with Pixel Aggregation Network. + + [https://arxiv.org/abs/1908.05900]. This code is partially adapted from + https://github.com/WenmuZhou/PAN.pytorch. + + Args: + shrink_ratio (tuple[float]): The ratios for shrinking text instances. + max_shrink (int): The maximum shrink distance. + """ + + def __init__(self, shrink_ratio=(1.0, 0.5), max_shrink=20): + self.shrink_ratio = shrink_ratio + self.max_shrink = max_shrink + + def generate_targets(self, results): + """Generate the gt targets for PANet. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + + assert isinstance(results, dict) + + polygon_masks = results['gt_masks'].masks + polygon_masks_ignore = results['gt_masks_ignore'].masks + + h, w, _ = results['img_shape'] + gt_kernels = [] + for ratio in self.shrink_ratio: + mask, _ = self.generate_kernels((h, w), + polygon_masks, + ratio, + max_shrink=self.max_shrink, + ignore_tags=None) + gt_kernels.append(mask) + gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore) + + results['mask_fields'].clear() # rm gt_masks encoded by polygons + if 'bbox_fields' in results: + results['bbox_fields'].clear() + results.pop('gt_labels', None) + results.pop('gt_masks', None) + results.pop('gt_bboxes', None) + results.pop('gt_bboxes_ignore', None) + + mapping = {'gt_kernels': gt_kernels, 'gt_mask': gt_mask} + for key, value in mapping.items(): + value = value if isinstance(value, list) else [value] + results[key] = BitmapMasks(value, h, w) + results['mask_fields'].append(key) + + return results diff --git a/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py b/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py new file mode 100644 index 00000000..d78b7fd7 --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/psenet_targets.py @@ -0,0 +1,21 @@ +from mmdet.datasets.builder import PIPELINES +from . import PANetTargets + + +@PIPELINES.register_module() +class PSENetTargets(PANetTargets): + """Generate the ground truth targets of PSENet: Shape robust text detection + with progressive scale expansion network. + + [https://arxiv.org/abs/1903.12473]. This code is partially adapted from + https://github.com/whai362/PSENet. + + Args: + shrink_ratio(tuple(float)): The ratios for shrinking text instances. + max_shrink(int): The maximum shrinking distance. + """ + + def __init__(self, + shrink_ratio=(1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4), + max_shrink=20): + super().__init__(shrink_ratio=shrink_ratio, max_shrink=max_shrink) diff --git a/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py new file mode 100644 index 00000000..eeebf5bd --- /dev/null +++ b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py @@ -0,0 +1,454 @@ +import cv2 +import numpy as np +from numpy.linalg import norm + +import mmocr.utils.check_argument as check_argument +from mmdet.core import BitmapMasks +from mmdet.datasets.builder import PIPELINES +from . import BaseTextDetTargets + + +@PIPELINES.register_module() +class TextSnakeTargets(BaseTextDetTargets): + """Generate the ground truth targets of TextSnake: TextSnake: A Flexible + Representation for Detecting Text of Arbitrary Shapes. + + [https://arxiv.org/abs/1807.01544]. This was partially adapted from + https://github.com/princewang1994/TextSnake.pytorch. + + Args: + orientation_thr (float): The threshold for distinguishing between + head edge and tail edge among the horizontal and vertical edges + of a quadrangle. + """ + + def __init__(self, + orientation_thr=2.0, + resample_step=4.0, + center_region_shrink_ratio=0.3): + + super().__init__() + self.orientation_thr = orientation_thr + self.resample_step = resample_step + self.center_region_shrink_ratio = center_region_shrink_ratio + + def vector_angle(self, vec1, vec2): + if vec1.ndim > 1: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8) + if vec2.ndim > 1: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1)) + else: + unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8) + return np.arccos( + np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0)) + + def vector_slope(self, vec): + assert len(vec) == 2 + return abs(vec[1] / (vec[0] + 1e-8)) + + def vector_sin(self, vec): + assert len(vec) == 2 + return vec[1] / (norm(vec) + 1e-8) + + def vector_cos(self, vec): + assert len(vec) == 2 + return vec[0] / (norm(vec) + 1e-8) + + def find_head_tail(self, points, orientation_thr): + """Find the head edge and tail edge of a text polygon. + + Args: + points (ndarray): The points composing a text polygon. + orientation_thr (float): The threshold for distinguishing between + head edge and tail edge among the horizontal and vertical edges + of a quadrangle. + + Returns: + head_inds (list): The indexes of two points composing head edge. + tail_inds (list): The indexes of two points composing tail edge. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + assert isinstance(orientation_thr, float) + + if len(points) > 4: + pad_points = np.vstack([points, points[0]]) + edge_vec = pad_points[1:] - pad_points[:-1] + + theta_sum = [] + + for i, edge_vec1 in enumerate(edge_vec): + adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]] + adjacent_edge_vec = edge_vec[adjacent_ind] + temp_theta_sum = np.sum( + self.vector_angle(edge_vec1, adjacent_edge_vec)) + theta_sum.append(temp_theta_sum) + theta_sum = np.array(theta_sum) + head_start, tail_start = np.argsort(theta_sum)[::-1][0:2] + + if (abs(head_start - tail_start) < 2 + or abs(head_start - tail_start) > 12): + tail_start = (head_start + len(points) // 2) % len(points) + head_end = (head_start + 1) % len(points) + tail_end = (tail_start + 1) % len(points) + + if head_end > tail_end: + head_start, tail_start = tail_start, head_start + head_end, tail_end = tail_end, head_end + head_inds = [head_start, head_end] + tail_inds = [tail_start, tail_end] + else: + if self.vector_slope(points[1] - points[0]) + self.vector_slope( + points[3] - points[2]) < self.vector_slope( + points[2] - points[1]) + self.vector_slope(points[0] - + points[3]): + horizontal_edge_inds = [[0, 1], [2, 3]] + vertical_edge_inds = [[3, 0], [1, 2]] + else: + horizontal_edge_inds = [[3, 0], [1, 2]] + vertical_edge_inds = [[0, 1], [2, 3]] + + vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - + points[vertical_edge_inds[0][1]]) + norm( + points[vertical_edge_inds[1][0]] - + points[vertical_edge_inds[1][1]]) + horizontal_len_sum = norm( + points[horizontal_edge_inds[0][0]] - + points[horizontal_edge_inds[0][1]]) + norm( + points[horizontal_edge_inds[1][0]] - + points[horizontal_edge_inds[1][1]]) + + if vertical_len_sum > horizontal_len_sum * orientation_thr: + head_inds = horizontal_edge_inds[0] + tail_inds = horizontal_edge_inds[1] + else: + head_inds = vertical_edge_inds[0] + tail_inds = vertical_edge_inds[1] + + return head_inds, tail_inds + + def reorder_poly_edge(self, points): + """Get the respective points composing head edge, tail edge, top + sideline and bottom sideline. + + Args: + points (ndarray): The points composing a text polygon. + + Returns: + head_edge (ndarray): The two points composing the head edge of text + polygon. + tail_edge (ndarray): The two points composing the tail edge of text + polygon. + top_sideline (ndarray): The points composing top curved sideline of + text polygon. + bot_sideline (ndarray): The points composing bottom curved sideline + of text polygon. + """ + + assert points.ndim == 2 + assert points.shape[0] >= 4 + assert points.shape[1] == 2 + + head_inds, tail_inds = self.find_head_tail(points, + self.orientation_thr) + head_edge, tail_edge = points[head_inds], points[tail_inds] + + pad_points = np.vstack([points, points]) + if tail_inds[1] < 1: + tail_inds[1] = len(points) + sideline1 = pad_points[head_inds[1]:tail_inds[1]] + sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))] + sideline_mean_shift = np.mean( + sideline1, axis=0) - np.mean( + sideline2, axis=0) + + if sideline_mean_shift[1] > 0: + top_sideline, bot_sideline = sideline2, sideline1 + else: + top_sideline, bot_sideline = sideline1, sideline2 + + return head_edge, tail_edge, top_sideline, bot_sideline + + def resample_line(self, line, n): + """Resample n points on a line. + + Args: + line (ndarray): The points composing a line. + n (int): The resampled points number. + + Returns: + resampled_line (ndarray): The points composing the resampled line. + """ + + assert line.ndim == 2 + assert line.shape[0] >= 2 + assert line.shape[1] == 2 + assert isinstance(n, int) + + length_list = [ + norm(line[i + 1] - line[i]) for i in range(len(line) - 1) + ] + total_length = sum(length_list) + length_cumsum = np.cumsum([0.0] + length_list) + delta_length = total_length / (float(n) + 1e-8) + + current_edge_ind = 0 + resampled_line = [line[0]] + + for i in range(1, n): + current_line_len = i * delta_length + + while current_line_len >= length_cumsum[current_edge_ind + 1]: + current_edge_ind += 1 + current_edge_end_shift = current_line_len - length_cumsum[ + current_edge_ind] + end_shift_ratio = current_edge_end_shift / length_list[ + current_edge_ind] + current_point = line[current_edge_ind] + ( + line[current_edge_ind + 1] - + line[current_edge_ind]) * end_shift_ratio + resampled_line.append(current_point) + + resampled_line.append(line[-1]) + resampled_line = np.array(resampled_line) + + return resampled_line + + def resample_sidelines(self, sideline1, sideline2, resample_step): + """Resample two sidelines to be of the same points number according to + step size. + + Args: + sideline1 (ndarray): The points composing a sideline of a text + polygon. + sideline2 (ndarray): The points composing another sideline of a + text polygon. + resample_step (float): The resampled step size. + + Returns: + resampled_line1 (ndarray): The resampled line 1. + resampled_line2 (ndarray): The resampled line 2. + """ + + assert sideline1.ndim == sideline1.ndim == 2 + assert sideline1.shape[1] == sideline1.shape[1] == 2 + assert sideline1.shape[0] >= 2 + assert sideline2.shape[0] >= 2 + assert isinstance(resample_step, float) + + length1 = sum([ + norm(sideline1[i + 1] - sideline1[i]) + for i in range(len(sideline1) - 1) + ]) + length2 = sum([ + norm(sideline2[i + 1] - sideline2[i]) + for i in range(len(sideline2) - 1) + ]) + + total_length = (length1 + length2) / 2 + resample_point_num = int(float(total_length) / resample_step) + + resampled_line1 = self.resample_line(sideline1, resample_point_num) + resampled_line2 = self.resample_line(sideline2, resample_point_num) + + return resampled_line1, resampled_line2 + + def draw_center_region_maps(self, top_line, bot_line, center_line, + center_region_mask, radius_map, sin_map, + cos_map, region_shrink_ratio): + """Draw attributes on text center region. + + Args: + top_line (ndarray): The points composing top curved sideline of + text polygon. + bot_line (ndarray): The points composing bottom curved sideline + of text polygon. + center_line (ndarray): The points composing the center line of text + instance. + center_region_mask (ndarray): The text center region mask. + radius_map (ndarray): The map where the distance from point to + sidelines will be drawn on for each pixel in text center + region. + sin_map (ndarray): The map where vector_sin(theta) will be drawn + on text center regions. Theta is the angle between tangent + line and vector (1, 0). + cos_map (ndarray): The map where vector_cos(theta) will be drawn on + text center regions. Theta is the angle between tangent line + and vector (1, 0). + region_shrink_ratio (float): The shrink ratio of text center. + """ + + assert top_line.shape == bot_line.shape == center_line.shape + assert (center_region_mask.shape == radius_map.shape == sin_map.shape + == cos_map.shape) + assert isinstance(region_shrink_ratio, float) + for i in range(0, len(center_line) - 1): + + top_mid_point = (top_line[i] + top_line[i + 1]) / 2 + bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2 + radius = norm(top_mid_point - bot_mid_point) / 2 + + text_direction = center_line[i + 1] - center_line[i] + sin_theta = self.vector_sin(text_direction) + cos_theta = self.vector_cos(text_direction) + + pnt_tl = center_line[i] + (top_line[i] - + center_line[i]) * region_shrink_ratio + pnt_tr = center_line[i + 1] + ( + top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio + pnt_br = center_line[i + 1] + ( + bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio + pnt_bl = center_line[i] + (bot_line[i] - + center_line[i]) * region_shrink_ratio + current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br, + pnt_bl]).astype(np.int32) + + cv2.fillPoly(center_region_mask, [current_center_box], color=1) + cv2.fillPoly(sin_map, [current_center_box], color=sin_theta) + cv2.fillPoly(cos_map, [current_center_box], color=cos_theta) + cv2.fillPoly(radius_map, [current_center_box], color=radius) + + def generate_center_mask_attrib_maps(self, img_size, text_polys): + """Generate text center region mask and geometric attribute maps. + + Args: + img_size (tuple): The image size of (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + center_region_mask (ndarray): The text center region mask. + radius_map (ndarray): The distance map from each pixel in text + center region to top sideline. + sin_map (ndarray): The sin(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + cos_map (ndarray): The cos(theta) map where theta is the angle + between vector (top point - bottom point) and vector (1, 0). + """ + + assert isinstance(img_size, tuple) + assert check_argument.is_2dlist(text_polys) + + h, w = img_size + + center_region_mask = np.zeros((h, w), np.uint8) + radius_map = np.zeros((h, w), dtype=np.float32) + sin_map = np.zeros((h, w), dtype=np.float32) + cos_map = np.zeros((h, w), dtype=np.float32) + + for poly in text_polys: + assert len(poly) == 1 + text_instance = [[poly[0][i], poly[0][i + 1]] + for i in range(0, len(poly[0]), 2)] + polygon_points = np.array( + text_instance, dtype=np.int32).reshape(-1, 2) + + _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points) + resampled_top_line, resampled_bot_line = self.resample_sidelines( + top_line, bot_line, self.resample_step) + resampled_bot_line = resampled_bot_line[::-1] + center_line = (resampled_top_line + resampled_bot_line) / 2 + + if self.vector_slope(center_line[-1] - center_line[0]) > 0.9: + if (center_line[-1] - center_line[0])[1] < 0: + center_line = center_line[::-1] + resampled_top_line = resampled_top_line[::-1] + resampled_bot_line = resampled_bot_line[::-1] + else: + if (center_line[-1] - center_line[0])[0] < 0: + center_line = center_line[::-1] + resampled_top_line = resampled_top_line[::-1] + resampled_bot_line = resampled_bot_line[::-1] + + line_head_shrink_len = norm(resampled_top_line[0] - + resampled_bot_line[0]) / 4.0 + line_tail_shrink_len = norm(resampled_top_line[-1] - + resampled_bot_line[-1]) / 4.0 + head_shrink_num = int(line_head_shrink_len // self.resample_step) + tail_shrink_num = int(line_tail_shrink_len // self.resample_step) + + if len(center_line) > head_shrink_num + tail_shrink_num + 2: + center_line = center_line[head_shrink_num:len(center_line) - + tail_shrink_num] + resampled_top_line = resampled_top_line[ + head_shrink_num:len(resampled_top_line) - tail_shrink_num] + resampled_bot_line = resampled_bot_line[ + head_shrink_num:len(resampled_bot_line) - tail_shrink_num] + + self.draw_center_region_maps(resampled_top_line, + resampled_bot_line, center_line, + center_region_mask, radius_map, + sin_map, cos_map, + self.center_region_shrink_ratio) + + return center_region_mask, radius_map, sin_map, cos_map + + def generate_text_region_mask(self, img_size, text_polys): + """Generate text center region mask and geometry attribute maps. + + Args: + img_size (tuple): The image size (height, width). + text_polys (list[list[ndarray]]): The list of text polygons. + + Returns: + text_region_mask (ndarray): The text region mask. + """ + + assert isinstance(img_size, tuple) + assert check_argument.is_2dlist(text_polys) + + h, w = img_size + text_region_mask = np.zeros((h, w), dtype=np.uint8) + + for poly in text_polys: + assert len(poly) == 1 + text_instance = [[poly[0][i], poly[0][i + 1]] + for i in range(0, len(poly[0]), 2)] + polygon = np.array( + text_instance, dtype=np.int32).reshape((1, -1, 2)) + cv2.fillPoly(text_region_mask, polygon, 1) + + return text_region_mask + + def generate_targets(self, results): + """Generate the gt targets for TextSnake. + + Args: + results (dict): The input result dictionary. + + Returns: + results (dict): The output result dictionary. + """ + + assert isinstance(results, dict) + + polygon_masks = results['gt_masks'].masks + polygon_masks_ignore = results['gt_masks_ignore'].masks + + h, w, _ = results['img_shape'] + + gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks) + gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore) + + (gt_center_region_mask, gt_radius_map, gt_sin_map, + gt_cos_map) = self.generate_center_mask_attrib_maps((h, w), + polygon_masks) + + results['mask_fields'].clear() # rm gt_masks encoded by polygons + mapping = { + 'gt_text_mask': gt_text_mask, + 'gt_center_region_mask': gt_center_region_mask, + 'gt_mask': gt_mask, + 'gt_radius_map': gt_radius_map, + 'gt_sin_map': gt_sin_map, + 'gt_cos_map': gt_cos_map + } + for key, value in mapping.items(): + value = value if isinstance(value, list) else [value] + results[key] = BitmapMasks(value, h, w) + results['mask_fields'].append(key) + + return results diff --git a/mmocr/datasets/pipelines/transforms.py b/mmocr/datasets/pipelines/transforms.py new file mode 100644 index 00000000..537c107e --- /dev/null +++ b/mmocr/datasets/pipelines/transforms.py @@ -0,0 +1,727 @@ +import math + +import cv2 +import numpy as np +import torchvision.transforms as transforms +from PIL import Image + +import mmocr.core.evaluation.utils as eval_utils +from mmdet.core import BitmapMasks, PolygonMasks +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines.transforms import Resize +from mmocr.utils import check_argument + + +@PIPELINES.register_module() +class RandomCropInstances: + """Randomly crop images and make sure to contain text instances. + + Args: + target_size (tuple or int): (height, width) + positive_sample_ratio (float): The probability of sampling regions + that go through positive regions. + """ + + def __init__( + self, + target_size, + instance_key, + mask_type='inx0', # 'inx0' or 'union_all' + positive_sample_ratio=5.0 / 8.0): + + assert mask_type in ['inx0', 'union_all'] + + self.mask_type = mask_type + self.instance_key = instance_key + self.positive_sample_ratio = positive_sample_ratio + self.target_size = target_size if (target_size is None or isinstance( + target_size, tuple)) else (target_size, target_size) + + def sample_offset(self, img_gt, img_size): + h, w = img_size + t_h, t_w = self.target_size + + # target size is bigger than origin size + t_h = t_h if t_h < h else h + t_w = t_w if t_w < w else w + if (img_gt is not None + and np.random.random_sample() < self.positive_sample_ratio + and np.max(img_gt) > 0): + + # make sure to crop the positive region + + # the minimum top left to crop positive region (h,w) + tl = np.min(np.where(img_gt > 0), axis=1) - (t_h, t_w) + tl[tl < 0] = 0 + # the maximum top left to crop positive region + br = np.max(np.where(img_gt > 0), axis=1) - (t_h, t_w) + br[br < 0] = 0 + # if br is too big so that crop the outside region of img + br[0] = min(br[0], h - t_h) + br[1] = min(br[1], w - t_w) + # + h = np.random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + w = np.random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 + else: + # make sure not to crop outside of img + + h = np.random.randint(0, h - t_h) if h - t_h > 0 else 0 + w = np.random.randint(0, w - t_w) if w - t_w > 0 else 0 + + return (h, w) + + @staticmethod + def crop_img(img, offset, target_size): + h, w = img.shape[:2] + br = np.min( + np.stack((np.array(offset) + np.array(target_size), np.array( + (h, w)))), + axis=0) + return img[offset[0]:br[0], offset[1]:br[1]], np.array( + [offset[1], offset[0], br[1], br[0]]) + + def crop_bboxes(self, bboxes, canvas_bbox): + kept_bboxes = [] + kept_inx = [] + canvas_poly = eval_utils.box2polygon(canvas_bbox) + tl = canvas_bbox[0:2] + + for inx, bbox in enumerate(bboxes): + poly = eval_utils.box2polygon(bbox) + area, inters = eval_utils.poly_intersection(poly, canvas_poly) + if area == 0: + continue + xmin, xmax, ymin, ymax = inters.boundingBox() + kept_bboxes += [ + np.array( + [xmin - tl[0], ymin - tl[1], xmax - tl[0], ymax - tl[1]], + dtype=np.float32) + ] + kept_inx += [inx] + + if len(kept_inx) == 0: + return np.array([]).astype(np.float32).reshape(0, 4), kept_inx + + return np.stack(kept_bboxes), kept_inx + + @staticmethod + def generate_mask(gt_mask, type): + + if type == 'inx0': + return gt_mask.masks[0] + if type == 'union_all': + mask = gt_mask.masks[0].copy() + for inx in range(1, len(gt_mask.masks)): + mask = np.logical_or(mask, gt_mask.masks[inx]) + return mask + + raise NotImplementedError + + def __call__(self, results): + + gt_mask = results[self.instance_key] + mask = None + if len(gt_mask.masks) > 0: + mask = self.generate_mask(gt_mask, self.mask_type) + results['crop_offset'] = self.sample_offset(mask, + results['img'].shape[:2]) + + # crop img. bbox = [x1,y1,x2,y2] + img, bbox = self.crop_img(results['img'], results['crop_offset'], + self.target_size) + results['img'] = img + img_shape = img.shape + results['img_shape'] = img_shape + + # crop masks + for key in results.get('mask_fields', []): + results[key] = results[key].crop(bbox) + + # for mask rcnn + for key in results.get('bbox_fields', []): + results[key], kept_inx = self.crop_bboxes(results[key], bbox) + if key == 'gt_bboxes': + # ignore gt_labels accordingly + if 'gt_labels' in results: + ori_labels = results['gt_labels'] + ori_inst_num = len(ori_labels) + results['gt_labels'] = [ + ori_labels[inx] for inx in range(ori_inst_num) + if inx in kept_inx + ] + # ignore g_masks accordingly + if 'gt_masks' in results: + ori_mask = results['gt_masks'].masks + kept_mask = [ + ori_mask[inx] for inx in range(ori_inst_num) + if inx in kept_inx + ] + target_h, target_w = bbox[3] - bbox[1], bbox[2] - bbox[0] + if len(kept_inx) > 0: + kept_mask = np.stack(kept_mask) + else: + kept_mask = np.empty((0, target_h, target_w), + dtype=np.float32) + results['gt_masks'] = BitmapMasks(kept_mask, target_h, + target_w) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomRotateTextDet: + """Randomly rotate images.""" + + def __init__(self, rotate_ratio=1.0, max_angle=10): + self.rotate_ratio = rotate_ratio + self.max_angle = max_angle + + @staticmethod + def sample_angle(max_angle): + angle = np.random.random_sample() * 2 * max_angle - max_angle + return angle + + @staticmethod + def rotate_img(img, angle): + h, w = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + img_target = cv2.warpAffine( + img, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST) + assert img_target.shape == img.shape + return img_target + + def __call__(self, results): + if np.random.random_sample() < self.rotate_ratio: + # rotate imgs + results['rotated_angle'] = self.sample_angle(self.max_angle) + img = self.rotate_img(results['img'], results['rotated_angle']) + results['img'] = img + img_shape = img.shape + results['img_shape'] = img_shape + + # rotate masks + for key in results.get('mask_fields', []): + masks = results[key].masks + mask_list = [] + for m in masks: + rotated_m = self.rotate_img(m, results['rotated_angle']) + mask_list.append(rotated_m) + results[key] = BitmapMasks(mask_list, *(img_shape[:2])) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class ColorJitter: + """An interface for torch color jitter so that it can be invoked in + mmdetection pipeline.""" + + def __init__(self, **kwargs): + self.transform = transforms.ColorJitter(**kwargs) + + def __call__(self, results): + # img is bgr + img = results['img'][..., ::-1] + img = Image.fromarray(img) + img = self.transform(img) + img = np.asarray(img) + img = img[..., ::-1] + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class ScaleAspectJitter(Resize): + """Resize image and segmentation mask encoded by coordinates. + + Allowed resize types are `around_min_img_scale`, `long_short_bound`, and + `indep_sample_in_range`. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=False, + resize_type='around_min_img_scale', + aspect_ratio_range=None, + long_size_bound=None, + short_size_bound=None, + scale_range=None): + super().__init__( + img_scale=img_scale, + multiscale_mode=multiscale_mode, + ratio_range=ratio_range, + keep_ratio=keep_ratio) + assert not keep_ratio + assert resize_type in [ + 'around_min_img_scale', 'long_short_bound', 'indep_sample_in_range' + ] + self.resize_type = resize_type + + if resize_type == 'indep_sample_in_range': + assert ratio_range is None + assert aspect_ratio_range is None + assert short_size_bound is None + assert long_size_bound is None + assert scale_range is not None + else: + assert scale_range is None + assert isinstance(ratio_range, tuple) + assert isinstance(aspect_ratio_range, tuple) + assert check_argument.equal_len(ratio_range, aspect_ratio_range) + + if resize_type in ['long_short_bound']: + assert short_size_bound is not None + assert long_size_bound is not None + + self.aspect_ratio_range = aspect_ratio_range + self.long_size_bound = long_size_bound + self.short_size_bound = short_size_bound + self.scale_range = scale_range + + @staticmethod + def sample_from_range(range): + assert len(range) == 2 + min_value, max_value = min(range), max(range) + value = np.random.random_sample() * (max_value - min_value) + min_value + + return value + + def _random_scale(self, results): + + if self.resize_type == 'indep_sample_in_range': + w = self.sample_from_range(self.scale_range) + h = self.sample_from_range(self.scale_range) + results['scale'] = (int(w), int(h)) # (w,h) + results['scale_idx'] = None + return + h, w = results['img'].shape[0:2] + if self.resize_type == 'long_short_bound': + scale1 = 1 + if max(h, w) > self.long_size_bound: + scale1 = self.long_size_bound / max(h, w) + scale2 = self.sample_from_range(self.ratio_range) + scale = scale1 * scale2 + if min(h, w) * scale <= self.short_size_bound: + scale = (self.short_size_bound + 10) * 1.0 / min(h, w) + elif self.resize_type == 'around_min_img_scale': + short_size = min(self.img_scale[0]) + ratio = self.sample_from_range(self.ratio_range) + scale = (ratio * short_size) / min(h, w) + else: + raise NotImplementedError + + aspect = self.sample_from_range(self.aspect_ratio_range) + h_scale = scale * math.sqrt(aspect) + w_scale = scale / math.sqrt(aspect) + results['scale'] = (int(w * w_scale), int(h * h_scale)) # (w,h) + results['scale_idx'] = None + + +@PIPELINES.register_module() +class AffineJitter: + """An interface for torchvision random affine so that it can be invoked in + mmdet pipeline.""" + + def __init__(self, + degrees=4, + translate=(0.02, 0.04), + scale=(0.9, 1.1), + shear=None, + resample=False, + fillcolor=0): + self.transform = transforms.RandomAffine( + degrees=degrees, + translate=translate, + scale=scale, + shear=shear, + resample=resample, + fillcolor=fillcolor) + + def __call__(self, results): + # img is bgr + img = results['img'][..., ::-1] + img = Image.fromarray(img) + img = self.transform(img) + img = np.asarray(img) + img = img[..., ::-1] + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomCropPolyInstances: + """Randomly crop images and make sure to contain at least one intact + instance.""" + + def __init__(self, + instance_key='gt_masks', + crop_ratio=5.0 / 8.0, + min_side_ratio=0.4): + super().__init__() + self.instance_key = instance_key + self.crop_ratio = crop_ratio + self.min_side_ratio = min_side_ratio + + def sample_valid_start_end(self, valid_array, min_len, max_start, min_end): + + assert isinstance(min_len, int) + assert len(valid_array) > min_len + + start_array = valid_array.copy() + max_start = min(len(start_array) - min_len, max_start) + start_array[max_start:] = 0 + start_array[0] = 1 + diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + start = np.random.randint(region_starts[region_ind], + region_ends[region_ind]) + + end_array = valid_array.copy() + min_end = max(start + min_len, min_end) + end_array[:min_end] = 0 + end_array[-1] = 1 + diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0]) + region_starts = np.where(diff_array < 0)[0] + region_ends = np.where(diff_array > 0)[0] + region_ind = np.random.randint(0, len(region_starts)) + end = np.random.randint(region_starts[region_ind], + region_ends[region_ind]) + return start, end + + def sample_crop_box(self, img_size, masks): + """Generate crop box and make sure not to crop the polygon instances. + + Args: + img_size (tuple(int)): The image size. + masks (list[list[ndarray]]): The polygon masks. + """ + + assert isinstance(img_size, tuple) + h, w = img_size[:2] + + x_valid_array = np.ones(w, dtype=np.int32) + y_valid_array = np.ones(h, dtype=np.int32) + + selected_mask = masks[np.random.randint(0, len(masks))] + selected_mask = selected_mask[0].reshape((-1, 2)).astype(np.int32) + max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0) + min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1) + max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0) + min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1) + + for mask in masks: + assert len(mask) == 1 + mask = mask[0].reshape((-1, 2)).astype(np.int32) + clip_x = np.clip(mask[:, 0], 0, w - 1) + clip_y = np.clip(mask[:, 1], 0, h - 1) + min_x, max_x = np.min(clip_x), np.max(clip_x) + min_y, max_y = np.min(clip_y), np.max(clip_y) + + x_valid_array[min_x - 2:max_x + 3] = 0 + y_valid_array[min_y - 2:max_y + 3] = 0 + + min_w = int(w * self.min_side_ratio) + min_h = int(h * self.min_side_ratio) + + x1, x2 = self.sample_valid_start_end(x_valid_array, min_w, max_x_start, + min_x_end) + y1, y2 = self.sample_valid_start_end(y_valid_array, min_h, max_y_start, + min_y_end) + + return np.array([x1, y1, x2, y2]) + + def crop_img(self, img, bbox): + assert img.ndim == 3 + h, w, _ = img.shape + assert 0 <= bbox[1] < bbox[3] <= h + assert 0 <= bbox[0] < bbox[2] <= w + return img[bbox[1]:bbox[3], bbox[0]:bbox[2]] + + def __call__(self, results): + if np.random.random_sample() < self.crop_ratio: + crop_box = self.sample_crop_box(results['img'].shape, + results[self.instance_key].masks) + results['crop_region'] = crop_box + img = self.crop_img(results['img'], crop_box) + results['img'] = img + results['img_shape'] = img.shape + + # crop and filter masks + x1, y1, x2, y2 = crop_box + w = max(x2 - x1, 1) + h = max(y2 - y1, 1) + labels = results['gt_labels'] + valid_labels = [] + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + results[key] = results[key].crop(crop_box) + # filter out polygons beyond crop box. + masks = results[key].masks + valid_masks_list = [] + + for ind, mask in enumerate(masks): + assert len(mask) == 1 + polygon = mask[0].reshape((-1, 2)) + if (polygon[:, 0] > + -4).all() and (polygon[:, 0] < w + 4).all() and ( + polygon[:, 1] > -4).all() and (polygon[:, 1] < + h + 4).all(): + mask[0][::2] = np.clip(mask[0][::2], 0, w) + mask[0][1::2] = np.clip(mask[0][1::2], 0, h) + if key == self.instance_key: + valid_labels.append(labels[ind]) + valid_masks_list.append(mask) + + results[key] = PolygonMasks(valid_masks_list, h, w) + results['gt_labels'] = np.array(valid_labels) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class RandomRotatePolyInstances: + + def __init__(self, + rotate_ratio=0.5, + max_angle=10, + pad_with_fixed_color=False, + pad_value=(0, 0, 0)): + """Randomly rotate images and polygon masks. + + Args: + rotate_ratio (float): The ratio of samples to operate rotation. + max_angle (int): The maximum rotation angle. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rotated image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + self.rotate_ratio = rotate_ratio + self.max_angle = max_angle + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def rotate(self, center, points, theta, center_shift=(0, 0)): + # rotate points. + (center_x, center_y) = center + center_y = -center_y + x, y = points[::2], points[1::2] + y = -y + + theta = theta / 180 * math.pi + cos = math.cos(theta) + sin = math.sin(theta) + + x = (x - center_x) + y = (y - center_y) + + _x = center_x + x * cos - y * sin + center_shift[0] + _y = -(center_y + x * sin + y * cos) + center_shift[1] + + points[::2], points[1::2] = _x, _y + return points + + def cal_canvas_size(self, ori_size, degree): + assert isinstance(ori_size, tuple) + angle = degree * math.pi / 180.0 + h, w = ori_size[:2] + + cos = math.cos(angle) + sin = math.sin(angle) + canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos)) + canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin)) + + canvas_size = (canvas_h, canvas_w) + return canvas_size + + def sample_angle(self, max_angle): + angle = np.random.random_sample() * 2 * max_angle - max_angle + return angle + + def rotate_img(self, img, angle, canvas_size): + h, w = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) + rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2) + rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2) + + if self.pad_with_fixed_color: + target_img = cv2.warpAffine( + img, + rotation_matrix, (canvas_size[1], canvas_size[0]), + flags=cv2.INTER_NEAREST, + borderValue=self.pad_value) + else: + mask = np.zeros_like(img) + (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8)) + img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] + img_cut = cv2.resize(img_cut, (canvas_size[1], canvas_size[0])) + mask = cv2.warpAffine( + mask, + rotation_matrix, (canvas_size[1], canvas_size[0]), + borderValue=[1, 1, 1]) + target_img = cv2.warpAffine( + img, + rotation_matrix, (canvas_size[1], canvas_size[0]), + borderValue=[0, 0, 0]) + target_img = target_img + img_cut * mask + + return target_img + + def __call__(self, results): + if np.random.random_sample() < self.rotate_ratio: + img = results['img'] + h, w = img.shape[:2] + angle = self.sample_angle(self.max_angle) + canvas_size = self.cal_canvas_size((h, w), angle) + center_shift = (int( + (canvas_size[1] - w) / 2), int((canvas_size[0] - h) / 2)) + + # rotate image + results['rotated_poly_angle'] = angle + img = self.rotate_img(img, angle, canvas_size) + results['img'] = img + img_shape = img.shape + results['img_shape'] = img_shape + + # rotate polygons + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + masks = results[key].masks + rotated_masks = [] + for mask in masks: + rotated_mask = self.rotate((w / 2, h / 2), mask[0], angle, + center_shift) + rotated_masks.append([rotated_mask]) + + results[key] = PolygonMasks(rotated_masks, *(img_shape[:2])) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +@PIPELINES.register_module() +class SquareResizePad: + + def __init__(self, + target_size, + pad_ratio=0.6, + pad_with_fixed_color=False, + pad_value=(0, 0, 0)): + """Resize or pad images to be square shape. + + Args: + target_size (int): The target size of square shaped image. + pad_with_fixed_color (bool): The flag for whether to pad rotated + image with fixed value. If set to False, the rescales image will + be padded onto cropped image. + pad_value (tuple(int)): The color value for padding rotated image. + """ + assert isinstance(target_size, int) + assert isinstance(pad_ratio, float) + assert isinstance(pad_with_fixed_color, bool) + assert isinstance(pad_value, tuple) + + self.target_size = target_size + self.pad_ratio = pad_ratio + self.pad_with_fixed_color = pad_with_fixed_color + self.pad_value = pad_value + + def resize_img(self, img, keep_ratio=True): + h, w, _ = img.shape + if keep_ratio: + t_h = self.target_size if h >= w else int(h * self.target_size / w) + t_w = self.target_size if h <= w else int(w * self.target_size / h) + else: + t_h = t_w = self.target_size + img = cv2.resize(img, (t_w, t_h)) + return img, (t_h, t_w) + + def square_pad(self, img): + h, w = img.shape[:2] + if h == w: + return img, (0, 0) + pad_size = max(h, w) + if self.pad_with_fixed_color: + expand_img = np.ones((pad_size, pad_size, 3), dtype=np.uint8) + expand_img[:] = self.pad_value + else: + (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), + np.random.randint(0, w * 7 // 8)) + img_cut = img[h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] + expand_img = cv2.resize(img_cut, (pad_size, pad_size)) + if h > w: + y0, x0 = 0, (h - w) // 2 + else: + y0, x0 = (w - h) // 2, 0 + expand_img[y0:y0 + h, x0:x0 + w] = img + offset = (x0, y0) + + return expand_img, offset + + def square_pad_mask(self, points, offset): + x0, y0 = offset + pad_points = points.copy() + pad_points[::2] = pad_points[::2] + x0 + pad_points[1::2] = pad_points[1::2] + y0 + return pad_points + + def __call__(self, results): + img = results['img'] + + if np.random.random_sample() < self.pad_ratio: + img, out_size = self.resize_img(img, keep_ratio=True) + img, offset = self.square_pad(img) + else: + img, out_size = self.resize_img(img, keep_ratio=False) + offset = (0, 0) + + results['img'] = img + results['img_shape'] = img.shape + + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + results[key] = results[key].resize(out_size) + masks = results[key].masks + processed_masks = [] + for mask in masks: + square_pad_mask = self.square_pad_mask(mask[0], offset) + processed_masks.append([square_pad_mask]) + + results[key] = PolygonMasks(processed_masks, *(img.shape[:2])) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str diff --git a/mmocr/datasets/text_det_dataset.py b/mmocr/datasets/text_det_dataset.py new file mode 100644 index 00000000..65a2d9d9 --- /dev/null +++ b/mmocr/datasets/text_det_dataset.py @@ -0,0 +1,121 @@ +import numpy as np + +from mmdet.datasets.builder import DATASETS +from mmocr.core.evaluation.hmean import eval_hmean +from mmocr.datasets.base_dataset import BaseDataset + + +@DATASETS.register_module() +class TextDetDataset(BaseDataset): + + def _parse_anno_info(self, annotations): + """Parse bbox and mask annotation. + Args: + annotations (dict): Annotations of one image. + + Returns: + dict: A dict containing the following keys: bboxes, bboxes_ignore, + labels, masks, masks_ignore. "masks" and + "masks_ignore" are represented by polygon boundary + point sequences. + """ + gt_bboxes, gt_bboxes_ignore = [], [] + gt_masks, gt_masks_ignore = [], [] + gt_labels = [] + for ann in annotations: + if ann.get('iscrowd', False): + gt_bboxes_ignore.append(ann['bbox']) + gt_masks_ignore.append(ann.get('segmentation', None)) + else: + gt_bboxes.append(ann['bbox']) + gt_labels.append(ann['category_id']) + gt_masks.append(ann.get('segmentation', None)) + if gt_bboxes: + gt_bboxes = np.array(gt_bboxes, dtype=np.float32) + gt_labels = np.array(gt_labels, dtype=np.int64) + else: + gt_bboxes = np.zeros((0, 4), dtype=np.float32) + gt_labels = np.array([], dtype=np.int64) + + if gt_bboxes_ignore: + gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) + else: + gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) + + ann = dict( + bboxes=gt_bboxes, + labels=gt_labels, + bboxes_ignore=gt_bboxes_ignore, + masks_ignore=gt_masks_ignore, + masks=gt_masks) + + return ann + + def prepare_train_img(self, index): + """Get training data and annotations from pipeline. + + Args: + index (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + img_ann_info = self.data_infos[index] + img_info = { + 'filename': img_ann_info['file_name'], + 'height': img_ann_info['height'], + 'width': img_ann_info['width'] + } + ann_info = self._parse_anno_info(img_ann_info['annotations']) + results = dict(img_info=img_info, ann_info=ann_info) + results['bbox_fields'] = [] + results['mask_fields'] = [] + results['seg_fields'] = [] + self.pre_pipeline(results) + + return self.pipeline(results) + + def evaluate(self, + results, + metric='hmean-iou', + score_thr=0.3, + rank_list=None, + logger=None, + **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + score_thr (float): Score threshold for prediction map. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + rank_list (str): json file used to save eval result + of each image after ranking. + Returns: + dict[str: float] + """ + metrics = metric if isinstance(metric, list) else [metric] + allowed_metrics = ['hmean-iou', 'hmean-ic13'] + metrics = set(metrics) & set(allowed_metrics) + + img_infos = [] + ann_infos = [] + for i in range(len(self)): + img_ann_info = self.data_infos[i] + img_info = {'filename': img_ann_info['file_name']} + ann_info = self._parse_anno_info(img_ann_info['annotations']) + img_infos.append(img_info) + ann_infos.append(ann_info) + + eval_results = eval_hmean( + results, + img_infos, + ann_infos, + metrics=metrics, + score_thr=score_thr, + logger=logger, + rank_list=rank_list) + + return eval_results diff --git a/mmocr/datasets/utils/__init__.py b/mmocr/datasets/utils/__init__.py new file mode 100644 index 00000000..f014de7e --- /dev/null +++ b/mmocr/datasets/utils/__init__.py @@ -0,0 +1,4 @@ +from .loader import HardDiskLoader, LmdbLoader +from .parser import LineJsonParser, LineStrParser + +__all__ = ['HardDiskLoader', 'LmdbLoader', 'LineStrParser', 'LineJsonParser'] diff --git a/mmocr/datasets/utils/loader.py b/mmocr/datasets/utils/loader.py new file mode 100644 index 00000000..55a4d075 --- /dev/null +++ b/mmocr/datasets/utils/loader.py @@ -0,0 +1,108 @@ +import os.path as osp + +from mmocr.datasets.builder import LOADERS, build_parser + + +@LOADERS.register_module() +class Loader: + """Load annotation from annotation file, and parse instance information to + dict format with parser. + + Args: + ann_file (str): Annotation file path. + parser (dict): Dictionary to construct parser + to parse original annotation infos. + repeat (int): Repeated times of annotations. + """ + + def __init__(self, ann_file, parser, repeat=1): + assert isinstance(ann_file, str) + assert isinstance(repeat, int) + assert isinstance(parser, dict) + assert repeat > 0 + assert osp.exists(ann_file), f'{ann_file} is not exist' + + self.ori_data_infos = self._load(ann_file) + self.parser = build_parser(parser) + self.repeat = repeat + + def __len__(self): + return len(self.ori_data_infos) * self.repeat + + def _load(self, ann_file): + """Load annotation file.""" + raise NotImplementedError + + def __getitem__(self, index): + """Retrieve anno info of one instance with dict format.""" + return self.parser.get_item(self.ori_data_infos, index) + + +@LOADERS.register_module() +class HardDiskLoader(Loader): + """Load annotation file from hard disk to RAM. + + Args: + ann_file (str): Annotation file path. + """ + + def _load(self, ann_file): + data_ret = [] + with open(ann_file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + data_ret.append(line) + + return data_ret + + +@LOADERS.register_module() +class LmdbLoader(Loader): + """Load annotation file with lmdb storage backend.""" + + def _load(self, ann_file): + lmdb_anno_obj = LmdbAnnFileBackend(ann_file) + + return lmdb_anno_obj + + +class LmdbAnnFileBackend: + """Lmdb storage backend for annotation file. + + Args: + lmdb_path (str): Lmdb file path. + """ + + def __init__(self, lmdb_path, coding='utf8'): + self.lmdb_path = lmdb_path + self.coding = coding + env = self._get_env() + with env.begin(write=False) as txn: + self.total_number = int( + txn.get('total_number'.encode(self.coding)).decode( + self.coding)) + + def __getitem__(self, index): + """Retrieval one line from lmdb file by index.""" + # only attach env to self when __getitem__ is called + # because env object cannot be pickle + if not hasattr(self, 'env'): + self.env = self._get_env() + + with self.env.begin(write=False) as txn: + line = txn.get(str(index).encode(self.coding)).decode(self.coding) + return line + + def __len__(self): + return self.total_number + + def _get_env(self): + import lmdb + return lmdb.open( + self.lmdb_path, + max_readers=1, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) diff --git a/mmocr/datasets/utils/parser.py b/mmocr/datasets/utils/parser.py new file mode 100644 index 00000000..a895e217 --- /dev/null +++ b/mmocr/datasets/utils/parser.py @@ -0,0 +1,69 @@ +import json + +from mmocr.datasets.builder import PARSERS + + +@PARSERS.register_module() +class LineStrParser: + """Parse string of one line in annotation file to dict format. + + Args: + keys (list[str]): Keys in result dict. + keys_idx (list[int]): Value index in sub-string list + for each key above. + separator (str): Separator to separate string to list of sub-string. + """ + + def __init__(self, + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' '): + assert isinstance(keys, list) + assert isinstance(keys_idx, list) + assert isinstance(separator, str) + assert len(keys) > 0 + assert len(keys) == len(keys_idx) + self.keys = keys + self.keys_idx = keys_idx + self.separator = separator + + def get_item(self, data_ret, index): + map_index = index % len(data_ret) + line_str = data_ret[map_index] + for split_key in self.separator: + if split_key != ' ': + line_str = line_str.replace(split_key, ' ') + line_str = line_str.split() + if len(line_str) <= max(self.keys_idx): + raise Exception( + f'key index: {max(self.keys_idx)} out of range: {line_str}') + + line_info = {} + for i, key in enumerate(self.keys): + line_info[key] = line_str[self.keys_idx[i]] + return line_info + + +@PARSERS.register_module() +class LineJsonParser: + """Parse json-string of one line in annotation file to dict format. + + Args: + keys (list[str]): Keys in both json-string and result dict. + """ + + def __init__(self, keys=[], **kwargs): + assert isinstance(keys, list) + assert len(keys) > 0 + self.keys = keys + + def get_item(self, data_ret, index): + map_index = index % len(data_ret) + line_json_obj = json.loads(data_ret[map_index]) + line_info = {} + for key in self.keys: + if key not in line_json_obj: + raise Exception(f'key {key} not in line json {line_json_obj}') + line_info[key] = line_json_obj[key] + + return line_info diff --git a/mmocr/models/__init__.py b/mmocr/models/__init__.py new file mode 100644 index 00000000..529a6208 --- /dev/null +++ b/mmocr/models/__init__.py @@ -0,0 +1,16 @@ +from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, + build_backbone, build_detector, build_loss) +from .builder import (CONVERTORS, DECODERS, ENCODERS, PREPROCESSOR, + build_convertor, build_decoder, build_encoder, + build_preprocessor) +from .common import * # noqa: F401,F403 +from .kie import * # noqa: F401,F403 +from .textdet import * # noqa: F401,F403 +from .textrecog import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'DETECTORS', 'HEADS', 'LOSSES', 'NECKS', 'build_backbone', + 'build_detector', 'build_loss', 'CONVERTORS', 'ENCODERS', 'DECODERS', + 'PREPROCESSOR', 'build_convertor', 'build_encoder', 'build_decoder', + 'build_preprocessor' +] diff --git a/mmocr/models/builder.py b/mmocr/models/builder.py new file mode 100644 index 00000000..cd606a8a --- /dev/null +++ b/mmocr/models/builder.py @@ -0,0 +1,33 @@ +from mmcv.utils import Registry, build_from_cfg + +RECOGNIZERS = Registry('recognizer') +CONVERTORS = Registry('convertor') +ENCODERS = Registry('encoder') +DECODERS = Registry('decoder') +PREPROCESSOR = Registry('preprocessor') + + +def build_recognizer(cfg, train_cfg=None, test_cfg=None): + """Build recognizer.""" + return build_from_cfg(cfg, RECOGNIZERS, + dict(train_cfg=train_cfg, test_cfg=test_cfg)) + + +def build_convertor(cfg): + """Build label convertor for scene text recognizer.""" + return build_from_cfg(cfg, CONVERTORS) + + +def build_encoder(cfg): + """Build encoder for scene text recognizer.""" + return build_from_cfg(cfg, ENCODERS) + + +def build_decoder(cfg): + """Build decoder for scene text recognizer.""" + return build_from_cfg(cfg, DECODERS) + + +def build_preprocessor(cfg): + """Build preprocessor for scene text recognizer.""" + return build_from_cfg(cfg, PREPROCESSOR) diff --git a/mmocr/models/common/__init__.py b/mmocr/models/common/__init__.py new file mode 100644 index 00000000..8d662575 --- /dev/null +++ b/mmocr/models/common/__init__.py @@ -0,0 +1,2 @@ +from .backbones import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 diff --git a/mmocr/models/common/backbones/__init__.py b/mmocr/models/common/backbones/__init__.py new file mode 100644 index 00000000..de67ca96 --- /dev/null +++ b/mmocr/models/common/backbones/__init__.py @@ -0,0 +1,3 @@ +from .unet import UNet + +__all__ = ['UNet'] diff --git a/mmocr/models/common/backbones/unet.py b/mmocr/models/common/backbones/unet.py new file mode 100644 index 00000000..cc11a8f8 --- /dev/null +++ b/mmocr/models/common/backbones/unet.py @@ -0,0 +1,529 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, + build_norm_layer, build_upsample_layer, constant_init, + kaiming_init) +from mmcv.runner import load_checkpoint +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super().__init__() + + assert ( + kernel_size - scale_factor >= 0 + and (kernel_size - scale_factor) % 2 == 0), ( + f'kernel_size should be greater than or equal to scale_factor ' + f'and (kernel_size - scale_factor) should be even numbers, ' + f'while the kernel size is {kernel_size} and scale_factor is ' + f'{scale_factor}.') + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + _, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super().__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = nn.Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@BACKBONES.register_module() +class UNet(nn.Module): + """UNet backbone. + U-Net: Convolutional Networks for Biomedical Image Segmentation. + https://arxiv.org/pdf/1505.04597.pdf + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, ( + 'The length of strides should be equal to num_stages, ' + f'while the strides is {strides}, the length of ' + f'strides is {len(strides)}, and the num_stages is ' + f'{num_stages}.') + assert len(enc_num_convs) == num_stages, ( + 'The length of enc_num_convs should be equal to num_stages, ' + f'while the enc_num_convs is {enc_num_convs}, the length of ' + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is ' + f'{num_stages}.') + assert len(dec_num_convs) == (num_stages - 1), ( + 'The length of dec_num_convs should be equal to (num_stages-1), ' + f'while the dec_num_convs is {dec_num_convs}, the length of ' + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is ' + f'{num_stages}.') + assert len(downsamples) == (num_stages - 1), ( + 'The length of downsamples should be equal to (num_stages-1), ' + f'while the downsamples is {downsamples}, the length of ' + f'downsamples is {len(downsamples)}, and the num_stages is ' + f'{num_stages}.') + assert len(enc_dilations) == num_stages, ( + 'The length of enc_dilations should be equal to num_stages, ' + f'while the enc_dilations is {enc_dilations}, the length of ' + f'enc_dilations is {len(enc_dilations)}, and the num_stages is ' + f'{num_stages}.') + assert len(dec_dilations) == (num_stages - 1), ( + 'The length of dec_dilations should be equal to (num_stages-1), ' + f'while the dec_dilations is {dec_dilations}, the length of ' + f'dec_dilations is {len(dec_dilations)}, and the num_stages is ' + f'{num_stages}.') + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append((nn.Sequential(*enc_conv_block))) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert ( + h % whole_downsample_rate == 0 and w % whole_downsample_rate == 0 + ), (f'The input image size {(h, w)} should be divisible by the whole ' + f'downsample rate {whole_downsample_rate}, when num_stages is ' + f'{self.num_stages}, strides is {self.strides}, and downsamples ' + f'is {self.downsamples}.') + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') diff --git a/mmocr/models/common/losses/__init__.py b/mmocr/models/common/losses/__init__.py new file mode 100644 index 00000000..7daa7345 --- /dev/null +++ b/mmocr/models/common/losses/__init__.py @@ -0,0 +1,3 @@ +from .dice_loss import DiceLoss + +__all__ = ['DiceLoss'] diff --git a/mmocr/models/common/losses/dice_loss.py b/mmocr/models/common/losses/dice_loss.py new file mode 100644 index 00000000..0dfdf009 --- /dev/null +++ b/mmocr/models/common/losses/dice_loss.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class DiceLoss(nn.Module): + + def __init__(self, eps=1e-6): + super().__init__() + assert isinstance(eps, float) + self.eps = eps + + def forward(self, pred, target, mask=None): + + pred = pred.contiguous().view(pred.size()[0], -1) + target = target.contiguous().view(target.size()[0], -1) + + if mask is not None: + mask = mask.contiguous().view(mask.size()[0], -1) + pred = pred * mask + target = target * mask + + a = torch.sum(pred * target) + b = torch.sum(pred) + c = torch.sum(target) + d = (2 * a) / (b + c + self.eps) + + return 1 - d diff --git a/mmocr/models/kie/__init__.py b/mmocr/models/kie/__init__.py new file mode 100644 index 00000000..46d98163 --- /dev/null +++ b/mmocr/models/kie/__init__.py @@ -0,0 +1,3 @@ +from .extractors import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 diff --git a/mmocr/models/kie/extractors/__init__.py b/mmocr/models/kie/extractors/__init__.py new file mode 100644 index 00000000..f58541e6 --- /dev/null +++ b/mmocr/models/kie/extractors/__init__.py @@ -0,0 +1,3 @@ +from .sdmgr import SDMGR + +__all__ = ['SDMGR'] diff --git a/mmocr/models/kie/extractors/sdmgr.py b/mmocr/models/kie/extractors/sdmgr.py new file mode 100644 index 00000000..c8275a28 --- /dev/null +++ b/mmocr/models/kie/extractors/sdmgr.py @@ -0,0 +1,154 @@ +import warnings + +import mmcv +from torch import nn +from torch.nn import functional as F + +from mmdet.core import bbox2roi +from mmdet.models.builder import DETECTORS, build_roi_extractor +from mmdet.models.detectors import SingleStageDetector +from mmocr.core import imshow_edge_node + + +@DETECTORS.register_module() +class SDMGR(SingleStageDetector): + """The implementation of the paper: Spatial Dual-Modality Graph Reasoning + for Key Information Extraction. https://arxiv.org/abs/2103.14470. + + Args: + visual_modality (bool): Whether use the visual modality. + class_list (None | str): Mapping file of class index to + class name. If None, class index will be shown in + `show_results`, else class name. + """ + + def __init__(self, + backbone, + neck=None, + bbox_head=None, + extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7), + featmap_strides=[1]), + visual_modality=False, + train_cfg=None, + test_cfg=None, + pretrained=None, + class_list=None): + super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg, + pretrained) + self.visual_modality = visual_modality + if visual_modality: + self.extractor = build_roi_extractor({ + **extractor, 'out_channels': + self.backbone.base_channels + }) + self.maxpool = nn.MaxPool2d(extractor['roi_layer']['output_size']) + else: + self.extractor = None + self.class_list = class_list + + def forward_train(self, img, img_metas, relations, texts, gt_bboxes, + gt_labels): + """ + Args: + img (tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + contains: 'img_shape', 'scale_factor', 'flip', and may also + contain 'filename', 'ori_shape', 'pad_shape', and + 'img_norm_cfg'. For details of the values of these keys, + please see :class:`mmdet.datasets.pipelines.Collect`. + relations (list[tensor]): Relations between bboxes. + texts (list[tensor]): Texts in bboxes. + gt_bboxes (list[tensor]): Each item is the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[tensor]): Class indices corresponding to each box. + + Returns: + dict[str, tensor]: A dictionary of loss components. + """ + x = self.extract_feat(img, gt_bboxes) + node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) + return self.bbox_head.loss(node_preds, edge_preds, gt_labels) + + def forward_test(self, + img, + img_metas, + relations, + texts, + gt_bboxes, + rescale=False): + x = self.extract_feat(img, gt_bboxes) + node_preds, edge_preds = self.bbox_head.forward(relations, texts, x) + return [ + dict( + img_metas=img_metas, + nodes=F.softmax(node_preds, -1), + edges=F.softmax(edge_preds, -1)) + ] + + def extract_feat(self, img, gt_bboxes): + if self.visual_modality: + x = super().extract_feat(img)[-1] + feats = self.maxpool(self.extractor([x], bbox2roi(gt_bboxes))) + return feats.view(feats.size(0), -1) + return None + + def show_result(self, + img, + result, + boxes, + win_name='', + show=False, + wait_time=0, + out_file=None, + **kwargs): + """Draw `result` on `img`. + + Args: + img (str or tensor): The image to be displayed. + result (dict): The results to draw on `img`. + boxes (list): Bbox of img. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The output filename. + Default: None. + + Returns: + img (tensor): Only if not `show` or `out_file`. + """ + img = mmcv.imread(img) + img = img.copy() + + idx_to_cls = {} + if self.class_list is not None: + with open(self.class_list, 'r') as fr: + for line in fr: + line = line.strip().split() + class_idx, class_label = line + idx_to_cls[class_idx] = class_label + + # if out_file specified, do not show image in window + if out_file is not None: + show = False + + img = imshow_edge_node( + img, + result, + boxes, + idx_to_cls=idx_to_cls, + show=show, + win_name=win_name, + wait_time=wait_time, + out_file=out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img + + return img diff --git a/mmocr/models/kie/heads/__init__.py b/mmocr/models/kie/heads/__init__.py new file mode 100644 index 00000000..00a11469 --- /dev/null +++ b/mmocr/models/kie/heads/__init__.py @@ -0,0 +1,3 @@ +from .sdmgr_head import SDMGRHead + +__all__ = ['SDMGRHead'] diff --git a/mmocr/models/kie/heads/sdmgr_head.py b/mmocr/models/kie/heads/sdmgr_head.py new file mode 100644 index 00000000..3fc0ba36 --- /dev/null +++ b/mmocr/models/kie/heads/sdmgr_head.py @@ -0,0 +1,193 @@ +import torch +from mmcv.cnn import normal_init +from torch import nn +from torch.nn import functional as F + +from mmdet.models.builder import HEADS, build_loss + + +@HEADS.register_module() +class SDMGRHead(nn.Module): + + def __init__(self, + num_chars=92, + visual_dim=64, + fusion_dim=1024, + node_input=32, + node_embed=256, + edge_input=5, + edge_embed=256, + num_gnn=2, + num_classes=26, + loss=dict(type='SDMGRLoss'), + bidirectional=False, + train_cfg=None, + test_cfg=None): + super().__init__() + + self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim) + self.node_embed = nn.Embedding(num_chars, node_input, 0) + hidden = node_embed // 2 if bidirectional else node_embed + self.rnn = nn.LSTM( + input_size=node_input, + hidden_size=hidden, + num_layers=1, + batch_first=True, + bidirectional=bidirectional) + self.edge_embed = nn.Linear(edge_input, edge_embed) + self.gnn_layers = nn.ModuleList( + [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]) + self.node_cls = nn.Linear(node_embed, num_classes) + self.edge_cls = nn.Linear(edge_embed, 2) + self.loss = build_loss(loss) + + def init_weights(self, pretrained=False): + normal_init(self.edge_embed, mean=0, std=0.01) + + def forward(self, relations, texts, x=None): + node_nums, char_nums = [], [] + for text in texts: + node_nums.append(text.size(0)) + char_nums.append((text > 0).sum(-1)) + + max_num = max([char_num.max() for char_num in char_nums]) + all_nodes = torch.cat([ + torch.cat( + [text, + text.new_zeros(text.size(0), max_num - text.size(1))], -1) + for text in texts + ]) + embed_nodes = self.node_embed(all_nodes.clamp(min=0).long()) + rnn_nodes, _ = self.rnn(embed_nodes) + + nodes = rnn_nodes.new_zeros(*rnn_nodes.shape[::2]) + all_nums = torch.cat(char_nums) + valid = all_nums > 0 + nodes[valid] = rnn_nodes[valid].gather( + 1, (all_nums[valid] - 1).unsqueeze(-1).unsqueeze(-1).expand( + -1, -1, rnn_nodes.size(-1))).squeeze(1) + + if x is not None: + nodes = self.fusion([x, nodes]) + + all_edges = torch.cat( + [rel.view(-1, rel.size(-1)) for rel in relations]) + embed_edges = self.edge_embed(all_edges.float()) + embed_edges = F.normalize(embed_edges) + + for gnn_layer in self.gnn_layers: + nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums) + + node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes) + return node_cls, edge_cls + + +class GNNLayer(nn.Module): + + def __init__(self, node_dim=256, edge_dim=256): + super().__init__() + self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim) + self.coef_fc = nn.Linear(node_dim, 1) + self.out_fc = nn.Linear(node_dim, node_dim) + self.relu = nn.ReLU() + + def forward(self, nodes, edges, nums): + start, cat_nodes = 0, [] + for num in nums: + sample_nodes = nodes[start:start + num] + cat_nodes.append( + torch.cat([ + sample_nodes.unsqueeze(1).expand(-1, num, -1), + sample_nodes.unsqueeze(0).expand(num, -1, -1) + ], -1).view(num**2, -1)) + start += num + cat_nodes = torch.cat([torch.cat(cat_nodes), edges], -1) + cat_nodes = self.relu(self.in_fc(cat_nodes)) + coefs = self.coef_fc(cat_nodes) + + start, residuals = 0, [] + for num in nums: + residual = F.softmax( + -torch.eye(num).to(coefs.device).unsqueeze(-1) * 1e9 + + coefs[start:start + num**2].view(num, num, -1), 1) + residuals.append( + (residual * + cat_nodes[start:start + num**2].view(num, num, -1)).sum(1)) + start += num**2 + + nodes += self.relu(self.out_fc(torch.cat(residuals))) + return nodes, cat_nodes + + +class Block(nn.Module): + + def __init__(self, + input_dims, + output_dim, + mm_dim=1600, + chunks=20, + rank=15, + shared=False, + dropout_input=0., + dropout_pre_lin=0., + dropout_output=0., + pos_norm='before_cat'): + super().__init__() + self.rank = rank + self.dropout_input = dropout_input + self.dropout_pre_lin = dropout_pre_lin + self.dropout_output = dropout_output + assert (pos_norm in ['before_cat', 'after_cat']) + self.pos_norm = pos_norm + # Modules + self.linear0 = nn.Linear(input_dims[0], mm_dim) + self.linear1 = ( + self.linear0 if shared else nn.Linear(input_dims[1], mm_dim)) + self.merge_linears0 = nn.ModuleList() + self.merge_linears1 = nn.ModuleList() + self.chunks = self.chunk_sizes(mm_dim, chunks) + for size in self.chunks: + ml0 = nn.Linear(size, size * rank) + self.merge_linears0.append(ml0) + ml1 = ml0 if shared else nn.Linear(size, size * rank) + self.merge_linears1.append(ml1) + self.linear_out = nn.Linear(mm_dim, output_dim) + + def forward(self, x): + x0 = self.linear0(x[0]) + x1 = self.linear1(x[1]) + bs = x1.size(0) + if self.dropout_input > 0: + x0 = F.dropout(x0, p=self.dropout_input, training=self.training) + x1 = F.dropout(x1, p=self.dropout_input, training=self.training) + x0_chunks = torch.split(x0, self.chunks, -1) + x1_chunks = torch.split(x1, self.chunks, -1) + zs = [] + for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, + self.merge_linears0, + self.merge_linears1): + m = m0(x0_c) * m1(x1_c) # bs x split_size*rank + m = m.view(bs, self.rank, -1) + z = torch.sum(m, 1) + if self.pos_norm == 'before_cat': + z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) + z = F.normalize(z) + zs.append(z) + z = torch.cat(zs, 1) + if self.pos_norm == 'after_cat': + z = torch.sqrt(F.relu(z)) - torch.sqrt(F.relu(-z)) + z = F.normalize(z) + + if self.dropout_pre_lin > 0: + z = F.dropout(z, p=self.dropout_pre_lin, training=self.training) + z = self.linear_out(z) + if self.dropout_output > 0: + z = F.dropout(z, p=self.dropout_output, training=self.training) + return z + + @staticmethod + def chunk_sizes(dim, chunks): + split_size = (dim + chunks - 1) // chunks + sizes_list = [split_size] * chunks + sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim) + return sizes_list diff --git a/mmocr/models/kie/losses/__init__.py b/mmocr/models/kie/losses/__init__.py new file mode 100644 index 00000000..96b4afde --- /dev/null +++ b/mmocr/models/kie/losses/__init__.py @@ -0,0 +1,3 @@ +from .sdmgr_loss import SDMGRLoss + +__all__ = ['SDMGRLoss'] diff --git a/mmocr/models/kie/losses/sdmgr_loss.py b/mmocr/models/kie/losses/sdmgr_loss.py new file mode 100644 index 00000000..9e1d2312 --- /dev/null +++ b/mmocr/models/kie/losses/sdmgr_loss.py @@ -0,0 +1,39 @@ +import torch +from torch import nn + +from mmdet.models.builder import LOSSES +from mmdet.models.losses import accuracy + + +@LOSSES.register_module() +class SDMGRLoss(nn.Module): + """The implementation the loss of key information extraction proposed in + the paper: Spatial Dual-Modality Graph Reasoning for Key Information + Extraction. + + https://arxiv.org/abs/2103.14470. + """ + + def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0): + super().__init__() + self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore) + self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1) + self.node_weight = node_weight + self.edge_weight = edge_weight + self.ignore = ignore + + def forward(self, node_preds, edge_preds, gts): + node_gts, edge_gts = [], [] + for gt in gts: + node_gts.append(gt[:, 0]) + edge_gts.append(gt[:, 1:].contiguous().view(-1)) + node_gts = torch.cat(node_gts).long() + edge_gts = torch.cat(edge_gts).long() + + node_valids = torch.nonzero(node_gts != self.ignore).view(-1) + edge_valids = torch.nonzero(edge_gts != -1).view(-1) + return dict( + loss_node=self.node_weight * self.loss_node(node_preds, node_gts), + loss_edge=self.edge_weight * self.loss_edge(edge_preds, edge_gts), + acc_node=accuracy(node_preds[node_valids], node_gts[node_valids]), + acc_edge=accuracy(edge_preds[edge_valids], edge_gts[edge_valids])) diff --git a/mmocr/models/textdet/__init__.py b/mmocr/models/textdet/__init__.py new file mode 100644 index 00000000..bf95e0f7 --- /dev/null +++ b/mmocr/models/textdet/__init__.py @@ -0,0 +1,5 @@ +from .dense_heads import * # noqa: F401,F403 +from .detectors import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .postprocess import * # noqa: F401,F403 diff --git a/mmocr/models/textdet/dense_heads/__init__.py b/mmocr/models/textdet/dense_heads/__init__.py new file mode 100644 index 00000000..8f227b80 --- /dev/null +++ b/mmocr/models/textdet/dense_heads/__init__.py @@ -0,0 +1,7 @@ +from .db_head import DBHead +from .head_mixin import HeadMixin +from .pan_head import PANHead +from .pse_head import PSEHead +from .textsnake_head import TextSnakeHead + +__all__ = ['PSEHead', 'PANHead', 'DBHead', 'HeadMixin', 'TextSnakeHead'] diff --git a/mmocr/models/textdet/dense_heads/db_head.py b/mmocr/models/textdet/dense_heads/db_head.py new file mode 100644 index 00000000..f32b296c --- /dev/null +++ b/mmocr/models/textdet/dense_heads/db_head.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn + +from mmdet.models.builder import HEADS, build_loss +from .head_mixin import HeadMixin + + +@HEADS.register_module() +class DBHead(HeadMixin, nn.Module): + """The class for DBNet head. + + This was partially adapted from https://github.com/MhLiao/DB + """ + + def __init__(self, + in_channels, + with_bias=False, + decoding_type='db', + text_repr_type='poly', + downsample_ratio=1.0, + loss=dict(type='DBLoss'), + train_cfg=None, + test_cfg=None): + """Initialization. + + Args: + in_channels (int): The number of input channels of the db head. + decoding_type (str): The type of decoder for dbnet. + text_repr_type (str): Boundary encoding type 'poly' or 'quad'. + downsample_ratio (float): The downsample ratio of ground truths. + loss (dict): The type of loss for dbnet. + """ + super().__init__() + + assert isinstance(in_channels, int) + + self.in_channels = in_channels + self.text_repr_type = text_repr_type + self.loss_module = build_loss(loss) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.downsample_ratio = downsample_ratio + self.decoding_type = decoding_type + + self.binarize = nn.Sequential( + nn.Conv2d( + in_channels, in_channels // 4, 3, bias=with_bias, padding=1), + nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2), + nn.BatchNorm2d(in_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels // 4, 1, 2, 2), nn.Sigmoid()) + + self.threshold = self._init_thr(in_channels) + + def init_weights(self): + self.binarize.apply(self.init_class_parameters) + self.threshold.apply(self.init_class_parameters) + + def init_class_parameters(self, m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + def diff_binarize(self, prob_map, thr_map, k): + return torch.reciprocal(1.0 + torch.exp(-k * (prob_map - thr_map))) + + def forward(self, inputs): + prob_map = self.binarize(inputs) + thr_map = self.threshold(inputs) + binary_map = self.diff_binarize(prob_map, thr_map, k=50) + outputs = torch.cat((prob_map, thr_map, binary_map), dim=1) + return (outputs, ) + + def _init_thr(self, inner_channels, bias=False): + in_channels = inner_channels + seq = nn.Sequential( + nn.Conv2d( + in_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), + nn.BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) + return seq diff --git a/mmocr/models/textdet/dense_heads/head_mixin.py b/mmocr/models/textdet/dense_heads/head_mixin.py new file mode 100644 index 00000000..567d334d --- /dev/null +++ b/mmocr/models/textdet/dense_heads/head_mixin.py @@ -0,0 +1,74 @@ +import numpy as np + +from mmdet.models.builder import HEADS +from mmocr.models.textdet.postprocess import decode +from mmocr.utils import check_argument + + +@HEADS.register_module() +class HeadMixin: + """The head minxin for dbnet and pannet heads.""" + + def resize_boundary(self, boundaries, scale_factor): + """Rescale boundaries via scale_factor. + + Args: + boundaries (list[list[float]]): The boundary list. Each boundary + with size 2k+1 with k>=4. + scale_factor(ndarray): The scale factor of size (4,). + + Returns: + boundaries (list[list[float]]): The scaled boundaries. + """ + assert check_argument.is_2dlist(boundaries) + assert isinstance(scale_factor, np.ndarray) + assert scale_factor.shape[0] == 4 + + for b in boundaries: + sz = len(b) + check_argument.valid_boundary(b, True) + b[:sz - + 1] = (np.array(b[:sz - 1]) * + (np.tile(scale_factor[:2], int( + (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist() + return boundaries + + def get_boundary(self, score_maps, img_metas, rescale): + """Compute text boundaries via post processing. + + Args: + score_maps (Tensor): The text score map. + img_metas (dict): The image meta info. + rescale (bool): Rescale boundaries to the original image resolution + if true, and keep the score_maps resolution if false. + + Returns: + results (dict): The result dict. + """ + + assert check_argument.is_type_list(img_metas, dict) + assert isinstance(rescale, bool) + + score_maps = score_maps.squeeze() + boundaries = decode( + decoding_type=self.decoding_type, + preds=score_maps, + text_repr_type=self.text_repr_type) + if rescale: + boundaries = self.resize_boundary( + boundaries, + 1.0 / self.downsample_ratio / img_metas[0]['scale_factor']) + results = dict(boundary_result=boundaries) + return results + + def loss(self, pred_maps, **kwargs): + """Compute the loss for text detection. + + Args: + pred_maps (tensor): The input score maps of NxCxHxW. + + Returns: + losses (dict): The dict for losses. + """ + losses = self.loss_module(pred_maps, self.downsample_ratio, **kwargs) + return losses diff --git a/mmocr/models/textdet/dense_heads/pan_head.py b/mmocr/models/textdet/dense_heads/pan_head.py new file mode 100644 index 00000000..08cc7cff --- /dev/null +++ b/mmocr/models/textdet/dense_heads/pan_head.py @@ -0,0 +1,61 @@ +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import normal_init + +from mmdet.models.builder import HEADS, build_loss +from mmocr.utils import check_argument +from . import HeadMixin + + +@HEADS.register_module() +class PANHead(HeadMixin, nn.Module): + """The class for PANet head.""" + + def __init__( + self, + in_channels, + out_channels, + text_repr_type='poly', # 'poly' or 'quad' + downsample_ratio=0.25, + loss=dict(type='PANLoss'), + train_cfg=None, + test_cfg=None): + super().__init__() + + assert check_argument.is_type_list(in_channels, int) + assert isinstance(out_channels, int) + assert text_repr_type in ['poly', 'quad'] + assert 0 <= downsample_ratio <= 1 + + self.loss_module = build_loss(loss) + self.in_channels = in_channels + self.out_channels = out_channels + self.text_repr_type = text_repr_type + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.downsample_ratio = downsample_ratio + if loss['type'] == 'PANLoss': + self.decoding_type = 'pan' + elif loss['type'] == 'PSELoss': + self.decoding_type = 'pse' + else: + type = loss['type'] + raise NotImplementedError(f'unsupported loss type {type}.') + + self.out_conv = nn.Conv2d( + in_channels=np.sum(np.array(in_channels)), + out_channels=out_channels, + kernel_size=1) + self.init_weights() + + def init_weights(self): + normal_init(self.out_conv, mean=0, std=0.01) + + def forward(self, inputs): + if isinstance(inputs, tuple): + outputs = torch.cat(inputs, dim=1) + else: + outputs = inputs + outputs = self.out_conv(outputs) + return outputs diff --git a/mmocr/models/textdet/dense_heads/pse_head.py b/mmocr/models/textdet/dense_heads/pse_head.py new file mode 100644 index 00000000..db2a9925 --- /dev/null +++ b/mmocr/models/textdet/dense_heads/pse_head.py @@ -0,0 +1,25 @@ +from mmdet.models.builder import HEADS +from . import PANHead + + +@HEADS.register_module() +class PSEHead(PANHead): + """The class for PANet head.""" + + def __init__( + self, + in_channels, + out_channels, + text_repr_type='poly', # 'poly' or 'quad' + downsample_ratio=0.25, + loss=dict(type='PSELoss'), + train_cfg=None, + test_cfg=None): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + text_repr_type=text_repr_type, + downsample_ratio=downsample_ratio, + loss=loss, + train_cfg=train_cfg, + test_cfg=test_cfg) diff --git a/mmocr/models/textdet/dense_heads/textsnake_head.py b/mmocr/models/textdet/dense_heads/textsnake_head.py new file mode 100644 index 00000000..1645bba7 --- /dev/null +++ b/mmocr/models/textdet/dense_heads/textsnake_head.py @@ -0,0 +1,48 @@ +import torch.nn as nn +from mmcv.cnn import normal_init + +from mmdet.models.builder import HEADS, build_loss +from . import HeadMixin + + +@HEADS.register_module() +class TextSnakeHead(HeadMixin, nn.Module): + """The class for TextSnake head: TextSnake: A Flexible Representation for + Detecting Text of Arbitrary Shapes. + + [https://arxiv.org/abs/1807.01544] + """ + + def __init__(self, + in_channels, + decoding_type='textsnake', + text_repr_type='poly', + loss=dict(type='TextSnakeLoss'), + train_cfg=None, + test_cfg=None): + super().__init__() + + assert isinstance(in_channels, int) + self.in_channels = in_channels + self.out_channels = 5 + self.downsample_ratio = 1.0 + self.decoding_type = decoding_type + self.text_repr_type = text_repr_type + self.loss_module = build_loss(loss) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.out_conv = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0) + self.init_weights() + + def init_weights(self): + normal_init(self.out_conv, mean=0, std=0.01) + + def forward(self, inputs): + outputs = self.out_conv(inputs) + return outputs diff --git a/mmocr/models/textdet/detectors/__init__.py b/mmocr/models/textdet/detectors/__init__.py new file mode 100644 index 00000000..6aab9c73 --- /dev/null +++ b/mmocr/models/textdet/detectors/__init__.py @@ -0,0 +1,12 @@ +from .dbnet import DBNet +from .ocr_mask_rcnn import OCRMaskRCNN +from .panet import PANet +from .psenet import PSENet +from .single_stage_text_detector import SingleStageTextDetector +from .text_detector_mixin import TextDetectorMixin +from .textsnake import TextSnake + +__all__ = [ + 'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet', + 'PANet', 'PSENet', 'TextSnake' +] diff --git a/mmocr/models/textdet/detectors/dbnet.py b/mmocr/models/textdet/detectors/dbnet.py new file mode 100644 index 00000000..4b3c33d4 --- /dev/null +++ b/mmocr/models/textdet/detectors/dbnet.py @@ -0,0 +1,26 @@ +from mmdet.models.builder import DETECTORS +from mmocr.models.textdet.detectors.single_stage_text_detector import \ + SingleStageTextDetector +from mmocr.models.textdet.detectors.text_detector_mixin import \ + TextDetectorMixin + + +@DETECTORS.register_module() +class DBNet(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing DBNet text detector: Real-time Scene Text + Detection with Differentiable Binarization. + + [https://arxiv.org/abs/1911.08947]. + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + show_score=False): + SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained) + TextDetectorMixin.__init__(self, show_score) diff --git a/mmocr/models/textdet/detectors/ocr_mask_rcnn.py b/mmocr/models/textdet/detectors/ocr_mask_rcnn.py new file mode 100644 index 00000000..a9ce9cc8 --- /dev/null +++ b/mmocr/models/textdet/detectors/ocr_mask_rcnn.py @@ -0,0 +1,41 @@ +from mmdet.models.builder import DETECTORS +from mmdet.models.detectors import MaskRCNN +from mmocr.models.textdet.detectors.text_detector_mixin import \ + TextDetectorMixin + + +@DETECTORS.register_module() +class OCRMaskRCNN(TextDetectorMixin, MaskRCNN): + """Mask RCNN tailored for OCR.""" + + def __init__(self, + backbone, + rpn_head, + roi_head, + train_cfg, + test_cfg, + neck=None, + pretrained=None, + text_repr_type='quad', + show_score=False): + TextDetectorMixin.__init__(self, show_score) + MaskRCNN.__init__( + self, + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained) + assert text_repr_type in ['quad', 'poly'] + self.text_repr_type = text_repr_type + + def simple_test(self, img, img_metas, proposals=None, rescale=False): + + results = super().simple_test(img, img_metas, proposals, rescale) + + boundaries = self.get_boundary(results[0]) + boundaries = boundaries if isinstance(boundaries, + list) else [boundaries] + return boundaries diff --git a/mmocr/models/textdet/detectors/panet.py b/mmocr/models/textdet/detectors/panet.py new file mode 100644 index 00000000..befa52c8 --- /dev/null +++ b/mmocr/models/textdet/detectors/panet.py @@ -0,0 +1,26 @@ +from mmdet.models.builder import DETECTORS +from mmocr.models.textdet.detectors.single_stage_text_detector import \ + SingleStageTextDetector +from mmocr.models.textdet.detectors.text_detector_mixin import \ + TextDetectorMixin + + +@DETECTORS.register_module() +class PANet(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing PANet text detector: + + Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel + Aggregation Network [https://arxiv.org/abs/1908.05900]. + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + show_score=False): + SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained) + TextDetectorMixin.__init__(self, show_score) diff --git a/mmocr/models/textdet/detectors/psenet.py b/mmocr/models/textdet/detectors/psenet.py new file mode 100644 index 00000000..7dccad4e --- /dev/null +++ b/mmocr/models/textdet/detectors/psenet.py @@ -0,0 +1,26 @@ +from mmdet.models.builder import DETECTORS +from mmocr.models.textdet.detectors.single_stage_text_detector import \ + SingleStageTextDetector +from mmocr.models.textdet.detectors.text_detector_mixin import \ + TextDetectorMixin + + +@DETECTORS.register_module() +class PSENet(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing PSENet text detector: Shape Robust Text + Detection with Progressive Scale Expansion Network. + + [https://arxiv.org/abs/1806.02559]. + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + show_score=False): + SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained) + TextDetectorMixin.__init__(self, show_score) diff --git a/mmocr/models/textdet/detectors/single_stage_text_detector.py b/mmocr/models/textdet/detectors/single_stage_text_detector.py new file mode 100644 index 00000000..3456479c --- /dev/null +++ b/mmocr/models/textdet/detectors/single_stage_text_detector.py @@ -0,0 +1,45 @@ +from mmdet.models.builder import DETECTORS +from mmdet.models.detectors import SingleStageDetector + + +@DETECTORS.register_module() +class SingleStageTextDetector(SingleStageDetector): + """The class for implementing single stage text detector. + + It is the parent class of PANet, PSENet, and DBNet. + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + SingleStageDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained) + + def forward_train(self, img, img_metas, **kwargs): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys, see + :class:`mmdet.datasets.pipelines.Collect`. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + x = self.extract_feat(img) + preds = self.bbox_head(x) + losses = self.bbox_head.loss(preds, **kwargs) + return losses + + def simple_test(self, img, img_metas, rescale=False): + x = self.extract_feat(img) + outs = self.bbox_head(x) + boundaries = self.bbox_head.get_boundary(*outs, img_metas, rescale) + + return [boundaries] diff --git a/mmocr/models/textdet/detectors/text_detector_mixin.py b/mmocr/models/textdet/detectors/text_detector_mixin.py new file mode 100644 index 00000000..30ba6560 --- /dev/null +++ b/mmocr/models/textdet/detectors/text_detector_mixin.py @@ -0,0 +1,101 @@ +import warnings + +import mmcv + +from mmocr.core import imshow_pred_boundary, seg2boundary + + +class TextDetectorMixin: + """The class for implementing text detector auxiliary methods.""" + + def __init__(self, show_score): + self.show_score = show_score + + def get_boundary(self, results): + """Convert segmentation into text boundaries. + + Args: + results (tuple): The result tuple. The first element is + segmentation while the second is its scores. + + Returns: + results (dict): A result dict containing 'boundary_result'. + """ + + assert isinstance(results, tuple) + + instance_num = len(results[1][0]) + boundaries = [] + for i in range(instance_num): + seg = results[1][0][i] + score = results[0][0][i][-1] + boundary = seg2boundary(seg, self.text_repr_type, score) + if boundary is not None: + boundaries.append(boundary) + + results = dict(boundary_result=boundaries) + return results + + def show_result(self, + img, + result, + score_thr=0.5, + bbox_color='green', + text_color='green', + thickness=1, + font_scale=0.5, + win_name='', + show=False, + wait_time=0, + out_file=None): + """Draw `result` over `img`. + + Args: + img (str or Tensor): The image to be displayed. + result (dict): The results to draw over `img`. + score_thr (float, optional): Minimum score of bboxes to be shown. + Default: 0.3. + bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. + text_color (str or tuple or :obj:`Color`): Color of texts. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The filename to write the image. + Default: None.imshow_pred_boundary` + """ + img = mmcv.imread(img) + img = img.copy() + boundaries = None + labels = None + if 'boundary_result' in result.keys(): + boundaries = result['boundary_result'] + labels = [0] * len(boundaries) + + # if out_file specified, do not show image in window + if out_file is not None: + show = False + # draw bounding boxes + if boundaries is not None: + imshow_pred_boundary( + img, + boundaries, + labels, + score_thr=score_thr, + boundary_color=bbox_color, + text_color=text_color, + thickness=thickness, + font_scale=font_scale, + win_name=win_name, + show=show, + wait_time=wait_time, + out_file=out_file, + show_score=self.show_score) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, ' + 'result image will be returned') + return img diff --git a/mmocr/models/textdet/detectors/textsnake.py b/mmocr/models/textdet/detectors/textsnake.py new file mode 100644 index 00000000..25f65abf --- /dev/null +++ b/mmocr/models/textdet/detectors/textsnake.py @@ -0,0 +1,23 @@ +from mmdet.models.builder import DETECTORS +from . import SingleStageTextDetector, TextDetectorMixin + + +@DETECTORS.register_module() +class TextSnake(TextDetectorMixin, SingleStageTextDetector): + """The class for implementing TextSnake text detector: TextSnake: A + Flexible Representation for Detecting Text of Arbitrary Shapes. + + [https://arxiv.org/abs/1807.01544] + """ + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + show_score=False): + SingleStageTextDetector.__init__(self, backbone, neck, bbox_head, + train_cfg, test_cfg, pretrained) + TextDetectorMixin.__init__(self, show_score) diff --git a/mmocr/models/textdet/losses/__init__.py b/mmocr/models/textdet/losses/__init__.py new file mode 100644 index 00000000..eaa4d9cd --- /dev/null +++ b/mmocr/models/textdet/losses/__init__.py @@ -0,0 +1,6 @@ +from .db_loss import DBLoss +from .pan_loss import PANLoss +from .pse_loss import PSELoss +from .textsnake_loss import TextSnakeLoss + +__all__ = ['PANLoss', 'PSELoss', 'DBLoss', 'TextSnakeLoss'] diff --git a/mmocr/models/textdet/losses/db_loss.py b/mmocr/models/textdet/losses/db_loss.py new file mode 100644 index 00000000..0081ecda --- /dev/null +++ b/mmocr/models/textdet/losses/db_loss.py @@ -0,0 +1,169 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from mmdet.models.builder import LOSSES +from mmocr.core.visualize import show_feature # noqa F401 +from mmocr.models.common.losses.dice_loss import DiceLoss + + +@LOSSES.register_module() +class DBLoss(nn.Module): + """The class for implementing DBNet loss. + + This is partially adapted from https://github.com/MhLiao/DB. + """ + + def __init__(self, + alpha=1, + beta=1, + reduction='mean', + negative_ratio=3.0, + eps=1e-6, + bbce_loss=False): + """Initialization. + + Args: + alpha (float): The binary loss coef. + beta (float): The threshold loss coef. + reduction (str): The way to reduce the loss. + negative_ratio (float): The ratio of positives to negatives. + eps (float): Epsilon in the threshold loss function. + bbce_loss (bool): Whether to use balanced bce for probability loss. + If False, dice loss will be used instead. + """ + super().__init__() + assert reduction in ['mean', + 'sum'], " reduction must in ['mean','sum']" + self.alpha = alpha + self.beta = beta + self.reduction = reduction + self.negative_ratio = negative_ratio + self.eps = eps + self.bbce_loss = bbce_loss + self.dice_loss = DiceLoss(eps=eps) + + def bitmasks2tensor(self, bitmasks, target_sz): + """Convert Bitmasks to tensor. + + Args: + bitmasks (list[BitMasks]): The BitMasks list. Each item is for + one img. + target_sz (tuple(int, int)): The target tensor size of KxHxW + with K being the number of kernels. + + Returns + result_tensors (list[tensor]): The list of kernel tensors. Each + element is for one kernel level. + """ + assert isinstance(bitmasks, list) + assert isinstance(target_sz, tuple) + + batch_size = len(bitmasks) + num_levels = len(bitmasks[0]) + + result_tensors = [] + + for level_inx in range(num_levels): + kernel = [] + for batch_inx in range(batch_size): + mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) + mask_sz = mask.shape + pad = [ + 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] + ] + mask = F.pad(mask, pad, mode='constant', value=0) + kernel.append(mask) + kernel = torch.stack(kernel) + result_tensors.append(kernel) + + return result_tensors + + def balance_bce_loss(self, pred, gt, mask): + + positive = (gt * mask) + negative = ((1 - gt) * mask) + positive_count = int(positive.float().sum()) + negative_count = min( + int(negative.float().sum()), + int(positive_count * self.negative_ratio)) + + assert gt.max() <= 1 and gt.min() >= 0 + assert pred.max() <= 1 and pred.min() >= 0 + loss = F.binary_cross_entropy(pred, gt, reduction='none') + positive_loss = loss * positive.float() + negative_loss = loss * negative.float() + + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) + + balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( + positive_count + negative_count + self.eps) + + return balance_loss + + def l1_thr_loss(self, pred, gt, mask): + thr_loss = torch.abs((pred - gt) * mask).sum() / ( + mask.sum() + self.eps) + return thr_loss + + def forward(self, preds, downsample_ratio, gt_shrink, gt_shrink_mask, + gt_thr, gt_thr_mask): + """Compute DBNet loss. + + Args: + preds (tensor): The output tensor with size of Nx3xHxW. + downsample_ratio (float): The downsample ratio for the + ground truths. + gt_shrink (list[BitmapMasks]): The mask list with each element + being the shrinked text mask for one img. + gt_shrink_mask (list[BitmapMasks]): The effective mask list with + each element being the shrinked effective mask for one img. + gt_thr (list[BitmapMasks]): The mask list with each element + being the threshold text mask for one img. + gt_thr_mask (list[BitmapMasks]): The effective mask list with + each element being the threshold effective mask for one img. + + Returns: + results(dict): The dict for dbnet losses with loss_prob, + loss_db and loss_thresh. + """ + assert isinstance(downsample_ratio, float) + + assert isinstance(gt_shrink, list) + assert isinstance(gt_shrink_mask, list) + assert isinstance(gt_thr, list) + assert isinstance(gt_thr_mask, list) + + preds = preds[0] + + pred_prob = preds[:, 0, :, :] + pred_thr = preds[:, 1, :, :] + pred_db = preds[:, 2, :, :] + feature_sz = preds.size() + + keys = ['gt_shrink', 'gt_shrink_mask', 'gt_thr', 'gt_thr_mask'] + gt = {} + for k in keys: + gt[k] = eval(k) + gt[k] = [item.rescale(downsample_ratio) for item in gt[k]] + gt[k] = self.bitmasks2tensor(gt[k], feature_sz[2:]) + gt[k] = [item.to(preds.device) for item in gt[k]] + gt['gt_shrink'][0] = (gt['gt_shrink'][0] > 0).float() + if self.bbce_loss: + loss_prob = self.balance_bce_loss(pred_prob, gt['gt_shrink'][0], + gt['gt_shrink_mask'][0]) + else: + loss_prob = self.dice_loss(pred_prob, gt['gt_shrink'][0], + gt['gt_shrink_mask'][0]) + + loss_db = self.dice_loss(pred_db, gt['gt_shrink'][0], + gt['gt_shrink_mask'][0]) + loss_thr = self.l1_thr_loss(pred_thr, gt['gt_thr'][0], + gt['gt_thr_mask'][0]) + + results = dict( + loss_prob=self.alpha * loss_prob, + loss_db=loss_db, + loss_thr=self.beta * loss_thr) + + return results diff --git a/mmocr/models/textdet/losses/pan_loss.py b/mmocr/models/textdet/losses/pan_loss.py new file mode 100644 index 00000000..c5ccd464 --- /dev/null +++ b/mmocr/models/textdet/losses/pan_loss.py @@ -0,0 +1,329 @@ +import itertools +import warnings + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from mmdet.core import BitmapMasks +from mmdet.models.builder import LOSSES +from mmocr.utils import check_argument + + +@LOSSES.register_module() +class PANLoss(nn.Module): + """The class for implementing PANet loss: Efficient and Accurate Arbitrary- + Shaped Text Detection with Pixel Aggregation Network. + + [https://arxiv.org/abs/1908.05900]. This was partially adapted from + https://github.com/WenmuZhou/PAN.pytorch + """ + + def __init__(self, + alpha=0.5, + beta=0.25, + delta_aggregation=0.5, + delta_discrimination=3, + ohem_ratio=3, + reduction='mean', + speedup_bbox_thr=-1): + """Initialization. + + Args: + alpha (float): The kernel loss coef. + beta (float): The aggregation and discriminative loss coef. + delta_aggregation (float): The constant for aggregation loss. + delta_discrimination (float): The constant for discriminative loss. + ohem_ratio (float): The negative/positive ratio in ohem. + reduction (str): The way to reduce the loss. + speedup_bbox_thr (int): Speed up if speedup_bbox_thr >0 + and 0.5).float() * ( + gt['gt_mask'][0].float()) + loss_kernels = self.dice_loss_with_logits(pred_kernels, + gt['gt_kernels'][1], + sampled_masks_kernel) + losses = [loss_texts, loss_kernels, loss_aggrs, loss_discrs] + if self.reduction == 'mean': + losses = [item.mean() for item in losses] + elif self.reduction == 'sum': + losses = [item.sum() for item in losses] + else: + raise NotImplementedError + + coefs = [1, self.alpha, self.beta, self.beta] + losses = [item * scale for item, scale in zip(losses, coefs)] + + results = dict() + results.update( + loss_text=losses[0], + loss_kernel=losses[1], + loss_aggregation=losses[2], + loss_discrimination=losses[3]) + return results + + def aggregation_discrimination_loss(self, gt_texts, gt_kernels, + inst_embeds): + """Compute the aggregation and discrimnative losses. + + Args: + gt_texts (tensor): The ground truth text mask of size Nx1xHxW. + gt_kernels (tensor): The ground truth text kernel mask of + size Nx1xHxW. + inst_embeds(tensor): The text instance embedding tensor + of size Nx4xHxW. + + Returns: + loss_aggrs (tensor): The aggregation loss before reduction. + loss_discrs (tensor): The discriminative loss before reduction. + """ + + batch_size = gt_texts.size()[0] + gt_texts = gt_texts.contiguous().reshape(batch_size, -1) + gt_kernels = gt_kernels.contiguous().reshape(batch_size, -1) + + assert inst_embeds.shape[1] == 4 + inst_embeds = inst_embeds.contiguous().reshape(batch_size, 4, -1) + + loss_aggrs = [] + loss_discrs = [] + + for text, kernel, embed in zip(gt_texts, gt_kernels, inst_embeds): + + # for each image + text_num = int(text.max().item()) + loss_aggr_img = [] + kernel_avgs = [] + select_num = self.speedup_bbox_thr + if 0 < select_num < text_num: + inds = np.random.choice( + text_num, select_num, replace=False) + 1 + else: + inds = range(1, text_num + 1) + + for i in inds: + # for each text instance + kernel_i = (kernel == i) # 0.2ms + if kernel_i.sum() == 0 or (text == i).sum() == 0: # 0.2ms + continue + + # compute G_Ki in Eq (2) + avg = embed[:, kernel_i].mean(1) # 0.5ms + kernel_avgs.append(avg) + + embed_i = embed[:, text == i] # 0.6ms + # ||F(p) - G(K_i)|| - delta_aggregation, shape: nums + distance = (embed_i - avg.reshape(4, 1)).norm( # 0.5ms + 2, dim=0) - self.delta_aggregation + # compute D(p,K_i) in Eq (2) + hinge = torch.max( + distance, + torch.tensor(0, device=distance.device, + dtype=torch.float)).pow(2) + + aggr = torch.log(hinge + 1).mean() + loss_aggr_img.append(aggr) + + num_inst = len(loss_aggr_img) + if num_inst > 0: + loss_aggr_img = torch.stack(loss_aggr_img).mean() + else: + loss_aggr_img = torch.tensor( + 0, device=gt_texts.device, dtype=torch.float) + loss_aggrs.append(loss_aggr_img) + + loss_discr_img = 0 + for avg_i, avg_j in itertools.combinations(kernel_avgs, 2): + # delta_discrimination - ||G(K_i) - G(K_j)|| + distance_ij = self.delta_discrimination - (avg_i - + avg_j).norm(2) + # D(K_i,K_j) + D_ij = torch.max( + distance_ij, + torch.tensor( + 0, device=distance_ij.device, + dtype=torch.float)).pow(2) + loss_discr_img += torch.log(D_ij + 1) + + if num_inst > 1: + loss_discr_img /= (num_inst * (num_inst - 1)) + else: + loss_discr_img = torch.tensor( + 0, device=gt_texts.device, dtype=torch.float) + if num_inst == 0: + warnings.warn('num of instance is 0') + loss_discrs.append(loss_discr_img) + return torch.stack(loss_aggrs), torch.stack(loss_discrs) + + def dice_loss_with_logits(self, pred, target, mask): + + smooth = 0.001 + + pred = torch.sigmoid(pred) + target[target <= 0.5] = 0 + target[target > 0.5] = 1 + pred = pred.contiguous().view(pred.size()[0], -1) + target = target.contiguous().view(target.size()[0], -1) + mask = mask.contiguous().view(mask.size()[0], -1) + + pred = pred * mask + target = target * mask + + a = torch.sum(pred * target, 1) + b = torch.sum(pred * pred, 1) + smooth + c = torch.sum(target * target, 1) + smooth + d = (2 * a) / (b + c) + return 1 - d + + def ohem_img(self, text_score, gt_text, gt_mask): + """Sample the top-k maximal negative samples and all positive samples. + + Args: + text_score (Tensor): The text score with size of HxW. + gt_text (Tensor): The ground truth text mask of HxW. + gt_mask (Tensor): The effective region mask of HxW. + + Returns: + sampled_mask (Tensor): The sampled pixel mask of size HxW. + """ + assert isinstance(text_score, torch.Tensor) + assert isinstance(gt_text, torch.Tensor) + assert isinstance(gt_mask, torch.Tensor) + assert len(text_score.shape) == 2 + assert text_score.shape == gt_text.shape + assert gt_text.shape == gt_mask.shape + + pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)( + torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item()) + neg_num = (int)(torch.sum(gt_text <= 0.5).item()) + neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num)) + + if pos_num == 0 or neg_num == 0: + warnings.warn('pos_num = 0 or neg_num = 0') + return gt_mask.bool() + + neg_score = text_score[gt_text <= 0.5] + neg_score_sorted, _ = torch.sort(neg_score, descending=True) + threshold = neg_score_sorted[neg_num - 1] + sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * ( + gt_mask > 0.5) + return sampled_mask + + def ohem_batch(self, text_scores, gt_texts, gt_mask): + """OHEM sampling for a batch of imgs. + + Args: + text_scores (Tensor): The text scores of size NxHxW. + gt_texts (Tensor): The gt text masks of size NxHxW. + gt_mask (Tensor): The gt effective mask of size NxHxW. + + Returns: + sampled_masks (Tensor): The sampled mask of size NxHxW. + """ + assert isinstance(text_scores, torch.Tensor) + assert isinstance(gt_texts, torch.Tensor) + assert isinstance(gt_mask, torch.Tensor) + assert len(text_scores.shape) == 3 + assert text_scores.shape == gt_texts.shape + assert gt_texts.shape == gt_mask.shape + + sampled_masks = [] + for i in range(text_scores.shape[0]): + sampled_masks.append( + self.ohem_img(text_scores[i], gt_texts[i], gt_mask[i])) + + sampled_masks = torch.stack(sampled_masks) + + return sampled_masks diff --git a/mmocr/models/textdet/losses/pse_loss.py b/mmocr/models/textdet/losses/pse_loss.py new file mode 100644 index 00000000..4580bfc4 --- /dev/null +++ b/mmocr/models/textdet/losses/pse_loss.py @@ -0,0 +1,104 @@ +from mmdet.core import BitmapMasks +from mmdet.models.builder import LOSSES +from mmocr.utils import check_argument +from . import PANLoss + + +@LOSSES.register_module() +class PSELoss(PANLoss): + """The class for implementing PSENet loss: Shape Robust Text Detection with + Progressive Scale Expansion Network [https://arxiv.org/abs/1806.02559]. + + This is partially adapted from https://github.com/whai362/PSENet. + """ + + def __init__(self, + alpha=0.7, + ohem_ratio=3, + reduction='mean', + kernel_sample_type='adaptive'): + """Initialization. + + Args: + alpha (float): alpha: The text loss coef; + (1-alpha): the kernel loss coef. + ohem_ratio (float): The negative/positive ratio in ohem. + reduction (str): The way to reduce the loss. + """ + super().__init__() + assert reduction in ['mean', + 'sum'], " reduction must in ['mean','sum']" + self.alpha = alpha + self.ohem_ratio = ohem_ratio + self.reduction = reduction + self.kernel_sample_type = kernel_sample_type + + def forward(self, score_maps, downsample_ratio, gt_kernels, gt_mask): + """Compute PSENet loss. + + Args: + score_maps (tensor): The output tensor with size of Nx6xHxW. + gt_kernels (list[BitmapMasks]): The kernel list with each element + being the text kernel mask for one img. + gt_mask (list[BitmapMasks]): The effective mask list + with each element being the effective mask fo one img. + downsample_ratio (float): The downsample ratio between score_maps + and the input img. + + Returns: + results (dict): The loss. + """ + + assert check_argument.is_type_list(gt_kernels, BitmapMasks) + assert check_argument.is_type_list(gt_mask, BitmapMasks) + assert isinstance(downsample_ratio, float) + losses = [] + + pred_texts = score_maps[:, 0, :, :] + pred_kernels = score_maps[:, 1:, :, :] + feature_sz = score_maps.size() + + gt_kernels = [item.rescale(downsample_ratio) for item in gt_kernels] + gt_kernels = self.bitmasks2tensor(gt_kernels, feature_sz[2:]) + gt_kernels = [item.to(score_maps.device) for item in gt_kernels] + + gt_mask = [item.rescale(downsample_ratio) for item in gt_mask] + gt_mask = self.bitmasks2tensor(gt_mask, feature_sz[2:]) + gt_mask = [item.to(score_maps.device) for item in gt_mask] + + # compute text loss + sampled_masks_text = self.ohem_batch(pred_texts.detach(), + gt_kernels[0], gt_mask[0]) + loss_texts = self.dice_loss_with_logits(pred_texts, gt_kernels[0], + sampled_masks_text) + losses.append(self.alpha * loss_texts) + + # compute kernel loss + if self.kernel_sample_type == 'hard': + sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * ( + gt_mask[0].float()) + elif self.kernel_sample_type == 'adaptive': + sampled_masks_kernel = (pred_texts > 0).float() * ( + gt_mask[0].float()) + else: + raise NotImplementedError + + num_kernel = pred_kernels.shape[1] + assert num_kernel == len(gt_kernels) - 1 + loss_list = [] + for inx in range(num_kernel): + loss_kernels = self.dice_loss_with_logits( + pred_kernels[:, inx, :, :], gt_kernels[1 + inx], + sampled_masks_kernel) + loss_list.append(loss_kernels) + + losses.append((1 - self.alpha) * sum(loss_list) / len(loss_list)) + + if self.reduction == 'mean': + losses = [item.mean() for item in losses] + elif self.reduction == 'sum': + losses = [item.sum() for item in losses] + else: + raise NotImplementedError + results = dict(loss_text=losses[0], loss_kernel=losses[1]) + return results diff --git a/mmocr/models/textdet/losses/textsnake_loss.py b/mmocr/models/textdet/losses/textsnake_loss.py new file mode 100644 index 00000000..b12de444 --- /dev/null +++ b/mmocr/models/textdet/losses/textsnake_loss.py @@ -0,0 +1,181 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from mmdet.core import BitmapMasks +from mmdet.models.builder import LOSSES +from mmocr.utils import check_argument + + +@LOSSES.register_module() +class TextSnakeLoss(nn.Module): + """The class for implementing TextSnake loss: + TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes + [https://arxiv.org/abs/1807.01544]. + This is partially adapted from + https://github.com/princewang1994/TextSnake.pytorch. + """ + + def __init__(self, ohem_ratio=3.0): + """Initialization. + + Args: + ohem_ratio (float): The negative/positive ratio in ohem. + """ + super().__init__() + self.ohem_ratio = ohem_ratio + + def balanced_bce_loss(self, pred, gt, mask): + + assert pred.shape == gt.shape == mask.shape + positive = gt * mask + negative = (1 - gt) * mask + positive_count = int(positive.float().sum()) + gt = gt.float() + if positive_count > 0: + loss = F.binary_cross_entropy(pred, gt, reduction='none') + positive_loss = torch.sum(loss * positive.float()) + negative_loss = loss * negative.float() + negative_count = min( + int(negative.float().sum()), + int(positive_count * self.ohem_ratio)) + else: + positive_loss = torch.tensor(0.0, device=pred.device) + loss = F.binary_cross_entropy(pred, gt, reduction='none') + negative_loss = loss * negative.float() + negative_count = 100 + negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) + + balance_loss = (positive_loss + torch.sum(negative_loss)) / ( + float(positive_count + negative_count) + 1e-5) + + return balance_loss + + def bitmasks2tensor(self, bitmasks, target_sz): + """Convert Bitmasks to tensor. + + Args: + bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is + for one img. + target_sz (tuple(int, int)): The target tensor size HxW. + + Returns + results (list[tensor]): The list of kernel tensors. Each + element is for one kernel level. + """ + assert check_argument.is_type_list(bitmasks, BitmapMasks) + assert isinstance(target_sz, tuple) + + batch_size = len(bitmasks) + num_masks = len(bitmasks[0]) + + results = [] + + for level_inx in range(num_masks): + kernel = [] + for batch_inx in range(batch_size): + mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx]) + # hxw + mask_sz = mask.shape + # left, right, top, bottom + pad = [ + 0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0] + ] + mask = F.pad(mask, pad, mode='constant', value=0) + kernel.append(mask) + kernel = torch.stack(kernel) + results.append(kernel) + + return results + + def forward(self, pred_maps, downsample_ratio, gt_text_mask, + gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map, + gt_cos_map): + + assert isinstance(downsample_ratio, float) + assert check_argument.is_type_list(gt_text_mask, BitmapMasks) + assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks) + assert check_argument.is_type_list(gt_mask, BitmapMasks) + assert check_argument.is_type_list(gt_radius_map, BitmapMasks) + assert check_argument.is_type_list(gt_sin_map, BitmapMasks) + assert check_argument.is_type_list(gt_cos_map, BitmapMasks) + + pred_text_region = pred_maps[:, 0, :, :] + pred_center_region = pred_maps[:, 1, :, :] + pred_sin_map = pred_maps[:, 2, :, :] + pred_cos_map = pred_maps[:, 3, :, :] + pred_radius_map = pred_maps[:, 4, :, :] + feature_sz = pred_maps.size() + device = pred_maps.device + + # bitmask 2 tensor + mapping = { + 'gt_text_mask': gt_text_mask, + 'gt_center_region_mask': gt_center_region_mask, + 'gt_mask': gt_mask, + 'gt_radius_map': gt_radius_map, + 'gt_sin_map': gt_sin_map, + 'gt_cos_map': gt_cos_map + } + gt = {} + for key, value in mapping.items(): + gt[key] = value + if abs(downsample_ratio - 1.0) < 1e-2: + gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) + else: + gt[key] = [item.rescale(downsample_ratio) for item in gt[key]] + gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:]) + if key == 'gt_radius_map': + gt[key] = [item * downsample_ratio for item in gt[key]] + gt[key] = [item.to(device) for item in gt[key]] + + scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8)) + pred_sin_map = pred_sin_map * scale + pred_cos_map = pred_cos_map * scale + + loss_text = self.balanced_bce_loss( + torch.sigmoid(pred_text_region), gt['gt_text_mask'][0], + gt['gt_mask'][0]) + + text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float() + loss_center_map = F.binary_cross_entropy( + torch.sigmoid(pred_center_region), + gt['gt_center_region_mask'][0].float(), + reduction='none') + if int(text_mask.sum()) > 0: + loss_center = torch.sum( + loss_center_map * text_mask) / torch.sum(text_mask) + else: + loss_center = torch.tensor(0.0, device=device) + + center_mask = (gt['gt_center_region_mask'][0] * + gt['gt_mask'][0]).float() + if int(center_mask.sum()) > 0: + map_sz = pred_radius_map.size() + ones = torch.ones(map_sz, dtype=torch.float, device=device) + loss_radius = torch.sum( + F.smooth_l1_loss( + pred_radius_map / (gt['gt_radius_map'][0] + 1e-2), + ones, + reduction='none') * center_mask) / torch.sum(center_mask) + loss_sin = torch.sum( + F.smooth_l1_loss( + pred_sin_map, gt['gt_sin_map'][0], reduction='none') * + center_mask) / torch.sum(center_mask) + loss_cos = torch.sum( + F.smooth_l1_loss( + pred_cos_map, gt['gt_cos_map'][0], reduction='none') * + center_mask) / torch.sum(center_mask) + else: + loss_radius = torch.tensor(0.0, device=device) + loss_sin = torch.tensor(0.0, device=device) + loss_cos = torch.tensor(0.0, device=device) + + results = dict( + loss_text=loss_text, + loss_center=loss_center, + loss_radius=loss_radius, + loss_sin=loss_sin, + loss_cos=loss_cos) + + return results diff --git a/mmocr/models/textdet/necks/__init__.py b/mmocr/models/textdet/necks/__init__.py new file mode 100644 index 00000000..101bbbdb --- /dev/null +++ b/mmocr/models/textdet/necks/__init__.py @@ -0,0 +1,6 @@ +from .fpem_ffm import FPEM_FFM +from .fpn_cat import FPNC +from .fpn_unet import FPN_UNET +from .fpnf import FPNF + +__all__ = ['FPEM_FFM', 'FPNF', 'FPNC', 'FPN_UNET'] diff --git a/mmocr/models/textdet/necks/fpem_ffm.py b/mmocr/models/textdet/necks/fpem_ffm.py new file mode 100644 index 00000000..3722649b --- /dev/null +++ b/mmocr/models/textdet/necks/fpem_ffm.py @@ -0,0 +1,143 @@ +import torch.nn.functional as F +from mmcv.cnn import xavier_init +from torch import nn + +from mmdet.models.builder import NECKS + + +class FPEM(nn.Module): + """FPN-like feature fusion module in PANet.""" + + def __init__(self, in_channels=128): + super().__init__() + self.up_add1 = SeparableConv2d(in_channels, in_channels, 1) + self.up_add2 = SeparableConv2d(in_channels, in_channels, 1) + self.up_add3 = SeparableConv2d(in_channels, in_channels, 1) + self.down_add1 = SeparableConv2d(in_channels, in_channels, 2) + self.down_add2 = SeparableConv2d(in_channels, in_channels, 2) + self.down_add3 = SeparableConv2d(in_channels, in_channels, 2) + + def forward(self, c2, c3, c4, c5): + # upsample + c4 = self.up_add1(self._upsample_add(c5, c4)) + c3 = self.up_add2(self._upsample_add(c4, c3)) + c2 = self.up_add3(self._upsample_add(c3, c2)) + + # downsample + c3 = self.down_add1(self._upsample_add(c3, c2)) + c4 = self.down_add2(self._upsample_add(c4, c3)) + c5 = self.down_add3(self._upsample_add(c5, c4)) + return c2, c3, c4, c5 + + def _upsample_add(self, x, y): + return F.interpolate(x, size=y.size()[2:]) + y + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + + self.depthwise_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + padding=1, + stride=stride, + groups=in_channels) + self.pointwise_conv = nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +@NECKS.register_module() +class FPEM_FFM(nn.Module): + """This code is from https://github.com/WenmuZhou/PAN.pytorch.""" + + def __init__(self, + in_channels, + conv_out=128, + fpem_repeat=2, + align_corners=False): + super().__init__() + # reduce layers + self.reduce_conv_c2 = nn.Sequential( + nn.Conv2d( + in_channels=in_channels[0], + out_channels=conv_out, + kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU()) + self.reduce_conv_c3 = nn.Sequential( + nn.Conv2d( + in_channels=in_channels[1], + out_channels=conv_out, + kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU()) + self.reduce_conv_c4 = nn.Sequential( + nn.Conv2d( + in_channels=in_channels[2], + out_channels=conv_out, + kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU()) + self.reduce_conv_c5 = nn.Sequential( + nn.Conv2d( + in_channels=in_channels[3], + out_channels=conv_out, + kernel_size=1), nn.BatchNorm2d(conv_out), nn.ReLU()) + self.align_corners = align_corners + self.fpems = nn.ModuleList() + for _ in range(fpem_repeat): + self.fpems.append(FPEM(conv_out)) + + def init_weights(self): + """Initialize the weights of FPN module.""" + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, x): + c2, c3, c4, c5 = x + # reduce channel + c2 = self.reduce_conv_c2(c2) + c3 = self.reduce_conv_c3(c3) + c4 = self.reduce_conv_c4(c4) + c5 = self.reduce_conv_c5(c5) + + # FPEM + for i, fpem in enumerate(self.fpems): + c2, c3, c4, c5 = fpem(c2, c3, c4, c5) + if i == 0: + c2_ffm = c2 + c3_ffm = c3 + c4_ffm = c4 + c5_ffm = c5 + else: + c2_ffm += c2 + c3_ffm += c3 + c4_ffm += c4 + c5_ffm += c5 + + # FFM + c5 = F.interpolate( + c5_ffm, + c2_ffm.size()[-2:], + mode='bilinear', + align_corners=self.align_corners) + c4 = F.interpolate( + c4_ffm, + c2_ffm.size()[-2:], + mode='bilinear', + align_corners=self.align_corners) + c3 = F.interpolate( + c3_ffm, + c2_ffm.size()[-2:], + mode='bilinear', + align_corners=self.align_corners) + outs = [c2_ffm, c3, c4, c5] + return tuple(outs) diff --git a/mmocr/models/textdet/necks/fpn_cat.py b/mmocr/models/textdet/necks/fpn_cat.py new file mode 100644 index 00000000..0a3241c1 --- /dev/null +++ b/mmocr/models/textdet/necks/fpn_cat.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import auto_fp16 + +from mmdet.models.builder import NECKS + + +@NECKS.register_module() +class FPNC(nn.Module): + """FPN-like fusion module in Real-time Scene Text Detection with + Differentiable Binarization. + + This was partially adapted from https://github.com/MhLiao/DB and + https://github.com/WenmuZhou/DBNet.pytorch + """ + + def __init__(self, + in_channels, + lateral_channels=256, + out_channels=64, + bias_on_lateral=False, + bn_re_on_lateral=False, + bias_on_smooth=False, + bn_re_on_smooth=False, + conv_after_concat=False): + super(FPNC, self).__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.lateral_channels = lateral_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.bn_re_on_lateral = bn_re_on_lateral + self.bn_re_on_smooth = bn_re_on_smooth + self.conv_after_concat = conv_after_concat + self.lateral_convs = nn.ModuleList() + self.smooth_convs = nn.ModuleList() + self.num_outs = self.num_ins + + for i in range(self.num_ins): + norm_cfg = None + act_cfg = None + if self.bn_re_on_lateral: + norm_cfg = dict(type='BN') + act_cfg = dict(type='ReLU') + l_conv = ConvModule( + in_channels[i], + lateral_channels, + 1, + bias=bias_on_lateral, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + norm_cfg = None + act_cfg = None + if self.bn_re_on_smooth: + norm_cfg = dict(type='BN') + act_cfg = dict(type='ReLU') + + smooth_conv = ConvModule( + lateral_channels, + out_channels, + 3, + bias=bias_on_smooth, + padding=1, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.smooth_convs.append(smooth_conv) + if self.conv_after_concat: + norm_cfg = dict(type='BN') + act_cfg = dict(type='ReLU') + self.out_conv = ConvModule( + out_channels * self.num_outs, + out_channels * self.num_outs, + 3, + padding=1, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + """Initialize the weights of FPN module.""" + for m in self.lateral_convs: + m.init_weights() + for m in self.smooth_convs: + m.init_weights() + if self.conv_after_concat: + self.out_conv.init_weights() + + @auto_fp16() + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + used_backbone_levels = len(laterals) + # build top-down path + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, mode='nearest') + # build outputs + # part 1: from original levels + outs = [ + self.smooth_convs[i](laterals[i]) + for i in range(used_backbone_levels) + ] + + for i in range(len(outs)): + scale = 2**i + outs[i] = F.interpolate( + outs[i], scale_factor=scale, mode='nearest') + out = torch.cat(outs, dim=1) + + if self.conv_after_concat: + out = self.out_conv(out) + + return out diff --git a/mmocr/models/textdet/necks/fpn_unet.py b/mmocr/models/textdet/necks/fpn_unet.py new file mode 100644 index 00000000..85c09085 --- /dev/null +++ b/mmocr/models/textdet/necks/fpn_unet.py @@ -0,0 +1,88 @@ +import torch +import torch.nn.functional as F +from mmcv.cnn import xavier_init +from torch import nn + +from mmdet.models.builder import NECKS + + +class UpBlock(nn.Module): + """Upsample block for DRRG and TextSnake.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + + assert isinstance(in_channels, int) + assert isinstance(out_channels, int) + + self.conv1x1 = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.conv3x3 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.deconv = nn.ConvTranspose2d( + out_channels, out_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x): + x = F.relu(self.conv1x1(x)) + x = F.relu(self.conv3x3(x)) + x = self.deconv(x) + return x + + +@NECKS.register_module() +class FPN_UNET(nn.Module): + """The class for implementing DRRG and TextSnake U-Net-like FPN. + + DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape + Text Detection [https://arxiv.org/abs/2003.07493]. + TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes + [https://arxiv.org/abs/1807.01544]. + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + + assert len(in_channels) == 4 + assert isinstance(out_channels, int) + + blocks_out_channels = [out_channels] + [ + min(out_channels * 2**i, 256) for i in range(4) + ] + blocks_in_channels = [blocks_out_channels[1]] + [ + in_channels[i] + blocks_out_channels[i + 2] for i in range(3) + ] + [in_channels[3]] + + self.up4 = nn.ConvTranspose2d( + blocks_in_channels[4], + blocks_out_channels[4], + kernel_size=4, + stride=2, + padding=1) + self.up_block3 = UpBlock(blocks_in_channels[3], blocks_out_channels[3]) + self.up_block2 = UpBlock(blocks_in_channels[2], blocks_out_channels[2]) + self.up_block1 = UpBlock(blocks_in_channels[1], blocks_out_channels[1]) + self.up_block0 = UpBlock(blocks_in_channels[0], blocks_out_channels[0]) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + xavier_init(m, distribution='uniform') + + def forward(self, x): + c2, c3, c4, c5 = x + + x = F.relu(self.up4(c5)) + + x = torch.cat([x, c4], dim=1) + x = F.relu(self.up_block3(x)) + + x = torch.cat([x, c3], dim=1) + x = F.relu(self.up_block2(x)) + + x = torch.cat([x, c2], dim=1) + x = F.relu(self.up_block1(x)) + + x = self.up_block0(x) + # the output should be of the same height and width as backbone input + return x diff --git a/mmocr/models/textdet/necks/fpnf.py b/mmocr/models/textdet/necks/fpnf.py new file mode 100644 index 00000000..bdb69fea --- /dev/null +++ b/mmocr/models/textdet/necks/fpnf.py @@ -0,0 +1,117 @@ +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, xavier_init +from mmcv.runner import auto_fp16 +from torch import nn + +from mmdet.models.builder import NECKS + + +@NECKS.register_module() +class FPNF(nn.Module): + """FPN-like fusion module in Shape Robust Text Detection with Progressive + Scale Expansion Network.""" + + def __init__( + self, + in_channels=[256, 512, 1024, 2048], + out_channels=256, + fusion_type='concat', # 'concat' or 'add' + upsample_ratio=1): + super().__init__() + conv_cfg = None + norm_cfg = dict(type='BN') + act_cfg = dict(type='ReLU') + + self.in_channels = in_channels + self.out_channels = out_channels + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + self.backbone_end_level = len(in_channels) + for i in range(self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + + if i < self.backbone_end_level - 1: + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(fpn_conv) + + self.fusion_type = fusion_type + + if self.fusion_type == 'concat': + feature_channels = 1024 + elif self.fusion_type == 'add': + feature_channels = 256 + else: + raise NotImplementedError + + self.output_convs = ConvModule( + feature_channels, + out_channels, + 3, + padding=1, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.upsample_ratio = upsample_ratio + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + @auto_fp16() + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # step 1: upsample to level i-1 size and add level i-1 + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, mode='nearest') + # step 2: smooth level i-1 + laterals[i - 1] = self.fpn_convs[i - 1](laterals[i - 1]) + + # upsample and cont + bottom_shape = laterals[0].shape[2:] + for i in range(1, used_backbone_levels): + laterals[i] = F.interpolate( + laterals[i], size=bottom_shape, mode='nearest') + + if self.fusion_type == 'concat': + out = torch.cat(laterals, 1) + elif self.fusion_type == 'add': + out = laterals[0] + for i in range(1, used_backbone_levels): + out += laterals[i] + else: + raise NotImplementedError + out = self.output_convs(out) + + return out diff --git a/mmocr/models/textdet/postprocess/__init__.py b/mmocr/models/textdet/postprocess/__init__.py new file mode 100644 index 00000000..acc72530 --- /dev/null +++ b/mmocr/models/textdet/postprocess/__init__.py @@ -0,0 +1,3 @@ +from .wrapper import decode + +__all__ = ['decode'] diff --git a/mmocr/models/textdet/postprocess/include/clipper/clipper.cpp b/mmocr/models/textdet/postprocess/include/clipper/clipper.cpp new file mode 100644 index 00000000..521c613c --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/clipper/clipper.cpp @@ -0,0 +1,4622 @@ +/******************************************************************************* +* * +* Author : Angus Johnson * +* Version : 6.4.0 * +* Date : 2 July 2015 * +* Website : http://www.angusj.com * +* Copyright : Angus Johnson 2010-2015 * +* * +* License: * +* Use, modification & distribution is subject to Boost Software License Ver 1. * +* http://www.boost.org/LICENSE_1_0.txt * +* * +* Attributions: * +* The code in this library is an extension of Bala Vatti's clipping algorithm: * +* "A generic solution to polygon clipping" * +* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. * +* http://portal.acm.org/citation.cfm?id=129906 * +* * +* Computer graphics and geometric modeling: implementation and algorithms * +* By Max K. Agoston * +* Springer; 1 edition (January 4, 2005) * +* http://books.google.com/books?q=vatti+clipping+agoston * +* * +* See also: * +* "Polygon Offsetting by Computing Winding Numbers" * +* Paper no. DETC2005-85513 pp. 565-575 * +* ASME 2005 International Design Engineering Technical Conferences * +* and Computers and Information in Engineering Conference (IDETC/CIE2005) * +* September 24-28, 2005 , Long Beach, California, USA * +* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf * +* * +*******************************************************************************/ + +/******************************************************************************* +* * +* This is a translation of the Delphi Clipper library and the naming style * +* used has retained a Delphi flavour. * +* * +*******************************************************************************/ + +#include "clipper.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ClipperLib { + +static double const pi = 3.141592653589793238; +static double const two_pi = pi *2; +static double const def_arc_tolerance = 0.25; + +enum Direction { dRightToLeft, dLeftToRight }; + +static int const Unassigned = -1; //edge not currently 'owning' a solution +static int const Skip = -2; //edge that would otherwise close a path + +#define HORIZONTAL (-1.0E+40) +#define TOLERANCE (1.0e-20) +#define NEAR_ZERO(val) (((val) > -TOLERANCE) && ((val) < TOLERANCE)) + +struct TEdge { + IntPoint Bot; + IntPoint Curr; //current (updated for every new scanbeam) + IntPoint Top; + double Dx; + PolyType PolyTyp; + EdgeSide Side; //side only refers to current side of solution poly + int WindDelta; //1 or -1 depending on winding direction + int WindCnt; + int WindCnt2; //winding count of the opposite polytype + int OutIdx; + TEdge *Next; + TEdge *Prev; + TEdge *NextInLML; + TEdge *NextInAEL; + TEdge *PrevInAEL; + TEdge *NextInSEL; + TEdge *PrevInSEL; +}; + +struct IntersectNode { + TEdge *Edge1; + TEdge *Edge2; + IntPoint Pt; +}; + +struct LocalMinimum { + cInt Y; + TEdge *LeftBound; + TEdge *RightBound; +}; + +struct OutPt; + +//OutRec: contains a path in the clipping solution. Edges in the AEL will +//carry a pointer to an OutRec when they are part of the clipping solution. +struct OutRec { + int Idx; + bool IsHole; + bool IsOpen; + OutRec *FirstLeft; //see comments in clipper.pas + PolyNode *PolyNd; + OutPt *Pts; + OutPt *BottomPt; +}; + +struct OutPt { + int Idx; + IntPoint Pt; + OutPt *Next; + OutPt *Prev; +}; + +struct Join { + OutPt *OutPt1; + OutPt *OutPt2; + IntPoint OffPt; +}; + +struct LocMinSorter +{ + inline bool operator()(const LocalMinimum& locMin1, const LocalMinimum& locMin2) + { + return locMin2.Y < locMin1.Y; + } +}; + +//------------------------------------------------------------------------------ +//------------------------------------------------------------------------------ + +inline cInt Round(double val) +{ + if ((val < 0)) return static_cast(val - 0.5); + else return static_cast(val + 0.5); +} +//------------------------------------------------------------------------------ + +inline cInt Abs(cInt val) +{ + return val < 0 ? -val : val; +} + +//------------------------------------------------------------------------------ +// PolyTree methods ... +//------------------------------------------------------------------------------ + +void PolyTree::Clear() +{ + for (PolyNodes::size_type i = 0; i < AllNodes.size(); ++i) + delete AllNodes[i]; + AllNodes.resize(0); + Childs.resize(0); +} +//------------------------------------------------------------------------------ + +PolyNode* PolyTree::GetFirst() const +{ + if (!Childs.empty()) + return Childs[0]; + else + return 0; +} +//------------------------------------------------------------------------------ + +int PolyTree::Total() const +{ + int result = (int)AllNodes.size(); + //with negative offsets, ignore the hidden outer polygon ... + if (result > 0 && Childs[0] != AllNodes[0]) result--; + return result; +} + +//------------------------------------------------------------------------------ +// PolyNode methods ... +//------------------------------------------------------------------------------ + +PolyNode::PolyNode(): Childs(), Parent(0), Index(0), m_IsOpen(false) +{ +} +//------------------------------------------------------------------------------ + +int PolyNode::ChildCount() const +{ + return (int)Childs.size(); +} +//------------------------------------------------------------------------------ + +void PolyNode::AddChild(PolyNode& child) +{ + unsigned cnt = (unsigned)Childs.size(); + Childs.push_back(&child); + child.Parent = this; + child.Index = cnt; +} +//------------------------------------------------------------------------------ + +PolyNode* PolyNode::GetNext() const +{ + if (!Childs.empty()) + return Childs[0]; + else + return GetNextSiblingUp(); +} +//------------------------------------------------------------------------------ + +PolyNode* PolyNode::GetNextSiblingUp() const +{ + if (!Parent) //protects against PolyTree.GetNextSiblingUp() + return 0; + else if (Index == Parent->Childs.size() - 1) + return Parent->GetNextSiblingUp(); + else + return Parent->Childs[Index + 1]; +} +//------------------------------------------------------------------------------ + +bool PolyNode::IsHole() const +{ + bool result = true; + PolyNode* node = Parent; + while (node) + { + result = !result; + node = node->Parent; + } + return result; +} +//------------------------------------------------------------------------------ + +bool PolyNode::IsOpen() const +{ + return m_IsOpen; +} +//------------------------------------------------------------------------------ + +#ifndef use_int32 + +//------------------------------------------------------------------------------ +// Int128 class (enables safe math on signed 64bit integers) +// eg Int128 val1((long64)9223372036854775807); //ie 2^63 -1 +// Int128 val2((long64)9223372036854775807); +// Int128 val3 = val1 * val2; +// val3.AsString => "85070591730234615847396907784232501249" (8.5e+37) +//------------------------------------------------------------------------------ + +class Int128 +{ + public: + ulong64 lo; + long64 hi; + + Int128(long64 _lo = 0) + { + lo = (ulong64)_lo; + if (_lo < 0) hi = -1; else hi = 0; + } + + + Int128(const Int128 &val): lo(val.lo), hi(val.hi){} + + Int128(const long64& _hi, const ulong64& _lo): lo(_lo), hi(_hi){} + + Int128& operator = (const long64 &val) + { + lo = (ulong64)val; + if (val < 0) hi = -1; else hi = 0; + return *this; + } + + bool operator == (const Int128 &val) const + {return (hi == val.hi && lo == val.lo);} + + bool operator != (const Int128 &val) const + { return !(*this == val);} + + bool operator > (const Int128 &val) const + { + if (hi != val.hi) + return hi > val.hi; + else + return lo > val.lo; + } + + bool operator < (const Int128 &val) const + { + if (hi != val.hi) + return hi < val.hi; + else + return lo < val.lo; + } + + bool operator >= (const Int128 &val) const + { return !(*this < val);} + + bool operator <= (const Int128 &val) const + { return !(*this > val);} + + Int128& operator += (const Int128 &rhs) + { + hi += rhs.hi; + lo += rhs.lo; + if (lo < rhs.lo) hi++; + return *this; + } + + Int128 operator + (const Int128 &rhs) const + { + Int128 result(*this); + result+= rhs; + return result; + } + + Int128& operator -= (const Int128 &rhs) + { + *this += -rhs; + return *this; + } + + Int128 operator - (const Int128 &rhs) const + { + Int128 result(*this); + result -= rhs; + return result; + } + + Int128 operator-() const //unary negation + { + if (lo == 0) + return Int128(-hi, 0); + else + return Int128(~hi, ~lo + 1); + } + + operator double() const + { + const double shift64 = 18446744073709551616.0; //2^64 + if (hi < 0) + { + if (lo == 0) return (double)hi * shift64; + else return -(double)(~lo + ~hi * shift64); + } + else + return (double)(lo + hi * shift64); + } + +}; +//------------------------------------------------------------------------------ + +Int128 Int128Mul (long64 lhs, long64 rhs) +{ + bool negate = (lhs < 0) != (rhs < 0); + + if (lhs < 0) lhs = -lhs; + ulong64 int1Hi = ulong64(lhs) >> 32; + ulong64 int1Lo = ulong64(lhs & 0xFFFFFFFF); + + if (rhs < 0) rhs = -rhs; + ulong64 int2Hi = ulong64(rhs) >> 32; + ulong64 int2Lo = ulong64(rhs & 0xFFFFFFFF); + + //nb: see comments in clipper.pas + ulong64 a = int1Hi * int2Hi; + ulong64 b = int1Lo * int2Lo; + ulong64 c = int1Hi * int2Lo + int1Lo * int2Hi; + + Int128 tmp; + tmp.hi = long64(a + (c >> 32)); + tmp.lo = long64(c << 32); + tmp.lo += long64(b); + if (tmp.lo < b) tmp.hi++; + if (negate) tmp = -tmp; + return tmp; +}; +#endif + +//------------------------------------------------------------------------------ +// Miscellaneous global functions +//------------------------------------------------------------------------------ + +bool Orientation(const Path &poly) +{ + return Area(poly) >= 0; +} +//------------------------------------------------------------------------------ + +double Area(const Path &poly) +{ + int size = (int)poly.size(); + if (size < 3) return 0; + + double a = 0; + for (int i = 0, j = size -1; i < size; ++i) + { + a += ((double)poly[j].X + poly[i].X) * ((double)poly[j].Y - poly[i].Y); + j = i; + } + return -a * 0.5; +} +//------------------------------------------------------------------------------ + +double Area(const OutPt *op) +{ + const OutPt *startOp = op; + if (!op) return 0; + double a = 0; + do { + a += (double)(op->Prev->Pt.X + op->Pt.X) * (double)(op->Prev->Pt.Y - op->Pt.Y); + op = op->Next; + } while (op != startOp); + return a * 0.5; +} +//------------------------------------------------------------------------------ + +double Area(const OutRec &outRec) +{ + return Area(outRec.Pts); +} +//------------------------------------------------------------------------------ + +bool PointIsVertex(const IntPoint &Pt, OutPt *pp) +{ + OutPt *pp2 = pp; + do + { + if (pp2->Pt == Pt) return true; + pp2 = pp2->Next; + } + while (pp2 != pp); + return false; +} +//------------------------------------------------------------------------------ + +//See "The Point in Polygon Problem for Arbitrary Polygons" by Hormann & Agathos +//http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.88.5498&rep=rep1&type=pdf +int PointInPolygon(const IntPoint &pt, const Path &path) +{ + //returns 0 if false, +1 if true, -1 if pt ON polygon boundary + int result = 0; + size_t cnt = path.size(); + if (cnt < 3) return 0; + IntPoint ip = path[0]; + for(size_t i = 1; i <= cnt; ++i) + { + IntPoint ipNext = (i == cnt ? path[0] : path[i]); + if (ipNext.Y == pt.Y) + { + if ((ipNext.X == pt.X) || (ip.Y == pt.Y && + ((ipNext.X > pt.X) == (ip.X < pt.X)))) return -1; + } + if ((ip.Y < pt.Y) != (ipNext.Y < pt.Y)) + { + if (ip.X >= pt.X) + { + if (ipNext.X > pt.X) result = 1 - result; + else + { + double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) - + (double)(ipNext.X - pt.X) * (ip.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (ipNext.Y > ip.Y)) result = 1 - result; + } + } else + { + if (ipNext.X > pt.X) + { + double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) - + (double)(ipNext.X - pt.X) * (ip.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (ipNext.Y > ip.Y)) result = 1 - result; + } + } + } + ip = ipNext; + } + return result; +} +//------------------------------------------------------------------------------ + +int PointInPolygon (const IntPoint &pt, OutPt *op) +{ + //returns 0 if false, +1 if true, -1 if pt ON polygon boundary + int result = 0; + OutPt* startOp = op; + for(;;) + { + if (op->Next->Pt.Y == pt.Y) + { + if ((op->Next->Pt.X == pt.X) || (op->Pt.Y == pt.Y && + ((op->Next->Pt.X > pt.X) == (op->Pt.X < pt.X)))) return -1; + } + if ((op->Pt.Y < pt.Y) != (op->Next->Pt.Y < pt.Y)) + { + if (op->Pt.X >= pt.X) + { + if (op->Next->Pt.X > pt.X) result = 1 - result; + else + { + double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) - + (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) result = 1 - result; + } + } else + { + if (op->Next->Pt.X > pt.X) + { + double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) - + (double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y); + if (!d) return -1; + if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y)) result = 1 - result; + } + } + } + op = op->Next; + if (startOp == op) break; + } + return result; +} +//------------------------------------------------------------------------------ + +bool Poly2ContainsPoly1(OutPt *OutPt1, OutPt *OutPt2) +{ + OutPt* op = OutPt1; + do + { + //nb: PointInPolygon returns 0 if false, +1 if true, -1 if pt on polygon + int res = PointInPolygon(op->Pt, OutPt2); + if (res >= 0) return res > 0; + op = op->Next; + } + while (op != OutPt1); + return true; +} +//---------------------------------------------------------------------- + +bool SlopesEqual(const TEdge &e1, const TEdge &e2, bool UseFullInt64Range) +{ +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(e1.Top.Y - e1.Bot.Y, e2.Top.X - e2.Bot.X) == + Int128Mul(e1.Top.X - e1.Bot.X, e2.Top.Y - e2.Bot.Y); + else +#endif + return (e1.Top.Y - e1.Bot.Y) * (e2.Top.X - e2.Bot.X) == + (e1.Top.X - e1.Bot.X) * (e2.Top.Y - e2.Bot.Y); +} +//------------------------------------------------------------------------------ + +bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, + const IntPoint pt3, bool UseFullInt64Range) +{ +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(pt1.Y-pt2.Y, pt2.X-pt3.X) == Int128Mul(pt1.X-pt2.X, pt2.Y-pt3.Y); + else +#endif + return (pt1.Y-pt2.Y)*(pt2.X-pt3.X) == (pt1.X-pt2.X)*(pt2.Y-pt3.Y); +} +//------------------------------------------------------------------------------ + +bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, + const IntPoint pt3, const IntPoint pt4, bool UseFullInt64Range) +{ +#ifndef use_int32 + if (UseFullInt64Range) + return Int128Mul(pt1.Y-pt2.Y, pt3.X-pt4.X) == Int128Mul(pt1.X-pt2.X, pt3.Y-pt4.Y); + else +#endif + return (pt1.Y-pt2.Y)*(pt3.X-pt4.X) == (pt1.X-pt2.X)*(pt3.Y-pt4.Y); +} +//------------------------------------------------------------------------------ + +inline bool IsHorizontal(TEdge &e) +{ + return e.Dx == HORIZONTAL; +} +//------------------------------------------------------------------------------ + +inline double GetDx(const IntPoint pt1, const IntPoint pt2) +{ + return (pt1.Y == pt2.Y) ? + HORIZONTAL : (double)(pt2.X - pt1.X) / (pt2.Y - pt1.Y); +} +//--------------------------------------------------------------------------- + +inline void SetDx(TEdge &e) +{ + cInt dy = (e.Top.Y - e.Bot.Y); + if (dy == 0) e.Dx = HORIZONTAL; + else e.Dx = (double)(e.Top.X - e.Bot.X) / dy; +} +//--------------------------------------------------------------------------- + +inline void SwapSides(TEdge &Edge1, TEdge &Edge2) +{ + EdgeSide Side = Edge1.Side; + Edge1.Side = Edge2.Side; + Edge2.Side = Side; +} +//------------------------------------------------------------------------------ + +inline void SwapPolyIndexes(TEdge &Edge1, TEdge &Edge2) +{ + int OutIdx = Edge1.OutIdx; + Edge1.OutIdx = Edge2.OutIdx; + Edge2.OutIdx = OutIdx; +} +//------------------------------------------------------------------------------ + +inline cInt TopX(TEdge &edge, const cInt currentY) +{ + return ( currentY == edge.Top.Y ) ? + edge.Top.X : edge.Bot.X + Round(edge.Dx *(currentY - edge.Bot.Y)); +} +//------------------------------------------------------------------------------ + +void IntersectPoint(TEdge &Edge1, TEdge &Edge2, IntPoint &ip) +{ +#ifdef use_xyz + ip.Z = 0; +#endif + + double b1, b2; + if (Edge1.Dx == Edge2.Dx) + { + ip.Y = Edge1.Curr.Y; + ip.X = TopX(Edge1, ip.Y); + return; + } + else if (Edge1.Dx == 0) + { + ip.X = Edge1.Bot.X; + if (IsHorizontal(Edge2)) + ip.Y = Edge2.Bot.Y; + else + { + b2 = Edge2.Bot.Y - (Edge2.Bot.X / Edge2.Dx); + ip.Y = Round(ip.X / Edge2.Dx + b2); + } + } + else if (Edge2.Dx == 0) + { + ip.X = Edge2.Bot.X; + if (IsHorizontal(Edge1)) + ip.Y = Edge1.Bot.Y; + else + { + b1 = Edge1.Bot.Y - (Edge1.Bot.X / Edge1.Dx); + ip.Y = Round(ip.X / Edge1.Dx + b1); + } + } + else + { + b1 = Edge1.Bot.X - Edge1.Bot.Y * Edge1.Dx; + b2 = Edge2.Bot.X - Edge2.Bot.Y * Edge2.Dx; + double q = (b2-b1) / (Edge1.Dx - Edge2.Dx); + ip.Y = Round(q); + if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx)) + ip.X = Round(Edge1.Dx * q + b1); + else + ip.X = Round(Edge2.Dx * q + b2); + } + + if (ip.Y < Edge1.Top.Y || ip.Y < Edge2.Top.Y) + { + if (Edge1.Top.Y > Edge2.Top.Y) + ip.Y = Edge1.Top.Y; + else + ip.Y = Edge2.Top.Y; + if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx)) + ip.X = TopX(Edge1, ip.Y); + else + ip.X = TopX(Edge2, ip.Y); + } + //finally, don't allow 'ip' to be BELOW curr.Y (ie bottom of scanbeam) ... + if (ip.Y > Edge1.Curr.Y) + { + ip.Y = Edge1.Curr.Y; + //use the more vertical edge to derive X ... + if (std::fabs(Edge1.Dx) > std::fabs(Edge2.Dx)) + ip.X = TopX(Edge2, ip.Y); else + ip.X = TopX(Edge1, ip.Y); + } +} +//------------------------------------------------------------------------------ + +void ReversePolyPtLinks(OutPt *pp) +{ + if (!pp) return; + OutPt *pp1, *pp2; + pp1 = pp; + do { + pp2 = pp1->Next; + pp1->Next = pp1->Prev; + pp1->Prev = pp2; + pp1 = pp2; + } while( pp1 != pp ); +} +//------------------------------------------------------------------------------ + +void DisposeOutPts(OutPt*& pp) +{ + if (pp == 0) return; + pp->Prev->Next = 0; + while( pp ) + { + OutPt *tmpPp = pp; + pp = pp->Next; + delete tmpPp; + } +} +//------------------------------------------------------------------------------ + +inline void InitEdge(TEdge* e, TEdge* eNext, TEdge* ePrev, const IntPoint& Pt) +{ + std::memset(e, 0, sizeof(TEdge)); + e->Next = eNext; + e->Prev = ePrev; + e->Curr = Pt; + e->OutIdx = Unassigned; +} +//------------------------------------------------------------------------------ + +void InitEdge2(TEdge& e, PolyType Pt) +{ + if (e.Curr.Y >= e.Next->Curr.Y) + { + e.Bot = e.Curr; + e.Top = e.Next->Curr; + } else + { + e.Top = e.Curr; + e.Bot = e.Next->Curr; + } + SetDx(e); + e.PolyTyp = Pt; +} +//------------------------------------------------------------------------------ + +TEdge* RemoveEdge(TEdge* e) +{ + //removes e from double_linked_list (but without removing from memory) + e->Prev->Next = e->Next; + e->Next->Prev = e->Prev; + TEdge* result = e->Next; + e->Prev = 0; //flag as removed (see ClipperBase.Clear) + return result; +} +//------------------------------------------------------------------------------ + +inline void ReverseHorizontal(TEdge &e) +{ + //swap horizontal edges' Top and Bottom x's so they follow the natural + //progression of the bounds - ie so their xbots will align with the + //adjoining lower edge. [Helpful in the ProcessHorizontal() method.] + std::swap(e.Top.X, e.Bot.X); +#ifdef use_xyz + std::swap(e.Top.Z, e.Bot.Z); +#endif +} +//------------------------------------------------------------------------------ + +void SwapPoints(IntPoint &pt1, IntPoint &pt2) +{ + IntPoint tmp = pt1; + pt1 = pt2; + pt2 = tmp; +} +//------------------------------------------------------------------------------ + +bool GetOverlapSegment(IntPoint pt1a, IntPoint pt1b, IntPoint pt2a, + IntPoint pt2b, IntPoint &pt1, IntPoint &pt2) +{ + //precondition: segments are Collinear. + if (Abs(pt1a.X - pt1b.X) > Abs(pt1a.Y - pt1b.Y)) + { + if (pt1a.X > pt1b.X) SwapPoints(pt1a, pt1b); + if (pt2a.X > pt2b.X) SwapPoints(pt2a, pt2b); + if (pt1a.X > pt2a.X) pt1 = pt1a; else pt1 = pt2a; + if (pt1b.X < pt2b.X) pt2 = pt1b; else pt2 = pt2b; + return pt1.X < pt2.X; + } else + { + if (pt1a.Y < pt1b.Y) SwapPoints(pt1a, pt1b); + if (pt2a.Y < pt2b.Y) SwapPoints(pt2a, pt2b); + if (pt1a.Y < pt2a.Y) pt1 = pt1a; else pt1 = pt2a; + if (pt1b.Y > pt2b.Y) pt2 = pt1b; else pt2 = pt2b; + return pt1.Y > pt2.Y; + } +} +//------------------------------------------------------------------------------ + +bool FirstIsBottomPt(const OutPt* btmPt1, const OutPt* btmPt2) +{ + OutPt *p = btmPt1->Prev; + while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) p = p->Prev; + double dx1p = std::fabs(GetDx(btmPt1->Pt, p->Pt)); + p = btmPt1->Next; + while ((p->Pt == btmPt1->Pt) && (p != btmPt1)) p = p->Next; + double dx1n = std::fabs(GetDx(btmPt1->Pt, p->Pt)); + + p = btmPt2->Prev; + while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) p = p->Prev; + double dx2p = std::fabs(GetDx(btmPt2->Pt, p->Pt)); + p = btmPt2->Next; + while ((p->Pt == btmPt2->Pt) && (p != btmPt2)) p = p->Next; + double dx2n = std::fabs(GetDx(btmPt2->Pt, p->Pt)); + + if (std::max(dx1p, dx1n) == std::max(dx2p, dx2n) && + std::min(dx1p, dx1n) == std::min(dx2p, dx2n)) + return Area(btmPt1) > 0; //if otherwise identical use orientation + else + return (dx1p >= dx2p && dx1p >= dx2n) || (dx1n >= dx2p && dx1n >= dx2n); +} +//------------------------------------------------------------------------------ + +OutPt* GetBottomPt(OutPt *pp) +{ + OutPt* dups = 0; + OutPt* p = pp->Next; + while (p != pp) + { + if (p->Pt.Y > pp->Pt.Y) + { + pp = p; + dups = 0; + } + else if (p->Pt.Y == pp->Pt.Y && p->Pt.X <= pp->Pt.X) + { + if (p->Pt.X < pp->Pt.X) + { + dups = 0; + pp = p; + } else + { + if (p->Next != pp && p->Prev != pp) dups = p; + } + } + p = p->Next; + } + if (dups) + { + //there appears to be at least 2 vertices at BottomPt so ... + while (dups != p) + { + if (!FirstIsBottomPt(p, dups)) pp = dups; + dups = dups->Next; + while (dups->Pt != pp->Pt) dups = dups->Next; + } + } + return pp; +} +//------------------------------------------------------------------------------ + +bool Pt2IsBetweenPt1AndPt3(const IntPoint pt1, + const IntPoint pt2, const IntPoint pt3) +{ + if ((pt1 == pt3) || (pt1 == pt2) || (pt3 == pt2)) + return false; + else if (pt1.X != pt3.X) + return (pt2.X > pt1.X) == (pt2.X < pt3.X); + else + return (pt2.Y > pt1.Y) == (pt2.Y < pt3.Y); +} +//------------------------------------------------------------------------------ + +bool HorzSegmentsOverlap(cInt seg1a, cInt seg1b, cInt seg2a, cInt seg2b) +{ + if (seg1a > seg1b) std::swap(seg1a, seg1b); + if (seg2a > seg2b) std::swap(seg2a, seg2b); + return (seg1a < seg2b) && (seg2a < seg1b); +} + +//------------------------------------------------------------------------------ +// ClipperBase class methods ... +//------------------------------------------------------------------------------ + +ClipperBase::ClipperBase() //constructor +{ + m_CurrentLM = m_MinimaList.begin(); //begin() == end() here + m_UseFullRange = false; +} +//------------------------------------------------------------------------------ + +ClipperBase::~ClipperBase() //destructor +{ + Clear(); +} +//------------------------------------------------------------------------------ + +void RangeTest(const IntPoint& Pt, bool& useFullRange) +{ + if (useFullRange) + { + if (Pt.X > hiRange || Pt.Y > hiRange || -Pt.X > hiRange || -Pt.Y > hiRange) + throw clipperException("Coordinate outside allowed range"); + } + else if (Pt.X > loRange|| Pt.Y > loRange || -Pt.X > loRange || -Pt.Y > loRange) + { + useFullRange = true; + RangeTest(Pt, useFullRange); + } +} +//------------------------------------------------------------------------------ + +TEdge* FindNextLocMin(TEdge* E) +{ + for (;;) + { + while (E->Bot != E->Prev->Bot || E->Curr == E->Top) E = E->Next; + if (!IsHorizontal(*E) && !IsHorizontal(*E->Prev)) break; + while (IsHorizontal(*E->Prev)) E = E->Prev; + TEdge* E2 = E; + while (IsHorizontal(*E)) E = E->Next; + if (E->Top.Y == E->Prev->Bot.Y) continue; //ie just an intermediate horz. + if (E2->Prev->Bot.X < E->Bot.X) E = E2; + break; + } + return E; +} +//------------------------------------------------------------------------------ + +TEdge* ClipperBase::ProcessBound(TEdge* E, bool NextIsForward) +{ + TEdge *Result = E; + TEdge *Horz = 0; + + if (E->OutIdx == Skip) + { + //if edges still remain in the current bound beyond the skip edge then + //create another LocMin and call ProcessBound once more + if (NextIsForward) + { + while (E->Top.Y == E->Next->Bot.Y) E = E->Next; + //don't include top horizontals when parsing a bound a second time, + //they will be contained in the opposite bound ... + while (E != Result && IsHorizontal(*E)) E = E->Prev; + } + else + { + while (E->Top.Y == E->Prev->Bot.Y) E = E->Prev; + while (E != Result && IsHorizontal(*E)) E = E->Next; + } + + if (E == Result) + { + if (NextIsForward) Result = E->Next; + else Result = E->Prev; + } + else + { + //there are more edges in the bound beyond result starting with E + if (NextIsForward) + E = Result->Next; + else + E = Result->Prev; + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + locMin.LeftBound = 0; + locMin.RightBound = E; + E->WindDelta = 0; + Result = ProcessBound(E, NextIsForward); + m_MinimaList.push_back(locMin); + } + return Result; + } + + TEdge *EStart; + + if (IsHorizontal(*E)) + { + //We need to be careful with open paths because this may not be a + //true local minima (ie E may be following a skip edge). + //Also, consecutive horz. edges may start heading left before going right. + if (NextIsForward) + EStart = E->Prev; + else + EStart = E->Next; + if (IsHorizontal(*EStart)) //ie an adjoining horizontal skip edge + { + if (EStart->Bot.X != E->Bot.X && EStart->Top.X != E->Bot.X) + ReverseHorizontal(*E); + } + else if (EStart->Bot.X != E->Bot.X) + ReverseHorizontal(*E); + } + + EStart = E; + if (NextIsForward) + { + while (Result->Top.Y == Result->Next->Bot.Y && Result->Next->OutIdx != Skip) + Result = Result->Next; + if (IsHorizontal(*Result) && Result->Next->OutIdx != Skip) + { + //nb: at the top of a bound, horizontals are added to the bound + //only when the preceding edge attaches to the horizontal's left vertex + //unless a Skip edge is encountered when that becomes the top divide + Horz = Result; + while (IsHorizontal(*Horz->Prev)) Horz = Horz->Prev; + if (Horz->Prev->Top.X > Result->Next->Top.X) Result = Horz->Prev; + } + while (E != Result) + { + E->NextInLML = E->Next; + if (IsHorizontal(*E) && E != EStart && + E->Bot.X != E->Prev->Top.X) ReverseHorizontal(*E); + E = E->Next; + } + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Prev->Top.X) + ReverseHorizontal(*E); + Result = Result->Next; //move to the edge just beyond current bound + } else + { + while (Result->Top.Y == Result->Prev->Bot.Y && Result->Prev->OutIdx != Skip) + Result = Result->Prev; + if (IsHorizontal(*Result) && Result->Prev->OutIdx != Skip) + { + Horz = Result; + while (IsHorizontal(*Horz->Next)) Horz = Horz->Next; + if (Horz->Next->Top.X == Result->Prev->Top.X || + Horz->Next->Top.X > Result->Prev->Top.X) Result = Horz->Next; + } + + while (E != Result) + { + E->NextInLML = E->Prev; + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X) + ReverseHorizontal(*E); + E = E->Prev; + } + if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X) + ReverseHorizontal(*E); + Result = Result->Prev; //move to the edge just beyond current bound + } + + return Result; +} +//------------------------------------------------------------------------------ + +bool ClipperBase::AddPath(const Path &pg, PolyType PolyTyp, bool Closed) +{ +#ifdef use_lines + if (!Closed && PolyTyp == ptClip) + throw clipperException("AddPath: Open paths must be subject."); +#else + if (!Closed) + throw clipperException("AddPath: Open paths have been disabled."); +#endif + + int highI = (int)pg.size() -1; + if (Closed) while (highI > 0 && (pg[highI] == pg[0])) --highI; + while (highI > 0 && (pg[highI] == pg[highI -1])) --highI; + if ((Closed && highI < 2) || (!Closed && highI < 1)) return false; + + //create a new edge array ... + TEdge *edges = new TEdge [highI +1]; + + bool IsFlat = true; + //1. Basic (first) edge initialization ... + try + { + edges[1].Curr = pg[1]; + RangeTest(pg[0], m_UseFullRange); + RangeTest(pg[highI], m_UseFullRange); + InitEdge(&edges[0], &edges[1], &edges[highI], pg[0]); + InitEdge(&edges[highI], &edges[0], &edges[highI-1], pg[highI]); + for (int i = highI - 1; i >= 1; --i) + { + RangeTest(pg[i], m_UseFullRange); + InitEdge(&edges[i], &edges[i+1], &edges[i-1], pg[i]); + } + } + catch(...) + { + delete [] edges; + throw; //range test fails + } + TEdge *eStart = &edges[0]; + + //2. Remove duplicate vertices, and (when closed) collinear edges ... + TEdge *E = eStart, *eLoopStop = eStart; + for (;;) + { + //nb: allows matching start and end points when not Closed ... + if (E->Curr == E->Next->Curr && (Closed || E->Next != eStart)) + { + if (E == E->Next) break; + if (E == eStart) eStart = E->Next; + E = RemoveEdge(E); + eLoopStop = E; + continue; + } + if (E->Prev == E->Next) + break; //only two vertices + else if (Closed && + SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr, m_UseFullRange) && + (!m_PreserveCollinear || + !Pt2IsBetweenPt1AndPt3(E->Prev->Curr, E->Curr, E->Next->Curr))) + { + //Collinear edges are allowed for open paths but in closed paths + //the default is to merge adjacent collinear edges into a single edge. + //However, if the PreserveCollinear property is enabled, only overlapping + //collinear edges (ie spikes) will be removed from closed paths. + if (E == eStart) eStart = E->Next; + E = RemoveEdge(E); + E = E->Prev; + eLoopStop = E; + continue; + } + E = E->Next; + if ((E == eLoopStop) || (!Closed && E->Next == eStart)) break; + } + + if ((!Closed && (E == E->Next)) || (Closed && (E->Prev == E->Next))) + { + delete [] edges; + return false; + } + + if (!Closed) + { + m_HasOpenPaths = true; + eStart->Prev->OutIdx = Skip; + } + + //3. Do second stage of edge initialization ... + E = eStart; + do + { + InitEdge2(*E, PolyTyp); + E = E->Next; + if (IsFlat && E->Curr.Y != eStart->Curr.Y) IsFlat = false; + } + while (E != eStart); + + //4. Finally, add edge bounds to LocalMinima list ... + + //Totally flat paths must be handled differently when adding them + //to LocalMinima list to avoid endless loops etc ... + if (IsFlat) + { + if (Closed) + { + delete [] edges; + return false; + } + E->Prev->OutIdx = Skip; + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + locMin.LeftBound = 0; + locMin.RightBound = E; + locMin.RightBound->Side = esRight; + locMin.RightBound->WindDelta = 0; + for (;;) + { + if (E->Bot.X != E->Prev->Top.X) ReverseHorizontal(*E); + if (E->Next->OutIdx == Skip) break; + E->NextInLML = E->Next; + E = E->Next; + } + m_MinimaList.push_back(locMin); + m_edges.push_back(edges); + return true; + } + + m_edges.push_back(edges); + bool leftBoundIsForward; + TEdge* EMin = 0; + + //workaround to avoid an endless loop in the while loop below when + //open paths have matching start and end points ... + if (E->Prev->Bot == E->Prev->Top) E = E->Next; + + for (;;) + { + E = FindNextLocMin(E); + if (E == EMin) break; + else if (!EMin) EMin = E; + + //E and E.Prev now share a local minima (left aligned if horizontal). + //Compare their slopes to find which starts which bound ... + MinimaList::value_type locMin; + locMin.Y = E->Bot.Y; + if (E->Dx < E->Prev->Dx) + { + locMin.LeftBound = E->Prev; + locMin.RightBound = E; + leftBoundIsForward = false; //Q.nextInLML = Q.prev + } else + { + locMin.LeftBound = E; + locMin.RightBound = E->Prev; + leftBoundIsForward = true; //Q.nextInLML = Q.next + } + + if (!Closed) locMin.LeftBound->WindDelta = 0; + else if (locMin.LeftBound->Next == locMin.RightBound) + locMin.LeftBound->WindDelta = -1; + else locMin.LeftBound->WindDelta = 1; + locMin.RightBound->WindDelta = -locMin.LeftBound->WindDelta; + + E = ProcessBound(locMin.LeftBound, leftBoundIsForward); + if (E->OutIdx == Skip) E = ProcessBound(E, leftBoundIsForward); + + TEdge* E2 = ProcessBound(locMin.RightBound, !leftBoundIsForward); + if (E2->OutIdx == Skip) E2 = ProcessBound(E2, !leftBoundIsForward); + + if (locMin.LeftBound->OutIdx == Skip) + locMin.LeftBound = 0; + else if (locMin.RightBound->OutIdx == Skip) + locMin.RightBound = 0; + m_MinimaList.push_back(locMin); + if (!leftBoundIsForward) E = E2; + } + return true; +} +//------------------------------------------------------------------------------ + +bool ClipperBase::AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed) +{ + bool result = false; + for (Paths::size_type i = 0; i < ppg.size(); ++i) + if (AddPath(ppg[i], PolyTyp, Closed)) result = true; + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::Clear() +{ + DisposeLocalMinimaList(); + for (EdgeList::size_type i = 0; i < m_edges.size(); ++i) + { + TEdge* edges = m_edges[i]; + delete [] edges; + } + m_edges.clear(); + m_UseFullRange = false; + m_HasOpenPaths = false; +} +//------------------------------------------------------------------------------ + +void ClipperBase::Reset() +{ + m_CurrentLM = m_MinimaList.begin(); + if (m_CurrentLM == m_MinimaList.end()) return; //ie nothing to process + std::sort(m_MinimaList.begin(), m_MinimaList.end(), LocMinSorter()); + + m_Scanbeam = ScanbeamList(); //clears/resets priority_queue + //reset all edges ... + for (MinimaList::iterator lm = m_MinimaList.begin(); lm != m_MinimaList.end(); ++lm) + { + InsertScanbeam(lm->Y); + TEdge* e = lm->LeftBound; + if (e) + { + e->Curr = e->Bot; + e->Side = esLeft; + e->OutIdx = Unassigned; + } + + e = lm->RightBound; + if (e) + { + e->Curr = e->Bot; + e->Side = esRight; + e->OutIdx = Unassigned; + } + } + m_ActiveEdges = 0; + m_CurrentLM = m_MinimaList.begin(); +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeLocalMinimaList() +{ + m_MinimaList.clear(); + m_CurrentLM = m_MinimaList.begin(); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::PopLocalMinima(cInt Y, const LocalMinimum *&locMin) +{ + if (m_CurrentLM == m_MinimaList.end() || (*m_CurrentLM).Y != Y) return false; + locMin = &(*m_CurrentLM); + ++m_CurrentLM; + return true; +} +//------------------------------------------------------------------------------ + +IntRect ClipperBase::GetBounds() +{ + IntRect result; + MinimaList::iterator lm = m_MinimaList.begin(); + if (lm == m_MinimaList.end()) + { + result.left = result.top = result.right = result.bottom = 0; + return result; + } + result.left = lm->LeftBound->Bot.X; + result.top = lm->LeftBound->Bot.Y; + result.right = lm->LeftBound->Bot.X; + result.bottom = lm->LeftBound->Bot.Y; + while (lm != m_MinimaList.end()) + { + //todo - needs fixing for open paths + result.bottom = std::max(result.bottom, lm->LeftBound->Bot.Y); + TEdge* e = lm->LeftBound; + for (;;) { + TEdge* bottomE = e; + while (e->NextInLML) + { + if (e->Bot.X < result.left) result.left = e->Bot.X; + if (e->Bot.X > result.right) result.right = e->Bot.X; + e = e->NextInLML; + } + result.left = std::min(result.left, e->Bot.X); + result.right = std::max(result.right, e->Bot.X); + result.left = std::min(result.left, e->Top.X); + result.right = std::max(result.right, e->Top.X); + result.top = std::min(result.top, e->Top.Y); + if (bottomE == lm->LeftBound) e = lm->RightBound; + else break; + } + ++lm; + } + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::InsertScanbeam(const cInt Y) +{ + m_Scanbeam.push(Y); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::PopScanbeam(cInt &Y) +{ + if (m_Scanbeam.empty()) return false; + Y = m_Scanbeam.top(); + m_Scanbeam.pop(); + while (!m_Scanbeam.empty() && Y == m_Scanbeam.top()) { m_Scanbeam.pop(); } // Pop duplicates. + return true; +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeAllOutRecs(){ + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + DisposeOutRec(i); + m_PolyOuts.clear(); +} +//------------------------------------------------------------------------------ + +void ClipperBase::DisposeOutRec(PolyOutList::size_type index) +{ + OutRec *outRec = m_PolyOuts[index]; + if (outRec->Pts) DisposeOutPts(outRec->Pts); + delete outRec; + m_PolyOuts[index] = 0; +} +//------------------------------------------------------------------------------ + +void ClipperBase::DeleteFromAEL(TEdge *e) +{ + TEdge* AelPrev = e->PrevInAEL; + TEdge* AelNext = e->NextInAEL; + if (!AelPrev && !AelNext && (e != m_ActiveEdges)) return; //already deleted + if (AelPrev) AelPrev->NextInAEL = AelNext; + else m_ActiveEdges = AelNext; + if (AelNext) AelNext->PrevInAEL = AelPrev; + e->NextInAEL = 0; + e->PrevInAEL = 0; +} +//------------------------------------------------------------------------------ + +OutRec* ClipperBase::CreateOutRec() +{ + OutRec* result = new OutRec; + result->IsHole = false; + result->IsOpen = false; + result->FirstLeft = 0; + result->Pts = 0; + result->BottomPt = 0; + result->PolyNd = 0; + m_PolyOuts.push_back(result); + result->Idx = (int)m_PolyOuts.size() - 1; + return result; +} +//------------------------------------------------------------------------------ + +void ClipperBase::SwapPositionsInAEL(TEdge *Edge1, TEdge *Edge2) +{ + //check that one or other edge hasn't already been removed from AEL ... + if (Edge1->NextInAEL == Edge1->PrevInAEL || + Edge2->NextInAEL == Edge2->PrevInAEL) return; + + if (Edge1->NextInAEL == Edge2) + { + TEdge* Next = Edge2->NextInAEL; + if (Next) Next->PrevInAEL = Edge1; + TEdge* Prev = Edge1->PrevInAEL; + if (Prev) Prev->NextInAEL = Edge2; + Edge2->PrevInAEL = Prev; + Edge2->NextInAEL = Edge1; + Edge1->PrevInAEL = Edge2; + Edge1->NextInAEL = Next; + } + else if (Edge2->NextInAEL == Edge1) + { + TEdge* Next = Edge1->NextInAEL; + if (Next) Next->PrevInAEL = Edge2; + TEdge* Prev = Edge2->PrevInAEL; + if (Prev) Prev->NextInAEL = Edge1; + Edge1->PrevInAEL = Prev; + Edge1->NextInAEL = Edge2; + Edge2->PrevInAEL = Edge1; + Edge2->NextInAEL = Next; + } + else + { + TEdge* Next = Edge1->NextInAEL; + TEdge* Prev = Edge1->PrevInAEL; + Edge1->NextInAEL = Edge2->NextInAEL; + if (Edge1->NextInAEL) Edge1->NextInAEL->PrevInAEL = Edge1; + Edge1->PrevInAEL = Edge2->PrevInAEL; + if (Edge1->PrevInAEL) Edge1->PrevInAEL->NextInAEL = Edge1; + Edge2->NextInAEL = Next; + if (Edge2->NextInAEL) Edge2->NextInAEL->PrevInAEL = Edge2; + Edge2->PrevInAEL = Prev; + if (Edge2->PrevInAEL) Edge2->PrevInAEL->NextInAEL = Edge2; + } + + if (!Edge1->PrevInAEL) m_ActiveEdges = Edge1; + else if (!Edge2->PrevInAEL) m_ActiveEdges = Edge2; +} +//------------------------------------------------------------------------------ + +void ClipperBase::UpdateEdgeIntoAEL(TEdge *&e) +{ + if (!e->NextInLML) + throw clipperException("UpdateEdgeIntoAEL: invalid call"); + + e->NextInLML->OutIdx = e->OutIdx; + TEdge* AelPrev = e->PrevInAEL; + TEdge* AelNext = e->NextInAEL; + if (AelPrev) AelPrev->NextInAEL = e->NextInLML; + else m_ActiveEdges = e->NextInLML; + if (AelNext) AelNext->PrevInAEL = e->NextInLML; + e->NextInLML->Side = e->Side; + e->NextInLML->WindDelta = e->WindDelta; + e->NextInLML->WindCnt = e->WindCnt; + e->NextInLML->WindCnt2 = e->WindCnt2; + e = e->NextInLML; + e->Curr = e->Bot; + e->PrevInAEL = AelPrev; + e->NextInAEL = AelNext; + if (!IsHorizontal(*e)) InsertScanbeam(e->Top.Y); +} +//------------------------------------------------------------------------------ + +bool ClipperBase::LocalMinimaPending() +{ + return (m_CurrentLM != m_MinimaList.end()); +} + +//------------------------------------------------------------------------------ +// TClipper methods ... +//------------------------------------------------------------------------------ + +Clipper::Clipper(int initOptions) : ClipperBase() //constructor +{ + m_ExecuteLocked = false; + m_UseFullRange = false; + m_ReverseOutput = ((initOptions & ioReverseSolution) != 0); + m_StrictSimple = ((initOptions & ioStrictlySimple) != 0); + m_PreserveCollinear = ((initOptions & ioPreserveCollinear) != 0); + m_HasOpenPaths = false; +#ifdef use_xyz + m_ZFill = 0; +#endif +} +//------------------------------------------------------------------------------ + +#ifdef use_xyz +void Clipper::ZFillFunction(ZFillCallback zFillFunc) +{ + m_ZFill = zFillFunc; +} +//------------------------------------------------------------------------------ +#endif + +bool Clipper::Execute(ClipType clipType, Paths &solution, PolyFillType fillType) +{ + return Execute(clipType, solution, fillType, fillType); +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, PolyTree &polytree, PolyFillType fillType) +{ + return Execute(clipType, polytree, fillType, fillType); +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, Paths &solution, + PolyFillType subjFillType, PolyFillType clipFillType) +{ + if( m_ExecuteLocked ) return false; + if (m_HasOpenPaths) + throw clipperException("Error: PolyTree struct is needed for open path clipping."); + m_ExecuteLocked = true; + solution.resize(0); + m_SubjFillType = subjFillType; + m_ClipFillType = clipFillType; + m_ClipType = clipType; + m_UsingPolyTree = false; + bool succeeded = ExecuteInternal(); + if (succeeded) BuildResult(solution); + DisposeAllOutRecs(); + m_ExecuteLocked = false; + return succeeded; +} +//------------------------------------------------------------------------------ + +bool Clipper::Execute(ClipType clipType, PolyTree& polytree, + PolyFillType subjFillType, PolyFillType clipFillType) +{ + if( m_ExecuteLocked ) return false; + m_ExecuteLocked = true; + m_SubjFillType = subjFillType; + m_ClipFillType = clipFillType; + m_ClipType = clipType; + m_UsingPolyTree = true; + bool succeeded = ExecuteInternal(); + if (succeeded) BuildResult2(polytree); + DisposeAllOutRecs(); + m_ExecuteLocked = false; + return succeeded; +} +//------------------------------------------------------------------------------ + +void Clipper::FixHoleLinkage(OutRec &outrec) +{ + //skip OutRecs that (a) contain outermost polygons or + //(b) already have the correct owner/child linkage ... + if (!outrec.FirstLeft || + (outrec.IsHole != outrec.FirstLeft->IsHole && + outrec.FirstLeft->Pts)) return; + + OutRec* orfl = outrec.FirstLeft; + while (orfl && ((orfl->IsHole == outrec.IsHole) || !orfl->Pts)) + orfl = orfl->FirstLeft; + outrec.FirstLeft = orfl; +} +//------------------------------------------------------------------------------ + +bool Clipper::ExecuteInternal() +{ + bool succeeded = true; + try { + Reset(); + m_Maxima = MaximaList(); + m_SortedEdges = 0; + + succeeded = true; + cInt botY, topY; + if (!PopScanbeam(botY)) return false; + InsertLocalMinimaIntoAEL(botY); + while (PopScanbeam(topY) || LocalMinimaPending()) + { + ProcessHorizontals(); + ClearGhostJoins(); + if (!ProcessIntersections(topY)) + { + succeeded = false; + break; + } + ProcessEdgesAtTopOfScanbeam(topY); + botY = topY; + InsertLocalMinimaIntoAEL(botY); + } + } + catch(...) + { + succeeded = false; + } + + if (succeeded) + { + //fix orientations ... + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec *outRec = m_PolyOuts[i]; + if (!outRec->Pts || outRec->IsOpen) continue; + if ((outRec->IsHole ^ m_ReverseOutput) == (Area(*outRec) > 0)) + ReversePolyPtLinks(outRec->Pts); + } + + if (!m_Joins.empty()) JoinCommonEdges(); + + //unfortunately FixupOutPolygon() must be done after JoinCommonEdges() + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec *outRec = m_PolyOuts[i]; + if (!outRec->Pts) continue; + if (outRec->IsOpen) + FixupOutPolyline(*outRec); + else + FixupOutPolygon(*outRec); + } + + if (m_StrictSimple) DoSimplePolygons(); + } + + ClearJoins(); + ClearGhostJoins(); + return succeeded; +} +//------------------------------------------------------------------------------ + +void Clipper::SetWindingCount(TEdge &edge) +{ + TEdge *e = edge.PrevInAEL; + //find the edge of the same polytype that immediately preceeds 'edge' in AEL + while (e && ((e->PolyTyp != edge.PolyTyp) || (e->WindDelta == 0))) e = e->PrevInAEL; + if (!e) + { + if (edge.WindDelta == 0) + { + PolyFillType pft = (edge.PolyTyp == ptSubject ? m_SubjFillType : m_ClipFillType); + edge.WindCnt = (pft == pftNegative ? -1 : 1); + } + else + edge.WindCnt = edge.WindDelta; + edge.WindCnt2 = 0; + e = m_ActiveEdges; //ie get ready to calc WindCnt2 + } + else if (edge.WindDelta == 0 && m_ClipType != ctUnion) + { + edge.WindCnt = 1; + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; //ie get ready to calc WindCnt2 + } + else if (IsEvenOddFillType(edge)) + { + //EvenOdd filling ... + if (edge.WindDelta == 0) + { + //are we inside a subj polygon ... + bool Inside = true; + TEdge *e2 = e->PrevInAEL; + while (e2) + { + if (e2->PolyTyp == e->PolyTyp && e2->WindDelta != 0) + Inside = !Inside; + e2 = e2->PrevInAEL; + } + edge.WindCnt = (Inside ? 0 : 1); + } + else + { + edge.WindCnt = edge.WindDelta; + } + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; //ie get ready to calc WindCnt2 + } + else + { + //nonZero, Positive or Negative filling ... + if (e->WindCnt * e->WindDelta < 0) + { + //prev edge is 'decreasing' WindCount (WC) toward zero + //so we're outside the previous polygon ... + if (Abs(e->WindCnt) > 1) + { + //outside prev poly but still inside another. + //when reversing direction of prev poly use the same WC + if (e->WindDelta * edge.WindDelta < 0) edge.WindCnt = e->WindCnt; + //otherwise continue to 'decrease' WC ... + else edge.WindCnt = e->WindCnt + edge.WindDelta; + } + else + //now outside all polys of same polytype so set own WC ... + edge.WindCnt = (edge.WindDelta == 0 ? 1 : edge.WindDelta); + } else + { + //prev edge is 'increasing' WindCount (WC) away from zero + //so we're inside the previous polygon ... + if (edge.WindDelta == 0) + edge.WindCnt = (e->WindCnt < 0 ? e->WindCnt - 1 : e->WindCnt + 1); + //if wind direction is reversing prev then use same WC + else if (e->WindDelta * edge.WindDelta < 0) edge.WindCnt = e->WindCnt; + //otherwise add to WC ... + else edge.WindCnt = e->WindCnt + edge.WindDelta; + } + edge.WindCnt2 = e->WindCnt2; + e = e->NextInAEL; //ie get ready to calc WindCnt2 + } + + //update WindCnt2 ... + if (IsEvenOddAltFillType(edge)) + { + //EvenOdd filling ... + while (e != &edge) + { + if (e->WindDelta != 0) + edge.WindCnt2 = (edge.WindCnt2 == 0 ? 1 : 0); + e = e->NextInAEL; + } + } else + { + //nonZero, Positive or Negative filling ... + while ( e != &edge ) + { + edge.WindCnt2 += e->WindDelta; + e = e->NextInAEL; + } + } +} +//------------------------------------------------------------------------------ + +bool Clipper::IsEvenOddFillType(const TEdge& edge) const +{ + if (edge.PolyTyp == ptSubject) + return m_SubjFillType == pftEvenOdd; else + return m_ClipFillType == pftEvenOdd; +} +//------------------------------------------------------------------------------ + +bool Clipper::IsEvenOddAltFillType(const TEdge& edge) const +{ + if (edge.PolyTyp == ptSubject) + return m_ClipFillType == pftEvenOdd; else + return m_SubjFillType == pftEvenOdd; +} +//------------------------------------------------------------------------------ + +bool Clipper::IsContributing(const TEdge& edge) const +{ + PolyFillType pft, pft2; + if (edge.PolyTyp == ptSubject) + { + pft = m_SubjFillType; + pft2 = m_ClipFillType; + } else + { + pft = m_ClipFillType; + pft2 = m_SubjFillType; + } + + switch(pft) + { + case pftEvenOdd: + //return false if a subj line has been flagged as inside a subj polygon + if (edge.WindDelta == 0 && edge.WindCnt != 1) return false; + break; + case pftNonZero: + if (Abs(edge.WindCnt) != 1) return false; + break; + case pftPositive: + if (edge.WindCnt != 1) return false; + break; + default: //pftNegative + if (edge.WindCnt != -1) return false; + } + + switch(m_ClipType) + { + case ctIntersection: + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 != 0); + case pftPositive: + return (edge.WindCnt2 > 0); + default: + return (edge.WindCnt2 < 0); + } + break; + case ctUnion: + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + break; + case ctDifference: + if (edge.PolyTyp == ptSubject) + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + else + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 != 0); + case pftPositive: + return (edge.WindCnt2 > 0); + default: + return (edge.WindCnt2 < 0); + } + break; + case ctXor: + if (edge.WindDelta == 0) //XOr always contributing unless open + switch(pft2) + { + case pftEvenOdd: + case pftNonZero: + return (edge.WindCnt2 == 0); + case pftPositive: + return (edge.WindCnt2 <= 0); + default: + return (edge.WindCnt2 >= 0); + } + else + return true; + break; + default: + return true; + } +} +//------------------------------------------------------------------------------ + +OutPt* Clipper::AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) +{ + OutPt* result; + TEdge *e, *prevE; + if (IsHorizontal(*e2) || ( e1->Dx > e2->Dx )) + { + result = AddOutPt(e1, Pt); + e2->OutIdx = e1->OutIdx; + e1->Side = esLeft; + e2->Side = esRight; + e = e1; + if (e->PrevInAEL == e2) + prevE = e2->PrevInAEL; + else + prevE = e->PrevInAEL; + } else + { + result = AddOutPt(e2, Pt); + e1->OutIdx = e2->OutIdx; + e1->Side = esRight; + e2->Side = esLeft; + e = e2; + if (e->PrevInAEL == e1) + prevE = e1->PrevInAEL; + else + prevE = e->PrevInAEL; + } + + if (prevE && prevE->OutIdx >= 0) + { + cInt xPrev = TopX(*prevE, Pt.Y); + cInt xE = TopX(*e, Pt.Y); + if (xPrev == xE && (e->WindDelta != 0) && (prevE->WindDelta != 0) && + SlopesEqual(IntPoint(xPrev, Pt.Y), prevE->Top, IntPoint(xE, Pt.Y), e->Top, m_UseFullRange)) + { + OutPt* outPt = AddOutPt(prevE, Pt); + AddJoin(result, outPt, e->Top); + } + } + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) +{ + AddOutPt( e1, Pt ); + if (e2->WindDelta == 0) AddOutPt(e2, Pt); + if( e1->OutIdx == e2->OutIdx ) + { + e1->OutIdx = Unassigned; + e2->OutIdx = Unassigned; + } + else if (e1->OutIdx < e2->OutIdx) + AppendPolygon(e1, e2); + else + AppendPolygon(e2, e1); +} +//------------------------------------------------------------------------------ + +void Clipper::AddEdgeToSEL(TEdge *edge) +{ + //SEL pointers in PEdge are reused to build a list of horizontal edges. + //However, we don't need to worry about order with horizontal edge processing. + if( !m_SortedEdges ) + { + m_SortedEdges = edge; + edge->PrevInSEL = 0; + edge->NextInSEL = 0; + } + else + { + edge->NextInSEL = m_SortedEdges; + edge->PrevInSEL = 0; + m_SortedEdges->PrevInSEL = edge; + m_SortedEdges = edge; + } +} +//------------------------------------------------------------------------------ + +bool Clipper::PopEdgeFromSEL(TEdge *&edge) +{ + if (!m_SortedEdges) return false; + edge = m_SortedEdges; + DeleteFromSEL(m_SortedEdges); + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::CopyAELToSEL() +{ + TEdge* e = m_ActiveEdges; + m_SortedEdges = e; + while ( e ) + { + e->PrevInSEL = e->PrevInAEL; + e->NextInSEL = e->NextInAEL; + e = e->NextInAEL; + } +} +//------------------------------------------------------------------------------ + +void Clipper::AddJoin(OutPt *op1, OutPt *op2, const IntPoint OffPt) +{ + Join* j = new Join; + j->OutPt1 = op1; + j->OutPt2 = op2; + j->OffPt = OffPt; + m_Joins.push_back(j); +} +//------------------------------------------------------------------------------ + +void Clipper::ClearJoins() +{ + for (JoinList::size_type i = 0; i < m_Joins.size(); i++) + delete m_Joins[i]; + m_Joins.resize(0); +} +//------------------------------------------------------------------------------ + +void Clipper::ClearGhostJoins() +{ + for (JoinList::size_type i = 0; i < m_GhostJoins.size(); i++) + delete m_GhostJoins[i]; + m_GhostJoins.resize(0); +} +//------------------------------------------------------------------------------ + +void Clipper::AddGhostJoin(OutPt *op, const IntPoint OffPt) +{ + Join* j = new Join; + j->OutPt1 = op; + j->OutPt2 = 0; + j->OffPt = OffPt; + m_GhostJoins.push_back(j); +} +//------------------------------------------------------------------------------ + +void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) +{ + const LocalMinimum *lm; + while (PopLocalMinima(botY, lm)) + { + TEdge* lb = lm->LeftBound; + TEdge* rb = lm->RightBound; + + OutPt *Op1 = 0; + if (!lb) + { + //nb: don't insert LB into either AEL or SEL + InsertEdgeIntoAEL(rb, 0); + SetWindingCount(*rb); + if (IsContributing(*rb)) + Op1 = AddOutPt(rb, rb->Bot); + } + else if (!rb) + { + InsertEdgeIntoAEL(lb, 0); + SetWindingCount(*lb); + if (IsContributing(*lb)) + Op1 = AddOutPt(lb, lb->Bot); + InsertScanbeam(lb->Top.Y); + } + else + { + InsertEdgeIntoAEL(lb, 0); + InsertEdgeIntoAEL(rb, lb); + SetWindingCount( *lb ); + rb->WindCnt = lb->WindCnt; + rb->WindCnt2 = lb->WindCnt2; + if (IsContributing(*lb)) + Op1 = AddLocalMinPoly(lb, rb, lb->Bot); + InsertScanbeam(lb->Top.Y); + } + + if (rb) + { + if (IsHorizontal(*rb)) + { + AddEdgeToSEL(rb); + if (rb->NextInLML) + InsertScanbeam(rb->NextInLML->Top.Y); + } + else InsertScanbeam( rb->Top.Y ); + } + + if (!lb || !rb) continue; + + //if any output polygons share an edge, they'll need joining later ... + if (Op1 && IsHorizontal(*rb) && + m_GhostJoins.size() > 0 && (rb->WindDelta != 0)) + { + for (JoinList::size_type i = 0; i < m_GhostJoins.size(); ++i) + { + Join* jr = m_GhostJoins[i]; + //if the horizontal Rb and a 'ghost' horizontal overlap, then convert + //the 'ghost' join to a real join ready for later ... + if (HorzSegmentsOverlap(jr->OutPt1->Pt.X, jr->OffPt.X, rb->Bot.X, rb->Top.X)) + AddJoin(jr->OutPt1, Op1, jr->OffPt); + } + } + + if (lb->OutIdx >= 0 && lb->PrevInAEL && + lb->PrevInAEL->Curr.X == lb->Bot.X && + lb->PrevInAEL->OutIdx >= 0 && + SlopesEqual(lb->PrevInAEL->Bot, lb->PrevInAEL->Top, lb->Curr, lb->Top, m_UseFullRange) && + (lb->WindDelta != 0) && (lb->PrevInAEL->WindDelta != 0)) + { + OutPt *Op2 = AddOutPt(lb->PrevInAEL, lb->Bot); + AddJoin(Op1, Op2, lb->Top); + } + + if(lb->NextInAEL != rb) + { + + if (rb->OutIdx >= 0 && rb->PrevInAEL->OutIdx >= 0 && + SlopesEqual(rb->PrevInAEL->Curr, rb->PrevInAEL->Top, rb->Curr, rb->Top, m_UseFullRange) && + (rb->WindDelta != 0) && (rb->PrevInAEL->WindDelta != 0)) + { + OutPt *Op2 = AddOutPt(rb->PrevInAEL, rb->Bot); + AddJoin(Op1, Op2, rb->Top); + } + + TEdge* e = lb->NextInAEL; + if (e) + { + while( e != rb ) + { + //nb: For calculating winding counts etc, IntersectEdges() assumes + //that param1 will be to the Right of param2 ABOVE the intersection ... + IntersectEdges(rb , e , lb->Curr); //order important here + e = e->NextInAEL; + } + } + } + + } +} +//------------------------------------------------------------------------------ + +void Clipper::DeleteFromSEL(TEdge *e) +{ + TEdge* SelPrev = e->PrevInSEL; + TEdge* SelNext = e->NextInSEL; + if( !SelPrev && !SelNext && (e != m_SortedEdges) ) return; //already deleted + if( SelPrev ) SelPrev->NextInSEL = SelNext; + else m_SortedEdges = SelNext; + if( SelNext ) SelNext->PrevInSEL = SelPrev; + e->NextInSEL = 0; + e->PrevInSEL = 0; +} +//------------------------------------------------------------------------------ + +#ifdef use_xyz +void Clipper::SetZ(IntPoint& pt, TEdge& e1, TEdge& e2) +{ + if (pt.Z != 0 || !m_ZFill) return; + else if (pt == e1.Bot) pt.Z = e1.Bot.Z; + else if (pt == e1.Top) pt.Z = e1.Top.Z; + else if (pt == e2.Bot) pt.Z = e2.Bot.Z; + else if (pt == e2.Top) pt.Z = e2.Top.Z; + else (*m_ZFill)(e1.Bot, e1.Top, e2.Bot, e2.Top, pt); +} +//------------------------------------------------------------------------------ +#endif + +void Clipper::IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &Pt) +{ + bool e1Contributing = ( e1->OutIdx >= 0 ); + bool e2Contributing = ( e2->OutIdx >= 0 ); + +#ifdef use_xyz + SetZ(Pt, *e1, *e2); +#endif + +#ifdef use_lines + //if either edge is on an OPEN path ... + if (e1->WindDelta == 0 || e2->WindDelta == 0) + { + //ignore subject-subject open path intersections UNLESS they + //are both open paths, AND they are both 'contributing maximas' ... + if (e1->WindDelta == 0 && e2->WindDelta == 0) return; + + //if intersecting a subj line with a subj poly ... + else if (e1->PolyTyp == e2->PolyTyp && + e1->WindDelta != e2->WindDelta && m_ClipType == ctUnion) + { + if (e1->WindDelta == 0) + { + if (e2Contributing) + { + AddOutPt(e1, Pt); + if (e1Contributing) e1->OutIdx = Unassigned; + } + } + else + { + if (e1Contributing) + { + AddOutPt(e2, Pt); + if (e2Contributing) e2->OutIdx = Unassigned; + } + } + } + else if (e1->PolyTyp != e2->PolyTyp) + { + //toggle subj open path OutIdx on/off when Abs(clip.WndCnt) == 1 ... + if ((e1->WindDelta == 0) && abs(e2->WindCnt) == 1 && + (m_ClipType != ctUnion || e2->WindCnt2 == 0)) + { + AddOutPt(e1, Pt); + if (e1Contributing) e1->OutIdx = Unassigned; + } + else if ((e2->WindDelta == 0) && (abs(e1->WindCnt) == 1) && + (m_ClipType != ctUnion || e1->WindCnt2 == 0)) + { + AddOutPt(e2, Pt); + if (e2Contributing) e2->OutIdx = Unassigned; + } + } + return; + } +#endif + + //update winding counts... + //assumes that e1 will be to the Right of e2 ABOVE the intersection + if ( e1->PolyTyp == e2->PolyTyp ) + { + if ( IsEvenOddFillType( *e1) ) + { + int oldE1WindCnt = e1->WindCnt; + e1->WindCnt = e2->WindCnt; + e2->WindCnt = oldE1WindCnt; + } else + { + if (e1->WindCnt + e2->WindDelta == 0 ) e1->WindCnt = -e1->WindCnt; + else e1->WindCnt += e2->WindDelta; + if ( e2->WindCnt - e1->WindDelta == 0 ) e2->WindCnt = -e2->WindCnt; + else e2->WindCnt -= e1->WindDelta; + } + } else + { + if (!IsEvenOddFillType(*e2)) e1->WindCnt2 += e2->WindDelta; + else e1->WindCnt2 = ( e1->WindCnt2 == 0 ) ? 1 : 0; + if (!IsEvenOddFillType(*e1)) e2->WindCnt2 -= e1->WindDelta; + else e2->WindCnt2 = ( e2->WindCnt2 == 0 ) ? 1 : 0; + } + + PolyFillType e1FillType, e2FillType, e1FillType2, e2FillType2; + if (e1->PolyTyp == ptSubject) + { + e1FillType = m_SubjFillType; + e1FillType2 = m_ClipFillType; + } else + { + e1FillType = m_ClipFillType; + e1FillType2 = m_SubjFillType; + } + if (e2->PolyTyp == ptSubject) + { + e2FillType = m_SubjFillType; + e2FillType2 = m_ClipFillType; + } else + { + e2FillType = m_ClipFillType; + e2FillType2 = m_SubjFillType; + } + + cInt e1Wc, e2Wc; + switch (e1FillType) + { + case pftPositive: e1Wc = e1->WindCnt; break; + case pftNegative: e1Wc = -e1->WindCnt; break; + default: e1Wc = Abs(e1->WindCnt); + } + switch(e2FillType) + { + case pftPositive: e2Wc = e2->WindCnt; break; + case pftNegative: e2Wc = -e2->WindCnt; break; + default: e2Wc = Abs(e2->WindCnt); + } + + if ( e1Contributing && e2Contributing ) + { + if ((e1Wc != 0 && e1Wc != 1) || (e2Wc != 0 && e2Wc != 1) || + (e1->PolyTyp != e2->PolyTyp && m_ClipType != ctXor) ) + { + AddLocalMaxPoly(e1, e2, Pt); + } + else + { + AddOutPt(e1, Pt); + AddOutPt(e2, Pt); + SwapSides( *e1 , *e2 ); + SwapPolyIndexes( *e1 , *e2 ); + } + } + else if ( e1Contributing ) + { + if (e2Wc == 0 || e2Wc == 1) + { + AddOutPt(e1, Pt); + SwapSides(*e1, *e2); + SwapPolyIndexes(*e1, *e2); + } + } + else if ( e2Contributing ) + { + if (e1Wc == 0 || e1Wc == 1) + { + AddOutPt(e2, Pt); + SwapSides(*e1, *e2); + SwapPolyIndexes(*e1, *e2); + } + } + else if ( (e1Wc == 0 || e1Wc == 1) && (e2Wc == 0 || e2Wc == 1)) + { + //neither edge is currently contributing ... + + cInt e1Wc2, e2Wc2; + switch (e1FillType2) + { + case pftPositive: e1Wc2 = e1->WindCnt2; break; + case pftNegative : e1Wc2 = -e1->WindCnt2; break; + default: e1Wc2 = Abs(e1->WindCnt2); + } + switch (e2FillType2) + { + case pftPositive: e2Wc2 = e2->WindCnt2; break; + case pftNegative: e2Wc2 = -e2->WindCnt2; break; + default: e2Wc2 = Abs(e2->WindCnt2); + } + + if (e1->PolyTyp != e2->PolyTyp) + { + AddLocalMinPoly(e1, e2, Pt); + } + else if (e1Wc == 1 && e2Wc == 1) + switch( m_ClipType ) { + case ctIntersection: + if (e1Wc2 > 0 && e2Wc2 > 0) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctUnion: + if ( e1Wc2 <= 0 && e2Wc2 <= 0 ) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctDifference: + if (((e1->PolyTyp == ptClip) && (e1Wc2 > 0) && (e2Wc2 > 0)) || + ((e1->PolyTyp == ptSubject) && (e1Wc2 <= 0) && (e2Wc2 <= 0))) + AddLocalMinPoly(e1, e2, Pt); + break; + case ctXor: + AddLocalMinPoly(e1, e2, Pt); + } + else + SwapSides( *e1, *e2 ); + } +} +//------------------------------------------------------------------------------ + +void Clipper::SetHoleState(TEdge *e, OutRec *outrec) +{ + TEdge *e2 = e->PrevInAEL; + TEdge *eTmp = 0; + while (e2) + { + if (e2->OutIdx >= 0 && e2->WindDelta != 0) + { + if (!eTmp) eTmp = e2; + else if (eTmp->OutIdx == e2->OutIdx) eTmp = 0; + } + e2 = e2->PrevInAEL; + } + if (!eTmp) + { + outrec->FirstLeft = 0; + outrec->IsHole = false; + } + else + { + outrec->FirstLeft = m_PolyOuts[eTmp->OutIdx]; + outrec->IsHole = !outrec->FirstLeft->IsHole; + } +} +//------------------------------------------------------------------------------ + +OutRec* GetLowermostRec(OutRec *outRec1, OutRec *outRec2) +{ + //work out which polygon fragment has the correct hole state ... + if (!outRec1->BottomPt) + outRec1->BottomPt = GetBottomPt(outRec1->Pts); + if (!outRec2->BottomPt) + outRec2->BottomPt = GetBottomPt(outRec2->Pts); + OutPt *OutPt1 = outRec1->BottomPt; + OutPt *OutPt2 = outRec2->BottomPt; + if (OutPt1->Pt.Y > OutPt2->Pt.Y) return outRec1; + else if (OutPt1->Pt.Y < OutPt2->Pt.Y) return outRec2; + else if (OutPt1->Pt.X < OutPt2->Pt.X) return outRec1; + else if (OutPt1->Pt.X > OutPt2->Pt.X) return outRec2; + else if (OutPt1->Next == OutPt1) return outRec2; + else if (OutPt2->Next == OutPt2) return outRec1; + else if (FirstIsBottomPt(OutPt1, OutPt2)) return outRec1; + else return outRec2; +} +//------------------------------------------------------------------------------ + +bool OutRec1RightOfOutRec2(OutRec* outRec1, OutRec* outRec2) +{ + do + { + outRec1 = outRec1->FirstLeft; + if (outRec1 == outRec2) return true; + } while (outRec1); + return false; +} +//------------------------------------------------------------------------------ + +OutRec* Clipper::GetOutRec(int Idx) +{ + OutRec* outrec = m_PolyOuts[Idx]; + while (outrec != m_PolyOuts[outrec->Idx]) + outrec = m_PolyOuts[outrec->Idx]; + return outrec; +} +//------------------------------------------------------------------------------ + +void Clipper::AppendPolygon(TEdge *e1, TEdge *e2) +{ + //get the start and ends of both output polygons ... + OutRec *outRec1 = m_PolyOuts[e1->OutIdx]; + OutRec *outRec2 = m_PolyOuts[e2->OutIdx]; + + OutRec *holeStateRec; + if (OutRec1RightOfOutRec2(outRec1, outRec2)) + holeStateRec = outRec2; + else if (OutRec1RightOfOutRec2(outRec2, outRec1)) + holeStateRec = outRec1; + else + holeStateRec = GetLowermostRec(outRec1, outRec2); + + //get the start and ends of both output polygons and + //join e2 poly onto e1 poly and delete pointers to e2 ... + + OutPt* p1_lft = outRec1->Pts; + OutPt* p1_rt = p1_lft->Prev; + OutPt* p2_lft = outRec2->Pts; + OutPt* p2_rt = p2_lft->Prev; + + //join e2 poly onto e1 poly and delete pointers to e2 ... + if( e1->Side == esLeft ) + { + if( e2->Side == esLeft ) + { + //z y x a b c + ReversePolyPtLinks(p2_lft); + p2_lft->Next = p1_lft; + p1_lft->Prev = p2_lft; + p1_rt->Next = p2_rt; + p2_rt->Prev = p1_rt; + outRec1->Pts = p2_rt; + } else + { + //x y z a b c + p2_rt->Next = p1_lft; + p1_lft->Prev = p2_rt; + p2_lft->Prev = p1_rt; + p1_rt->Next = p2_lft; + outRec1->Pts = p2_lft; + } + } else + { + if( e2->Side == esRight ) + { + //a b c z y x + ReversePolyPtLinks(p2_lft); + p1_rt->Next = p2_rt; + p2_rt->Prev = p1_rt; + p2_lft->Next = p1_lft; + p1_lft->Prev = p2_lft; + } else + { + //a b c x y z + p1_rt->Next = p2_lft; + p2_lft->Prev = p1_rt; + p1_lft->Prev = p2_rt; + p2_rt->Next = p1_lft; + } + } + + outRec1->BottomPt = 0; + if (holeStateRec == outRec2) + { + if (outRec2->FirstLeft != outRec1) + outRec1->FirstLeft = outRec2->FirstLeft; + outRec1->IsHole = outRec2->IsHole; + } + outRec2->Pts = 0; + outRec2->BottomPt = 0; + outRec2->FirstLeft = outRec1; + + int OKIdx = e1->OutIdx; + int ObsoleteIdx = e2->OutIdx; + + e1->OutIdx = Unassigned; //nb: safe because we only get here via AddLocalMaxPoly + e2->OutIdx = Unassigned; + + TEdge* e = m_ActiveEdges; + while( e ) + { + if( e->OutIdx == ObsoleteIdx ) + { + e->OutIdx = OKIdx; + e->Side = e1->Side; + break; + } + e = e->NextInAEL; + } + + outRec2->Idx = outRec1->Idx; +} +//------------------------------------------------------------------------------ + +OutPt* Clipper::AddOutPt(TEdge *e, const IntPoint &pt) +{ + if( e->OutIdx < 0 ) + { + OutRec *outRec = CreateOutRec(); + outRec->IsOpen = (e->WindDelta == 0); + OutPt* newOp = new OutPt; + outRec->Pts = newOp; + newOp->Idx = outRec->Idx; + newOp->Pt = pt; + newOp->Next = newOp; + newOp->Prev = newOp; + if (!outRec->IsOpen) + SetHoleState(e, outRec); + e->OutIdx = outRec->Idx; + return newOp; + } else + { + OutRec *outRec = m_PolyOuts[e->OutIdx]; + //OutRec.Pts is the 'Left-most' point & OutRec.Pts.Prev is the 'Right-most' + OutPt* op = outRec->Pts; + + bool ToFront = (e->Side == esLeft); + if (ToFront && (pt == op->Pt)) return op; + else if (!ToFront && (pt == op->Prev->Pt)) return op->Prev; + + OutPt* newOp = new OutPt; + newOp->Idx = outRec->Idx; + newOp->Pt = pt; + newOp->Next = op; + newOp->Prev = op->Prev; + newOp->Prev->Next = newOp; + op->Prev = newOp; + if (ToFront) outRec->Pts = newOp; + return newOp; + } +} +//------------------------------------------------------------------------------ + +OutPt* Clipper::GetLastOutPt(TEdge *e) +{ + OutRec *outRec = m_PolyOuts[e->OutIdx]; + if (e->Side == esLeft) + return outRec->Pts; + else + return outRec->Pts->Prev; +} +//------------------------------------------------------------------------------ + +void Clipper::ProcessHorizontals() +{ + TEdge* horzEdge; + while (PopEdgeFromSEL(horzEdge)) + ProcessHorizontal(horzEdge); +} +//------------------------------------------------------------------------------ + +inline bool IsMinima(TEdge *e) +{ + return e && (e->Prev->NextInLML != e) && (e->Next->NextInLML != e); +} +//------------------------------------------------------------------------------ + +inline bool IsMaxima(TEdge *e, const cInt Y) +{ + return e && e->Top.Y == Y && !e->NextInLML; +} +//------------------------------------------------------------------------------ + +inline bool IsIntermediate(TEdge *e, const cInt Y) +{ + return e->Top.Y == Y && e->NextInLML; +} +//------------------------------------------------------------------------------ + +TEdge *GetMaximaPair(TEdge *e) +{ + if ((e->Next->Top == e->Top) && !e->Next->NextInLML) + return e->Next; + else if ((e->Prev->Top == e->Top) && !e->Prev->NextInLML) + return e->Prev; + else return 0; +} +//------------------------------------------------------------------------------ + +TEdge *GetMaximaPairEx(TEdge *e) +{ + //as GetMaximaPair() but returns 0 if MaxPair isn't in AEL (unless it's horizontal) + TEdge* result = GetMaximaPair(e); + if (result && (result->OutIdx == Skip || + (result->NextInAEL == result->PrevInAEL && !IsHorizontal(*result)))) return 0; + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::SwapPositionsInSEL(TEdge *Edge1, TEdge *Edge2) +{ + if( !( Edge1->NextInSEL ) && !( Edge1->PrevInSEL ) ) return; + if( !( Edge2->NextInSEL ) && !( Edge2->PrevInSEL ) ) return; + + if( Edge1->NextInSEL == Edge2 ) + { + TEdge* Next = Edge2->NextInSEL; + if( Next ) Next->PrevInSEL = Edge1; + TEdge* Prev = Edge1->PrevInSEL; + if( Prev ) Prev->NextInSEL = Edge2; + Edge2->PrevInSEL = Prev; + Edge2->NextInSEL = Edge1; + Edge1->PrevInSEL = Edge2; + Edge1->NextInSEL = Next; + } + else if( Edge2->NextInSEL == Edge1 ) + { + TEdge* Next = Edge1->NextInSEL; + if( Next ) Next->PrevInSEL = Edge2; + TEdge* Prev = Edge2->PrevInSEL; + if( Prev ) Prev->NextInSEL = Edge1; + Edge1->PrevInSEL = Prev; + Edge1->NextInSEL = Edge2; + Edge2->PrevInSEL = Edge1; + Edge2->NextInSEL = Next; + } + else + { + TEdge* Next = Edge1->NextInSEL; + TEdge* Prev = Edge1->PrevInSEL; + Edge1->NextInSEL = Edge2->NextInSEL; + if( Edge1->NextInSEL ) Edge1->NextInSEL->PrevInSEL = Edge1; + Edge1->PrevInSEL = Edge2->PrevInSEL; + if( Edge1->PrevInSEL ) Edge1->PrevInSEL->NextInSEL = Edge1; + Edge2->NextInSEL = Next; + if( Edge2->NextInSEL ) Edge2->NextInSEL->PrevInSEL = Edge2; + Edge2->PrevInSEL = Prev; + if( Edge2->PrevInSEL ) Edge2->PrevInSEL->NextInSEL = Edge2; + } + + if( !Edge1->PrevInSEL ) m_SortedEdges = Edge1; + else if( !Edge2->PrevInSEL ) m_SortedEdges = Edge2; +} +//------------------------------------------------------------------------------ + +TEdge* GetNextInAEL(TEdge *e, Direction dir) +{ + return dir == dLeftToRight ? e->NextInAEL : e->PrevInAEL; +} +//------------------------------------------------------------------------------ + +void GetHorzDirection(TEdge& HorzEdge, Direction& Dir, cInt& Left, cInt& Right) +{ + if (HorzEdge.Bot.X < HorzEdge.Top.X) + { + Left = HorzEdge.Bot.X; + Right = HorzEdge.Top.X; + Dir = dLeftToRight; + } else + { + Left = HorzEdge.Top.X; + Right = HorzEdge.Bot.X; + Dir = dRightToLeft; + } +} +//------------------------------------------------------------------------ + +/******************************************************************************* +* Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or * +* Bottom of a scanbeam) are processed as if layered. The order in which HEs * +* are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] * +* (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), * +* and with other non-horizontal edges [*]. Once these intersections are * +* processed, intermediate HEs then 'promote' the Edge above (NextInLML) into * +* the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. * +*******************************************************************************/ + +void Clipper::ProcessHorizontal(TEdge *horzEdge) +{ + Direction dir; + cInt horzLeft, horzRight; + bool IsOpen = (horzEdge->WindDelta == 0); + + GetHorzDirection(*horzEdge, dir, horzLeft, horzRight); + + TEdge* eLastHorz = horzEdge, *eMaxPair = 0; + while (eLastHorz->NextInLML && IsHorizontal(*eLastHorz->NextInLML)) + eLastHorz = eLastHorz->NextInLML; + if (!eLastHorz->NextInLML) + eMaxPair = GetMaximaPair(eLastHorz); + + MaximaList::const_iterator maxIt; + MaximaList::const_reverse_iterator maxRit; + if (m_Maxima.size() > 0) + { + //get the first maxima in range (X) ... + if (dir == dLeftToRight) + { + maxIt = m_Maxima.begin(); + while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X) maxIt++; + if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X) + maxIt = m_Maxima.end(); + } + else + { + maxRit = m_Maxima.rbegin(); + while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X) maxRit++; + if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X) + maxRit = m_Maxima.rend(); + } + } + + OutPt* op1 = 0; + + for (;;) //loop through consec. horizontal edges + { + + bool IsLastHorz = (horzEdge == eLastHorz); + TEdge* e = GetNextInAEL(horzEdge, dir); + while(e) + { + + //this code block inserts extra coords into horizontal edges (in output + //polygons) whereever maxima touch these horizontal edges. This helps + //'simplifying' polygons (ie if the Simplify property is set). + if (m_Maxima.size() > 0) + { + if (dir == dLeftToRight) + { + while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) + { + if (horzEdge->OutIdx >= 0 && !IsOpen) + AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y)); + maxIt++; + } + } + else + { + while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) + { + if (horzEdge->OutIdx >= 0 && !IsOpen) + AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y)); + maxRit++; + } + } + }; + + if ((dir == dLeftToRight && e->Curr.X > horzRight) || + (dir == dRightToLeft && e->Curr.X < horzLeft)) break; + + //Also break if we've got to the end of an intermediate horizontal edge ... + //nb: Smaller Dx's are to the right of larger Dx's ABOVE the horizontal. + if (e->Curr.X == horzEdge->Top.X && horzEdge->NextInLML && + e->Dx < horzEdge->NextInLML->Dx) break; + + if (horzEdge->OutIdx >= 0 && !IsOpen) //note: may be done multiple times + { + op1 = AddOutPt(horzEdge, e->Curr); + TEdge* eNextHorz = m_SortedEdges; + while (eNextHorz) + { + if (eNextHorz->OutIdx >= 0 && + HorzSegmentsOverlap(horzEdge->Bot.X, + horzEdge->Top.X, eNextHorz->Bot.X, eNextHorz->Top.X)) + { + OutPt* op2 = GetLastOutPt(eNextHorz); + AddJoin(op2, op1, eNextHorz->Top); + } + eNextHorz = eNextHorz->NextInSEL; + } + AddGhostJoin(op1, horzEdge->Bot); + } + + //OK, so far we're still in range of the horizontal Edge but make sure + //we're at the last of consec. horizontals when matching with eMaxPair + if(e == eMaxPair && IsLastHorz) + { + if (horzEdge->OutIdx >= 0) + AddLocalMaxPoly(horzEdge, eMaxPair, horzEdge->Top); + DeleteFromAEL(horzEdge); + DeleteFromAEL(eMaxPair); + return; + } + + if(dir == dLeftToRight) + { + IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y); + IntersectEdges(horzEdge, e, Pt); + } + else + { + IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y); + IntersectEdges( e, horzEdge, Pt); + } + TEdge* eNext = GetNextInAEL(e, dir); + SwapPositionsInAEL( horzEdge, e ); + e = eNext; + } //end while(e) + + //Break out of loop if HorzEdge.NextInLML is not also horizontal ... + if (!horzEdge->NextInLML || !IsHorizontal(*horzEdge->NextInLML)) break; + + UpdateEdgeIntoAEL(horzEdge); + if (horzEdge->OutIdx >= 0) AddOutPt(horzEdge, horzEdge->Bot); + GetHorzDirection(*horzEdge, dir, horzLeft, horzRight); + + } //end for (;;) + + if (horzEdge->OutIdx >= 0 && !op1) + { + op1 = GetLastOutPt(horzEdge); + TEdge* eNextHorz = m_SortedEdges; + while (eNextHorz) + { + if (eNextHorz->OutIdx >= 0 && + HorzSegmentsOverlap(horzEdge->Bot.X, + horzEdge->Top.X, eNextHorz->Bot.X, eNextHorz->Top.X)) + { + OutPt* op2 = GetLastOutPt(eNextHorz); + AddJoin(op2, op1, eNextHorz->Top); + } + eNextHorz = eNextHorz->NextInSEL; + } + AddGhostJoin(op1, horzEdge->Top); + } + + if (horzEdge->NextInLML) + { + if(horzEdge->OutIdx >= 0) + { + op1 = AddOutPt( horzEdge, horzEdge->Top); + UpdateEdgeIntoAEL(horzEdge); + if (horzEdge->WindDelta == 0) return; + //nb: HorzEdge is no longer horizontal here + TEdge* ePrev = horzEdge->PrevInAEL; + TEdge* eNext = horzEdge->NextInAEL; + if (ePrev && ePrev->Curr.X == horzEdge->Bot.X && + ePrev->Curr.Y == horzEdge->Bot.Y && ePrev->WindDelta != 0 && + (ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y && + SlopesEqual(*horzEdge, *ePrev, m_UseFullRange))) + { + OutPt* op2 = AddOutPt(ePrev, horzEdge->Bot); + AddJoin(op1, op2, horzEdge->Top); + } + else if (eNext && eNext->Curr.X == horzEdge->Bot.X && + eNext->Curr.Y == horzEdge->Bot.Y && eNext->WindDelta != 0 && + eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y && + SlopesEqual(*horzEdge, *eNext, m_UseFullRange)) + { + OutPt* op2 = AddOutPt(eNext, horzEdge->Bot); + AddJoin(op1, op2, horzEdge->Top); + } + } + else + UpdateEdgeIntoAEL(horzEdge); + } + else + { + if (horzEdge->OutIdx >= 0) AddOutPt(horzEdge, horzEdge->Top); + DeleteFromAEL(horzEdge); + } +} +//------------------------------------------------------------------------------ + +bool Clipper::ProcessIntersections(const cInt topY) +{ + if( !m_ActiveEdges ) return true; + try { + BuildIntersectList(topY); + size_t IlSize = m_IntersectList.size(); + if (IlSize == 0) return true; + if (IlSize == 1 || FixupIntersectionOrder()) ProcessIntersectList(); + else return false; + } + catch(...) + { + m_SortedEdges = 0; + DisposeIntersectNodes(); + throw clipperException("ProcessIntersections error"); + } + m_SortedEdges = 0; + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::DisposeIntersectNodes() +{ + for (size_t i = 0; i < m_IntersectList.size(); ++i ) + delete m_IntersectList[i]; + m_IntersectList.clear(); +} +//------------------------------------------------------------------------------ + +void Clipper::BuildIntersectList(const cInt topY) +{ + if ( !m_ActiveEdges ) return; + + //prepare for sorting ... + TEdge* e = m_ActiveEdges; + m_SortedEdges = e; + while( e ) + { + e->PrevInSEL = e->PrevInAEL; + e->NextInSEL = e->NextInAEL; + e->Curr.X = TopX( *e, topY ); + e = e->NextInAEL; + } + + //bubblesort ... + bool isModified; + do + { + isModified = false; + e = m_SortedEdges; + while( e->NextInSEL ) + { + TEdge *eNext = e->NextInSEL; + IntPoint Pt; + if(e->Curr.X > eNext->Curr.X) + { + IntersectPoint(*e, *eNext, Pt); + if (Pt.Y < topY) Pt = IntPoint(TopX(*e, topY), topY); + IntersectNode * newNode = new IntersectNode; + newNode->Edge1 = e; + newNode->Edge2 = eNext; + newNode->Pt = Pt; + m_IntersectList.push_back(newNode); + + SwapPositionsInSEL(e, eNext); + isModified = true; + } + else + e = eNext; + } + if( e->PrevInSEL ) e->PrevInSEL->NextInSEL = 0; + else break; + } + while ( isModified ); + m_SortedEdges = 0; //important +} +//------------------------------------------------------------------------------ + + +void Clipper::ProcessIntersectList() +{ + for (size_t i = 0; i < m_IntersectList.size(); ++i) + { + IntersectNode* iNode = m_IntersectList[i]; + { + IntersectEdges( iNode->Edge1, iNode->Edge2, iNode->Pt); + SwapPositionsInAEL( iNode->Edge1 , iNode->Edge2 ); + } + delete iNode; + } + m_IntersectList.clear(); +} +//------------------------------------------------------------------------------ + +bool IntersectListSort(IntersectNode* node1, IntersectNode* node2) +{ + return node2->Pt.Y < node1->Pt.Y; +} +//------------------------------------------------------------------------------ + +inline bool EdgesAdjacent(const IntersectNode &inode) +{ + return (inode.Edge1->NextInSEL == inode.Edge2) || + (inode.Edge1->PrevInSEL == inode.Edge2); +} +//------------------------------------------------------------------------------ + +bool Clipper::FixupIntersectionOrder() +{ + //pre-condition: intersections are sorted Bottom-most first. + //Now it's crucial that intersections are made only between adjacent edges, + //so to ensure this the order of intersections may need adjusting ... + CopyAELToSEL(); + std::sort(m_IntersectList.begin(), m_IntersectList.end(), IntersectListSort); + size_t cnt = m_IntersectList.size(); + for (size_t i = 0; i < cnt; ++i) + { + if (!EdgesAdjacent(*m_IntersectList[i])) + { + size_t j = i + 1; + while (j < cnt && !EdgesAdjacent(*m_IntersectList[j])) j++; + if (j == cnt) return false; + std::swap(m_IntersectList[i], m_IntersectList[j]); + } + SwapPositionsInSEL(m_IntersectList[i]->Edge1, m_IntersectList[i]->Edge2); + } + return true; +} +//------------------------------------------------------------------------------ + +void Clipper::DoMaxima(TEdge *e) +{ + TEdge* eMaxPair = GetMaximaPairEx(e); + if (!eMaxPair) + { + if (e->OutIdx >= 0) + AddOutPt(e, e->Top); + DeleteFromAEL(e); + return; + } + + TEdge* eNext = e->NextInAEL; + while(eNext && eNext != eMaxPair) + { + IntersectEdges(e, eNext, e->Top); + SwapPositionsInAEL(e, eNext); + eNext = e->NextInAEL; + } + + if(e->OutIdx == Unassigned && eMaxPair->OutIdx == Unassigned) + { + DeleteFromAEL(e); + DeleteFromAEL(eMaxPair); + } + else if( e->OutIdx >= 0 && eMaxPair->OutIdx >= 0 ) + { + if (e->OutIdx >= 0) AddLocalMaxPoly(e, eMaxPair, e->Top); + DeleteFromAEL(e); + DeleteFromAEL(eMaxPair); + } +#ifdef use_lines + else if (e->WindDelta == 0) + { + if (e->OutIdx >= 0) + { + AddOutPt(e, e->Top); + e->OutIdx = Unassigned; + } + DeleteFromAEL(e); + + if (eMaxPair->OutIdx >= 0) + { + AddOutPt(eMaxPair, e->Top); + eMaxPair->OutIdx = Unassigned; + } + DeleteFromAEL(eMaxPair); + } +#endif + else throw clipperException("DoMaxima error"); +} +//------------------------------------------------------------------------------ + +void Clipper::ProcessEdgesAtTopOfScanbeam(const cInt topY) +{ + TEdge* e = m_ActiveEdges; + while( e ) + { + //1. process maxima, treating them as if they're 'bent' horizontal edges, + // but exclude maxima with horizontal edges. nb: e can't be a horizontal. + bool IsMaximaEdge = IsMaxima(e, topY); + + if(IsMaximaEdge) + { + TEdge* eMaxPair = GetMaximaPairEx(e); + IsMaximaEdge = (!eMaxPair || !IsHorizontal(*eMaxPair)); + } + + if(IsMaximaEdge) + { + if (m_StrictSimple) m_Maxima.push_back(e->Top.X); + TEdge* ePrev = e->PrevInAEL; + DoMaxima(e); + if( !ePrev ) e = m_ActiveEdges; + else e = ePrev->NextInAEL; + } + else + { + //2. promote horizontal edges, otherwise update Curr.X and Curr.Y ... + if (IsIntermediate(e, topY) && IsHorizontal(*e->NextInLML)) + { + UpdateEdgeIntoAEL(e); + if (e->OutIdx >= 0) + AddOutPt(e, e->Bot); + AddEdgeToSEL(e); + } + else + { + e->Curr.X = TopX( *e, topY ); + e->Curr.Y = topY; + } + + //When StrictlySimple and 'e' is being touched by another edge, then + //make sure both edges have a vertex here ... + if (m_StrictSimple) + { + TEdge* ePrev = e->PrevInAEL; + if ((e->OutIdx >= 0) && (e->WindDelta != 0) && ePrev && (ePrev->OutIdx >= 0) && + (ePrev->Curr.X == e->Curr.X) && (ePrev->WindDelta != 0)) + { + IntPoint pt = e->Curr; +#ifdef use_xyz + SetZ(pt, *ePrev, *e); +#endif + OutPt* op = AddOutPt(ePrev, pt); + OutPt* op2 = AddOutPt(e, pt); + AddJoin(op, op2, pt); //StrictlySimple (type-3) join + } + } + + e = e->NextInAEL; + } + } + + //3. Process horizontals at the Top of the scanbeam ... + m_Maxima.sort(); + ProcessHorizontals(); + m_Maxima.clear(); + + //4. Promote intermediate vertices ... + e = m_ActiveEdges; + while(e) + { + if(IsIntermediate(e, topY)) + { + OutPt* op = 0; + if( e->OutIdx >= 0 ) + op = AddOutPt(e, e->Top); + UpdateEdgeIntoAEL(e); + + //if output polygons share an edge, they'll need joining later ... + TEdge* ePrev = e->PrevInAEL; + TEdge* eNext = e->NextInAEL; + if (ePrev && ePrev->Curr.X == e->Bot.X && + ePrev->Curr.Y == e->Bot.Y && op && + ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y && + SlopesEqual(e->Curr, e->Top, ePrev->Curr, ePrev->Top, m_UseFullRange) && + (e->WindDelta != 0) && (ePrev->WindDelta != 0)) + { + OutPt* op2 = AddOutPt(ePrev, e->Bot); + AddJoin(op, op2, e->Top); + } + else if (eNext && eNext->Curr.X == e->Bot.X && + eNext->Curr.Y == e->Bot.Y && op && + eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y && + SlopesEqual(e->Curr, e->Top, eNext->Curr, eNext->Top, m_UseFullRange) && + (e->WindDelta != 0) && (eNext->WindDelta != 0)) + { + OutPt* op2 = AddOutPt(eNext, e->Bot); + AddJoin(op, op2, e->Top); + } + } + e = e->NextInAEL; + } +} +//------------------------------------------------------------------------------ + +void Clipper::FixupOutPolyline(OutRec &outrec) +{ + OutPt *pp = outrec.Pts; + OutPt *lastPP = pp->Prev; + while (pp != lastPP) + { + pp = pp->Next; + if (pp->Pt == pp->Prev->Pt) + { + if (pp == lastPP) lastPP = pp->Prev; + OutPt *tmpPP = pp->Prev; + tmpPP->Next = pp->Next; + pp->Next->Prev = tmpPP; + delete pp; + pp = tmpPP; + } + } + + if (pp == pp->Prev) + { + DisposeOutPts(pp); + outrec.Pts = 0; + return; + } +} +//------------------------------------------------------------------------------ + +void Clipper::FixupOutPolygon(OutRec &outrec) +{ + //FixupOutPolygon() - removes duplicate points and simplifies consecutive + //parallel edges by removing the middle vertex. + OutPt *lastOK = 0; + outrec.BottomPt = 0; + OutPt *pp = outrec.Pts; + bool preserveCol = m_PreserveCollinear || m_StrictSimple; + + for (;;) + { + if (pp->Prev == pp || pp->Prev == pp->Next) + { + DisposeOutPts(pp); + outrec.Pts = 0; + return; + } + + //test for duplicate points and collinear edges ... + if ((pp->Pt == pp->Next->Pt) || (pp->Pt == pp->Prev->Pt) || + (SlopesEqual(pp->Prev->Pt, pp->Pt, pp->Next->Pt, m_UseFullRange) && + (!preserveCol || !Pt2IsBetweenPt1AndPt3(pp->Prev->Pt, pp->Pt, pp->Next->Pt)))) + { + lastOK = 0; + OutPt *tmp = pp; + pp->Prev->Next = pp->Next; + pp->Next->Prev = pp->Prev; + pp = pp->Prev; + delete tmp; + } + else if (pp == lastOK) break; + else + { + if (!lastOK) lastOK = pp; + pp = pp->Next; + } + } + outrec.Pts = pp; +} +//------------------------------------------------------------------------------ + +int PointCount(OutPt *Pts) +{ + if (!Pts) return 0; + int result = 0; + OutPt* p = Pts; + do + { + result++; + p = p->Next; + } + while (p != Pts); + return result; +} +//------------------------------------------------------------------------------ + +void Clipper::BuildResult(Paths &polys) +{ + polys.reserve(m_PolyOuts.size()); + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + if (!m_PolyOuts[i]->Pts) continue; + Path pg; + OutPt* p = m_PolyOuts[i]->Pts->Prev; + int cnt = PointCount(p); + if (cnt < 2) continue; + pg.reserve(cnt); + for (int i = 0; i < cnt; ++i) + { + pg.push_back(p->Pt); + p = p->Prev; + } + polys.push_back(pg); + } +} +//------------------------------------------------------------------------------ + +void Clipper::BuildResult2(PolyTree& polytree) +{ + polytree.Clear(); + polytree.AllNodes.reserve(m_PolyOuts.size()); + //add each output polygon/contour to polytree ... + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) + { + OutRec* outRec = m_PolyOuts[i]; + int cnt = PointCount(outRec->Pts); + if ((outRec->IsOpen && cnt < 2) || (!outRec->IsOpen && cnt < 3)) continue; + FixHoleLinkage(*outRec); + PolyNode* pn = new PolyNode(); + //nb: polytree takes ownership of all the PolyNodes + polytree.AllNodes.push_back(pn); + outRec->PolyNd = pn; + pn->Parent = 0; + pn->Index = 0; + pn->Contour.reserve(cnt); + OutPt *op = outRec->Pts->Prev; + for (int j = 0; j < cnt; j++) + { + pn->Contour.push_back(op->Pt); + op = op->Prev; + } + } + + //fixup PolyNode links etc ... + polytree.Childs.reserve(m_PolyOuts.size()); + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) + { + OutRec* outRec = m_PolyOuts[i]; + if (!outRec->PolyNd) continue; + if (outRec->IsOpen) + { + outRec->PolyNd->m_IsOpen = true; + polytree.AddChild(*outRec->PolyNd); + } + else if (outRec->FirstLeft && outRec->FirstLeft->PolyNd) + outRec->FirstLeft->PolyNd->AddChild(*outRec->PolyNd); + else + polytree.AddChild(*outRec->PolyNd); + } +} +//------------------------------------------------------------------------------ + +void SwapIntersectNodes(IntersectNode &int1, IntersectNode &int2) +{ + //just swap the contents (because fIntersectNodes is a single-linked-list) + IntersectNode inode = int1; //gets a copy of Int1 + int1.Edge1 = int2.Edge1; + int1.Edge2 = int2.Edge2; + int1.Pt = int2.Pt; + int2.Edge1 = inode.Edge1; + int2.Edge2 = inode.Edge2; + int2.Pt = inode.Pt; +} +//------------------------------------------------------------------------------ + +inline bool E2InsertsBeforeE1(TEdge &e1, TEdge &e2) +{ + if (e2.Curr.X == e1.Curr.X) + { + if (e2.Top.Y > e1.Top.Y) + return e2.Top.X < TopX(e1, e2.Top.Y); + else return e1.Top.X > TopX(e2, e1.Top.Y); + } + else return e2.Curr.X < e1.Curr.X; +} +//------------------------------------------------------------------------------ + +bool GetOverlap(const cInt a1, const cInt a2, const cInt b1, const cInt b2, + cInt& Left, cInt& Right) +{ + if (a1 < a2) + { + if (b1 < b2) {Left = std::max(a1,b1); Right = std::min(a2,b2);} + else {Left = std::max(a1,b2); Right = std::min(a2,b1);} + } + else + { + if (b1 < b2) {Left = std::max(a2,b1); Right = std::min(a1,b2);} + else {Left = std::max(a2,b2); Right = std::min(a1,b1);} + } + return Left < Right; +} +//------------------------------------------------------------------------------ + +inline void UpdateOutPtIdxs(OutRec& outrec) +{ + OutPt* op = outrec.Pts; + do + { + op->Idx = outrec.Idx; + op = op->Prev; + } + while(op != outrec.Pts); +} +//------------------------------------------------------------------------------ + +void Clipper::InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge) +{ + if(!m_ActiveEdges) + { + edge->PrevInAEL = 0; + edge->NextInAEL = 0; + m_ActiveEdges = edge; + } + else if(!startEdge && E2InsertsBeforeE1(*m_ActiveEdges, *edge)) + { + edge->PrevInAEL = 0; + edge->NextInAEL = m_ActiveEdges; + m_ActiveEdges->PrevInAEL = edge; + m_ActiveEdges = edge; + } + else + { + if(!startEdge) startEdge = m_ActiveEdges; + while(startEdge->NextInAEL && + !E2InsertsBeforeE1(*startEdge->NextInAEL , *edge)) + startEdge = startEdge->NextInAEL; + edge->NextInAEL = startEdge->NextInAEL; + if(startEdge->NextInAEL) startEdge->NextInAEL->PrevInAEL = edge; + edge->PrevInAEL = startEdge; + startEdge->NextInAEL = edge; + } +} +//---------------------------------------------------------------------- + +OutPt* DupOutPt(OutPt* outPt, bool InsertAfter) +{ + OutPt* result = new OutPt; + result->Pt = outPt->Pt; + result->Idx = outPt->Idx; + if (InsertAfter) + { + result->Next = outPt->Next; + result->Prev = outPt; + outPt->Next->Prev = result; + outPt->Next = result; + } + else + { + result->Prev = outPt->Prev; + result->Next = outPt; + outPt->Prev->Next = result; + outPt->Prev = result; + } + return result; +} +//------------------------------------------------------------------------------ + +bool JoinHorz(OutPt* op1, OutPt* op1b, OutPt* op2, OutPt* op2b, + const IntPoint Pt, bool DiscardLeft) +{ + Direction Dir1 = (op1->Pt.X > op1b->Pt.X ? dRightToLeft : dLeftToRight); + Direction Dir2 = (op2->Pt.X > op2b->Pt.X ? dRightToLeft : dLeftToRight); + if (Dir1 == Dir2) return false; + + //When DiscardLeft, we want Op1b to be on the Left of Op1, otherwise we + //want Op1b to be on the Right. (And likewise with Op2 and Op2b.) + //So, to facilitate this while inserting Op1b and Op2b ... + //when DiscardLeft, make sure we're AT or RIGHT of Pt before adding Op1b, + //otherwise make sure we're AT or LEFT of Pt. (Likewise with Op2b.) + if (Dir1 == dLeftToRight) + { + while (op1->Next->Pt.X <= Pt.X && + op1->Next->Pt.X >= op1->Pt.X && op1->Next->Pt.Y == Pt.Y) + op1 = op1->Next; + if (DiscardLeft && (op1->Pt.X != Pt.X)) op1 = op1->Next; + op1b = DupOutPt(op1, !DiscardLeft); + if (op1b->Pt != Pt) + { + op1 = op1b; + op1->Pt = Pt; + op1b = DupOutPt(op1, !DiscardLeft); + } + } + else + { + while (op1->Next->Pt.X >= Pt.X && + op1->Next->Pt.X <= op1->Pt.X && op1->Next->Pt.Y == Pt.Y) + op1 = op1->Next; + if (!DiscardLeft && (op1->Pt.X != Pt.X)) op1 = op1->Next; + op1b = DupOutPt(op1, DiscardLeft); + if (op1b->Pt != Pt) + { + op1 = op1b; + op1->Pt = Pt; + op1b = DupOutPt(op1, DiscardLeft); + } + } + + if (Dir2 == dLeftToRight) + { + while (op2->Next->Pt.X <= Pt.X && + op2->Next->Pt.X >= op2->Pt.X && op2->Next->Pt.Y == Pt.Y) + op2 = op2->Next; + if (DiscardLeft && (op2->Pt.X != Pt.X)) op2 = op2->Next; + op2b = DupOutPt(op2, !DiscardLeft); + if (op2b->Pt != Pt) + { + op2 = op2b; + op2->Pt = Pt; + op2b = DupOutPt(op2, !DiscardLeft); + }; + } else + { + while (op2->Next->Pt.X >= Pt.X && + op2->Next->Pt.X <= op2->Pt.X && op2->Next->Pt.Y == Pt.Y) + op2 = op2->Next; + if (!DiscardLeft && (op2->Pt.X != Pt.X)) op2 = op2->Next; + op2b = DupOutPt(op2, DiscardLeft); + if (op2b->Pt != Pt) + { + op2 = op2b; + op2->Pt = Pt; + op2b = DupOutPt(op2, DiscardLeft); + }; + }; + + if ((Dir1 == dLeftToRight) == DiscardLeft) + { + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + } + else + { + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + } + return true; +} +//------------------------------------------------------------------------------ + +bool Clipper::JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2) +{ + OutPt *op1 = j->OutPt1, *op1b; + OutPt *op2 = j->OutPt2, *op2b; + + //There are 3 kinds of joins for output polygons ... + //1. Horizontal joins where Join.OutPt1 & Join.OutPt2 are vertices anywhere + //along (horizontal) collinear edges (& Join.OffPt is on the same horizontal). + //2. Non-horizontal joins where Join.OutPt1 & Join.OutPt2 are at the same + //location at the Bottom of the overlapping segment (& Join.OffPt is above). + //3. StrictSimple joins where edges touch but are not collinear and where + //Join.OutPt1, Join.OutPt2 & Join.OffPt all share the same point. + bool isHorizontal = (j->OutPt1->Pt.Y == j->OffPt.Y); + + if (isHorizontal && (j->OffPt == j->OutPt1->Pt) && + (j->OffPt == j->OutPt2->Pt)) + { + //Strictly Simple join ... + if (outRec1 != outRec2) return false; + op1b = j->OutPt1->Next; + while (op1b != op1 && (op1b->Pt == j->OffPt)) + op1b = op1b->Next; + bool reverse1 = (op1b->Pt.Y > j->OffPt.Y); + op2b = j->OutPt2->Next; + while (op2b != op2 && (op2b->Pt == j->OffPt)) + op2b = op2b->Next; + bool reverse2 = (op2b->Pt.Y > j->OffPt.Y); + if (reverse1 == reverse2) return false; + if (reverse1) + { + op1b = DupOutPt(op1, false); + op2b = DupOutPt(op2, true); + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } else + { + op1b = DupOutPt(op1, true); + op2b = DupOutPt(op2, false); + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } + } + else if (isHorizontal) + { + //treat horizontal joins differently to non-horizontal joins since with + //them we're not yet sure where the overlapping is. OutPt1.Pt & OutPt2.Pt + //may be anywhere along the horizontal edge. + op1b = op1; + while (op1->Prev->Pt.Y == op1->Pt.Y && op1->Prev != op1b && op1->Prev != op2) + op1 = op1->Prev; + while (op1b->Next->Pt.Y == op1b->Pt.Y && op1b->Next != op1 && op1b->Next != op2) + op1b = op1b->Next; + if (op1b->Next == op1 || op1b->Next == op2) return false; //a flat 'polygon' + + op2b = op2; + while (op2->Prev->Pt.Y == op2->Pt.Y && op2->Prev != op2b && op2->Prev != op1b) + op2 = op2->Prev; + while (op2b->Next->Pt.Y == op2b->Pt.Y && op2b->Next != op2 && op2b->Next != op1) + op2b = op2b->Next; + if (op2b->Next == op2 || op2b->Next == op1) return false; //a flat 'polygon' + + cInt Left, Right; + //Op1 --> Op1b & Op2 --> Op2b are the extremites of the horizontal edges + if (!GetOverlap(op1->Pt.X, op1b->Pt.X, op2->Pt.X, op2b->Pt.X, Left, Right)) + return false; + + //DiscardLeftSide: when overlapping edges are joined, a spike will created + //which needs to be cleaned up. However, we don't want Op1 or Op2 caught up + //on the discard Side as either may still be needed for other joins ... + IntPoint Pt; + bool DiscardLeftSide; + if (op1->Pt.X >= Left && op1->Pt.X <= Right) + { + Pt = op1->Pt; DiscardLeftSide = (op1->Pt.X > op1b->Pt.X); + } + else if (op2->Pt.X >= Left&& op2->Pt.X <= Right) + { + Pt = op2->Pt; DiscardLeftSide = (op2->Pt.X > op2b->Pt.X); + } + else if (op1b->Pt.X >= Left && op1b->Pt.X <= Right) + { + Pt = op1b->Pt; DiscardLeftSide = op1b->Pt.X > op1->Pt.X; + } + else + { + Pt = op2b->Pt; DiscardLeftSide = (op2b->Pt.X > op2->Pt.X); + } + j->OutPt1 = op1; j->OutPt2 = op2; + return JoinHorz(op1, op1b, op2, op2b, Pt, DiscardLeftSide); + } else + { + //nb: For non-horizontal joins ... + // 1. Jr.OutPt1.Pt.Y == Jr.OutPt2.Pt.Y + // 2. Jr.OutPt1.Pt > Jr.OffPt.Y + + //make sure the polygons are correctly oriented ... + op1b = op1->Next; + while ((op1b->Pt == op1->Pt) && (op1b != op1)) op1b = op1b->Next; + bool Reverse1 = ((op1b->Pt.Y > op1->Pt.Y) || + !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange)); + if (Reverse1) + { + op1b = op1->Prev; + while ((op1b->Pt == op1->Pt) && (op1b != op1)) op1b = op1b->Prev; + if ((op1b->Pt.Y > op1->Pt.Y) || + !SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange)) return false; + }; + op2b = op2->Next; + while ((op2b->Pt == op2->Pt) && (op2b != op2))op2b = op2b->Next; + bool Reverse2 = ((op2b->Pt.Y > op2->Pt.Y) || + !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange)); + if (Reverse2) + { + op2b = op2->Prev; + while ((op2b->Pt == op2->Pt) && (op2b != op2)) op2b = op2b->Prev; + if ((op2b->Pt.Y > op2->Pt.Y) || + !SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange)) return false; + } + + if ((op1b == op1) || (op2b == op2) || (op1b == op2b) || + ((outRec1 == outRec2) && (Reverse1 == Reverse2))) return false; + + if (Reverse1) + { + op1b = DupOutPt(op1, false); + op2b = DupOutPt(op2, true); + op1->Prev = op2; + op2->Next = op1; + op1b->Next = op2b; + op2b->Prev = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } else + { + op1b = DupOutPt(op1, true); + op2b = DupOutPt(op2, false); + op1->Next = op2; + op2->Prev = op1; + op1b->Prev = op2b; + op2b->Next = op1b; + j->OutPt1 = op1; + j->OutPt2 = op1b; + return true; + } + } +} +//---------------------------------------------------------------------- + +static OutRec* ParseFirstLeft(OutRec* FirstLeft) +{ + while (FirstLeft && !FirstLeft->Pts) + FirstLeft = FirstLeft->FirstLeft; + return FirstLeft; +} +//------------------------------------------------------------------------------ + +void Clipper::FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec) +{ + //tests if NewOutRec contains the polygon before reassigning FirstLeft + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec* outRec = m_PolyOuts[i]; + OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (outRec->Pts && firstLeft == OldOutRec) + { + if (Poly2ContainsPoly1(outRec->Pts, NewOutRec->Pts)) + outRec->FirstLeft = NewOutRec; + } + } +} +//---------------------------------------------------------------------- + +void Clipper::FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec) +{ + //A polygon has split into two such that one is now the inner of the other. + //It's possible that these polygons now wrap around other polygons, so check + //every polygon that's also contained by OuterOutRec's FirstLeft container + //(including 0) to see if they've become inner to the new inner polygon ... + OutRec* orfl = OuterOutRec->FirstLeft; + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec* outRec = m_PolyOuts[i]; + + if (!outRec->Pts || outRec == OuterOutRec || outRec == InnerOutRec) + continue; + OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (firstLeft != orfl && firstLeft != InnerOutRec && firstLeft != OuterOutRec) + continue; + if (Poly2ContainsPoly1(outRec->Pts, InnerOutRec->Pts)) + outRec->FirstLeft = InnerOutRec; + else if (Poly2ContainsPoly1(outRec->Pts, OuterOutRec->Pts)) + outRec->FirstLeft = OuterOutRec; + else if (outRec->FirstLeft == InnerOutRec || outRec->FirstLeft == OuterOutRec) + outRec->FirstLeft = orfl; + } +} +//---------------------------------------------------------------------- +void Clipper::FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec) +{ + //reassigns FirstLeft WITHOUT testing if NewOutRec contains the polygon + for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) + { + OutRec* outRec = m_PolyOuts[i]; + OutRec* firstLeft = ParseFirstLeft(outRec->FirstLeft); + if (outRec->Pts && outRec->FirstLeft == OldOutRec) + outRec->FirstLeft = NewOutRec; + } +} +//---------------------------------------------------------------------- + +void Clipper::JoinCommonEdges() +{ + for (JoinList::size_type i = 0; i < m_Joins.size(); i++) + { + Join* join = m_Joins[i]; + + OutRec *outRec1 = GetOutRec(join->OutPt1->Idx); + OutRec *outRec2 = GetOutRec(join->OutPt2->Idx); + + if (!outRec1->Pts || !outRec2->Pts) continue; + if (outRec1->IsOpen || outRec2->IsOpen) continue; + + //get the polygon fragment with the correct hole state (FirstLeft) + //before calling JoinPoints() ... + OutRec *holeStateRec; + if (outRec1 == outRec2) holeStateRec = outRec1; + else if (OutRec1RightOfOutRec2(outRec1, outRec2)) holeStateRec = outRec2; + else if (OutRec1RightOfOutRec2(outRec2, outRec1)) holeStateRec = outRec1; + else holeStateRec = GetLowermostRec(outRec1, outRec2); + + if (!JoinPoints(join, outRec1, outRec2)) continue; + + if (outRec1 == outRec2) + { + //instead of joining two polygons, we've just created a new one by + //splitting one polygon into two. + outRec1->Pts = join->OutPt1; + outRec1->BottomPt = 0; + outRec2 = CreateOutRec(); + outRec2->Pts = join->OutPt2; + + //update all OutRec2.Pts Idx's ... + UpdateOutPtIdxs(*outRec2); + + if (Poly2ContainsPoly1(outRec2->Pts, outRec1->Pts)) + { + //outRec1 contains outRec2 ... + outRec2->IsHole = !outRec1->IsHole; + outRec2->FirstLeft = outRec1; + + if (m_UsingPolyTree) FixupFirstLefts2(outRec2, outRec1); + + if ((outRec2->IsHole ^ m_ReverseOutput) == (Area(*outRec2) > 0)) + ReversePolyPtLinks(outRec2->Pts); + + } else if (Poly2ContainsPoly1(outRec1->Pts, outRec2->Pts)) + { + //outRec2 contains outRec1 ... + outRec2->IsHole = outRec1->IsHole; + outRec1->IsHole = !outRec2->IsHole; + outRec2->FirstLeft = outRec1->FirstLeft; + outRec1->FirstLeft = outRec2; + + if (m_UsingPolyTree) FixupFirstLefts2(outRec1, outRec2); + + if ((outRec1->IsHole ^ m_ReverseOutput) == (Area(*outRec1) > 0)) + ReversePolyPtLinks(outRec1->Pts); + } + else + { + //the 2 polygons are completely separate ... + outRec2->IsHole = outRec1->IsHole; + outRec2->FirstLeft = outRec1->FirstLeft; + + //fixup FirstLeft pointers that may need reassigning to OutRec2 + if (m_UsingPolyTree) FixupFirstLefts1(outRec1, outRec2); + } + + } else + { + //joined 2 polygons together ... + + outRec2->Pts = 0; + outRec2->BottomPt = 0; + outRec2->Idx = outRec1->Idx; + + outRec1->IsHole = holeStateRec->IsHole; + if (holeStateRec == outRec2) + outRec1->FirstLeft = outRec2->FirstLeft; + outRec2->FirstLeft = outRec1; + + if (m_UsingPolyTree) FixupFirstLefts3(outRec2, outRec1); + } + } +} + +//------------------------------------------------------------------------------ +// ClipperOffset support functions ... +//------------------------------------------------------------------------------ + +DoublePoint GetUnitNormal(const IntPoint &pt1, const IntPoint &pt2) +{ + if(pt2.X == pt1.X && pt2.Y == pt1.Y) + return DoublePoint(0, 0); + + double Dx = (double)(pt2.X - pt1.X); + double dy = (double)(pt2.Y - pt1.Y); + double f = 1 *1.0/ std::sqrt( Dx*Dx + dy*dy ); + Dx *= f; + dy *= f; + return DoublePoint(dy, -Dx); +} + +//------------------------------------------------------------------------------ +// ClipperOffset class +//------------------------------------------------------------------------------ + +ClipperOffset::ClipperOffset(double miterLimit, double arcTolerance) +{ + this->MiterLimit = miterLimit; + this->ArcTolerance = arcTolerance; + m_lowest.X = -1; +} +//------------------------------------------------------------------------------ + +ClipperOffset::~ClipperOffset() +{ + Clear(); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Clear() +{ + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) + delete m_polyNodes.Childs[i]; + m_polyNodes.Childs.clear(); + m_lowest.X = -1; +} +//------------------------------------------------------------------------------ + +void ClipperOffset::AddPath(const Path& path, JoinType joinType, EndType endType) +{ + int highI = (int)path.size() - 1; + if (highI < 0) return; + PolyNode* newNode = new PolyNode(); + newNode->m_jointype = joinType; + newNode->m_endtype = endType; + + //strip duplicate points from path and also get index to the lowest point ... + if (endType == etClosedLine || endType == etClosedPolygon) + while (highI > 0 && path[0] == path[highI]) highI--; + newNode->Contour.reserve(highI + 1); + newNode->Contour.push_back(path[0]); + int j = 0, k = 0; + for (int i = 1; i <= highI; i++) + if (newNode->Contour[j] != path[i]) + { + j++; + newNode->Contour.push_back(path[i]); + if (path[i].Y > newNode->Contour[k].Y || + (path[i].Y == newNode->Contour[k].Y && + path[i].X < newNode->Contour[k].X)) k = j; + } + if (endType == etClosedPolygon && j < 2) + { + delete newNode; + return; + } + m_polyNodes.AddChild(*newNode); + + //if this path's lowest pt is lower than all the others then update m_lowest + if (endType != etClosedPolygon) return; + if (m_lowest.X < 0) + m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k); + else + { + IntPoint ip = m_polyNodes.Childs[(int)m_lowest.X]->Contour[(int)m_lowest.Y]; + if (newNode->Contour[k].Y > ip.Y || + (newNode->Contour[k].Y == ip.Y && + newNode->Contour[k].X < ip.X)) + m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::AddPaths(const Paths& paths, JoinType joinType, EndType endType) +{ + for (Paths::size_type i = 0; i < paths.size(); ++i) + AddPath(paths[i], joinType, endType); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::FixOrientations() +{ + //fixup orientations of all closed paths if the orientation of the + //closed path with the lowermost vertex is wrong ... + if (m_lowest.X >= 0 && + !Orientation(m_polyNodes.Childs[(int)m_lowest.X]->Contour)) + { + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) + { + PolyNode& node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedPolygon || + (node.m_endtype == etClosedLine && Orientation(node.Contour))) + ReversePath(node.Contour); + } + } else + { + for (int i = 0; i < m_polyNodes.ChildCount(); ++i) + { + PolyNode& node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedLine && !Orientation(node.Contour)) + ReversePath(node.Contour); + } + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Execute(Paths& solution, double delta) +{ + solution.clear(); + FixOrientations(); + DoOffset(delta); + + //now clean up 'corners' ... + Clipper clpr; + clpr.AddPaths(m_destPolys, ptSubject, true); + if (delta > 0) + { + clpr.Execute(ctUnion, solution, pftPositive, pftPositive); + } + else + { + IntRect r = clpr.GetBounds(); + Path outer(4); + outer[0] = IntPoint(r.left - 10, r.bottom + 10); + outer[1] = IntPoint(r.right + 10, r.bottom + 10); + outer[2] = IntPoint(r.right + 10, r.top - 10); + outer[3] = IntPoint(r.left - 10, r.top - 10); + + clpr.AddPath(outer, ptSubject, true); + clpr.ReverseSolution(true); + clpr.Execute(ctUnion, solution, pftNegative, pftNegative); + if (solution.size() > 0) solution.erase(solution.begin()); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::Execute(PolyTree& solution, double delta) +{ + solution.Clear(); + FixOrientations(); + DoOffset(delta); + + //now clean up 'corners' ... + Clipper clpr; + clpr.AddPaths(m_destPolys, ptSubject, true); + if (delta > 0) + { + clpr.Execute(ctUnion, solution, pftPositive, pftPositive); + } + else + { + IntRect r = clpr.GetBounds(); + Path outer(4); + outer[0] = IntPoint(r.left - 10, r.bottom + 10); + outer[1] = IntPoint(r.right + 10, r.bottom + 10); + outer[2] = IntPoint(r.right + 10, r.top - 10); + outer[3] = IntPoint(r.left - 10, r.top - 10); + + clpr.AddPath(outer, ptSubject, true); + clpr.ReverseSolution(true); + clpr.Execute(ctUnion, solution, pftNegative, pftNegative); + //remove the outer PolyNode rectangle ... + if (solution.ChildCount() == 1 && solution.Childs[0]->ChildCount() > 0) + { + PolyNode* outerNode = solution.Childs[0]; + solution.Childs.reserve(outerNode->ChildCount()); + solution.Childs[0] = outerNode->Childs[0]; + solution.Childs[0]->Parent = outerNode->Parent; + for (int i = 1; i < outerNode->ChildCount(); ++i) + solution.AddChild(*outerNode->Childs[i]); + } + else + solution.Clear(); + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoOffset(double delta) +{ + m_destPolys.clear(); + m_delta = delta; + + //if Zero offset, just copy any CLOSED polygons to m_p and return ... + if (NEAR_ZERO(delta)) + { + m_destPolys.reserve(m_polyNodes.ChildCount()); + for (int i = 0; i < m_polyNodes.ChildCount(); i++) + { + PolyNode& node = *m_polyNodes.Childs[i]; + if (node.m_endtype == etClosedPolygon) + m_destPolys.push_back(node.Contour); + } + return; + } + + //see offset_triginometry3.svg in the documentation folder ... + if (MiterLimit > 2) m_miterLim = 2/(MiterLimit * MiterLimit); + else m_miterLim = 0.5; + + double y; + if (ArcTolerance <= 0.0) y = def_arc_tolerance; + else if (ArcTolerance > std::fabs(delta) * def_arc_tolerance) + y = std::fabs(delta) * def_arc_tolerance; + else y = ArcTolerance; + //see offset_triginometry2.svg in the documentation folder ... + double steps = pi / std::acos(1 - y / std::fabs(delta)); + if (steps > std::fabs(delta) * pi) + steps = std::fabs(delta) * pi; //ie excessive precision check + m_sin = std::sin(two_pi / steps); + m_cos = std::cos(two_pi / steps); + m_StepsPerRad = steps / two_pi; + if (delta < 0.0) m_sin = -m_sin; + + m_destPolys.reserve(m_polyNodes.ChildCount() * 2); + for (int i = 0; i < m_polyNodes.ChildCount(); i++) + { + PolyNode& node = *m_polyNodes.Childs[i]; + m_srcPoly = node.Contour; + + int len = (int)m_srcPoly.size(); + if (len == 0 || (delta <= 0 && (len < 3 || node.m_endtype != etClosedPolygon))) + continue; + + m_destPoly.clear(); + if (len == 1) + { + if (node.m_jointype == jtRound) + { + double X = 1.0, Y = 0.0; + for (cInt j = 1; j <= steps; j++) + { + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[0].X + X * delta), + Round(m_srcPoly[0].Y + Y * delta))); + double X2 = X; + X = X * m_cos - m_sin * Y; + Y = X2 * m_sin + Y * m_cos; + } + } + else + { + double X = -1.0, Y = -1.0; + for (int j = 0; j < 4; ++j) + { + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[0].X + X * delta), + Round(m_srcPoly[0].Y + Y * delta))); + if (X < 0) X = 1; + else if (Y < 0) Y = 1; + else X = -1; + } + } + m_destPolys.push_back(m_destPoly); + continue; + } + //build m_normals ... + m_normals.clear(); + m_normals.reserve(len); + for (int j = 0; j < len - 1; ++j) + m_normals.push_back(GetUnitNormal(m_srcPoly[j], m_srcPoly[j + 1])); + if (node.m_endtype == etClosedLine || node.m_endtype == etClosedPolygon) + m_normals.push_back(GetUnitNormal(m_srcPoly[len - 1], m_srcPoly[0])); + else + m_normals.push_back(DoublePoint(m_normals[len - 2])); + + if (node.m_endtype == etClosedPolygon) + { + int k = len - 1; + for (int j = 0; j < len; ++j) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + } + else if (node.m_endtype == etClosedLine) + { + int k = len - 1; + for (int j = 0; j < len; ++j) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + m_destPoly.clear(); + //re-build m_normals ... + DoublePoint n = m_normals[len -1]; + for (int j = len - 1; j > 0; j--) + m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y); + m_normals[0] = DoublePoint(-n.X, -n.Y); + k = 0; + for (int j = len - 1; j >= 0; j--) + OffsetPoint(j, k, node.m_jointype); + m_destPolys.push_back(m_destPoly); + } + else + { + int k = 0; + for (int j = 1; j < len - 1; ++j) + OffsetPoint(j, k, node.m_jointype); + + IntPoint pt1; + if (node.m_endtype == etOpenButt) + { + int j = len - 1; + pt1 = IntPoint((cInt)Round(m_srcPoly[j].X + m_normals[j].X * + delta), (cInt)Round(m_srcPoly[j].Y + m_normals[j].Y * delta)); + m_destPoly.push_back(pt1); + pt1 = IntPoint((cInt)Round(m_srcPoly[j].X - m_normals[j].X * + delta), (cInt)Round(m_srcPoly[j].Y - m_normals[j].Y * delta)); + m_destPoly.push_back(pt1); + } + else + { + int j = len - 1; + k = len - 2; + m_sinA = 0; + m_normals[j] = DoublePoint(-m_normals[j].X, -m_normals[j].Y); + if (node.m_endtype == etOpenSquare) + DoSquare(j, k); + else + DoRound(j, k); + } + + //re-build m_normals ... + for (int j = len - 1; j > 0; j--) + m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y); + m_normals[0] = DoublePoint(-m_normals[1].X, -m_normals[1].Y); + + k = len - 1; + for (int j = k - 1; j > 0; --j) OffsetPoint(j, k, node.m_jointype); + + if (node.m_endtype == etOpenButt) + { + pt1 = IntPoint((cInt)Round(m_srcPoly[0].X - m_normals[0].X * delta), + (cInt)Round(m_srcPoly[0].Y - m_normals[0].Y * delta)); + m_destPoly.push_back(pt1); + pt1 = IntPoint((cInt)Round(m_srcPoly[0].X + m_normals[0].X * delta), + (cInt)Round(m_srcPoly[0].Y + m_normals[0].Y * delta)); + m_destPoly.push_back(pt1); + } + else + { + k = 1; + m_sinA = 0; + if (node.m_endtype == etOpenSquare) + DoSquare(0, 1); + else + DoRound(0, 1); + } + m_destPolys.push_back(m_destPoly); + } + } +} +//------------------------------------------------------------------------------ + +void ClipperOffset::OffsetPoint(int j, int& k, JoinType jointype) +{ + //cross product ... + m_sinA = (m_normals[k].X * m_normals[j].Y - m_normals[j].X * m_normals[k].Y); + if (std::fabs(m_sinA * m_delta) < 1.0) + { + //dot product ... + double cosA = (m_normals[k].X * m_normals[j].X + m_normals[j].Y * m_normals[k].Y ); + if (cosA > 0) // angle => 0 degrees + { + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta))); + return; + } + //else angle => 180 degrees + } + else if (m_sinA > 1.0) m_sinA = 1.0; + else if (m_sinA < -1.0) m_sinA = -1.0; + + if (m_sinA * m_delta < 0) + { + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta))); + m_destPoly.push_back(m_srcPoly[j]); + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + m_normals[j].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta))); + } + else + switch (jointype) + { + case jtMiter: + { + double r = 1 + (m_normals[j].X * m_normals[k].X + + m_normals[j].Y * m_normals[k].Y); + if (r >= m_miterLim) DoMiter(j, k, r); else DoSquare(j, k); + break; + } + case jtSquare: DoSquare(j, k); break; + case jtRound: DoRound(j, k); break; + } + k = j; +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoSquare(int j, int k) +{ + double dx = std::tan(std::atan2(m_sinA, + m_normals[k].X * m_normals[j].X + m_normals[k].Y * m_normals[j].Y) / 4); + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_delta * (m_normals[k].X - m_normals[k].Y * dx)), + Round(m_srcPoly[j].Y + m_delta * (m_normals[k].Y + m_normals[k].X * dx)))); + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_delta * (m_normals[j].X + m_normals[j].Y * dx)), + Round(m_srcPoly[j].Y + m_delta * (m_normals[j].Y - m_normals[j].X * dx)))); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoMiter(int j, int k, double r) +{ + double q = m_delta / r; + m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + (m_normals[k].X + m_normals[j].X) * q), + Round(m_srcPoly[j].Y + (m_normals[k].Y + m_normals[j].Y) * q))); +} +//------------------------------------------------------------------------------ + +void ClipperOffset::DoRound(int j, int k) +{ + double a = std::atan2(m_sinA, + m_normals[k].X * m_normals[j].X + m_normals[k].Y * m_normals[j].Y); + int steps = std::max((int)Round(m_StepsPerRad * std::fabs(a)), 1); + + double X = m_normals[k].X, Y = m_normals[k].Y, X2; + for (int i = 0; i < steps; ++i) + { + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + X * m_delta), + Round(m_srcPoly[j].Y + Y * m_delta))); + X2 = X; + X = X * m_cos - m_sin * Y; + Y = X2 * m_sin + Y * m_cos; + } + m_destPoly.push_back(IntPoint( + Round(m_srcPoly[j].X + m_normals[j].X * m_delta), + Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta))); +} + +//------------------------------------------------------------------------------ +// Miscellaneous public functions +//------------------------------------------------------------------------------ + +void Clipper::DoSimplePolygons() +{ + PolyOutList::size_type i = 0; + while (i < m_PolyOuts.size()) + { + OutRec* outrec = m_PolyOuts[i++]; + OutPt* op = outrec->Pts; + if (!op || outrec->IsOpen) continue; + do //for each Pt in Polygon until duplicate found do ... + { + OutPt* op2 = op->Next; + while (op2 != outrec->Pts) + { + if ((op->Pt == op2->Pt) && op2->Next != op && op2->Prev != op) + { + //split the polygon into two ... + OutPt* op3 = op->Prev; + OutPt* op4 = op2->Prev; + op->Prev = op4; + op4->Next = op; + op2->Prev = op3; + op3->Next = op2; + + outrec->Pts = op; + OutRec* outrec2 = CreateOutRec(); + outrec2->Pts = op2; + UpdateOutPtIdxs(*outrec2); + if (Poly2ContainsPoly1(outrec2->Pts, outrec->Pts)) + { + //OutRec2 is contained by OutRec1 ... + outrec2->IsHole = !outrec->IsHole; + outrec2->FirstLeft = outrec; + if (m_UsingPolyTree) FixupFirstLefts2(outrec2, outrec); + } + else + if (Poly2ContainsPoly1(outrec->Pts, outrec2->Pts)) + { + //OutRec1 is contained by OutRec2 ... + outrec2->IsHole = outrec->IsHole; + outrec->IsHole = !outrec2->IsHole; + outrec2->FirstLeft = outrec->FirstLeft; + outrec->FirstLeft = outrec2; + if (m_UsingPolyTree) FixupFirstLefts2(outrec, outrec2); + } + else + { + //the 2 polygons are separate ... + outrec2->IsHole = outrec->IsHole; + outrec2->FirstLeft = outrec->FirstLeft; + if (m_UsingPolyTree) FixupFirstLefts1(outrec, outrec2); + } + op2 = op; //ie get ready for the Next iteration + } + op2 = op2->Next; + } + op = op->Next; + } + while (op != outrec->Pts); + } +} +//------------------------------------------------------------------------------ + +void ReversePath(Path& p) +{ + std::reverse(p.begin(), p.end()); +} +//------------------------------------------------------------------------------ + +void ReversePaths(Paths& p) +{ + for (Paths::size_type i = 0; i < p.size(); ++i) + ReversePath(p[i]); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType) +{ + Clipper c; + c.StrictlySimple(true); + c.AddPath(in_poly, ptSubject, true); + c.Execute(ctUnion, out_polys, fillType, fillType); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType) +{ + Clipper c; + c.StrictlySimple(true); + c.AddPaths(in_polys, ptSubject, true); + c.Execute(ctUnion, out_polys, fillType, fillType); +} +//------------------------------------------------------------------------------ + +void SimplifyPolygons(Paths &polys, PolyFillType fillType) +{ + SimplifyPolygons(polys, polys, fillType); +} +//------------------------------------------------------------------------------ + +inline double DistanceSqrd(const IntPoint& pt1, const IntPoint& pt2) +{ + double Dx = ((double)pt1.X - pt2.X); + double dy = ((double)pt1.Y - pt2.Y); + return (Dx*Dx + dy*dy); +} +//------------------------------------------------------------------------------ + +double DistanceFromLineSqrd( + const IntPoint& pt, const IntPoint& ln1, const IntPoint& ln2) +{ + //The equation of a line in general form (Ax + By + C = 0) + //given 2 points (x¹,y¹) & (x²,y²) is ... + //(y¹ - y²)x + (x² - x¹)y + (y² - y¹)x¹ - (x² - x¹)y¹ = 0 + //A = (y¹ - y²); B = (x² - x¹); C = (y² - y¹)x¹ - (x² - x¹)y¹ + //perpendicular distance of point (x³,y³) = (Ax³ + By³ + C)/Sqrt(A² + B²) + //see http://en.wikipedia.org/wiki/Perpendicular_distance + double A = double(ln1.Y - ln2.Y); + double B = double(ln2.X - ln1.X); + double C = A * ln1.X + B * ln1.Y; + C = A * pt.X + B * pt.Y - C; + return (C * C) / (A * A + B * B); +} +//--------------------------------------------------------------------------- + +bool SlopesNearCollinear(const IntPoint& pt1, + const IntPoint& pt2, const IntPoint& pt3, double distSqrd) +{ + //this function is more accurate when the point that's geometrically + //between the other 2 points is the one that's tested for distance. + //ie makes it more likely to pick up 'spikes' ... + if (Abs(pt1.X - pt2.X) > Abs(pt1.Y - pt2.Y)) + { + if ((pt1.X > pt2.X) == (pt1.X < pt3.X)) + return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd; + else if ((pt2.X > pt1.X) == (pt2.X < pt3.X)) + return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd; + else + return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd; + } + else + { + if ((pt1.Y > pt2.Y) == (pt1.Y < pt3.Y)) + return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd; + else if ((pt2.Y > pt1.Y) == (pt2.Y < pt3.Y)) + return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd; + else + return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd; + } +} +//------------------------------------------------------------------------------ + +bool PointsAreClose(IntPoint pt1, IntPoint pt2, double distSqrd) +{ + double Dx = (double)pt1.X - pt2.X; + double dy = (double)pt1.Y - pt2.Y; + return ((Dx * Dx) + (dy * dy) <= distSqrd); +} +//------------------------------------------------------------------------------ + +OutPt* ExcludeOp(OutPt* op) +{ + OutPt* result = op->Prev; + result->Next = op->Next; + op->Next->Prev = result; + result->Idx = 0; + return result; +} +//------------------------------------------------------------------------------ + +void CleanPolygon(const Path& in_poly, Path& out_poly, double distance) +{ + //distance = proximity in units/pixels below which vertices + //will be stripped. Default ~= sqrt(2). + + size_t size = in_poly.size(); + + if (size == 0) + { + out_poly.clear(); + return; + } + + OutPt* outPts = new OutPt[size]; + for (size_t i = 0; i < size; ++i) + { + outPts[i].Pt = in_poly[i]; + outPts[i].Next = &outPts[(i + 1) % size]; + outPts[i].Next->Prev = &outPts[i]; + outPts[i].Idx = 0; + } + + double distSqrd = distance * distance; + OutPt* op = &outPts[0]; + while (op->Idx == 0 && op->Next != op->Prev) + { + if (PointsAreClose(op->Pt, op->Prev->Pt, distSqrd)) + { + op = ExcludeOp(op); + size--; + } + else if (PointsAreClose(op->Prev->Pt, op->Next->Pt, distSqrd)) + { + ExcludeOp(op->Next); + op = ExcludeOp(op); + size -= 2; + } + else if (SlopesNearCollinear(op->Prev->Pt, op->Pt, op->Next->Pt, distSqrd)) + { + op = ExcludeOp(op); + size--; + } + else + { + op->Idx = 1; + op = op->Next; + } + } + + if (size < 3) size = 0; + out_poly.resize(size); + for (size_t i = 0; i < size; ++i) + { + out_poly[i] = op->Pt; + op = op->Next; + } + delete [] outPts; +} +//------------------------------------------------------------------------------ + +void CleanPolygon(Path& poly, double distance) +{ + CleanPolygon(poly, poly, distance); +} +//------------------------------------------------------------------------------ + +void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance) +{ + out_polys.resize(in_polys.size()); + for (Paths::size_type i = 0; i < in_polys.size(); ++i) + CleanPolygon(in_polys[i], out_polys[i], distance); +} +//------------------------------------------------------------------------------ + +void CleanPolygons(Paths& polys, double distance) +{ + CleanPolygons(polys, polys, distance); +} +//------------------------------------------------------------------------------ + +void Minkowski(const Path& poly, const Path& path, + Paths& solution, bool isSum, bool isClosed) +{ + int delta = (isClosed ? 1 : 0); + size_t polyCnt = poly.size(); + size_t pathCnt = path.size(); + Paths pp; + pp.reserve(pathCnt); + if (isSum) + for (size_t i = 0; i < pathCnt; ++i) + { + Path p; + p.reserve(polyCnt); + for (size_t j = 0; j < poly.size(); ++j) + p.push_back(IntPoint(path[i].X + poly[j].X, path[i].Y + poly[j].Y)); + pp.push_back(p); + } + else + for (size_t i = 0; i < pathCnt; ++i) + { + Path p; + p.reserve(polyCnt); + for (size_t j = 0; j < poly.size(); ++j) + p.push_back(IntPoint(path[i].X - poly[j].X, path[i].Y - poly[j].Y)); + pp.push_back(p); + } + + solution.clear(); + solution.reserve((pathCnt + delta) * (polyCnt + 1)); + for (size_t i = 0; i < pathCnt - 1 + delta; ++i) + for (size_t j = 0; j < polyCnt; ++j) + { + Path quad; + quad.reserve(4); + quad.push_back(pp[i % pathCnt][j % polyCnt]); + quad.push_back(pp[(i + 1) % pathCnt][j % polyCnt]); + quad.push_back(pp[(i + 1) % pathCnt][(j + 1) % polyCnt]); + quad.push_back(pp[i % pathCnt][(j + 1) % polyCnt]); + if (!Orientation(quad)) ReversePath(quad); + solution.push_back(quad); + } +} +//------------------------------------------------------------------------------ + +void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed) +{ + Minkowski(pattern, path, solution, true, pathIsClosed); + Clipper c; + c.AddPaths(solution, ptSubject, true); + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +void TranslatePath(const Path& input, Path& output, const IntPoint delta) +{ + //precondition: input != output + output.resize(input.size()); + for (size_t i = 0; i < input.size(); ++i) + output[i] = IntPoint(input[i].X + delta.X, input[i].Y + delta.Y); +} +//------------------------------------------------------------------------------ + +void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed) +{ + Clipper c; + for (size_t i = 0; i < paths.size(); ++i) + { + Paths tmp; + Minkowski(pattern, paths[i], tmp, true, pathIsClosed); + c.AddPaths(tmp, ptSubject, true); + if (pathIsClosed) + { + Path tmp2; + TranslatePath(paths[i], tmp2, pattern[0]); + c.AddPath(tmp2, ptClip, true); + } + } + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution) +{ + Minkowski(poly1, poly2, solution, false, true); + Clipper c; + c.AddPaths(solution, ptSubject, true); + c.Execute(ctUnion, solution, pftNonZero, pftNonZero); +} +//------------------------------------------------------------------------------ + +enum NodeType {ntAny, ntOpen, ntClosed}; + +void AddPolyNodeToPaths(const PolyNode& polynode, NodeType nodetype, Paths& paths) +{ + bool match = true; + if (nodetype == ntClosed) match = !polynode.IsOpen(); + else if (nodetype == ntOpen) return; + + if (!polynode.Contour.empty() && match) + paths.push_back(polynode.Contour); + for (int i = 0; i < polynode.ChildCount(); ++i) + AddPolyNodeToPaths(*polynode.Childs[i], nodetype, paths); +} +//------------------------------------------------------------------------------ + +void PolyTreeToPaths(const PolyTree& polytree, Paths& paths) +{ + paths.resize(0); + paths.reserve(polytree.Total()); + AddPolyNodeToPaths(polytree, ntAny, paths); +} +//------------------------------------------------------------------------------ + +void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths) +{ + paths.resize(0); + paths.reserve(polytree.Total()); + AddPolyNodeToPaths(polytree, ntClosed, paths); +} +//------------------------------------------------------------------------------ + +void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths) +{ + paths.resize(0); + paths.reserve(polytree.Total()); + //Open paths are top level only, so ... + for (int i = 0; i < polytree.ChildCount(); ++i) + if (polytree.Childs[i]->IsOpen()) + paths.push_back(polytree.Childs[i]->Contour); +} +//------------------------------------------------------------------------------ + +std::ostream& operator <<(std::ostream &s, const IntPoint &p) +{ + s << "(" << p.X << "," << p.Y << ")"; + return s; +} +//------------------------------------------------------------------------------ + +std::ostream& operator <<(std::ostream &s, const Path &p) +{ + if (p.empty()) return s; + Path::size_type last = p.size() -1; + for (Path::size_type i = 0; i < last; i++) + s << "(" << p[i].X << "," << p[i].Y << "), "; + s << "(" << p[last].X << "," << p[last].Y << ")\n"; + return s; +} +//------------------------------------------------------------------------------ + +std::ostream& operator <<(std::ostream &s, const Paths &p) +{ + for (Paths::size_type i = 0; i < p.size(); i++) + s << p[i]; + s << "\n"; + return s; +} +//------------------------------------------------------------------------------ + +} //ClipperLib namespace diff --git a/mmocr/models/textdet/postprocess/include/clipper/clipper.hpp b/mmocr/models/textdet/postprocess/include/clipper/clipper.hpp new file mode 100644 index 00000000..2ac51182 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/clipper/clipper.hpp @@ -0,0 +1,402 @@ +/******************************************************************************* +* * +* Author : Angus Johnson * +* Version : 6.4.0 * +* Date : 2 July 2015 * +* Website : http://www.angusj.com * +* Copyright : Angus Johnson 2010-2015 * +* * +* License: * +* Use, modification & distribution is subject to Boost Software License Ver 1. * +* http://www.boost.org/LICENSE_1_0.txt * +* * +* Attributions: * +* The code in this library is an extension of Bala Vatti's clipping algorithm: * +* "A generic solution to polygon clipping" * +* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. * +* http://portal.acm.org/citation.cfm?id=129906 * +* * +* Computer graphics and geometric modeling: implementation and algorithms * +* By Max K. Agoston * +* Springer; 1 edition (January 4, 2005) * +* http://books.google.com/books?q=vatti+clipping+agoston * +* * +* See also: * +* "Polygon Offsetting by Computing Winding Numbers" * +* Paper no. DETC2005-85513 pp. 565-575 * +* ASME 2005 International Design Engineering Technical Conferences * +* and Computers and Information in Engineering Conference (IDETC/CIE2005) * +* September 24-28, 2005 , Long Beach, California, USA * +* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf * +* * +*******************************************************************************/ + +#ifndef clipper_hpp +#define clipper_hpp + +#define CLIPPER_VERSION "6.2.6" + +//use_int32: When enabled 32bit ints are used instead of 64bit ints. This +//improve performance but coordinate values are limited to the range +/- 46340 +//#define use_int32 + +//use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance. +//#define use_xyz + +//use_lines: Enables line clipping. Adds a very minor cost to performance. +#define use_lines + +//use_deprecated: Enables temporary support for the obsolete functions +//#define use_deprecated + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ClipperLib { + +enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor }; +enum PolyType { ptSubject, ptClip }; +//By far the most widely used winding rules for polygon filling are +//EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32) +//Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL) +//see http://glprogramming.com/red/chapter11.html +enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative }; + +#ifdef use_int32 + typedef int cInt; + static cInt const loRange = 0x7FFF; + static cInt const hiRange = 0x7FFF; +#else + typedef signed long long cInt; + static cInt const loRange = 0x3FFFFFFF; + static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL; + typedef signed long long long64; //used by Int128 class + typedef unsigned long long ulong64; + +#endif + +struct IntPoint { + cInt X; + cInt Y; +#ifdef use_xyz + cInt Z; + IntPoint(cInt x = 0, cInt y = 0, cInt z = 0): X(x), Y(y), Z(z) {}; +#else + IntPoint(cInt x = 0, cInt y = 0): X(x), Y(y) {}; +#endif + + friend inline bool operator== (const IntPoint& a, const IntPoint& b) + { + return a.X == b.X && a.Y == b.Y; + } + friend inline bool operator!= (const IntPoint& a, const IntPoint& b) + { + return a.X != b.X || a.Y != b.Y; + } +}; +//------------------------------------------------------------------------------ + +typedef std::vector< IntPoint > Path; +typedef std::vector< Path > Paths; + +inline Path& operator <<(Path& poly, const IntPoint& p) {poly.push_back(p); return poly;} +inline Paths& operator <<(Paths& polys, const Path& p) {polys.push_back(p); return polys;} + +std::ostream& operator <<(std::ostream &s, const IntPoint &p); +std::ostream& operator <<(std::ostream &s, const Path &p); +std::ostream& operator <<(std::ostream &s, const Paths &p); + +struct DoublePoint +{ + double X; + double Y; + DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {} + DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {} +}; +//------------------------------------------------------------------------------ + +#ifdef use_xyz +typedef void (*ZFillCallback)(IntPoint& e1bot, IntPoint& e1top, IntPoint& e2bot, IntPoint& e2top, IntPoint& pt); +#endif + +enum InitOptions {ioReverseSolution = 1, ioStrictlySimple = 2, ioPreserveCollinear = 4}; +enum JoinType {jtSquare, jtRound, jtMiter}; +enum EndType {etClosedPolygon, etClosedLine, etOpenButt, etOpenSquare, etOpenRound}; + +class PolyNode; +typedef std::vector< PolyNode* > PolyNodes; + +class PolyNode +{ +public: + PolyNode(); + virtual ~PolyNode(){}; + Path Contour; + PolyNodes Childs; + PolyNode* Parent; + PolyNode* GetNext() const; + bool IsHole() const; + bool IsOpen() const; + int ChildCount() const; +private: + unsigned Index; //node index in Parent.Childs + bool m_IsOpen; + JoinType m_jointype; + EndType m_endtype; + PolyNode* GetNextSiblingUp() const; + void AddChild(PolyNode& child); + friend class Clipper; //to access Index + friend class ClipperOffset; +}; + +class PolyTree: public PolyNode +{ +public: + ~PolyTree(){Clear();}; + PolyNode* GetFirst() const; + void Clear(); + int Total() const; +private: + PolyNodes AllNodes; + friend class Clipper; //to access AllNodes +}; + +bool Orientation(const Path &poly); +double Area(const Path &poly); +int PointInPolygon(const IntPoint &pt, const Path &path); + +void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType = pftEvenOdd); +void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType = pftEvenOdd); +void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd); + +void CleanPolygon(const Path& in_poly, Path& out_poly, double distance = 1.415); +void CleanPolygon(Path& poly, double distance = 1.415); +void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance = 1.415); +void CleanPolygons(Paths& polys, double distance = 1.415); + +void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed); +void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed); +void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution); + +void PolyTreeToPaths(const PolyTree& polytree, Paths& paths); +void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths); +void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths); + +void ReversePath(Path& p); +void ReversePaths(Paths& p); + +struct IntRect { cInt left; cInt top; cInt right; cInt bottom; }; + +//enums that are used internally ... +enum EdgeSide { esLeft = 1, esRight = 2}; + +//forward declarations (for stuff used internally) ... +struct TEdge; +struct IntersectNode; +struct LocalMinimum; +struct OutPt; +struct OutRec; +struct Join; + +typedef std::vector < OutRec* > PolyOutList; +typedef std::vector < TEdge* > EdgeList; +typedef std::vector < Join* > JoinList; +typedef std::vector < IntersectNode* > IntersectList; + +//------------------------------------------------------------------------------ + +//ClipperBase is the ancestor to the Clipper class. It should not be +//instantiated directly. This class simply abstracts the conversion of sets of +//polygon coordinates into edge objects that are stored in a LocalMinima list. +class ClipperBase +{ +public: + ClipperBase(); + virtual ~ClipperBase(); + virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed); + bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed); + virtual void Clear(); + IntRect GetBounds(); + bool PreserveCollinear() {return m_PreserveCollinear;}; + void PreserveCollinear(bool value) {m_PreserveCollinear = value;}; +protected: + void DisposeLocalMinimaList(); + TEdge* AddBoundsToLML(TEdge *e, bool IsClosed); + virtual void Reset(); + TEdge* ProcessBound(TEdge* E, bool IsClockwise); + void InsertScanbeam(const cInt Y); + bool PopScanbeam(cInt &Y); + bool LocalMinimaPending(); + bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin); + OutRec* CreateOutRec(); + void DisposeAllOutRecs(); + void DisposeOutRec(PolyOutList::size_type index); + void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2); + void DeleteFromAEL(TEdge *e); + void UpdateEdgeIntoAEL(TEdge *&e); + + typedef std::vector MinimaList; + MinimaList::iterator m_CurrentLM; + MinimaList m_MinimaList; + + bool m_UseFullRange; + EdgeList m_edges; + bool m_PreserveCollinear; + bool m_HasOpenPaths; + PolyOutList m_PolyOuts; + TEdge *m_ActiveEdges; + + typedef std::priority_queue ScanbeamList; + ScanbeamList m_Scanbeam; +}; +//------------------------------------------------------------------------------ + +class Clipper : public virtual ClipperBase +{ +public: + Clipper(int initOptions = 0); + bool Execute(ClipType clipType, + Paths &solution, + PolyFillType fillType = pftEvenOdd); + bool Execute(ClipType clipType, + Paths &solution, + PolyFillType subjFillType, + PolyFillType clipFillType); + bool Execute(ClipType clipType, + PolyTree &polytree, + PolyFillType fillType = pftEvenOdd); + bool Execute(ClipType clipType, + PolyTree &polytree, + PolyFillType subjFillType, + PolyFillType clipFillType); + bool ReverseSolution() { return m_ReverseOutput; }; + void ReverseSolution(bool value) {m_ReverseOutput = value;}; + bool StrictlySimple() {return m_StrictSimple;}; + void StrictlySimple(bool value) {m_StrictSimple = value;}; + //set the callback function for z value filling on intersections (otherwise Z is 0) +#ifdef use_xyz + void ZFillFunction(ZFillCallback zFillFunc); +#endif +protected: + virtual bool ExecuteInternal(); +private: + JoinList m_Joins; + JoinList m_GhostJoins; + IntersectList m_IntersectList; + ClipType m_ClipType; + typedef std::list MaximaList; + MaximaList m_Maxima; + TEdge *m_SortedEdges; + bool m_ExecuteLocked; + PolyFillType m_ClipFillType; + PolyFillType m_SubjFillType; + bool m_ReverseOutput; + bool m_UsingPolyTree; + bool m_StrictSimple; +#ifdef use_xyz + ZFillCallback m_ZFill; //custom callback +#endif + void SetWindingCount(TEdge& edge); + bool IsEvenOddFillType(const TEdge& edge) const; + bool IsEvenOddAltFillType(const TEdge& edge) const; + void InsertLocalMinimaIntoAEL(const cInt botY); + void InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge); + void AddEdgeToSEL(TEdge *edge); + bool PopEdgeFromSEL(TEdge *&edge); + void CopyAELToSEL(); + void DeleteFromSEL(TEdge *e); + void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2); + bool IsContributing(const TEdge& edge) const; + bool IsTopHorz(const cInt XPos); + void DoMaxima(TEdge *e); + void ProcessHorizontals(); + void ProcessHorizontal(TEdge *horzEdge); + void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); + OutPt* AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); + OutRec* GetOutRec(int idx); + void AppendPolygon(TEdge *e1, TEdge *e2); + void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt); + OutPt* AddOutPt(TEdge *e, const IntPoint &pt); + OutPt* GetLastOutPt(TEdge *e); + bool ProcessIntersections(const cInt topY); + void BuildIntersectList(const cInt topY); + void ProcessIntersectList(); + void ProcessEdgesAtTopOfScanbeam(const cInt topY); + void BuildResult(Paths& polys); + void BuildResult2(PolyTree& polytree); + void SetHoleState(TEdge *e, OutRec *outrec); + void DisposeIntersectNodes(); + bool FixupIntersectionOrder(); + void FixupOutPolygon(OutRec &outrec); + void FixupOutPolyline(OutRec &outrec); + bool IsHole(TEdge *e); + bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl); + void FixHoleLinkage(OutRec &outrec); + void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt); + void ClearJoins(); + void ClearGhostJoins(); + void AddGhostJoin(OutPt *op, const IntPoint offPt); + bool JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2); + void JoinCommonEdges(); + void DoSimplePolygons(); + void FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec); + void FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec); + void FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec); +#ifdef use_xyz + void SetZ(IntPoint& pt, TEdge& e1, TEdge& e2); +#endif +}; +//------------------------------------------------------------------------------ + +class ClipperOffset +{ +public: + ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25); + ~ClipperOffset(); + void AddPath(const Path& path, JoinType joinType, EndType endType); + void AddPaths(const Paths& paths, JoinType joinType, EndType endType); + void Execute(Paths& solution, double delta); + void Execute(PolyTree& solution, double delta); + void Clear(); + double MiterLimit; + double ArcTolerance; +private: + Paths m_destPolys; + Path m_srcPoly; + Path m_destPoly; + std::vector m_normals; + double m_delta, m_sinA, m_sin, m_cos; + double m_miterLim, m_StepsPerRad; + IntPoint m_lowest; + PolyNode m_polyNodes; + + void FixOrientations(); + void DoOffset(double delta); + void OffsetPoint(int j, int& k, JoinType jointype); + void DoSquare(int j, int k); + void DoMiter(int j, int k, double r); + void DoRound(int j, int k); +}; +//------------------------------------------------------------------------------ + +class clipperException : public std::exception +{ + public: + clipperException(const char* description): m_descr(description) {} + virtual ~clipperException() throw() {} + virtual const char* what() const throw() {return m_descr.c_str();} + private: + std::string m_descr; +}; +//------------------------------------------------------------------------------ + +} //ClipperLib namespace + +#endif //clipper_hpp diff --git a/mmocr/models/textdet/postprocess/include/pybind11/attr.h b/mmocr/models/textdet/postprocess/include/pybind11/attr.h new file mode 100644 index 00000000..8732cfe1 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/attr.h @@ -0,0 +1,492 @@ +/* + pybind11/attr.h: Infrastructure for processing custom + type and function attributes + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "cast.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// \addtogroup annotations +/// @{ + +/// Annotation for methods +struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; + +/// Annotation for operators +struct is_operator { }; + +/// Annotation for parent scope +struct scope { handle value; scope(const handle &s) : value(s) { } }; + +/// Annotation for documentation +struct doc { const char *value; doc(const char *value) : value(value) { } }; + +/// Annotation for function names +struct name { const char *value; name(const char *value) : value(value) { } }; + +/// Annotation indicating that a function is an overload associated with a given "sibling" +struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; + +/// Annotation indicating that a class derives from another given type +template struct base { + PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") + base() { } +}; + +/// Keep patient alive while nurse lives +template struct keep_alive { }; + +/// Annotation indicating that a class is involved in a multiple inheritance relationship +struct multiple_inheritance { }; + +/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class +struct dynamic_attr { }; + +/// Annotation which enables the buffer protocol for a type +struct buffer_protocol { }; + +/// Annotation which requests that a special metaclass is created for a type +struct metaclass { + handle value; + + PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") + metaclass() {} + + /// Override pybind11's default metaclass + explicit metaclass(handle value) : value(value) { } +}; + +/// Annotation that marks a class as local to the module: +struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; + +/// Annotation to mark enums as an arithmetic type +struct arithmetic { }; + +/** \rst + A call policy which places one or more guard variables (``Ts...``) around the function call. + + For example, this definition: + + .. code-block:: cpp + + m.def("foo", foo, py::call_guard()); + + is equivalent to the following pseudocode: + + .. code-block:: cpp + + m.def("foo", [](args...) { + T scope_guard; + return foo(args...); // forwarded arguments + }); + \endrst */ +template struct call_guard; + +template <> struct call_guard<> { using type = detail::void_type; }; + +template +struct call_guard { + static_assert(std::is_default_constructible::value, + "The guard type must be default constructible"); + + using type = T; +}; + +template +struct call_guard { + struct type { + T guard{}; // Compose multiple guard types with left-to-right default-constructor order + typename call_guard::type next{}; + }; +}; + +/// @} annotations + +NAMESPACE_BEGIN(detail) +/* Forward declarations */ +enum op_id : int; +enum op_type : int; +struct undefined_t; +template struct op_; +inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); + +/// Internal data structure which holds metadata about a keyword argument +struct argument_record { + const char *name; ///< Argument name + const char *descr; ///< Human-readable version of the argument value + handle value; ///< Associated Python object + bool convert : 1; ///< True if the argument is allowed to convert when loading + bool none : 1; ///< True if None is allowed when loading + + argument_record(const char *name, const char *descr, handle value, bool convert, bool none) + : name(name), descr(descr), value(value), convert(convert), none(none) { } +}; + +/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) +struct function_record { + function_record() + : is_constructor(false), is_new_style_constructor(false), is_stateless(false), + is_operator(false), has_args(false), has_kwargs(false), is_method(false) { } + + /// Function name + char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ + + // User-specified documentation string + char *doc = nullptr; + + /// Human-readable version of the function signature + char *signature = nullptr; + + /// List of registered keyword arguments + std::vector args; + + /// Pointer to lambda function which converts arguments and performs the actual call + handle (*impl) (function_call &) = nullptr; + + /// Storage for the wrapped function pointer and captured data, if any + void *data[3] = { }; + + /// Pointer to custom destructor for 'data' (if needed) + void (*free_data) (function_record *ptr) = nullptr; + + /// Return value policy associated with this function + return_value_policy policy = return_value_policy::automatic; + + /// True if name == '__init__' + bool is_constructor : 1; + + /// True if this is a new-style `__init__` defined in `detail/init.h` + bool is_new_style_constructor : 1; + + /// True if this is a stateless function pointer + bool is_stateless : 1; + + /// True if this is an operator (__add__), etc. + bool is_operator : 1; + + /// True if the function has a '*args' argument + bool has_args : 1; + + /// True if the function has a '**kwargs' argument + bool has_kwargs : 1; + + /// True if this is a method + bool is_method : 1; + + /// Number of arguments (including py::args and/or py::kwargs, if present) + std::uint16_t nargs; + + /// Python method object + PyMethodDef *def = nullptr; + + /// Python handle to the parent scope (a class or a module) + handle scope; + + /// Python handle to the sibling function representing an overload chain + handle sibling; + + /// Pointer to next overload + function_record *next = nullptr; +}; + +/// Special data structure which (temporarily) holds metadata about a bound class +struct type_record { + PYBIND11_NOINLINE type_record() + : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { } + + /// Handle to the parent scope + handle scope; + + /// Name of the class + const char *name = nullptr; + + // Pointer to RTTI type_info data structure + const std::type_info *type = nullptr; + + /// How large is the underlying C++ type? + size_t type_size = 0; + + /// What is the alignment of the underlying C++ type? + size_t type_align = 0; + + /// How large is the type's holder? + size_t holder_size = 0; + + /// The global operator new can be overridden with a class-specific variant + void *(*operator_new)(size_t) = nullptr; + + /// Function pointer to class_<..>::init_instance + void (*init_instance)(instance *, const void *) = nullptr; + + /// Function pointer to class_<..>::dealloc + void (*dealloc)(detail::value_and_holder &) = nullptr; + + /// List of base classes of the newly created type + list bases; + + /// Optional docstring + const char *doc = nullptr; + + /// Custom metaclass (optional) + handle metaclass; + + /// Multiple inheritance marker + bool multiple_inheritance : 1; + + /// Does the class manage a __dict__? + bool dynamic_attr : 1; + + /// Does the class implement the buffer protocol? + bool buffer_protocol : 1; + + /// Is the default (unique_ptr) holder type used? + bool default_holder : 1; + + /// Is the class definition local to the module shared object? + bool module_local : 1; + + PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { + auto base_info = detail::get_type_info(base, false); + if (!base_info) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + + "\" referenced unknown base type \"" + tname + "\""); + } + + if (default_holder != base_info->default_holder) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + + (default_holder ? "does not have" : "has") + + " a non-default holder type while its base \"" + tname + "\" " + + (base_info->default_holder ? "does not" : "does")); + } + + bases.append((PyObject *) base_info->type); + + if (base_info->type->tp_dictoffset != 0) + dynamic_attr = true; + + if (caster) + base_info->implicit_casts.emplace_back(type, caster); + } +}; + +inline function_call::function_call(const function_record &f, handle p) : + func(f), parent(p) { + args.reserve(f.nargs); + args_convert.reserve(f.nargs); +} + +/// Tag for a new-style `__init__` defined in `detail/init.h` +struct is_new_style_constructor { }; + +/** + * Partial template specializations to process custom attributes provided to + * cpp_function_ and class_. These are either used to initialize the respective + * fields in the type_record and function_record data structures or executed at + * runtime to deal with custom call policies (e.g. keep_alive). + */ +template struct process_attribute; + +template struct process_attribute_default { + /// Default implementation: do nothing + static void init(const T &, function_record *) { } + static void init(const T &, type_record *) { } + static void precall(function_call &) { } + static void postcall(function_call &, handle) { } +}; + +/// Process an attribute specifying the function's name +template <> struct process_attribute : process_attribute_default { + static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } +}; + +/// Process an attribute specifying the function's docstring +template <> struct process_attribute : process_attribute_default { + static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } +}; + +/// Process an attribute specifying the function's docstring (provided as a C-style string) +template <> struct process_attribute : process_attribute_default { + static void init(const char *d, function_record *r) { r->doc = const_cast(d); } + static void init(const char *d, type_record *r) { r->doc = const_cast(d); } +}; +template <> struct process_attribute : process_attribute { }; + +/// Process an attribute indicating the function's return value policy +template <> struct process_attribute : process_attribute_default { + static void init(const return_value_policy &p, function_record *r) { r->policy = p; } +}; + +/// Process an attribute which indicates that this is an overloaded function associated with a given sibling +template <> struct process_attribute : process_attribute_default { + static void init(const sibling &s, function_record *r) { r->sibling = s.value; } +}; + +/// Process an attribute which indicates that this function is a method +template <> struct process_attribute : process_attribute_default { + static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } +}; + +/// Process an attribute which indicates the parent scope of a method +template <> struct process_attribute : process_attribute_default { + static void init(const scope &s, function_record *r) { r->scope = s.value; } +}; + +/// Process an attribute which indicates that this function is an operator +template <> struct process_attribute : process_attribute_default { + static void init(const is_operator &, function_record *r) { r->is_operator = true; } +}; + +template <> struct process_attribute : process_attribute_default { + static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } +}; + +/// Process a keyword argument attribute (*without* a default value) +template <> struct process_attribute : process_attribute_default { + static void init(const arg &a, function_record *r) { + if (r->is_method && r->args.empty()) + r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); + r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); + } +}; + +/// Process a keyword argument attribute (*with* a default value) +template <> struct process_attribute : process_attribute_default { + static void init(const arg_v &a, function_record *r) { + if (r->is_method && r->args.empty()) + r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); + + if (!a.value) { +#if !defined(NDEBUG) + std::string descr("'"); + if (a.name) descr += std::string(a.name) + ": "; + descr += a.type + "'"; + if (r->is_method) { + if (r->name) + descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; + else + descr += " in method of '" + (std::string) str(r->scope) + "'"; + } else if (r->name) { + descr += " in function '" + (std::string) r->name + "'"; + } + pybind11_fail("arg(): could not convert default argument " + + descr + " into a Python object (type not registered yet?)"); +#else + pybind11_fail("arg(): could not convert default argument " + "into a Python object (type not registered yet?). " + "Compile in debug mode for more information."); +#endif + } + r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); + } +}; + +/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) +template +struct process_attribute::value>> : process_attribute_default { + static void init(const handle &h, type_record *r) { r->bases.append(h); } +}; + +/// Process a parent class attribute (deprecated, does not support multiple inheritance) +template +struct process_attribute> : process_attribute_default> { + static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } +}; + +/// Process a multiple inheritance attribute +template <> +struct process_attribute : process_attribute_default { + static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const module_local &l, type_record *r) { r->module_local = l.value; } +}; + +/// Process an 'arithmetic' attribute for enums (does nothing here) +template <> +struct process_attribute : process_attribute_default {}; + +template +struct process_attribute> : process_attribute_default> { }; + +/** + * Process a keep_alive call policy -- invokes keep_alive_impl during the + * pre-call handler if both Nurse, Patient != 0 and use the post-call handler + * otherwise + */ +template struct process_attribute> : public process_attribute_default> { + template = 0> + static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } + template = 0> + static void postcall(function_call &, handle) { } + template = 0> + static void precall(function_call &) { } + template = 0> + static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } +}; + +/// Recursively iterate over variadic template arguments +template struct process_attributes { + static void init(const Args&... args, function_record *r) { + int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; + ignore_unused(unused); + } + static void init(const Args&... args, type_record *r) { + int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; + ignore_unused(unused); + } + static void precall(function_call &call) { + int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; + ignore_unused(unused); + } + static void postcall(function_call &call, handle fn_ret) { + int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; + ignore_unused(unused); + } +}; + +template +using is_call_guard = is_instantiation; + +/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) +template +using extract_guard_t = typename exactly_one_t, Extra...>::type; + +/// Check the number of named arguments at compile time +template ::value...), + size_t self = constexpr_sum(std::is_same::value...)> +constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { + return named == 0 || (self + named + has_args + has_kwargs) == nargs; +} + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/buffer_info.h b/mmocr/models/textdet/postprocess/include/pybind11/buffer_info.h new file mode 100644 index 00000000..9f072fa7 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/buffer_info.h @@ -0,0 +1,108 @@ +/* + pybind11/buffer_info.h: Python buffer object interface + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// Information record describing a Python buffer object +struct buffer_info { + void *ptr = nullptr; // Pointer to the underlying storage + ssize_t itemsize = 0; // Size of individual items in bytes + ssize_t size = 0; // Total number of entries + std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() + ssize_t ndim = 0; // Number of dimensions + std::vector shape; // Shape of the tensor (1 entry per dimension) + std::vector strides; // Number of entries between adjacent entries (for each per dimension) + + buffer_info() { } + + buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, + detail::any_container shape_in, detail::any_container strides_in) + : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)) { + if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) + pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); + for (size_t i = 0; i < (size_t) ndim; ++i) + size *= shape[i]; + } + + template + buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) + : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } + + buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) + : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } + + template + buffer_info(T *ptr, ssize_t size) + : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } + + explicit buffer_info(Py_buffer *view, bool ownview = true) + : buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { + this->view = view; + this->ownview = ownview; + } + + buffer_info(const buffer_info &) = delete; + buffer_info& operator=(const buffer_info &) = delete; + + buffer_info(buffer_info &&other) { + (*this) = std::move(other); + } + + buffer_info& operator=(buffer_info &&rhs) { + ptr = rhs.ptr; + itemsize = rhs.itemsize; + size = rhs.size; + format = std::move(rhs.format); + ndim = rhs.ndim; + shape = std::move(rhs.shape); + strides = std::move(rhs.strides); + std::swap(view, rhs.view); + std::swap(ownview, rhs.ownview); + return *this; + } + + ~buffer_info() { + if (view && ownview) { PyBuffer_Release(view); delete view; } + } + +private: + struct private_ctr_tag { }; + + buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, + detail::any_container &&shape_in, detail::any_container &&strides_in) + : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } + + Py_buffer *view = nullptr; + bool ownview = false; +}; + +NAMESPACE_BEGIN(detail) + +template struct compare_buffer_info { + static bool compare(const buffer_info& b) { + return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); + } +}; + +template struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || + ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || + ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); + } +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/cast.h b/mmocr/models/textdet/postprocess/include/pybind11/cast.h new file mode 100644 index 00000000..80abb2b9 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/cast.h @@ -0,0 +1,2128 @@ +/* + pybind11/cast.h: Partial template specializations to cast between + C++ and Python types + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pytypes.h" +#include "detail/typeid.h" +#include "detail/descr.h" +#include "detail/internals.h" +#include +#include +#include +#include + +#if defined(PYBIND11_CPP17) +# if defined(__has_include) +# if __has_include() +# define PYBIND11_HAS_STRING_VIEW +# endif +# elif defined(_MSC_VER) +# define PYBIND11_HAS_STRING_VIEW +# endif +#endif +#ifdef PYBIND11_HAS_STRING_VIEW +#include +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// A life support system for temporary objects created by `type_caster::load()`. +/// Adding a patient will keep it alive up until the enclosing function returns. +class loader_life_support { +public: + /// A new patient frame is created when a function is entered + loader_life_support() { + get_internals().loader_patient_stack.push_back(nullptr); + } + + /// ... and destroyed after it returns + ~loader_life_support() { + auto &stack = get_internals().loader_patient_stack; + if (stack.empty()) + pybind11_fail("loader_life_support: internal error"); + + auto ptr = stack.back(); + stack.pop_back(); + Py_CLEAR(ptr); + + // A heuristic to reduce the stack's capacity (e.g. after long recursive calls) + if (stack.capacity() > 16 && stack.size() != 0 && stack.capacity() / stack.size() > 2) + stack.shrink_to_fit(); + } + + /// This can only be used inside a pybind11-bound function, either by `argument_loader` + /// at argument preparation time or by `py::cast()` at execution time. + PYBIND11_NOINLINE static void add_patient(handle h) { + auto &stack = get_internals().loader_patient_stack; + if (stack.empty()) + throw cast_error("When called outside a bound function, py::cast() cannot " + "do Python -> C++ conversions which require the creation " + "of temporary values"); + + auto &list_ptr = stack.back(); + if (list_ptr == nullptr) { + list_ptr = PyList_New(1); + if (!list_ptr) + pybind11_fail("loader_life_support: error allocating list"); + PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); + } else { + auto result = PyList_Append(list_ptr, h.ptr()); + if (result == -1) + pybind11_fail("loader_life_support: error adding patient"); + } + } +}; + +// Gets the cache entry for the given type, creating it if necessary. The return value is the pair +// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was +// just created. +inline std::pair all_type_info_get_cache(PyTypeObject *type); + +// Populates a just-created cache entry. +PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector &bases) { + std::vector check; + for (handle parent : reinterpret_borrow(t->tp_bases)) + check.push_back((PyTypeObject *) parent.ptr()); + + auto const &type_dict = get_internals().registered_types_py; + for (size_t i = 0; i < check.size(); i++) { + auto type = check[i]; + // Ignore Python2 old-style class super type: + if (!PyType_Check((PyObject *) type)) continue; + + // Check `type` in the current set of registered python types: + auto it = type_dict.find(type); + if (it != type_dict.end()) { + // We found a cache entry for it, so it's either pybind-registered or has pre-computed + // pybind bases, but we have to make sure we haven't already seen the type(s) before: we + // want to follow Python/virtual C++ rules that there should only be one instance of a + // common base. + for (auto *tinfo : it->second) { + // NB: Could use a second set here, rather than doing a linear search, but since + // having a large number of immediate pybind11-registered types seems fairly + // unlikely, that probably isn't worthwhile. + bool found = false; + for (auto *known : bases) { + if (known == tinfo) { found = true; break; } + } + if (!found) bases.push_back(tinfo); + } + } + else if (type->tp_bases) { + // It's some python type, so keep follow its bases classes to look for one or more + // registered types + if (i + 1 == check.size()) { + // When we're at the end, we can pop off the current element to avoid growing + // `check` when adding just one base (which is typical--i.e. when there is no + // multiple inheritance) + check.pop_back(); + i--; + } + for (handle parent : reinterpret_borrow(type->tp_bases)) + check.push_back((PyTypeObject *) parent.ptr()); + } + } +} + +/** + * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will + * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side + * derived class that uses single inheritance. Will contain as many types as required for a Python + * class that uses multiple inheritance to inherit (directly or indirectly) from multiple + * pybind-registered classes. Will be empty if neither the type nor any base classes are + * pybind-registered. + * + * The value is cached for the lifetime of the Python type. + */ +inline const std::vector &all_type_info(PyTypeObject *type) { + auto ins = all_type_info_get_cache(type); + if (ins.second) + // New cache entry: populate it + all_type_info_populate(type, ins.first->second); + + return ins.first->second; +} + +/** + * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any + * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use + * `all_type_info` instead if you want to support multiple bases. + */ +PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) { + auto &bases = all_type_info(type); + if (bases.size() == 0) + return nullptr; + if (bases.size() > 1) + pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases"); + return bases.front(); +} + +inline detail::type_info *get_local_type_info(const std::type_index &tp) { + auto &locals = registered_local_types_cpp(); + auto it = locals.find(tp); + if (it != locals.end()) + return it->second; + return nullptr; +} + +inline detail::type_info *get_global_type_info(const std::type_index &tp) { + auto &types = get_internals().registered_types_cpp; + auto it = types.find(tp); + if (it != types.end()) + return it->second; + return nullptr; +} + +/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr. +PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp, + bool throw_if_missing = false) { + if (auto ltype = get_local_type_info(tp)) + return ltype; + if (auto gtype = get_global_type_info(tp)) + return gtype; + + if (throw_if_missing) { + std::string tname = tp.name(); + detail::clean_type_id(tname); + pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\""); + } + return nullptr; +} + +PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) { + detail::type_info *type_info = get_type_info(tp, throw_if_missing); + return handle(type_info ? ((PyObject *) type_info->type) : nullptr); +} + +struct value_and_holder { + instance *inst; + size_t index; + const detail::type_info *type; + void **vh; + + // Main constructor for a found value/holder: + value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) : + inst{i}, index{index}, type{type}, + vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]} + {} + + // Default constructor (used to signal a value-and-holder not found by get_value_and_holder()) + value_and_holder() : inst{nullptr} {} + + // Used for past-the-end iterator + value_and_holder(size_t index) : index{index} {} + + template V *&value_ptr() const { + return reinterpret_cast(vh[0]); + } + // True if this `value_and_holder` has a non-null value pointer + explicit operator bool() const { return value_ptr(); } + + template H &holder() const { + return reinterpret_cast(vh[1]); + } + bool holder_constructed() const { + return inst->simple_layout + ? inst->simple_holder_constructed + : inst->nonsimple.status[index] & instance::status_holder_constructed; + } + void set_holder_constructed(bool v = true) { + if (inst->simple_layout) + inst->simple_holder_constructed = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_holder_constructed; + else + inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed; + } + bool instance_registered() const { + return inst->simple_layout + ? inst->simple_instance_registered + : inst->nonsimple.status[index] & instance::status_instance_registered; + } + void set_instance_registered(bool v = true) { + if (inst->simple_layout) + inst->simple_instance_registered = v; + else if (v) + inst->nonsimple.status[index] |= instance::status_instance_registered; + else + inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered; + } +}; + +// Container for accessing and iterating over an instance's values/holders +struct values_and_holders { +private: + instance *inst; + using type_vec = std::vector; + const type_vec &tinfo; + +public: + values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} + + struct iterator { + private: + instance *inst; + const type_vec *types; + value_and_holder curr; + friend struct values_and_holders; + iterator(instance *inst, const type_vec *tinfo) + : inst{inst}, types{tinfo}, + curr(inst /* instance */, + types->empty() ? nullptr : (*types)[0] /* type info */, + 0, /* vpos: (non-simple types only): the first vptr comes first */ + 0 /* index */) + {} + // Past-the-end iterator: + iterator(size_t end) : curr(end) {} + public: + bool operator==(const iterator &other) { return curr.index == other.curr.index; } + bool operator!=(const iterator &other) { return curr.index != other.curr.index; } + iterator &operator++() { + if (!inst->simple_layout) + curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; + ++curr.index; + curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; + return *this; + } + value_and_holder &operator*() { return curr; } + value_and_holder *operator->() { return &curr; } + }; + + iterator begin() { return iterator(inst, &tinfo); } + iterator end() { return iterator(tinfo.size()); } + + iterator find(const type_info *find_type) { + auto it = begin(), endit = end(); + while (it != endit && it->type != find_type) ++it; + return it; + } + + size_t size() { return tinfo.size(); } +}; + +/** + * Extracts C++ value and holder pointer references from an instance (which may contain multiple + * values/holders for python-side multiple inheritance) that match the given type. Throws an error + * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If + * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned, + * regardless of type (and the resulting .type will be nullptr). + * + * The returned object should be short-lived: in particular, it must not outlive the called-upon + * instance. + */ +PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) { + // Optimize common case: + if (!find_type || Py_TYPE(this) == find_type->type) + return value_and_holder(this, find_type, 0, 0); + + detail::values_and_holders vhs(this); + auto it = vhs.find(find_type); + if (it != vhs.end()) + return *it; + + if (!throw_if_missing) + return value_and_holder(); + +#if defined(NDEBUG) + pybind11_fail("pybind11::detail::instance::get_value_and_holder: " + "type is not a pybind11 base of the given instance " + "(compile in debug mode for type details)"); +#else + pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + + std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" + + std::string(Py_TYPE(this)->tp_name) + "' instance"); +#endif +} + +PYBIND11_NOINLINE inline void instance::allocate_layout() { + auto &tinfo = all_type_info(Py_TYPE(this)); + + const size_t n_types = tinfo.size(); + + if (n_types == 0) + pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types"); + + simple_layout = + n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs(); + + // Simple path: no python-side multiple inheritance, and a small-enough holder + if (simple_layout) { + simple_value_holder[0] = nullptr; + simple_holder_constructed = false; + simple_instance_registered = false; + } + else { // multiple base types or a too-large holder + // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer, + // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool + // values that tracks whether each associated holder has been initialized. Each [block] is + // padded, if necessary, to an integer multiple of sizeof(void *). + size_t space = 0; + for (auto t : tinfo) { + space += 1; // value pointer + space += t->holder_size_in_ptrs; // holder instance + } + size_t flags_at = space; + space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered) + + // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values, + // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6 + // they default to using pymalloc, which is designed to be efficient for small allocations + // like the one we're doing here; in earlier versions (and for larger allocations) they are + // just wrappers around malloc. +#if PY_VERSION_HEX >= 0x03050000 + nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *)); + if (!nonsimple.values_and_holders) throw std::bad_alloc(); +#else + nonsimple.values_and_holders = (void **) PyMem_New(void *, space); + if (!nonsimple.values_and_holders) throw std::bad_alloc(); + std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); +#endif + nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]); + } + owned = true; +} + +PYBIND11_NOINLINE inline void instance::deallocate_layout() { + if (!simple_layout) + PyMem_Free(nonsimple.values_and_holders); +} + +PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) { + handle type = detail::get_type_handle(tp, false); + if (!type) + return false; + return isinstance(obj, type); +} + +PYBIND11_NOINLINE inline std::string error_string() { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); + return "Unknown internal error occurred"; + } + + error_scope scope; // Preserve error state + + std::string errorString; + if (scope.type) { + errorString += handle(scope.type).attr("__name__").cast(); + errorString += ": "; + } + if (scope.value) + errorString += (std::string) str(scope.value); + + PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); + +#if PY_MAJOR_VERSION >= 3 + if (scope.trace != nullptr) + PyException_SetTraceback(scope.value, scope.trace); +#endif + +#if !defined(PYPY_VERSION) + if (scope.trace) { + PyTracebackObject *trace = (PyTracebackObject *) scope.trace; + + /* Get the deepest trace possible */ + while (trace->tb_next) + trace = trace->tb_next; + + PyFrameObject *frame = trace->tb_frame; + errorString += "\n\nAt:\n"; + while (frame) { + int lineno = PyFrame_GetLineNumber(frame); + errorString += + " " + handle(frame->f_code->co_filename).cast() + + "(" + std::to_string(lineno) + "): " + + handle(frame->f_code->co_name).cast() + "\n"; + frame = frame->f_back; + } + } +#endif + + return errorString; +} + +PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) { + auto &instances = get_internals().registered_instances; + auto range = instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + for (auto vh : values_and_holders(it->second)) { + if (vh.type == type) + return handle((PyObject *) it->second); + } + } + return handle(); +} + +inline PyThreadState *get_thread_state_unchecked() { +#if defined(PYPY_VERSION) + return PyThreadState_GET(); +#elif PY_VERSION_HEX < 0x03000000 + return _PyThreadState_Current; +#elif PY_VERSION_HEX < 0x03050000 + return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current); +#elif PY_VERSION_HEX < 0x03050200 + return (PyThreadState*) _PyThreadState_Current.value; +#else + return _PyThreadState_UncheckedGet(); +#endif +} + +// Forward declarations +inline void keep_alive_impl(handle nurse, handle patient); +inline PyObject *make_new_instance(PyTypeObject *type); + +class type_caster_generic { +public: + PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) + : typeinfo(get_type_info(type_info)), cpptype(&type_info) { } + + type_caster_generic(const type_info *typeinfo) + : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { } + + bool load(handle src, bool convert) { + return load_impl(src, convert); + } + + PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, + const detail::type_info *tinfo, + void *(*copy_constructor)(const void *), + void *(*move_constructor)(const void *), + const void *existing_holder = nullptr) { + if (!tinfo) // no type info: error will be set already + return handle(); + + void *src = const_cast(_src); + if (src == nullptr) + return none().release(); + + auto it_instances = get_internals().registered_instances.equal_range(src); + for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { + for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { + if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) + return handle((PyObject *) it_i->second).inc_ref(); + } + } + + auto inst = reinterpret_steal(make_new_instance(tinfo->type)); + auto wrapper = reinterpret_cast(inst.ptr()); + wrapper->owned = false; + void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); + + switch (policy) { + case return_value_policy::automatic: + case return_value_policy::take_ownership: + valueptr = src; + wrapper->owned = true; + break; + + case return_value_policy::automatic_reference: + case return_value_policy::reference: + valueptr = src; + wrapper->owned = false; + break; + + case return_value_policy::copy: + if (copy_constructor) + valueptr = copy_constructor(src); + else + throw cast_error("return_value_policy = copy, but the " + "object is non-copyable!"); + wrapper->owned = true; + break; + + case return_value_policy::move: + if (move_constructor) + valueptr = move_constructor(src); + else if (copy_constructor) + valueptr = copy_constructor(src); + else + throw cast_error("return_value_policy = move, but the " + "object is neither movable nor copyable!"); + wrapper->owned = true; + break; + + case return_value_policy::reference_internal: + valueptr = src; + wrapper->owned = false; + keep_alive_impl(inst, parent); + break; + + default: + throw cast_error("unhandled return_value_policy: should not happen!"); + } + + tinfo->init_instance(wrapper, existing_holder); + + return inst.release(); + } + + // Base methods for generic caster; there are overridden in copyable_holder_caster + void load_value(value_and_holder &&v_h) { + auto *&vptr = v_h.value_ptr(); + // Lazy allocation for unallocated values: + if (vptr == nullptr) { + auto *type = v_h.type ? v_h.type : typeinfo; + if (type->operator_new) { + vptr = type->operator_new(type->type_size); + } else { + #if defined(PYBIND11_CPP17) + if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) + vptr = ::operator new(type->type_size, + (std::align_val_t) type->type_align); + else + #endif + vptr = ::operator new(type->type_size); + } + } + value = vptr; + } + bool try_implicit_casts(handle src, bool convert) { + for (auto &cast : typeinfo->implicit_casts) { + type_caster_generic sub_caster(*cast.first); + if (sub_caster.load(src, convert)) { + value = cast.second(sub_caster.value); + return true; + } + } + return false; + } + bool try_direct_conversions(handle src) { + for (auto &converter : *typeinfo->direct_conversions) { + if (converter(src.ptr(), value)) + return true; + } + return false; + } + void check_holder_compat() {} + + PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) { + auto caster = type_caster_generic(ti); + if (caster.load(src, false)) + return caster.value; + return nullptr; + } + + /// Try to load with foreign typeinfo, if available. Used when there is no + /// native typeinfo, or when the native one wasn't able to produce a value. + PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { + constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; + const auto pytype = src.get_type(); + if (!hasattr(pytype, local_key)) + return false; + + type_info *foreign_typeinfo = reinterpret_borrow(getattr(pytype, local_key)); + // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type + if (foreign_typeinfo->module_local_load == &local_load + || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) + return false; + + if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { + value = result; + return true; + } + return false; + } + + // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant + // bits of code between here and copyable_holder_caster where the two classes need different + // logic (without having to resort to virtual inheritance). + template + PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { + if (!src) return false; + if (!typeinfo) return try_load_foreign_module_local(src); + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; + value = nullptr; + return true; + } + + auto &this_ = static_cast(*this); + this_.check_holder_compat(); + + PyTypeObject *srctype = Py_TYPE(src.ptr()); + + // Case 1: If src is an exact type match for the target type then we can reinterpret_cast + // the instance's value pointer to the target type: + if (srctype == typeinfo->type) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2: We have a derived class + else if (PyType_IsSubtype(srctype, typeinfo->type)) { + auto &bases = all_type_info(srctype); + bool no_cpp_mi = typeinfo->simple_type; + + // Case 2a: the python type is a Python-inherited derived class that inherits from just + // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of + // the right type and we can use reinterpret_cast. + // (This is essentially the same as case 2b, but because not using multiple inheritance + // is extremely common, we handle it specially to avoid the loop iterator and type + // pointer lookup overhead) + if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); + return true; + } + // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if + // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we + // can safely reinterpret_cast to the relevant pointer. + else if (bases.size() > 1) { + for (auto base : bases) { + if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder(base)); + return true; + } + } + } + + // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match + // in the registered bases, above, so try implicit casting (needed for proper C++ casting + // when MI is involved). + if (this_.try_implicit_casts(src, convert)) + return true; + } + + // Perform an implicit conversion + if (convert) { + for (auto &converter : typeinfo->implicit_conversions) { + auto temp = reinterpret_steal(converter(src.ptr(), typeinfo->type)); + if (load_impl(temp, false)) { + loader_life_support::add_patient(temp); + return true; + } + } + if (this_.try_direct_conversions(src)) + return true; + } + + // Failed to match local typeinfo. Try again with global. + if (typeinfo->module_local) { + if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { + typeinfo = gtype; + return load(src, false); + } + } + + // Global typeinfo has precedence over foreign module_local + return try_load_foreign_module_local(src); + } + + + // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast + // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair + // with .second = nullptr. (p.first = nullptr is not an error: it becomes None). + PYBIND11_NOINLINE static std::pair src_and_type( + const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) { + if (auto *tpi = get_type_info(cast_type)) + return {src, const_cast(tpi)}; + + // Not found, set error: + std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); + detail::clean_type_id(tname); + std::string msg = "Unregistered type : " + tname; + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return {nullptr, nullptr}; + } + + const type_info *typeinfo = nullptr; + const std::type_info *cpptype = nullptr; + void *value = nullptr; +}; + +/** + * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster + * needs to provide `operator T*()` and `operator T&()` operators. + * + * If the type supports moving the value away via an `operator T&&() &&` method, it should use + * `movable_cast_op_type` instead. + */ +template +using cast_op_type = + conditional_t>::value, + typename std::add_pointer>::type, + typename std::add_lvalue_reference>::type>; + +/** + * Determine suitable casting operator for a type caster with a movable value. Such a type caster + * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be + * called in appropriate contexts where the value can be moved rather than copied. + * + * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro. + */ +template +using movable_cast_op_type = + conditional_t::type>::value, + typename std::add_pointer>::type, + conditional_t::value, + typename std::add_rvalue_reference>::type, + typename std::add_lvalue_reference>::type>>; + +// std::is_copy_constructible isn't quite enough: it lets std::vector (and similar) through when +// T is non-copyable, but code containing such a copy constructor fails to actually compile. +template struct is_copy_constructible : std::is_copy_constructible {}; + +// Specialization for types that appear to be copy constructible but also look like stl containers +// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if +// so, copy constructability depends on whether the value_type is copy constructible. +template struct is_copy_constructible, + std::is_same + >::value>> : is_copy_constructible {}; + +#if !defined(PYBIND11_CPP17) +// Likewise for std::pair before C++17 (which mandates that the copy constructor not exist when the +// two types aren't themselves copy constructible). +template struct is_copy_constructible> + : all_of, is_copy_constructible> {}; +#endif + +NAMESPACE_END(detail) + +// polymorphic_type_hook::get(src, tinfo) determines whether the object pointed +// to by `src` actually is an instance of some class derived from `itype`. +// If so, it sets `tinfo` to point to the std::type_info representing that derived +// type, and returns a pointer to the start of the most-derived object of that type +// (in which `src` is a subobject; this will be the same address as `src` in most +// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src` +// and leaves `tinfo` at its default value of nullptr. +// +// The default polymorphic_type_hook just returns src. A specialization for polymorphic +// types determines the runtime type of the passed object and adjusts the this-pointer +// appropriately via dynamic_cast. This is what enables a C++ Animal* to appear +// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is +// registered with pybind11, and this Animal is in fact a Dog). +// +// You may specialize polymorphic_type_hook yourself for types that want to appear +// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern +// in performance-sensitive applications, used most notably in LLVM.) +template +struct polymorphic_type_hook +{ + static const void *get(const itype *src, const std::type_info*&) { return src; } +}; +template +struct polymorphic_type_hook::value>> +{ + static const void *get(const itype *src, const std::type_info*& type) { + type = src ? &typeid(*src) : nullptr; + return dynamic_cast(src); + } +}; + +NAMESPACE_BEGIN(detail) + +/// Generic type caster for objects stored on the heap +template class type_caster_base : public type_caster_generic { + using itype = intrinsic_t; + +public: + static constexpr auto name = _(); + + type_caster_base() : type_caster_base(typeid(type)) { } + explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { } + + static handle cast(const itype &src, return_value_policy policy, handle parent) { + if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast(&src, policy, parent); + } + + static handle cast(itype &&src, return_value_policy, handle parent) { + return cast(&src, return_value_policy::move, parent); + } + + // Returns a (pointer, type_info) pair taking care of necessary type lookup for a + // polymorphic type (using RTTI by default, but can be overridden by specializing + // polymorphic_type_hook). If the instance isn't derived, returns the base version. + static std::pair src_and_type(const itype *src) { + auto &cast_type = typeid(itype); + const std::type_info *instance_type = nullptr; + const void *vsrc = polymorphic_type_hook::get(src, instance_type); + if (instance_type && !same_type(cast_type, *instance_type)) { + // This is a base pointer to a derived type. If the derived type is registered + // with pybind11, we want to make the full derived object available. + // In the typical case where itype is polymorphic, we get the correct + // derived pointer (which may be != base pointer) by a dynamic_cast to + // most derived type. If itype is not polymorphic, we won't get here + // except via a user-provided specialization of polymorphic_type_hook, + // and the user has promised that no this-pointer adjustment is + // required in that case, so it's OK to use static_cast. + if (const auto *tpi = get_type_info(*instance_type)) + return {vsrc, tpi}; + } + // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so + // don't do a cast + return type_caster_generic::src_and_type(src, cast_type, instance_type); + } + + static handle cast(const itype *src, return_value_policy policy, handle parent) { + auto st = src_and_type(src); + return type_caster_generic::cast( + st.first, policy, parent, st.second, + make_copy_constructor(src), make_move_constructor(src)); + } + + static handle cast_holder(const itype *src, const void *holder) { + auto st = src_and_type(src); + return type_caster_generic::cast( + st.first, return_value_policy::take_ownership, {}, st.second, + nullptr, nullptr, holder); + } + + template using cast_op_type = detail::cast_op_type; + + operator itype*() { return (type *) value; } + operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); } + +protected: + using Constructor = void *(*)(const void *); + + /* Only enabled when the types are {copy,move}-constructible *and* when the type + does not have a private operator new implementation. */ + template ::value>> + static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) { + return [](const void *arg) -> void * { + return new T(*reinterpret_cast(arg)); + }; + } + + template ::value>> + static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { + return [](const void *arg) -> void * { + return new T(std::move(*const_cast(reinterpret_cast(arg)))); + }; + } + + static Constructor make_copy_constructor(...) { return nullptr; } + static Constructor make_move_constructor(...) { return nullptr; } +}; + +template class type_caster : public type_caster_base { }; +template using make_caster = type_caster>; + +// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T +template typename make_caster::template cast_op_type cast_op(make_caster &caster) { + return caster.operator typename make_caster::template cast_op_type(); +} +template typename make_caster::template cast_op_type::type> +cast_op(make_caster &&caster) { + return std::move(caster).operator + typename make_caster::template cast_op_type::type>(); +} + +template class type_caster> { +private: + using caster_t = make_caster; + caster_t subcaster; + using subcaster_cast_op_type = typename caster_t::template cast_op_type; + static_assert(std::is_same::type &, subcaster_cast_op_type>::value, + "std::reference_wrapper caster requires T to have a caster with an `T &` operator"); +public: + bool load(handle src, bool convert) { return subcaster.load(src, convert); } + static constexpr auto name = caster_t::name; + static handle cast(const std::reference_wrapper &src, return_value_policy policy, handle parent) { + // It is definitely wrong to take ownership of this pointer, so mask that rvp + if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic) + policy = return_value_policy::automatic_reference; + return caster_t::cast(&src.get(), policy, parent); + } + template using cast_op_type = std::reference_wrapper; + operator std::reference_wrapper() { return subcaster.operator subcaster_cast_op_type&(); } +}; + +#define PYBIND11_TYPE_CASTER(type, py_name) \ + protected: \ + type value; \ + public: \ + static constexpr auto name = py_name; \ + template >::value, int> = 0> \ + static handle cast(T_ *src, return_value_policy policy, handle parent) { \ + if (!src) return none().release(); \ + if (policy == return_value_policy::take_ownership) { \ + auto h = cast(std::move(*src), policy, parent); delete src; return h; \ + } else { \ + return cast(*src, policy, parent); \ + } \ + } \ + operator type*() { return &value; } \ + operator type&() { return value; } \ + operator type&&() && { return std::move(value); } \ + template using cast_op_type = pybind11::detail::movable_cast_op_type + + +template using is_std_char_type = any_of< + std::is_same, /* std::string */ + std::is_same, /* std::u16string */ + std::is_same, /* std::u32string */ + std::is_same /* std::wstring */ +>; + +template +struct type_caster::value && !is_std_char_type::value>> { + using _py_type_0 = conditional_t; + using _py_type_1 = conditional_t::value, _py_type_0, typename std::make_unsigned<_py_type_0>::type>; + using py_type = conditional_t::value, double, _py_type_1>; +public: + + bool load(handle src, bool convert) { + py_type py_value; + + if (!src) + return false; + + if (std::is_floating_point::value) { + if (convert || PyFloat_Check(src.ptr())) + py_value = (py_type) PyFloat_AsDouble(src.ptr()); + else + return false; + } else if (PyFloat_Check(src.ptr())) { + return false; + } else if (std::is_unsigned::value) { + py_value = as_unsigned(src.ptr()); + } else { // signed integer: + py_value = sizeof(T) <= sizeof(long) + ? (py_type) PyLong_AsLong(src.ptr()) + : (py_type) PYBIND11_LONG_AS_LONGLONG(src.ptr()); + } + + bool py_err = py_value == (py_type) -1 && PyErr_Occurred(); + if (py_err || (std::is_integral::value && sizeof(py_type) != sizeof(T) && + (py_value < (py_type) std::numeric_limits::min() || + py_value > (py_type) std::numeric_limits::max()))) { + bool type_error = py_err && PyErr_ExceptionMatches( +#if PY_VERSION_HEX < 0x03000000 && !defined(PYPY_VERSION) + PyExc_SystemError +#else + PyExc_TypeError +#endif + ); + PyErr_Clear(); + if (type_error && convert && PyNumber_Check(src.ptr())) { + auto tmp = reinterpret_steal(std::is_floating_point::value + ? PyNumber_Float(src.ptr()) + : PyNumber_Long(src.ptr())); + PyErr_Clear(); + return load(tmp, false); + } + return false; + } + + value = (T) py_value; + return true; + } + + template + static typename std::enable_if::value, handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyFloat_FromDouble((double) src); + } + + template + static typename std::enable_if::value && std::is_signed::value && (sizeof(U) <= sizeof(long)), handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PYBIND11_LONG_FROM_SIGNED((long) src); + } + + template + static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) <= sizeof(unsigned long)), handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src); + } + + template + static typename std::enable_if::value && std::is_signed::value && (sizeof(U) > sizeof(long)), handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyLong_FromLongLong((long long) src); + } + + template + static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) > sizeof(unsigned long)), handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyLong_FromUnsignedLongLong((unsigned long long) src); + } + + PYBIND11_TYPE_CASTER(T, _::value>("int", "float")); +}; + +template struct void_caster { +public: + bool load(handle src, bool) { + if (src && src.is_none()) + return true; + return false; + } + static handle cast(T, return_value_policy /* policy */, handle /* parent */) { + return none().inc_ref(); + } + PYBIND11_TYPE_CASTER(T, _("None")); +}; + +template <> class type_caster : public void_caster {}; + +template <> class type_caster : public type_caster { +public: + using type_caster::cast; + + bool load(handle h, bool) { + if (!h) { + return false; + } else if (h.is_none()) { + value = nullptr; + return true; + } + + /* Check if this is a capsule */ + if (isinstance(h)) { + value = reinterpret_borrow(h); + return true; + } + + /* Check if this is a C++ type */ + auto &bases = all_type_info((PyTypeObject *) h.get_type().ptr()); + if (bases.size() == 1) { // Only allowing loading from a single-value type + value = values_and_holders(reinterpret_cast(h.ptr())).begin()->value_ptr(); + return true; + } + + /* Fail */ + return false; + } + + static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) { + if (ptr) + return capsule(ptr).release(); + else + return none().inc_ref(); + } + + template using cast_op_type = void*&; + operator void *&() { return value; } + static constexpr auto name = _("capsule"); +private: + void *value = nullptr; +}; + +template <> class type_caster : public void_caster { }; + +template <> class type_caster { +public: + bool load(handle src, bool convert) { + if (!src) return false; + else if (src.ptr() == Py_True) { value = true; return true; } + else if (src.ptr() == Py_False) { value = false; return true; } + else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { + // (allow non-implicit conversion for numpy booleans) + + Py_ssize_t res = -1; + if (src.is_none()) { + res = 0; // None is implicitly converted to False + } + #if defined(PYPY_VERSION) + // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists + else if (hasattr(src, PYBIND11_BOOL_ATTR)) { + res = PyObject_IsTrue(src.ptr()); + } + #else + // Alternate approach for CPython: this does the same as the above, but optimized + // using the CPython API so as to avoid an unneeded attribute lookup. + else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) { + if (PYBIND11_NB_BOOL(tp_as_number)) { + res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); + } + } + #endif + if (res == 0 || res == 1) { + value = (bool) res; + return true; + } + } + return false; + } + static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { + return handle(src ? Py_True : Py_False).inc_ref(); + } + PYBIND11_TYPE_CASTER(bool, _("bool")); +}; + +// Helper class for UTF-{8,16,32} C++ stl strings: +template struct string_caster { + using CharT = typename StringType::value_type; + + // Simplify life by being able to assume standard char sizes (the standard only guarantees + // minimums, but Python requires exact sizes) + static_assert(!std::is_same::value || sizeof(CharT) == 1, "Unsupported char size != 1"); + static_assert(!std::is_same::value || sizeof(CharT) == 2, "Unsupported char16_t size != 2"); + static_assert(!std::is_same::value || sizeof(CharT) == 4, "Unsupported char32_t size != 4"); + // wchar_t can be either 16 bits (Windows) or 32 (everywhere else) + static_assert(!std::is_same::value || sizeof(CharT) == 2 || sizeof(CharT) == 4, + "Unsupported wchar_t size != 2/4"); + static constexpr size_t UTF_N = 8 * sizeof(CharT); + + bool load(handle src, bool) { +#if PY_MAJOR_VERSION < 3 + object temp; +#endif + handle load_src = src; + if (!src) { + return false; + } else if (!PyUnicode_Check(load_src.ptr())) { +#if PY_MAJOR_VERSION >= 3 + return load_bytes(load_src); +#else + if (sizeof(CharT) == 1) { + return load_bytes(load_src); + } + + // The below is a guaranteed failure in Python 3 when PyUnicode_Check returns false + if (!PYBIND11_BYTES_CHECK(load_src.ptr())) + return false; + + temp = reinterpret_steal(PyUnicode_FromObject(load_src.ptr())); + if (!temp) { PyErr_Clear(); return false; } + load_src = temp; +#endif + } + + object utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( + load_src.ptr(), UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr)); + if (!utfNbytes) { PyErr_Clear(); return false; } + + const CharT *buffer = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); + size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); + if (UTF_N > 8) { buffer++; length--; } // Skip BOM for UTF-16/32 + value = StringType(buffer, length); + + // If we're loading a string_view we need to keep the encoded Python object alive: + if (IsView) + loader_life_support::add_patient(utfNbytes); + + return true; + } + + static handle cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) { + const char *buffer = reinterpret_cast(src.data()); + ssize_t nbytes = ssize_t(src.size() * sizeof(CharT)); + handle s = decode_utfN(buffer, nbytes); + if (!s) throw error_already_set(); + return s; + } + + PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME)); + +private: + static handle decode_utfN(const char *buffer, ssize_t nbytes) { +#if !defined(PYPY_VERSION) + return + UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) : + UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) : + PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); +#else + // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 version + // sometimes segfaults for unknown reasons, while the UTF16 and 32 versions require a + // non-const char * arguments, which is also a nuisance, so bypass the whole thing by just + // passing the encoding as a string value, which works properly: + return PyUnicode_Decode(buffer, nbytes, UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr); +#endif + } + + // When loading into a std::string or char*, accept a bytes object as-is (i.e. + // without any encoding/decoding attempt). For other C++ char sizes this is a no-op. + // which supports loading a unicode from a str, doesn't take this path. + template + bool load_bytes(enable_if_t src) { + if (PYBIND11_BYTES_CHECK(src.ptr())) { + // We were passed a Python 3 raw bytes; accept it into a std::string or char* + // without any encoding attempt. + const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); + if (bytes) { + value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); + return true; + } + } + + return false; + } + + template + bool load_bytes(enable_if_t) { return false; } +}; + +template +struct type_caster, enable_if_t::value>> + : string_caster> {}; + +#ifdef PYBIND11_HAS_STRING_VIEW +template +struct type_caster, enable_if_t::value>> + : string_caster, true> {}; +#endif + +// Type caster for C-style strings. We basically use a std::string type caster, but also add the +// ability to use None as a nullptr char* (which the string caster doesn't allow). +template struct type_caster::value>> { + using StringType = std::basic_string; + using StringCaster = type_caster; + StringCaster str_caster; + bool none = false; + CharT one_char = 0; +public: + bool load(handle src, bool convert) { + if (!src) return false; + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; + none = true; + return true; + } + return str_caster.load(src, convert); + } + + static handle cast(const CharT *src, return_value_policy policy, handle parent) { + if (src == nullptr) return pybind11::none().inc_ref(); + return StringCaster::cast(StringType(src), policy, parent); + } + + static handle cast(CharT src, return_value_policy policy, handle parent) { + if (std::is_same::value) { + handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr); + if (!s) throw error_already_set(); + return s; + } + return StringCaster::cast(StringType(1, src), policy, parent); + } + + operator CharT*() { return none ? nullptr : const_cast(static_cast(str_caster).c_str()); } + operator CharT&() { + if (none) + throw value_error("Cannot convert None to a character"); + + auto &value = static_cast(str_caster); + size_t str_len = value.size(); + if (str_len == 0) + throw value_error("Cannot convert empty string to a character"); + + // If we're in UTF-8 mode, we have two possible failures: one for a unicode character that + // is too high, and one for multiple unicode characters (caught later), so we need to figure + // out how long the first encoded character is in bytes to distinguish between these two + // errors. We also allow want to allow unicode characters U+0080 through U+00FF, as those + // can fit into a single char value. + if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { + unsigned char v0 = static_cast(value[0]); + size_t char0_bytes = !(v0 & 0x80) ? 1 : // low bits only: 0-127 + (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence + (v0 & 0xF0) == 0xE0 ? 3 : // 0b1110xxxx - start of 3-byte sequence + 4; // 0b11110xxx - start of 4-byte sequence + + if (char0_bytes == str_len) { + // If we have a 128-255 value, we can decode it into a single char: + if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx + one_char = static_cast(((v0 & 3) << 6) + (static_cast(value[1]) & 0x3F)); + return one_char; + } + // Otherwise we have a single character, but it's > U+00FF + throw value_error("Character code point not in range(0x100)"); + } + } + + // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a + // surrogate pair with total length 2 instantly indicates a range error (but not a "your + // string was too long" error). + else if (StringCaster::UTF_N == 16 && str_len == 2) { + one_char = static_cast(value[0]); + if (one_char >= 0xD800 && one_char < 0xE000) + throw value_error("Character code point not in range(0x10000)"); + } + + if (str_len != 1) + throw value_error("Expected a character, but multi-character string found"); + + one_char = value[0]; + return one_char; + } + + static constexpr auto name = _(PYBIND11_STRING_NAME); + template using cast_op_type = pybind11::detail::cast_op_type<_T>; +}; + +// Base implementation for std::tuple and std::pair +template class Tuple, typename... Ts> class tuple_caster { + using type = Tuple; + static constexpr auto size = sizeof...(Ts); + using indices = make_index_sequence; +public: + + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + const auto seq = reinterpret_borrow(src); + if (seq.size() != size) + return false; + return load_impl(seq, convert, indices{}); + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + return cast_impl(std::forward(src), policy, parent, indices{}); + } + + static constexpr auto name = _("Tuple[") + concat(make_caster::name...) + _("]"); + + template using cast_op_type = type; + + operator type() & { return implicit_cast(indices{}); } + operator type() && { return std::move(*this).implicit_cast(indices{}); } + +protected: + template + type implicit_cast(index_sequence) & { return type(cast_op(std::get(subcasters))...); } + template + type implicit_cast(index_sequence) && { return type(cast_op(std::move(std::get(subcasters)))...); } + + static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; } + + template + bool load_impl(const sequence &seq, bool convert, index_sequence) { + for (bool r : {std::get(subcasters).load(seq[Is], convert)...}) + if (!r) + return false; + return true; + } + + /* Implementation: Convert a C++ tuple into a Python tuple */ + template + static handle cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence) { + std::array entries{{ + reinterpret_steal(make_caster::cast(std::get(std::forward(src)), policy, parent))... + }}; + for (const auto &entry: entries) + if (!entry) + return handle(); + tuple result(size); + int counter = 0; + for (auto & entry: entries) + PyTuple_SET_ITEM(result.ptr(), counter++, entry.release().ptr()); + return result.release(); + } + + Tuple...> subcasters; +}; + +template class type_caster> + : public tuple_caster {}; + +template class type_caster> + : public tuple_caster {}; + +/// Helper class which abstracts away certain actions. Users can provide specializations for +/// custom holders, but it's only necessary if the type has a non-standard interface. +template +struct holder_helper { + static auto get(const T &p) -> decltype(p.get()) { return p.get(); } +}; + +/// Type caster for holder types like std::shared_ptr, etc. +template +struct copyable_holder_caster : public type_caster_base { +public: + using base = type_caster_base; + static_assert(std::is_base_of>::value, + "Holder classes are only supported for custom types"); + using base::base; + using base::cast; + using base::typeinfo; + using base::value; + + bool load(handle src, bool convert) { + return base::template load_impl>(src, convert); + } + + explicit operator type*() { return this->value; } + explicit operator type&() { return *(this->value); } + explicit operator holder_type*() { return std::addressof(holder); } + + // Workaround for Intel compiler bug + // see pybind11 issue 94 + #if defined(__ICC) || defined(__INTEL_COMPILER) + operator holder_type&() { return holder; } + #else + explicit operator holder_type&() { return holder; } + #endif + + static handle cast(const holder_type &src, return_value_policy, handle) { + const auto *ptr = holder_helper::get(src); + return type_caster_base::cast_holder(ptr, &src); + } + +protected: + friend class type_caster_generic; + void check_holder_compat() { + if (typeinfo->default_holder) + throw cast_error("Unable to load a custom holder type from a default-holder instance"); + } + + bool load_value(value_and_holder &&v_h) { + if (v_h.holder_constructed()) { + value = v_h.value_ptr(); + holder = v_h.template holder(); + return true; + } else { + throw cast_error("Unable to cast from non-held to held instance (T& to Holder) " +#if defined(NDEBUG) + "(compile in debug mode for type information)"); +#else + "of type '" + type_id() + "''"); +#endif + } + } + + template ::value, int> = 0> + bool try_implicit_casts(handle, bool) { return false; } + + template ::value, int> = 0> + bool try_implicit_casts(handle src, bool convert) { + for (auto &cast : typeinfo->implicit_casts) { + copyable_holder_caster sub_caster(*cast.first); + if (sub_caster.load(src, convert)) { + value = cast.second(sub_caster.value); + holder = holder_type(sub_caster.holder, (type *) value); + return true; + } + } + return false; + } + + static bool try_direct_conversions(handle) { return false; } + + + holder_type holder; +}; + +/// Specialize for the common std::shared_ptr, so users don't need to +template +class type_caster> : public copyable_holder_caster> { }; + +template +struct move_only_holder_caster { + static_assert(std::is_base_of, type_caster>::value, + "Holder classes are only supported for custom types"); + + static handle cast(holder_type &&src, return_value_policy, handle) { + auto *ptr = holder_helper::get(src); + return type_caster_base::cast_holder(ptr, std::addressof(src)); + } + static constexpr auto name = type_caster_base::name; +}; + +template +class type_caster> + : public move_only_holder_caster> { }; + +template +using type_caster_holder = conditional_t::value, + copyable_holder_caster, + move_only_holder_caster>; + +template struct always_construct_holder { static constexpr bool value = Value; }; + +/// Create a specialization for custom holder types (silently ignores std::shared_ptr) +#define PYBIND11_DECLARE_HOLDER_TYPE(type, holder_type, ...) \ + namespace pybind11 { namespace detail { \ + template \ + struct always_construct_holder : always_construct_holder { }; \ + template \ + class type_caster::value>> \ + : public type_caster_holder { }; \ + }} + +// PYBIND11_DECLARE_HOLDER_TYPE holder types: +template struct is_holder_type : + std::is_base_of, detail::type_caster> {}; +// Specialization for always-supported unique_ptr holders: +template struct is_holder_type> : + std::true_type {}; + +template struct handle_type_name { static constexpr auto name = _(); }; +template <> struct handle_type_name { static constexpr auto name = _(PYBIND11_BYTES_NAME); }; +template <> struct handle_type_name { static constexpr auto name = _("*args"); }; +template <> struct handle_type_name { static constexpr auto name = _("**kwargs"); }; + +template +struct pyobject_caster { + template ::value, int> = 0> + bool load(handle src, bool /* convert */) { value = src; return static_cast(value); } + + template ::value, int> = 0> + bool load(handle src, bool /* convert */) { + if (!isinstance(src)) + return false; + value = reinterpret_borrow(src); + return true; + } + + static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) { + return src.inc_ref(); + } + PYBIND11_TYPE_CASTER(type, handle_type_name::name); +}; + +template +class type_caster::value>> : public pyobject_caster { }; + +// Our conditions for enabling moving are quite restrictive: +// At compile time: +// - T needs to be a non-const, non-pointer, non-reference type +// - type_caster::operator T&() must exist +// - the type must be move constructible (obviously) +// At run-time: +// - if the type is non-copy-constructible, the object must be the sole owner of the type (i.e. it +// must have ref_count() == 1)h +// If any of the above are not satisfied, we fall back to copying. +template using move_is_plain_type = satisfies_none_of; +template struct move_always : std::false_type {}; +template struct move_always, + negation>, + std::is_move_constructible, + std::is_same>().operator T&()), T&> +>::value>> : std::true_type {}; +template struct move_if_unreferenced : std::false_type {}; +template struct move_if_unreferenced, + negation>, + std::is_move_constructible, + std::is_same>().operator T&()), T&> +>::value>> : std::true_type {}; +template using move_never = none_of, move_if_unreferenced>; + +// Detect whether returning a `type` from a cast on type's type_caster is going to result in a +// reference or pointer to a local variable of the type_caster. Basically, only +// non-reference/pointer `type`s and reference/pointers from a type_caster_generic are safe; +// everything else returns a reference/pointer to a local variable. +template using cast_is_temporary_value_reference = bool_constant< + (std::is_reference::value || std::is_pointer::value) && + !std::is_base_of>::value && + !std::is_same, void>::value +>; + +// When a value returned from a C++ function is being cast back to Python, we almost always want to +// force `policy = move`, regardless of the return value policy the function/method was declared +// with. +template struct return_value_policy_override { + static return_value_policy policy(return_value_policy p) { return p; } +}; + +template struct return_value_policy_override>::value, void>> { + static return_value_policy policy(return_value_policy p) { + return !std::is_lvalue_reference::value && + !std::is_pointer::value + ? return_value_policy::move : p; + } +}; + +// Basic python -> C++ casting; throws if casting fails +template type_caster &load_type(type_caster &conv, const handle &handle) { + if (!conv.load(handle, true)) { +#if defined(NDEBUG) + throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)"); +#else + throw cast_error("Unable to cast Python instance of type " + + (std::string) str(handle.get_type()) + " to C++ type '" + type_id() + "'"); +#endif + } + return conv; +} +// Wrapper around the above that also constructs and returns a type_caster +template make_caster load_type(const handle &handle) { + make_caster conv; + load_type(conv, handle); + return conv; +} + +NAMESPACE_END(detail) + +// pytype -> C++ type +template ::value, int> = 0> +T cast(const handle &handle) { + using namespace detail; + static_assert(!cast_is_temporary_value_reference::value, + "Unable to cast type to reference: value is local to type caster"); + return cast_op(load_type(handle)); +} + +// pytype -> pytype (calls converting constructor) +template ::value, int> = 0> +T cast(const handle &handle) { return T(reinterpret_borrow(handle)); } + +// C++ type -> py::object +template ::value, int> = 0> +object cast(const T &value, return_value_policy policy = return_value_policy::automatic_reference, + handle parent = handle()) { + if (policy == return_value_policy::automatic) + policy = std::is_pointer::value ? return_value_policy::take_ownership : return_value_policy::copy; + else if (policy == return_value_policy::automatic_reference) + policy = std::is_pointer::value ? return_value_policy::reference : return_value_policy::copy; + return reinterpret_steal(detail::make_caster::cast(value, policy, parent)); +} + +template T handle::cast() const { return pybind11::cast(*this); } +template <> inline void handle::cast() const { return; } + +template +detail::enable_if_t::value, T> move(object &&obj) { + if (obj.ref_count() > 1) +#if defined(NDEBUG) + throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references" + " (compile in debug mode for details)"); +#else + throw cast_error("Unable to move from Python " + (std::string) str(obj.get_type()) + + " instance to C++ " + type_id() + " instance: instance has multiple references"); +#endif + + // Move into a temporary and return that, because the reference may be a local value of `conv` + T ret = std::move(detail::load_type(obj).operator T&()); + return ret; +} + +// Calling cast() on an rvalue calls pybind::cast with the object rvalue, which does: +// - If we have to move (because T has no copy constructor), do it. This will fail if the moved +// object has multiple references, but trying to copy will fail to compile. +// - If both movable and copyable, check ref count: if 1, move; otherwise copy +// - Otherwise (not movable), copy. +template detail::enable_if_t::value, T> cast(object &&object) { + return move(std::move(object)); +} +template detail::enable_if_t::value, T> cast(object &&object) { + if (object.ref_count() > 1) + return cast(object); + else + return move(std::move(object)); +} +template detail::enable_if_t::value, T> cast(object &&object) { + return cast(object); +} + +template T object::cast() const & { return pybind11::cast(*this); } +template T object::cast() && { return pybind11::cast(std::move(*this)); } +template <> inline void object::cast() const & { return; } +template <> inline void object::cast() && { return; } + +NAMESPACE_BEGIN(detail) + +// Declared in pytypes.h: +template ::value, int>> +object object_or_cast(T &&o) { return pybind11::cast(std::forward(o)); } + +struct overload_unused {}; // Placeholder type for the unneeded (and dead code) static variable in the OVERLOAD_INT macro +template using overload_caster_t = conditional_t< + cast_is_temporary_value_reference::value, make_caster, overload_unused>; + +// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then +// store the result in the given variable. For other types, this is a no-op. +template enable_if_t::value, T> cast_ref(object &&o, make_caster &caster) { + return cast_op(load_type(caster, o)); +} +template enable_if_t::value, T> cast_ref(object &&, overload_unused &) { + pybind11_fail("Internal error: cast_ref fallback invoked"); } + +// Trampoline use: Having a pybind11::cast with an invalid reference type is going to static_assert, even +// though if it's in dead code, so we provide a "trampoline" to pybind11::cast that only does anything in +// cases where pybind11::cast is valid. +template enable_if_t::value, T> cast_safe(object &&o) { + return pybind11::cast(std::move(o)); } +template enable_if_t::value, T> cast_safe(object &&) { + pybind11_fail("Internal error: cast_safe fallback invoked"); } +template <> inline void cast_safe(object &&) {} + +NAMESPACE_END(detail) + +template +tuple make_tuple() { return tuple(0); } + +template tuple make_tuple(Args&&... args_) { + constexpr size_t size = sizeof...(Args); + std::array args { + { reinterpret_steal(detail::make_caster::cast( + std::forward(args_), policy, nullptr))... } + }; + for (size_t i = 0; i < args.size(); i++) { + if (!args[i]) { +#if defined(NDEBUG) + throw cast_error("make_tuple(): unable to convert arguments to Python object (compile in debug mode for details)"); +#else + std::array argtypes { {type_id()...} }; + throw cast_error("make_tuple(): unable to convert argument of type '" + + argtypes[i] + "' to Python object"); +#endif + } + } + tuple result(size); + int counter = 0; + for (auto &arg_value : args) + PyTuple_SET_ITEM(result.ptr(), counter++, arg_value.release().ptr()); + return result; +} + +/// \ingroup annotations +/// Annotation for arguments +struct arg { + /// Constructs an argument with the name of the argument; if null or omitted, this is a positional argument. + constexpr explicit arg(const char *name = nullptr) : name(name), flag_noconvert(false), flag_none(true) { } + /// Assign a value to this argument + template arg_v operator=(T &&value) const; + /// Indicate that the type should not be converted in the type caster + arg &noconvert(bool flag = true) { flag_noconvert = flag; return *this; } + /// Indicates that the argument should/shouldn't allow None (e.g. for nullable pointer args) + arg &none(bool flag = true) { flag_none = flag; return *this; } + + const char *name; ///< If non-null, this is a named kwargs argument + bool flag_noconvert : 1; ///< If set, do not allow conversion (requires a supporting type caster!) + bool flag_none : 1; ///< If set (the default), allow None to be passed to this argument +}; + +/// \ingroup annotations +/// Annotation for arguments with values +struct arg_v : arg { +private: + template + arg_v(arg &&base, T &&x, const char *descr = nullptr) + : arg(base), + value(reinterpret_steal( + detail::make_caster::cast(x, return_value_policy::automatic, {}) + )), + descr(descr) +#if !defined(NDEBUG) + , type(type_id()) +#endif + { } + +public: + /// Direct construction with name, default, and description + template + arg_v(const char *name, T &&x, const char *descr = nullptr) + : arg_v(arg(name), std::forward(x), descr) { } + + /// Called internally when invoking `py::arg("a") = value` + template + arg_v(const arg &base, T &&x, const char *descr = nullptr) + : arg_v(arg(base), std::forward(x), descr) { } + + /// Same as `arg::noconvert()`, but returns *this as arg_v&, not arg& + arg_v &noconvert(bool flag = true) { arg::noconvert(flag); return *this; } + + /// Same as `arg::nonone()`, but returns *this as arg_v&, not arg& + arg_v &none(bool flag = true) { arg::none(flag); return *this; } + + /// The default value + object value; + /// The (optional) description of the default value + const char *descr; +#if !defined(NDEBUG) + /// The C++ type name of the default value (only available when compiled in debug mode) + std::string type; +#endif +}; + +template +arg_v arg::operator=(T &&value) const { return {std::move(*this), std::forward(value)}; } + +/// Alias for backward compatibility -- to be removed in version 2.0 +template using arg_t = arg_v; + +inline namespace literals { +/** \rst + String literal version of `arg` + \endrst */ +constexpr arg operator"" _a(const char *name, size_t) { return arg(name); } +} + +NAMESPACE_BEGIN(detail) + +// forward declaration (definition in attr.h) +struct function_record; + +/// Internal data associated with a single function call +struct function_call { + function_call(const function_record &f, handle p); // Implementation in attr.h + + /// The function data: + const function_record &func; + + /// Arguments passed to the function: + std::vector args; + + /// The `convert` value the arguments should be loaded with + std::vector args_convert; + + /// Extra references for the optional `py::args` and/or `py::kwargs` arguments (which, if + /// present, are also in `args` but without a reference). + object args_ref, kwargs_ref; + + /// The parent, if any + handle parent; + + /// If this is a call to an initializer, this argument contains `self` + handle init_self; +}; + + +/// Helper class which loads arguments for C++ functions called from Python +template +class argument_loader { + using indices = make_index_sequence; + + template using argument_is_args = std::is_same, args>; + template using argument_is_kwargs = std::is_same, kwargs>; + // Get args/kwargs argument positions relative to the end of the argument list: + static constexpr auto args_pos = constexpr_first() - (int) sizeof...(Args), + kwargs_pos = constexpr_first() - (int) sizeof...(Args); + + static constexpr bool args_kwargs_are_last = kwargs_pos >= - 1 && args_pos >= kwargs_pos - 1; + + static_assert(args_kwargs_are_last, "py::args/py::kwargs are only permitted as the last argument(s) of a function"); + +public: + static constexpr bool has_kwargs = kwargs_pos < 0; + static constexpr bool has_args = args_pos < 0; + + static constexpr auto arg_names = concat(type_descr(make_caster::name)...); + + bool load_args(function_call &call) { + return load_impl_sequence(call, indices{}); + } + + template + enable_if_t::value, Return> call(Func &&f) && { + return std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); + } + + template + enable_if_t::value, void_type> call(Func &&f) && { + std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); + return void_type(); + } + +private: + + static bool load_impl_sequence(function_call &, index_sequence<>) { return true; } + + template + bool load_impl_sequence(function_call &call, index_sequence) { + for (bool r : {std::get(argcasters).load(call.args[Is], call.args_convert[Is])...}) + if (!r) + return false; + return true; + } + + template + Return call_impl(Func &&f, index_sequence, Guard &&) { + return std::forward(f)(cast_op(std::move(std::get(argcasters)))...); + } + + std::tuple...> argcasters; +}; + +/// Helper class which collects only positional arguments for a Python function call. +/// A fancier version below can collect any argument, but this one is optimal for simple calls. +template +class simple_collector { +public: + template + explicit simple_collector(Ts &&...values) + : m_args(pybind11::make_tuple(std::forward(values)...)) { } + + const tuple &args() const & { return m_args; } + dict kwargs() const { return {}; } + + tuple args() && { return std::move(m_args); } + + /// Call a Python function and pass the collected arguments + object call(PyObject *ptr) const { + PyObject *result = PyObject_CallObject(ptr, m_args.ptr()); + if (!result) + throw error_already_set(); + return reinterpret_steal(result); + } + +private: + tuple m_args; +}; + +/// Helper class which collects positional, keyword, * and ** arguments for a Python function call +template +class unpacking_collector { +public: + template + explicit unpacking_collector(Ts &&...values) { + // Tuples aren't (easily) resizable so a list is needed for collection, + // but the actual function call strictly requires a tuple. + auto args_list = list(); + int _[] = { 0, (process(args_list, std::forward(values)), 0)... }; + ignore_unused(_); + + m_args = std::move(args_list); + } + + const tuple &args() const & { return m_args; } + const dict &kwargs() const & { return m_kwargs; } + + tuple args() && { return std::move(m_args); } + dict kwargs() && { return std::move(m_kwargs); } + + /// Call a Python function and pass the collected arguments + object call(PyObject *ptr) const { + PyObject *result = PyObject_Call(ptr, m_args.ptr(), m_kwargs.ptr()); + if (!result) + throw error_already_set(); + return reinterpret_steal(result); + } + +private: + template + void process(list &args_list, T &&x) { + auto o = reinterpret_steal(detail::make_caster::cast(std::forward(x), policy, {})); + if (!o) { +#if defined(NDEBUG) + argument_cast_error(); +#else + argument_cast_error(std::to_string(args_list.size()), type_id()); +#endif + } + args_list.append(o); + } + + void process(list &args_list, detail::args_proxy ap) { + for (const auto &a : ap) + args_list.append(a); + } + + void process(list &/*args_list*/, arg_v a) { + if (!a.name) +#if defined(NDEBUG) + nameless_argument_error(); +#else + nameless_argument_error(a.type); +#endif + + if (m_kwargs.contains(a.name)) { +#if defined(NDEBUG) + multiple_values_error(); +#else + multiple_values_error(a.name); +#endif + } + if (!a.value) { +#if defined(NDEBUG) + argument_cast_error(); +#else + argument_cast_error(a.name, a.type); +#endif + } + m_kwargs[a.name] = a.value; + } + + void process(list &/*args_list*/, detail::kwargs_proxy kp) { + if (!kp) + return; + for (const auto &k : reinterpret_borrow(kp)) { + if (m_kwargs.contains(k.first)) { +#if defined(NDEBUG) + multiple_values_error(); +#else + multiple_values_error(str(k.first)); +#endif + } + m_kwargs[k.first] = k.second; + } + } + + [[noreturn]] static void nameless_argument_error() { + throw type_error("Got kwargs without a name; only named arguments " + "may be passed via py::arg() to a python function call. " + "(compile in debug mode for details)"); + } + [[noreturn]] static void nameless_argument_error(std::string type) { + throw type_error("Got kwargs without a name of type '" + type + "'; only named " + "arguments may be passed via py::arg() to a python function call. "); + } + [[noreturn]] static void multiple_values_error() { + throw type_error("Got multiple values for keyword argument " + "(compile in debug mode for details)"); + } + + [[noreturn]] static void multiple_values_error(std::string name) { + throw type_error("Got multiple values for keyword argument '" + name + "'"); + } + + [[noreturn]] static void argument_cast_error() { + throw cast_error("Unable to convert call argument to Python object " + "(compile in debug mode for details)"); + } + + [[noreturn]] static void argument_cast_error(std::string name, std::string type) { + throw cast_error("Unable to convert call argument '" + name + + "' of type '" + type + "' to Python object"); + } + +private: + tuple m_args; + dict m_kwargs; +}; + +/// Collect only positional arguments for a Python function call +template ...>::value>> +simple_collector collect_arguments(Args &&...args) { + return simple_collector(std::forward(args)...); +} + +/// Collect all arguments, including keywords and unpacking (only instantiated when needed) +template ...>::value>> +unpacking_collector collect_arguments(Args &&...args) { + // Following argument order rules for generalized unpacking according to PEP 448 + static_assert( + constexpr_last() < constexpr_first() + && constexpr_last() < constexpr_first(), + "Invalid function call: positional args must precede keywords and ** unpacking; " + "* unpacking must precede ** unpacking" + ); + return unpacking_collector(std::forward(args)...); +} + +template +template +object object_api::operator()(Args &&...args) const { + return detail::collect_arguments(std::forward(args)...).call(derived().ptr()); +} + +template +template +object object_api::call(Args &&...args) const { + return operator()(std::forward(args)...); +} + +NAMESPACE_END(detail) + +#define PYBIND11_MAKE_OPAQUE(...) \ + namespace pybind11 { namespace detail { \ + template<> class type_caster<__VA_ARGS__> : public type_caster_base<__VA_ARGS__> { }; \ + }} + +/// Lets you pass a type containing a `,` through a macro parameter without needing a separate +/// typedef, e.g.: `PYBIND11_OVERLOAD(PYBIND11_TYPE(ReturnType), PYBIND11_TYPE(Parent), f, arg)` +#define PYBIND11_TYPE(...) __VA_ARGS__ + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/chrono.h b/mmocr/models/textdet/postprocess/include/pybind11/chrono.h new file mode 100644 index 00000000..95ada76e --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/chrono.h @@ -0,0 +1,162 @@ +/* + pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime + + Copyright (c) 2016 Trent Houliston and + Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include +#include +#include +#include + +// Backport the PyDateTime_DELTA functions from Python3.3 if required +#ifndef PyDateTime_DELTA_GET_DAYS +#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) +#endif +#ifndef PyDateTime_DELTA_GET_SECONDS +#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) +#endif +#ifndef PyDateTime_DELTA_GET_MICROSECONDS +#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template class duration_caster { +public: + typedef typename type::rep rep; + typedef typename type::period period; + + typedef std::chrono::duration> days; + + bool load(handle src, bool) { + using namespace std::chrono; + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + if (!src) return false; + // If invoked with datetime.delta object + if (PyDelta_Check(src.ptr())) { + value = type(duration_cast>( + days(PyDateTime_DELTA_GET_DAYS(src.ptr())) + + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) + + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); + return true; + } + // If invoked with a float we assume it is seconds and convert + else if (PyFloat_Check(src.ptr())) { + value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); + return true; + } + else return false; + } + + // If this is a duration just return it back + static const std::chrono::duration& get_duration(const std::chrono::duration &src) { + return src; + } + + // If this is a time_point get the time_since_epoch + template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { + return src.time_since_epoch(); + } + + static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { + using namespace std::chrono; + + // Use overloaded function to get our duration from our source + // Works out if it is a duration or time_point and get the duration + auto d = get_duration(src); + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + // Declare these special duration types so the conversions happen with the correct primitive types (int) + using dd_t = duration>; + using ss_t = duration>; + using us_t = duration; + + auto dd = duration_cast(d); + auto subd = d - dd; + auto ss = duration_cast(subd); + auto us = duration_cast(subd - ss); + return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); + } + + PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); +}; + +// This is for casting times on the system clock into datetime.datetime instances +template class type_caster> { +public: + typedef std::chrono::time_point type; + bool load(handle src, bool) { + using namespace std::chrono; + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + if (!src) return false; + if (PyDateTime_Check(src.ptr())) { + std::tm cal; + cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); + cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); + cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); + cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); + cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; + cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; + cal.tm_isdst = -1; + + value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); + return true; + } + else return false; + } + + static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { + using namespace std::chrono; + + // Lazy initialise the PyDateTime import + if (!PyDateTimeAPI) { PyDateTime_IMPORT; } + + std::time_t tt = system_clock::to_time_t(src); + // this function uses static memory so it's best to copy it out asap just in case + // otherwise other code that is using localtime may break this (not just python code) + std::tm localtime = *std::localtime(&tt); + + // Declare these special duration types so the conversions happen with the correct primitive types (int) + using us_t = duration; + + return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, + localtime.tm_mon + 1, + localtime.tm_mday, + localtime.tm_hour, + localtime.tm_min, + localtime.tm_sec, + (duration_cast(src.time_since_epoch() % seconds(1))).count()); + } + PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); +}; + +// Other clocks that are not the system clock are not measured as datetime.datetime objects +// since they are not measured on calendar time. So instead we just make them timedeltas +// Or if they have passed us a time as a float we convert that +template class type_caster> +: public duration_caster> { +}; + +template class type_caster> +: public duration_caster> { +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/class_support.h b/mmocr/models/textdet/postprocess/include/pybind11/class_support.h new file mode 100644 index 00000000..8e18c4c6 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/class_support.h @@ -0,0 +1,603 @@ +/* + pybind11/class_support.h: Python C API implementation details for py::class_ + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "attr.h" + +NAMESPACE_BEGIN(pybind11) +NAMESPACE_BEGIN(detail) + +inline PyTypeObject *type_incref(PyTypeObject *type) { + Py_INCREF(type); + return type; +} + +#if !defined(PYPY_VERSION) + +/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. +extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { + return PyProperty_Type.tp_descr_get(self, cls, cls); +} + +/// `pybind11_static_property.__set__()`: Just like the above `__get__()`. +extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { + PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); + return PyProperty_Type.tp_descr_set(self, cls, value); +} + +/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` + methods are modified to always use the object type instead of a concrete instance. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + constexpr auto *name = "pybind11_static_property"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) + pybind11_fail("make_static_property_type(): error allocating type!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyProperty_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + type->tp_descr_get = pybind11_static_get; + type->tp_descr_set = pybind11_static_set; + + if (PyType_Ready(type) < 0) + pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + + return type; +} + +#else // PYPY + +/** PyPy has some issues with the above C API, so we evaluate Python code instead. + This function will only be called once so performance isn't really a concern. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + auto d = dict(); + PyObject *result = PyRun_String(R"(\ + class pybind11_static_property(property): + def __get__(self, obj, cls): + return property.__get__(self, cls, cls) + + def __set__(self, obj, value): + cls = obj if isinstance(obj, type) else type(obj) + property.__set__(self, cls, value) + )", Py_file_input, d.ptr(), d.ptr() + ); + if (result == nullptr) + throw error_already_set(); + Py_DECREF(result); + return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); +} + +#endif // PYPY + +/** Types with static properties need to handle `Type.static_prop = x` in a specific way. + By default, Python replaces the `static_property` itself, but for wrapped C++ types + we need to call `static_property.__set__()` in order to propagate the new value to + the underlying C++ data structure. */ +extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) { + // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw + // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + + // The following assignment combinations are possible: + // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` + // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` + // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment + const auto static_prop = (PyObject *) get_internals().static_property_type; + const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop) + && !PyObject_IsInstance(value, static_prop); + if (call_descr_set) { + // Call `static_property.__set__()` instead of replacing the `static_property`. +#if !defined(PYPY_VERSION) + return Py_TYPE(descr)->tp_descr_set(descr, obj, value); +#else + if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { + Py_DECREF(result); + return 0; + } else { + return -1; + } +#endif + } else { + // Replace existing attribute. + return PyType_Type.tp_setattro(obj, name, value); + } +} + +#if PY_MAJOR_VERSION >= 3 +/** + * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing + * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, + * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here + * to do a special case bypass for PyInstanceMethod_Types. + */ +extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + if (descr && PyInstanceMethod_Check(descr)) { + Py_INCREF(descr); + return descr; + } + else { + return PyType_Type.tp_getattro(obj, name); + } +} +#endif + +/** This metaclass is assigned by default to all pybind11 types and is required in order + for static properties to function correctly. Users may override this using `py::metaclass`. + Return value: New reference. */ +inline PyTypeObject* make_default_metaclass() { + constexpr auto *name = "pybind11_type"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) + pybind11_fail("make_default_metaclass(): error allocating metaclass!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyType_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_setattro = pybind11_meta_setattro; +#if PY_MAJOR_VERSION >= 3 + type->tp_getattro = pybind11_meta_getattro; +#endif + + if (PyType_Ready(type) < 0) + pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + + return type; +} + +/// For multiple inheritance types we need to recursively register/deregister base pointers for any +/// base classes with pointers that are difference from the instance value pointer so that we can +/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs. +inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self, + bool (*f)(void * /*parentptr*/, instance * /*self*/)) { + for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { + if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { + for (auto &c : parent_tinfo->implicit_casts) { + if (c.first == tinfo->cpptype) { + auto *parentptr = c.second(valueptr); + if (parentptr != valueptr) + f(parentptr, self); + traverse_offset_bases(parentptr, parent_tinfo, self, f); + break; + } + } + } + } +} + +inline bool register_instance_impl(void *ptr, instance *self) { + get_internals().registered_instances.emplace(ptr, self); + return true; // unused, but gives the same signature as the deregister func +} +inline bool deregister_instance_impl(void *ptr, instance *self) { + auto ®istered_instances = get_internals().registered_instances; + auto range = registered_instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + if (Py_TYPE(self) == Py_TYPE(it->second)) { + registered_instances.erase(it); + return true; + } + } + return false; +} + +inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { + register_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) + traverse_offset_bases(valptr, tinfo, self, register_instance_impl); +} + +inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { + bool ret = deregister_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) + traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); + return ret; +} + +/// Instance creation function for all pybind11 types. It only allocates space for the C++ object +/// (or multiple objects, for Python-side inheritance from multiple pybind11 types), but doesn't +/// call the constructor -- an `__init__` function must do that (followed by an `init_instance` +/// to set up the holder and register the instance). +inline PyObject *make_new_instance(PyTypeObject *type, bool allocate_value /*= true (in cast.h)*/) { +#if defined(PYPY_VERSION) + // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited + // object is a a plain Python type (i.e. not derived from an extension type). Fix it. + ssize_t instance_size = static_cast(sizeof(instance)); + if (type->tp_basicsize < instance_size) { + type->tp_basicsize = instance_size; + } +#endif + PyObject *self = type->tp_alloc(type, 0); + auto inst = reinterpret_cast(self); + // Allocate the value/holder internals: + inst->allocate_layout(); + + inst->owned = true; + // Allocate (if requested) the value pointers; otherwise leave them as nullptr + if (allocate_value) { + for (auto &v_h : values_and_holders(inst)) { + void *&vptr = v_h.value_ptr(); + vptr = v_h.type->operator_new(v_h.type->type_size); + } + } + + return self; +} + +/// Instance creation function for all pybind11 types. It only allocates space for the +/// C++ object, but doesn't call the constructor -- an `__init__` function must do that. +extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { + return make_new_instance(type); +} + +/// An `__init__` function constructs the C++ object. Users should provide at least one +/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the +/// following default function will be used which simply throws an exception. +extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { + PyTypeObject *type = Py_TYPE(self); + std::string msg; +#if defined(PYPY_VERSION) + msg += handle((PyObject *) type).attr("__module__").cast() + "."; +#endif + msg += type->tp_name; + msg += ": No constructor defined!"; + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return -1; +} + +inline void add_patient(PyObject *nurse, PyObject *patient) { + auto &internals = get_internals(); + auto instance = reinterpret_cast(nurse); + instance->has_patients = true; + Py_INCREF(patient); + internals.patients[nurse].push_back(patient); +} + +inline void clear_patients(PyObject *self) { + auto instance = reinterpret_cast(self); + auto &internals = get_internals(); + auto pos = internals.patients.find(self); + assert(pos != internals.patients.end()); + // Clearing the patients can cause more Python code to run, which + // can invalidate the iterator. Extract the vector of patients + // from the unordered_map first. + auto patients = std::move(pos->second); + internals.patients.erase(pos); + instance->has_patients = false; + for (PyObject *&patient : patients) + Py_CLEAR(patient); +} + +/// Clears all internal data from the instance and removes it from registered instances in +/// preparation for deallocation. +inline void clear_instance(PyObject *self) { + auto instance = reinterpret_cast(self); + + // Deallocate any values/holders, if present: + for (auto &v_h : values_and_holders(instance)) { + if (v_h) { + + // We have to deregister before we call dealloc because, for virtual MI types, we still + // need to be able to get the parent pointers. + if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) + pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); + + if (instance->owned || v_h.holder_constructed()) + v_h.type->dealloc(v_h); + } + } + // Deallocate the value/holder layout internals: + instance->deallocate_layout(); + + if (instance->weakrefs) + PyObject_ClearWeakRefs(self); + + PyObject **dict_ptr = _PyObject_GetDictPtr(self); + if (dict_ptr) + Py_CLEAR(*dict_ptr); + + if (instance->has_patients) + clear_patients(self); +} + +/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` +/// to destroy the C++ object itself, while the rest is Python bookkeeping. +extern "C" inline void pybind11_object_dealloc(PyObject *self) { + clear_instance(self); + Py_TYPE(self)->tp_free(self); +} + +/** Create the type which can be used as a common base for all classes. This is + needed in order to satisfy Python's requirements for multiple inheritance. + Return value: New reference. */ +inline PyObject *make_object_base_type(PyTypeObject *metaclass) { + constexpr auto *name = "pybind11_object"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) + pybind11_fail("make_object_base_type(): error allocating type!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyBaseObject_Type); + type->tp_basicsize = static_cast(sizeof(instance)); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_new = pybind11_object_new; + type->tp_init = pybind11_object_init; + type->tp_dealloc = pybind11_object_dealloc; + + /* Support weak references (needed for the keep_alive feature) */ + type->tp_weaklistoffset = offsetof(instance, weakrefs); + + if (PyType_Ready(type) < 0) + pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string()); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + + assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + return (PyObject *) heap_type; +} + +/// dynamic_attr: Support for `d = instance.__dict__`. +extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + if (!dict) + dict = PyDict_New(); + Py_XINCREF(dict); + return dict; +} + +/// dynamic_attr: Support for `instance.__dict__ = dict()`. +extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) { + if (!PyDict_Check(new_dict)) { + PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'", + Py_TYPE(new_dict)->tp_name); + return -1; + } + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_INCREF(new_dict); + Py_CLEAR(dict); + dict = new_dict; + return 0; +} + +/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. +extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); + return 0; +} + +/// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" inline int pybind11_clear(PyObject *self) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); + return 0; +} + +/// Give instances of this type a `__dict__` and opt into garbage collection. +inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { + auto type = &heap_type->ht_type; +#if defined(PYPY_VERSION) + pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are " + "currently not supported in " + "conjunction with PyPy!"); +#endif + type->tp_flags |= Py_TPFLAGS_HAVE_GC; + type->tp_dictoffset = type->tp_basicsize; // place dict at the end + type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it + type->tp_traverse = pybind11_traverse; + type->tp_clear = pybind11_clear; + + static PyGetSetDef getset[] = { + {const_cast("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr} + }; + type->tp_getset = getset; +} + +/// buffer_protocol: Fill in the view as specified by flags. +extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { + // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). + type_info *tinfo = nullptr; + for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { + tinfo = get_type_info((PyTypeObject *) type.ptr()); + if (tinfo && tinfo->get_buffer) + break; + } + if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) { + if (view) + view->obj = nullptr; + PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); + return -1; + } + std::memset(view, 0, sizeof(Py_buffer)); + buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data); + view->obj = obj; + view->ndim = 1; + view->internal = info; + view->buf = info->ptr; + view->itemsize = info->itemsize; + view->len = view->itemsize; + for (auto s : info->shape) + view->len *= s; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) + view->format = const_cast(info->format.c_str()); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = (int) info->ndim; + view->strides = &info->strides[0]; + view->shape = &info->shape[0]; + } + Py_INCREF(view->obj); + return 0; +} + +/// buffer_protocol: Release the resources of the buffer. +extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { + delete (buffer_info *) view->internal; +} + +/// Give this type a buffer interface. +inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { + heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; +#if PY_MAJOR_VERSION < 3 + heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER; +#endif + + heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; + heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; +} + +/** Create a brand new Python type according to the `type_record` specification. + Return value: New reference. */ +inline PyObject* make_new_python_type(const type_record &rec) { + auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); + +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 + auto ht_qualname = name; + if (rec.scope && hasattr(rec.scope, "__qualname__")) { + ht_qualname = reinterpret_steal( + PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); + } +#endif + + object module; + if (rec.scope) { + if (hasattr(rec.scope, "__module__")) + module = rec.scope.attr("__module__"); + else if (hasattr(rec.scope, "__name__")) + module = rec.scope.attr("__name__"); + } + +#if !defined(PYPY_VERSION) + const auto full_name = module ? str(module).cast() + "." + rec.name + : std::string(rec.name); +#else + const auto full_name = std::string(rec.name); +#endif + + char *tp_doc = nullptr; + if (rec.doc && options::show_user_defined_docstrings()) { + /* Allocate memory for docstring (using PyObject_MALLOC, since + Python will free this later on) */ + size_t size = strlen(rec.doc) + 1; + tp_doc = (char *) PyObject_MALLOC(size); + memcpy((void *) tp_doc, rec.doc, size); + } + + auto &internals = get_internals(); + auto bases = tuple(rec.bases); + auto base = (bases.size() == 0) ? internals.instance_base + : bases[0].ptr(); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() + : internals.default_metaclass; + + auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) + pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); + + heap_type->ht_name = name.release().ptr(); +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 + heap_type->ht_qualname = ht_qualname.release().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = strdup(full_name.c_str()); + type->tp_doc = tp_doc; + type->tp_base = type_incref((PyTypeObject *)base); + type->tp_basicsize = static_cast(sizeof(instance)); + if (bases.size() > 0) + type->tp_bases = bases.release().ptr(); + + /* Don't inherit base __init__ */ + type->tp_init = pybind11_object_init; + + /* Supported protocols */ + type->tp_as_number = &heap_type->as_number; + type->tp_as_sequence = &heap_type->as_sequence; + type->tp_as_mapping = &heap_type->as_mapping; + + /* Flags */ + type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; +#if PY_MAJOR_VERSION < 3 + type->tp_flags |= Py_TPFLAGS_CHECKTYPES; +#endif + + if (rec.dynamic_attr) + enable_dynamic_attributes(heap_type); + + if (rec.buffer_protocol) + enable_buffer_protocol(heap_type); + + if (PyType_Ready(type) < 0) + pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!"); + + assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) + : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + + /* Register type with the parent scope */ + if (rec.scope) + setattr(rec.scope, rec.name, (PyObject *) type); + + if (module) // Needed by pydoc + setattr((PyObject *) type, "__module__", module); + + return (PyObject *) type; +} + +NAMESPACE_END(detail) +NAMESPACE_END(pybind11) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/common.h b/mmocr/models/textdet/postprocess/include/pybind11/common.h new file mode 100644 index 00000000..6c8a4f1e --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/common.h @@ -0,0 +1,2 @@ +#include "detail/common.h" +#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." diff --git a/mmocr/models/textdet/postprocess/include/pybind11/complex.h b/mmocr/models/textdet/postprocess/include/pybind11/complex.h new file mode 100644 index 00000000..3f896385 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/complex.h @@ -0,0 +1,65 @@ +/* + pybind11/complex.h: Complex number support + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include + +/// glibc defines I as a macro which breaks things, e.g., boost template names +#ifdef I +# undef I +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +template struct format_descriptor, detail::enable_if_t::value>> { + static constexpr const char c = format_descriptor::c; + static constexpr const char value[3] = { 'Z', c, '\0' }; + static std::string format() { return std::string(value); } +}; + +#ifndef PYBIND11_CPP17 + +template constexpr const char format_descriptor< + std::complex, detail::enable_if_t::value>>::value[3]; + +#endif + +NAMESPACE_BEGIN(detail) + +template struct is_fmt_numeric, detail::enable_if_t::value>> { + static constexpr bool value = true; + static constexpr int index = is_fmt_numeric::index + 3; +}; + +template class type_caster> { +public: + bool load(handle src, bool convert) { + if (!src) + return false; + if (!convert && !PyComplex_Check(src.ptr())) + return false; + Py_complex result = PyComplex_AsCComplex(src.ptr()); + if (result.real == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + return false; + } + value = std::complex((T) result.real, (T) result.imag); + return true; + } + + static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { + return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); + } + + PYBIND11_TYPE_CASTER(std::complex, _("complex")); +}; +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/descr.h b/mmocr/models/textdet/postprocess/include/pybind11/descr.h new file mode 100644 index 00000000..23a099cf --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/descr.h @@ -0,0 +1,185 @@ +/* + pybind11/descr.h: Helper type for concatenating type signatures + either at runtime (C++11) or compile time (C++14) + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "common.h" + +NAMESPACE_BEGIN(pybind11) +NAMESPACE_BEGIN(detail) + +/* Concatenate type signatures at compile time using C++14 */ +#if defined(PYBIND11_CPP14) && !defined(_MSC_VER) +#define PYBIND11_CONSTEXPR_DESCR + +template class descr { + template friend class descr; +public: + constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) + : descr(text, types, + make_index_sequence(), + make_index_sequence()) { } + + constexpr const char *text() const { return m_text; } + constexpr const std::type_info * const * types() const { return m_types; } + + template + constexpr descr operator+(const descr &other) const { + return concat(other, + make_index_sequence(), + make_index_sequence(), + make_index_sequence(), + make_index_sequence()); + } + +protected: + template + constexpr descr( + char const (&text) [Size1+1], + const std::type_info * const (&types) [Size2+1], + index_sequence, index_sequence) + : m_text{text[Indices1]..., '\0'}, + m_types{types[Indices2]..., nullptr } {} + + template + constexpr descr + concat(const descr &other, + index_sequence, index_sequence, + index_sequence, index_sequence) const { + return descr( + { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, + { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } + ); + } + +protected: + char m_text[Size1 + 1]; + const std::type_info * m_types[Size2 + 1]; +}; + +template constexpr descr _(char const(&text)[Size]) { + return descr(text, { nullptr }); +} + +template struct int_to_str : int_to_str { }; +template struct int_to_str<0, Digits...> { + static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); +}; + +// Ternary description (like std::conditional) +template +constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { + return _(text1); +} +template +constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { + return _(text2); +} +template +constexpr enable_if_t> _(descr d, descr) { return d; } +template +constexpr enable_if_t> _(descr, descr d) { return d; } + +template auto constexpr _() -> decltype(int_to_str::digits) { + return int_to_str::digits; +} + +template constexpr descr<1, 1> _() { + return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); +} + +inline constexpr descr<0, 0> concat() { return _(""); } +template auto constexpr concat(descr descr) { return descr; } +template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } +template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } + +#define PYBIND11_DESCR constexpr auto + +#else /* Simpler C++11 implementation based on run-time memory allocation and copying */ + +class descr { +public: + PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { + size_t nChars = len(text), nTypes = len(types); + m_text = new char[nChars]; + m_types = new const std::type_info *[nTypes]; + memcpy(m_text, text, nChars * sizeof(char)); + memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); + } + + PYBIND11_NOINLINE descr operator+(descr &&d2) && { + descr r; + + size_t nChars1 = len(m_text), nTypes1 = len(m_types); + size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); + + r.m_text = new char[nChars1 + nChars2 - 1]; + r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; + memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); + memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); + memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); + memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); + + delete[] m_text; delete[] m_types; + delete[] d2.m_text; delete[] d2.m_types; + + return r; + } + + char *text() { return m_text; } + const std::type_info * * types() { return m_types; } + +protected: + PYBIND11_NOINLINE descr() { } + + template static size_t len(const T *ptr) { // return length including null termination + const T *it = ptr; + while (*it++ != (T) 0) + ; + return static_cast(it - ptr); + } + + const std::type_info **m_types = nullptr; + char *m_text = nullptr; +}; + +/* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ + +PYBIND11_NOINLINE inline descr _(const char *text) { + const std::type_info *types[1] = { nullptr }; + return descr(text, types); +} + +template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } +template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } +template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } +template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } + +template PYBIND11_NOINLINE descr _() { + const std::type_info *types[2] = { &typeid(Type), nullptr }; + return descr("%", types); +} + +template PYBIND11_NOINLINE descr _() { + const std::type_info *types[1] = { nullptr }; + return descr(std::to_string(Size).c_str(), types); +} + +PYBIND11_NOINLINE inline descr concat() { return _(""); } +PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } +template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } +PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } + +#define PYBIND11_DESCR ::pybind11::detail::descr +#endif + +NAMESPACE_END(detail) +NAMESPACE_END(pybind11) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/detail/class.h b/mmocr/models/textdet/postprocess/include/pybind11/detail/class.h new file mode 100644 index 00000000..7a5dd013 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/detail/class.h @@ -0,0 +1,622 @@ +/* + pybind11/detail/class.h: Python C API implementation details for py::class_ + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "../attr.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +#if PY_VERSION_HEX >= 0x03030000 +# define PYBIND11_BUILTIN_QUALNAME +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) +#else +// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type +// signatures; in 3.3+ this macro expands to nothing: +# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj) +#endif + +inline PyTypeObject *type_incref(PyTypeObject *type) { + Py_INCREF(type); + return type; +} + +#if !defined(PYPY_VERSION) + +/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. +extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { + return PyProperty_Type.tp_descr_get(self, cls, cls); +} + +/// `pybind11_static_property.__set__()`: Just like the above `__get__()`. +extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { + PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); + return PyProperty_Type.tp_descr_set(self, cls, value); +} + +/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` + methods are modified to always use the object type instead of a concrete instance. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + constexpr auto *name = "pybind11_static_property"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) + pybind11_fail("make_static_property_type(): error allocating type!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyProperty_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + type->tp_descr_get = pybind11_static_get; + type->tp_descr_set = pybind11_static_set; + + if (PyType_Ready(type) < 0) + pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +#else // PYPY + +/** PyPy has some issues with the above C API, so we evaluate Python code instead. + This function will only be called once so performance isn't really a concern. + Return value: New reference. */ +inline PyTypeObject *make_static_property_type() { + auto d = dict(); + PyObject *result = PyRun_String(R"(\ + class pybind11_static_property(property): + def __get__(self, obj, cls): + return property.__get__(self, cls, cls) + + def __set__(self, obj, value): + cls = obj if isinstance(obj, type) else type(obj) + property.__set__(self, cls, value) + )", Py_file_input, d.ptr(), d.ptr() + ); + if (result == nullptr) + throw error_already_set(); + Py_DECREF(result); + return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); +} + +#endif // PYPY + +/** Types with static properties need to handle `Type.static_prop = x` in a specific way. + By default, Python replaces the `static_property` itself, but for wrapped C++ types + we need to call `static_property.__set__()` in order to propagate the new value to + the underlying C++ data structure. */ +extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) { + // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw + // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + + // The following assignment combinations are possible: + // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` + // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` + // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment + const auto static_prop = (PyObject *) get_internals().static_property_type; + const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop) + && !PyObject_IsInstance(value, static_prop); + if (call_descr_set) { + // Call `static_property.__set__()` instead of replacing the `static_property`. +#if !defined(PYPY_VERSION) + return Py_TYPE(descr)->tp_descr_set(descr, obj, value); +#else + if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { + Py_DECREF(result); + return 0; + } else { + return -1; + } +#endif + } else { + // Replace existing attribute. + return PyType_Type.tp_setattro(obj, name, value); + } +} + +#if PY_MAJOR_VERSION >= 3 +/** + * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing + * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, + * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here + * to do a special case bypass for PyInstanceMethod_Types. + */ +extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { + PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); + if (descr && PyInstanceMethod_Check(descr)) { + Py_INCREF(descr); + return descr; + } + else { + return PyType_Type.tp_getattro(obj, name); + } +} +#endif + +/** This metaclass is assigned by default to all pybind11 types and is required in order + for static properties to function correctly. Users may override this using `py::metaclass`. + Return value: New reference. */ +inline PyTypeObject* make_default_metaclass() { + constexpr auto *name = "pybind11_type"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); + if (!heap_type) + pybind11_fail("make_default_metaclass(): error allocating metaclass!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyType_Type); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_setattro = pybind11_meta_setattro; +#if PY_MAJOR_VERSION >= 3 + type->tp_getattro = pybind11_meta_getattro; +#endif + + if (PyType_Ready(type) < 0) + pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + return type; +} + +/// For multiple inheritance types we need to recursively register/deregister base pointers for any +/// base classes with pointers that are difference from the instance value pointer so that we can +/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs. +inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self, + bool (*f)(void * /*parentptr*/, instance * /*self*/)) { + for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { + if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { + for (auto &c : parent_tinfo->implicit_casts) { + if (c.first == tinfo->cpptype) { + auto *parentptr = c.second(valueptr); + if (parentptr != valueptr) + f(parentptr, self); + traverse_offset_bases(parentptr, parent_tinfo, self, f); + break; + } + } + } + } +} + +inline bool register_instance_impl(void *ptr, instance *self) { + get_internals().registered_instances.emplace(ptr, self); + return true; // unused, but gives the same signature as the deregister func +} +inline bool deregister_instance_impl(void *ptr, instance *self) { + auto ®istered_instances = get_internals().registered_instances; + auto range = registered_instances.equal_range(ptr); + for (auto it = range.first; it != range.second; ++it) { + if (Py_TYPE(self) == Py_TYPE(it->second)) { + registered_instances.erase(it); + return true; + } + } + return false; +} + +inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { + register_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) + traverse_offset_bases(valptr, tinfo, self, register_instance_impl); +} + +inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { + bool ret = deregister_instance_impl(valptr, self); + if (!tinfo->simple_ancestors) + traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); + return ret; +} + +/// Instance creation function for all pybind11 types. It allocates the internal instance layout for +/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast +/// to a reference or pointer), and initialization is done by an `__init__` function. +inline PyObject *make_new_instance(PyTypeObject *type) { +#if defined(PYPY_VERSION) + // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited + // object is a a plain Python type (i.e. not derived from an extension type). Fix it. + ssize_t instance_size = static_cast(sizeof(instance)); + if (type->tp_basicsize < instance_size) { + type->tp_basicsize = instance_size; + } +#endif + PyObject *self = type->tp_alloc(type, 0); + auto inst = reinterpret_cast(self); + // Allocate the value/holder internals: + inst->allocate_layout(); + + inst->owned = true; + + return self; +} + +/// Instance creation function for all pybind11 types. It only allocates space for the +/// C++ object, but doesn't call the constructor -- an `__init__` function must do that. +extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { + return make_new_instance(type); +} + +/// An `__init__` function constructs the C++ object. Users should provide at least one +/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the +/// following default function will be used which simply throws an exception. +extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { + PyTypeObject *type = Py_TYPE(self); + std::string msg; +#if defined(PYPY_VERSION) + msg += handle((PyObject *) type).attr("__module__").cast() + "."; +#endif + msg += type->tp_name; + msg += ": No constructor defined!"; + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return -1; +} + +inline void add_patient(PyObject *nurse, PyObject *patient) { + auto &internals = get_internals(); + auto instance = reinterpret_cast(nurse); + instance->has_patients = true; + Py_INCREF(patient); + internals.patients[nurse].push_back(patient); +} + +inline void clear_patients(PyObject *self) { + auto instance = reinterpret_cast(self); + auto &internals = get_internals(); + auto pos = internals.patients.find(self); + assert(pos != internals.patients.end()); + // Clearing the patients can cause more Python code to run, which + // can invalidate the iterator. Extract the vector of patients + // from the unordered_map first. + auto patients = std::move(pos->second); + internals.patients.erase(pos); + instance->has_patients = false; + for (PyObject *&patient : patients) + Py_CLEAR(patient); +} + +/// Clears all internal data from the instance and removes it from registered instances in +/// preparation for deallocation. +inline void clear_instance(PyObject *self) { + auto instance = reinterpret_cast(self); + + // Deallocate any values/holders, if present: + for (auto &v_h : values_and_holders(instance)) { + if (v_h) { + + // We have to deregister before we call dealloc because, for virtual MI types, we still + // need to be able to get the parent pointers. + if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) + pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); + + if (instance->owned || v_h.holder_constructed()) + v_h.type->dealloc(v_h); + } + } + // Deallocate the value/holder layout internals: + instance->deallocate_layout(); + + if (instance->weakrefs) + PyObject_ClearWeakRefs(self); + + PyObject **dict_ptr = _PyObject_GetDictPtr(self); + if (dict_ptr) + Py_CLEAR(*dict_ptr); + + if (instance->has_patients) + clear_patients(self); +} + +/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` +/// to destroy the C++ object itself, while the rest is Python bookkeeping. +extern "C" inline void pybind11_object_dealloc(PyObject *self) { + clear_instance(self); + + auto type = Py_TYPE(self); + type->tp_free(self); + + // `type->tp_dealloc != pybind11_object_dealloc` means that we're being called + // as part of a derived type's dealloc, in which case we're not allowed to decref + // the type here. For cross-module compatibility, we shouldn't compare directly + // with `pybind11_object_dealloc`, but with the common one stashed in internals. + auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base; + if (type->tp_dealloc == pybind11_object_type->tp_dealloc) + Py_DECREF(type); +} + +/** Create the type which can be used as a common base for all classes. This is + needed in order to satisfy Python's requirements for multiple inheritance. + Return value: New reference. */ +inline PyObject *make_object_base_type(PyTypeObject *metaclass) { + constexpr auto *name = "pybind11_object"; + auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) + pybind11_fail("make_object_base_type(): error allocating type!"); + + heap_type->ht_name = name_obj.inc_ref().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = name_obj.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = name; + type->tp_base = type_incref(&PyBaseObject_Type); + type->tp_basicsize = static_cast(sizeof(instance)); + type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; + + type->tp_new = pybind11_object_new; + type->tp_init = pybind11_object_init; + type->tp_dealloc = pybind11_object_dealloc; + + /* Support weak references (needed for the keep_alive feature) */ + type->tp_weaklistoffset = offsetof(instance, weakrefs); + + if (PyType_Ready(type) < 0) + pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string()); + + setattr((PyObject *) type, "__module__", str("pybind11_builtins")); + PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); + + assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + return (PyObject *) heap_type; +} + +/// dynamic_attr: Support for `d = instance.__dict__`. +extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + if (!dict) + dict = PyDict_New(); + Py_XINCREF(dict); + return dict; +} + +/// dynamic_attr: Support for `instance.__dict__ = dict()`. +extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) { + if (!PyDict_Check(new_dict)) { + PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'", + Py_TYPE(new_dict)->tp_name); + return -1; + } + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_INCREF(new_dict); + Py_CLEAR(dict); + dict = new_dict; + return 0; +} + +/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. +extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_VISIT(dict); + return 0; +} + +/// dynamic_attr: Allow the GC to clear the dictionary. +extern "C" inline int pybind11_clear(PyObject *self) { + PyObject *&dict = *_PyObject_GetDictPtr(self); + Py_CLEAR(dict); + return 0; +} + +/// Give instances of this type a `__dict__` and opt into garbage collection. +inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { + auto type = &heap_type->ht_type; +#if defined(PYPY_VERSION) + pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are " + "currently not supported in " + "conjunction with PyPy!"); +#endif + type->tp_flags |= Py_TPFLAGS_HAVE_GC; + type->tp_dictoffset = type->tp_basicsize; // place dict at the end + type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it + type->tp_traverse = pybind11_traverse; + type->tp_clear = pybind11_clear; + + static PyGetSetDef getset[] = { + {const_cast("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr} + }; + type->tp_getset = getset; +} + +/// buffer_protocol: Fill in the view as specified by flags. +extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { + // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). + type_info *tinfo = nullptr; + for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { + tinfo = get_type_info((PyTypeObject *) type.ptr()); + if (tinfo && tinfo->get_buffer) + break; + } + if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) { + if (view) + view->obj = nullptr; + PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); + return -1; + } + std::memset(view, 0, sizeof(Py_buffer)); + buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data); + view->obj = obj; + view->ndim = 1; + view->internal = info; + view->buf = info->ptr; + view->itemsize = info->itemsize; + view->len = view->itemsize; + for (auto s : info->shape) + view->len *= s; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) + view->format = const_cast(info->format.c_str()); + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = (int) info->ndim; + view->strides = &info->strides[0]; + view->shape = &info->shape[0]; + } + Py_INCREF(view->obj); + return 0; +} + +/// buffer_protocol: Release the resources of the buffer. +extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { + delete (buffer_info *) view->internal; +} + +/// Give this type a buffer interface. +inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { + heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; +#if PY_MAJOR_VERSION < 3 + heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER; +#endif + + heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; + heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; +} + +/** Create a brand new Python type according to the `type_record` specification. + Return value: New reference. */ +inline PyObject* make_new_python_type(const type_record &rec) { + auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); + + auto qualname = name; + if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) { +#if PY_MAJOR_VERSION >= 3 + qualname = reinterpret_steal( + PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); +#else + qualname = str(rec.scope.attr("__qualname__").cast() + "." + rec.name); +#endif + } + + object module; + if (rec.scope) { + if (hasattr(rec.scope, "__module__")) + module = rec.scope.attr("__module__"); + else if (hasattr(rec.scope, "__name__")) + module = rec.scope.attr("__name__"); + } + + auto full_name = c_str( +#if !defined(PYPY_VERSION) + module ? str(module).cast() + "." + rec.name : +#endif + rec.name); + + char *tp_doc = nullptr; + if (rec.doc && options::show_user_defined_docstrings()) { + /* Allocate memory for docstring (using PyObject_MALLOC, since + Python will free this later on) */ + size_t size = strlen(rec.doc) + 1; + tp_doc = (char *) PyObject_MALLOC(size); + memcpy((void *) tp_doc, rec.doc, size); + } + + auto &internals = get_internals(); + auto bases = tuple(rec.bases); + auto base = (bases.size() == 0) ? internals.instance_base + : bases[0].ptr(); + + /* Danger zone: from now (and until PyType_Ready), make sure to + issue no Python C API calls which could potentially invoke the + garbage collector (the GC will call type_traverse(), which will in + turn find the newly constructed type in an invalid state) */ + auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() + : internals.default_metaclass; + + auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); + if (!heap_type) + pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); + + heap_type->ht_name = name.release().ptr(); +#ifdef PYBIND11_BUILTIN_QUALNAME + heap_type->ht_qualname = qualname.inc_ref().ptr(); +#endif + + auto type = &heap_type->ht_type; + type->tp_name = full_name; + type->tp_doc = tp_doc; + type->tp_base = type_incref((PyTypeObject *)base); + type->tp_basicsize = static_cast(sizeof(instance)); + if (bases.size() > 0) + type->tp_bases = bases.release().ptr(); + + /* Don't inherit base __init__ */ + type->tp_init = pybind11_object_init; + + /* Supported protocols */ + type->tp_as_number = &heap_type->as_number; + type->tp_as_sequence = &heap_type->as_sequence; + type->tp_as_mapping = &heap_type->as_mapping; + + /* Flags */ + type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; +#if PY_MAJOR_VERSION < 3 + type->tp_flags |= Py_TPFLAGS_CHECKTYPES; +#endif + + if (rec.dynamic_attr) + enable_dynamic_attributes(heap_type); + + if (rec.buffer_protocol) + enable_buffer_protocol(heap_type); + + if (PyType_Ready(type) < 0) + pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!"); + + assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) + : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); + + /* Register type with the parent scope */ + if (rec.scope) + setattr(rec.scope, rec.name, (PyObject *) type); + else + Py_INCREF(type); // Keep it alive forever (reference leak) + + if (module) // Needed by pydoc + setattr((PyObject *) type, "__module__", module); + + PYBIND11_SET_OLDPY_QUALNAME(type, qualname); + + return (PyObject *) type; +} + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/detail/common.h b/mmocr/models/textdet/postprocess/include/pybind11/detail/common.h new file mode 100644 index 00000000..5ff74856 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/detail/common.h @@ -0,0 +1,807 @@ +/* + pybind11/detail/common.h -- Basic macros + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#if !defined(NAMESPACE_BEGIN) +# define NAMESPACE_BEGIN(name) namespace name { +#endif +#if !defined(NAMESPACE_END) +# define NAMESPACE_END(name) } +#endif + +// Robust support for some features and loading modules compiled against different pybind versions +// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on +// the main `pybind11` namespace. +#if !defined(PYBIND11_NAMESPACE) +# ifdef __GNUG__ +# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden"))) +# else +# define PYBIND11_NAMESPACE pybind11 +# endif +#endif + +#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER) +# if __cplusplus >= 201402L +# define PYBIND11_CPP14 +# if __cplusplus >= 201703L +# define PYBIND11_CPP17 +# endif +# endif +#elif defined(_MSC_VER) && __cplusplus == 199711L +// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented) +// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer +# if _MSVC_LANG >= 201402L +# define PYBIND11_CPP14 +# if _MSVC_LANG > 201402L && _MSC_VER >= 1910 +# define PYBIND11_CPP17 +# endif +# endif +#endif + +// Compiler version assertions +#if defined(__INTEL_COMPILER) +# if __INTEL_COMPILER < 1700 +# error pybind11 requires Intel C++ compiler v17 or newer +# endif +#elif defined(__clang__) && !defined(__apple_build_version__) +# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3) +# error pybind11 requires clang 3.3 or newer +# endif +#elif defined(__clang__) +// Apple changes clang version macros to its Xcode version; the first Xcode release based on +// (upstream) clang 3.3 was Xcode 5: +# if __clang_major__ < 5 +# error pybind11 requires Xcode/clang 5.0 or newer +# endif +#elif defined(__GNUG__) +# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8) +# error pybind11 requires gcc 4.8 or newer +# endif +#elif defined(_MSC_VER) +// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features +// (e.g. std::negation) added in 2015u3: +# if _MSC_FULL_VER < 190024210 +# error pybind11 requires MSVC 2015 update 3 or newer +# endif +#endif + +#if !defined(PYBIND11_EXPORT) +# if defined(WIN32) || defined(_WIN32) +# define PYBIND11_EXPORT __declspec(dllexport) +# else +# define PYBIND11_EXPORT __attribute__ ((visibility("default"))) +# endif +#endif + +#if defined(_MSC_VER) +# define PYBIND11_NOINLINE __declspec(noinline) +#else +# define PYBIND11_NOINLINE __attribute__ ((noinline)) +#endif + +#if defined(PYBIND11_CPP14) +# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]] +#else +# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason))) +#endif + +#define PYBIND11_VERSION_MAJOR 2 +#define PYBIND11_VERSION_MINOR 3 +#define PYBIND11_VERSION_PATCH dev0 + +/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode +#if defined(_MSC_VER) +# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4) +# define HAVE_ROUND 1 +# endif +# pragma warning(push) +# pragma warning(disable: 4510 4610 4512 4005) +# if defined(_DEBUG) +# define PYBIND11_DEBUG_MARKER +# undef _DEBUG +# endif +#endif + +#include +#include +#include + +#if defined(_WIN32) && (defined(min) || defined(max)) +# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows +#endif + +#if defined(isalnum) +# undef isalnum +# undef isalpha +# undef islower +# undef isspace +# undef isupper +# undef tolower +# undef toupper +#endif + +#if defined(_MSC_VER) +# if defined(PYBIND11_DEBUG_MARKER) +# define _DEBUG +# undef PYBIND11_DEBUG_MARKER +# endif +# pragma warning(pop) +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions +#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr) +#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check +#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION +#define PYBIND11_BYTES_CHECK PyBytes_Check +#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString +#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize +#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize +#define PYBIND11_BYTES_AS_STRING PyBytes_AsString +#define PYBIND11_BYTES_SIZE PyBytes_Size +#define PYBIND11_LONG_CHECK(o) PyLong_Check(o) +#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o) +#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o) +#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o) +#define PYBIND11_BYTES_NAME "bytes" +#define PYBIND11_STRING_NAME "str" +#define PYBIND11_SLICE_OBJECT PyObject +#define PYBIND11_FROM_STRING PyUnicode_FromString +#define PYBIND11_STR_TYPE ::pybind11::str +#define PYBIND11_BOOL_ATTR "__bool__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool) +#define PYBIND11_PLUGIN_IMPL(name) \ + extern "C" PYBIND11_EXPORT PyObject *PyInit_##name() + +#else +#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_) +#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check +#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION +#define PYBIND11_BYTES_CHECK PyString_Check +#define PYBIND11_BYTES_FROM_STRING PyString_FromString +#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize +#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize +#define PYBIND11_BYTES_AS_STRING PyString_AsString +#define PYBIND11_BYTES_SIZE PyString_Size +#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o)) +#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o)) +#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed. +#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed. +#define PYBIND11_BYTES_NAME "str" +#define PYBIND11_STRING_NAME "unicode" +#define PYBIND11_SLICE_OBJECT PySliceObject +#define PYBIND11_FROM_STRING PyString_FromString +#define PYBIND11_STR_TYPE ::pybind11::bytes +#define PYBIND11_BOOL_ATTR "__nonzero__" +#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero) +#define PYBIND11_PLUGIN_IMPL(name) \ + static PyObject *pybind11_init_wrapper(); \ + extern "C" PYBIND11_EXPORT void init##name() { \ + (void)pybind11_init_wrapper(); \ + } \ + PyObject *pybind11_init_wrapper() +#endif + +#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200 +extern "C" { + struct _Py_atomic_address { void *value; }; + PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current; +} +#endif + +#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code +#define PYBIND11_STRINGIFY(x) #x +#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x) +#define PYBIND11_CONCAT(first, second) first##second + +#define PYBIND11_CHECK_PYTHON_VERSION \ + { \ + const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \ + "." PYBIND11_TOSTRING(PY_MINOR_VERSION); \ + const char *runtime_ver = Py_GetVersion(); \ + size_t len = std::strlen(compiled_ver); \ + if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \ + || (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \ + PyErr_Format(PyExc_ImportError, \ + "Python version mismatch: module was compiled for Python %s, " \ + "but the interpreter version is incompatible: %s.", \ + compiled_ver, runtime_ver); \ + return nullptr; \ + } \ + } + +#define PYBIND11_CATCH_INIT_EXCEPTIONS \ + catch (pybind11::error_already_set &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } catch (const std::exception &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } \ + +/** \rst + ***Deprecated in favor of PYBIND11_MODULE*** + + This macro creates the entry point that will be invoked when the Python interpreter + imports a plugin library. Please create a `module` in the function body and return + the pointer to its underlying Python object at the end. + + .. code-block:: cpp + + PYBIND11_PLUGIN(example) { + pybind11::module m("example", "pybind11 example plugin"); + /// Set up bindings here + return m.ptr(); + } +\endrst */ +#define PYBIND11_PLUGIN(name) \ + PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \ + static PyObject *pybind11_init(); \ + PYBIND11_PLUGIN_IMPL(name) { \ + PYBIND11_CHECK_PYTHON_VERSION \ + try { \ + return pybind11_init(); \ + } PYBIND11_CATCH_INIT_EXCEPTIONS \ + } \ + PyObject *pybind11_init() + +/** \rst + This macro creates the entry point that will be invoked when the Python interpreter + imports an extension module. The module name is given as the fist argument and it + should not be in quotes. The second macro argument defines a variable of type + `py::module` which can be used to initialize the module. + + .. code-block:: cpp + + PYBIND11_MODULE(example, m) { + m.doc() = "pybind11 example module"; + + // Add bindings here + m.def("foo", []() { + return "Hello, World!"; + }); + } +\endrst */ +#define PYBIND11_MODULE(name, variable) \ + static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ + PYBIND11_PLUGIN_IMPL(name) { \ + PYBIND11_CHECK_PYTHON_VERSION \ + auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ + try { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + return m.ptr(); \ + } PYBIND11_CATCH_INIT_EXCEPTIONS \ + } \ + void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) + + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +using ssize_t = Py_ssize_t; +using size_t = std::size_t; + +/// Approach used to cast a previously unknown C++ instance into a Python object +enum class return_value_policy : uint8_t { + /** This is the default return value policy, which falls back to the policy + return_value_policy::take_ownership when the return value is a pointer. + Otherwise, it uses return_value::move or return_value::copy for rvalue + and lvalue references, respectively. See below for a description of what + all of these different policies do. */ + automatic = 0, + + /** As above, but use policy return_value_policy::reference when the return + value is a pointer. This is the default conversion policy for function + arguments when calling Python functions manually from C++ code (i.e. via + handle::operator()). You probably won't need to use this. */ + automatic_reference, + + /** Reference an existing object (i.e. do not create a new copy) and take + ownership. Python will call the destructor and delete operator when the + object’s reference count reaches zero. Undefined behavior ensues when + the C++ side does the same.. */ + take_ownership, + + /** Create a new copy of the returned object, which will be owned by + Python. This policy is comparably safe because the lifetimes of the two + instances are decoupled. */ + copy, + + /** Use std::move to move the return value contents into a new instance + that will be owned by Python. This policy is comparably safe because the + lifetimes of the two instances (move source and destination) are + decoupled. */ + move, + + /** Reference an existing object, but do not take ownership. The C++ side + is responsible for managing the object’s lifetime and deallocating it + when it is no longer used. Warning: undefined behavior will ensue when + the C++ side deletes an object that is still referenced and used by + Python. */ + reference, + + /** This policy only applies to methods and properties. It references the + object without taking ownership similar to the above + return_value_policy::reference policy. In contrast to that policy, the + function or property’s implicit this argument (called the parent) is + considered to be the the owner of the return value (the child). + pybind11 then couples the lifetime of the parent to the child via a + reference relationship that ensures that the parent cannot be garbage + collected while Python is still using the child. More advanced + variations of this scheme are also possible using combinations of + return_value_policy::reference and the keep_alive call policy */ + reference_internal +}; + +NAMESPACE_BEGIN(detail) + +inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); } + +// Returns the size as a multiple of sizeof(void *), rounded up. +inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); } + +/** + * The space to allocate for simple layout instance holders (see below) in multiple of the size of + * a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required + * to holder either a std::unique_ptr or std::shared_ptr (which is almost always + * sizeof(std::shared_ptr)). + */ +constexpr size_t instance_simple_holder_in_ptrs() { + static_assert(sizeof(std::shared_ptr) >= sizeof(std::unique_ptr), + "pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs"); + return size_in_ptrs(sizeof(std::shared_ptr)); +} + +// Forward declarations +struct type_info; +struct value_and_holder; + +struct nonsimple_values_and_holders { + void **values_and_holders; + uint8_t *status; +}; + +/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof') +struct instance { + PyObject_HEAD + /// Storage for pointers and holder; see simple_layout, below, for a description + union { + void *simple_value_holder[1 + instance_simple_holder_in_ptrs()]; + nonsimple_values_and_holders nonsimple; + }; + /// Weak references + PyObject *weakrefs; + /// If true, the pointer is owned which means we're free to manage it with a holder. + bool owned : 1; + /** + * An instance has two possible value/holder layouts. + * + * Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer + * and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied + * whenever there is no python-side multiple inheritance of bound C++ types *and* the type's + * holder will fit in the default space (which is large enough to hold either a std::unique_ptr + * or std::shared_ptr). + * + * Non-simple layout applies when using custom holders that require more space than `shared_ptr` + * (which is typically the size of two pointers), or when multiple inheritance is used on the + * python side. Non-simple layout allocates the required amount of memory to have multiple + * bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a + * pointer to allocated space of the required space to hold a sequence of value pointers and + * holders followed `status`, a set of bit flags (1 byte each), i.e. + * [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of + * `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the + * beginning of the [bb...] block (but not independently allocated). + * + * Status bits indicate whether the associated holder is constructed (& + * status_holder_constructed) and whether the value pointer is registered (& + * status_instance_registered) in `registered_instances`. + */ + bool simple_layout : 1; + /// For simple layout, tracks whether the holder has been constructed + bool simple_holder_constructed : 1; + /// For simple layout, tracks whether the instance is registered in `registered_instances` + bool simple_instance_registered : 1; + /// If true, get_internals().patients has an entry for this object + bool has_patients : 1; + + /// Initializes all of the above type/values/holders data (but not the instance values themselves) + void allocate_layout(); + + /// Destroys/deallocates all of the above + void deallocate_layout(); + + /// Returns the value_and_holder wrapper for the given type (or the first, if `find_type` + /// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if + /// `throw_if_missing` is false. + value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true); + + /// Bit values for the non-simple status flags + static constexpr uint8_t status_holder_constructed = 1; + static constexpr uint8_t status_instance_registered = 2; +}; + +static_assert(std::is_standard_layout::value, "Internal error: `pybind11::detail::instance` is not standard layout!"); + +/// from __cpp_future__ import (convenient aliases from C++14/17) +#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910) +using std::enable_if_t; +using std::conditional_t; +using std::remove_cv_t; +using std::remove_reference_t; +#else +template using enable_if_t = typename std::enable_if::type; +template using conditional_t = typename std::conditional::type; +template using remove_cv_t = typename std::remove_cv::type; +template using remove_reference_t = typename std::remove_reference::type; +#endif + +/// Index sequences +#if defined(PYBIND11_CPP14) +using std::index_sequence; +using std::make_index_sequence; +#else +template struct index_sequence { }; +template struct make_index_sequence_impl : make_index_sequence_impl { }; +template struct make_index_sequence_impl <0, S...> { typedef index_sequence type; }; +template using make_index_sequence = typename make_index_sequence_impl::type; +#endif + +/// Make an index sequence of the indices of true arguments +template struct select_indices_impl { using type = ISeq; }; +template struct select_indices_impl, I, B, Bs...> + : select_indices_impl, index_sequence>, I + 1, Bs...> {}; +template using select_indices = typename select_indices_impl, 0, Bs...>::type; + +/// Backports of std::bool_constant and std::negation to accommodate older compilers +template using bool_constant = std::integral_constant; +template struct negation : bool_constant { }; + +template struct void_t_impl { using type = void; }; +template using void_t = typename void_t_impl::type; + +/// Compile-time all/any/none of that check the boolean value of all template types +#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916)) +template using all_of = bool_constant<(Ts::value && ...)>; +template using any_of = bool_constant<(Ts::value || ...)>; +#elif !defined(_MSC_VER) +template struct bools {}; +template using all_of = std::is_same< + bools, + bools>; +template using any_of = negation...>>; +#else +// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit +// at a slight loss of compilation efficiency). +template using all_of = std::conjunction; +template using any_of = std::disjunction; +#endif +template using none_of = negation>; + +template class... Predicates> using satisfies_all_of = all_of...>; +template class... Predicates> using satisfies_any_of = any_of...>; +template class... Predicates> using satisfies_none_of = none_of...>; + +/// Strip the class from a method type +template struct remove_class { }; +template struct remove_class { typedef R type(A...); }; +template struct remove_class { typedef R type(A...); }; + +/// Helper template to strip away type modifiers +template struct intrinsic_type { typedef T type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template struct intrinsic_type { typedef typename intrinsic_type::type type; }; +template using intrinsic_t = typename intrinsic_type::type; + +/// Helper type to replace 'void' in some expressions +struct void_type { }; + +/// Helper template which holds a list of types +template struct type_list { }; + +/// Compile-time integer sum +#ifdef __cpp_fold_expressions +template constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); } +#else +constexpr size_t constexpr_sum() { return 0; } +template +constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); } +#endif + +NAMESPACE_BEGIN(constexpr_impl) +/// Implementation details for constexpr functions +constexpr int first(int i) { return i; } +template +constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); } + +constexpr int last(int /*i*/, int result) { return result; } +template +constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); } +NAMESPACE_END(constexpr_impl) + +/// Return the index of the first type in Ts which satisfies Predicate. Returns sizeof...(Ts) if +/// none match. +template class Predicate, typename... Ts> +constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate::value...); } + +/// Return the index of the last type in Ts which satisfies Predicate, or -1 if none match. +template class Predicate, typename... Ts> +constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate::value...); } + +/// Return the Nth element from the parameter pack +template +struct pack_element { using type = typename pack_element::type; }; +template +struct pack_element<0, T, Ts...> { using type = T; }; + +/// Return the one and only type which matches the predicate, or Default if none match. +/// If more than one type matches the predicate, fail at compile-time. +template class Predicate, typename Default, typename... Ts> +struct exactly_one { + static constexpr auto found = constexpr_sum(Predicate::value...); + static_assert(found <= 1, "Found more than one type matching the predicate"); + + static constexpr auto index = found ? constexpr_first() : 0; + using type = conditional_t::type, Default>; +}; +template class P, typename Default> +struct exactly_one { using type = Default; }; + +template class Predicate, typename Default, typename... Ts> +using exactly_one_t = typename exactly_one::type; + +/// Defer the evaluation of type T until types Us are instantiated +template struct deferred_type { using type = T; }; +template using deferred_t = typename deferred_type::type; + +/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of::value == false`, +/// unlike `std::is_base_of`) +template using is_strict_base_of = bool_constant< + std::is_base_of::value && !std::is_same::value>; + +/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer +/// can be converted to a Base pointer) +template using is_accessible_base_of = bool_constant< + std::is_base_of::value && std::is_convertible::value>; + +template class Base> +struct is_template_base_of_impl { + template static std::true_type check(Base *); + static std::false_type check(...); +}; + +/// Check if a template is the base of a type. For example: +/// `is_template_base_of` is true if `struct T : Base {}` where U can be anything +template class Base, typename T> +#if !defined(_MSC_VER) +using is_template_base_of = decltype(is_template_base_of_impl::check((intrinsic_t*)nullptr)); +#else // MSVC2015 has trouble with decltype in template aliases +struct is_template_base_of : decltype(is_template_base_of_impl::check((intrinsic_t*)nullptr)) { }; +#endif + +/// Check if T is an instantiation of the template `Class`. For example: +/// `is_instantiation` is true if `T == shared_ptr` where U can be anything. +template class Class, typename T> +struct is_instantiation : std::false_type { }; +template class Class, typename... Us> +struct is_instantiation> : std::true_type { }; + +/// Check if T is std::shared_ptr where U can be anything +template using is_shared_ptr = is_instantiation; + +/// Check if T looks like an input iterator +template struct is_input_iterator : std::false_type {}; +template +struct is_input_iterator()), decltype(++std::declval())>> + : std::true_type {}; + +template using is_function_pointer = bool_constant< + std::is_pointer::value && std::is_function::type>::value>; + +template struct strip_function_object { + using type = typename remove_class::type; +}; + +// Extracts the function signature from a function, function pointer or lambda. +template > +using function_signature_t = conditional_t< + std::is_function::value, + F, + typename conditional_t< + std::is_pointer::value || std::is_member_pointer::value, + std::remove_pointer, + strip_function_object + >::type +>; + +/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member +/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used +/// in a place where passing a lambda makes sense. +template using is_lambda = satisfies_none_of, + std::is_function, std::is_pointer, std::is_member_pointer>; + +/// Ignore that a variable is unused in compiler warnings +inline void ignore_unused(const int *) { } + +/// Apply a function over each element of a parameter pack +#ifdef __cpp_fold_expressions +#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...) +#else +using expand_side_effects = bool[]; +#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false } +#endif + +NAMESPACE_END(detail) + +/// C++ bindings of builtin Python exceptions +class builtin_exception : public std::runtime_error { +public: + using std::runtime_error::runtime_error; + /// Set the error using the Python C API + virtual void set_error() const = 0; +}; + +#define PYBIND11_RUNTIME_EXCEPTION(name, type) \ + class name : public builtin_exception { public: \ + using builtin_exception::builtin_exception; \ + name() : name("") { } \ + void set_error() const override { PyErr_SetString(type, what()); } \ + }; + +PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration) +PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError) +PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError) +PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError) +PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError) +PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error +PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally + +[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); } +[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); } + +template struct format_descriptor { }; + +NAMESPACE_BEGIN(detail) +// Returns the index of the given type in the type char array below, and in the list in numpy.h +// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double; +// complex float,double,long double. Note that the long double types only participate when long +// double is actually longer than double (it isn't under MSVC). +// NB: not only the string below but also complex.h and numpy.h rely on this order. +template struct is_fmt_numeric { static constexpr bool value = false; }; +template struct is_fmt_numeric::value>> { + static constexpr bool value = true; + static constexpr int index = std::is_same::value ? 0 : 1 + ( + std::is_integral::value ? detail::log2(sizeof(T))*2 + std::is_unsigned::value : 8 + ( + std::is_same::value ? 1 : std::is_same::value ? 2 : 0)); +}; +NAMESPACE_END(detail) + +template struct format_descriptor::value>> { + static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric::index]; + static constexpr const char value[2] = { c, '\0' }; + static std::string format() { return std::string(1, c); } +}; + +#if !defined(PYBIND11_CPP17) + +template constexpr const char format_descriptor< + T, detail::enable_if_t::value>>::value[2]; + +#endif + +/// RAII wrapper that temporarily clears any Python error state +struct error_scope { + PyObject *type, *value, *trace; + error_scope() { PyErr_Fetch(&type, &value, &trace); } + ~error_scope() { PyErr_Restore(type, value, trace); } +}; + +/// Dummy destructor wrapper that can be used to expose classes with a private destructor +struct nodelete { template void operator()(T*) { } }; + +// overload_cast requires variable templates: C++14 +#if defined(PYBIND11_CPP14) +#define PYBIND11_OVERLOAD_CAST 1 + +NAMESPACE_BEGIN(detail) +template +struct overload_cast_impl { + constexpr overload_cast_impl() {} // MSVC 2015 needs this + + template + constexpr auto operator()(Return (*pf)(Args...)) const noexcept + -> decltype(pf) { return pf; } + + template + constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept + -> decltype(pmf) { return pmf; } + + template + constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept + -> decltype(pmf) { return pmf; } +}; +NAMESPACE_END(detail) + +/// Syntax sugar for resolving overloaded function pointers: +/// - regular: static_cast(&Class::func) +/// - sweet: overload_cast(&Class::func) +template +static constexpr detail::overload_cast_impl overload_cast = {}; +// MSVC 2015 only accepts this particular initialization syntax for this variable template. + +/// Const member function selector for overload_cast +/// - regular: static_cast(&Class::func) +/// - sweet: overload_cast(&Class::func, const_) +static constexpr auto const_ = std::true_type{}; + +#else // no overload_cast: providing something that static_assert-fails: +template struct overload_cast { + static_assert(detail::deferred_t::value, + "pybind11::overload_cast<...> requires compiling in C++14 mode"); +}; +#endif // overload_cast + +NAMESPACE_BEGIN(detail) + +// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from +// any standard container (or C-style array) supporting std::begin/std::end, any singleton +// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair. +template +class any_container { + std::vector v; +public: + any_container() = default; + + // Can construct from a pair of iterators + template ::value>> + any_container(It first, It last) : v(first, last) { } + + // Implicit conversion constructor from any arbitrary container type with values convertible to T + template ())), T>::value>> + any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { } + + // initializer_list's aren't deducible, so don't get matched by the above template; we need this + // to explicitly allow implicit conversion from one: + template ::value>> + any_container(const std::initializer_list &c) : any_container(c.begin(), c.end()) { } + + // Avoid copying if given an rvalue vector of the correct type. + any_container(std::vector &&v) : v(std::move(v)) { } + + // Moves the vector out of an rvalue any_container + operator std::vector &&() && { return std::move(v); } + + // Dereferencing obtains a reference to the underlying vector + std::vector &operator*() { return v; } + const std::vector &operator*() const { return v; } + + // -> lets you call methods on the underlying vector + std::vector *operator->() { return &v; } + const std::vector *operator->() const { return &v; } +}; + +NAMESPACE_END(detail) + + + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/detail/descr.h b/mmocr/models/textdet/postprocess/include/pybind11/detail/descr.h new file mode 100644 index 00000000..8d404e53 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/detail/descr.h @@ -0,0 +1,100 @@ +/* + pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "common.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +#if !defined(_MSC_VER) +# define PYBIND11_DESCR_CONSTEXPR static constexpr +#else +# define PYBIND11_DESCR_CONSTEXPR const +#endif + +/* Concatenate type signatures at compile time */ +template +struct descr { + char text[N + 1]; + + constexpr descr() : text{'\0'} { } + constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } + + template + constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } + + template + constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } + + static constexpr std::array types() { + return {{&typeid(Ts)..., nullptr}}; + } +}; + +template +constexpr descr plus_impl(const descr &a, const descr &b, + index_sequence, index_sequence) { + return {a.text[Is1]..., b.text[Is2]...}; +} + +template +constexpr descr operator+(const descr &a, const descr &b) { + return plus_impl(a, b, make_index_sequence(), make_index_sequence()); +} + +template +constexpr descr _(char const(&text)[N]) { return descr(text); } +constexpr descr<0> _(char const(&)[1]) { return {}; } + +template struct int_to_str : int_to_str { }; +template struct int_to_str<0, Digits...> { + static constexpr auto digits = descr(('0' + Digits)...); +}; + +// Ternary description (like std::conditional) +template +constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { + return _(text1); +} +template +constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { + return _(text2); +} + +template +constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } +template +constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } + +template auto constexpr _() -> decltype(int_to_str::digits) { + return int_to_str::digits; +} + +template constexpr descr<1, Type> _() { return {'%'}; } + +constexpr descr<0> concat() { return {}; } + +template +constexpr descr concat(const descr &descr) { return descr; } + +template +constexpr auto concat(const descr &d, const Args &...args) + -> decltype(std::declval>() + concat(args...)) { + return d + _(", ") + concat(args...); +} + +template +constexpr descr type_descr(const descr &descr) { + return _("{") + descr + _("}"); +} + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/detail/init.h b/mmocr/models/textdet/postprocess/include/pybind11/detail/init.h new file mode 100644 index 00000000..acfe00bd --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/detail/init.h @@ -0,0 +1,335 @@ +/* + pybind11/detail/init.h: init factory function implementation and support code. + + Copyright (c) 2017 Jason Rhinelander + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "class.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template <> +class type_caster { +public: + bool load(handle h, bool) { + value = reinterpret_cast(h.ptr()); + return true; + } + + template using cast_op_type = value_and_holder &; + operator value_and_holder &() { return *value; } + static constexpr auto name = _(); + +private: + value_and_holder *value = nullptr; +}; + +NAMESPACE_BEGIN(initimpl) + +inline void no_nullptr(void *ptr) { + if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr"); +} + +// Implementing functions for all forms of py::init<...> and py::init(...) +template using Cpp = typename Class::type; +template using Alias = typename Class::type_alias; +template using Holder = typename Class::holder_type; + +template using is_alias_constructible = std::is_constructible, Cpp &&>; + +// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance. +template = 0> +bool is_alias(Cpp *ptr) { + return dynamic_cast *>(ptr) != nullptr; +} +// Failing fallback version of the above for a no-alias class (always returns false) +template +constexpr bool is_alias(void *) { return false; } + +// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall +// back to brace aggregate initiailization so that for aggregate initialization can be used with +// py::init, e.g. `py::init` to initialize a `struct T { int a; int b; }`. For +// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually +// works, but will not do the expected thing when `T` has an `initializer_list` constructor). +template ::value, int> = 0> +inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward(args)...); } +template ::value, int> = 0> +inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward(args)...}; } + +// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with +// an alias to provide only a single Cpp factory function as long as the Alias can be +// constructed from an rvalue reference of the base Cpp type. This means that Alias classes +// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to +// inherit all the base class constructors. +template +void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/, + value_and_holder &v_h, Cpp &&base) { + v_h.value_ptr() = new Alias(std::move(base)); +} +template +[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/, + value_and_holder &, Cpp &&) { + throw type_error("pybind11::init(): unable to convert returned instance to required " + "alias class: no `Alias(Class &&)` constructor available"); +} + +// Error-generating fallback for factories that don't match one of the below construction +// mechanisms. +template +void construct(...) { + static_assert(!std::is_same::value /* always false */, + "pybind11::init(): init function must return a compatible pointer, " + "holder, or value"); +} + +// Pointer return v1: the factory function returns a class pointer for a registered class. +// If we don't need an alias (because this class doesn't have one, or because the final type is +// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to +// construct an Alias from the returned base instance. +template +void construct(value_and_holder &v_h, Cpp *ptr, bool need_alias) { + no_nullptr(ptr); + if (Class::has_alias && need_alias && !is_alias(ptr)) { + // We're going to try to construct an alias by moving the cpp type. Whether or not + // that succeeds, we still need to destroy the original cpp pointer (either the + // moved away leftover, if the alias construction works, or the value itself if we + // throw an error), but we can't just call `delete ptr`: it might have a special + // deleter, or might be shared_from_this. So we construct a holder around it as if + // it was a normal instance, then steal the holder away into a local variable; thus + // the holder and destruction happens when we leave the C++ scope, and the holder + // class gets to handle the destruction however it likes. + v_h.value_ptr() = ptr; + v_h.set_instance_registered(true); // To prevent init_instance from registering it + v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder + Holder temp_holder(std::move(v_h.holder>())); // Steal the holder + v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null + v_h.set_instance_registered(false); + + construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(*ptr)); + } else { + // Otherwise the type isn't inherited, so we don't need an Alias + v_h.value_ptr() = ptr; + } +} + +// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over +// ownership of the pointer. +template = 0> +void construct(value_and_holder &v_h, Alias *alias_ptr, bool) { + no_nullptr(alias_ptr); + v_h.value_ptr() = static_cast *>(alias_ptr); +} + +// Holder return: copy its pointer, and move or copy the returned holder into the new instance's +// holder. This also handles types like std::shared_ptr and std::unique_ptr where T is a +// derived type (through those holder's implicit conversion from derived class holder constructors). +template +void construct(value_and_holder &v_h, Holder holder, bool need_alias) { + auto *ptr = holder_helper>::get(holder); + // If we need an alias, check that the held pointer is actually an alias instance + if (Class::has_alias && need_alias && !is_alias(ptr)) + throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance " + "is not an alias instance"); + + v_h.value_ptr() = ptr; + v_h.type->init_instance(v_h.inst, &holder); +} + +// return-by-value version 1: returning a cpp class by value. If the class has an alias and an +// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct +// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't +// need it, we simply move-construct the cpp value into a new instance. +template +void construct(value_and_holder &v_h, Cpp &&result, bool need_alias) { + static_assert(std::is_move_constructible>::value, + "pybind11::init() return-by-value factory function requires a movable class"); + if (Class::has_alias && need_alias) + construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(result)); + else + v_h.value_ptr() = new Cpp(std::move(result)); +} + +// return-by-value version 2: returning a value of the alias type itself. We move-construct an +// Alias instance (even if no the python-side inheritance is involved). The is intended for +// cases where Alias initialization is always desired. +template +void construct(value_and_holder &v_h, Alias &&result, bool) { + static_assert(std::is_move_constructible>::value, + "pybind11::init() return-by-alias-value factory function requires a movable alias class"); + v_h.value_ptr() = new Alias(std::move(result)); +} + +// Implementing class for py::init<...>() +template +struct constructor { + template = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } + + template , Args...>::value, int> = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + if (Py_TYPE(v_h.inst) == v_h.type->type) + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + else + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } + + template , Args...>::value, int> = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } +}; + +// Implementing class for py::init_alias<...>() +template struct alias_constructor { + template , Args...>::value, int> = 0> + static void execute(Class &cl, const Extra&... extra) { + cl.def("__init__", [](value_and_holder &v_h, Args... args) { + v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); + }, is_new_style_constructor(), extra...); + } +}; + +// Implementation class for py::init(Func) and py::init(Func, AliasFunc) +template , typename = function_signature_t> +struct factory; + +// Specialization for py::init(Func) +template +struct factory { + remove_reference_t class_factory; + + factory(Func &&f) : class_factory(std::forward(f)) { } + + // The given class either has no alias or has no separate alias factory; + // this always constructs the class itself. If the class is registered with an alias + // type and an alias instance is needed (i.e. because the final type is a Python class + // inheriting from the C++ type) the returned value needs to either already be an alias + // instance, or the alias needs to be constructible from a `Class &&` argument. + template + void execute(Class &cl, const Extra &...extra) && { + #if defined(PYBIND11_CPP14) + cl.def("__init__", [func = std::move(class_factory)] + #else + auto &func = class_factory; + cl.def("__init__", [func] + #endif + (value_and_holder &v_h, Args... args) { + construct(v_h, func(std::forward(args)...), + Py_TYPE(v_h.inst) != v_h.type->type); + }, is_new_style_constructor(), extra...); + } +}; + +// Specialization for py::init(Func, AliasFunc) +template +struct factory { + static_assert(sizeof...(CArgs) == sizeof...(AArgs), + "pybind11::init(class_factory, alias_factory): class and alias factories " + "must have identical argument signatures"); + static_assert(all_of...>::value, + "pybind11::init(class_factory, alias_factory): class and alias factories " + "must have identical argument signatures"); + + remove_reference_t class_factory; + remove_reference_t alias_factory; + + factory(CFunc &&c, AFunc &&a) + : class_factory(std::forward(c)), alias_factory(std::forward(a)) { } + + // The class factory is called when the `self` type passed to `__init__` is the direct + // class (i.e. not inherited), the alias factory when `self` is a Python-side subtype. + template + void execute(Class &cl, const Extra&... extra) && { + static_assert(Class::has_alias, "The two-argument version of `py::init()` can " + "only be used if the class has an alias"); + #if defined(PYBIND11_CPP14) + cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)] + #else + auto &class_func = class_factory; + auto &alias_func = alias_factory; + cl.def("__init__", [class_func, alias_func] + #endif + (value_and_holder &v_h, CArgs... args) { + if (Py_TYPE(v_h.inst) == v_h.type->type) + // If the instance type equals the registered type we don't have inheritance, so + // don't need the alias and can construct using the class function: + construct(v_h, class_func(std::forward(args)...), false); + else + construct(v_h, alias_func(std::forward(args)...), true); + }, is_new_style_constructor(), extra...); + } +}; + +/// Set just the C++ state. Same as `__init__`. +template +void setstate(value_and_holder &v_h, T &&result, bool need_alias) { + construct(v_h, std::forward(result), need_alias); +} + +/// Set both the C++ and Python states +template ::value, int> = 0> +void setstate(value_and_holder &v_h, std::pair &&result, bool need_alias) { + construct(v_h, std::move(result.first), need_alias); + setattr((PyObject *) v_h.inst, "__dict__", result.second); +} + +/// Implementation for py::pickle(GetState, SetState) +template , typename = function_signature_t> +struct pickle_factory; + +template +struct pickle_factory { + static_assert(std::is_same, intrinsic_t>::value, + "The type returned by `__getstate__` must be the same " + "as the argument accepted by `__setstate__`"); + + remove_reference_t get; + remove_reference_t set; + + pickle_factory(Get get, Set set) + : get(std::forward(get)), set(std::forward(set)) { } + + template + void execute(Class &cl, const Extra &...extra) && { + cl.def("__getstate__", std::move(get)); + +#if defined(PYBIND11_CPP14) + cl.def("__setstate__", [func = std::move(set)] +#else + auto &func = set; + cl.def("__setstate__", [func] +#endif + (value_and_holder &v_h, ArgState state) { + setstate(v_h, func(std::forward(state)), + Py_TYPE(v_h.inst) != v_h.type->type); + }, is_new_style_constructor(), extra...); + } +}; + +NAMESPACE_END(initimpl) +NAMESPACE_END(detail) +NAMESPACE_END(pybind11) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/detail/internals.h b/mmocr/models/textdet/postprocess/include/pybind11/detail/internals.h new file mode 100644 index 00000000..6d7dc5cf --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/detail/internals.h @@ -0,0 +1,291 @@ +/* + pybind11/detail/internals.h: Internal data structure and related functions + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "../pytypes.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) +// Forward declarations +inline PyTypeObject *make_static_property_type(); +inline PyTypeObject *make_default_metaclass(); +inline PyObject *make_object_base_type(PyTypeObject *metaclass); + +// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new +// Thread Specific Storage (TSS) API. +#if PY_VERSION_HEX >= 0x03070000 +# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr +# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) +# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate)) +# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) +#else + // Usually an int but a long on Cygwin64 with Python 3.x +# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0 +# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) +# if PY_MAJOR_VERSION < 3 +# define PYBIND11_TLS_DELETE_VALUE(key) \ + PyThread_delete_key_value(key) +# define PYBIND11_TLS_REPLACE_VALUE(key, value) \ + do { \ + PyThread_delete_key_value((key)); \ + PyThread_set_key_value((key), (value)); \ + } while (false) +# else +# define PYBIND11_TLS_DELETE_VALUE(key) \ + PyThread_set_key_value((key), nullptr) +# define PYBIND11_TLS_REPLACE_VALUE(key, value) \ + PyThread_set_key_value((key), (value)) +# endif +#endif + +// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly +// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module +// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under +// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name, +// which works. If not under a known-good stl, provide our own name-based hash and equality +// functions that use the type name. +#if defined(__GLIBCXX__) +inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; } +using type_hash = std::hash; +using type_equal_to = std::equal_to; +#else +inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { + return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; +} + +struct type_hash { + size_t operator()(const std::type_index &t) const { + size_t hash = 5381; + const char *ptr = t.name(); + while (auto c = static_cast(*ptr++)) + hash = (hash * 33) ^ c; + return hash; + } +}; + +struct type_equal_to { + bool operator()(const std::type_index &lhs, const std::type_index &rhs) const { + return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; + } +}; +#endif + +template +using type_map = std::unordered_map; + +struct overload_hash { + inline size_t operator()(const std::pair& v) const { + size_t value = std::hash()(v.first); + value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); + return value; + } +}; + +/// Internal data structure used to track registered instances and types. +/// Whenever binary incompatible changes are made to this structure, +/// `PYBIND11_INTERNALS_VERSION` must be incremented. +struct internals { + type_map registered_types_cpp; // std::type_index -> pybind11's type information + std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) + std::unordered_multimap registered_instances; // void * -> instance* + std::unordered_set, overload_hash> inactive_overload_cache; + type_map> direct_conversions; + std::unordered_map> patients; + std::forward_list registered_exception_translators; + std::unordered_map shared_data; // Custom data to be shared across extensions + std::vector loader_patient_stack; // Used by `loader_life_support` + std::forward_list static_strings; // Stores the std::strings backing detail::c_str() + PyTypeObject *static_property_type; + PyTypeObject *default_metaclass; + PyObject *instance_base; +#if defined(WITH_THREAD) + PYBIND11_TLS_KEY_INIT(tstate); + PyInterpreterState *istate = nullptr; +#endif +}; + +/// Additional type information which does not fit into the PyTypeObject. +/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`. +struct type_info { + PyTypeObject *type; + const std::type_info *cpptype; + size_t type_size, type_align, holder_size_in_ptrs; + void *(*operator_new)(size_t); + void (*init_instance)(instance *, const void *); + void (*dealloc)(value_and_holder &v_h); + std::vector implicit_conversions; + std::vector> implicit_casts; + std::vector *direct_conversions; + buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; + void *get_buffer_data = nullptr; + void *(*module_local_load)(PyObject *, const type_info *) = nullptr; + /* A simple type never occurs as a (direct or indirect) parent + * of a class that makes use of multiple inheritance */ + bool simple_type : 1; + /* True if there is no multiple inheritance in this type's inheritance tree */ + bool simple_ancestors : 1; + /* for base vs derived holder_type checks */ + bool default_holder : 1; + /* true if this is a type registered with py::module_local */ + bool module_local : 1; +}; + +/// Tracks the `internals` and `type_info` ABI version independent of the main library version +#define PYBIND11_INTERNALS_VERSION 3 + +#if defined(_DEBUG) +# define PYBIND11_BUILD_TYPE "_debug" +#else +# define PYBIND11_BUILD_TYPE "" +#endif + +#if defined(WITH_THREAD) +# define PYBIND11_INTERNALS_KIND "" +#else +# define PYBIND11_INTERNALS_KIND "_without_thread" +#endif + +#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ + PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" + +#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \ + PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" + +/// Each module locally stores a pointer to the `internals` data. The data +/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`. +inline internals **&get_internals_pp() { + static internals **internals_pp = nullptr; + return internals_pp; +} + +/// Return a reference to the current `internals` data +PYBIND11_NOINLINE inline internals &get_internals() { + auto **&internals_pp = get_internals_pp(); + if (internals_pp && *internals_pp) + return **internals_pp; + + constexpr auto *id = PYBIND11_INTERNALS_ID; + auto builtins = handle(PyEval_GetBuiltins()); + if (builtins.contains(id) && isinstance(builtins[id])) { + internals_pp = static_cast(capsule(builtins[id])); + + // We loaded builtins through python's builtins, which means that our `error_already_set` + // and `builtin_exception` may be different local classes than the ones set up in the + // initial exception translator, below, so add another for our local exception classes. + // + // libstdc++ doesn't require this (types there are identified only by name) +#if !defined(__GLIBCXX__) + (*internals_pp)->registered_exception_translators.push_front( + [](std::exception_ptr p) -> void { + try { + if (p) std::rethrow_exception(p); + } catch (error_already_set &e) { e.restore(); return; + } catch (const builtin_exception &e) { e.set_error(); return; + } + } + ); +#endif + } else { + if (!internals_pp) internals_pp = new internals*(); + auto *&internals_ptr = *internals_pp; + internals_ptr = new internals(); +#if defined(WITH_THREAD) + PyEval_InitThreads(); + PyThreadState *tstate = PyThreadState_Get(); + #if PY_VERSION_HEX >= 0x03070000 + internals_ptr->tstate = PyThread_tss_alloc(); + if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate)) + pybind11_fail("get_internals: could not successfully initialize the TSS key!"); + PyThread_tss_set(internals_ptr->tstate, tstate); + #else + internals_ptr->tstate = PyThread_create_key(); + if (internals_ptr->tstate == -1) + pybind11_fail("get_internals: could not successfully initialize the TLS key!"); + PyThread_set_key_value(internals_ptr->tstate, tstate); + #endif + internals_ptr->istate = tstate->interp; +#endif + builtins[id] = capsule(internals_pp); + internals_ptr->registered_exception_translators.push_front( + [](std::exception_ptr p) -> void { + try { + if (p) std::rethrow_exception(p); + } catch (error_already_set &e) { e.restore(); return; + } catch (const builtin_exception &e) { e.set_error(); return; + } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; + } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; + } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; + } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; + } catch (...) { + PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); + return; + } + } + ); + internals_ptr->static_property_type = make_static_property_type(); + internals_ptr->default_metaclass = make_default_metaclass(); + internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass); + } + return **internals_pp; +} + +/// Works like `internals.registered_types_cpp`, but for module-local registered types: +inline type_map ®istered_local_types_cpp() { + static type_map locals{}; + return locals; +} + +/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its +/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only +/// cleared when the program exits or after interpreter shutdown (when embedding), and so are +/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name). +template +const char *c_str(Args &&...args) { + auto &strings = get_internals().static_strings; + strings.emplace_front(std::forward(args)...); + return strings.front().c_str(); +} + +NAMESPACE_END(detail) + +/// Returns a named pointer that is shared among all extension modules (using the same +/// pybind11 version) running in the current interpreter. Names starting with underscores +/// are reserved for internal usage. Returns `nullptr` if no matching entry was found. +inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { + auto &internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + return it != internals.shared_data.end() ? it->second : nullptr; +} + +/// Set the shared data that can be later recovered by `get_shared_data()`. +inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { + detail::get_internals().shared_data[name] = data; + return data; +} + +/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if +/// such entry exists. Otherwise, a new object of default-constructible type `T` is +/// added to the shared data under the given name and a reference to it is returned. +template +T &get_or_create_shared_data(const std::string &name) { + auto &internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr); + if (!ptr) { + ptr = new T(); + internals.shared_data[name] = ptr; + } + return *ptr; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/detail/typeid.h b/mmocr/models/textdet/postprocess/include/pybind11/detail/typeid.h new file mode 100644 index 00000000..6f36aab7 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/detail/typeid.h @@ -0,0 +1,53 @@ +/* + pybind11/detail/typeid.h: Compiler-independent access to type identifiers + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include + +#if defined(__GNUG__) +#include +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) +/// Erase all occurrences of a substring +inline void erase_all(std::string &string, const std::string &search) { + for (size_t pos = 0;;) { + pos = string.find(search, pos); + if (pos == std::string::npos) break; + string.erase(pos, search.length()); + } +} + +PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { +#if defined(__GNUG__) + int status = 0; + std::unique_ptr res { + abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; + if (status == 0) + name = res.get(); +#else + detail::erase_all(name, "class "); + detail::erase_all(name, "struct "); + detail::erase_all(name, "enum "); +#endif + detail::erase_all(name, "pybind11::"); +} +NAMESPACE_END(detail) + +/// Return a string representation of a C++ type +template static std::string type_id() { + std::string name(typeid(T).name()); + detail::clean_type_id(name); + return name; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/eigen.h b/mmocr/models/textdet/postprocess/include/pybind11/eigen.h new file mode 100644 index 00000000..d963d965 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/eigen.h @@ -0,0 +1,607 @@ +/* + pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "numpy.h" + +#if defined(__INTEL_COMPILER) +# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) +#elif defined(__GNUG__) || defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wconversion" +# pragma GCC diagnostic ignored "-Wdeprecated-declarations" +# ifdef __clang__ +// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated +// under Clang, so disable that warning here: +# pragma GCC diagnostic ignored "-Wdeprecated" +# endif +# if __GNUC__ >= 7 +# pragma GCC diagnostic ignored "-Wint-in-bool-context" +# endif +#endif + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17 +#endif + +#include +#include + +// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit +// move constructors that break things. We could detect this an explicitly copy, but an extra copy +// of matrices seems highly undesirable. +static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7"); + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides: +using EigenDStride = Eigen::Stride; +template using EigenDRef = Eigen::Ref; +template using EigenDMap = Eigen::Map; + +NAMESPACE_BEGIN(detail) + +#if EIGEN_VERSION_AT_LEAST(3,3,0) +using EigenIndex = Eigen::Index; +#else +using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE; +#endif + +// Matches Eigen::Map, Eigen::Ref, blocks, etc: +template using is_eigen_dense_map = all_of, std::is_base_of, T>>; +template using is_eigen_mutable_map = std::is_base_of, T>; +template using is_eigen_dense_plain = all_of>, is_template_base_of>; +template using is_eigen_sparse = is_template_base_of; +// Test for objects inheriting from EigenBase that aren't captured by the above. This +// basically covers anything that can be assigned to a dense matrix but that don't have a typical +// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and +// SelfAdjointView fall into this category. +template using is_eigen_other = all_of< + is_template_base_of, + negation, is_eigen_dense_plain, is_eigen_sparse>> +>; + +// Captures numpy/eigen conformability status (returned by EigenProps::conformable()): +template struct EigenConformable { + bool conformable = false; + EigenIndex rows = 0, cols = 0; + EigenDStride stride{0, 0}; // Only valid if negativestrides is false! + bool negativestrides = false; // If true, do not use stride! + + EigenConformable(bool fits = false) : conformable{fits} {} + // Matrix type: + EigenConformable(EigenIndex r, EigenIndex c, + EigenIndex rstride, EigenIndex cstride) : + conformable{true}, rows{r}, cols{c} { + // TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747 + if (rstride < 0 || cstride < 0) { + negativestrides = true; + } else { + stride = {EigenRowMajor ? rstride : cstride /* outer stride */, + EigenRowMajor ? cstride : rstride /* inner stride */ }; + } + } + // Vector type: + EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride) + : EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {} + + template bool stride_compatible() const { + // To have compatible strides, we need (on both dimensions) one of fully dynamic strides, + // matching strides, or a dimension size of 1 (in which case the stride value is irrelevant) + return + !negativestrides && + (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() || + (EigenRowMajor ? cols : rows) == 1) && + (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() || + (EigenRowMajor ? rows : cols) == 1); + } + operator bool() const { return conformable; } +}; + +template struct eigen_extract_stride { using type = Type; }; +template +struct eigen_extract_stride> { using type = StrideType; }; +template +struct eigen_extract_stride> { using type = StrideType; }; + +// Helper struct for extracting information from an Eigen type +template struct EigenProps { + using Type = Type_; + using Scalar = typename Type::Scalar; + using StrideType = typename eigen_extract_stride::type; + static constexpr EigenIndex + rows = Type::RowsAtCompileTime, + cols = Type::ColsAtCompileTime, + size = Type::SizeAtCompileTime; + static constexpr bool + row_major = Type::IsRowMajor, + vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1 + fixed_rows = rows != Eigen::Dynamic, + fixed_cols = cols != Eigen::Dynamic, + fixed = size != Eigen::Dynamic, // Fully-fixed size + dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size + + template using if_zero = std::integral_constant; + static constexpr EigenIndex inner_stride = if_zero::value, + outer_stride = if_zero::value; + static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic; + static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1; + static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1; + + // Takes an input array and determines whether we can make it fit into the Eigen type. If + // the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector + // (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type). + static EigenConformable conformable(const array &a) { + const auto dims = a.ndim(); + if (dims < 1 || dims > 2) + return false; + + if (dims == 2) { // Matrix type: require exact match (or dynamic) + + EigenIndex + np_rows = a.shape(0), + np_cols = a.shape(1), + np_rstride = a.strides(0) / static_cast(sizeof(Scalar)), + np_cstride = a.strides(1) / static_cast(sizeof(Scalar)); + if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols)) + return false; + + return {np_rows, np_cols, np_rstride, np_cstride}; + } + + // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever + // is used, we want the (single) numpy stride value. + const EigenIndex n = a.shape(0), + stride = a.strides(0) / static_cast(sizeof(Scalar)); + + if (vector) { // Eigen type is a compile-time vector + if (fixed && size != n) + return false; // Vector size mismatch + return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride}; + } + else if (fixed) { + // The type has a fixed size, but is not a vector: abort + return false; + } + else if (fixed_cols) { + // Since this isn't a vector, cols must be != 1. We allow this only if it exactly + // equals the number of elements (rows is Dynamic, and so 1 row is allowed). + if (cols != n) return false; + return {1, n, stride}; + } + else { + // Otherwise it's either fully dynamic, or column dynamic; both become a column vector + if (fixed_rows && rows != n) return false; + return {n, 1, stride}; + } + } + + static constexpr bool show_writeable = is_eigen_dense_map::value && is_eigen_mutable_map::value; + static constexpr bool show_order = is_eigen_dense_map::value; + static constexpr bool show_c_contiguous = show_order && requires_row_major; + static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major; + + static constexpr auto descriptor = + _("numpy.ndarray[") + npy_format_descriptor::name + + _("[") + _(_<(size_t) rows>(), _("m")) + + _(", ") + _(_<(size_t) cols>(), _("n")) + + _("]") + + // For a reference type (e.g. Ref) we have other constraints that might need to be + // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride + // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output + // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to + // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you + // *gave* a numpy.ndarray of the right type and dimensions. + _(", flags.writeable", "") + + _(", flags.c_contiguous", "") + + _(", flags.f_contiguous", "") + + _("]"); +}; + +// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data, +// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array. +template handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) { + constexpr ssize_t elem_size = sizeof(typename props::Scalar); + array a; + if (props::vector) + a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base); + else + a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() }, + src.data(), base); + + if (!writeable) + array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_; + + return a.release(); +} + +// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that +// reference the Eigen object's data with `base` as the python-registered base class (if omitted, +// the base will be set to None, and lifetime management is up to the caller). The numpy array is +// non-writeable if the given type is const. +template +handle eigen_ref_array(Type &src, handle parent = none()) { + // none here is to get past array's should-we-copy detection, which currently always + // copies when there is no base. Setting the base to None should be harmless. + return eigen_array_cast(src, parent, !std::is_const::value); +} + +// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy +// array that references the encapsulated data with a python-side reference to the capsule to tie +// its destruction to that of any dependent python objects. Const-ness is determined by whether or +// not the Type of the pointer given is const. +template ::value>> +handle eigen_encapsulate(Type *src) { + capsule base(src, [](void *o) { delete static_cast(o); }); + return eigen_ref_array(*src, base); +} + +// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense +// types. +template +struct type_caster::value>> { + using Scalar = typename Type::Scalar; + using props = EigenProps; + + bool load(handle src, bool convert) { + // If we're in no-convert mode, only load if given an array of the correct type + if (!convert && !isinstance>(src)) + return false; + + // Coerce into an array, but don't do type conversion yet; the copy below handles it. + auto buf = array::ensure(src); + + if (!buf) + return false; + + auto dims = buf.ndim(); + if (dims < 1 || dims > 2) + return false; + + auto fits = props::conformable(buf); + if (!fits) + return false; + + // Allocate the new type, then build a numpy reference into it + value = Type(fits.rows, fits.cols); + auto ref = reinterpret_steal(eigen_ref_array(value)); + if (dims == 1) ref = ref.squeeze(); + else if (ref.ndim() == 1) buf = buf.squeeze(); + + int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr()); + + if (result < 0) { // Copy failed! + PyErr_Clear(); + return false; + } + + return true; + } + +private: + + // Cast implementation + template + static handle cast_impl(CType *src, return_value_policy policy, handle parent) { + switch (policy) { + case return_value_policy::take_ownership: + case return_value_policy::automatic: + return eigen_encapsulate(src); + case return_value_policy::move: + return eigen_encapsulate(new CType(std::move(*src))); + case return_value_policy::copy: + return eigen_array_cast(*src); + case return_value_policy::reference: + case return_value_policy::automatic_reference: + return eigen_ref_array(*src); + case return_value_policy::reference_internal: + return eigen_ref_array(*src, parent); + default: + throw cast_error("unhandled return_value_policy: should not happen!"); + }; + } + +public: + + // Normal returned non-reference, non-const value: + static handle cast(Type &&src, return_value_policy /* policy */, handle parent) { + return cast_impl(&src, return_value_policy::move, parent); + } + // If you return a non-reference const, we mark the numpy array readonly: + static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) { + return cast_impl(&src, return_value_policy::move, parent); + } + // lvalue reference return; default (automatic) becomes copy + static handle cast(Type &src, return_value_policy policy, handle parent) { + if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast_impl(&src, policy, parent); + } + // const lvalue reference return; default (automatic) becomes copy + static handle cast(const Type &src, return_value_policy policy, handle parent) { + if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) + policy = return_value_policy::copy; + return cast(&src, policy, parent); + } + // non-const pointer return + static handle cast(Type *src, return_value_policy policy, handle parent) { + return cast_impl(src, policy, parent); + } + // const pointer return + static handle cast(const Type *src, return_value_policy policy, handle parent) { + return cast_impl(src, policy, parent); + } + + static constexpr auto name = props::descriptor; + + operator Type*() { return &value; } + operator Type&() { return value; } + operator Type&&() && { return std::move(value); } + template using cast_op_type = movable_cast_op_type; + +private: + Type value; +}; + +// Base class for casting reference/map/block/etc. objects back to python. +template struct eigen_map_caster { +private: + using props = EigenProps; + +public: + + // Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has + // to stay around), but we'll allow it under the assumption that you know what you're doing (and + // have an appropriate keep_alive in place). We return a numpy array pointing directly at the + // ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note + // that this means you need to ensure you don't destroy the object in some other way (e.g. with + // an appropriate keep_alive, or with a reference to a statically allocated matrix). + static handle cast(const MapType &src, return_value_policy policy, handle parent) { + switch (policy) { + case return_value_policy::copy: + return eigen_array_cast(src); + case return_value_policy::reference_internal: + return eigen_array_cast(src, parent, is_eigen_mutable_map::value); + case return_value_policy::reference: + case return_value_policy::automatic: + case return_value_policy::automatic_reference: + return eigen_array_cast(src, none(), is_eigen_mutable_map::value); + default: + // move, take_ownership don't make any sense for a ref/map: + pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type"); + } + } + + static constexpr auto name = props::descriptor; + + // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return + // types but not bound arguments). We still provide them (with an explicitly delete) so that + // you end up here if you try anyway. + bool load(handle, bool) = delete; + operator MapType() = delete; + template using cast_op_type = MapType; +}; + +// We can return any map-like object (but can only load Refs, specialized next): +template struct type_caster::value>> + : eigen_map_caster {}; + +// Loader for Ref<...> arguments. See the documentation for info on how to make this work without +// copying (it requires some extra effort in many cases). +template +struct type_caster< + Eigen::Ref, + enable_if_t>::value> +> : public eigen_map_caster> { +private: + using Type = Eigen::Ref; + using props = EigenProps; + using Scalar = typename props::Scalar; + using MapType = Eigen::Map; + using Array = array_t; + static constexpr bool need_writeable = is_eigen_mutable_map::value; + // Delay construction (these have no default constructor) + std::unique_ptr map; + std::unique_ptr ref; + // Our array. When possible, this is just a numpy array pointing to the source data, but + // sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible + // layout, or is an array of a type that needs to be converted). Using a numpy temporary + // (rather than an Eigen temporary) saves an extra copy when we need both type conversion and + // storage order conversion. (Note that we refuse to use this temporary copy when loading an + // argument for a Ref with M non-const, i.e. a read-write reference). + Array copy_or_ref; +public: + bool load(handle src, bool convert) { + // First check whether what we have is already an array of the right type. If not, we can't + // avoid a copy (because the copy is also going to do type conversion). + bool need_copy = !isinstance(src); + + EigenConformable fits; + if (!need_copy) { + // We don't need a converting copy, but we also need to check whether the strides are + // compatible with the Ref's stride requirements + Array aref = reinterpret_borrow(src); + + if (aref && (!need_writeable || aref.writeable())) { + fits = props::conformable(aref); + if (!fits) return false; // Incompatible dimensions + if (!fits.template stride_compatible()) + need_copy = true; + else + copy_or_ref = std::move(aref); + } + else { + need_copy = true; + } + } + + if (need_copy) { + // We need to copy: If we need a mutable reference, or we're not supposed to convert + // (either because we're in the no-convert overload pass, or because we're explicitly + // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading. + if (!convert || need_writeable) return false; + + Array copy = Array::ensure(src); + if (!copy) return false; + fits = props::conformable(copy); + if (!fits || !fits.template stride_compatible()) + return false; + copy_or_ref = std::move(copy); + loader_life_support::add_patient(copy_or_ref); + } + + ref.reset(); + map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner()))); + ref.reset(new Type(*map)); + + return true; + } + + operator Type*() { return ref.get(); } + operator Type&() { return *ref; } + template using cast_op_type = pybind11::detail::cast_op_type<_T>; + +private: + template ::value, int> = 0> + Scalar *data(Array &a) { return a.mutable_data(); } + + template ::value, int> = 0> + const Scalar *data(Array &a) { return a.data(); } + + // Attempt to figure out a constructor of `Stride` that will work. + // If both strides are fixed, use a default constructor: + template using stride_ctor_default = bool_constant< + S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic && + std::is_default_constructible::value>; + // Otherwise, if there is a two-index constructor, assume it is (outer,inner) like + // Eigen::Stride, and use it: + template using stride_ctor_dual = bool_constant< + !stride_ctor_default::value && std::is_constructible::value>; + // Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use + // it (passing whichever stride is dynamic). + template using stride_ctor_outer = bool_constant< + !any_of, stride_ctor_dual>::value && + S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic && + std::is_constructible::value>; + template using stride_ctor_inner = bool_constant< + !any_of, stride_ctor_dual>::value && + S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic && + std::is_constructible::value>; + + template ::value, int> = 0> + static S make_stride(EigenIndex, EigenIndex) { return S(); } + template ::value, int> = 0> + static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); } + template ::value, int> = 0> + static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); } + template ::value, int> = 0> + static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); } + +}; + +// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not +// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout). +// load() is not supported, but we can cast them into the python domain by first copying to a +// regular Eigen::Matrix, then casting that. +template +struct type_caster::value>> { +protected: + using Matrix = Eigen::Matrix; + using props = EigenProps; +public: + static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { + handle h = eigen_encapsulate(new Matrix(src)); + return h; + } + static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); } + + static constexpr auto name = props::descriptor; + + // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return + // types but not bound arguments). We still provide them (with an explicitly delete) so that + // you end up here if you try anyway. + bool load(handle, bool) = delete; + operator Type() = delete; + template using cast_op_type = Type; +}; + +template +struct type_caster::value>> { + typedef typename Type::Scalar Scalar; + typedef remove_reference_t().outerIndexPtr())> StorageIndex; + typedef typename Type::Index Index; + static constexpr bool rowMajor = Type::IsRowMajor; + + bool load(handle src, bool) { + if (!src) + return false; + + auto obj = reinterpret_borrow(src); + object sparse_module = module::import("scipy.sparse"); + object matrix_type = sparse_module.attr( + rowMajor ? "csr_matrix" : "csc_matrix"); + + if (!obj.get_type().is(matrix_type)) { + try { + obj = matrix_type(obj); + } catch (const error_already_set &) { + return false; + } + } + + auto values = array_t((object) obj.attr("data")); + auto innerIndices = array_t((object) obj.attr("indices")); + auto outerIndices = array_t((object) obj.attr("indptr")); + auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); + auto nnz = obj.attr("nnz").cast(); + + if (!values || !innerIndices || !outerIndices) + return false; + + value = Eigen::MappedSparseMatrix( + shape[0].cast(), shape[1].cast(), nnz, + outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); + + return true; + } + + static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { + const_cast(src).makeCompressed(); + + object matrix_type = module::import("scipy.sparse").attr( + rowMajor ? "csr_matrix" : "csc_matrix"); + + array data(src.nonZeros(), src.valuePtr()); + array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr()); + array innerIndices(src.nonZeros(), src.innerIndexPtr()); + + return matrix_type( + std::make_tuple(data, innerIndices, outerIndices), + std::make_pair(src.rows(), src.cols()) + ).release(); + } + + PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") + + npy_format_descriptor::name + _("]")); +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(__GNUG__) || defined(__clang__) +# pragma GCC diagnostic pop +#elif defined(_MSC_VER) +# pragma warning(pop) +#endif diff --git a/mmocr/models/textdet/postprocess/include/pybind11/embed.h b/mmocr/models/textdet/postprocess/include/pybind11/embed.h new file mode 100644 index 00000000..72655885 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/embed.h @@ -0,0 +1,200 @@ +/* + pybind11/embed.h: Support for embedding the interpreter + + Copyright (c) 2017 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include "eval.h" + +#if defined(PYPY_VERSION) +# error Embedding the interpreter is not supported with PyPy +#endif + +#if PY_MAJOR_VERSION >= 3 +# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ + extern "C" PyObject *pybind11_init_impl_##name() { \ + return pybind11_init_wrapper_##name(); \ + } +#else +# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ + extern "C" void pybind11_init_impl_##name() { \ + pybind11_init_wrapper_##name(); \ + } +#endif + +/** \rst + Add a new module to the table of builtins for the interpreter. Must be + defined in global scope. The first macro parameter is the name of the + module (without quotes). The second parameter is the variable which will + be used as the interface to add functions and classes to the module. + + .. code-block:: cpp + + PYBIND11_EMBEDDED_MODULE(example, m) { + // ... initialize functions and classes here + m.def("foo", []() { + return "Hello, World!"; + }); + } + \endrst */ +#define PYBIND11_EMBEDDED_MODULE(name, variable) \ + static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ + static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ + auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ + try { \ + PYBIND11_CONCAT(pybind11_init_, name)(m); \ + return m.ptr(); \ + } catch (pybind11::error_already_set &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } catch (const std::exception &e) { \ + PyErr_SetString(PyExc_ImportError, e.what()); \ + return nullptr; \ + } \ + } \ + PYBIND11_EMBEDDED_MODULE_IMPL(name) \ + pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \ + PYBIND11_CONCAT(pybind11_init_impl_, name)); \ + void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) + + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. +struct embedded_module { +#if PY_MAJOR_VERSION >= 3 + using init_t = PyObject *(*)(); +#else + using init_t = void (*)(); +#endif + embedded_module(const char *name, init_t init) { + if (Py_IsInitialized()) + pybind11_fail("Can't add new modules after the interpreter has been initialized"); + + auto result = PyImport_AppendInittab(name, init); + if (result == -1) + pybind11_fail("Insufficient memory to add a new module"); + } +}; + +NAMESPACE_END(detail) + +/** \rst + Initialize the Python interpreter. No other pybind11 or CPython API functions can be + called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The + optional parameter can be used to skip the registration of signal handlers (see the + `Python documentation`_ for details). Calling this function again after the interpreter + has already been initialized is a fatal error. + + If initializing the Python interpreter fails, then the program is terminated. (This + is controlled by the CPython runtime and is an exception to pybind11's normal behavior + of throwing exceptions on errors.) + + .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx + \endrst */ +inline void initialize_interpreter(bool init_signal_handlers = true) { + if (Py_IsInitialized()) + pybind11_fail("The interpreter is already running"); + + Py_InitializeEx(init_signal_handlers ? 1 : 0); + + // Make .py files in the working directory available by default + module::import("sys").attr("path").cast().append("."); +} + +/** \rst + Shut down the Python interpreter. No pybind11 or CPython API functions can be called + after this. In addition, pybind11 objects must not outlive the interpreter: + + .. code-block:: cpp + + { // BAD + py::initialize_interpreter(); + auto hello = py::str("Hello, World!"); + py::finalize_interpreter(); + } // <-- BOOM, hello's destructor is called after interpreter shutdown + + { // GOOD + py::initialize_interpreter(); + { // scoped + auto hello = py::str("Hello, World!"); + } // <-- OK, hello is cleaned up properly + py::finalize_interpreter(); + } + + { // BETTER + py::scoped_interpreter guard{}; + auto hello = py::str("Hello, World!"); + } + + .. warning:: + + The interpreter can be restarted by calling `initialize_interpreter` again. + Modules created using pybind11 can be safely re-initialized. However, Python + itself cannot completely unload binary extension modules and there are several + caveats with regard to interpreter restarting. All the details can be found + in the CPython documentation. In short, not all interpreter memory may be + freed, either due to reference cycles or user-created global data. + + \endrst */ +inline void finalize_interpreter() { + handle builtins(PyEval_GetBuiltins()); + const char *id = PYBIND11_INTERNALS_ID; + + // Get the internals pointer (without creating it if it doesn't exist). It's possible for the + // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` + // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). + detail::internals **internals_ptr_ptr = detail::get_internals_pp(); + // It could also be stashed in builtins, so look there too: + if (builtins.contains(id) && isinstance(builtins[id])) + internals_ptr_ptr = capsule(builtins[id]); + + Py_Finalize(); + + if (internals_ptr_ptr) { + delete *internals_ptr_ptr; + *internals_ptr_ptr = nullptr; + } +} + +/** \rst + Scope guard version of `initialize_interpreter` and `finalize_interpreter`. + This a move-only guard and only a single instance can exist. + + .. code-block:: cpp + + #include + + int main() { + py::scoped_interpreter guard{}; + py::print(Hello, World!); + } // <-- interpreter shutdown + \endrst */ +class scoped_interpreter { +public: + scoped_interpreter(bool init_signal_handlers = true) { + initialize_interpreter(init_signal_handlers); + } + + scoped_interpreter(const scoped_interpreter &) = delete; + scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } + scoped_interpreter &operator=(const scoped_interpreter &) = delete; + scoped_interpreter &operator=(scoped_interpreter &&) = delete; + + ~scoped_interpreter() { + if (is_valid) + finalize_interpreter(); + } + +private: + bool is_valid = true; +}; + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/eval.h b/mmocr/models/textdet/postprocess/include/pybind11/eval.h new file mode 100644 index 00000000..ea85ba1d --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/eval.h @@ -0,0 +1,117 @@ +/* + pybind11/exec.h: Support for evaluating Python expressions and statements + from strings and files + + Copyright (c) 2016 Klemens Morgenstern and + Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +enum eval_mode { + /// Evaluate a string containing an isolated expression + eval_expr, + + /// Evaluate a string containing a single statement. Returns \c none + eval_single_statement, + + /// Evaluate a string containing a sequence of statement. Returns \c none + eval_statements +}; + +template +object eval(str expr, object global = globals(), object local = object()) { + if (!local) + local = global; + + /* PyRun_String does not accept a PyObject / encoding specifier, + this seems to be the only alternative */ + std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; + + int start; + switch (mode) { + case eval_expr: start = Py_eval_input; break; + case eval_single_statement: start = Py_single_input; break; + case eval_statements: start = Py_file_input; break; + default: pybind11_fail("invalid evaluation mode"); + } + + PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); + if (!result) + throw error_already_set(); + return reinterpret_steal(result); +} + +template +object eval(const char (&s)[N], object global = globals(), object local = object()) { + /* Support raw string literals by removing common leading whitespace */ + auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) + : str(s); + return eval(expr, global, local); +} + +inline void exec(str expr, object global = globals(), object local = object()) { + eval(expr, global, local); +} + +template +void exec(const char (&s)[N], object global = globals(), object local = object()) { + eval(s, global, local); +} + +template +object eval_file(str fname, object global = globals(), object local = object()) { + if (!local) + local = global; + + int start; + switch (mode) { + case eval_expr: start = Py_eval_input; break; + case eval_single_statement: start = Py_single_input; break; + case eval_statements: start = Py_file_input; break; + default: pybind11_fail("invalid evaluation mode"); + } + + int closeFile = 1; + std::string fname_str = (std::string) fname; +#if PY_VERSION_HEX >= 0x03040000 + FILE *f = _Py_fopen_obj(fname.ptr(), "r"); +#elif PY_VERSION_HEX >= 0x03000000 + FILE *f = _Py_fopen(fname.ptr(), "r"); +#else + /* No unicode support in open() :( */ + auto fobj = reinterpret_steal(PyFile_FromString( + const_cast(fname_str.c_str()), + const_cast("r"))); + FILE *f = nullptr; + if (fobj) + f = PyFile_AsFile(fobj.ptr()); + closeFile = 0; +#endif + if (!f) { + PyErr_Clear(); + pybind11_fail("File \"" + fname_str + "\" could not be opened!"); + } + +#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) + PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), + local.ptr()); + (void) closeFile; +#else + PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), + local.ptr(), closeFile); +#endif + + if (!result) + throw error_already_set(); + return reinterpret_steal(result); +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/functional.h b/mmocr/models/textdet/postprocess/include/pybind11/functional.h new file mode 100644 index 00000000..9cdf21f7 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/functional.h @@ -0,0 +1,83 @@ +/* + pybind11/functional.h: std::function<> support + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template +struct type_caster> { + using type = std::function; + using retval_type = conditional_t::value, void_type, Return>; + using function_type = Return (*) (Args...); + +public: + bool load(handle src, bool convert) { + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) return false; + return true; + } + + if (!isinstance(src)) + return false; + + auto func = reinterpret_borrow(src); + + /* + When passing a C++ function as an argument to another C++ + function via Python, every function call would normally involve + a full C++ -> Python -> C++ roundtrip, which can be prohibitive. + Here, we try to at least detect the case where the function is + stateless (i.e. function pointer or lambda function without + captured variables), in which case the roundtrip can be avoided. + */ + if (auto cfunc = func.cpp_function()) { + auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); + auto rec = (function_record *) c; + + if (rec && rec->is_stateless && + same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { + struct capture { function_type f; }; + value = ((capture *) &rec->data)->f; + return true; + } + } + + value = [func](Args... args) -> Return { + gil_scoped_acquire acq; + object retval(func(std::forward(args)...)); + /* Visual studio 2015 parser issue: need parentheses around this expression */ + return (retval.template cast()); + }; + return true; + } + + template + static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { + if (!f_) + return none().inc_ref(); + + auto result = f_.template target(); + if (result) + return cpp_function(*result, policy).release(); + else + return cpp_function(std::forward(f_), policy).release(); + } + + PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") + + make_caster::name + _("]")); +}; + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/iostream.h b/mmocr/models/textdet/postprocess/include/pybind11/iostream.h new file mode 100644 index 00000000..182e8eef --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/iostream.h @@ -0,0 +1,200 @@ +/* + pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python + + Copyright (c) 2017 Henry F. Schreiner + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" + +#include +#include +#include +#include +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +// Buffer that writes to Python instead of C++ +class pythonbuf : public std::streambuf { +private: + using traits_type = std::streambuf::traits_type; + + char d_buffer[1024]; + object pywrite; + object pyflush; + + int overflow(int c) { + if (!traits_type::eq_int_type(c, traits_type::eof())) { + *pptr() = traits_type::to_char_type(c); + pbump(1); + } + return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); + } + + int sync() { + if (pbase() != pptr()) { + // This subtraction cannot be negative, so dropping the sign + str line(pbase(), static_cast(pptr() - pbase())); + + pywrite(line); + pyflush(); + + setp(pbase(), epptr()); + } + return 0; + } + +public: + pythonbuf(object pyostream) + : pywrite(pyostream.attr("write")), + pyflush(pyostream.attr("flush")) { + setp(d_buffer, d_buffer + sizeof(d_buffer) - 1); + } + + /// Sync before destroy + ~pythonbuf() { + sync(); + } +}; + +NAMESPACE_END(detail) + + +/** \rst + This a move-only guard that redirects output. + + .. code-block:: cpp + + #include + + ... + + { + py::scoped_ostream_redirect output; + std::cout << "Hello, World!"; // Python stdout + } // <-- return std::cout to normal + + You can explicitly pass the c++ stream and the python object, + for example to guard stderr instead. + + .. code-block:: cpp + + { + py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; + std::cerr << "Hello, World!"; + } + \endrst */ +class scoped_ostream_redirect { +protected: + std::streambuf *old; + std::ostream &costream; + detail::pythonbuf buffer; + +public: + scoped_ostream_redirect( + std::ostream &costream = std::cout, + object pyostream = module::import("sys").attr("stdout")) + : costream(costream), buffer(pyostream) { + old = costream.rdbuf(&buffer); + } + + ~scoped_ostream_redirect() { + costream.rdbuf(old); + } + + scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; + scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; + scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; + scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; +}; + + +/** \rst + Like `scoped_ostream_redirect`, but redirects cerr by default. This class + is provided primary to make ``py::call_guard`` easier to make. + + .. code-block:: cpp + + m.def("noisy_func", &noisy_func, + py::call_guard()); + +\endrst */ +class scoped_estream_redirect : public scoped_ostream_redirect { +public: + scoped_estream_redirect( + std::ostream &costream = std::cerr, + object pyostream = module::import("sys").attr("stderr")) + : scoped_ostream_redirect(costream,pyostream) {} +}; + + +NAMESPACE_BEGIN(detail) + +// Class to redirect output as a context manager. C++ backend. +class OstreamRedirect { + bool do_stdout_; + bool do_stderr_; + std::unique_ptr redirect_stdout; + std::unique_ptr redirect_stderr; + +public: + OstreamRedirect(bool do_stdout = true, bool do_stderr = true) + : do_stdout_(do_stdout), do_stderr_(do_stderr) {} + + void enter() { + if (do_stdout_) + redirect_stdout.reset(new scoped_ostream_redirect()); + if (do_stderr_) + redirect_stderr.reset(new scoped_estream_redirect()); + } + + void exit() { + redirect_stdout.reset(); + redirect_stderr.reset(); + } +}; + +NAMESPACE_END(detail) + +/** \rst + This is a helper function to add a C++ redirect context manager to Python + instead of using a C++ guard. To use it, add the following to your binding code: + + .. code-block:: cpp + + #include + + ... + + py::add_ostream_redirect(m, "ostream_redirect"); + + You now have a Python context manager that redirects your output: + + .. code-block:: python + + with m.ostream_redirect(): + m.print_to_cout_function() + + This manager can optionally be told which streams to operate on: + + .. code-block:: python + + with m.ostream_redirect(stdout=true, stderr=true): + m.noisy_function_with_error_printing() + + \endrst */ +inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { + return class_(m, name.c_str(), module_local()) + .def(init(), arg("stdout")=true, arg("stderr")=true) + .def("__enter__", &detail::OstreamRedirect::enter) + .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/numpy.h b/mmocr/models/textdet/postprocess/include/pybind11/numpy.h new file mode 100644 index 00000000..37471d8b --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/numpy.h @@ -0,0 +1,1610 @@ +/* + pybind11/numpy.h: Basic NumPy support, vectorize() wrapper + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include "complex.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +/* This will be true on all flat address space platforms and allows us to reduce the + whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size + and dimension types (e.g. shape, strides, indexing), instead of inflicting this + upon the library user. */ +static_assert(sizeof(ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t"); + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +class array; // Forward declaration + +NAMESPACE_BEGIN(detail) +template struct npy_format_descriptor; + +struct PyArrayDescr_Proxy { + PyObject_HEAD + PyObject *typeobj; + char kind; + char type; + char byteorder; + char flags; + int type_num; + int elsize; + int alignment; + char *subarray; + PyObject *fields; + PyObject *names; +}; + +struct PyArray_Proxy { + PyObject_HEAD + char *data; + int nd; + ssize_t *dimensions; + ssize_t *strides; + PyObject *base; + PyObject *descr; + int flags; +}; + +struct PyVoidScalarObject_Proxy { + PyObject_VAR_HEAD + char *obval; + PyArrayDescr_Proxy *descr; + int flags; + PyObject *base; +}; + +struct numpy_type_info { + PyObject* dtype_ptr; + std::string format_str; +}; + +struct numpy_internals { + std::unordered_map registered_dtypes; + + numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { + auto it = registered_dtypes.find(std::type_index(tinfo)); + if (it != registered_dtypes.end()) + return &(it->second); + if (throw_if_missing) + pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); + return nullptr; + } + + template numpy_type_info *get_type_info(bool throw_if_missing = true) { + return get_type_info(typeid(typename std::remove_cv::type), throw_if_missing); + } +}; + +inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { + ptr = &get_or_create_shared_data("_numpy_internals"); +} + +inline numpy_internals& get_numpy_internals() { + static numpy_internals* ptr = nullptr; + if (!ptr) + load_numpy_internals(ptr); + return *ptr; +} + +struct npy_api { + enum constants { + NPY_ARRAY_C_CONTIGUOUS_ = 0x0001, + NPY_ARRAY_F_CONTIGUOUS_ = 0x0002, + NPY_ARRAY_OWNDATA_ = 0x0004, + NPY_ARRAY_FORCECAST_ = 0x0010, + NPY_ARRAY_ENSUREARRAY_ = 0x0040, + NPY_ARRAY_ALIGNED_ = 0x0100, + NPY_ARRAY_WRITEABLE_ = 0x0400, + NPY_BOOL_ = 0, + NPY_BYTE_, NPY_UBYTE_, + NPY_SHORT_, NPY_USHORT_, + NPY_INT_, NPY_UINT_, + NPY_LONG_, NPY_ULONG_, + NPY_LONGLONG_, NPY_ULONGLONG_, + NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, + NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_, + NPY_OBJECT_ = 17, + NPY_STRING_, NPY_UNICODE_, NPY_VOID_ + }; + + typedef struct { + Py_intptr_t *ptr; + int len; + } PyArray_Dims; + + static npy_api& get() { + static npy_api api = lookup(); + return api; + } + + bool PyArray_Check_(PyObject *obj) const { + return (bool) PyObject_TypeCheck(obj, PyArray_Type_); + } + bool PyArrayDescr_Check_(PyObject *obj) const { + return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_); + } + + unsigned int (*PyArray_GetNDArrayCFeatureVersion_)(); + PyObject *(*PyArray_DescrFromType_)(int); + PyObject *(*PyArray_NewFromDescr_) + (PyTypeObject *, PyObject *, int, Py_intptr_t *, + Py_intptr_t *, void *, int, PyObject *); + PyObject *(*PyArray_DescrNewFromType_)(int); + int (*PyArray_CopyInto_)(PyObject *, PyObject *); + PyObject *(*PyArray_NewCopy_)(PyObject *, int); + PyTypeObject *PyArray_Type_; + PyTypeObject *PyVoidArrType_Type_; + PyTypeObject *PyArrayDescr_Type_; + PyObject *(*PyArray_DescrFromScalar_)(PyObject *); + PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); + int (*PyArray_DescrConverter_) (PyObject *, PyObject **); + bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); + int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, + Py_ssize_t *, PyObject **, PyObject *); + PyObject *(*PyArray_Squeeze_)(PyObject *); + int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); + PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int); +private: + enum functions { + API_PyArray_GetNDArrayCFeatureVersion = 211, + API_PyArray_Type = 2, + API_PyArrayDescr_Type = 3, + API_PyVoidArrType_Type = 39, + API_PyArray_DescrFromType = 45, + API_PyArray_DescrFromScalar = 57, + API_PyArray_FromAny = 69, + API_PyArray_Resize = 80, + API_PyArray_CopyInto = 82, + API_PyArray_NewCopy = 85, + API_PyArray_NewFromDescr = 94, + API_PyArray_DescrNewFromType = 9, + API_PyArray_DescrConverter = 174, + API_PyArray_EquivTypes = 182, + API_PyArray_GetArrayParamsFromObject = 278, + API_PyArray_Squeeze = 136, + API_PyArray_SetBaseObject = 282 + }; + + static npy_api lookup() { + module m = module::import("numpy.core.multiarray"); + auto c = m.attr("_ARRAY_API"); +#if PY_MAJOR_VERSION >= 3 + void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL); +#else + void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr()); +#endif + npy_api api; +#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func]; + DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion); + if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7) + pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0"); + DECL_NPY_API(PyArray_Type); + DECL_NPY_API(PyVoidArrType_Type); + DECL_NPY_API(PyArrayDescr_Type); + DECL_NPY_API(PyArray_DescrFromType); + DECL_NPY_API(PyArray_DescrFromScalar); + DECL_NPY_API(PyArray_FromAny); + DECL_NPY_API(PyArray_Resize); + DECL_NPY_API(PyArray_CopyInto); + DECL_NPY_API(PyArray_NewCopy); + DECL_NPY_API(PyArray_NewFromDescr); + DECL_NPY_API(PyArray_DescrNewFromType); + DECL_NPY_API(PyArray_DescrConverter); + DECL_NPY_API(PyArray_EquivTypes); + DECL_NPY_API(PyArray_GetArrayParamsFromObject); + DECL_NPY_API(PyArray_Squeeze); + DECL_NPY_API(PyArray_SetBaseObject); +#undef DECL_NPY_API + return api; + } +}; + +inline PyArray_Proxy* array_proxy(void* ptr) { + return reinterpret_cast(ptr); +} + +inline const PyArray_Proxy* array_proxy(const void* ptr) { + return reinterpret_cast(ptr); +} + +inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) { + return reinterpret_cast(ptr); +} + +inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) { + return reinterpret_cast(ptr); +} + +inline bool check_flags(const void* ptr, int flag) { + return (flag == (array_proxy(ptr)->flags & flag)); +} + +template struct is_std_array : std::false_type { }; +template struct is_std_array> : std::true_type { }; +template struct is_complex : std::false_type { }; +template struct is_complex> : std::true_type { }; + +template struct array_info_scalar { + typedef T type; + static constexpr bool is_array = false; + static constexpr bool is_empty = false; + static constexpr auto extents = _(""); + static void append_extents(list& /* shape */) { } +}; +// Computes underlying type and a comma-separated list of extents for array +// types (any mix of std::array and built-in arrays). An array of char is +// treated as scalar because it gets special handling. +template struct array_info : array_info_scalar { }; +template struct array_info> { + using type = typename array_info::type; + static constexpr bool is_array = true; + static constexpr bool is_empty = (N == 0) || array_info::is_empty; + static constexpr size_t extent = N; + + // appends the extents to shape + static void append_extents(list& shape) { + shape.append(N); + array_info::append_extents(shape); + } + + static constexpr auto extents = _::is_array>( + concat(_(), array_info::extents), _() + ); +}; +// For numpy we have special handling for arrays of characters, so we don't include +// the size in the array extents. +template struct array_info : array_info_scalar { }; +template struct array_info> : array_info_scalar> { }; +template struct array_info : array_info> { }; +template using remove_all_extents_t = typename array_info::type; + +template using is_pod_struct = all_of< + std::is_standard_layout, // since we're accessing directly in memory we need a standard layout type +#if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI) + // _GLIBCXX_USE_CXX11_ABI indicates that we're using libstdc++ from GCC 5 or newer, independent + // of the actual compiler (Clang can also use libstdc++, but it always defines __GNUC__ == 4). + std::is_trivially_copyable, +#else + // GCC 4 doesn't implement is_trivially_copyable, so approximate it + std::is_trivially_destructible, + satisfies_any_of, +#endif + satisfies_none_of +>; + +template ssize_t byte_offset_unsafe(const Strides &) { return 0; } +template +ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) { + return i * strides[Dim] + byte_offset_unsafe(strides, index...); +} + +/** + * Proxy class providing unsafe, unchecked const access to array data. This is constructed through + * the `unchecked()` method of `array` or the `unchecked()` method of `array_t`. `Dims` + * will be -1 for dimensions determined at runtime. + */ +template +class unchecked_reference { +protected: + static constexpr bool Dynamic = Dims < 0; + const unsigned char *data_; + // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to + // make large performance gains on big, nested loops, but requires compile-time dimensions + conditional_t> + shape_, strides_; + const ssize_t dims_; + + friend class pybind11::array; + // Constructor for compile-time dimensions: + template + unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t) + : data_{reinterpret_cast(data)}, dims_{Dims} { + for (size_t i = 0; i < (size_t) dims_; i++) { + shape_[i] = shape[i]; + strides_[i] = strides[i]; + } + } + // Constructor for runtime dimensions: + template + unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t dims) + : data_{reinterpret_cast(data)}, shape_{shape}, strides_{strides}, dims_{dims} {} + +public: + /** + * Unchecked const reference access to data at the given indices. For a compile-time known + * number of dimensions, this requires the correct number of arguments; for run-time + * dimensionality, this is not checked (and so is up to the caller to use safely). + */ + template const T &operator()(Ix... index) const { + static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); + return *reinterpret_cast(data_ + byte_offset_unsafe(strides_, ssize_t(index)...)); + } + /** + * Unchecked const reference access to data; this operator only participates if the reference + * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. + */ + template > + const T &operator[](ssize_t index) const { return operator()(index); } + + /// Pointer access to the data at the given indices. + template const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); } + + /// Returns the item size, i.e. sizeof(T) + constexpr static ssize_t itemsize() { return sizeof(T); } + + /// Returns the shape (i.e. size) of dimension `dim` + ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; } + + /// Returns the number of dimensions of the array + ssize_t ndim() const { return dims_; } + + /// Returns the total number of elements in the referenced array, i.e. the product of the shapes + template + enable_if_t size() const { + return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies()); + } + template + enable_if_t size() const { + return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies()); + } + + /// Returns the total number of bytes used by the referenced data. Note that the actual span in + /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice). + ssize_t nbytes() const { + return size() * itemsize(); + } +}; + +template +class unchecked_mutable_reference : public unchecked_reference { + friend class pybind11::array; + using ConstBase = unchecked_reference; + using ConstBase::ConstBase; + using ConstBase::Dynamic; +public: + /// Mutable, unchecked access to data at the given indices. + template T& operator()(Ix... index) { + static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic, + "Invalid number of indices for unchecked array reference"); + return const_cast(ConstBase::operator()(index...)); + } + /** + * Mutable, unchecked access data at the given index; this operator only participates if the + * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is + * exactly equivalent to `obj(index)`. + */ + template > + T &operator[](ssize_t index) { return operator()(index); } + + /// Mutable pointer access to the data at the given indices. + template T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); } +}; + +template +struct type_caster> { + static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable"); +}; +template +struct type_caster> : type_caster> {}; + +NAMESPACE_END(detail) + +class dtype : public object { +public: + PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); + + explicit dtype(const buffer_info &info) { + dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format))); + // If info.itemsize == 0, use the value calculated from the format string + m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr(); + } + + explicit dtype(const std::string &format) { + m_ptr = from_args(pybind11::str(format)).release().ptr(); + } + + dtype(const char *format) : dtype(std::string(format)) { } + + dtype(list names, list formats, list offsets, ssize_t itemsize) { + dict args; + args["names"] = names; + args["formats"] = formats; + args["offsets"] = offsets; + args["itemsize"] = pybind11::int_(itemsize); + m_ptr = from_args(args).release().ptr(); + } + + /// This is essentially the same as calling numpy.dtype(args) in Python. + static dtype from_args(object args) { + PyObject *ptr = nullptr; + if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr) + throw error_already_set(); + return reinterpret_steal(ptr); + } + + /// Return dtype associated with a C++ type. + template static dtype of() { + return detail::npy_format_descriptor::type>::dtype(); + } + + /// Size of the data type in bytes. + ssize_t itemsize() const { + return detail::array_descriptor_proxy(m_ptr)->elsize; + } + + /// Returns true for structured data types. + bool has_fields() const { + return detail::array_descriptor_proxy(m_ptr)->names != nullptr; + } + + /// Single-character type code. + char kind() const { + return detail::array_descriptor_proxy(m_ptr)->kind; + } + +private: + static object _dtype_from_pep3118() { + static PyObject *obj = module::import("numpy.core._internal") + .attr("_dtype_from_pep3118").cast().release().ptr(); + return reinterpret_borrow(obj); + } + + dtype strip_padding(ssize_t itemsize) { + // Recursively strip all void fields with empty names that are generated for + // padding fields (as of NumPy v1.11). + if (!has_fields()) + return *this; + + struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; + std::vector field_descriptors; + + for (auto field : attr("fields").attr("items")()) { + auto spec = field.cast(); + auto name = spec[0].cast(); + auto format = spec[1].cast()[0].cast(); + auto offset = spec[1].cast()[1].cast(); + if (!len(name) && format.kind() == 'V') + continue; + field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset}); + } + + std::sort(field_descriptors.begin(), field_descriptors.end(), + [](const field_descr& a, const field_descr& b) { + return a.offset.cast() < b.offset.cast(); + }); + + list names, formats, offsets; + for (auto& descr : field_descriptors) { + names.append(descr.name); + formats.append(descr.format); + offsets.append(descr.offset); + } + return dtype(names, formats, offsets, itemsize); + } +}; + +class array : public buffer { +public: + PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array) + + enum { + c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_, + f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_, + forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ + }; + + array() : array({{0}}, static_cast(nullptr)) {} + + using ShapeContainer = detail::any_container; + using StridesContainer = detail::any_container; + + // Constructs an array taking shape/strides from arbitrary container types + array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides, + const void *ptr = nullptr, handle base = handle()) { + + if (strides->empty()) + *strides = c_strides(*shape, dt.itemsize()); + + auto ndim = shape->size(); + if (ndim != strides->size()) + pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); + auto descr = dt; + + int flags = 0; + if (base && ptr) { + if (isinstance(base)) + /* Copy flags from base (except ownership bit) */ + flags = reinterpret_borrow(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; + else + /* Writable by default, easy to downgrade later on if needed */ + flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; + } + + auto &api = detail::npy_api::get(); + auto tmp = reinterpret_steal(api.PyArray_NewFromDescr_( + api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(), + const_cast(ptr), flags, nullptr)); + if (!tmp) + throw error_already_set(); + if (ptr) { + if (base) { + api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr()); + } else { + tmp = reinterpret_steal(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); + } + } + m_ptr = tmp.release().ptr(); + } + + array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle()) + : array(dt, std::move(shape), {}, ptr, base) { } + + template ::value && !std::is_same::value>> + array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle()) + : array(dt, {{count}}, ptr, base) { } + + template + array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle()) + : array(pybind11::dtype::of(), std::move(shape), std::move(strides), ptr, base) { } + + template + array(ShapeContainer shape, const T *ptr, handle base = handle()) + : array(std::move(shape), {}, ptr, base) { } + + template + explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { } + + explicit array(const buffer_info &info) + : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } + + /// Array descriptor (dtype) + pybind11::dtype dtype() const { + return reinterpret_borrow(detail::array_proxy(m_ptr)->descr); + } + + /// Total number of elements + ssize_t size() const { + return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies()); + } + + /// Byte size of a single element + ssize_t itemsize() const { + return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize; + } + + /// Total number of bytes + ssize_t nbytes() const { + return size() * itemsize(); + } + + /// Number of dimensions + ssize_t ndim() const { + return detail::array_proxy(m_ptr)->nd; + } + + /// Base object + object base() const { + return reinterpret_borrow(detail::array_proxy(m_ptr)->base); + } + + /// Dimensions of the array + const ssize_t* shape() const { + return detail::array_proxy(m_ptr)->dimensions; + } + + /// Dimension along a given axis + ssize_t shape(ssize_t dim) const { + if (dim >= ndim()) + fail_dim_check(dim, "invalid axis"); + return shape()[dim]; + } + + /// Strides of the array + const ssize_t* strides() const { + return detail::array_proxy(m_ptr)->strides; + } + + /// Stride along a given axis + ssize_t strides(ssize_t dim) const { + if (dim >= ndim()) + fail_dim_check(dim, "invalid axis"); + return strides()[dim]; + } + + /// Return the NumPy array flags + int flags() const { + return detail::array_proxy(m_ptr)->flags; + } + + /// If set, the array is writeable (otherwise the buffer is read-only) + bool writeable() const { + return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); + } + + /// If set, the array owns the data (will be freed when the array is deleted) + bool owndata() const { + return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); + } + + /// Pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + template const void* data(Ix... index) const { + return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); + } + + /// Mutable pointer to the contained data. If index is not provided, points to the + /// beginning of the buffer. May throw if the index would lead to out of bounds access. + /// May throw if the array is not writeable. + template void* mutable_data(Ix... index) { + check_writeable(); + return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); + } + + /// Byte offset from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template ssize_t offset_at(Ix... index) const { + if ((ssize_t) sizeof...(index) > ndim()) + fail_dim_check(sizeof...(index), "too many indices for an array"); + return byte_offset(ssize_t(index)...); + } + + ssize_t offset_at() const { return 0; } + + /// Item count from beginning of the array to a given index (full or partial). + /// May throw if the index would lead to out of bounds access. + template ssize_t index_at(Ix... index) const { + return offset_at(index...) / itemsize(); + } + + /** + * Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template detail::unchecked_mutable_reference mutable_unchecked() & { + if (Dims >= 0 && ndim() != Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_mutable_reference(mutable_data(), shape(), strides(), ndim()); + } + + /** + * Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the + * underlying array have the `writable` flag. Use with care: the array must not be destroyed or + * reshaped for the duration of the returned object, and the caller must take care not to access + * invalid dimensions or dimension indices. + */ + template detail::unchecked_reference unchecked() const & { + if (Dims >= 0 && ndim() != Dims) + throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + + "; expected " + std::to_string(Dims)); + return detail::unchecked_reference(data(), shape(), strides(), ndim()); + } + + /// Return a new view with all of the dimensions of length 1 removed + array squeeze() { + auto& api = detail::npy_api::get(); + return reinterpret_steal(api.PyArray_Squeeze_(m_ptr)); + } + + /// Resize array to given shape + /// If refcheck is true and more that one reference exist to this array + /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change + void resize(ShapeContainer new_shape, bool refcheck = true) { + detail::npy_api::PyArray_Dims d = { + new_shape->data(), int(new_shape->size()) + }; + // try to resize, set ordering param to -1 cause it's not used anyway + object new_array = reinterpret_steal( + detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1) + ); + if (!new_array) throw error_already_set(); + if (isinstance(new_array)) { *this = std::move(new_array); } + } + + /// Ensure that the argument is a NumPy array + /// In case of an error, nullptr is returned and the Python error is cleared. + static array ensure(handle h, int ExtraFlags = 0) { + auto result = reinterpret_steal(raw_array(h.ptr(), ExtraFlags)); + if (!result) + PyErr_Clear(); + return result; + } + +protected: + template friend struct detail::npy_format_descriptor; + + void fail_dim_check(ssize_t dim, const std::string& msg) const { + throw index_error(msg + ": " + std::to_string(dim) + + " (ndim = " + std::to_string(ndim()) + ")"); + } + + template ssize_t byte_offset(Ix... index) const { + check_dimensions(index...); + return detail::byte_offset_unsafe(strides(), ssize_t(index)...); + } + + void check_writeable() const { + if (!writeable()) + throw std::domain_error("array is not writeable"); + } + + // Default, C-style strides + static std::vector c_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + if (ndim > 0) + for (size_t i = ndim - 1; i > 0; --i) + strides[i - 1] = strides[i] * shape[i]; + return strides; + } + + // F-style strides; default when constructing an array_t with `ExtraFlags & f_style` + static std::vector f_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + for (size_t i = 1; i < ndim; ++i) + strides[i] = strides[i - 1] * shape[i - 1]; + return strides; + } + + template void check_dimensions(Ix... index) const { + check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...); + } + + void check_dimensions_impl(ssize_t, const ssize_t*) const { } + + template void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const { + if (i >= *shape) { + throw index_error(std::string("index ") + std::to_string(i) + + " is out of bounds for axis " + std::to_string(axis) + + " with size " + std::to_string(*shape)); + } + check_dimensions_impl(axis + 1, shape + 1, index...); + } + + /// Create array from any object -- always returns a new reference + static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) { + if (ptr == nullptr) { + PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr"); + return nullptr; + } + return detail::npy_api::get().PyArray_FromAny_( + ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); + } +}; + +template class array_t : public array { +private: + struct private_ctor {}; + // Delegating constructor needed when both moving and accessing in the same constructor + array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base) + : array(std::move(shape), std::move(strides), ptr, base) {} +public: + static_assert(!detail::array_info::is_array, "Array types cannot be used with array_t"); + + using value_type = T; + + array_t() : array(0, static_cast(nullptr)) {} + array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { } + array_t(handle h, stolen_t) : array(h, stolen_t{}) { } + + PYBIND11_DEPRECATED("Use array_t::ensure() instead") + array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) { + if (!m_ptr) PyErr_Clear(); + if (!is_borrowed) Py_XDECREF(h.ptr()); + } + + array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) { + if (!m_ptr) throw error_already_set(); + } + + explicit array_t(const buffer_info& info) : array(info) { } + + array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle()) + : array(std::move(shape), std::move(strides), ptr, base) { } + + explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle()) + : array_t(private_ctor{}, std::move(shape), + ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()), + ptr, base) { } + + explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle()) + : array({count}, {}, ptr, base) { } + + constexpr ssize_t itemsize() const { + return sizeof(T); + } + + template ssize_t index_at(Ix... index) const { + return offset_at(index...) / itemsize(); + } + + template const T* data(Ix... index) const { + return static_cast(array::data(index...)); + } + + template T* mutable_data(Ix... index) { + return static_cast(array::mutable_data(index...)); + } + + // Reference to element at a given index + template const T& at(Ix... index) const { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + return *(static_cast(array::data()) + byte_offset(ssize_t(index)...) / itemsize()); + } + + // Mutable reference to element at a given index + template T& mutable_at(Ix... index) { + if (sizeof...(index) != ndim()) + fail_dim_check(sizeof...(index), "index dimension mismatch"); + return *(static_cast(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize()); + } + + /** + * Returns a proxy object that provides access to the array's data without bounds or + * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with + * care: the array must not be destroyed or reshaped for the duration of the returned object, + * and the caller must take care not to access invalid dimensions or dimension indices. + */ + template detail::unchecked_mutable_reference mutable_unchecked() & { + return array::mutable_unchecked(); + } + + /** + * Returns a proxy object that provides const access to the array's data without bounds or + * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying + * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped + * for the duration of the returned object, and the caller must take care not to access invalid + * dimensions or dimension indices. + */ + template detail::unchecked_reference unchecked() const & { + return array::unchecked(); + } + + /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert + /// it). In case of an error, nullptr is returned and the Python error is cleared. + static array_t ensure(handle h) { + auto result = reinterpret_steal(raw_array_t(h.ptr())); + if (!result) + PyErr_Clear(); + return result; + } + + static bool check_(handle h) { + const auto &api = detail::npy_api::get(); + return api.PyArray_Check_(h.ptr()) + && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of().ptr()); + } + +protected: + /// Create array from any object -- always returns a new reference + static PyObject *raw_array_t(PyObject *ptr) { + if (ptr == nullptr) { + PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr"); + return nullptr; + } + return detail::npy_api::get().PyArray_FromAny_( + ptr, dtype::of().release().ptr(), 0, 0, + detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); + } +}; + +template +struct format_descriptor::value>> { + static std::string format() { + return detail::npy_format_descriptor::type>::format(); + } +}; + +template struct format_descriptor { + static std::string format() { return std::to_string(N) + "s"; } +}; +template struct format_descriptor> { + static std::string format() { return std::to_string(N) + "s"; } +}; + +template +struct format_descriptor::value>> { + static std::string format() { + return format_descriptor< + typename std::remove_cv::type>::type>::format(); + } +}; + +template +struct format_descriptor::is_array>> { + static std::string format() { + using namespace detail; + static constexpr auto extents = _("(") + array_info::extents + _(")"); + return extents.text + format_descriptor>::format(); + } +}; + +NAMESPACE_BEGIN(detail) +template +struct pyobject_caster> { + using type = array_t; + + bool load(handle src, bool convert) { + if (!convert && !type::check_(src)) + return false; + value = type::ensure(src); + return static_cast(value); + } + + static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) { + return src.inc_ref(); + } + PYBIND11_TYPE_CASTER(type, handle_type_name::name); +}; + +template +struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return npy_api::get().PyArray_EquivTypes_(dtype::of().ptr(), dtype(b).ptr()); + } +}; + +template +struct npy_format_descriptor_name; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name = _::value>( + _("bool"), _::value>("int", "uint") + _() + ); +}; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name = _::value || std::is_same::value>( + _("float") + _(), _("longdouble") + ); +}; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name = _::value + || std::is_same::value>( + _("complex") + _(), _("longcomplex") + ); +}; + +template +struct npy_format_descriptor::value>> + : npy_format_descriptor_name { +private: + // NB: the order here must match the one in common.h + constexpr static const int values[15] = { + npy_api::NPY_BOOL_, + npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_, + npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_, + npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_, + npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_ + }; + +public: + static constexpr int value = values[detail::is_fmt_numeric::index]; + + static pybind11::dtype dtype() { + if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) + return reinterpret_borrow(ptr); + pybind11_fail("Unsupported buffer format!"); + } +}; + +#define PYBIND11_DECL_CHAR_FMT \ + static constexpr auto name = _("S") + _(); \ + static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); } +template struct npy_format_descriptor { PYBIND11_DECL_CHAR_FMT }; +template struct npy_format_descriptor> { PYBIND11_DECL_CHAR_FMT }; +#undef PYBIND11_DECL_CHAR_FMT + +template struct npy_format_descriptor::is_array>> { +private: + using base_descr = npy_format_descriptor::type>; +public: + static_assert(!array_info::is_empty, "Zero-sized arrays are not supported"); + + static constexpr auto name = _("(") + array_info::extents + _(")") + base_descr::name; + static pybind11::dtype dtype() { + list shape; + array_info::append_extents(shape); + return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape)); + } +}; + +template struct npy_format_descriptor::value>> { +private: + using base_descr = npy_format_descriptor::type>; +public: + static constexpr auto name = base_descr::name; + static pybind11::dtype dtype() { return base_descr::dtype(); } +}; + +struct field_descriptor { + const char *name; + ssize_t offset; + ssize_t size; + std::string format; + dtype descr; +}; + +inline PYBIND11_NOINLINE void register_structured_dtype( + any_container fields, + const std::type_info& tinfo, ssize_t itemsize, + bool (*direct_converter)(PyObject *, void *&)) { + + auto& numpy_internals = get_numpy_internals(); + if (numpy_internals.get_type_info(tinfo, false)) + pybind11_fail("NumPy: dtype is already registered"); + + list names, formats, offsets; + for (auto field : *fields) { + if (!field.descr) + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + + field.name + "` @ " + tinfo.name()); + names.append(PYBIND11_STR_TYPE(field.name)); + formats.append(field.descr); + offsets.append(pybind11::int_(field.offset)); + } + auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); + + // There is an existing bug in NumPy (as of v1.11): trailing bytes are + // not encoded explicitly into the format string. This will supposedly + // get fixed in v1.12; for further details, see these: + // - https://github.com/numpy/numpy/issues/7797 + // - https://github.com/numpy/numpy/pull/7798 + // Because of this, we won't use numpy's logic to generate buffer format + // strings and will just do it ourselves. + std::vector ordered_fields(std::move(fields)); + std::sort(ordered_fields.begin(), ordered_fields.end(), + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); + ssize_t offset = 0; + std::ostringstream oss; + // mark the structure as unaligned with '^', because numpy and C++ don't + // always agree about alignment (particularly for complex), and we're + // explicitly listing all our padding. This depends on none of the fields + // overriding the endianness. Putting the ^ in front of individual fields + // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049 + oss << "^T{"; + for (auto& field : ordered_fields) { + if (field.offset > offset) + oss << (field.offset - offset) << 'x'; + oss << field.format << ':' << field.name << ':'; + offset = field.offset + field.size; + } + if (itemsize > offset) + oss << (itemsize - offset) << 'x'; + oss << '}'; + auto format_str = oss.str(); + + // Sanity check: verify that NumPy properly parses our buffer format string + auto& api = npy_api::get(); + auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); + if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) + pybind11_fail("NumPy: invalid buffer descriptor!"); + + auto tindex = std::type_index(tinfo); + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; + get_internals().direct_conversions[tindex].push_back(direct_converter); +} + +template struct npy_format_descriptor { + static_assert(is_pod_struct::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype"); + + static constexpr auto name = make_caster::name; + + static pybind11::dtype dtype() { + return reinterpret_borrow(dtype_ptr()); + } + + static std::string format() { + static auto format_str = get_numpy_internals().get_type_info(true)->format_str; + return format_str; + } + + static void register_dtype(any_container fields) { + register_structured_dtype(std::move(fields), typeid(typename std::remove_cv::type), + sizeof(T), &direct_converter); + } + +private: + static PyObject* dtype_ptr() { + static PyObject* ptr = get_numpy_internals().get_type_info(true)->dtype_ptr; + return ptr; + } + + static bool direct_converter(PyObject *obj, void*& value) { + auto& api = npy_api::get(); + if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) + return false; + if (auto descr = reinterpret_steal(api.PyArray_DescrFromScalar_(obj))) { + if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { + value = ((PyVoidScalarObject_Proxy *) obj)->obval; + return true; + } + } + return false; + } +}; + +#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code) +# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0) +# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0) +#else + +#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ + ::pybind11::detail::field_descriptor { \ + Name, offsetof(T, Field), sizeof(decltype(std::declval().Field)), \ + ::pybind11::format_descriptor().Field)>::format(), \ + ::pybind11::detail::npy_format_descriptor().Field)>::dtype() \ + } + +// Extract name, offset and format descriptor for a struct field +#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field) + +// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro +// (C) William Swanson, Paul Fultz +#define PYBIND11_EVAL0(...) __VA_ARGS__ +#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__))) +#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__))) +#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__))) +#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__))) +#define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__))) +#define PYBIND11_MAP_END(...) +#define PYBIND11_MAP_OUT +#define PYBIND11_MAP_COMMA , +#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END +#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT +#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0) +#define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next) +#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround +#define PYBIND11_MAP_LIST_NEXT1(test, next) \ + PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) +#else +#define PYBIND11_MAP_LIST_NEXT1(test, next) \ + PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) +#endif +#define PYBIND11_MAP_LIST_NEXT(test, next) \ + PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) +#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \ + f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__) +#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \ + f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__) +// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ... +#define PYBIND11_MAP_LIST(f, t, ...) \ + PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0)) + +#define PYBIND11_NUMPY_DTYPE(Type, ...) \ + ::pybind11::detail::npy_format_descriptor::register_dtype \ + (::std::vector<::pybind11::detail::field_descriptor> \ + {PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)}) + +#ifdef _MSC_VER +#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ + PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) +#else +#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ + PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) +#endif +#define PYBIND11_MAP2_LIST_NEXT(test, next) \ + PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) +#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \ + f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__) +#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \ + f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__) +// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ... +#define PYBIND11_MAP2_LIST(f, t, ...) \ + PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0)) + +#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \ + ::pybind11::detail::npy_format_descriptor::register_dtype \ + (::std::vector<::pybind11::detail::field_descriptor> \ + {PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)}) + +#endif // __CLION_IDE__ + +template +using array_iterator = typename std::add_pointer::type; + +template +array_iterator array_begin(const buffer_info& buffer) { + return array_iterator(reinterpret_cast(buffer.ptr)); +} + +template +array_iterator array_end(const buffer_info& buffer) { + return array_iterator(reinterpret_cast(buffer.ptr) + buffer.size); +} + +class common_iterator { +public: + using container_type = std::vector; + using value_type = container_type::value_type; + using size_type = container_type::size_type; + + common_iterator() : p_ptr(0), m_strides() {} + + common_iterator(void* ptr, const container_type& strides, const container_type& shape) + : p_ptr(reinterpret_cast(ptr)), m_strides(strides.size()) { + m_strides.back() = static_cast(strides.back()); + for (size_type i = m_strides.size() - 1; i != 0; --i) { + size_type j = i - 1; + value_type s = static_cast(shape[i]); + m_strides[j] = strides[j] + m_strides[i] - strides[i] * s; + } + } + + void increment(size_type dim) { + p_ptr += m_strides[dim]; + } + + void* data() const { + return p_ptr; + } + +private: + char* p_ptr; + container_type m_strides; +}; + +template class multi_array_iterator { +public: + using container_type = std::vector; + + multi_array_iterator(const std::array &buffers, + const container_type &shape) + : m_shape(shape.size()), m_index(shape.size(), 0), + m_common_iterator() { + + // Manual copy to avoid conversion warning if using std::copy + for (size_t i = 0; i < shape.size(); ++i) + m_shape[i] = shape[i]; + + container_type strides(shape.size()); + for (size_t i = 0; i < N; ++i) + init_common_iterator(buffers[i], shape, m_common_iterator[i], strides); + } + + multi_array_iterator& operator++() { + for (size_t j = m_index.size(); j != 0; --j) { + size_t i = j - 1; + if (++m_index[i] != m_shape[i]) { + increment_common_iterator(i); + break; + } else { + m_index[i] = 0; + } + } + return *this; + } + + template T* data() const { + return reinterpret_cast(m_common_iterator[K].data()); + } + +private: + + using common_iter = common_iterator; + + void init_common_iterator(const buffer_info &buffer, + const container_type &shape, + common_iter &iterator, + container_type &strides) { + auto buffer_shape_iter = buffer.shape.rbegin(); + auto buffer_strides_iter = buffer.strides.rbegin(); + auto shape_iter = shape.rbegin(); + auto strides_iter = strides.rbegin(); + + while (buffer_shape_iter != buffer.shape.rend()) { + if (*shape_iter == *buffer_shape_iter) + *strides_iter = *buffer_strides_iter; + else + *strides_iter = 0; + + ++buffer_shape_iter; + ++buffer_strides_iter; + ++shape_iter; + ++strides_iter; + } + + std::fill(strides_iter, strides.rend(), 0); + iterator = common_iter(buffer.ptr, strides, shape); + } + + void increment_common_iterator(size_t dim) { + for (auto &iter : m_common_iterator) + iter.increment(dim); + } + + container_type m_shape; + container_type m_index; + std::array m_common_iterator; +}; + +enum class broadcast_trivial { non_trivial, c_trivial, f_trivial }; + +// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial +// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a +// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage +// buffer; returns `non_trivial` otherwise. +template +broadcast_trivial broadcast(const std::array &buffers, ssize_t &ndim, std::vector &shape) { + ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) { + return std::max(res, buf.ndim); + }); + + shape.clear(); + shape.resize((size_t) ndim, 1); + + // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or + // the full size). + for (size_t i = 0; i < N; ++i) { + auto res_iter = shape.rbegin(); + auto end = buffers[i].shape.rend(); + for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) { + const auto &dim_size_in = *shape_iter; + auto &dim_size_out = *res_iter; + + // Each input dimension can either be 1 or `n`, but `n` values must match across buffers + if (dim_size_out == 1) + dim_size_out = dim_size_in; + else if (dim_size_in != 1 && dim_size_in != dim_size_out) + pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); + } + } + + bool trivial_broadcast_c = true; + bool trivial_broadcast_f = true; + for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) { + if (buffers[i].size == 1) + continue; + + // Require the same number of dimensions: + if (buffers[i].ndim != ndim) + return broadcast_trivial::non_trivial; + + // Require all dimensions be full-size: + if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) + return broadcast_trivial::non_trivial; + + // Check for C contiguity (but only if previous inputs were also C contiguous) + if (trivial_broadcast_c) { + ssize_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.crend(); + for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin(); + trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_c = false; + } + } + + // Check for Fortran contiguity (if previous inputs were also F contiguous) + if (trivial_broadcast_f) { + ssize_t expect_stride = buffers[i].itemsize; + auto end = buffers[i].shape.cend(); + for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin(); + trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) { + if (expect_stride == *stride_iter) + expect_stride *= *shape_iter; + else + trivial_broadcast_f = false; + } + } + } + + return + trivial_broadcast_c ? broadcast_trivial::c_trivial : + trivial_broadcast_f ? broadcast_trivial::f_trivial : + broadcast_trivial::non_trivial; +} + +template +struct vectorize_arg { + static_assert(!std::is_rvalue_reference::value, "Functions with rvalue reference arguments cannot be vectorized"); + // The wrapped function gets called with this type: + using call_type = remove_reference_t; + // Is this a vectorized argument? + static constexpr bool vectorize = + satisfies_any_of::value && + satisfies_none_of::value && + (!std::is_reference::value || + (std::is_lvalue_reference::value && std::is_const::value)); + // Accept this type: an array for vectorized types, otherwise the type as-is: + using type = conditional_t, array::forcecast>, T>; +}; + +template +struct vectorize_helper { +private: + static constexpr size_t N = sizeof...(Args); + static constexpr size_t NVectorized = constexpr_sum(vectorize_arg::vectorize...); + static_assert(NVectorized >= 1, + "pybind11::vectorize(...) requires a function with at least one vectorizable argument"); + +public: + template + explicit vectorize_helper(T &&f) : f(std::forward(f)) { } + + object operator()(typename vectorize_arg::type... args) { + return run(args..., + make_index_sequence(), + select_indices::vectorize...>(), + make_index_sequence()); + } + +private: + remove_reference_t f; + + // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag + // when arg_call_types is manually inlined. + using arg_call_types = std::tuple::call_type...>; + template using param_n_t = typename std::tuple_element::type; + + // Runs a vectorized function given arguments tuple and three index sequences: + // - Index is the full set of 0 ... (N-1) argument indices; + // - VIndex is the subset of argument indices with vectorized parameters, letting us access + // vectorized arguments (anything not in this sequence is passed through) + // - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that + // we can store vectorized buffer_infos in an array (argument VIndex has its buffer at + // index BIndex in the array). + template object run( + typename vectorize_arg::type &...args, + index_sequence i_seq, index_sequence vi_seq, index_sequence bi_seq) { + + // Pointers to values the function was called with; the vectorized ones set here will start + // out as array_t pointers, but they will be changed them to T pointers before we make + // call the wrapped function. Non-vectorized pointers are left as-is. + std::array params{{ &args... }}; + + // The array of `buffer_info`s of vectorized arguments: + std::array buffers{{ reinterpret_cast(params[VIndex])->request()... }}; + + /* Determine dimensions parameters of output array */ + ssize_t nd = 0; + std::vector shape(0); + auto trivial = broadcast(buffers, nd, shape); + size_t ndim = (size_t) nd; + + size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies()); + + // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e. + // not wrapped in an array). + if (size == 1 && ndim == 0) { + PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr); + return cast(f(*reinterpret_cast *>(params[Index])...)); + } + + array_t result; + if (trivial == broadcast_trivial::f_trivial) result = array_t(shape); + else result = array_t(shape); + + if (size == 0) return result; + + /* Call the function */ + if (trivial == broadcast_trivial::non_trivial) + apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq); + else + apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq); + + return result; + } + + template + void apply_trivial(std::array &buffers, + std::array ¶ms, + Return *out, + size_t size, + index_sequence, index_sequence, index_sequence) { + + // Initialize an array of mutable byte references and sizes with references set to the + // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size + // (except for singletons, which get an increment of 0). + std::array, NVectorized> vecparams{{ + std::pair( + reinterpret_cast(params[VIndex] = buffers[BIndex].ptr), + buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t) + )... + }}; + + for (size_t i = 0; i < size; ++i) { + out[i] = f(*reinterpret_cast *>(params[Index])...); + for (auto &x : vecparams) x.first += x.second; + } + } + + template + void apply_broadcast(std::array &buffers, + std::array ¶ms, + array_t &output_array, + index_sequence, index_sequence, index_sequence) { + + buffer_info output = output_array.request(); + multi_array_iterator input_iter(buffers, output.shape); + + for (array_iterator iter = array_begin(output), end = array_end(output); + iter != end; + ++iter, ++input_iter) { + PYBIND11_EXPAND_SIDE_EFFECTS(( + params[VIndex] = input_iter.template data() + )); + *iter = f(*reinterpret_cast *>(std::get(params))...); + } + } +}; + +template +vectorize_helper +vectorize_extractor(const Func &f, Return (*) (Args ...)) { + return detail::vectorize_helper(f); +} + +template struct handle_type_name> { + static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor::name + _("]"); +}; + +NAMESPACE_END(detail) + +// Vanilla pointer vectorizer: +template +detail::vectorize_helper +vectorize(Return (*f) (Args ...)) { + return detail::vectorize_helper(f); +} + +// lambda vectorizer: +template ::value, int> = 0> +auto vectorize(Func &&f) -> decltype( + detail::vectorize_extractor(std::forward(f), (detail::function_signature_t *) nullptr)) { + return detail::vectorize_extractor(std::forward(f), (detail::function_signature_t *) nullptr); +} + +// Vectorize a class method (non-const): +template ())), Return, Class *, Args...>> +Helper vectorize(Return (Class::*f)(Args...)) { + return Helper(std::mem_fn(f)); +} + +// Vectorize a class method (const): +template ())), Return, const Class *, Args...>> +Helper vectorize(Return (Class::*f)(Args...) const) { + return Helper(std::mem_fn(f)); +} + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/mmocr/models/textdet/postprocess/include/pybind11/operators.h b/mmocr/models/textdet/postprocess/include/pybind11/operators.h new file mode 100644 index 00000000..b3dd62c3 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/operators.h @@ -0,0 +1,168 @@ +/* + pybind11/operator.h: Metatemplates for operator overloading + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" + +#if defined(__clang__) && !defined(__INTEL_COMPILER) +# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) +#elif defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// Enumeration with all supported operator types +enum op_id : int { + op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, + op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, + op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, + op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, + op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, + op_repr, op_truediv, op_itruediv, op_hash +}; + +enum op_type : int { + op_l, /* base type on left */ + op_r, /* base type on right */ + op_u /* unary operator */ +}; + +struct self_t { }; +static const self_t self = self_t(); + +/// Type for an unused type slot +struct undefined_t { }; + +/// Don't warn about an unused variable +inline self_t __self() { return self; } + +/// base template of operator implementations +template struct op_impl { }; + +/// Operator implementation generator +template struct op_ { + template void execute(Class &cl, const Extra&... extra) const { + using Base = typename Class::type; + using L_type = conditional_t::value, Base, L>; + using R_type = conditional_t::value, Base, R>; + using op = op_impl; + cl.def(op::name(), &op::execute, is_operator(), extra...); + #if PY_MAJOR_VERSION < 3 + if (id == op_truediv || id == op_itruediv) + cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", + &op::execute, is_operator(), extra...); + #endif + } + template void execute_cast(Class &cl, const Extra&... extra) const { + using Base = typename Class::type; + using L_type = conditional_t::value, Base, L>; + using R_type = conditional_t::value, Base, R>; + using op = op_impl; + cl.def(op::name(), &op::execute_cast, is_operator(), extra...); + #if PY_MAJOR_VERSION < 3 + if (id == op_truediv || id == op_itruediv) + cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", + &op::execute, is_operator(), extra...); + #endif + } +}; + +#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ +template struct op_impl { \ + static char const* name() { return "__" #id "__"; } \ + static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ + static B execute_cast(const L &l, const R &r) { return B(expr); } \ +}; \ +template struct op_impl { \ + static char const* name() { return "__" #rid "__"; } \ + static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ + static B execute_cast(const R &r, const L &l) { return B(expr); } \ +}; \ +inline op_ op(const self_t &, const self_t &) { \ + return op_(); \ +} \ +template op_ op(const self_t &, const T &) { \ + return op_(); \ +} \ +template op_ op(const T &, const self_t &) { \ + return op_(); \ +} + +#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ +template struct op_impl { \ + static char const* name() { return "__" #id "__"; } \ + static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ + static B execute_cast(L &l, const R &r) { return B(expr); } \ +}; \ +template op_ op(const self_t &, const T &) { \ + return op_(); \ +} + +#define PYBIND11_UNARY_OPERATOR(id, op, expr) \ +template struct op_impl { \ + static char const* name() { return "__" #id "__"; } \ + static auto execute(const L &l) -> decltype(expr) { return expr; } \ + static B execute_cast(const L &l) { return B(expr); } \ +}; \ +inline op_ op(const self_t &) { \ + return op_(); \ +} + +PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) +PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) +PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) +PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) +PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) +PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) +PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) +PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) +PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) +PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) +PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) +PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) +PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) +PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) +PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) +PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) +//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) +PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) +PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) +PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) +PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) +PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) +PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) +PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) +PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) +PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) +PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) +PYBIND11_UNARY_OPERATOR(neg, operator-, -l) +PYBIND11_UNARY_OPERATOR(pos, operator+, +l) +PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) +PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) +PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) +PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) +PYBIND11_UNARY_OPERATOR(int, int_, (int) l) +PYBIND11_UNARY_OPERATOR(float, float_, (double) l) + +#undef PYBIND11_BINARY_OPERATOR +#undef PYBIND11_INPLACE_OPERATOR +#undef PYBIND11_UNARY_OPERATOR +NAMESPACE_END(detail) + +using detail::self; + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) +# pragma warning(pop) +#endif diff --git a/mmocr/models/textdet/postprocess/include/pybind11/options.h b/mmocr/models/textdet/postprocess/include/pybind11/options.h new file mode 100644 index 00000000..cc1e1f6f --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/options.h @@ -0,0 +1,65 @@ +/* + pybind11/options.h: global settings that are configurable at runtime. + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +class options { +public: + + // Default RAII constructor, which leaves settings as they currently are. + options() : previous_state(global_state()) {} + + // Class is non-copyable. + options(const options&) = delete; + options& operator=(const options&) = delete; + + // Destructor, which restores settings that were in effect before. + ~options() { + global_state() = previous_state; + } + + // Setter methods (affect the global state): + + options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } + + options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } + + options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } + + options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } + + // Getter methods (return the global state): + + static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } + + static bool show_function_signatures() { return global_state().show_function_signatures; } + + // This type is not meant to be allocated on the heap. + void* operator new(size_t) = delete; + +private: + + struct state { + bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. + bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. + }; + + static state &global_state() { + static state instance; + return instance; + } + + state previous_state; +}; + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/pybind11.h b/mmocr/models/textdet/postprocess/include/pybind11/pybind11.h new file mode 100644 index 00000000..7fa0f0e1 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/pybind11.h @@ -0,0 +1,2094 @@ +/* + pybind11/pybind11.h: Main header file of the C++11 python + binding generator library + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#if defined(__INTEL_COMPILER) +# pragma warning push +# pragma warning disable 68 // integer conversion resulted in a change of sign +# pragma warning disable 186 // pointless comparison of unsigned integer with zero +# pragma warning disable 878 // incompatible exception specifications +# pragma warning disable 1334 // the "template" keyword used for syntactic disambiguation may only be used within a template +# pragma warning disable 1682 // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) +# pragma warning disable 1786 // function "strdup" was declared deprecated +# pragma warning disable 1875 // offsetof applied to non-POD (Plain Old Data) types is nonstandard +# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline" +#elif defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter +# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +# pragma warning(disable: 4512) // warning C4512: Assignment operator was implicitly defined as deleted +# pragma warning(disable: 4800) // warning C4800: 'int': forcing value to bool 'true' or 'false' (performance warning) +# pragma warning(disable: 4996) // warning C4996: The POSIX name for this item is deprecated. Instead, use the ISO C and C++ conformant name +# pragma warning(disable: 4702) // warning C4702: unreachable code +# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified +#elif defined(__GNUG__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wunused-but-set-parameter" +# pragma GCC diagnostic ignored "-Wunused-but-set-variable" +# pragma GCC diagnostic ignored "-Wmissing-field-initializers" +# pragma GCC diagnostic ignored "-Wstrict-aliasing" +# pragma GCC diagnostic ignored "-Wattributes" +# if __GNUC__ >= 7 +# pragma GCC diagnostic ignored "-Wnoexcept-type" +# endif +#endif + +#include "attr.h" +#include "options.h" +#include "detail/class.h" +#include "detail/init.h" + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// Wraps an arbitrary C++ function/method/lambda function/.. into a callable Python object +class cpp_function : public function { +public: + cpp_function() { } + cpp_function(std::nullptr_t) { } + + /// Construct a cpp_function from a vanilla function pointer + template + cpp_function(Return (*f)(Args...), const Extra&... extra) { + initialize(f, f, extra...); + } + + /// Construct a cpp_function from a lambda function (possibly with internal state) + template ::value>> + cpp_function(Func &&f, const Extra&... extra) { + initialize(std::forward(f), + (detail::function_signature_t *) nullptr, extra...); + } + + /// Construct a cpp_function from a class method (non-const) + template + cpp_function(Return (Class::*f)(Arg...), const Extra&... extra) { + initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(args...); }, + (Return (*) (Class *, Arg...)) nullptr, extra...); + } + + /// Construct a cpp_function from a class method (const) + template + cpp_function(Return (Class::*f)(Arg...) const, const Extra&... extra) { + initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(args...); }, + (Return (*)(const Class *, Arg ...)) nullptr, extra...); + } + + /// Return the function name + object name() const { return attr("__name__"); } + +protected: + /// Space optimization: don't inline this frequently instantiated fragment + PYBIND11_NOINLINE detail::function_record *make_function_record() { + return new detail::function_record(); + } + + /// Special internal constructor for functors, lambda functions, etc. + template + void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) { + using namespace detail; + struct capture { remove_reference_t f; }; + + /* Store the function including any extra state it might have (e.g. a lambda capture object) */ + auto rec = make_function_record(); + + /* Store the capture object directly in the function record if there is enough space */ + if (sizeof(capture) <= sizeof(rec->data)) { + /* Without these pragmas, GCC warns that there might not be + enough space to use the placement new operator. However, the + 'if' statement above ensures that this is the case. */ +#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wplacement-new" +#endif + new ((capture *) &rec->data) capture { std::forward(f) }; +#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 +# pragma GCC diagnostic pop +#endif + if (!std::is_trivially_destructible::value) + rec->free_data = [](function_record *r) { ((capture *) &r->data)->~capture(); }; + } else { + rec->data[0] = new capture { std::forward(f) }; + rec->free_data = [](function_record *r) { delete ((capture *) r->data[0]); }; + } + + /* Type casters for the function arguments and return value */ + using cast_in = argument_loader; + using cast_out = make_caster< + conditional_t::value, void_type, Return> + >; + + static_assert(expected_num_args(sizeof...(Args), cast_in::has_args, cast_in::has_kwargs), + "The number of argument annotations does not match the number of function arguments"); + + /* Dispatch code which converts function arguments and performs the actual function call */ + rec->impl = [](function_call &call) -> handle { + cast_in args_converter; + + /* Try to cast the function arguments into the C++ domain */ + if (!args_converter.load_args(call)) + return PYBIND11_TRY_NEXT_OVERLOAD; + + /* Invoke call policy pre-call hook */ + process_attributes::precall(call); + + /* Get a pointer to the capture object */ + auto data = (sizeof(capture) <= sizeof(call.func.data) + ? &call.func.data : call.func.data[0]); + capture *cap = const_cast(reinterpret_cast(data)); + + /* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */ + return_value_policy policy = return_value_policy_override::policy(call.func.policy); + + /* Function scope guard -- defaults to the compile-to-nothing `void_type` */ + using Guard = extract_guard_t; + + /* Perform the function call */ + handle result = cast_out::cast( + std::move(args_converter).template call(cap->f), policy, call.parent); + + /* Invoke call policy post-call hook */ + process_attributes::postcall(call, result); + + return result; + }; + + /* Process any user-provided function attributes */ + process_attributes::init(extra..., rec); + + /* Generate a readable signature describing the function's arguments and return value types */ + static constexpr auto signature = _("(") + cast_in::arg_names + _(") -> ") + cast_out::name; + PYBIND11_DESCR_CONSTEXPR auto types = decltype(signature)::types(); + + /* Register the function with Python from generic (non-templated) code */ + initialize_generic(rec, signature.text, types.data(), sizeof...(Args)); + + if (cast_in::has_args) rec->has_args = true; + if (cast_in::has_kwargs) rec->has_kwargs = true; + + /* Stash some additional information used by an important optimization in 'functional.h' */ + using FunctionType = Return (*)(Args...); + constexpr bool is_function_ptr = + std::is_convertible::value && + sizeof(capture) == sizeof(void *); + if (is_function_ptr) { + rec->is_stateless = true; + rec->data[1] = const_cast(reinterpret_cast(&typeid(FunctionType))); + } + } + + /// Register a function call with Python (generic non-templated code goes here) + void initialize_generic(detail::function_record *rec, const char *text, + const std::type_info *const *types, size_t args) { + + /* Create copies of all referenced C-style strings */ + rec->name = strdup(rec->name ? rec->name : ""); + if (rec->doc) rec->doc = strdup(rec->doc); + for (auto &a: rec->args) { + if (a.name) + a.name = strdup(a.name); + if (a.descr) + a.descr = strdup(a.descr); + else if (a.value) + a.descr = strdup(a.value.attr("__repr__")().cast().c_str()); + } + + rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__"); + +#if !defined(NDEBUG) && !defined(PYBIND11_DISABLE_NEW_STYLE_INIT_WARNING) + if (rec->is_constructor && !rec->is_new_style_constructor) { + const auto class_name = std::string(((PyTypeObject *) rec->scope.ptr())->tp_name); + const auto func_name = std::string(rec->name); + PyErr_WarnEx( + PyExc_FutureWarning, + ("pybind11-bound class '" + class_name + "' is using an old-style " + "placement-new '" + func_name + "' which has been deprecated. See " + "the upgrade guide in pybind11's docs. This message is only visible " + "when compiled in debug mode.").c_str(), 0 + ); + } +#endif + + /* Generate a proper function signature */ + std::string signature; + size_t type_index = 0, arg_index = 0; + for (auto *pc = text; *pc != '\0'; ++pc) { + const auto c = *pc; + + if (c == '{') { + // Write arg name for everything except *args and **kwargs. + if (*(pc + 1) == '*') + continue; + + if (arg_index < rec->args.size() && rec->args[arg_index].name) { + signature += rec->args[arg_index].name; + } else if (arg_index == 0 && rec->is_method) { + signature += "self"; + } else { + signature += "arg" + std::to_string(arg_index - (rec->is_method ? 1 : 0)); + } + signature += ": "; + } else if (c == '}') { + // Write default value if available. + if (arg_index < rec->args.size() && rec->args[arg_index].descr) { + signature += " = "; + signature += rec->args[arg_index].descr; + } + arg_index++; + } else if (c == '%') { + const std::type_info *t = types[type_index++]; + if (!t) + pybind11_fail("Internal error while parsing type signature (1)"); + if (auto tinfo = detail::get_type_info(*t)) { + handle th((PyObject *) tinfo->type); + signature += + th.attr("__module__").cast() + "." + + th.attr("__qualname__").cast(); // Python 3.3+, but we backport it to earlier versions + } else if (rec->is_new_style_constructor && arg_index == 0) { + // A new-style `__init__` takes `self` as `value_and_holder`. + // Rewrite it to the proper class type. + signature += + rec->scope.attr("__module__").cast() + "." + + rec->scope.attr("__qualname__").cast(); + } else { + std::string tname(t->name()); + detail::clean_type_id(tname); + signature += tname; + } + } else { + signature += c; + } + } + if (arg_index != args || types[type_index] != nullptr) + pybind11_fail("Internal error while parsing type signature (2)"); + +#if PY_MAJOR_VERSION < 3 + if (strcmp(rec->name, "__next__") == 0) { + std::free(rec->name); + rec->name = strdup("next"); + } else if (strcmp(rec->name, "__bool__") == 0) { + std::free(rec->name); + rec->name = strdup("__nonzero__"); + } +#endif + rec->signature = strdup(signature.c_str()); + rec->args.shrink_to_fit(); + rec->nargs = (std::uint16_t) args; + + if (rec->sibling && PYBIND11_INSTANCE_METHOD_CHECK(rec->sibling.ptr())) + rec->sibling = PYBIND11_INSTANCE_METHOD_GET_FUNCTION(rec->sibling.ptr()); + + detail::function_record *chain = nullptr, *chain_start = rec; + if (rec->sibling) { + if (PyCFunction_Check(rec->sibling.ptr())) { + auto rec_capsule = reinterpret_borrow(PyCFunction_GET_SELF(rec->sibling.ptr())); + chain = (detail::function_record *) rec_capsule; + /* Never append a method to an overload chain of a parent class; + instead, hide the parent's overloads in this case */ + if (!chain->scope.is(rec->scope)) + chain = nullptr; + } + // Don't trigger for things like the default __init__, which are wrapper_descriptors that we are intentionally replacing + else if (!rec->sibling.is_none() && rec->name[0] != '_') + pybind11_fail("Cannot overload existing non-function object \"" + std::string(rec->name) + + "\" with a function of the same name"); + } + + if (!chain) { + /* No existing overload was found, create a new function object */ + rec->def = new PyMethodDef(); + std::memset(rec->def, 0, sizeof(PyMethodDef)); + rec->def->ml_name = rec->name; + rec->def->ml_meth = reinterpret_cast(reinterpret_cast(*dispatcher)); + rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS; + + capsule rec_capsule(rec, [](void *ptr) { + destruct((detail::function_record *) ptr); + }); + + object scope_module; + if (rec->scope) { + if (hasattr(rec->scope, "__module__")) { + scope_module = rec->scope.attr("__module__"); + } else if (hasattr(rec->scope, "__name__")) { + scope_module = rec->scope.attr("__name__"); + } + } + + m_ptr = PyCFunction_NewEx(rec->def, rec_capsule.ptr(), scope_module.ptr()); + if (!m_ptr) + pybind11_fail("cpp_function::cpp_function(): Could not allocate function object"); + } else { + /* Append at the end of the overload chain */ + m_ptr = rec->sibling.ptr(); + inc_ref(); + chain_start = chain; + if (chain->is_method != rec->is_method) + pybind11_fail("overloading a method with both static and instance methods is not supported; " + #if defined(NDEBUG) + "compile in debug mode for more details" + #else + "error while attempting to bind " + std::string(rec->is_method ? "instance" : "static") + " method " + + std::string(pybind11::str(rec->scope.attr("__name__"))) + "." + std::string(rec->name) + signature + #endif + ); + while (chain->next) + chain = chain->next; + chain->next = rec; + } + + std::string signatures; + int index = 0; + /* Create a nice pydoc rec including all signatures and + docstrings of the functions in the overload chain */ + if (chain && options::show_function_signatures()) { + // First a generic signature + signatures += rec->name; + signatures += "(*args, **kwargs)\n"; + signatures += "Overloaded function.\n\n"; + } + // Then specific overload signatures + bool first_user_def = true; + for (auto it = chain_start; it != nullptr; it = it->next) { + if (options::show_function_signatures()) { + if (index > 0) signatures += "\n"; + if (chain) + signatures += std::to_string(++index) + ". "; + signatures += rec->name; + signatures += it->signature; + signatures += "\n"; + } + if (it->doc && strlen(it->doc) > 0 && options::show_user_defined_docstrings()) { + // If we're appending another docstring, and aren't printing function signatures, we + // need to append a newline first: + if (!options::show_function_signatures()) { + if (first_user_def) first_user_def = false; + else signatures += "\n"; + } + if (options::show_function_signatures()) signatures += "\n"; + signatures += it->doc; + if (options::show_function_signatures()) signatures += "\n"; + } + } + + /* Install docstring */ + PyCFunctionObject *func = (PyCFunctionObject *) m_ptr; + if (func->m_ml->ml_doc) + std::free(const_cast(func->m_ml->ml_doc)); + func->m_ml->ml_doc = strdup(signatures.c_str()); + + if (rec->is_method) { + m_ptr = PYBIND11_INSTANCE_METHOD_NEW(m_ptr, rec->scope.ptr()); + if (!m_ptr) + pybind11_fail("cpp_function::cpp_function(): Could not allocate instance method object"); + Py_DECREF(func); + } + } + + /// When a cpp_function is GCed, release any memory allocated by pybind11 + static void destruct(detail::function_record *rec) { + while (rec) { + detail::function_record *next = rec->next; + if (rec->free_data) + rec->free_data(rec); + std::free((char *) rec->name); + std::free((char *) rec->doc); + std::free((char *) rec->signature); + for (auto &arg: rec->args) { + std::free(const_cast(arg.name)); + std::free(const_cast(arg.descr)); + arg.value.dec_ref(); + } + if (rec->def) { + std::free(const_cast(rec->def->ml_doc)); + delete rec->def; + } + delete rec; + rec = next; + } + } + + /// Main dispatch logic for calls to functions bound using pybind11 + static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) { + using namespace detail; + + /* Iterator over the list of potentially admissible overloads */ + const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr), + *it = overloads; + + /* Need to know how many arguments + keyword arguments there are to pick the right overload */ + const size_t n_args_in = (size_t) PyTuple_GET_SIZE(args_in); + + handle parent = n_args_in > 0 ? PyTuple_GET_ITEM(args_in, 0) : nullptr, + result = PYBIND11_TRY_NEXT_OVERLOAD; + + auto self_value_and_holder = value_and_holder(); + if (overloads->is_constructor) { + const auto tinfo = get_type_info((PyTypeObject *) overloads->scope.ptr()); + const auto pi = reinterpret_cast(parent.ptr()); + self_value_and_holder = pi->get_value_and_holder(tinfo, false); + + if (!self_value_and_holder.type || !self_value_and_holder.inst) { + PyErr_SetString(PyExc_TypeError, "__init__(self, ...) called with invalid `self` argument"); + return nullptr; + } + + // If this value is already registered it must mean __init__ is invoked multiple times; + // we really can't support that in C++, so just ignore the second __init__. + if (self_value_and_holder.instance_registered()) + return none().release().ptr(); + } + + try { + // We do this in two passes: in the first pass, we load arguments with `convert=false`; + // in the second, we allow conversion (except for arguments with an explicit + // py::arg().noconvert()). This lets us prefer calls without conversion, with + // conversion as a fallback. + std::vector second_pass; + + // However, if there are no overloads, we can just skip the no-convert pass entirely + const bool overloaded = it != nullptr && it->next != nullptr; + + for (; it != nullptr; it = it->next) { + + /* For each overload: + 1. Copy all positional arguments we were given, also checking to make sure that + named positional arguments weren't *also* specified via kwarg. + 2. If we weren't given enough, try to make up the omitted ones by checking + whether they were provided by a kwarg matching the `py::arg("name")` name. If + so, use it (and remove it from kwargs; if not, see if the function binding + provided a default that we can use. + 3. Ensure that either all keyword arguments were "consumed", or that the function + takes a kwargs argument to accept unconsumed kwargs. + 4. Any positional arguments still left get put into a tuple (for args), and any + leftover kwargs get put into a dict. + 5. Pack everything into a vector; if we have py::args or py::kwargs, they are an + extra tuple or dict at the end of the positional arguments. + 6. Call the function call dispatcher (function_record::impl) + + If one of these fail, move on to the next overload and keep trying until we get a + result other than PYBIND11_TRY_NEXT_OVERLOAD. + */ + + const function_record &func = *it; + size_t pos_args = func.nargs; // Number of positional arguments that we need + if (func.has_args) --pos_args; // (but don't count py::args + if (func.has_kwargs) --pos_args; // or py::kwargs) + + if (!func.has_args && n_args_in > pos_args) + continue; // Too many arguments for this overload + + if (n_args_in < pos_args && func.args.size() < pos_args) + continue; // Not enough arguments given, and not enough defaults to fill in the blanks + + function_call call(func, parent); + + size_t args_to_copy = std::min(pos_args, n_args_in); + size_t args_copied = 0; + + // 0. Inject new-style `self` argument + if (func.is_new_style_constructor) { + // The `value` may have been preallocated by an old-style `__init__` + // if it was a preceding candidate for overload resolution. + if (self_value_and_holder) + self_value_and_holder.type->dealloc(self_value_and_holder); + + call.init_self = PyTuple_GET_ITEM(args_in, 0); + call.args.push_back(reinterpret_cast(&self_value_and_holder)); + call.args_convert.push_back(false); + ++args_copied; + } + + // 1. Copy any position arguments given. + bool bad_arg = false; + for (; args_copied < args_to_copy; ++args_copied) { + const argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr; + if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) { + bad_arg = true; + break; + } + + handle arg(PyTuple_GET_ITEM(args_in, args_copied)); + if (arg_rec && !arg_rec->none && arg.is_none()) { + bad_arg = true; + break; + } + call.args.push_back(arg); + call.args_convert.push_back(arg_rec ? arg_rec->convert : true); + } + if (bad_arg) + continue; // Maybe it was meant for another overload (issue #688) + + // We'll need to copy this if we steal some kwargs for defaults + dict kwargs = reinterpret_borrow(kwargs_in); + + // 2. Check kwargs and, failing that, defaults that may help complete the list + if (args_copied < pos_args) { + bool copied_kwargs = false; + + for (; args_copied < pos_args; ++args_copied) { + const auto &arg = func.args[args_copied]; + + handle value; + if (kwargs_in && arg.name) + value = PyDict_GetItemString(kwargs.ptr(), arg.name); + + if (value) { + // Consume a kwargs value + if (!copied_kwargs) { + kwargs = reinterpret_steal(PyDict_Copy(kwargs.ptr())); + copied_kwargs = true; + } + PyDict_DelItemString(kwargs.ptr(), arg.name); + } else if (arg.value) { + value = arg.value; + } + + if (value) { + call.args.push_back(value); + call.args_convert.push_back(arg.convert); + } + else + break; + } + + if (args_copied < pos_args) + continue; // Not enough arguments, defaults, or kwargs to fill the positional arguments + } + + // 3. Check everything was consumed (unless we have a kwargs arg) + if (kwargs && kwargs.size() > 0 && !func.has_kwargs) + continue; // Unconsumed kwargs, but no py::kwargs argument to accept them + + // 4a. If we have a py::args argument, create a new tuple with leftovers + if (func.has_args) { + tuple extra_args; + if (args_to_copy == 0) { + // We didn't copy out any position arguments from the args_in tuple, so we + // can reuse it directly without copying: + extra_args = reinterpret_borrow(args_in); + } else if (args_copied >= n_args_in) { + extra_args = tuple(0); + } else { + size_t args_size = n_args_in - args_copied; + extra_args = tuple(args_size); + for (size_t i = 0; i < args_size; ++i) { + extra_args[i] = PyTuple_GET_ITEM(args_in, args_copied + i); + } + } + call.args.push_back(extra_args); + call.args_convert.push_back(false); + call.args_ref = std::move(extra_args); + } + + // 4b. If we have a py::kwargs, pass on any remaining kwargs + if (func.has_kwargs) { + if (!kwargs.ptr()) + kwargs = dict(); // If we didn't get one, send an empty one + call.args.push_back(kwargs); + call.args_convert.push_back(false); + call.kwargs_ref = std::move(kwargs); + } + + // 5. Put everything in a vector. Not technically step 5, we've been building it + // in `call.args` all along. + #if !defined(NDEBUG) + if (call.args.size() != func.nargs || call.args_convert.size() != func.nargs) + pybind11_fail("Internal error: function call dispatcher inserted wrong number of arguments!"); + #endif + + std::vector second_pass_convert; + if (overloaded) { + // We're in the first no-convert pass, so swap out the conversion flags for a + // set of all-false flags. If the call fails, we'll swap the flags back in for + // the conversion-allowed call below. + second_pass_convert.resize(func.nargs, false); + call.args_convert.swap(second_pass_convert); + } + + // 6. Call the function. + try { + loader_life_support guard{}; + result = func.impl(call); + } catch (reference_cast_error &) { + result = PYBIND11_TRY_NEXT_OVERLOAD; + } + + if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) + break; + + if (overloaded) { + // The (overloaded) call failed; if the call has at least one argument that + // permits conversion (i.e. it hasn't been explicitly specified `.noconvert()`) + // then add this call to the list of second pass overloads to try. + for (size_t i = func.is_method ? 1 : 0; i < pos_args; i++) { + if (second_pass_convert[i]) { + // Found one: swap the converting flags back in and store the call for + // the second pass. + call.args_convert.swap(second_pass_convert); + second_pass.push_back(std::move(call)); + break; + } + } + } + } + + if (overloaded && !second_pass.empty() && result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { + // The no-conversion pass finished without success, try again with conversion allowed + for (auto &call : second_pass) { + try { + loader_life_support guard{}; + result = call.func.impl(call); + } catch (reference_cast_error &) { + result = PYBIND11_TRY_NEXT_OVERLOAD; + } + + if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) { + // The error reporting logic below expects 'it' to be valid, as it would be + // if we'd encountered this failure in the first-pass loop. + if (!result) + it = &call.func; + break; + } + } + } + } catch (error_already_set &e) { + e.restore(); + return nullptr; + } catch (...) { + /* When an exception is caught, give each registered exception + translator a chance to translate it to a Python exception + in reverse order of registration. + + A translator may choose to do one of the following: + + - catch the exception and call PyErr_SetString or PyErr_SetObject + to set a standard (or custom) Python exception, or + - do nothing and let the exception fall through to the next translator, or + - delegate translation to the next translator by throwing a new type of exception. */ + + auto last_exception = std::current_exception(); + auto ®istered_exception_translators = get_internals().registered_exception_translators; + for (auto& translator : registered_exception_translators) { + try { + translator(last_exception); + } catch (...) { + last_exception = std::current_exception(); + continue; + } + return nullptr; + } + PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!"); + return nullptr; + } + + auto append_note_if_missing_header_is_suspected = [](std::string &msg) { + if (msg.find("std::") != std::string::npos) { + msg += "\n\n" + "Did you forget to `#include `? Or ,\n" + ", , etc. Some automatic\n" + "conversions are optional and require extra headers to be included\n" + "when compiling your pybind11 module."; + } + }; + + if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { + if (overloads->is_operator) + return handle(Py_NotImplemented).inc_ref().ptr(); + + std::string msg = std::string(overloads->name) + "(): incompatible " + + std::string(overloads->is_constructor ? "constructor" : "function") + + " arguments. The following argument types are supported:\n"; + + int ctr = 0; + for (const function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) { + msg += " "+ std::to_string(++ctr) + ". "; + + bool wrote_sig = false; + if (overloads->is_constructor) { + // For a constructor, rewrite `(self: Object, arg0, ...) -> NoneType` as `Object(arg0, ...)` + std::string sig = it2->signature; + size_t start = sig.find('(') + 7; // skip "(self: " + if (start < sig.size()) { + // End at the , for the next argument + size_t end = sig.find(", "), next = end + 2; + size_t ret = sig.rfind(" -> "); + // Or the ), if there is no comma: + if (end >= sig.size()) next = end = sig.find(')'); + if (start < end && next < sig.size()) { + msg.append(sig, start, end - start); + msg += '('; + msg.append(sig, next, ret - next); + wrote_sig = true; + } + } + } + if (!wrote_sig) msg += it2->signature; + + msg += "\n"; + } + msg += "\nInvoked with: "; + auto args_ = reinterpret_borrow(args_in); + bool some_args = false; + for (size_t ti = overloads->is_constructor ? 1 : 0; ti < args_.size(); ++ti) { + if (!some_args) some_args = true; + else msg += ", "; + msg += pybind11::repr(args_[ti]); + } + if (kwargs_in) { + auto kwargs = reinterpret_borrow(kwargs_in); + if (kwargs.size() > 0) { + if (some_args) msg += "; "; + msg += "kwargs: "; + bool first = true; + for (auto kwarg : kwargs) { + if (first) first = false; + else msg += ", "; + msg += pybind11::str("{}={!r}").format(kwarg.first, kwarg.second); + } + } + } + + append_note_if_missing_header_is_suspected(msg); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return nullptr; + } else if (!result) { + std::string msg = "Unable to convert function return value to a " + "Python type! The signature was\n\t"; + msg += it->signature; + append_note_if_missing_header_is_suspected(msg); + PyErr_SetString(PyExc_TypeError, msg.c_str()); + return nullptr; + } else { + if (overloads->is_constructor && !self_value_and_holder.holder_constructed()) { + auto *pi = reinterpret_cast(parent.ptr()); + self_value_and_holder.type->init_instance(pi, nullptr); + } + return result.ptr(); + } + } +}; + +/// Wrapper for Python extension modules +class module : public object { +public: + PYBIND11_OBJECT_DEFAULT(module, object, PyModule_Check) + + /// Create a new top-level Python module with the given name and docstring + explicit module(const char *name, const char *doc = nullptr) { + if (!options::show_user_defined_docstrings()) doc = nullptr; +#if PY_MAJOR_VERSION >= 3 + PyModuleDef *def = new PyModuleDef(); + std::memset(def, 0, sizeof(PyModuleDef)); + def->m_name = name; + def->m_doc = doc; + def->m_size = -1; + Py_INCREF(def); + m_ptr = PyModule_Create(def); +#else + m_ptr = Py_InitModule3(name, nullptr, doc); +#endif + if (m_ptr == nullptr) + pybind11_fail("Internal error in module::module()"); + inc_ref(); + } + + /** \rst + Create Python binding for a new function within the module scope. ``Func`` + can be a plain C++ function, a function pointer, or a lambda function. For + details on the ``Extra&& ... extra`` argument, see section :ref:`extras`. + \endrst */ + template + module &def(const char *name_, Func &&f, const Extra& ... extra) { + cpp_function func(std::forward(f), name(name_), scope(*this), + sibling(getattr(*this, name_, none())), extra...); + // NB: allow overwriting here because cpp_function sets up a chain with the intention of + // overwriting (and has already checked internally that it isn't overwriting non-functions). + add_object(name_, func, true /* overwrite */); + return *this; + } + + /** \rst + Create and return a new Python submodule with the given name and docstring. + This also works recursively, i.e. + + .. code-block:: cpp + + py::module m("example", "pybind11 example plugin"); + py::module m2 = m.def_submodule("sub", "A submodule of 'example'"); + py::module m3 = m2.def_submodule("subsub", "A submodule of 'example.sub'"); + \endrst */ + module def_submodule(const char *name, const char *doc = nullptr) { + std::string full_name = std::string(PyModule_GetName(m_ptr)) + + std::string(".") + std::string(name); + auto result = reinterpret_borrow(PyImport_AddModule(full_name.c_str())); + if (doc && options::show_user_defined_docstrings()) + result.attr("__doc__") = pybind11::str(doc); + attr(name) = result; + return result; + } + + /// Import and return a module or throws `error_already_set`. + static module import(const char *name) { + PyObject *obj = PyImport_ImportModule(name); + if (!obj) + throw error_already_set(); + return reinterpret_steal(obj); + } + + /// Reload the module or throws `error_already_set`. + void reload() { + PyObject *obj = PyImport_ReloadModule(ptr()); + if (!obj) + throw error_already_set(); + *this = reinterpret_steal(obj); + } + + // Adds an object to the module using the given name. Throws if an object with the given name + // already exists. + // + // overwrite should almost always be false: attempting to overwrite objects that pybind11 has + // established will, in most cases, break things. + PYBIND11_NOINLINE void add_object(const char *name, handle obj, bool overwrite = false) { + if (!overwrite && hasattr(*this, name)) + pybind11_fail("Error during initialization: multiple incompatible definitions with name \"" + + std::string(name) + "\""); + + PyModule_AddObject(ptr(), name, obj.inc_ref().ptr() /* steals a reference */); + } +}; + +/// \ingroup python_builtins +/// Return a dictionary representing the global variables in the current execution frame, +/// or ``__main__.__dict__`` if there is no frame (usually when the interpreter is embedded). +inline dict globals() { + PyObject *p = PyEval_GetGlobals(); + return reinterpret_borrow(p ? p : module::import("__main__").attr("__dict__").ptr()); +} + +NAMESPACE_BEGIN(detail) +/// Generic support for creating new Python heap types +class generic_type : public object { + template friend class class_; +public: + PYBIND11_OBJECT_DEFAULT(generic_type, object, PyType_Check) +protected: + void initialize(const type_record &rec) { + if (rec.scope && hasattr(rec.scope, rec.name)) + pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec.name) + + "\": an object with that name is already defined"); + + if (rec.module_local ? get_local_type_info(*rec.type) : get_global_type_info(*rec.type)) + pybind11_fail("generic_type: type \"" + std::string(rec.name) + + "\" is already registered!"); + + m_ptr = make_new_python_type(rec); + + /* Register supplemental type information in C++ dict */ + auto *tinfo = new detail::type_info(); + tinfo->type = (PyTypeObject *) m_ptr; + tinfo->cpptype = rec.type; + tinfo->type_size = rec.type_size; + tinfo->type_align = rec.type_align; + tinfo->operator_new = rec.operator_new; + tinfo->holder_size_in_ptrs = size_in_ptrs(rec.holder_size); + tinfo->init_instance = rec.init_instance; + tinfo->dealloc = rec.dealloc; + tinfo->simple_type = true; + tinfo->simple_ancestors = true; + tinfo->default_holder = rec.default_holder; + tinfo->module_local = rec.module_local; + + auto &internals = get_internals(); + auto tindex = std::type_index(*rec.type); + tinfo->direct_conversions = &internals.direct_conversions[tindex]; + if (rec.module_local) + registered_local_types_cpp()[tindex] = tinfo; + else + internals.registered_types_cpp[tindex] = tinfo; + internals.registered_types_py[(PyTypeObject *) m_ptr] = { tinfo }; + + if (rec.bases.size() > 1 || rec.multiple_inheritance) { + mark_parents_nonsimple(tinfo->type); + tinfo->simple_ancestors = false; + } + else if (rec.bases.size() == 1) { + auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr()); + tinfo->simple_ancestors = parent_tinfo->simple_ancestors; + } + + if (rec.module_local) { + // Stash the local typeinfo and loader so that external modules can access it. + tinfo->module_local_load = &type_caster_generic::local_load; + setattr(m_ptr, PYBIND11_MODULE_LOCAL_ID, capsule(tinfo)); + } + } + + /// Helper function which tags all parents of a type using mult. inheritance + void mark_parents_nonsimple(PyTypeObject *value) { + auto t = reinterpret_borrow(value->tp_bases); + for (handle h : t) { + auto tinfo2 = get_type_info((PyTypeObject *) h.ptr()); + if (tinfo2) + tinfo2->simple_type = false; + mark_parents_nonsimple((PyTypeObject *) h.ptr()); + } + } + + void install_buffer_funcs( + buffer_info *(*get_buffer)(PyObject *, void *), + void *get_buffer_data) { + PyHeapTypeObject *type = (PyHeapTypeObject*) m_ptr; + auto tinfo = detail::get_type_info(&type->ht_type); + + if (!type->ht_type.tp_as_buffer) + pybind11_fail( + "To be able to register buffer protocol support for the type '" + + std::string(tinfo->type->tp_name) + + "' the associated class<>(..) invocation must " + "include the pybind11::buffer_protocol() annotation!"); + + tinfo->get_buffer = get_buffer; + tinfo->get_buffer_data = get_buffer_data; + } + + // rec_func must be set for either fget or fset. + void def_property_static_impl(const char *name, + handle fget, handle fset, + detail::function_record *rec_func) { + const auto is_static = rec_func && !(rec_func->is_method && rec_func->scope); + const auto has_doc = rec_func && rec_func->doc && pybind11::options::show_user_defined_docstrings(); + auto property = handle((PyObject *) (is_static ? get_internals().static_property_type + : &PyProperty_Type)); + attr(name) = property(fget.ptr() ? fget : none(), + fset.ptr() ? fset : none(), + /*deleter*/none(), + pybind11::str(has_doc ? rec_func->doc : "")); + } +}; + +/// Set the pointer to operator new if it exists. The cast is needed because it can be overloaded. +template (T::operator new))>> +void set_operator_new(type_record *r) { r->operator_new = &T::operator new; } + +template void set_operator_new(...) { } + +template struct has_operator_delete : std::false_type { }; +template struct has_operator_delete(T::operator delete))>> + : std::true_type { }; +template struct has_operator_delete_size : std::false_type { }; +template struct has_operator_delete_size(T::operator delete))>> + : std::true_type { }; +/// Call class-specific delete if it exists or global otherwise. Can also be an overload set. +template ::value, int> = 0> +void call_operator_delete(T *p, size_t, size_t) { T::operator delete(p); } +template ::value && has_operator_delete_size::value, int> = 0> +void call_operator_delete(T *p, size_t s, size_t) { T::operator delete(p, s); } + +inline void call_operator_delete(void *p, size_t s, size_t a) { + (void)s; (void)a; +#if defined(PYBIND11_CPP17) + if (a > __STDCPP_DEFAULT_NEW_ALIGNMENT__) + ::operator delete(p, s, std::align_val_t(a)); + else + ::operator delete(p, s); +#else + ::operator delete(p); +#endif +} + +NAMESPACE_END(detail) + +/// Given a pointer to a member function, cast it to its `Derived` version. +/// Forward everything else unchanged. +template +auto method_adaptor(F &&f) -> decltype(std::forward(f)) { return std::forward(f); } + +template +auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) { + static_assert(detail::is_accessible_base_of::value, + "Cannot bind an inaccessible base class method; use a lambda definition instead"); + return pmf; +} + +template +auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const { + static_assert(detail::is_accessible_base_of::value, + "Cannot bind an inaccessible base class method; use a lambda definition instead"); + return pmf; +} + +template +class class_ : public detail::generic_type { + template using is_holder = detail::is_holder_type; + template using is_subtype = detail::is_strict_base_of; + template using is_base = detail::is_strict_base_of; + // struct instead of using here to help MSVC: + template struct is_valid_class_option : + detail::any_of, is_subtype, is_base> {}; + +public: + using type = type_; + using type_alias = detail::exactly_one_t; + constexpr static bool has_alias = !std::is_void::value; + using holder_type = detail::exactly_one_t, options...>; + + static_assert(detail::all_of...>::value, + "Unknown/invalid class_ template parameters provided"); + + static_assert(!has_alias || std::is_polymorphic::value, + "Cannot use an alias class with a non-polymorphic type"); + + PYBIND11_OBJECT(class_, generic_type, PyType_Check) + + template + class_(handle scope, const char *name, const Extra &... extra) { + using namespace detail; + + // MI can only be specified via class_ template options, not constructor parameters + static_assert( + none_of...>::value || // no base class arguments, or: + ( constexpr_sum(is_pyobject::value...) == 1 && // Exactly one base + constexpr_sum(is_base::value...) == 0 && // no template option bases + none_of...>::value), // no multiple_inheritance attr + "Error: multiple inheritance bases must be specified via class_ template options"); + + type_record record; + record.scope = scope; + record.name = name; + record.type = &typeid(type); + record.type_size = sizeof(conditional_t); + record.type_align = alignof(conditional_t&); + record.holder_size = sizeof(holder_type); + record.init_instance = init_instance; + record.dealloc = dealloc; + record.default_holder = detail::is_instantiation::value; + + set_operator_new(&record); + + /* Register base classes specified via template arguments to class_, if any */ + PYBIND11_EXPAND_SIDE_EFFECTS(add_base(record)); + + /* Process optional arguments, if any */ + process_attributes::init(extra..., &record); + + generic_type::initialize(record); + + if (has_alias) { + auto &instances = record.module_local ? registered_local_types_cpp() : get_internals().registered_types_cpp; + instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))]; + } + } + + template ::value, int> = 0> + static void add_base(detail::type_record &rec) { + rec.add_base(typeid(Base), [](void *src) -> void * { + return static_cast(reinterpret_cast(src)); + }); + } + + template ::value, int> = 0> + static void add_base(detail::type_record &) { } + + template + class_ &def(const char *name_, Func&& f, const Extra&... extra) { + cpp_function cf(method_adaptor(std::forward(f)), name(name_), is_method(*this), + sibling(getattr(*this, name_, none())), extra...); + attr(cf.name()) = cf; + return *this; + } + + template class_ & + def_static(const char *name_, Func &&f, const Extra&... extra) { + static_assert(!std::is_member_function_pointer::value, + "def_static(...) called with a non-static member function pointer"); + cpp_function cf(std::forward(f), name(name_), scope(*this), + sibling(getattr(*this, name_, none())), extra...); + attr(cf.name()) = cf; + return *this; + } + + template + class_ &def(const detail::op_ &op, const Extra&... extra) { + op.execute(*this, extra...); + return *this; + } + + template + class_ & def_cast(const detail::op_ &op, const Extra&... extra) { + op.execute_cast(*this, extra...); + return *this; + } + + template + class_ &def(const detail::initimpl::constructor &init, const Extra&... extra) { + init.execute(*this, extra...); + return *this; + } + + template + class_ &def(const detail::initimpl::alias_constructor &init, const Extra&... extra) { + init.execute(*this, extra...); + return *this; + } + + template + class_ &def(detail::initimpl::factory &&init, const Extra&... extra) { + std::move(init).execute(*this, extra...); + return *this; + } + + template + class_ &def(detail::initimpl::pickle_factory &&pf, const Extra &...extra) { + std::move(pf).execute(*this, extra...); + return *this; + } + + template class_& def_buffer(Func &&func) { + struct capture { Func func; }; + capture *ptr = new capture { std::forward(func) }; + install_buffer_funcs([](PyObject *obj, void *ptr) -> buffer_info* { + detail::make_caster caster; + if (!caster.load(obj, false)) + return nullptr; + return new buffer_info(((capture *) ptr)->func(caster)); + }, ptr); + return *this; + } + + template + class_ &def_buffer(Return (Class::*func)(Args...)) { + return def_buffer([func] (type &obj) { return (obj.*func)(); }); + } + + template + class_ &def_buffer(Return (Class::*func)(Args...) const) { + return def_buffer([func] (const type &obj) { return (obj.*func)(); }); + } + + template + class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) { + static_assert(std::is_base_of::value, "def_readwrite() requires a class member (or base class member)"); + cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)), + fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this)); + def_property(name, fget, fset, return_value_policy::reference_internal, extra...); + return *this; + } + + template + class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) { + static_assert(std::is_base_of::value, "def_readonly() requires a class member (or base class member)"); + cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)); + def_property_readonly(name, fget, return_value_policy::reference_internal, extra...); + return *this; + } + + template + class_ &def_readwrite_static(const char *name, D *pm, const Extra& ...extra) { + cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)), + fset([pm](object, const D &value) { *pm = value; }, scope(*this)); + def_property_static(name, fget, fset, return_value_policy::reference, extra...); + return *this; + } + + template + class_ &def_readonly_static(const char *name, const D *pm, const Extra& ...extra) { + cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)); + def_property_readonly_static(name, fget, return_value_policy::reference, extra...); + return *this; + } + + /// Uses return_value_policy::reference_internal by default + template + class_ &def_property_readonly(const char *name, const Getter &fget, const Extra& ...extra) { + return def_property_readonly(name, cpp_function(method_adaptor(fget)), + return_value_policy::reference_internal, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property_readonly(const char *name, const cpp_function &fget, const Extra& ...extra) { + return def_property(name, fget, nullptr, extra...); + } + + /// Uses return_value_policy::reference by default + template + class_ &def_property_readonly_static(const char *name, const Getter &fget, const Extra& ...extra) { + return def_property_readonly_static(name, cpp_function(fget), return_value_policy::reference, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property_readonly_static(const char *name, const cpp_function &fget, const Extra& ...extra) { + return def_property_static(name, fget, nullptr, extra...); + } + + /// Uses return_value_policy::reference_internal by default + template + class_ &def_property(const char *name, const Getter &fget, const Setter &fset, const Extra& ...extra) { + return def_property(name, fget, cpp_function(method_adaptor(fset)), extra...); + } + template + class_ &def_property(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) { + return def_property(name, cpp_function(method_adaptor(fget)), fset, + return_value_policy::reference_internal, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) { + return def_property_static(name, fget, fset, is_method(*this), extra...); + } + + /// Uses return_value_policy::reference by default + template + class_ &def_property_static(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) { + return def_property_static(name, cpp_function(fget), fset, return_value_policy::reference, extra...); + } + + /// Uses cpp_function's return_value_policy by default + template + class_ &def_property_static(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) { + auto rec_fget = get_function_record(fget), rec_fset = get_function_record(fset); + auto *rec_active = rec_fget; + if (rec_fget) { + char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */ + detail::process_attributes::init(extra..., rec_fget); + if (rec_fget->doc && rec_fget->doc != doc_prev) { + free(doc_prev); + rec_fget->doc = strdup(rec_fget->doc); + } + } + if (rec_fset) { + char *doc_prev = rec_fset->doc; + detail::process_attributes::init(extra..., rec_fset); + if (rec_fset->doc && rec_fset->doc != doc_prev) { + free(doc_prev); + rec_fset->doc = strdup(rec_fset->doc); + } + if (! rec_active) rec_active = rec_fset; + } + def_property_static_impl(name, fget, fset, rec_active); + return *this; + } + +private: + /// Initialize holder object, variant 1: object derives from enable_shared_from_this + template + static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, + const holder_type * /* unused */, const std::enable_shared_from_this * /* dummy */) { + try { + auto sh = std::dynamic_pointer_cast( + v_h.value_ptr()->shared_from_this()); + if (sh) { + new (std::addressof(v_h.holder())) holder_type(std::move(sh)); + v_h.set_holder_constructed(); + } + } catch (const std::bad_weak_ptr &) {} + + if (!v_h.holder_constructed() && inst->owned) { + new (std::addressof(v_h.holder())) holder_type(v_h.value_ptr()); + v_h.set_holder_constructed(); + } + } + + static void init_holder_from_existing(const detail::value_and_holder &v_h, + const holder_type *holder_ptr, std::true_type /*is_copy_constructible*/) { + new (std::addressof(v_h.holder())) holder_type(*reinterpret_cast(holder_ptr)); + } + + static void init_holder_from_existing(const detail::value_and_holder &v_h, + const holder_type *holder_ptr, std::false_type /*is_copy_constructible*/) { + new (std::addressof(v_h.holder())) holder_type(std::move(*const_cast(holder_ptr))); + } + + /// Initialize holder object, variant 2: try to construct from existing holder object, if possible + static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, + const holder_type *holder_ptr, const void * /* dummy -- not enable_shared_from_this) */) { + if (holder_ptr) { + init_holder_from_existing(v_h, holder_ptr, std::is_copy_constructible()); + v_h.set_holder_constructed(); + } else if (inst->owned || detail::always_construct_holder::value) { + new (std::addressof(v_h.holder())) holder_type(v_h.value_ptr()); + v_h.set_holder_constructed(); + } + } + + /// Performs instance initialization including constructing a holder and registering the known + /// instance. Should be called as soon as the `type` value_ptr is set for an instance. Takes an + /// optional pointer to an existing holder to use; if not specified and the instance is + /// `.owned`, a new holder will be constructed to manage the value pointer. + static void init_instance(detail::instance *inst, const void *holder_ptr) { + auto v_h = inst->get_value_and_holder(detail::get_type_info(typeid(type))); + if (!v_h.instance_registered()) { + register_instance(inst, v_h.value_ptr(), v_h.type); + v_h.set_instance_registered(); + } + init_holder(inst, v_h, (const holder_type *) holder_ptr, v_h.value_ptr()); + } + + /// Deallocates an instance; via holder, if constructed; otherwise via operator delete. + static void dealloc(detail::value_and_holder &v_h) { + if (v_h.holder_constructed()) { + v_h.holder().~holder_type(); + v_h.set_holder_constructed(false); + } + else { + detail::call_operator_delete(v_h.value_ptr(), + v_h.type->type_size, + v_h.type->type_align + ); + } + v_h.value_ptr() = nullptr; + } + + static detail::function_record *get_function_record(handle h) { + h = detail::get_function(h); + return h ? (detail::function_record *) reinterpret_borrow(PyCFunction_GET_SELF(h.ptr())) + : nullptr; + } +}; + +/// Binds an existing constructor taking arguments Args... +template detail::initimpl::constructor init() { return {}; } +/// Like `init()`, but the instance is always constructed through the alias class (even +/// when not inheriting on the Python side). +template detail::initimpl::alias_constructor init_alias() { return {}; } + +/// Binds a factory function as a constructor +template > +Ret init(Func &&f) { return {std::forward(f)}; } + +/// Dual-argument factory function: the first function is called when no alias is needed, the second +/// when an alias is needed (i.e. due to python-side inheritance). Arguments must be identical. +template > +Ret init(CFunc &&c, AFunc &&a) { + return {std::forward(c), std::forward(a)}; +} + +/// Binds pickling functions `__getstate__` and `__setstate__` and ensures that the type +/// returned by `__getstate__` is the same as the argument accepted by `__setstate__`. +template +detail::initimpl::pickle_factory pickle(GetState &&g, SetState &&s) { + return {std::forward(g), std::forward(s)}; +} + +NAMESPACE_BEGIN(detail) +struct enum_base { + enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { } + + PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) { + m_base.attr("__entries") = dict(); + auto property = handle((PyObject *) &PyProperty_Type); + auto static_property = handle((PyObject *) get_internals().static_property_type); + + m_base.attr("__repr__") = cpp_function( + [](handle arg) -> str { + handle type = arg.get_type(); + object type_name = type.attr("__name__"); + dict entries = type.attr("__entries"); + for (const auto &kv : entries) { + object other = kv.second[int_(0)]; + if (other.equal(arg)) + return pybind11::str("{}.{}").format(type_name, kv.first); + } + return pybind11::str("{}.???").format(type_name); + }, is_method(m_base) + ); + + m_base.attr("name") = property(cpp_function( + [](handle arg) -> str { + dict entries = arg.get_type().attr("__entries"); + for (const auto &kv : entries) { + if (handle(kv.second[int_(0)]).equal(arg)) + return pybind11::str(kv.first); + } + return "???"; + }, is_method(m_base) + )); + + m_base.attr("__doc__") = static_property(cpp_function( + [](handle arg) -> std::string { + std::string docstring; + dict entries = arg.attr("__entries"); + if (((PyTypeObject *) arg.ptr())->tp_doc) + docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n"; + docstring += "Members:"; + for (const auto &kv : entries) { + auto key = std::string(pybind11::str(kv.first)); + auto comment = kv.second[int_(1)]; + docstring += "\n\n " + key; + if (!comment.is_none()) + docstring += " : " + (std::string) pybind11::str(comment); + } + return docstring; + } + ), none(), none(), ""); + + m_base.attr("__members__") = static_property(cpp_function( + [](handle arg) -> dict { + dict entries = arg.attr("__entries"), m; + for (const auto &kv : entries) + m[kv.first] = kv.second[int_(0)]; + return m; + }), none(), none(), "" + ); + + #define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \ + m_base.attr(op) = cpp_function( \ + [](object a, object b) { \ + if (!a.get_type().is(b.get_type())) \ + strict_behavior; \ + return expr; \ + }, \ + is_method(m_base)) + + #define PYBIND11_ENUM_OP_CONV(op, expr) \ + m_base.attr(op) = cpp_function( \ + [](object a_, object b_) { \ + int_ a(a_), b(b_); \ + return expr; \ + }, \ + is_method(m_base)) + + if (is_convertible) { + PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b)); + PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b)); + + if (is_arithmetic) { + PYBIND11_ENUM_OP_CONV("__lt__", a < b); + PYBIND11_ENUM_OP_CONV("__gt__", a > b); + PYBIND11_ENUM_OP_CONV("__le__", a <= b); + PYBIND11_ENUM_OP_CONV("__ge__", a >= b); + PYBIND11_ENUM_OP_CONV("__and__", a & b); + PYBIND11_ENUM_OP_CONV("__rand__", a & b); + PYBIND11_ENUM_OP_CONV("__or__", a | b); + PYBIND11_ENUM_OP_CONV("__ror__", a | b); + PYBIND11_ENUM_OP_CONV("__xor__", a ^ b); + PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b); + } + } else { + PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false); + PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true); + + if (is_arithmetic) { + #define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!"); + PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW); + PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW); + PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW); + PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW); + #undef PYBIND11_THROW + } + } + + #undef PYBIND11_ENUM_OP_CONV + #undef PYBIND11_ENUM_OP_STRICT + + object getstate = cpp_function( + [](object arg) { return int_(arg); }, is_method(m_base)); + + m_base.attr("__getstate__") = getstate; + m_base.attr("__hash__") = getstate; + } + + PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) { + dict entries = m_base.attr("__entries"); + str name(name_); + if (entries.contains(name)) { + std::string type_name = (std::string) str(m_base.attr("__name__")); + throw value_error(type_name + ": element \"" + std::string(name_) + "\" already exists!"); + } + + entries[name] = std::make_pair(value, doc); + m_base.attr(name) = value; + } + + PYBIND11_NOINLINE void export_values() { + dict entries = m_base.attr("__entries"); + for (const auto &kv : entries) + m_parent.attr(kv.first) = kv.second[int_(0)]; + } + + handle m_base; + handle m_parent; +}; + +NAMESPACE_END(detail) + +/// Binds C++ enumerations and enumeration classes to Python +template class enum_ : public class_ { +public: + using Base = class_; + using Base::def; + using Base::attr; + using Base::def_property_readonly; + using Base::def_property_readonly_static; + using Scalar = typename std::underlying_type::type; + + template + enum_(const handle &scope, const char *name, const Extra&... extra) + : class_(scope, name, extra...), m_base(*this, scope) { + constexpr bool is_arithmetic = detail::any_of...>::value; + constexpr bool is_convertible = std::is_convertible::value; + m_base.init(is_arithmetic, is_convertible); + + def(init([](Scalar i) { return static_cast(i); })); + def("__int__", [](Type value) { return (Scalar) value; }); + #if PY_MAJOR_VERSION < 3 + def("__long__", [](Type value) { return (Scalar) value; }); + #endif + cpp_function setstate( + [](Type &value, Scalar arg) { value = static_cast(arg); }, + is_method(*this)); + attr("__setstate__") = setstate; + } + + /// Export enumeration entries into the parent scope + enum_& export_values() { + m_base.export_values(); + return *this; + } + + /// Add an enumeration entry + enum_& value(char const* name, Type value, const char *doc = nullptr) { + m_base.value(name, pybind11::cast(value, return_value_policy::copy), doc); + return *this; + } + +private: + detail::enum_base m_base; +}; + +NAMESPACE_BEGIN(detail) + + +inline void keep_alive_impl(handle nurse, handle patient) { + if (!nurse || !patient) + pybind11_fail("Could not activate keep_alive!"); + + if (patient.is_none() || nurse.is_none()) + return; /* Nothing to keep alive or nothing to be kept alive by */ + + auto tinfo = all_type_info(Py_TYPE(nurse.ptr())); + if (!tinfo.empty()) { + /* It's a pybind-registered type, so we can store the patient in the + * internal list. */ + add_patient(nurse.ptr(), patient.ptr()); + } + else { + /* Fall back to clever approach based on weak references taken from + * Boost.Python. This is not used for pybind-registered types because + * the objects can be destroyed out-of-order in a GC pass. */ + cpp_function disable_lifesupport( + [patient](handle weakref) { patient.dec_ref(); weakref.dec_ref(); }); + + weakref wr(nurse, disable_lifesupport); + + patient.inc_ref(); /* reference patient and leak the weak reference */ + (void) wr.release(); + } +} + +PYBIND11_NOINLINE inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { + auto get_arg = [&](size_t n) { + if (n == 0) + return ret; + else if (n == 1 && call.init_self) + return call.init_self; + else if (n <= call.args.size()) + return call.args[n - 1]; + return handle(); + }; + + keep_alive_impl(get_arg(Nurse), get_arg(Patient)); +} + +inline std::pair all_type_info_get_cache(PyTypeObject *type) { + auto res = get_internals().registered_types_py +#ifdef __cpp_lib_unordered_map_try_emplace + .try_emplace(type); +#else + .emplace(type, std::vector()); +#endif + if (res.second) { + // New cache entry created; set up a weak reference to automatically remove it if the type + // gets destroyed: + weakref((PyObject *) type, cpp_function([type](handle wr) { + get_internals().registered_types_py.erase(type); + wr.dec_ref(); + })).release(); + } + + return res; +} + +template +struct iterator_state { + Iterator it; + Sentinel end; + bool first_or_done; +}; + +NAMESPACE_END(detail) + +/// Makes a python iterator from a first and past-the-end C++ InputIterator. +template ()), + typename... Extra> +iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) { + typedef detail::iterator_state state; + + if (!detail::get_type_info(typeid(state), false)) { + class_(handle(), "iterator", pybind11::module_local()) + .def("__iter__", [](state &s) -> state& { return s; }) + .def("__next__", [](state &s) -> ValueType { + if (!s.first_or_done) + ++s.it; + else + s.first_or_done = false; + if (s.it == s.end) { + s.first_or_done = true; + throw stop_iteration(); + } + return *s.it; + }, std::forward(extra)..., Policy); + } + + return cast(state{first, last, true}); +} + +/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a +/// first and past-the-end InputIterator. +template ()).first), + typename... Extra> +iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) { + typedef detail::iterator_state state; + + if (!detail::get_type_info(typeid(state), false)) { + class_(handle(), "iterator", pybind11::module_local()) + .def("__iter__", [](state &s) -> state& { return s; }) + .def("__next__", [](state &s) -> KeyType { + if (!s.first_or_done) + ++s.it; + else + s.first_or_done = false; + if (s.it == s.end) { + s.first_or_done = true; + throw stop_iteration(); + } + return (*s.it).first; + }, std::forward(extra)..., Policy); + } + + return cast(state{first, last, true}); +} + +/// Makes an iterator over values of an stl container or other container supporting +/// `std::begin()`/`std::end()` +template iterator make_iterator(Type &value, Extra&&... extra) { + return make_iterator(std::begin(value), std::end(value), extra...); +} + +/// Makes an iterator over the keys (`.first`) of a stl map-like container supporting +/// `std::begin()`/`std::end()` +template iterator make_key_iterator(Type &value, Extra&&... extra) { + return make_key_iterator(std::begin(value), std::end(value), extra...); +} + +template void implicitly_convertible() { + struct set_flag { + bool &flag; + set_flag(bool &flag) : flag(flag) { flag = true; } + ~set_flag() { flag = false; } + }; + auto implicit_caster = [](PyObject *obj, PyTypeObject *type) -> PyObject * { + static bool currently_used = false; + if (currently_used) // implicit conversions are non-reentrant + return nullptr; + set_flag flag_helper(currently_used); + if (!detail::make_caster().load(obj, false)) + return nullptr; + tuple args(1); + args[0] = obj; + PyObject *result = PyObject_Call((PyObject *) type, args.ptr(), nullptr); + if (result == nullptr) + PyErr_Clear(); + return result; + }; + + if (auto tinfo = detail::get_type_info(typeid(OutputType))) + tinfo->implicit_conversions.push_back(implicit_caster); + else + pybind11_fail("implicitly_convertible: Unable to find type " + type_id()); +} + +template +void register_exception_translator(ExceptionTranslator&& translator) { + detail::get_internals().registered_exception_translators.push_front( + std::forward(translator)); +} + +/** + * Wrapper to generate a new Python exception type. + * + * This should only be used with PyErr_SetString for now. + * It is not (yet) possible to use as a py::base. + * Template type argument is reserved for future use. + */ +template +class exception : public object { +public: + exception() = default; + exception(handle scope, const char *name, PyObject *base = PyExc_Exception) { + std::string full_name = scope.attr("__name__").cast() + + std::string(".") + name; + m_ptr = PyErr_NewException(const_cast(full_name.c_str()), base, NULL); + if (hasattr(scope, name)) + pybind11_fail("Error during initialization: multiple incompatible " + "definitions with name \"" + std::string(name) + "\""); + scope.attr(name) = *this; + } + + // Sets the current python exception to this exception object with the given message + void operator()(const char *message) { + PyErr_SetString(m_ptr, message); + } +}; + +NAMESPACE_BEGIN(detail) +// Returns a reference to a function-local static exception object used in the simple +// register_exception approach below. (It would be simpler to have the static local variable +// directly in register_exception, but that makes clang <3.5 segfault - issue #1349). +template +exception &get_exception_object() { static exception ex; return ex; } +NAMESPACE_END(detail) + +/** + * Registers a Python exception in `m` of the given `name` and installs an exception translator to + * translate the C++ exception to the created Python exception using the exceptions what() method. + * This is intended for simple exception translations; for more complex translation, register the + * exception object and translator directly. + */ +template +exception ®ister_exception(handle scope, + const char *name, + PyObject *base = PyExc_Exception) { + auto &ex = detail::get_exception_object(); + if (!ex) ex = exception(scope, name, base); + + register_exception_translator([](std::exception_ptr p) { + if (!p) return; + try { + std::rethrow_exception(p); + } catch (const CppException &e) { + detail::get_exception_object()(e.what()); + } + }); + return ex; +} + +NAMESPACE_BEGIN(detail) +PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) { + auto strings = tuple(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + strings[i] = str(args[i]); + } + auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" "); + auto line = sep.attr("join")(strings); + + object file; + if (kwargs.contains("file")) { + file = kwargs["file"].cast(); + } else { + try { + file = module::import("sys").attr("stdout"); + } catch (const error_already_set &) { + /* If print() is called from code that is executed as + part of garbage collection during interpreter shutdown, + importing 'sys' can fail. Give up rather than crashing the + interpreter in this case. */ + return; + } + } + + auto write = file.attr("write"); + write(line); + write(kwargs.contains("end") ? kwargs["end"] : cast("\n")); + + if (kwargs.contains("flush") && kwargs["flush"].cast()) + file.attr("flush")(); +} +NAMESPACE_END(detail) + +template +void print(Args &&...args) { + auto c = detail::collect_arguments(std::forward(args)...); + detail::print(c.args(), c.kwargs()); +} + +#if defined(WITH_THREAD) && !defined(PYPY_VERSION) + +/* The functions below essentially reproduce the PyGILState_* API using a RAII + * pattern, but there are a few important differences: + * + * 1. When acquiring the GIL from an non-main thread during the finalization + * phase, the GILState API blindly terminates the calling thread, which + * is often not what is wanted. This API does not do this. + * + * 2. The gil_scoped_release function can optionally cut the relationship + * of a PyThreadState and its associated thread, which allows moving it to + * another thread (this is a fairly rare/advanced use case). + * + * 3. The reference count of an acquired thread state can be controlled. This + * can be handy to prevent cases where callbacks issued from an external + * thread would otherwise constantly construct and destroy thread state data + * structures. + * + * See the Python bindings of NanoGUI (http://github.com/wjakob/nanogui) for an + * example which uses features 2 and 3 to migrate the Python thread of + * execution to another thread (to run the event loop on the original thread, + * in this case). + */ + +class gil_scoped_acquire { +public: + PYBIND11_NOINLINE gil_scoped_acquire() { + auto const &internals = detail::get_internals(); + tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate); + + if (!tstate) { + /* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if + calling from a Python thread). Since we use a different key, this ensures + we don't create a new thread state and deadlock in PyEval_AcquireThread + below. Note we don't save this state with internals.tstate, since we don't + create it we would fail to clear it (its reference count should be > 0). */ + tstate = PyGILState_GetThisThreadState(); + } + + if (!tstate) { + tstate = PyThreadState_New(internals.istate); + #if !defined(NDEBUG) + if (!tstate) + pybind11_fail("scoped_acquire: could not create thread state!"); + #endif + tstate->gilstate_counter = 0; + PYBIND11_TLS_REPLACE_VALUE(internals.tstate, tstate); + } else { + release = detail::get_thread_state_unchecked() != tstate; + } + + if (release) { + /* Work around an annoying assertion in PyThreadState_Swap */ + #if defined(Py_DEBUG) + PyInterpreterState *interp = tstate->interp; + tstate->interp = nullptr; + #endif + PyEval_AcquireThread(tstate); + #if defined(Py_DEBUG) + tstate->interp = interp; + #endif + } + + inc_ref(); + } + + void inc_ref() { + ++tstate->gilstate_counter; + } + + PYBIND11_NOINLINE void dec_ref() { + --tstate->gilstate_counter; + #if !defined(NDEBUG) + if (detail::get_thread_state_unchecked() != tstate) + pybind11_fail("scoped_acquire::dec_ref(): thread state must be current!"); + if (tstate->gilstate_counter < 0) + pybind11_fail("scoped_acquire::dec_ref(): reference count underflow!"); + #endif + if (tstate->gilstate_counter == 0) { + #if !defined(NDEBUG) + if (!release) + pybind11_fail("scoped_acquire::dec_ref(): internal error!"); + #endif + PyThreadState_Clear(tstate); + PyThreadState_DeleteCurrent(); + PYBIND11_TLS_DELETE_VALUE(detail::get_internals().tstate); + release = false; + } + } + + PYBIND11_NOINLINE ~gil_scoped_acquire() { + dec_ref(); + if (release) + PyEval_SaveThread(); + } +private: + PyThreadState *tstate = nullptr; + bool release = true; +}; + +class gil_scoped_release { +public: + explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) { + // `get_internals()` must be called here unconditionally in order to initialize + // `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an + // initialization race could occur as multiple threads try `gil_scoped_acquire`. + const auto &internals = detail::get_internals(); + tstate = PyEval_SaveThread(); + if (disassoc) { + auto key = internals.tstate; + PYBIND11_TLS_DELETE_VALUE(key); + } + } + ~gil_scoped_release() { + if (!tstate) + return; + PyEval_RestoreThread(tstate); + if (disassoc) { + auto key = detail::get_internals().tstate; + PYBIND11_TLS_REPLACE_VALUE(key, tstate); + } + } +private: + PyThreadState *tstate; + bool disassoc; +}; +#elif defined(PYPY_VERSION) +class gil_scoped_acquire { + PyGILState_STATE state; +public: + gil_scoped_acquire() { state = PyGILState_Ensure(); } + ~gil_scoped_acquire() { PyGILState_Release(state); } +}; + +class gil_scoped_release { + PyThreadState *state; +public: + gil_scoped_release() { state = PyEval_SaveThread(); } + ~gil_scoped_release() { PyEval_RestoreThread(state); } +}; +#else +class gil_scoped_acquire { }; +class gil_scoped_release { }; +#endif + +error_already_set::~error_already_set() { + if (type) { + error_scope scope; + gil_scoped_acquire gil; + type.release().dec_ref(); + value.release().dec_ref(); + trace.release().dec_ref(); + } +} + +inline function get_type_overload(const void *this_ptr, const detail::type_info *this_type, const char *name) { + handle self = detail::get_object_handle(this_ptr, this_type); + if (!self) + return function(); + handle type = self.get_type(); + auto key = std::make_pair(type.ptr(), name); + + /* Cache functions that aren't overloaded in Python to avoid + many costly Python dictionary lookups below */ + auto &cache = detail::get_internals().inactive_overload_cache; + if (cache.find(key) != cache.end()) + return function(); + + function overload = getattr(self, name, function()); + if (overload.is_cpp_function()) { + cache.insert(key); + return function(); + } + + /* Don't call dispatch code if invoked from overridden function. + Unfortunately this doesn't work on PyPy. */ +#if !defined(PYPY_VERSION) + PyFrameObject *frame = PyThreadState_Get()->frame; + if (frame && (std::string) str(frame->f_code->co_name) == name && + frame->f_code->co_argcount > 0) { + PyFrame_FastToLocals(frame); + PyObject *self_caller = PyDict_GetItem( + frame->f_locals, PyTuple_GET_ITEM(frame->f_code->co_varnames, 0)); + if (self_caller == self.ptr()) + return function(); + } +#else + /* PyPy currently doesn't provide a detailed cpyext emulation of + frame objects, so we have to emulate this using Python. This + is going to be slow..*/ + dict d; d["self"] = self; d["name"] = pybind11::str(name); + PyObject *result = PyRun_String( + "import inspect\n" + "frame = inspect.currentframe()\n" + "if frame is not None:\n" + " frame = frame.f_back\n" + " if frame is not None and str(frame.f_code.co_name) == name and " + "frame.f_code.co_argcount > 0:\n" + " self_caller = frame.f_locals[frame.f_code.co_varnames[0]]\n" + " if self_caller == self:\n" + " self = None\n", + Py_file_input, d.ptr(), d.ptr()); + if (result == nullptr) + throw error_already_set(); + if (d["self"].is_none()) + return function(); + Py_DECREF(result); +#endif + + return overload; +} + +template function get_overload(const T *this_ptr, const char *name) { + auto tinfo = detail::get_type_info(typeid(T)); + return tinfo ? get_type_overload(this_ptr, tinfo, name) : function(); +} + +#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \ + pybind11::gil_scoped_acquire gil; \ + pybind11::function overload = pybind11::get_overload(static_cast(this), name); \ + if (overload) { \ + auto o = overload(__VA_ARGS__); \ + if (pybind11::detail::cast_is_temporary_value_reference::value) { \ + static pybind11::detail::overload_caster_t caster; \ + return pybind11::detail::cast_ref(std::move(o), caster); \ + } \ + else return pybind11::detail::cast_safe(std::move(o)); \ + } \ + } + +#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \ + PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \ + return cname::fn(__VA_ARGS__) + +#define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \ + PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \ + pybind11::pybind11_fail("Tried to call pure virtual function \"" PYBIND11_STRINGIFY(cname) "::" name "\""); + +#define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \ + PYBIND11_OVERLOAD_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) + +#define PYBIND11_OVERLOAD_PURE(ret_type, cname, fn, ...) \ + PYBIND11_OVERLOAD_PURE_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +# pragma warning(pop) +#elif defined(__GNUG__) && !defined(__clang__) +# pragma GCC diagnostic pop +#endif diff --git a/mmocr/models/textdet/postprocess/include/pybind11/pytypes.h b/mmocr/models/textdet/postprocess/include/pybind11/pytypes.h new file mode 100644 index 00000000..3329fda2 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/pytypes.h @@ -0,0 +1,1438 @@ +/* + pybind11/pytypes.h: Convenience wrapper classes for basic Python types + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "buffer_info.h" +#include +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/* A few forward declarations */ +class handle; class object; +class str; class iterator; +struct arg; struct arg_v; + +NAMESPACE_BEGIN(detail) +class args_proxy; +inline bool isinstance_generic(handle obj, const std::type_info &tp); + +// Accessor forward declarations +template class accessor; +namespace accessor_policies { + struct obj_attr; + struct str_attr; + struct generic_item; + struct sequence_item; + struct list_item; + struct tuple_item; +} +using obj_attr_accessor = accessor; +using str_attr_accessor = accessor; +using item_accessor = accessor; +using sequence_accessor = accessor; +using list_accessor = accessor; +using tuple_accessor = accessor; + +/// Tag and check to identify a class which implements the Python object API +class pyobject_tag { }; +template using is_pyobject = std::is_base_of>; + +/** \rst + A mixin class which adds common functions to `handle`, `object` and various accessors. + The only requirement for `Derived` is to implement ``PyObject *Derived::ptr() const``. +\endrst */ +template +class object_api : public pyobject_tag { + const Derived &derived() const { return static_cast(*this); } + +public: + /** \rst + Return an iterator equivalent to calling ``iter()`` in Python. The object + must be a collection which supports the iteration protocol. + \endrst */ + iterator begin() const; + /// Return a sentinel which ends iteration. + iterator end() const; + + /** \rst + Return an internal functor to invoke the object's sequence protocol. Casting + the returned ``detail::item_accessor`` instance to a `handle` or `object` + subclass causes a corresponding call to ``__getitem__``. Assigning a `handle` + or `object` subclass causes a call to ``__setitem__``. + \endrst */ + item_accessor operator[](handle key) const; + /// See above (the only difference is that they key is provided as a string literal) + item_accessor operator[](const char *key) const; + + /** \rst + Return an internal functor to access the object's attributes. Casting the + returned ``detail::obj_attr_accessor`` instance to a `handle` or `object` + subclass causes a corresponding call to ``getattr``. Assigning a `handle` + or `object` subclass causes a call to ``setattr``. + \endrst */ + obj_attr_accessor attr(handle key) const; + /// See above (the only difference is that they key is provided as a string literal) + str_attr_accessor attr(const char *key) const; + + /** \rst + Matches * unpacking in Python, e.g. to unpack arguments out of a ``tuple`` + or ``list`` for a function call. Applying another * to the result yields + ** unpacking, e.g. to unpack a dict as function keyword arguments. + See :ref:`calling_python_functions`. + \endrst */ + args_proxy operator*() const; + + /// Check if the given item is contained within this object, i.e. ``item in obj``. + template bool contains(T &&item) const; + + /** \rst + Assuming the Python object is a function or implements the ``__call__`` + protocol, ``operator()`` invokes the underlying function, passing an + arbitrary set of parameters. The result is returned as a `object` and + may need to be converted back into a Python object using `handle::cast()`. + + When some of the arguments cannot be converted to Python objects, the + function will throw a `cast_error` exception. When the Python function + call fails, a `error_already_set` exception is thrown. + \endrst */ + template + object operator()(Args &&...args) const; + template + PYBIND11_DEPRECATED("call(...) was deprecated in favor of operator()(...)") + object call(Args&&... args) const; + + /// Equivalent to ``obj is other`` in Python. + bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); } + /// Equivalent to ``obj is None`` in Python. + bool is_none() const { return derived().ptr() == Py_None; } + /// Equivalent to obj == other in Python + bool equal(object_api const &other) const { return rich_compare(other, Py_EQ); } + bool not_equal(object_api const &other) const { return rich_compare(other, Py_NE); } + bool operator<(object_api const &other) const { return rich_compare(other, Py_LT); } + bool operator<=(object_api const &other) const { return rich_compare(other, Py_LE); } + bool operator>(object_api const &other) const { return rich_compare(other, Py_GT); } + bool operator>=(object_api const &other) const { return rich_compare(other, Py_GE); } + + object operator-() const; + object operator~() const; + object operator+(object_api const &other) const; + object operator+=(object_api const &other) const; + object operator-(object_api const &other) const; + object operator-=(object_api const &other) const; + object operator*(object_api const &other) const; + object operator*=(object_api const &other) const; + object operator/(object_api const &other) const; + object operator/=(object_api const &other) const; + object operator|(object_api const &other) const; + object operator|=(object_api const &other) const; + object operator&(object_api const &other) const; + object operator&=(object_api const &other) const; + object operator^(object_api const &other) const; + object operator^=(object_api const &other) const; + object operator<<(object_api const &other) const; + object operator<<=(object_api const &other) const; + object operator>>(object_api const &other) const; + object operator>>=(object_api const &other) const; + + PYBIND11_DEPRECATED("Use py::str(obj) instead") + pybind11::str str() const; + + /// Get or set the object's docstring, i.e. ``obj.__doc__``. + str_attr_accessor doc() const; + + /// Return the object's current reference count + int ref_count() const { return static_cast(Py_REFCNT(derived().ptr())); } + /// Return a handle to the Python type object underlying the instance + handle get_type() const; + +private: + bool rich_compare(object_api const &other, int value) const; +}; + +NAMESPACE_END(detail) + +/** \rst + Holds a reference to a Python object (no reference counting) + + The `handle` class is a thin wrapper around an arbitrary Python object (i.e. a + ``PyObject *`` in Python's C API). It does not perform any automatic reference + counting and merely provides a basic C++ interface to various Python API functions. + + .. seealso:: + The `object` class inherits from `handle` and adds automatic reference + counting features. +\endrst */ +class handle : public detail::object_api { +public: + /// The default constructor creates a handle with a ``nullptr``-valued pointer + handle() = default; + /// Creates a ``handle`` from the given raw Python object pointer + handle(PyObject *ptr) : m_ptr(ptr) { } // Allow implicit conversion from PyObject* + + /// Return the underlying ``PyObject *`` pointer + PyObject *ptr() const { return m_ptr; } + PyObject *&ptr() { return m_ptr; } + + /** \rst + Manually increase the reference count of the Python object. Usually, it is + preferable to use the `object` class which derives from `handle` and calls + this function automatically. Returns a reference to itself. + \endrst */ + const handle& inc_ref() const & { Py_XINCREF(m_ptr); return *this; } + + /** \rst + Manually decrease the reference count of the Python object. Usually, it is + preferable to use the `object` class which derives from `handle` and calls + this function automatically. Returns a reference to itself. + \endrst */ + const handle& dec_ref() const & { Py_XDECREF(m_ptr); return *this; } + + /** \rst + Attempt to cast the Python object into the given C++ type. A `cast_error` + will be throw upon failure. + \endrst */ + template T cast() const; + /// Return ``true`` when the `handle` wraps a valid Python object + explicit operator bool() const { return m_ptr != nullptr; } + /** \rst + Deprecated: Check that the underlying pointers are the same. + Equivalent to ``obj1 is obj2`` in Python. + \endrst */ + PYBIND11_DEPRECATED("Use obj1.is(obj2) instead") + bool operator==(const handle &h) const { return m_ptr == h.m_ptr; } + PYBIND11_DEPRECATED("Use !obj1.is(obj2) instead") + bool operator!=(const handle &h) const { return m_ptr != h.m_ptr; } + PYBIND11_DEPRECATED("Use handle::operator bool() instead") + bool check() const { return m_ptr != nullptr; } +protected: + PyObject *m_ptr = nullptr; +}; + +/** \rst + Holds a reference to a Python object (with reference counting) + + Like `handle`, the `object` class is a thin wrapper around an arbitrary Python + object (i.e. a ``PyObject *`` in Python's C API). In contrast to `handle`, it + optionally increases the object's reference count upon construction, and it + *always* decreases the reference count when the `object` instance goes out of + scope and is destructed. When using `object` instances consistently, it is much + easier to get reference counting right at the first attempt. +\endrst */ +class object : public handle { +public: + object() = default; + PYBIND11_DEPRECATED("Use reinterpret_borrow() or reinterpret_steal()") + object(handle h, bool is_borrowed) : handle(h) { if (is_borrowed) inc_ref(); } + /// Copy constructor; always increases the reference count + object(const object &o) : handle(o) { inc_ref(); } + /// Move constructor; steals the object from ``other`` and preserves its reference count + object(object &&other) noexcept { m_ptr = other.m_ptr; other.m_ptr = nullptr; } + /// Destructor; automatically calls `handle::dec_ref()` + ~object() { dec_ref(); } + + /** \rst + Resets the internal pointer to ``nullptr`` without without decreasing the + object's reference count. The function returns a raw handle to the original + Python object. + \endrst */ + handle release() { + PyObject *tmp = m_ptr; + m_ptr = nullptr; + return handle(tmp); + } + + object& operator=(const object &other) { + other.inc_ref(); + dec_ref(); + m_ptr = other.m_ptr; + return *this; + } + + object& operator=(object &&other) noexcept { + if (this != &other) { + handle temp(m_ptr); + m_ptr = other.m_ptr; + other.m_ptr = nullptr; + temp.dec_ref(); + } + return *this; + } + + // Calling cast() on an object lvalue just copies (via handle::cast) + template T cast() const &; + // Calling on an object rvalue does a move, if needed and/or possible + template T cast() &&; + +protected: + // Tags for choosing constructors from raw PyObject * + struct borrowed_t { }; + struct stolen_t { }; + + template friend T reinterpret_borrow(handle); + template friend T reinterpret_steal(handle); + +public: + // Only accessible from derived classes and the reinterpret_* functions + object(handle h, borrowed_t) : handle(h) { inc_ref(); } + object(handle h, stolen_t) : handle(h) { } +}; + +/** \rst + Declare that a `handle` or ``PyObject *`` is a certain type and borrow the reference. + The target type ``T`` must be `object` or one of its derived classes. The function + doesn't do any conversions or checks. It's up to the user to make sure that the + target type is correct. + + .. code-block:: cpp + + PyObject *p = PyList_GetItem(obj, index); + py::object o = reinterpret_borrow(p); + // or + py::tuple t = reinterpret_borrow(p); // <-- `p` must be already be a `tuple` +\endrst */ +template T reinterpret_borrow(handle h) { return {h, object::borrowed_t{}}; } + +/** \rst + Like `reinterpret_borrow`, but steals the reference. + + .. code-block:: cpp + + PyObject *p = PyObject_Str(obj); + py::str s = reinterpret_steal(p); // <-- `p` must be already be a `str` +\endrst */ +template T reinterpret_steal(handle h) { return {h, object::stolen_t{}}; } + +NAMESPACE_BEGIN(detail) +inline std::string error_string(); +NAMESPACE_END(detail) + +/// Fetch and hold an error which was already set in Python. An instance of this is typically +/// thrown to propagate python-side errors back through C++ which can either be caught manually or +/// else falls back to the function dispatcher (which then raises the captured error back to +/// python). +class error_already_set : public std::runtime_error { +public: + /// Constructs a new exception from the current Python error indicator, if any. The current + /// Python error indicator will be cleared. + error_already_set() : std::runtime_error(detail::error_string()) { + PyErr_Fetch(&type.ptr(), &value.ptr(), &trace.ptr()); + } + + error_already_set(const error_already_set &) = default; + error_already_set(error_already_set &&) = default; + + inline ~error_already_set(); + + /// Give the currently-held error back to Python, if any. If there is currently a Python error + /// already set it is cleared first. After this call, the current object no longer stores the + /// error variables (but the `.what()` string is still available). + void restore() { PyErr_Restore(type.release().ptr(), value.release().ptr(), trace.release().ptr()); } + + // Does nothing; provided for backwards compatibility. + PYBIND11_DEPRECATED("Use of error_already_set.clear() is deprecated") + void clear() {} + + /// Check if the currently trapped error type matches the given Python exception class (or a + /// subclass thereof). May also be passed a tuple to search for any exception class matches in + /// the given tuple. + bool matches(handle ex) const { return PyErr_GivenExceptionMatches(ex.ptr(), type.ptr()); } + +private: + object type, value, trace; +}; + +/** \defgroup python_builtins _ + Unless stated otherwise, the following C++ functions behave the same + as their Python counterparts. + */ + +/** \ingroup python_builtins + \rst + Return true if ``obj`` is an instance of ``T``. Type ``T`` must be a subclass of + `object` or a class which was exposed to Python as ``py::class_``. +\endrst */ +template ::value, int> = 0> +bool isinstance(handle obj) { return T::check_(obj); } + +template ::value, int> = 0> +bool isinstance(handle obj) { return detail::isinstance_generic(obj, typeid(T)); } + +template <> inline bool isinstance(handle obj) = delete; +template <> inline bool isinstance(handle obj) { return obj.ptr() != nullptr; } + +/// \ingroup python_builtins +/// Return true if ``obj`` is an instance of the ``type``. +inline bool isinstance(handle obj, handle type) { + const auto result = PyObject_IsInstance(obj.ptr(), type.ptr()); + if (result == -1) + throw error_already_set(); + return result != 0; +} + +/// \addtogroup python_builtins +/// @{ +inline bool hasattr(handle obj, handle name) { + return PyObject_HasAttr(obj.ptr(), name.ptr()) == 1; +} + +inline bool hasattr(handle obj, const char *name) { + return PyObject_HasAttrString(obj.ptr(), name) == 1; +} + +inline void delattr(handle obj, handle name) { + if (PyObject_DelAttr(obj.ptr(), name.ptr()) != 0) { throw error_already_set(); } +} + +inline void delattr(handle obj, const char *name) { + if (PyObject_DelAttrString(obj.ptr(), name) != 0) { throw error_already_set(); } +} + +inline object getattr(handle obj, handle name) { + PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr()); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); +} + +inline object getattr(handle obj, const char *name) { + PyObject *result = PyObject_GetAttrString(obj.ptr(), name); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); +} + +inline object getattr(handle obj, handle name, handle default_) { + if (PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr())) { + return reinterpret_steal(result); + } else { + PyErr_Clear(); + return reinterpret_borrow(default_); + } +} + +inline object getattr(handle obj, const char *name, handle default_) { + if (PyObject *result = PyObject_GetAttrString(obj.ptr(), name)) { + return reinterpret_steal(result); + } else { + PyErr_Clear(); + return reinterpret_borrow(default_); + } +} + +inline void setattr(handle obj, handle name, handle value) { + if (PyObject_SetAttr(obj.ptr(), name.ptr(), value.ptr()) != 0) { throw error_already_set(); } +} + +inline void setattr(handle obj, const char *name, handle value) { + if (PyObject_SetAttrString(obj.ptr(), name, value.ptr()) != 0) { throw error_already_set(); } +} + +inline ssize_t hash(handle obj) { + auto h = PyObject_Hash(obj.ptr()); + if (h == -1) { throw error_already_set(); } + return h; +} + +/// @} python_builtins + +NAMESPACE_BEGIN(detail) +inline handle get_function(handle value) { + if (value) { +#if PY_MAJOR_VERSION >= 3 + if (PyInstanceMethod_Check(value.ptr())) + value = PyInstanceMethod_GET_FUNCTION(value.ptr()); + else +#endif + if (PyMethod_Check(value.ptr())) + value = PyMethod_GET_FUNCTION(value.ptr()); + } + return value; +} + +// Helper aliases/functions to support implicit casting of values given to python accessors/methods. +// When given a pyobject, this simply returns the pyobject as-is; for other C++ type, the value goes +// through pybind11::cast(obj) to convert it to an `object`. +template ::value, int> = 0> +auto object_or_cast(T &&o) -> decltype(std::forward(o)) { return std::forward(o); } +// The following casting version is implemented in cast.h: +template ::value, int> = 0> +object object_or_cast(T &&o); +// Match a PyObject*, which we want to convert directly to handle via its converting constructor +inline handle object_or_cast(PyObject *ptr) { return ptr; } + +template +class accessor : public object_api> { + using key_type = typename Policy::key_type; + +public: + accessor(handle obj, key_type key) : obj(obj), key(std::move(key)) { } + accessor(const accessor &) = default; + accessor(accessor &&) = default; + + // accessor overload required to override default assignment operator (templates are not allowed + // to replace default compiler-generated assignments). + void operator=(const accessor &a) && { std::move(*this).operator=(handle(a)); } + void operator=(const accessor &a) & { operator=(handle(a)); } + + template void operator=(T &&value) && { + Policy::set(obj, key, object_or_cast(std::forward(value))); + } + template void operator=(T &&value) & { + get_cache() = reinterpret_borrow(object_or_cast(std::forward(value))); + } + + template + PYBIND11_DEPRECATED("Use of obj.attr(...) as bool is deprecated in favor of pybind11::hasattr(obj, ...)") + explicit operator enable_if_t::value || + std::is_same::value, bool>() const { + return hasattr(obj, key); + } + template + PYBIND11_DEPRECATED("Use of obj[key] as bool is deprecated in favor of obj.contains(key)") + explicit operator enable_if_t::value, bool>() const { + return obj.contains(key); + } + + operator object() const { return get_cache(); } + PyObject *ptr() const { return get_cache().ptr(); } + template T cast() const { return get_cache().template cast(); } + +private: + object &get_cache() const { + if (!cache) { cache = Policy::get(obj, key); } + return cache; + } + +private: + handle obj; + key_type key; + mutable object cache; +}; + +NAMESPACE_BEGIN(accessor_policies) +struct obj_attr { + using key_type = object; + static object get(handle obj, handle key) { return getattr(obj, key); } + static void set(handle obj, handle key, handle val) { setattr(obj, key, val); } +}; + +struct str_attr { + using key_type = const char *; + static object get(handle obj, const char *key) { return getattr(obj, key); } + static void set(handle obj, const char *key, handle val) { setattr(obj, key, val); } +}; + +struct generic_item { + using key_type = object; + + static object get(handle obj, handle key) { + PyObject *result = PyObject_GetItem(obj.ptr(), key.ptr()); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); + } + + static void set(handle obj, handle key, handle val) { + if (PyObject_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) { throw error_already_set(); } + } +}; + +struct sequence_item { + using key_type = size_t; + + static object get(handle obj, size_t index) { + PyObject *result = PySequence_GetItem(obj.ptr(), static_cast(index)); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); + } + + static void set(handle obj, size_t index, handle val) { + // PySequence_SetItem does not steal a reference to 'val' + if (PySequence_SetItem(obj.ptr(), static_cast(index), val.ptr()) != 0) { + throw error_already_set(); + } + } +}; + +struct list_item { + using key_type = size_t; + + static object get(handle obj, size_t index) { + PyObject *result = PyList_GetItem(obj.ptr(), static_cast(index)); + if (!result) { throw error_already_set(); } + return reinterpret_borrow(result); + } + + static void set(handle obj, size_t index, handle val) { + // PyList_SetItem steals a reference to 'val' + if (PyList_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { + throw error_already_set(); + } + } +}; + +struct tuple_item { + using key_type = size_t; + + static object get(handle obj, size_t index) { + PyObject *result = PyTuple_GetItem(obj.ptr(), static_cast(index)); + if (!result) { throw error_already_set(); } + return reinterpret_borrow(result); + } + + static void set(handle obj, size_t index, handle val) { + // PyTuple_SetItem steals a reference to 'val' + if (PyTuple_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { + throw error_already_set(); + } + } +}; +NAMESPACE_END(accessor_policies) + +/// STL iterator template used for tuple, list, sequence and dict +template +class generic_iterator : public Policy { + using It = generic_iterator; + +public: + using difference_type = ssize_t; + using iterator_category = typename Policy::iterator_category; + using value_type = typename Policy::value_type; + using reference = typename Policy::reference; + using pointer = typename Policy::pointer; + + generic_iterator() = default; + generic_iterator(handle seq, ssize_t index) : Policy(seq, index) { } + + reference operator*() const { return Policy::dereference(); } + reference operator[](difference_type n) const { return *(*this + n); } + pointer operator->() const { return **this; } + + It &operator++() { Policy::increment(); return *this; } + It operator++(int) { auto copy = *this; Policy::increment(); return copy; } + It &operator--() { Policy::decrement(); return *this; } + It operator--(int) { auto copy = *this; Policy::decrement(); return copy; } + It &operator+=(difference_type n) { Policy::advance(n); return *this; } + It &operator-=(difference_type n) { Policy::advance(-n); return *this; } + + friend It operator+(const It &a, difference_type n) { auto copy = a; return copy += n; } + friend It operator+(difference_type n, const It &b) { return b + n; } + friend It operator-(const It &a, difference_type n) { auto copy = a; return copy -= n; } + friend difference_type operator-(const It &a, const It &b) { return a.distance_to(b); } + + friend bool operator==(const It &a, const It &b) { return a.equal(b); } + friend bool operator!=(const It &a, const It &b) { return !(a == b); } + friend bool operator< (const It &a, const It &b) { return b - a > 0; } + friend bool operator> (const It &a, const It &b) { return b < a; } + friend bool operator>=(const It &a, const It &b) { return !(a < b); } + friend bool operator<=(const It &a, const It &b) { return !(a > b); } +}; + +NAMESPACE_BEGIN(iterator_policies) +/// Quick proxy class needed to implement ``operator->`` for iterators which can't return pointers +template +struct arrow_proxy { + T value; + + arrow_proxy(T &&value) : value(std::move(value)) { } + T *operator->() const { return &value; } +}; + +/// Lightweight iterator policy using just a simple pointer: see ``PySequence_Fast_ITEMS`` +class sequence_fast_readonly { +protected: + using iterator_category = std::random_access_iterator_tag; + using value_type = handle; + using reference = const handle; + using pointer = arrow_proxy; + + sequence_fast_readonly(handle obj, ssize_t n) : ptr(PySequence_Fast_ITEMS(obj.ptr()) + n) { } + + reference dereference() const { return *ptr; } + void increment() { ++ptr; } + void decrement() { --ptr; } + void advance(ssize_t n) { ptr += n; } + bool equal(const sequence_fast_readonly &b) const { return ptr == b.ptr; } + ssize_t distance_to(const sequence_fast_readonly &b) const { return ptr - b.ptr; } + +private: + PyObject **ptr; +}; + +/// Full read and write access using the sequence protocol: see ``detail::sequence_accessor`` +class sequence_slow_readwrite { +protected: + using iterator_category = std::random_access_iterator_tag; + using value_type = object; + using reference = sequence_accessor; + using pointer = arrow_proxy; + + sequence_slow_readwrite(handle obj, ssize_t index) : obj(obj), index(index) { } + + reference dereference() const { return {obj, static_cast(index)}; } + void increment() { ++index; } + void decrement() { --index; } + void advance(ssize_t n) { index += n; } + bool equal(const sequence_slow_readwrite &b) const { return index == b.index; } + ssize_t distance_to(const sequence_slow_readwrite &b) const { return index - b.index; } + +private: + handle obj; + ssize_t index; +}; + +/// Python's dictionary protocol permits this to be a forward iterator +class dict_readonly { +protected: + using iterator_category = std::forward_iterator_tag; + using value_type = std::pair; + using reference = const value_type; + using pointer = arrow_proxy; + + dict_readonly() = default; + dict_readonly(handle obj, ssize_t pos) : obj(obj), pos(pos) { increment(); } + + reference dereference() const { return {key, value}; } + void increment() { if (!PyDict_Next(obj.ptr(), &pos, &key, &value)) { pos = -1; } } + bool equal(const dict_readonly &b) const { return pos == b.pos; } + +private: + handle obj; + PyObject *key, *value; + ssize_t pos = -1; +}; +NAMESPACE_END(iterator_policies) + +#if !defined(PYPY_VERSION) +using tuple_iterator = generic_iterator; +using list_iterator = generic_iterator; +#else +using tuple_iterator = generic_iterator; +using list_iterator = generic_iterator; +#endif + +using sequence_iterator = generic_iterator; +using dict_iterator = generic_iterator; + +inline bool PyIterable_Check(PyObject *obj) { + PyObject *iter = PyObject_GetIter(obj); + if (iter) { + Py_DECREF(iter); + return true; + } else { + PyErr_Clear(); + return false; + } +} + +inline bool PyNone_Check(PyObject *o) { return o == Py_None; } +#if PY_MAJOR_VERSION >= 3 +inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; } +#endif + +inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); } + +class kwargs_proxy : public handle { +public: + explicit kwargs_proxy(handle h) : handle(h) { } +}; + +class args_proxy : public handle { +public: + explicit args_proxy(handle h) : handle(h) { } + kwargs_proxy operator*() const { return kwargs_proxy(*this); } +}; + +/// Python argument categories (using PEP 448 terms) +template using is_keyword = std::is_base_of; +template using is_s_unpacking = std::is_same; // * unpacking +template using is_ds_unpacking = std::is_same; // ** unpacking +template using is_positional = satisfies_none_of; +template using is_keyword_or_ds = satisfies_any_of; + +// Call argument collector forward declarations +template +class simple_collector; +template +class unpacking_collector; + +NAMESPACE_END(detail) + +// TODO: After the deprecated constructors are removed, this macro can be simplified by +// inheriting ctors: `using Parent::Parent`. It's not an option right now because +// the `using` statement triggers the parent deprecation warning even if the ctor +// isn't even used. +#define PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ + public: \ + PYBIND11_DEPRECATED("Use reinterpret_borrow<"#Name">() or reinterpret_steal<"#Name">()") \ + Name(handle h, bool is_borrowed) : Parent(is_borrowed ? Parent(h, borrowed_t{}) : Parent(h, stolen_t{})) { } \ + Name(handle h, borrowed_t) : Parent(h, borrowed_t{}) { } \ + Name(handle h, stolen_t) : Parent(h, stolen_t{}) { } \ + PYBIND11_DEPRECATED("Use py::isinstance(obj) instead") \ + bool check() const { return m_ptr != nullptr && (bool) CheckFun(m_ptr); } \ + static bool check_(handle h) { return h.ptr() != nullptr && CheckFun(h.ptr()); } + +#define PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, ConvertFun) \ + PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ + /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ + Name(const object &o) \ + : Parent(check_(o) ? o.inc_ref().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ + { if (!m_ptr) throw error_already_set(); } \ + Name(object &&o) \ + : Parent(check_(o) ? o.release().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ + { if (!m_ptr) throw error_already_set(); } \ + template \ + Name(const ::pybind11::detail::accessor &a) : Name(object(a)) { } + +#define PYBIND11_OBJECT(Name, Parent, CheckFun) \ + PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ + /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ + Name(const object &o) : Parent(o) { } \ + Name(object &&o) : Parent(std::move(o)) { } + +#define PYBIND11_OBJECT_DEFAULT(Name, Parent, CheckFun) \ + PYBIND11_OBJECT(Name, Parent, CheckFun) \ + Name() : Parent() { } + +/// \addtogroup pytypes +/// @{ + +/** \rst + Wraps a Python iterator so that it can also be used as a C++ input iterator + + Caveat: copying an iterator does not (and cannot) clone the internal + state of the Python iterable. This also applies to the post-increment + operator. This iterator should only be used to retrieve the current + value using ``operator*()``. +\endrst */ +class iterator : public object { +public: + using iterator_category = std::input_iterator_tag; + using difference_type = ssize_t; + using value_type = handle; + using reference = const handle; + using pointer = const handle *; + + PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) + + iterator& operator++() { + advance(); + return *this; + } + + iterator operator++(int) { + auto rv = *this; + advance(); + return rv; + } + + reference operator*() const { + if (m_ptr && !value.ptr()) { + auto& self = const_cast(*this); + self.advance(); + } + return value; + } + + pointer operator->() const { operator*(); return &value; } + + /** \rst + The value which marks the end of the iteration. ``it == iterator::sentinel()`` + is equivalent to catching ``StopIteration`` in Python. + + .. code-block:: cpp + + void foo(py::iterator it) { + while (it != py::iterator::sentinel()) { + // use `*it` + ++it; + } + } + \endrst */ + static iterator sentinel() { return {}; } + + friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); } + friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); } + +private: + void advance() { + value = reinterpret_steal(PyIter_Next(m_ptr)); + if (PyErr_Occurred()) { throw error_already_set(); } + } + +private: + object value = {}; +}; + +class iterable : public object { +public: + PYBIND11_OBJECT_DEFAULT(iterable, object, detail::PyIterable_Check) +}; + +class bytes; + +class str : public object { +public: + PYBIND11_OBJECT_CVT(str, object, detail::PyUnicode_Check_Permissive, raw_str) + + str(const char *c, size_t n) + : object(PyUnicode_FromStringAndSize(c, (ssize_t) n), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate string object!"); + } + + // 'explicit' is explicitly omitted from the following constructors to allow implicit conversion to py::str from C++ string-like objects + str(const char *c = "") + : object(PyUnicode_FromString(c), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate string object!"); + } + + str(const std::string &s) : str(s.data(), s.size()) { } + + explicit str(const bytes &b); + + /** \rst + Return a string representation of the object. This is analogous to + the ``str()`` function in Python. + \endrst */ + explicit str(handle h) : object(raw_str(h.ptr()), stolen_t{}) { } + + operator std::string() const { + object temp = *this; + if (PyUnicode_Check(m_ptr)) { + temp = reinterpret_steal(PyUnicode_AsUTF8String(m_ptr)); + if (!temp) + pybind11_fail("Unable to extract string contents! (encoding issue)"); + } + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) + pybind11_fail("Unable to extract string contents! (invalid type)"); + return std::string(buffer, (size_t) length); + } + + template + str format(Args &&...args) const { + return attr("format")(std::forward(args)...); + } + +private: + /// Return string representation -- always returns a new reference, even if already a str + static PyObject *raw_str(PyObject *op) { + PyObject *str_value = PyObject_Str(op); +#if PY_MAJOR_VERSION < 3 + if (!str_value) throw error_already_set(); + PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr); + Py_XDECREF(str_value); str_value = unicode; +#endif + return str_value; + } +}; +/// @} pytypes + +inline namespace literals { +/** \rst + String literal version of `str` + \endrst */ +inline str operator"" _s(const char *s, size_t size) { return {s, size}; } +} + +/// \addtogroup pytypes +/// @{ +class bytes : public object { +public: + PYBIND11_OBJECT(bytes, object, PYBIND11_BYTES_CHECK) + + // Allow implicit conversion: + bytes(const char *c = "") + : object(PYBIND11_BYTES_FROM_STRING(c), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); + } + + bytes(const char *c, size_t n) + : object(PYBIND11_BYTES_FROM_STRING_AND_SIZE(c, (ssize_t) n), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); + } + + // Allow implicit conversion: + bytes(const std::string &s) : bytes(s.data(), s.size()) { } + + explicit bytes(const pybind11::str &s); + + operator std::string() const { + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(m_ptr, &buffer, &length)) + pybind11_fail("Unable to extract bytes contents!"); + return std::string(buffer, (size_t) length); + } +}; + +inline bytes::bytes(const pybind11::str &s) { + object temp = s; + if (PyUnicode_Check(s.ptr())) { + temp = reinterpret_steal(PyUnicode_AsUTF8String(s.ptr())); + if (!temp) + pybind11_fail("Unable to extract string contents! (encoding issue)"); + } + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) + pybind11_fail("Unable to extract string contents! (invalid type)"); + auto obj = reinterpret_steal(PYBIND11_BYTES_FROM_STRING_AND_SIZE(buffer, length)); + if (!obj) + pybind11_fail("Could not allocate bytes object!"); + m_ptr = obj.release().ptr(); +} + +inline str::str(const bytes& b) { + char *buffer; + ssize_t length; + if (PYBIND11_BYTES_AS_STRING_AND_SIZE(b.ptr(), &buffer, &length)) + pybind11_fail("Unable to extract bytes contents!"); + auto obj = reinterpret_steal(PyUnicode_FromStringAndSize(buffer, (ssize_t) length)); + if (!obj) + pybind11_fail("Could not allocate string object!"); + m_ptr = obj.release().ptr(); +} + +class none : public object { +public: + PYBIND11_OBJECT(none, object, detail::PyNone_Check) + none() : object(Py_None, borrowed_t{}) { } +}; + +#if PY_MAJOR_VERSION >= 3 +class ellipsis : public object { +public: + PYBIND11_OBJECT(ellipsis, object, detail::PyEllipsis_Check) + ellipsis() : object(Py_Ellipsis, borrowed_t{}) { } +}; +#endif + +class bool_ : public object { +public: + PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool) + bool_() : object(Py_False, borrowed_t{}) { } + // Allow implicit conversion from and to `bool`: + bool_(bool value) : object(value ? Py_True : Py_False, borrowed_t{}) { } + operator bool() const { return m_ptr && PyLong_AsLong(m_ptr) != 0; } + +private: + /// Return the truth value of an object -- always returns a new reference + static PyObject *raw_bool(PyObject *op) { + const auto value = PyObject_IsTrue(op); + if (value == -1) return nullptr; + return handle(value ? Py_True : Py_False).inc_ref().ptr(); + } +}; + +NAMESPACE_BEGIN(detail) +// Converts a value to the given unsigned type. If an error occurs, you get back (Unsigned) -1; +// otherwise you get back the unsigned long or unsigned long long value cast to (Unsigned). +// (The distinction is critically important when casting a returned -1 error value to some other +// unsigned type: (A)-1 != (B)-1 when A and B are unsigned types of different sizes). +template +Unsigned as_unsigned(PyObject *o) { + if (sizeof(Unsigned) <= sizeof(unsigned long) +#if PY_VERSION_HEX < 0x03000000 + || PyInt_Check(o) +#endif + ) { + unsigned long v = PyLong_AsUnsignedLong(o); + return v == (unsigned long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; + } + else { + unsigned long long v = PyLong_AsUnsignedLongLong(o); + return v == (unsigned long long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; + } +} +NAMESPACE_END(detail) + +class int_ : public object { +public: + PYBIND11_OBJECT_CVT(int_, object, PYBIND11_LONG_CHECK, PyNumber_Long) + int_() : object(PyLong_FromLong(0), stolen_t{}) { } + // Allow implicit conversion from C++ integral types: + template ::value, int> = 0> + int_(T value) { + if (sizeof(T) <= sizeof(long)) { + if (std::is_signed::value) + m_ptr = PyLong_FromLong((long) value); + else + m_ptr = PyLong_FromUnsignedLong((unsigned long) value); + } else { + if (std::is_signed::value) + m_ptr = PyLong_FromLongLong((long long) value); + else + m_ptr = PyLong_FromUnsignedLongLong((unsigned long long) value); + } + if (!m_ptr) pybind11_fail("Could not allocate int object!"); + } + + template ::value, int> = 0> + operator T() const { + return std::is_unsigned::value + ? detail::as_unsigned(m_ptr) + : sizeof(T) <= sizeof(long) + ? (T) PyLong_AsLong(m_ptr) + : (T) PYBIND11_LONG_AS_LONGLONG(m_ptr); + } +}; + +class float_ : public object { +public: + PYBIND11_OBJECT_CVT(float_, object, PyFloat_Check, PyNumber_Float) + // Allow implicit conversion from float/double: + float_(float value) : object(PyFloat_FromDouble((double) value), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate float object!"); + } + float_(double value = .0) : object(PyFloat_FromDouble((double) value), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate float object!"); + } + operator float() const { return (float) PyFloat_AsDouble(m_ptr); } + operator double() const { return (double) PyFloat_AsDouble(m_ptr); } +}; + +class weakref : public object { +public: + PYBIND11_OBJECT_DEFAULT(weakref, object, PyWeakref_Check) + explicit weakref(handle obj, handle callback = {}) + : object(PyWeakref_NewRef(obj.ptr(), callback.ptr()), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate weak reference!"); + } +}; + +class slice : public object { +public: + PYBIND11_OBJECT_DEFAULT(slice, object, PySlice_Check) + slice(ssize_t start_, ssize_t stop_, ssize_t step_) { + int_ start(start_), stop(stop_), step(step_); + m_ptr = PySlice_New(start.ptr(), stop.ptr(), step.ptr()); + if (!m_ptr) pybind11_fail("Could not allocate slice object!"); + } + bool compute(size_t length, size_t *start, size_t *stop, size_t *step, + size_t *slicelength) const { + return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr, + (ssize_t) length, (ssize_t *) start, + (ssize_t *) stop, (ssize_t *) step, + (ssize_t *) slicelength) == 0; + } +}; + +class capsule : public object { +public: + PYBIND11_OBJECT_DEFAULT(capsule, object, PyCapsule_CheckExact) + PYBIND11_DEPRECATED("Use reinterpret_borrow() or reinterpret_steal()") + capsule(PyObject *ptr, bool is_borrowed) : object(is_borrowed ? object(ptr, borrowed_t{}) : object(ptr, stolen_t{})) { } + + explicit capsule(const void *value, const char *name = nullptr, void (*destructor)(PyObject *) = nullptr) + : object(PyCapsule_New(const_cast(value), name, destructor), stolen_t{}) { + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + } + + PYBIND11_DEPRECATED("Please pass a destructor that takes a void pointer as input") + capsule(const void *value, void (*destruct)(PyObject *)) + : object(PyCapsule_New(const_cast(value), nullptr, destruct), stolen_t{}) { + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + } + + capsule(const void *value, void (*destructor)(void *)) { + m_ptr = PyCapsule_New(const_cast(value), nullptr, [](PyObject *o) { + auto destructor = reinterpret_cast(PyCapsule_GetContext(o)); + void *ptr = PyCapsule_GetPointer(o, nullptr); + destructor(ptr); + }); + + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + + if (PyCapsule_SetContext(m_ptr, (void *) destructor) != 0) + pybind11_fail("Could not set capsule context!"); + } + + capsule(void (*destructor)()) { + m_ptr = PyCapsule_New(reinterpret_cast(destructor), nullptr, [](PyObject *o) { + auto destructor = reinterpret_cast(PyCapsule_GetPointer(o, nullptr)); + destructor(); + }); + + if (!m_ptr) + pybind11_fail("Could not allocate capsule object!"); + } + + template operator T *() const { + auto name = this->name(); + T * result = static_cast(PyCapsule_GetPointer(m_ptr, name)); + if (!result) pybind11_fail("Unable to extract capsule contents!"); + return result; + } + + const char *name() const { return PyCapsule_GetName(m_ptr); } +}; + +class tuple : public object { +public: + PYBIND11_OBJECT_CVT(tuple, object, PyTuple_Check, PySequence_Tuple) + explicit tuple(size_t size = 0) : object(PyTuple_New((ssize_t) size), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate tuple object!"); + } + size_t size() const { return (size_t) PyTuple_Size(m_ptr); } + detail::tuple_accessor operator[](size_t index) const { return {*this, index}; } + detail::item_accessor operator[](handle h) const { return object::operator[](h); } + detail::tuple_iterator begin() const { return {*this, 0}; } + detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; } +}; + +class dict : public object { +public: + PYBIND11_OBJECT_CVT(dict, object, PyDict_Check, raw_dict) + dict() : object(PyDict_New(), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate dict object!"); + } + template ...>::value>, + // MSVC workaround: it can't compile an out-of-line definition, so defer the collector + typename collector = detail::deferred_t, Args...>> + explicit dict(Args &&...args) : dict(collector(std::forward(args)...).kwargs()) { } + + size_t size() const { return (size_t) PyDict_Size(m_ptr); } + detail::dict_iterator begin() const { return {*this, 0}; } + detail::dict_iterator end() const { return {}; } + void clear() const { PyDict_Clear(ptr()); } + bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; } + bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; } + +private: + /// Call the `dict` Python type -- always returns a new reference + static PyObject *raw_dict(PyObject *op) { + if (PyDict_Check(op)) + return handle(op).inc_ref().ptr(); + return PyObject_CallFunctionObjArgs((PyObject *) &PyDict_Type, op, nullptr); + } +}; + +class sequence : public object { +public: + PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check) + size_t size() const { return (size_t) PySequence_Size(m_ptr); } + detail::sequence_accessor operator[](size_t index) const { return {*this, index}; } + detail::item_accessor operator[](handle h) const { return object::operator[](h); } + detail::sequence_iterator begin() const { return {*this, 0}; } + detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; } +}; + +class list : public object { +public: + PYBIND11_OBJECT_CVT(list, object, PyList_Check, PySequence_List) + explicit list(size_t size = 0) : object(PyList_New((ssize_t) size), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate list object!"); + } + size_t size() const { return (size_t) PyList_Size(m_ptr); } + detail::list_accessor operator[](size_t index) const { return {*this, index}; } + detail::item_accessor operator[](handle h) const { return object::operator[](h); } + detail::list_iterator begin() const { return {*this, 0}; } + detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; } + template void append(T &&val) const { + PyList_Append(m_ptr, detail::object_or_cast(std::forward(val)).ptr()); + } +}; + +class args : public tuple { PYBIND11_OBJECT_DEFAULT(args, tuple, PyTuple_Check) }; +class kwargs : public dict { PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check) }; + +class set : public object { +public: + PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New) + set() : object(PySet_New(nullptr), stolen_t{}) { + if (!m_ptr) pybind11_fail("Could not allocate set object!"); + } + size_t size() const { return (size_t) PySet_Size(m_ptr); } + template bool add(T &&val) const { + return PySet_Add(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 0; + } + void clear() const { PySet_Clear(m_ptr); } +}; + +class function : public object { +public: + PYBIND11_OBJECT_DEFAULT(function, object, PyCallable_Check) + handle cpp_function() const { + handle fun = detail::get_function(m_ptr); + if (fun && PyCFunction_Check(fun.ptr())) + return fun; + return handle(); + } + bool is_cpp_function() const { return (bool) cpp_function(); } +}; + +class buffer : public object { +public: + PYBIND11_OBJECT_DEFAULT(buffer, object, PyObject_CheckBuffer) + + buffer_info request(bool writable = false) { + int flags = PyBUF_STRIDES | PyBUF_FORMAT; + if (writable) flags |= PyBUF_WRITABLE; + Py_buffer *view = new Py_buffer(); + if (PyObject_GetBuffer(m_ptr, view, flags) != 0) { + delete view; + throw error_already_set(); + } + return buffer_info(view); + } +}; + +class memoryview : public object { +public: + explicit memoryview(const buffer_info& info) { + static Py_buffer buf { }; + // Py_buffer uses signed sizes, strides and shape!.. + static std::vector py_strides { }; + static std::vector py_shape { }; + buf.buf = info.ptr; + buf.itemsize = info.itemsize; + buf.format = const_cast(info.format.c_str()); + buf.ndim = (int) info.ndim; + buf.len = info.size; + py_strides.clear(); + py_shape.clear(); + for (size_t i = 0; i < (size_t) info.ndim; ++i) { + py_strides.push_back(info.strides[i]); + py_shape.push_back(info.shape[i]); + } + buf.strides = py_strides.data(); + buf.shape = py_shape.data(); + buf.suboffsets = nullptr; + buf.readonly = false; + buf.internal = nullptr; + + m_ptr = PyMemoryView_FromBuffer(&buf); + if (!m_ptr) + pybind11_fail("Unable to create memoryview from buffer descriptor"); + } + + PYBIND11_OBJECT_CVT(memoryview, object, PyMemoryView_Check, PyMemoryView_FromObject) +}; +/// @} pytypes + +/// \addtogroup python_builtins +/// @{ +inline size_t len(handle h) { + ssize_t result = PyObject_Length(h.ptr()); + if (result < 0) + pybind11_fail("Unable to compute length of object"); + return (size_t) result; +} + +inline str repr(handle h) { + PyObject *str_value = PyObject_Repr(h.ptr()); + if (!str_value) throw error_already_set(); +#if PY_MAJOR_VERSION < 3 + PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr); + Py_XDECREF(str_value); str_value = unicode; + if (!str_value) throw error_already_set(); +#endif + return reinterpret_steal(str_value); +} + +inline iterator iter(handle obj) { + PyObject *result = PyObject_GetIter(obj.ptr()); + if (!result) { throw error_already_set(); } + return reinterpret_steal(result); +} +/// @} python_builtins + +NAMESPACE_BEGIN(detail) +template iterator object_api::begin() const { return iter(derived()); } +template iterator object_api::end() const { return iterator::sentinel(); } +template item_accessor object_api::operator[](handle key) const { + return {derived(), reinterpret_borrow(key)}; +} +template item_accessor object_api::operator[](const char *key) const { + return {derived(), pybind11::str(key)}; +} +template obj_attr_accessor object_api::attr(handle key) const { + return {derived(), reinterpret_borrow(key)}; +} +template str_attr_accessor object_api::attr(const char *key) const { + return {derived(), key}; +} +template args_proxy object_api::operator*() const { + return args_proxy(derived().ptr()); +} +template template bool object_api::contains(T &&item) const { + return attr("__contains__")(std::forward(item)).template cast(); +} + +template +pybind11::str object_api::str() const { return pybind11::str(derived()); } + +template +str_attr_accessor object_api::doc() const { return attr("__doc__"); } + +template +handle object_api::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); } + +template +bool object_api::rich_compare(object_api const &other, int value) const { + int rv = PyObject_RichCompareBool(derived().ptr(), other.derived().ptr(), value); + if (rv == -1) + throw error_already_set(); + return rv == 1; +} + +#define PYBIND11_MATH_OPERATOR_UNARY(op, fn) \ + template object object_api::op() const { \ + object result = reinterpret_steal(fn(derived().ptr())); \ + if (!result.ptr()) \ + throw error_already_set(); \ + return result; \ + } + +#define PYBIND11_MATH_OPERATOR_BINARY(op, fn) \ + template \ + object object_api::op(object_api const &other) const { \ + object result = reinterpret_steal( \ + fn(derived().ptr(), other.derived().ptr())); \ + if (!result.ptr()) \ + throw error_already_set(); \ + return result; \ + } + +PYBIND11_MATH_OPERATOR_UNARY (operator~, PyNumber_Invert) +PYBIND11_MATH_OPERATOR_UNARY (operator-, PyNumber_Negative) +PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add) +PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd) +PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract) +PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract) +PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply) +PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply) +PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide) +PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide) +PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or) +PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr) +PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And) +PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd) +PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor) +PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor) +PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift) +PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift) +PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift) +PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift) + +#undef PYBIND11_MATH_OPERATOR_UNARY +#undef PYBIND11_MATH_OPERATOR_BINARY + +NAMESPACE_END(detail) +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/stl.h b/mmocr/models/textdet/postprocess/include/pybind11/stl.h new file mode 100644 index 00000000..32f8d294 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/stl.h @@ -0,0 +1,386 @@ +/* + pybind11/stl.h: Transparent conversion for STL data types + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +#ifdef __has_include +// std::optional (but including it in c++14 mode isn't allowed) +# if defined(PYBIND11_CPP17) && __has_include() +# include +# define PYBIND11_HAS_OPTIONAL 1 +# endif +// std::experimental::optional (but not allowed in c++11 mode) +# if defined(PYBIND11_CPP14) && (__has_include() && \ + !__has_include()) +# include +# define PYBIND11_HAS_EXP_OPTIONAL 1 +# endif +// std::variant +# if defined(PYBIND11_CPP17) && __has_include() +# include +# define PYBIND11_HAS_VARIANT 1 +# endif +#elif defined(_MSC_VER) && defined(PYBIND11_CPP17) +# include +# include +# define PYBIND11_HAS_OPTIONAL 1 +# define PYBIND11_HAS_VARIANT 1 +#endif + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for +/// forwarding a container element). Typically used indirect via forwarded_type(), below. +template +using forwarded_type = conditional_t< + std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; + +/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically +/// used for forwarding a container's elements. +template +forwarded_type forward_like(U &&u) { + return std::forward>(std::forward(u)); +} + +template struct set_caster { + using type = Type; + using key_conv = make_caster; + + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + auto s = reinterpret_borrow(src); + value.clear(); + for (auto entry : s) { + key_conv conv; + if (!conv.load(entry, convert)) + return false; + value.insert(cast_op(std::move(conv))); + } + return true; + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + if (!std::is_lvalue_reference::value) + policy = return_value_policy_override::policy(policy); + pybind11::set s; + for (auto &&value : src) { + auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); + if (!value_ || !s.add(value_)) + return handle(); + } + return s.release(); + } + + PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]")); +}; + +template struct map_caster { + using key_conv = make_caster; + using value_conv = make_caster; + + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + auto d = reinterpret_borrow(src); + value.clear(); + for (auto it : d) { + key_conv kconv; + value_conv vconv; + if (!kconv.load(it.first.ptr(), convert) || + !vconv.load(it.second.ptr(), convert)) + return false; + value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); + } + return true; + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + dict d; + return_value_policy policy_key = policy; + return_value_policy policy_value = policy; + if (!std::is_lvalue_reference::value) { + policy_key = return_value_policy_override::policy(policy_key); + policy_value = return_value_policy_override::policy(policy_value); + } + for (auto &&kv : src) { + auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy_key, parent)); + auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy_value, parent)); + if (!key || !value) + return handle(); + d[key] = value; + } + return d.release(); + } + + PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]")); +}; + +template struct list_caster { + using value_conv = make_caster; + + bool load(handle src, bool convert) { + if (!isinstance(src) || isinstance(src)) + return false; + auto s = reinterpret_borrow(src); + value.clear(); + reserve_maybe(s, &value); + for (auto it : s) { + value_conv conv; + if (!conv.load(it, convert)) + return false; + value.push_back(cast_op(std::move(conv))); + } + return true; + } + +private: + template ().reserve(0)), void>::value, int> = 0> + void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } + void reserve_maybe(sequence, void *) { } + +public: + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + if (!std::is_lvalue_reference::value) + policy = return_value_policy_override::policy(policy); + list l(src.size()); + size_t index = 0; + for (auto &&value : src) { + auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); + if (!value_) + return handle(); + PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference + } + return l.release(); + } + + PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]")); +}; + +template struct type_caster> + : list_caster, Type> { }; + +template struct type_caster> + : list_caster, Type> { }; + +template struct type_caster> + : list_caster, Type> { }; + +template struct array_caster { + using value_conv = make_caster; + +private: + template + bool require_size(enable_if_t size) { + if (value.size() != size) + value.resize(size); + return true; + } + template + bool require_size(enable_if_t size) { + return size == Size; + } + +public: + bool load(handle src, bool convert) { + if (!isinstance(src)) + return false; + auto l = reinterpret_borrow(src); + if (!require_size(l.size())) + return false; + size_t ctr = 0; + for (auto it : l) { + value_conv conv; + if (!conv.load(it, convert)) + return false; + value[ctr++] = cast_op(std::move(conv)); + } + return true; + } + + template + static handle cast(T &&src, return_value_policy policy, handle parent) { + list l(src.size()); + size_t index = 0; + for (auto &&value : src) { + auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); + if (!value_) + return handle(); + PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference + } + return l.release(); + } + + PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _(_(""), _("[") + _() + _("]")) + _("]")); +}; + +template struct type_caster> + : array_caster, Type, false, Size> { }; + +template struct type_caster> + : array_caster, Type, true> { }; + +template struct type_caster> + : set_caster, Key> { }; + +template struct type_caster> + : set_caster, Key> { }; + +template struct type_caster> + : map_caster, Key, Value> { }; + +template struct type_caster> + : map_caster, Key, Value> { }; + +// This type caster is intended to be used for std::optional and std::experimental::optional +template struct optional_caster { + using value_conv = make_caster; + + template + static handle cast(T_ &&src, return_value_policy policy, handle parent) { + if (!src) + return none().inc_ref(); + policy = return_value_policy_override::policy(policy); + return value_conv::cast(*std::forward(src), policy, parent); + } + + bool load(handle src, bool convert) { + if (!src) { + return false; + } else if (src.is_none()) { + return true; // default-constructed value is already empty + } + value_conv inner_caster; + if (!inner_caster.load(src, convert)) + return false; + + value.emplace(cast_op(std::move(inner_caster))); + return true; + } + + PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]")); +}; + +#if PYBIND11_HAS_OPTIONAL +template struct type_caster> + : public optional_caster> {}; + +template<> struct type_caster + : public void_caster {}; +#endif + +#if PYBIND11_HAS_EXP_OPTIONAL +template struct type_caster> + : public optional_caster> {}; + +template<> struct type_caster + : public void_caster {}; +#endif + +/// Visit a variant and cast any found type to Python +struct variant_caster_visitor { + return_value_policy policy; + handle parent; + + using result_type = handle; // required by boost::variant in C++11 + + template + result_type operator()(T &&src) const { + return make_caster::cast(std::forward(src), policy, parent); + } +}; + +/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar +/// `namespace::variant` types which provide a `namespace::visit()` function are handled here +/// automatically using argument-dependent lookup. Users can provide specializations for other +/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. +template class Variant> +struct visit_helper { + template + static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { + return visit(std::forward(args)...); + } +}; + +/// Generic variant caster +template struct variant_caster; + +template class V, typename... Ts> +struct variant_caster> { + static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); + + template + bool load_alternative(handle src, bool convert, type_list) { + auto caster = make_caster(); + if (caster.load(src, convert)) { + value = cast_op(caster); + return true; + } + return load_alternative(src, convert, type_list{}); + } + + bool load_alternative(handle, bool, type_list<>) { return false; } + + bool load(handle src, bool convert) { + // Do a first pass without conversions to improve constructor resolution. + // E.g. `py::int_(1).cast>()` needs to fill the `int` + // slot of the variant. Without two-pass loading `double` would be filled + // because it appears first and a conversion is possible. + if (convert && load_alternative(src, false, type_list{})) + return true; + return load_alternative(src, convert, type_list{}); + } + + template + static handle cast(Variant &&src, return_value_policy policy, handle parent) { + return visit_helper::call(variant_caster_visitor{policy, parent}, + std::forward(src)); + } + + using Type = V; + PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name...) + _("]")); +}; + +#if PYBIND11_HAS_VARIANT +template +struct type_caster> : variant_caster> { }; +#endif + +NAMESPACE_END(detail) + +inline std::ostream &operator<<(std::ostream &os, const handle &obj) { + os << (std::string) str(obj); + return os; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/mmocr/models/textdet/postprocess/include/pybind11/stl_bind.h b/mmocr/models/textdet/postprocess/include/pybind11/stl_bind.h new file mode 100644 index 00000000..38dd68f6 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/stl_bind.h @@ -0,0 +1,599 @@ +/* + pybind11/std_bind.h: Binding generators for STL data types + + Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "operators.h" + +#include +#include + +NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +NAMESPACE_BEGIN(detail) + +/* SFINAE helper class used by 'is_comparable */ +template struct container_traits { + template static std::true_type test_comparable(decltype(std::declval() == std::declval())*); + template static std::false_type test_comparable(...); + template static std::true_type test_value(typename T2::value_type *); + template static std::false_type test_value(...); + template static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *); + template static std::false_type test_pair(...); + + static constexpr const bool is_comparable = std::is_same(nullptr))>::value; + static constexpr const bool is_pair = std::is_same(nullptr, nullptr))>::value; + static constexpr const bool is_vector = std::is_same(nullptr))>::value; + static constexpr const bool is_element = !is_pair && !is_vector; +}; + +/* Default: is_comparable -> std::false_type */ +template +struct is_comparable : std::false_type { }; + +/* For non-map data structures, check whether operator== can be instantiated */ +template +struct is_comparable< + T, enable_if_t::is_element && + container_traits::is_comparable>> + : std::true_type { }; + +/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */ +template +struct is_comparable::is_vector>> { + static constexpr const bool value = + is_comparable::value; +}; + +/* For pairs, recursively check the two data types */ +template +struct is_comparable::is_pair>> { + static constexpr const bool value = + is_comparable::value && + is_comparable::value; +}; + +/* Fallback functions */ +template void vector_if_copy_constructible(const Args &...) { } +template void vector_if_equal_operator(const Args &...) { } +template void vector_if_insertion_operator(const Args &...) { } +template void vector_modifiers(const Args &...) { } + +template +void vector_if_copy_constructible(enable_if_t::value, Class_> &cl) { + cl.def(init(), "Copy constructor"); +} + +template +void vector_if_equal_operator(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + + cl.def(self == self); + cl.def(self != self); + + cl.def("count", + [](const Vector &v, const T &x) { + return std::count(v.begin(), v.end(), x); + }, + arg("x"), + "Return the number of times ``x`` appears in the list" + ); + + cl.def("remove", [](Vector &v, const T &x) { + auto p = std::find(v.begin(), v.end(), x); + if (p != v.end()) + v.erase(p); + else + throw value_error(); + }, + arg("x"), + "Remove the first item from the list whose value is x. " + "It is an error if there is no such item." + ); + + cl.def("__contains__", + [](const Vector &v, const T &x) { + return std::find(v.begin(), v.end(), x) != v.end(); + }, + arg("x"), + "Return true the container contains ``x``" + ); +} + +// Vector modifiers -- requires a copyable vector_type: +// (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems +// silly to allow deletion but not insertion, so include them here too.) +template +void vector_modifiers(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + using SizeType = typename Vector::size_type; + using DiffType = typename Vector::difference_type; + + cl.def("append", + [](Vector &v, const T &value) { v.push_back(value); }, + arg("x"), + "Add an item to the end of the list"); + + cl.def(init([](iterable it) { + auto v = std::unique_ptr(new Vector()); + v->reserve(len(it)); + for (handle h : it) + v->push_back(h.cast()); + return v.release(); + })); + + cl.def("extend", + [](Vector &v, const Vector &src) { + v.insert(v.end(), src.begin(), src.end()); + }, + arg("L"), + "Extend the list by appending all the items in the given list" + ); + + cl.def("insert", + [](Vector &v, SizeType i, const T &x) { + if (i > v.size()) + throw index_error(); + v.insert(v.begin() + (DiffType) i, x); + }, + arg("i") , arg("x"), + "Insert an item at a given position." + ); + + cl.def("pop", + [](Vector &v) { + if (v.empty()) + throw index_error(); + T t = v.back(); + v.pop_back(); + return t; + }, + "Remove and return the last item" + ); + + cl.def("pop", + [](Vector &v, SizeType i) { + if (i >= v.size()) + throw index_error(); + T t = v[i]; + v.erase(v.begin() + (DiffType) i); + return t; + }, + arg("i"), + "Remove and return the item at index ``i``" + ); + + cl.def("__setitem__", + [](Vector &v, SizeType i, const T &t) { + if (i >= v.size()) + throw index_error(); + v[i] = t; + } + ); + + /// Slicing protocol + cl.def("__getitem__", + [](const Vector &v, slice slice) -> Vector * { + size_t start, stop, step, slicelength; + + if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) + throw error_already_set(); + + Vector *seq = new Vector(); + seq->reserve((size_t) slicelength); + + for (size_t i=0; ipush_back(v[start]); + start += step; + } + return seq; + }, + arg("s"), + "Retrieve list elements using a slice object" + ); + + cl.def("__setitem__", + [](Vector &v, slice slice, const Vector &value) { + size_t start, stop, step, slicelength; + if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) + throw error_already_set(); + + if (slicelength != value.size()) + throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); + + for (size_t i=0; i= v.size()) + throw index_error(); + v.erase(v.begin() + DiffType(i)); + }, + "Delete the list elements at index ``i``" + ); + + cl.def("__delitem__", + [](Vector &v, slice slice) { + size_t start, stop, step, slicelength; + + if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) + throw error_already_set(); + + if (step == 1 && false) { + v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength)); + } else { + for (size_t i = 0; i < slicelength; ++i) { + v.erase(v.begin() + DiffType(start)); + start += step - 1; + } + } + }, + "Delete list elements using a slice object" + ); + +} + +// If the type has an operator[] that doesn't return a reference (most notably std::vector), +// we have to access by copying; otherwise we return by reference. +template using vector_needs_copy = negation< + std::is_same()[typename Vector::size_type()]), typename Vector::value_type &>>; + +// The usual case: access and iterate by reference +template +void vector_accessor(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + using SizeType = typename Vector::size_type; + using ItType = typename Vector::iterator; + + cl.def("__getitem__", + [](Vector &v, SizeType i) -> T & { + if (i >= v.size()) + throw index_error(); + return v[i]; + }, + return_value_policy::reference_internal // ref + keepalive + ); + + cl.def("__iter__", + [](Vector &v) { + return make_iterator< + return_value_policy::reference_internal, ItType, ItType, T&>( + v.begin(), v.end()); + }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); +} + +// The case for special objects, like std::vector, that have to be returned-by-copy: +template +void vector_accessor(enable_if_t::value, Class_> &cl) { + using T = typename Vector::value_type; + using SizeType = typename Vector::size_type; + using ItType = typename Vector::iterator; + cl.def("__getitem__", + [](const Vector &v, SizeType i) -> T { + if (i >= v.size()) + throw index_error(); + return v[i]; + } + ); + + cl.def("__iter__", + [](Vector &v) { + return make_iterator< + return_value_policy::copy, ItType, ItType, T>( + v.begin(), v.end()); + }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); +} + +template auto vector_if_insertion_operator(Class_ &cl, std::string const &name) + -> decltype(std::declval() << std::declval(), void()) { + using size_type = typename Vector::size_type; + + cl.def("__repr__", + [name](Vector &v) { + std::ostringstream s; + s << name << '['; + for (size_type i=0; i < v.size(); ++i) { + s << v[i]; + if (i != v.size() - 1) + s << ", "; + } + s << ']'; + return s.str(); + }, + "Return the canonical string representation of this list." + ); +} + +// Provide the buffer interface for vectors if we have data() and we have a format for it +// GCC seems to have "void std::vector::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer +template +struct vector_has_data_and_format : std::false_type {}; +template +struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; + +// Add the buffer interface to a vector +template +enable_if_t...>::value> +vector_buffer(Class_& cl) { + using T = typename Vector::value_type; + + static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); + + // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here + format_descriptor::format(); + + cl.def_buffer([](Vector& v) -> buffer_info { + return buffer_info(v.data(), static_cast(sizeof(T)), format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); + }); + + cl.def(init([](buffer buf) { + auto info = buf.request(); + if (info.ndim != 1 || info.strides[0] % static_cast(sizeof(T))) + throw type_error("Only valid 1D buffers can be copied to a vector"); + if (!detail::compare_buffer_info::compare(info) || (ssize_t) sizeof(T) != info.itemsize) + throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor::format() + ")"); + + auto vec = std::unique_ptr(new Vector()); + vec->reserve((size_t) info.shape[0]); + T *p = static_cast(info.ptr); + ssize_t step = info.strides[0] / static_cast(sizeof(T)); + T *end = p + info.shape[0] * step; + for (; p != end; p += step) + vec->push_back(*p); + return vec.release(); + })); + + return; +} + +template +enable_if_t...>::value> vector_buffer(Class_&) {} + +NAMESPACE_END(detail) + +// +// std::vector +// +template , typename... Args> +class_ bind_vector(handle scope, std::string const &name, Args&&... args) { + using Class_ = class_; + + // If the value_type is unregistered (e.g. a converting type) or is itself registered + // module-local then make the vector binding module-local as well: + using vtype = typename Vector::value_type; + auto vtype_info = detail::get_type_info(typeid(vtype)); + bool local = !vtype_info || vtype_info->module_local; + + Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); + + // Declare the buffer interface if a buffer_protocol() is passed in + detail::vector_buffer(cl); + + cl.def(init<>()); + + // Register copy constructor (if possible) + detail::vector_if_copy_constructible(cl); + + // Register comparison-related operators and functions (if possible) + detail::vector_if_equal_operator(cl); + + // Register stream insertion operator (if possible) + detail::vector_if_insertion_operator(cl, name); + + // Modifiers require copyable vector value type + detail::vector_modifiers(cl); + + // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive + detail::vector_accessor(cl); + + cl.def("__bool__", + [](const Vector &v) -> bool { + return !v.empty(); + }, + "Check whether the list is nonempty" + ); + + cl.def("__len__", &Vector::size); + + + + +#if 0 + // C++ style functions deprecated, leaving it here as an example + cl.def(init()); + + cl.def("resize", + (void (Vector::*) (size_type count)) & Vector::resize, + "changes the number of elements stored"); + + cl.def("erase", + [](Vector &v, SizeType i) { + if (i >= v.size()) + throw index_error(); + v.erase(v.begin() + i); + }, "erases element at index ``i``"); + + cl.def("empty", &Vector::empty, "checks whether the container is empty"); + cl.def("size", &Vector::size, "returns the number of elements"); + cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end"); + cl.def("pop_back", &Vector::pop_back, "removes the last element"); + + cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements"); + cl.def("reserve", &Vector::reserve, "reserves storage"); + cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage"); + cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory"); + + cl.def("clear", &Vector::clear, "clears the contents"); + cl.def("swap", &Vector::swap, "swaps the contents"); + + cl.def("front", [](Vector &v) { + if (v.size()) return v.front(); + else throw index_error(); + }, "access the first element"); + + cl.def("back", [](Vector &v) { + if (v.size()) return v.back(); + else throw index_error(); + }, "access the last element "); + +#endif + + return cl; +} + + + +// +// std::map, std::unordered_map +// + +NAMESPACE_BEGIN(detail) + +/* Fallback functions */ +template void map_if_insertion_operator(const Args &...) { } +template void map_assignment(const Args &...) { } + +// Map assignment when copy-assignable: just copy the value +template +void map_assignment(enable_if_t::value, Class_> &cl) { + using KeyType = typename Map::key_type; + using MappedType = typename Map::mapped_type; + + cl.def("__setitem__", + [](Map &m, const KeyType &k, const MappedType &v) { + auto it = m.find(k); + if (it != m.end()) it->second = v; + else m.emplace(k, v); + } + ); +} + +// Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting +template +void map_assignment(enable_if_t< + !std::is_copy_assignable::value && + is_copy_constructible::value, + Class_> &cl) { + using KeyType = typename Map::key_type; + using MappedType = typename Map::mapped_type; + + cl.def("__setitem__", + [](Map &m, const KeyType &k, const MappedType &v) { + // We can't use m[k] = v; because value type might not be default constructable + auto r = m.emplace(k, v); + if (!r.second) { + // value type is not copy assignable so the only way to insert it is to erase it first... + m.erase(r.first); + m.emplace(k, v); + } + } + ); +} + + +template auto map_if_insertion_operator(Class_ &cl, std::string const &name) +-> decltype(std::declval() << std::declval() << std::declval(), void()) { + + cl.def("__repr__", + [name](Map &m) { + std::ostringstream s; + s << name << '{'; + bool f = false; + for (auto const &kv : m) { + if (f) + s << ", "; + s << kv.first << ": " << kv.second; + f = true; + } + s << '}'; + return s.str(); + }, + "Return the canonical string representation of this map." + ); +} + + +NAMESPACE_END(detail) + +template , typename... Args> +class_ bind_map(handle scope, const std::string &name, Args&&... args) { + using KeyType = typename Map::key_type; + using MappedType = typename Map::mapped_type; + using Class_ = class_; + + // If either type is a non-module-local bound type then make the map binding non-local as well; + // otherwise (e.g. both types are either module-local or converting) the map will be + // module-local. + auto tinfo = detail::get_type_info(typeid(MappedType)); + bool local = !tinfo || tinfo->module_local; + if (local) { + tinfo = detail::get_type_info(typeid(KeyType)); + local = !tinfo || tinfo->module_local; + } + + Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); + + cl.def(init<>()); + + // Register stream insertion operator (if possible) + detail::map_if_insertion_operator(cl, name); + + cl.def("__bool__", + [](const Map &m) -> bool { return !m.empty(); }, + "Check whether the map is nonempty" + ); + + cl.def("__iter__", + [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + cl.def("items", + [](Map &m) { return make_iterator(m.begin(), m.end()); }, + keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ + ); + + cl.def("__getitem__", + [](Map &m, const KeyType &k) -> MappedType & { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + return it->second; + }, + return_value_policy::reference_internal // ref + keepalive + ); + + // Assignment provided only if the type is copyable + detail::map_assignment(cl); + + cl.def("__delitem__", + [](Map &m, const KeyType &k) { + auto it = m.find(k); + if (it == m.end()) + throw key_error(); + m.erase(it); + } + ); + + cl.def("__len__", &Map::size); + + return cl; +} + +NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/mmocr/models/textdet/postprocess/include/pybind11/typeid.h b/mmocr/models/textdet/postprocess/include/pybind11/typeid.h new file mode 100644 index 00000000..c903fb14 --- /dev/null +++ b/mmocr/models/textdet/postprocess/include/pybind11/typeid.h @@ -0,0 +1,53 @@ +/* + pybind11/typeid.h: Compiler-independent access to type identifiers + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include +#include + +#if defined(__GNUG__) +#include +#endif + +NAMESPACE_BEGIN(pybind11) +NAMESPACE_BEGIN(detail) +/// Erase all occurrences of a substring +inline void erase_all(std::string &string, const std::string &search) { + for (size_t pos = 0;;) { + pos = string.find(search, pos); + if (pos == std::string::npos) break; + string.erase(pos, search.length()); + } +} + +PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { +#if defined(__GNUG__) + int status = 0; + std::unique_ptr res { + abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; + if (status == 0) + name = res.get(); +#else + detail::erase_all(name, "class "); + detail::erase_all(name, "struct "); + detail::erase_all(name, "enum "); +#endif + detail::erase_all(name, "pybind11::"); +} +NAMESPACE_END(detail) + +/// Return a string representation of a C++ type +template static std::string type_id() { + std::string name(typeid(T).name()); + detail::clean_type_id(name); + return name; +} + +NAMESPACE_END(pybind11) diff --git a/mmocr/models/textdet/postprocess/pan.cpp b/mmocr/models/textdet/postprocess/pan.cpp new file mode 100644 index 00000000..874ae0a6 --- /dev/null +++ b/mmocr/models/textdet/postprocess/pan.cpp @@ -0,0 +1,194 @@ +// This implementation is from https://github.com/WenmuZhou/PAN.pytorch/blob/master/post_processing/pse.cpp + +#include +#include +#include +#include +#include +#include "include/pybind11/pybind11.h" +#include "include/pybind11/numpy.h" +#include "include/pybind11/stl.h" +#include "include/pybind11/stl_bind.h" + +namespace py = pybind11; + + +namespace panet{ + py::array_t assign_pixels( + py::array_t text, + py::array_t similarity_vectors, + py::array_t label_map, + int label_num, + float dis_threshold = 0.8) + { + auto pbuf_text = text.request(); + auto pbuf_similarity_vectors = similarity_vectors.request(); + auto pbuf_label_map = label_map.request(); + if (pbuf_label_map.ndim != 2 || pbuf_label_map.shape[0]==0 || pbuf_label_map.shape[1]==0) + throw std::runtime_error("label map must have a shape of (h>0, w>0)"); + int h = pbuf_label_map.shape[0]; + int w = pbuf_label_map.shape[1]; + if (pbuf_similarity_vectors.ndim != 3 || pbuf_similarity_vectors.shape[0]!=h || pbuf_similarity_vectors.shape[1]!=w || pbuf_similarity_vectors.shape[2]!=4 || + pbuf_text.shape[0]!=h || pbuf_text.shape[1]!=w) + throw std::runtime_error("similarity_vectors must have a shape of (h,w,4) and text must have a shape of (h,w,4)"); + //åˆå§‹åŒ–结果 + auto res = py::array_t(pbuf_text.size); + auto pbuf_res = res.request(); + // èŽ·å– text similarity_vectors å’Œ label_map的指针 + auto ptr_label_map = static_cast(pbuf_label_map.ptr); + auto ptr_text = static_cast(pbuf_text.ptr); + auto ptr_similarity_vectors = static_cast(pbuf_similarity_vectors.ptr); + auto ptr_res = static_cast(pbuf_res.ptr); + + std::queue> q; + // 计算å„个kernelçš„similarity_vectors + float kernel_vector[label_num][5] = {0}; + + // 文本åƒç´ å…¥é˜Ÿåˆ— + for (int i = 0; i0) + { + kernel_vector[label][0] += p_similarity_vectors[k]; + kernel_vector[label][1] += p_similarity_vectors[k+1]; + kernel_vector[label][2] += p_similarity_vectors[k+2]; + kernel_vector[label][3] += p_similarity_vectors[k+3]; + kernel_vector[label][4] += 1; + q.push(std::make_tuple(i, j, label)); + } + p_res[j] = label; + } + } + + for(int i=0;i(q_n); + int x = std::get<1>(q_n); + int32_t l = std::get<2>(q_n); + //store the edge pixel after one expansion + auto kernel_cv = kernel_vector[l]; + for (int idx=0; idx<4; idx++) + { + int tmpy = y + dy[idx]; + int tmpx = x + dx[idx]; + auto p_res = ptr_res + tmpy*w; + if (tmpy<0 || tmpy>=h || tmpx<0 || tmpx>=w) + continue; + if (!ptr_text[tmpy*w+tmpx] || p_res[tmpx]>0) + continue; + // 计算è·ç¦» + float dis = 0; + auto p_similarity_vectors = ptr_similarity_vectors + tmpy * w*4; + for(size_t i=0;i<4;i++) + { + dis += pow(kernel_cv[i] - p_similarity_vectors[tmpx*4 + i],2); + } + dis = sqrt(dis); + if(dis >= dis_threshold) + continue; + q.push(std::make_tuple(tmpy, tmpx, l)); + p_res[tmpx]=l; + } + } + return res; + } + + std::map> estimate_text_confidence( + py::array_t label_map, + py::array_t score_map, + int label_num) + { + auto pbuf_label_map = label_map.request(); + auto pbuf_score_map = score_map.request(); + auto ptr_label_map = static_cast(pbuf_label_map.ptr); + auto ptr_score_map = static_cast(pbuf_score_map.ptr); + int h = pbuf_label_map.shape[0]; + int w = pbuf_label_map.shape[1]; + + std::map> point_dict; + std::vector> point_vector; + for(int i=0;i point; + point.push_back(0); + point.push_back(0); + point_vector.push_back(point); + } + for (int i = 0; i 2) + { + point_vector[i][0] /= point_vector[i][1]; + point_dict[i] = point_vector[i]; + } + } + return point_dict; + } + std::vector get_pixel_num( + py::array_t label_map, + int label_num) + { + auto pbuf_label_map = label_map.request(); + auto ptr_label_map = static_cast(pbuf_label_map.ptr); + int h = pbuf_label_map.shape[0]; + int w = pbuf_label_map.shape[1]; + + std::vector point_vector; + for(int i=0;i +#include + +using namespace std; + +namespace py = pybind11; + +namespace pse_adaptor { + + class Point2d { + public: + int x; + int y; + + Point2d() : x(0), y(0) + {} + + Point2d(int _x, int _y) : x(_x), y(_y) + {} + }; + + void growing_text_line(const int *data, + vector &data_shape, + const int *label_map, + vector &label_shape, + int &label_num, + float &min_area, + vector> &text_line) { + int area[label_num + 1]; + memset(area, 0, sizeof(area)); + for (int x = 0; x < label_shape[0]; ++x) { + for (int y = 0; y < label_shape[1]; ++y) { + int label = label_map[x * label_shape[1] + y]; + if (label == 0) continue; + area[label] += 1; + } + } + + queue queue, next_queue; + for (int x = 0; x < label_shape[0]; ++x) { + vector row(label_shape[1]); + for (int y = 0; y < label_shape[1]; ++y) { + int label = label_map[x * label_shape[1] + y]; + if (label == 0) continue; + if (area[label] < min_area) continue; + + Point2d point(x, y); + queue.push(point); + row[y] = label; + } + text_line.emplace_back(row); + } + + int dx[] = {-1, 1, 0, 0}; + int dy[] = {0, 0, -1, 1}; + + for (int kernel_id = data_shape[0] - 2; kernel_id >= 0; --kernel_id) { + while (!queue.empty()) { + Point2d point = queue.front(); + queue.pop(); + int x = point.x; + int y = point.y; + int label = text_line[x][y]; + + bool is_edge = true; + for (int d = 0; d < 4; ++d) { + int tmp_x = x + dx[d]; + int tmp_y = y + dy[d]; + + if (tmp_x < 0 || tmp_x >= (int)text_line.size()) continue; + if (tmp_y < 0 || tmp_y >= (int)text_line[1].size()) continue; + int kernel_value = data[kernel_id * data_shape[1] * data_shape[2] + tmp_x * data_shape[2] + tmp_y]; + if (kernel_value == 0) continue; + if (text_line[tmp_x][tmp_y] > 0) continue; + + Point2d point(tmp_x, tmp_y); + queue.push(point); + text_line[tmp_x][tmp_y] = label; + is_edge = false; + } + + if (is_edge) { + next_queue.push(point); + } + } + swap(queue, next_queue); + } + } + + vector> pse(py::array_t quad_n9, + float min_area, + py::array_t label_map, + int label_num) { + auto buf = quad_n9.request(); + auto data = static_cast(buf.ptr); + vector data_shape = buf.shape; + + auto buf_label_map = label_map.request(); + auto data_label_map = static_cast(buf_label_map.ptr); + vector label_map_shape = buf_label_map.shape; + + vector> text_line; + + growing_text_line(data, + data_shape, + data_label_map, + label_map_shape, + label_num, + min_area, + text_line); + + return text_line; + } +} + +PYBIND11_PLUGIN(pse) { + py::module m("pse", "pse"); + + m.def("pse", &pse_adaptor::pse, "pse"); + + return m.ptr(); +} diff --git a/mmocr/models/textdet/postprocess/wrapper.py b/mmocr/models/textdet/postprocess/wrapper.py new file mode 100644 index 00000000..0a749465 --- /dev/null +++ b/mmocr/models/textdet/postprocess/wrapper.py @@ -0,0 +1,394 @@ +import cv2 +import numpy as np +import pyclipper +import torch +from numpy.linalg import norm +from shapely.geometry import Polygon +from skimage.morphology import skeletonize + +from mmocr.core import points2boundary + + +def filter_instance(area, confidence, min_area, min_confidence): + return bool(area < min_area or confidence < min_confidence) + + +def decode( + decoding_type='pan', # 'pan' or 'pse' + **kwargs): + if decoding_type == 'pan': + return pan_decode(**kwargs) + if decoding_type == 'pse': + return pse_decode(**kwargs) + if decoding_type == 'db': + return db_decode(**kwargs) + if decoding_type == 'textsnake': + return textsnake_decode(**kwargs) + + raise NotImplementedError + + +def pan_decode(preds, + text_repr_type='poly', + min_text_confidence=0.5, + min_kernel_confidence=0.5, + min_text_avg_confidence=0.85, + min_kernel_area=0, + min_text_area=16): + """Convert scores to quadrangles via post processing in PANet. This is + partially adapted from https://github.com/WenmuZhou/PAN.pytorch. + + Args: + preds (tensor): The head output tensor of size 6xHxW. + text_repr_type (str): The boundary encoding type 'poly' or 'quad'. + min_text_confidence (float): The minimal text confidence. + min_kernel_confidence (float): The minimal kernel confidence. + min_text_avg_confidence (float): The minimal text average confidence. + min_kernel_area (int): The minimal text kernel area. + min_text_area (int): The minimal text instance region area. + Returns: + boundaries: (list[list[float]]): The instance boundary and its + instance confidence list. + """ + from .pan import assign_pixels, estimate_text_confidence, get_pixel_num + preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) + preds = preds.detach().cpu().numpy() + + text_score = preds[0].astype(np.float32) + text = preds[0] > min_text_confidence + kernel = (preds[1] > min_kernel_confidence) * text + embeddings = preds[2:].transpose((1, 2, 0)) # (h, w, 4) + + region_num, labels = cv2.connectedComponents( + kernel.astype(np.uint8), connectivity=4) + valid_kernel_inx = [] + region_pixel_num = get_pixel_num(labels, region_num) + + # from inx 1. 0: meaningless. + for region_idx in range(1, region_num): + if region_pixel_num[region_idx] < min_kernel_area: + continue + valid_kernel_inx.append(region_idx) + + # assign pixels to valid kernels + assignment = assign_pixels( + text.astype(np.uint8), embeddings, labels, region_num, 0.8) + assignment = assignment.reshape(text.shape) + + boundaries = [] + + # compute text avg confidence + + text_points = estimate_text_confidence(assignment, text_score, region_num) + for text_inx, text_point in text_points.items(): + if text_inx not in valid_kernel_inx: + continue + text_confidence = text_point[0] + text_point = text_point[2:] + text_point = np.array(text_point, dtype=int).reshape(-1, 2) + area = text_point.shape[0] + + if filter_instance(area, text_confidence, min_text_area, + min_text_avg_confidence): + continue + vertices_confidence = points2boundary(text_point, text_repr_type, + text_confidence) + if vertices_confidence is not None: + boundaries.append(vertices_confidence) + + return boundaries + + +def pse_decode(preds, + text_repr_type='poly', + min_kernel_confidence=0.5, + min_text_avg_confidence=0.85, + min_kernel_area=0, + min_text_area=16): + """Decoding predictions of PSENet to instances. This is partially adapted + from https://github.com/whai362/PSENet. + + Args: + preds (tensor): The head output tensor of size nxHxW. + text_repr_type (str): The boundary encoding type 'poly' or 'quad'. + min_text_confidence (float): The minimal text confidence. + min_kernel_confidence (float): The minimal kernel confidence. + min_text_avg_confidence (float): The minimal text average confidence. + min_kernel_area (int): The minimal text kernel area. + min_text_area (int): The minimal text instance region area. + Returns: + boundaries: (list[list[float]]): The instance boundary and its + instance confidence list. + """ + preds = torch.sigmoid(preds) # text confidence + + score = preds[0, :, :] + masks = preds > min_kernel_confidence + text_mask = masks[0, :, :] + kernel_masks = masks[0:, :, :] * text_mask + + score = score.data.cpu().numpy().astype(np.float32) # to numpy + + kernel_masks = kernel_masks.data.cpu().numpy().astype(np.uint8) # to numpy + from .pse import pse + + region_num, labels = cv2.connectedComponents( + kernel_masks[-1], connectivity=4) + + # labels = pse(kernel_masks, min_kernel_area) + labels = pse(kernel_masks, min_kernel_area, labels, region_num) + labels = np.array(labels) + label_num = np.max(labels) + 1 + boundaries = [] + for i in range(1, label_num): + points = np.array(np.where(labels == i)).transpose((1, 0))[:, ::-1] + area = points.shape[0] + score_instance = np.mean(score[labels == i]) + if filter_instance(area, score_instance, min_text_area, + min_text_avg_confidence): + continue + + vertices_confidence = points2boundary(points, text_repr_type, + score_instance) + if vertices_confidence is not None: + boundaries.append(vertices_confidence) + + return boundaries + + +def box_score_fast(bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + +def unclip(box, unclip_ratio=1.5): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + +def db_decode(preds, + text_repr_type='poly', + mask_thr=0.3, + min_text_score=0.3, + min_text_width=5, + unclip_ratio=1.5, + max_candidates=1000): + """Decoding predictions of DbNet to instances. This is partially adapted + from https://github.com/MhLiao/DB. + + Args: + preds (Tensor): The head output tensor of size nxHxW. + text_repr_type (str): The boundary encoding type 'poly' or 'quad'. + mask_thr (float): The mask threshold value for binarization. + min_text_score (float): The threshold value for converting binary map + to shrink text regions. + min_text_width (int): The minimum width of boundary polygon/box + predicted. + unclip_ratio (float): The unclip ratio for text regions dilation. + max_candidates (int): The maximum candidate number. + + Returns: + boundaries: (list[list[float]]): The predicted text boundaries. + """ + prob_map = preds[0, :, :] + text_mask = prob_map > mask_thr + + score_map = prob_map.data.cpu().numpy().astype(np.float32) + text_mask = text_mask.data.cpu().numpy().astype(np.uint8) # to numpy + + contours, _ = cv2.findContours((text_mask * 255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + boundaries = [] + for i, poly in enumerate(contours): + if i > max_candidates: + break + epsilon = 0.01 * cv2.arcLength(poly, True) + approx = cv2.approxPolyDP(poly, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + score = box_score_fast(score_map, points) + if score < min_text_score: + continue + poly = unclip(points, unclip_ratio=unclip_ratio) + if len(poly) == 0 or isinstance(poly[0], list): + continue + poly = poly.reshape(-1, 2) + poly = points2boundary(poly, text_repr_type, score, min_text_width) + if poly is not None: + boundaries.append(poly) + return boundaries + + +def fill_hole(input_mask): + h, w = input_mask.shape + canvas = np.zeros((h + 2, w + 2), np.uint8) + canvas[1:h + 1, 1:w + 1] = input_mask.copy() + + mask = np.zeros((h + 4, w + 4), np.uint8) + + cv2.floodFill(canvas, mask, (0, 0), 1) + canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool) + + return ~canvas | input_mask + + +def centralize(points_yx, + normal_sin, + normal_cos, + radius, + contour_mask, + step_ratio=0.03): + + h, w = contour_mask.shape + top_yx = bot_yx = points_yx + step_flags = np.ones((len(points_yx), 1), dtype=np.bool) + step = step_ratio * radius * np.hstack([normal_sin, normal_cos]) + while np.any(step_flags): + next_yx = np.array(top_yx + step, dtype=np.int32) + next_y, next_x = next_yx[:, 0], next_yx[:, 1] + step_flags = (0 <= next_y) & (next_y < h) & (0 < next_x) & ( + next_x < w) & contour_mask[np.clip(next_y, 0, h - 1), + np.clip(next_x, 0, w - 1)] + top_yx = top_yx + step_flags.reshape((-1, 1)) * step + step_flags = np.ones((len(points_yx), 1), dtype=np.bool) + while np.any(step_flags): + next_yx = np.array(bot_yx - step, dtype=np.int32) + next_y, next_x = next_yx[:, 0], next_yx[:, 1] + step_flags = (0 <= next_y) & (next_y < h) & (0 < next_x) & ( + next_x < w) & contour_mask[np.clip(next_y, 0, h - 1), + np.clip(next_x, 0, w - 1)] + bot_yx = bot_yx - step_flags.reshape((-1, 1)) * step + centers = np.array((top_yx + bot_yx) * 0.5, dtype=np.int32) + return centers + + +def merge_disks(disks, disk_overlap_thr): + xy = disks[:, 0:2] + radius = disks[:, 2] + scores = disks[:, 3] + order = scores.argsort()[::-1] + + merged_disks = [] + while order.size > 0: + if order.size == 1: + merged_disks.append(disks[order]) + break + else: + i = order[0] + d = norm(xy[i] - xy[order[1:]], axis=1) + ri = radius[i] + r = radius[order[1:]] + d_thr = (ri + r) * disk_overlap_thr + + merge_inds = np.where(d <= d_thr)[0] + 1 + if merge_inds.size > 0: + merge_order = np.hstack([i, order[merge_inds]]) + merged_disks.append(np.mean(disks[merge_order], axis=0)) + else: + merged_disks.append(disks[i]) + + inds = np.where(d > d_thr)[0] + 1 + order = order[inds] + merged_disks = np.vstack(merged_disks) + + return merged_disks + + +def textsnake_decode(preds, + text_repr_type='poly', + min_text_region_confidence=0.6, + min_center_region_confidence=0.2, + min_center_area=30, + disk_overlap_thr=0.03, + radius_shrink_ratio=1.03): + """Decoding predictions of TextSnake to instances. This was partially + adapted from https://github.com/princewang1994/TextSnake.pytorch. + + Args: + preds (tensor): The head output tensor of size 6xHxW. + text_repr_type (str): The boundary encoding type 'poly' or 'quad'. + min_text_region_confidence (float): The confidence threshold of text + region in TextSnake. + min_center_region_confidence (float): The confidence threshold of text + center region in TextSnake. + min_center_area (int): The minimal text center region area. + disk_overlap_thr (float): The radius overlap threshold for merging + disks. + radius_shrink_ratio (float): The shrink ratio of ordered disks radii. + + Returns: + boundaries (list[list[float]]): The instance boundary and its + instance confidence list. + """ + assert text_repr_type == 'poly' + preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) + preds = preds.detach().cpu().numpy() + + pred_text_score = preds[0] + pred_text_mask = pred_text_score > min_text_region_confidence + pred_center_score = preds[1] * pred_text_score + pred_center_mask = pred_center_score > min_center_region_confidence + pred_sin = preds[2] + pred_cos = preds[3] + pred_radius = preds[4] + mask_sz = pred_text_mask.shape + + scale = np.sqrt(1.0 / (pred_sin**2 + pred_cos**2 + 1e-8)) + pred_sin = pred_sin * scale + pred_cos = pred_cos * scale + + pred_center_mask = fill_hole(pred_center_mask).astype(np.uint8) + center_contours, _ = cv2.findContours(pred_center_mask, cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + + boundaries = [] + for contour in center_contours: + if cv2.contourArea(contour) < min_center_area: + continue + instance_center_mask = np.zeros(mask_sz, dtype=np.uint8) + cv2.drawContours(instance_center_mask, [contour], -1, 1, -1) + skeleton = skeletonize(instance_center_mask) + skeleton_yx = np.argwhere(skeleton > 0) + y, x = skeleton_yx[:, 0], skeleton_yx[:, 1] + cos = pred_cos[y, x].reshape((-1, 1)) + sin = pred_sin[y, x].reshape((-1, 1)) + radius = pred_radius[y, x].reshape((-1, 1)) + + center_line_yx = centralize(skeleton_yx, cos, -sin, radius, + instance_center_mask) + y, x = center_line_yx[:, 0], center_line_yx[:, 1] + radius = (pred_radius[y, x] * radius_shrink_ratio).reshape((-1, 1)) + score = pred_center_score[y, x].reshape((-1, 1)) + instance_disks = np.hstack([np.fliplr(center_line_yx), radius, score]) + instance_disks = merge_disks(instance_disks, disk_overlap_thr) + + instance_mask = np.zeros(mask_sz, dtype=np.uint8) + for x, y, radius, score in instance_disks: + if radius > 0: + cv2.circle(instance_mask, (int(x), int(y)), int(radius), 1, -1) + contours, _ = cv2.findContours(instance_mask, cv2.RETR_TREE, + cv2.CHAIN_APPROX_SIMPLE) + + score = np.sum(instance_mask * pred_text_score) / ( + np.sum(instance_mask) + 1e-8) + if len(contours) > 0: + boundary = contours[0].flatten().tolist() + boundaries.append(boundary + [score]) + + return boundaries diff --git a/mmocr/models/textrecog/__init__.py b/mmocr/models/textrecog/__init__.py new file mode 100644 index 00000000..76bfa419 --- /dev/null +++ b/mmocr/models/textrecog/__init__.py @@ -0,0 +1,8 @@ +from .backbones import * # noqa: F401,F403 +from .convertors import * # noqa: F401,F403 +from .decoders import * # noqa: F401,F403 +from .encoders import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .recognizer import * # noqa: F401,F403 diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py new file mode 100644 index 00000000..51517090 --- /dev/null +++ b/mmocr/models/textrecog/backbones/__init__.py @@ -0,0 +1,5 @@ +from .nrtr_modality_transformer import NRTRModalityTransform +from .resnet31_ocr import ResNet31OCR +from .very_deep_vgg import VeryDeepVgg + +__all__ = ['ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform'] diff --git a/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py new file mode 100644 index 00000000..4471c733 --- /dev/null +++ b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py @@ -0,0 +1,55 @@ +import torch.nn as nn +from mmcv.cnn import kaiming_init, uniform_init + +from mmdet.models.builder import BACKBONES + + +@BACKBONES.register_module() +class NRTRModalityTransform(nn.Module): + + def __init__(self, input_channels=3, input_height=32): + super().__init__() + + self.conv_1 = nn.Conv2d( + in_channels=input_channels, + out_channels=32, + kernel_size=3, + stride=2, + padding=1) + self.relu_1 = nn.ReLU(True) + self.bn_1 = nn.BatchNorm2d(32) + + self.conv_2 = nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + padding=1) + self.relu_2 = nn.ReLU(True) + self.bn_2 = nn.BatchNorm2d(64) + + feat_height = input_height // 4 + + self.linear = nn.Linear(64 * feat_height, 512) + + def init_weights(self, pretrained=None): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) + + def forward(self, x): + x = self.conv_1(x) + x = self.relu_1(x) + x = self.bn_1(x) + + x = self.conv_2(x) + x = self.relu_2(x) + x = self.bn_2(x) + + n, c, h, w = x.size() + x = x.permute(0, 3, 2, 1).contiguous().view(n, w, h * c) + x = self.linear(x) + x = x.permute(0, 2, 1).contiguous().view(n, -1, 1, w) + return x diff --git a/mmocr/models/textrecog/backbones/resnet31_ocr.py b/mmocr/models/textrecog/backbones/resnet31_ocr.py new file mode 100644 index 00000000..e2787556 --- /dev/null +++ b/mmocr/models/textrecog/backbones/resnet31_ocr.py @@ -0,0 +1,149 @@ +import torch.nn as nn +from mmcv.cnn import kaiming_init, uniform_init + +import mmocr.utils as utils +from mmdet.models.builder import BACKBONES +from mmocr.models.textrecog.layers import BasicBlock + + +@BACKBONES.register_module() +class ResNet31OCR(nn.Module): + """Implement ResNet backbone for text recognition, modified from + `ResNet `_ + Args: + base_channels (int): Number of channels of input image tensor. + layers (list[int]): List of BasicBlock number for each stage. + channels (list[int]): List of out_channels of Conv2d layer. + out_indices (None | Sequence[int]): Indicdes of output stages. + stage4_pool_cfg (dict): Dictionary to construct and configure + pooling layer in stage 4. + last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage. + """ + + def __init__(self, + base_channels=3, + layers=[1, 2, 5, 3], + channels=[64, 128, 256, 256, 512, 512, 512], + out_indices=None, + stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)), + last_stage_pool=False): + super().__init__() + assert isinstance(base_channels, int) + assert utils.is_type_list(layers, int) + assert utils.is_type_list(channels, int) + assert out_indices is None or (isinstance(out_indices, list) + or isinstance(out_indices, tuple)) + assert isinstance(last_stage_pool, bool) + + self.out_indices = out_indices + self.last_stage_pool = last_stage_pool + + # conv 1 (Conv, Conv) + self.conv1_1 = nn.Conv2d( + base_channels, channels[0], kernel_size=3, stride=1, padding=1) + self.bn1_1 = nn.BatchNorm2d(channels[0]) + self.relu1_1 = nn.ReLU(inplace=True) + + self.conv1_2 = nn.Conv2d( + channels[0], channels[1], kernel_size=3, stride=1, padding=1) + self.bn1_2 = nn.BatchNorm2d(channels[1]) + self.relu1_2 = nn.ReLU(inplace=True) + + # conv 2 (Max-pooling, Residual block, Conv) + self.pool2 = nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block2 = self._make_layer(channels[1], channels[2], layers[0]) + self.conv2 = nn.Conv2d( + channels[2], channels[2], kernel_size=3, stride=1, padding=1) + self.bn2 = nn.BatchNorm2d(channels[2]) + self.relu2 = nn.ReLU(inplace=True) + + # conv 3 (Max-pooling, Residual block, Conv) + self.pool3 = nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self.block3 = self._make_layer(channels[2], channels[3], layers[1]) + self.conv3 = nn.Conv2d( + channels[3], channels[3], kernel_size=3, stride=1, padding=1) + self.bn3 = nn.BatchNorm2d(channels[3]) + self.relu3 = nn.ReLU(inplace=True) + + # conv 4 (Max-pooling, Residual block, Conv) + self.pool4 = nn.MaxPool2d(padding=0, ceil_mode=True, **stage4_pool_cfg) + self.block4 = self._make_layer(channels[3], channels[4], layers[2]) + self.conv4 = nn.Conv2d( + channels[4], channels[4], kernel_size=3, stride=1, padding=1) + self.bn4 = nn.BatchNorm2d(channels[4]) + self.relu4 = nn.ReLU(inplace=True) + + # conv 5 ((Max-pooling), Residual block, Conv) + self.pool5 = None + if self.last_stage_pool: + self.pool5 = nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) # 1/16 + self.block5 = self._make_layer(channels[4], channels[5], layers[3]) + self.conv5 = nn.Conv2d( + channels[5], channels[5], kernel_size=3, stride=1, padding=1) + self.bn5 = nn.BatchNorm2d(channels[5]) + self.relu5 = nn.ReLU(inplace=True) + + def init_weights(self, pretrained=None): + # initialize weight and bias + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) + + def _make_layer(self, input_channels, output_channels, blocks): + layers = [] + for _ in range(blocks): + downsample = None + if input_channels != output_channels: + downsample = nn.Sequential( + nn.Conv2d( + input_channels, + output_channels, + kernel_size=1, + stride=1, + bias=False), + nn.BatchNorm2d(output_channels), + ) + layers.append( + BasicBlock( + input_channels, output_channels, downsample=downsample)) + input_channels = output_channels + + return nn.Sequential(*layers) + + def forward(self, x): + + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu1_1(x) + + x = self.conv1_2(x) + x = self.bn1_2(x) + x = self.relu1_2(x) + + outs = [] + for i in range(4): + layer_index = i + 2 + pool_layer = getattr(self, f'pool{layer_index}') + block_layer = getattr(self, f'block{layer_index}') + conv_layer = getattr(self, f'conv{layer_index}') + bn_layer = getattr(self, f'bn{layer_index}') + relu_layer = getattr(self, f'relu{layer_index}') + + if pool_layer is not None: + x = pool_layer(x) + x = block_layer(x) + x = conv_layer(x) + x = bn_layer(x) + x = relu_layer(x) + + outs.append(x) + + if self.out_indices is not None: + return tuple([outs[i] for i in self.out_indices]) + + return x diff --git a/mmocr/models/textrecog/backbones/very_deep_vgg.py b/mmocr/models/textrecog/backbones/very_deep_vgg.py new file mode 100644 index 00000000..a30ace5c --- /dev/null +++ b/mmocr/models/textrecog/backbones/very_deep_vgg.py @@ -0,0 +1,70 @@ +import torch.nn as nn +from mmcv.cnn import uniform_init, xavier_init + +from mmdet.models.builder import BACKBONES + + +@BACKBONES.register_module() +class VeryDeepVgg(nn.Module): + """Implement VGG-VeryDeep backbone for text recognition, modified from + `VGG-VeryDeep `_ + Args: + input_channels (int): Number of channels of input image tensor. + leakyRelu (bool): Use leakyRelu or not. + """ + + def __init__(self, leakyRelu=True, input_channels=3): + super().__init__() + + ks = [3, 3, 3, 3, 3, 3, 2] + ps = [1, 1, 1, 1, 1, 1, 0] + ss = [1, 1, 1, 1, 1, 1, 1] + nm = [64, 128, 256, 256, 512, 512, 512] + + self.channels = nm + + cnn = nn.Sequential() + + def convRelu(i, batchNormalization=False): + nIn = input_channels if i == 0 else nm[i - 1] + nOut = nm[i] + cnn.add_module('conv{0}'.format(i), + nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) + if batchNormalization: + cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) + if leakyRelu: + cnn.add_module('relu{0}'.format(i), + nn.LeakyReLU(0.2, inplace=True)) + else: + cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) + + convRelu(0) + cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 + convRelu(1) + cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 + convRelu(2, True) + convRelu(3) + cnn.add_module('pooling{0}'.format(2), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 + convRelu(4, True) + convRelu(5) + cnn.add_module('pooling{0}'.format(3), + nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 + convRelu(6, True) # 512x1x16 + + self.cnn = cnn + + def init_weights(self, pretrained=None): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) + + def out_channels(self): + return self.channels[-1] + + def forward(self, x): + output = self.cnn(x) + + return output diff --git a/mmocr/models/textrecog/convertors/__init__.py b/mmocr/models/textrecog/convertors/__init__.py new file mode 100644 index 00000000..60fc6300 --- /dev/null +++ b/mmocr/models/textrecog/convertors/__init__.py @@ -0,0 +1,6 @@ +from .attn import AttnConvertor +from .base import BaseConvertor +from .ctc import CTCConvertor +from .seg import SegConvertor + +__all__ = ['BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor'] diff --git a/mmocr/models/textrecog/convertors/attn.py b/mmocr/models/textrecog/convertors/attn.py new file mode 100644 index 00000000..a80282e8 --- /dev/null +++ b/mmocr/models/textrecog/convertors/attn.py @@ -0,0 +1,140 @@ +import torch + +import mmocr.utils as utils +from mmocr.models.builder import CONVERTORS +from .base import BaseConvertor + + +@CONVERTORS.register_module() +class AttnConvertor(BaseConvertor): + """Convert between text, index and tensor for encoder-decoder based + pipeline. + + Args: + dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}. + dict_file (None|str): Character dict file path. If not none, + higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, higher + priority than dict_type, but lower than dict_file. + with_unknown (bool): If True, add `UKN` token to class. + max_seq_len (int): Maximum sequence length of label. + lower (bool): If True, convert original string to lower case. + start_end_same (bool): Whether use the same index for + start and end token or not. Default: True. + """ + + def __init__(self, + dict_type='DICT90', + dict_file=None, + dict_list=None, + with_unknown=True, + max_seq_len=40, + lower=False, + start_end_same=True, + **kwargs): + super().__init__(dict_type, dict_file, dict_list) + assert isinstance(with_unknown, bool) + assert isinstance(max_seq_len, int) + assert isinstance(lower, bool) + + self.with_unknown = with_unknown + self.max_seq_len = max_seq_len + self.lower = lower + self.start_end_same = start_end_same + + self.update_dict() + + def update_dict(self): + start_end_token = '' + unknown_token = '' + padding_token = '' + + # unknown + self.unknown_idx = None + if self.with_unknown: + self.idx2char.append(unknown_token) + self.unknown_idx = len(self.idx2char) - 1 + + # BOS/EOS + self.idx2char.append(start_end_token) + self.start_idx = len(self.idx2char) - 1 + if not self.start_end_same: + self.idx2char.append(start_end_token) + self.end_idx = len(self.idx2char) - 1 + + # padding + self.idx2char.append(padding_token) + self.padding_idx = len(self.idx2char) - 1 + + # update char2idx + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def str2tensor(self, strings): + """ + Convert text-string into tensor. + Args: + strings (list[str]): ['hello', 'world'] + Returns: + dict (str: Tensor | list[tensor]): + tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]), + torch.Tensor([5,4,6,3,7])] + padded_targets (Tensor(bsz * max_seq_len)) + """ + assert utils.is_type_list(strings, str) + + tensors, padded_targets = [], [] + indexes = self.str2idx(strings) + for index in indexes: + tensor = torch.LongTensor(index) + tensors.append(tensor) + # target tensor for loss + src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0) + src_target[-1] = self.end_idx + src_target[0] = self.start_idx + src_target[1:-1] = tensor + padded_target = (torch.ones(self.max_seq_len) * + self.padding_idx).long() + char_num = src_target.size(0) + if char_num > self.max_seq_len: + padded_target = src_target[:self.max_seq_len] + else: + padded_target[:char_num] = src_target + padded_targets.append(padded_target) + padded_targets = torch.stack(padded_targets, 0).long() + + return {'targets': tensors, 'padded_targets': padded_targets} + + def tensor2idx(self, outputs, img_metas=None): + """ + Convert output tensor to text-index + Args: + outputs (tensor): model outputs with size: N * T * C + img_metas (list[dict]): Each dict contains one image info. + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]] + scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], + [0.9,0.9,0.98,0.97,0.96]] + """ + batch_size = outputs.size(0) + ignore_indexes = [self.padding_idx] + indexes, scores = [], [] + for idx in range(batch_size): + seq = outputs[idx, :, :] + max_value, max_idx = torch.max(seq, -1) + str_index, str_score = [], [] + output_index = max_idx.cpu().detach().numpy().tolist() + output_score = max_value.cpu().detach().numpy().tolist() + for char_index, char_score in zip(output_index, output_score): + if char_index in ignore_indexes: + continue + if char_index == self.end_idx: + break + str_index.append(char_index) + str_score.append(char_score) + + indexes.append(str_index) + scores.append(str_score) + + return indexes, scores diff --git a/mmocr/models/textrecog/convertors/base.py b/mmocr/models/textrecog/convertors/base.py new file mode 100644 index 00000000..9002d4fe --- /dev/null +++ b/mmocr/models/textrecog/convertors/base.py @@ -0,0 +1,115 @@ +from mmocr.models.builder import CONVERTORS + + +@CONVERTORS.register_module() +class BaseConvertor: + """Convert between text, index and tensor for text recognize pipeline. + + Args: + dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. + dict_file (None|str): Character dict file path. If not none, + the dict_file is of higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, the list + is of higher priority than dict_type, but lower than dict_file. + """ + start_idx = end_idx = padding_idx = 0 + unknown_idx = None + lower = False + + DICT36 = tuple('0123456789abcdefghijklmnopqrstuvwxyz') + DICT90 = tuple('0123456789abcdefghijklmnopqrstuvwxyz' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()' + '*+,-./:;<=>?@[\\]_`~') + + def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None): + assert dict_type in ('DICT36', 'DICT90') + assert dict_file is None or isinstance(dict_file, str) + assert dict_list is None or isinstance(dict_list, list) + self.idx2char = [] + if dict_file is not None: + with open(dict_file, encoding='utf-8') as fr: + for line in fr: + line = line.strip() + if line != '': + self.idx2char.append(line) + elif dict_list is not None: + self.idx2char = dict_list + else: + if dict_type == 'DICT36': + self.idx2char = list(self.DICT36) + else: + self.idx2char = list(self.DICT90) + + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def num_classes(self): + """Number of output classes.""" + return len(self.idx2char) + + def str2idx(self, strings): + """Convert strings to indexes. + + Args: + strings (list[str]): ['hello', 'world']. + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + """ + assert isinstance(strings, list) + + indexes = [] + for string in strings: + if self.lower: + string = string.lower() + index = [] + for char in string: + char_idx = self.char2idx.get(char, self.unknown_idx) + if char_idx is None: + raise Exception(f'Chararcter: {char} not in dict,' + f' please check gt_label and use' + f' custom dict file,' + f' or set "with_unknown=True"') + index.append(char_idx) + indexes.append(index) + + return indexes + + def str2tensor(self, strings): + """Convert text-string to input tensor. + + Args: + strings (list[str]): ['hello', 'world']. + Returns: + tensors (list[torch.Tensor]): [torch.Tensor([1,2,3,3,4]), + torch.Tensor([5,4,6,3,7])]. + """ + raise NotImplementedError + + def idx2str(self, indexes): + """Convert indexes to text strings. + + Args: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + Returns: + strings (list[str]): ['hello', 'world']. + """ + assert isinstance(indexes, list) + + strings = [] + for index in indexes: + string = [self.idx2char[i] for i in index] + strings.append(''.join(string)) + + return strings + + def tensor2idx(self, output): + """Convert model output tensor to character indexes and scores. + Args: + output (tensor): The model outputs with size: N * T * C + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], + [0.9,0.9,0.98,0.97,0.96]]. + """ + raise NotImplementedError diff --git a/mmocr/models/textrecog/convertors/ctc.py b/mmocr/models/textrecog/convertors/ctc.py new file mode 100644 index 00000000..c14fc23f --- /dev/null +++ b/mmocr/models/textrecog/convertors/ctc.py @@ -0,0 +1,144 @@ +import math + +import torch +import torch.nn.functional as F + +import mmocr.utils as utils +from mmocr.models.builder import CONVERTORS +from .base import BaseConvertor + + +@CONVERTORS.register_module() +class CTCConvertor(BaseConvertor): + """Convert between text, index and tensor for CTC loss-based pipeline. + + Args: + dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. + dict_file (None|str): Character dict file path. If not none, the file + is of higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, the list + is of higher priority than dict_type, but lower than dict_file. + with_unknown (bool): If True, add `UKN` token to class. + lower (bool): If True, convert original string to lower case. + """ + + def __init__(self, + dict_type='DICT90', + dict_file=None, + dict_list=None, + with_unknown=True, + lower=False, + **kwargs): + super().__init__(dict_type, dict_file, dict_list) + assert isinstance(with_unknown, bool) + assert isinstance(lower, bool) + + self.with_unknown = with_unknown + self.lower = lower + self.update_dict() + + def update_dict(self): + # CTC-blank + blank_token = '' + self.blank_idx = 0 + self.idx2char.insert(0, blank_token) + + # unknown + self.unknown_idx = None + if self.with_unknown: + self.idx2char.append('') + self.unknown_idx = len(self.idx2char) - 1 + + # update char2idx + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def str2tensor(self, strings): + """Convert text-string to ctc-loss input tensor. + + Args: + strings (list[str]): ['hello', 'world']. + Returns: + dict (str: tensor | list[tensor]): + tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]), + torch.Tensor([5,4,6,3,7])]. + flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]). + target_lengths (tensor): torch.IntTensot([5,5]). + """ + assert utils.is_type_list(strings, str) + + tensors = [] + indexes = self.str2idx(strings) + for index in indexes: + tensor = torch.IntTensor(index) + tensors.append(tensor) + target_lengths = torch.IntTensor([len(t) for t in tensors]) + flatten_target = torch.cat(tensors) + + return { + 'targets': tensors, + 'flatten_targets': flatten_target, + 'target_lengths': target_lengths + } + + def tensor2idx(self, output, img_metas, topk=1, return_topk=False): + """Convert model output tensor to index-list. + Args: + output (tensor): The model outputs with size: N * T * C. + img_metas (list[dict]): Each dict contains one image info. + topk (int): The highest k classes to be returned. + return_topk (bool): Whether to return topk or just top1. + Returns: + indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. + scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], + [0.9,0.9,0.98,0.97,0.96]] + ( + indexes_topk (list[list[list[int]->len=topk]]): + scores_topk (list[list[list[float]->len=topk]]) + ). + """ + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == output.size(0) + assert isinstance(topk, int) + assert topk >= 1 + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + + batch_size = output.size(0) + output = F.softmax(output, dim=2) + output = output.cpu().detach() + batch_topk_value, batch_topk_idx = output.topk(topk, dim=2) + batch_max_idx = batch_topk_idx[:, :, 0] + scores_topk, indexes_topk = [], [] + scores, indexes = [], [] + feat_len = output.size(1) + for b in range(batch_size): + valid_ratio = valid_ratios[b] + decode_len = min(feat_len, math.ceil(feat_len * valid_ratio)) + pred = batch_max_idx[b, :] + select_idx = [] + prev_idx = self.blank_idx + for t in range(decode_len): + tmp_value = pred[t].item() + if tmp_value not in (prev_idx, self.blank_idx): + select_idx.append(t) + prev_idx = tmp_value + select_idx = torch.LongTensor(select_idx) + topk_value = torch.index_select(batch_topk_value[b, :, :], 0, + select_idx) # valid_seqlen * topk + topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0, + select_idx) + topk_idx_list, topk_value_list = topk_idx.numpy().tolist( + ), topk_value.numpy().tolist() + indexes_topk.append(topk_idx_list) + scores_topk.append(topk_value_list) + indexes.append([x[0] for x in topk_idx_list]) + scores.append([x[0] for x in topk_value_list]) + + if return_topk: + return indexes_topk, scores_topk + + return indexes, scores diff --git a/mmocr/models/textrecog/convertors/seg.py b/mmocr/models/textrecog/convertors/seg.py new file mode 100644 index 00000000..0a626dac --- /dev/null +++ b/mmocr/models/textrecog/convertors/seg.py @@ -0,0 +1,123 @@ +import cv2 +import numpy as np +import torch + +import mmocr.utils as utils +from mmocr.models.builder import CONVERTORS +from .base import BaseConvertor + + +@CONVERTORS.register_module() +class SegConvertor(BaseConvertor): + """Convert between text, index and tensor for segmentation based pipeline. + + Args: + dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. + dict_file (None|str): Character dict file path. If not none, the + file is of higher priority than dict_type. + dict_list (None|list[str]): Character list. If not none, the list + is of higher priority than dict_type, but lower than dict_file. + with_unknown (bool): If True, add `UKN` token to class. + lower (bool): If True, convert original string to lower case. + """ + + def __init__(self, + dict_type='DICT36', + dict_file=None, + dict_list=None, + with_unknown=True, + lower=False, + **kwargs): + super().__init__(dict_type, dict_file, dict_list) + assert isinstance(with_unknown, bool) + assert isinstance(lower, bool) + + self.with_unknown = with_unknown + self.lower = lower + self.update_dict() + + def update_dict(self): + # background + self.idx2char.insert(0, '') + + # unknown + self.unknown_idx = None + if self.with_unknown: + self.idx2char.append('') + self.unknown_idx = len(self.idx2char) - 1 + + # update char2idx + self.char2idx = {} + for idx, char in enumerate(self.idx2char): + self.char2idx[char] = idx + + def tensor2str(self, output, img_metas=None): + """Convert model output tensor to string labels. + Args: + output (tensor): Model outputs with size: N * C * H * W + img_metas (list[dict]): Each dict contains one image info. + Returns: + texts (list[str]): Decoded text labels. + scores (list[list[float]]): Decoded chars scores. + """ + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == output.size(0) + + texts, scores = [], [] + for b in range(output.size(0)): + seg_pred = output[b].detach() + seg_res = torch.argmax( + seg_pred, dim=0).cpu().numpy().astype(np.int32) + + seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8) + _, labels, stats, centroids = cv2.connectedComponentsWithStats( + seg_thr) + + component_num = stats.shape[0] + + all_res = [] + for i in range(component_num): + temp_loc = (labels == i) + temp_value = seg_res[temp_loc] + temp_center = centroids[i] + + temp_max_num = 0 + temp_max_cls = -1 + temp_total_num = 0 + for c in range(len(self.idx2char)): + c_num = np.sum(temp_value == c) + temp_total_num += c_num + if c_num > temp_max_num: + temp_max_num = c_num + temp_max_cls = c + + if temp_max_cls == 0: + continue + temp_max_score = 1.0 * temp_max_num / temp_total_num + all_res.append( + [temp_max_cls, temp_center, temp_max_num, temp_max_score]) + + all_res = sorted(all_res, key=lambda s: s[1][0]) + chars, char_scores = [], [] + for res in all_res: + temp_area = res[2] + if temp_area < 20: + continue + temp_char_index = res[0] + if temp_char_index >= len(self.idx2char): + temp_char = '' + elif temp_char_index <= 0: + temp_char = '' + elif temp_char_index == self.unknown_idx: + temp_char = '' + else: + temp_char = self.idx2char[temp_char_index] + chars.append(temp_char) + char_scores.append(res[3]) + + text = ''.join(chars) + + texts.append(text) + scores.append(char_scores) + + return texts, scores diff --git a/mmocr/models/textrecog/decoders/__init__.py b/mmocr/models/textrecog/decoders/__init__.py new file mode 100755 index 00000000..8b374733 --- /dev/null +++ b/mmocr/models/textrecog/decoders/__init__.py @@ -0,0 +1,15 @@ +from .base_decoder import BaseDecoder +from .crnn_decoder import CRNNDecoder +from .position_attention_decoder import PositionAttentionDecoder +from .robust_scanner_decoder import RobustScannerDecoder +from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder +from .sar_decoder_with_bs import ParallelSARDecoderWithBS +from .sequence_attention_decoder import SequenceAttentionDecoder +from .transformer_decoder import TFDecoder + +__all__ = [ + 'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder', + 'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder', + 'SequenceAttentionDecoder', 'PositionAttentionDecoder', + 'RobustScannerDecoder' +] diff --git a/mmocr/models/textrecog/decoders/base_decoder.py b/mmocr/models/textrecog/decoders/base_decoder.py new file mode 100644 index 00000000..543cd528 --- /dev/null +++ b/mmocr/models/textrecog/decoders/base_decoder.py @@ -0,0 +1,32 @@ +import torch.nn as nn + +from mmocr.models.builder import DECODERS + + +@DECODERS.register_module() +class BaseDecoder(nn.Module): + """Base decoder class for text recognition.""" + + def __init__(self, **kwargs): + super().__init__() + + def init_weights(self): + pass + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + raise NotImplementedError + + def forward_test(self, feat, out_enc, img_metas): + raise NotImplementedError + + def forward(self, + feat, + out_enc, + targets_dict=None, + img_metas=None, + train_mode=True): + self.train_mode = train_mode + if train_mode: + return self.forward_train(feat, out_enc, targets_dict, img_metas) + + return self.forward_test(feat, out_enc, img_metas) diff --git a/mmocr/models/textrecog/decoders/crnn_decoder.py b/mmocr/models/textrecog/decoders/crnn_decoder.py new file mode 100644 index 00000000..1ce5226a --- /dev/null +++ b/mmocr/models/textrecog/decoders/crnn_decoder.py @@ -0,0 +1,49 @@ +import torch.nn as nn +from mmcv.cnn import xavier_init + +from mmocr.models.builder import DECODERS +from mmocr.models.textrecog.layers import BidirectionalLSTM +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class CRNNDecoder(BaseDecoder): + + def __init__(self, + in_channels=None, + num_classes=None, + rnn_flag=False, + **kwargs): + super().__init__() + self.num_classes = num_classes + self.rnn_flag = rnn_flag + + if rnn_flag: + self.decoder = nn.Sequential( + BidirectionalLSTM(in_channels, 256, 256), + BidirectionalLSTM(256, 256, num_classes)) + else: + self.decoder = nn.Conv2d( + in_channels, num_classes, kernel_size=1, stride=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + assert feat.size(2) == 1, 'feature height must be 1' + if self.rnn_flag: + x = feat.squeeze(2) # [N, C, W] + x = x.permute(2, 0, 1) # [W, N, C] + x = self.decoder(x) # [W, N, C] + outputs = x.permute(1, 0, 2).contiguous() + else: + x = self.decoder(feat) + x = x.permute(0, 3, 1, 2).contiguous() + n, w, c, h = x.size() + outputs = x.view(n, w, c * h) + return outputs + + def forward_test(self, feat, out_enc, img_metas): + return self.forward_train(feat, out_enc, None, img_metas) diff --git a/mmocr/models/textrecog/decoders/position_attention_decoder.py b/mmocr/models/textrecog/decoders/position_attention_decoder.py new file mode 100644 index 00000000..2ef85255 --- /dev/null +++ b/mmocr/models/textrecog/decoders/position_attention_decoder.py @@ -0,0 +1,138 @@ +import math + +import torch +import torch.nn as nn + +from mmocr.models.builder import DECODERS +from mmocr.models.textrecog.layers import (DotProductAttentionLayer, + PositionAwareLayer) +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class PositionAttentionDecoder(BaseDecoder): + + def __init__(self, + num_classes=None, + rnn_layers=2, + dim_input=512, + dim_model=128, + max_seq_len=40, + mask=True, + return_feature=False, + encode_value=False): + super().__init__() + + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.max_seq_len = max_seq_len + self.return_feature = return_feature + self.encode_value = encode_value + self.mask = mask + + self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model) + + self.position_aware_module = PositionAwareLayer( + self.dim_model, rnn_layers) + + self.attention_layer = DotProductAttentionLayer() + + self.prediction = None + if not self.return_feature: + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear( + dim_model if encode_value else dim_input, pred_num_classes) + + def init_weights(self): + pass + + def _get_position_index(self, length, batch_size, device=None): + position_index = torch.arange(0, length, device=device) + position_index = position_index.repeat([batch_size, 1]) + position_index = position_index.long() + return position_index + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + targets = targets_dict['padded_targets'].to(feat.device) + + # + n, c_enc, h, w = out_enc.size() + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.size() + assert c_feat == self.dim_input + _, len_q = targets.size() + assert len_q <= self.max_seq_len + + position_index = self._get_position_index(len_q, n, feat.device) + + position_out_enc = self.position_aware_module(out_enc) + + query = self.embedding(position_index) + query = query.permute(0, 2, 1).contiguous() + key = position_out_enc.view(n, c_enc, h * w) + if self.encode_value: + value = out_enc.view(n, c_enc, h * w) + else: + value = feat.view(n, c_feat, h * w) + + mask = None + if valid_ratios is not None: + mask = query.new_zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + mask[i, :, valid_width:] = 1 + mask = mask.bool() + mask = mask.view(n, h * w) + + attn_out = self.attention_layer(query, key, value, mask) + attn_out = attn_out.permute(0, 2, 1).contiguous() # [n, len_q, dim_v] + + if self.return_feature: + return attn_out + + return self.prediction(attn_out) + + def forward_test(self, feat, out_enc, img_metas): + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + seq_len = self.max_seq_len + n, c_enc, h, w = out_enc.size() + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.size() + assert c_feat == self.dim_input + + position_index = self._get_position_index(seq_len, n, feat.device) + + position_out_enc = self.position_aware_module(out_enc) + + query = self.embedding(position_index) + query = query.permute(0, 2, 1).contiguous() + key = position_out_enc.view(n, c_enc, h * w) + if self.encode_value: + value = out_enc.view(n, c_enc, h * w) + else: + value = feat.view(n, c_feat, h * w) + + mask = None + if valid_ratios is not None: + mask = query.new_zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + mask[i, :, valid_width:] = 1 + mask = mask.bool() + mask = mask.view(n, h * w) + + attn_out = self.attention_layer(query, key, value, mask) + attn_out = attn_out.permute(0, 2, 1).contiguous() + + if self.return_feature: + return attn_out + + return self.prediction(attn_out) diff --git a/mmocr/models/textrecog/decoders/robust_scanner_decoder.py b/mmocr/models/textrecog/decoders/robust_scanner_decoder.py new file mode 100644 index 00000000..0301e153 --- /dev/null +++ b/mmocr/models/textrecog/decoders/robust_scanner_decoder.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmocr.models.builder import DECODERS, build_decoder +from mmocr.models.textrecog.layers import RobustScannerFusionLayer +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class RobustScannerDecoder(BaseDecoder): + + def __init__(self, + num_classes=None, + dim_input=512, + dim_model=128, + max_seq_len=40, + start_idx=0, + mask=True, + padding_idx=None, + encode_value=False, + hybrid_decoder=None, + position_decoder=None): + super().__init__() + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.max_seq_len = max_seq_len + self.encode_value = encode_value + self.start_idx = start_idx + self.padding_idx = padding_idx + self.mask = mask + + # init hybrid decoder + hybrid_decoder.update(num_classes=self.num_classes) + hybrid_decoder.update(dim_input=self.dim_input) + hybrid_decoder.update(dim_model=self.dim_model) + hybrid_decoder.update(start_idx=self.start_idx) + hybrid_decoder.update(padding_idx=self.padding_idx) + hybrid_decoder.update(max_seq_len=self.max_seq_len) + hybrid_decoder.update(mask=self.mask) + hybrid_decoder.update(encode_value=self.encode_value) + hybrid_decoder.update(return_feature=True) + + self.hybrid_decoder = build_decoder(hybrid_decoder) + + # init position decoder + position_decoder.update(num_classes=self.num_classes) + position_decoder.update(dim_input=self.dim_input) + position_decoder.update(dim_model=self.dim_model) + position_decoder.update(max_seq_len=self.max_seq_len) + position_decoder.update(mask=self.mask) + position_decoder.update(encode_value=self.encode_value) + position_decoder.update(return_feature=True) + + self.position_decoder = build_decoder(position_decoder) + + self.fusion_module = RobustScannerFusionLayer( + self.dim_model if encode_value else dim_input) + + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear(dim_model if encode_value else dim_input, + pred_num_classes) + + def init_weights(self): + pass + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + hybrid_glimpse = self.hybrid_decoder.forward_train( + feat, out_enc, targets_dict, img_metas) + position_glimpse = self.position_decoder.forward_train( + feat, out_enc, targets_dict, img_metas) + + fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse) + + out = self.prediction(fusion_out) + + return out + + def forward_test(self, feat, out_enc, img_metas): + seq_len = self.max_seq_len + batch_size = feat.size(0) + + decode_sequence = (feat.new_ones( + (batch_size, seq_len)) * self.start_idx).long() + + position_glimpse = self.position_decoder.forward_test( + feat, out_enc, img_metas) + + outputs = [] + for i in range(seq_len): + hybrid_glimpse_step = self.hybrid_decoder.forward_test_step( + feat, out_enc, decode_sequence, i, img_metas) + + fusion_out = self.fusion_module(hybrid_glimpse_step, + position_glimpse[:, i, :]) + + char_out = self.prediction(fusion_out) + char_out = F.softmax(char_out, -1) + outputs.append(char_out) + _, max_idx = torch.max(char_out, dim=1, keepdim=False) + if i < seq_len - 1: + decode_sequence[:, i + 1] = max_idx + + outputs = torch.stack(outputs, 1) + + return outputs diff --git a/mmocr/models/textrecog/decoders/sar_decoder.py b/mmocr/models/textrecog/decoders/sar_decoder.py new file mode 100755 index 00000000..a463c373 --- /dev/null +++ b/mmocr/models/textrecog/decoders/sar_decoder.py @@ -0,0 +1,417 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mmocr.utils as utils +from mmocr.models.builder import DECODERS +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class ParallelSARDecoder(BaseDecoder): + """Implementation Parallel Decoder module in `SAR. + + `_ + + Args: + number_classes (int): Output class number. + channels (list[int]): Network layer channels. + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. + dec_do_rnn (float): Dropout of RNN layer in decoder. + dec_gru (bool): If True, use GRU, else LSTM in decoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + d_k (int): Dim of channels of attention module. + pred_dropout (float): Dropout probability of prediction layer. + max_seq_len (int): Maximum sequence length for decoding. + mask (bool): If True, mask padding in feature map. + start_idx (int): Index of start token. + padding_idx (int): Index of padding token. + pred_concat (bool): If True, concat glimpse feature from + attention with holistic feature and hidden state. + """ + + def __init__(self, + num_classes=37, + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0.0, + dec_gru=False, + d_model=512, + d_enc=512, + d_k=64, + pred_dropout=0.0, + max_seq_len=40, + mask=True, + start_idx=0, + padding_idx=92, + pred_concat=False, + **kwargs): + super().__init__() + + self.num_classes = num_classes + self.enc_bi_rnn = enc_bi_rnn + self.d_k = d_k + self.start_idx = start_idx + self.max_seq_len = max_seq_len + self.mask = mask + self.pred_concat = pred_concat + + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) + # 2D attention layer + self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) + self.conv3x3_1 = nn.Conv2d( + d_model, d_k, kernel_size=3, stride=1, padding=1) + self.conv1x1_2 = nn.Linear(d_k, 1) + + # Decoder RNN layer + kwargs = dict( + input_size=encoder_rnn_out_size, + hidden_size=encoder_rnn_out_size, + num_layers=2, + batch_first=True, + dropout=dec_do_rnn, + bidirectional=dec_bi_rnn) + if dec_gru: + self.rnn_decoder = nn.GRU(**kwargs) + else: + self.rnn_decoder = nn.LSTM(**kwargs) + + # Decoder input embedding + self.embedding = nn.Embedding( + self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) + + # Prediction layer + self.pred_dropout = nn.Dropout(pred_dropout) + pred_num_classes = num_classes - 1 # ignore padding_idx in prediction + if pred_concat: + fc_in_channel = decoder_rnn_out_size + d_model + d_enc + else: + fc_in_channel = d_model + self.prediction = nn.Linear(fc_in_channel, pred_num_classes) + + def _2d_attention(self, + decoder_input, + feat, + holistic_feat, + valid_ratios=None): + y = self.rnn_decoder(decoder_input)[0] + # y: bsz * (seq_len + 1) * hidden_size + + attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size + bsz, seq_len, attn_size = attn_query.size() + attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1) + + attn_key = self.conv3x3_1(feat) + # bsz * attn_size * h * w + attn_key = attn_key.unsqueeze(1) + # bsz * 1 * attn_size * h * w + + attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) + # bsz * (seq_len + 1) * attn_size * h * w + attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous() + # bsz * (seq_len + 1) * h * w * attn_size + attn_weight = self.conv1x1_2(attn_weight) + # bsz * (seq_len + 1) * h * w * 1 + bsz, T, h, w, c = attn_weight.size() + assert c == 1 + + if valid_ratios is not None: + # cal mask of attention weight + attn_mask = torch.zeros_like(attn_weight) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + attn_mask[i, :, :, valid_width:, :] = 1 + attn_weight = attn_weight.masked_fill(attn_mask.bool(), + float('-inf')) + + attn_weight = attn_weight.view(bsz, T, -1) + attn_weight = F.softmax(attn_weight, dim=-1) + attn_weight = attn_weight.view(bsz, T, h, w, + c).permute(0, 1, 4, 2, 3).contiguous() + + attn_feat = torch.sum( + torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False) + # bsz * (seq_len + 1) * C + + # linear transformation + if self.pred_concat: + hf_c = holistic_feat.size(-1) + holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c) + y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2)) + else: + y = self.prediction(attn_feat) + # bsz * (seq_len + 1) * num_classes + if self.train_mode: + y = self.pred_dropout(y) + + return y + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + if img_metas is not None: + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + targets = targets_dict['padded_targets'].to(feat.device) + tgt_embedding = self.embedding(targets) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + in_dec = torch.cat((out_enc, tgt_embedding), dim=1) + # bsz * (seq_len + 1) * C + out_dec = self._2d_attention( + in_dec, feat, out_enc, valid_ratios=valid_ratios) + # bsz * (seq_len + 1) * num_classes + + return out_dec[:, 1:, :] # bsz * seq_len * num_classes + + def forward_test(self, feat, out_enc, img_metas): + if img_metas is not None: + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + seq_len = self.max_seq_len + + bsz = feat.size(0) + start_token = torch.full((bsz, ), + self.start_idx, + device=feat.device, + dtype=torch.long) + # bsz + start_token = self.embedding(start_token) + # bsz * emb_dim + start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + decoder_input = torch.cat((out_enc, start_token), dim=1) + # bsz * (seq_len + 1) * emb_dim + + outputs = [] + for i in range(1, seq_len + 1): + decoder_output = self._2d_attention( + decoder_input, feat, out_enc, valid_ratios=valid_ratios) + char_output = decoder_output[:, i, :] # bsz * num_classes + char_output = F.softmax(char_output, -1) + outputs.append(char_output) + _, max_idx = torch.max(char_output, dim=1, keepdim=False) + char_embedding = self.embedding(max_idx) # bsz * emb_dim + if i < seq_len: + decoder_input[:, i + 1, :] = char_embedding + + outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes + + return outputs + + +@DECODERS.register_module() +class SequentialSARDecoder(BaseDecoder): + """Implementation Sequential Decoder module in `SAR. + + `_. + + Args: + number_classes (int): Number of output class. + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. + dec_do_rnn (float): Dropout of RNN layer in decoder. + dec_gru (bool): If True, use GRU, else LSTM in decoder. + d_k (int): Dim of conv layers in attention module. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + pred_dropout (float): Dropout probability of prediction layer. + max_seq_len (int): Maximum sequence length during decoding. + mask (bool): If True, mask padding in feature map. + start_idx (int): Index of start token. + padding_idx (int): Index of padding token. + pred_concat (bool): If True, concat glimpse feature from + attention with holistic feature and hidden state. + """ + + def __init__(self, + num_classes=37, + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_gru=False, + d_k=64, + d_model=512, + d_enc=512, + pred_dropout=0.0, + mask=True, + max_seq_len=40, + start_idx=0, + padding_idx=92, + pred_concat=False, + **kwargs): + super().__init__() + + self.num_classes = num_classes + self.enc_bi_rnn = enc_bi_rnn + self.d_k = d_k + self.start_idx = start_idx + self.dec_gru = dec_gru + self.max_seq_len = max_seq_len + self.mask = mask + self.pred_concat = pred_concat + + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) + # 2D attention layer + self.conv1x1_1 = nn.Conv2d( + decoder_rnn_out_size, d_k, kernel_size=1, stride=1) + self.conv3x3_1 = nn.Conv2d( + d_model, d_k, kernel_size=3, stride=1, padding=1) + self.conv1x1_2 = nn.Conv2d(d_k, 1, kernel_size=1, stride=1) + + # Decoder rnn layer + if dec_gru: + self.rnn_decoder_layer1 = nn.GRUCell(encoder_rnn_out_size, + encoder_rnn_out_size) + self.rnn_decoder_layer2 = nn.GRUCell(encoder_rnn_out_size, + encoder_rnn_out_size) + else: + self.rnn_decoder_layer1 = nn.LSTMCell(encoder_rnn_out_size, + encoder_rnn_out_size) + self.rnn_decoder_layer2 = nn.LSTMCell(encoder_rnn_out_size, + encoder_rnn_out_size) + + # Decoder input embedding + self.embedding = nn.Embedding( + self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx) + + # Prediction layer + self.pred_dropout = nn.Dropout(pred_dropout) + pred_num_class = num_classes - 1 # ignore padding index + if pred_concat: + fc_in_channel = decoder_rnn_out_size + d_model + d_enc + else: + fc_in_channel = d_model + self.prediction = nn.Linear(fc_in_channel, pred_num_class) + + def _2d_attention(self, + y_prev, + feat, + holistic_feat, + hx1, + cx1, + hx2, + cx2, + valid_ratios=None): + _, _, h_feat, w_feat = feat.size() + if self.dec_gru: + hx1 = cx1 = self.rnn_decoder_layer1(y_prev, hx1) + hx2 = cx2 = self.rnn_decoder_layer2(hx1, hx2) + else: + hx1, cx1 = self.rnn_decoder_layer1(y_prev, (hx1, cx1)) + hx2, cx2 = self.rnn_decoder_layer2(hx1, (hx2, cx2)) + + tile_hx2 = hx2.view(hx2.size(0), hx2.size(1), 1, 1) + attn_query = self.conv1x1_1(tile_hx2) # bsz * attn_size * 1 * 1 + attn_query = attn_query.expand(-1, -1, h_feat, w_feat) + attn_key = self.conv3x3_1(feat) + attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1)) + attn_weight = self.conv1x1_2(attn_weight) + bsz, c, h, w = attn_weight.size() + assert c == 1 + + if valid_ratios is not None: + # cal mask of attention weight + attn_mask = torch.zeros_like(attn_weight) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + attn_mask[i, :, :, valid_width:] = 1 + attn_weight = attn_weight.masked_fill(attn_mask.bool(), + float('-inf')) + + attn_weight = F.softmax(attn_weight.view(bsz, -1), dim=-1) + attn_weight = attn_weight.view(bsz, c, h, w) + + attn_feat = torch.sum( + torch.mul(feat, attn_weight), (2, 3), keepdim=False) # n * c + + # linear transformation + if self.pred_concat: + y = self.prediction(torch.cat((hx2, attn_feat, holistic_feat), 1)) + else: + y = self.prediction(attn_feat) + + return y, hx1, hx1, hx2, hx2 + + def forward_train(self, feat, out_enc, targets_dict, img_metas=None): + if img_metas is not None: + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + if self.train_mode: + targets = targets_dict['padded_targets'].to(feat.device) + tgt_embedding = self.embedding(targets) + + outputs = [] + start_token = torch.full((feat.size(0), ), + self.start_idx, + device=feat.device, + dtype=torch.long) + start_token = self.embedding(start_token) + for i in range(-1, self.max_seq_len): + if i == -1: + if self.dec_gru: + hx1 = cx1 = self.rnn_decoder_layer1(out_enc) + hx2 = cx2 = self.rnn_decoder_layer2(hx1) + else: + hx1, cx1 = self.rnn_decoder_layer1(out_enc) + hx2, cx2 = self.rnn_decoder_layer2(hx1) + if not self.train_mode: + y_prev = start_token + else: + if self.train_mode: + y_prev = tgt_embedding[:, i, :] + y, hx1, cx1, hx2, cx2 = self._2d_attention( + y_prev, + feat, + out_enc, + hx1, + cx1, + hx2, + cx2, + valid_ratios=valid_ratios) + if self.train_mode: + y = self.pred_dropout(y) + else: + y = F.softmax(y, -1) + _, max_idx = torch.max(y, dim=1, keepdim=False) + char_embedding = self.embedding(max_idx) + y_prev = char_embedding + outputs.append(y) + + outputs = torch.stack(outputs, 1) + + return outputs + + def forward_test(self, feat, out_enc, img_metas): + if img_metas is not None: + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + return self.forward_train(feat, out_enc, None, img_metas) diff --git a/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py new file mode 100755 index 00000000..98094dd4 --- /dev/null +++ b/mmocr/models/textrecog/decoders/sar_decoder_with_bs.py @@ -0,0 +1,148 @@ +from queue import PriorityQueue + +import torch +import torch.nn.functional as F + +import mmocr.utils as utils +from mmocr.models.builder import DECODERS +from . import ParallelSARDecoder + + +class DecodeNode: + """Node class to save decoded char indices and scores. + + Args: + indexes (list[int]): Char indices that decoded yes. + scores (list[float]): Char scores that decoded yes. + """ + + def __init__(self, indexes=[1], scores=[0.9]): + assert utils.is_type_list(indexes, int) + assert utils.is_type_list(scores, float) + assert utils.equal_len(indexes, scores) + + self.indexes = indexes + self.scores = scores + + def eval(self): + """Calculate accumulated score.""" + accu_score = sum(self.scores) + return accu_score + + +@DECODERS.register_module() +class ParallelSARDecoderWithBS(ParallelSARDecoder): + """Parallel Decoder module with beam-search in SAR. + + Args: + beam_width (int): Width for beam search. + """ + + def __init__(self, + beam_width=5, + num_classes=37, + enc_bi_rnn=False, + dec_bi_rnn=False, + dec_do_rnn=0, + dec_gru=False, + d_model=512, + d_enc=512, + d_k=64, + pred_dropout=0.0, + max_seq_len=40, + mask=True, + start_idx=0, + padding_idx=0, + pred_concat=False, + **kwargs): + super().__init__(num_classes, enc_bi_rnn, dec_bi_rnn, dec_do_rnn, + dec_gru, d_model, d_enc, d_k, pred_dropout, + max_seq_len, mask, start_idx, padding_idx, + pred_concat) + assert isinstance(beam_width, int) + assert beam_width > 0 + + self.beam_width = beam_width + + def forward_test(self, feat, out_enc, img_metas): + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + seq_len = self.max_seq_len + bsz = feat.size(0) + assert bsz == 1, 'batch size must be 1 for beam search.' + + start_token = torch.full((bsz, ), + self.start_idx, + device=feat.device, + dtype=torch.long) + # bsz + start_token = self.embedding(start_token) + # bsz * emb_dim + start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1) + # bsz * seq_len * emb_dim + out_enc = out_enc.unsqueeze(1) + # bsz * 1 * emb_dim + decoder_input = torch.cat((out_enc, start_token), dim=1) + # bsz * (seq_len + 1) * emb_dim + + # Initialize beam-search queue + q = PriorityQueue() + init_node = DecodeNode([self.start_idx], [0.0]) + q.put((-init_node.eval(), init_node)) + + for i in range(1, seq_len + 1): + next_nodes = [] + beam_width = self.beam_width if i > 1 else 1 + for _ in range(beam_width): + _, node = q.get() + + input_seq = torch.clone(decoder_input) # bsz * T * emb_dim + # fill previous input tokens (step 1...i) in input_seq + for t, index in enumerate(node.indexes): + input_token = torch.full((bsz, ), + index, + device=input_seq.device, + dtype=torch.long) + input_token = self.embedding(input_token) # bsz * emb_dim + input_seq[:, t + 1, :] = input_token + + output_seq = self._2d_attention( + input_seq, feat, out_enc, valid_ratios=valid_ratios) + + output_char = output_seq[:, i, :] # bsz * num_classes + output_char = F.softmax(output_char, -1) + topk_value, topk_idx = output_char.topk(self.beam_width, dim=1) + topk_value, topk_idx = topk_value.squeeze(0), topk_idx.squeeze( + 0) + + for k in range(self.beam_width): + kth_score = topk_value[k].item() + kth_idx = topk_idx[k].item() + next_node = DecodeNode(node.indexes + [kth_idx], + node.scores + [kth_score]) + delta = k * 1e-6 + next_nodes.append( + (-node.eval() - kth_score - delta, next_node)) + # Use minus since priority queue sort + # with ascending order + + while not q.empty(): + q.get() + + # Put all candidates to queue + for next_node in next_nodes: + q.put(next_node) + + best_node = q.get() + num_classes = self.num_classes - 1 # ignore padding index + outputs = torch.zeros(bsz, seq_len, num_classes) + for i in range(seq_len): + idx = best_node[1].indexes[i + 1] + outputs[0, i, idx] = best_node[1].scores[i + 1] + + return outputs diff --git a/mmocr/models/textrecog/decoders/sequence_attention_decoder.py b/mmocr/models/textrecog/decoders/sequence_attention_decoder.py new file mode 100644 index 00000000..6e7aa9e1 --- /dev/null +++ b/mmocr/models/textrecog/decoders/sequence_attention_decoder.py @@ -0,0 +1,165 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmocr.models.builder import DECODERS +from mmocr.models.textrecog.layers import DotProductAttentionLayer +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class SequenceAttentionDecoder(BaseDecoder): + + def __init__(self, + num_classes=None, + rnn_layers=2, + dim_input=512, + dim_model=128, + max_seq_len=40, + start_idx=0, + mask=True, + padding_idx=None, + dropout_ratio=0, + return_feature=False, + encode_value=False): + super().__init__() + + self.num_classes = num_classes + self.dim_input = dim_input + self.dim_model = dim_model + self.return_feature = return_feature + self.encode_value = encode_value + self.max_seq_len = max_seq_len + self.start_idx = start_idx + self.mask = mask + + self.embedding = nn.Embedding( + self.num_classes, self.dim_model, padding_idx=padding_idx) + + self.sequence_layer = nn.LSTM( + input_size=dim_model, + hidden_size=dim_model, + num_layers=rnn_layers, + batch_first=True, + dropout=dropout_ratio) + + self.attention_layer = DotProductAttentionLayer() + + self.prediction = None + if not self.return_feature: + pred_num_classes = num_classes - 1 + self.prediction = nn.Linear( + dim_model if encode_value else dim_input, pred_num_classes) + + def init_weights(self): + pass + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + targets = targets_dict['padded_targets'].to(feat.device) + tgt_embedding = self.embedding(targets) + + n, c_enc, h, w = out_enc.size() + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.size() + assert c_feat == self.dim_input + _, len_q, c_q = tgt_embedding.size() + assert c_q == self.dim_model + assert len_q <= self.max_seq_len + + query, _ = self.sequence_layer(tgt_embedding) + query = query.permute(0, 2, 1).contiguous() + key = out_enc.view(n, c_enc, h * w) + if self.encode_value: + value = key + else: + value = feat.view(n, c_feat, h * w) + + mask = None + if valid_ratios is not None: + mask = query.new_zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + mask[i, :, valid_width:] = 1 + mask = mask.bool() + mask = mask.view(n, h * w) + + attn_out = self.attention_layer(query, key, value, mask) + attn_out = attn_out.permute(0, 2, 1).contiguous() + + if self.return_feature: + return attn_out + + out = self.prediction(attn_out) + + return out + + def forward_test(self, feat, out_enc, img_metas): + seq_len = self.max_seq_len + batch_size = feat.size(0) + + decode_sequence = (feat.new_ones( + (batch_size, seq_len)) * self.start_idx).long() + + outputs = [] + for i in range(seq_len): + step_out = self.forward_test_step(feat, out_enc, decode_sequence, + i, img_metas) + outputs.append(step_out) + _, max_idx = torch.max(step_out, dim=1, keepdim=False) + if i < seq_len - 1: + decode_sequence[:, i + 1] = max_idx + + outputs = torch.stack(outputs, 1) + + return outputs + + def forward_test_step(self, feat, out_enc, decode_sequence, current_step, + img_metas): + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + embed = self.embedding(decode_sequence) + + n, c_enc, h, w = out_enc.size() + assert c_enc == self.dim_model + _, c_feat, _, _ = feat.size() + assert c_feat == self.dim_input + _, _, c_q = embed.size() + assert c_q == self.dim_model + + query, _ = self.sequence_layer(embed) + query = query.permute(0, 2, 1).contiguous() + key = out_enc.view(n, c_enc, h * w) + if self.encode_value: + value = key + else: + value = feat.view(n, c_feat, h * w) + + mask = None + if valid_ratios is not None: + mask = query.new_zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + mask[i, :, valid_width:] = 1 + mask = mask.bool() + mask = mask.view(n, h * w) + + # [n, c, l] + attn_out = self.attention_layer(query, key, value, mask) + + out = attn_out[:, :, current_step] + + if self.return_feature: + return out + + out = self.prediction(out) + out = F.softmax(out, dim=-1) + + return out diff --git a/mmocr/models/textrecog/decoders/transformer_decoder.py b/mmocr/models/textrecog/decoders/transformer_decoder.py new file mode 100644 index 00000000..164cdae9 --- /dev/null +++ b/mmocr/models/textrecog/decoders/transformer_decoder.py @@ -0,0 +1,133 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmocr.models.builder import DECODERS +from mmocr.models.textrecog.layers.transformer_layer import ( + PositionalEncoding, TransformerDecoderLayer, get_pad_mask, + get_subsequent_mask) +from .base_decoder import BaseDecoder + + +@DECODERS.register_module() +class TFDecoder(BaseDecoder): + """Transformer Decoder block with self attention mechanism.""" + + def __init__(self, + n_layers=6, + d_embedding=512, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + d_inner=256, + n_position=200, + dropout=0.1, + num_classes=93, + max_seq_len=40, + start_idx=1, + padding_idx=92, + **kwargs): + super().__init__() + + self.padding_idx = padding_idx + self.start_idx = start_idx + self.max_seq_len = max_seq_len + + self.trg_word_emb = nn.Embedding( + num_classes, d_embedding, padding_idx=padding_idx) + + self.position_enc = PositionalEncoding( + d_embedding, n_position=n_position) + self.dropout = nn.Dropout(p=dropout) + + self.layer_stack = nn.ModuleList([ + TransformerDecoderLayer( + d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + pred_num_class = num_classes - 1 # ignore padding_idx + self.classifier = nn.Linear(d_model, pred_num_class) + + def _attention(self, trg_seq, src, src_mask=None): + trg_embedding = self.trg_word_emb(trg_seq) + trg_pos_encoded = self.position_enc(trg_embedding) + tgt = self.dropout(trg_pos_encoded) + + trg_mask = get_pad_mask( + trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq) + output = tgt + for dec_layer in self.layer_stack: + output = dec_layer( + output, + src, + self_attn_mask=trg_mask, + dec_enc_attn_mask=src_mask) + output = self.layer_norm(output) + + return output + + def forward_train(self, feat, out_enc, targets_dict, img_metas): + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + n, c, h, w = out_enc.size() + src_mask = None + if valid_ratios is not None: + src_mask = out_enc.new_zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + src_mask[i, :, :valid_width] = 1 + src_mask = src_mask.view(n, h * w) + out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1) + out_enc = out_enc.contiguous() + targets = targets_dict['padded_targets'].to(out_enc.device) + attn_output = self._attention(targets, out_enc, src_mask=src_mask) + outputs = self.classifier(attn_output) + return outputs + + def forward_test(self, feat, out_enc, img_metas): + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + n, c, h, w = out_enc.size() + src_mask = None + if valid_ratios is not None: + src_mask = out_enc.new_zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + src_mask[i, :, :valid_width] = 1 + src_mask = src_mask.view(n, h * w) + out_enc = out_enc.view(n, c, h * w).permute(0, 2, 1) + out_enc = out_enc.contiguous() + + init_target_seq = torch.full((n, self.max_seq_len + 1), + self.padding_idx, + device=out_enc.device, + dtype=torch.long) + # bsz * seq_len + init_target_seq[:, 0] = self.start_idx + + outputs = [] + for step in range(0, self.max_seq_len): + decoder_output = self._attention( + init_target_seq, out_enc, src_mask=src_mask) + # bsz * seq_len * 512 + step_result = F.softmax( + self.classifier(decoder_output[:, step, :]), dim=-1) + # bsz * num_classes + outputs.append(step_result) + _, step_max_index = torch.max(step_result, dim=-1) + init_target_seq[:, step + 1] = step_max_index + + outputs = torch.stack(outputs, dim=1) + + return outputs diff --git a/mmocr/models/textrecog/encoders/__init__.py b/mmocr/models/textrecog/encoders/__init__.py new file mode 100755 index 00000000..e0d9394a --- /dev/null +++ b/mmocr/models/textrecog/encoders/__init__.py @@ -0,0 +1,6 @@ +from .base_encoder import BaseEncoder +from .channel_reduction_encoder import ChannelReductionEncoder +from .sar_encoder import SAREncoder +from .transformer_encoder import TFEncoder + +__all__ = ['SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder'] diff --git a/mmocr/models/textrecog/encoders/base_encoder.py b/mmocr/models/textrecog/encoders/base_encoder.py new file mode 100644 index 00000000..3dadc687 --- /dev/null +++ b/mmocr/models/textrecog/encoders/base_encoder.py @@ -0,0 +1,14 @@ +import torch.nn as nn + +from mmocr.models.builder import ENCODERS + + +@ENCODERS.register_module() +class BaseEncoder(nn.Module): + """Base Encoder class for text recognition.""" + + def init_weights(self): + pass + + def forward(self, feat, **kwargs): + return feat diff --git a/mmocr/models/textrecog/encoders/channel_reduction_encoder.py b/mmocr/models/textrecog/encoders/channel_reduction_encoder.py new file mode 100644 index 00000000..0eae4c14 --- /dev/null +++ b/mmocr/models/textrecog/encoders/channel_reduction_encoder.py @@ -0,0 +1,23 @@ +import torch.nn as nn +from mmcv.cnn import xavier_init + +from mmocr.models.builder import ENCODERS +from .base_encoder import BaseEncoder + + +@ENCODERS.register_module() +class ChannelReductionEncoder(BaseEncoder): + + def __init__(self, in_channels, out_channels): + super().__init__() + + self.layer = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + + def forward(self, feat, img_metas=None): + return self.layer(feat) diff --git a/mmocr/models/textrecog/encoders/sar_encoder.py b/mmocr/models/textrecog/encoders/sar_encoder.py new file mode 100644 index 00000000..342472a6 --- /dev/null +++ b/mmocr/models/textrecog/encoders/sar_encoder.py @@ -0,0 +1,105 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import uniform_init, xavier_init + +import mmocr.utils as utils +from mmocr.models.builder import ENCODERS +from .base_encoder import BaseEncoder + + +@ENCODERS.register_module() +class SAREncoder(BaseEncoder): + """Implementation of encoder module in `SAR. + + `_ + + Args: + enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. + enc_do_rnn (float): Dropout probability of RNN layer in encoder. + enc_gru (bool): If True, use GRU, else LSTM in encoder. + d_model (int): Dim of channels from backbone. + d_enc (int): Dim of encoder RNN layer. + mask (bool): If True, mask padding in RNN sequence. + """ + + def __init__(self, + enc_bi_rnn=False, + enc_do_rnn=0.0, + enc_gru=False, + d_model=512, + d_enc=512, + mask=True, + **kwargs): + super().__init__() + assert isinstance(enc_bi_rnn, bool) + assert isinstance(enc_do_rnn, (int, float)) + assert 0 <= enc_do_rnn < 1.0 + assert isinstance(enc_gru, bool) + assert isinstance(d_model, int) + assert isinstance(d_enc, int) + assert isinstance(mask, bool) + + self.enc_bi_rnn = enc_bi_rnn + self.enc_do_rnn = enc_do_rnn + self.mask = mask + + # LSTM Encoder + kwargs = dict( + input_size=d_model, + hidden_size=d_enc, + num_layers=2, + batch_first=True, + dropout=enc_do_rnn, + bidirectional=enc_bi_rnn) + if enc_gru: + self.rnn_encoder = nn.GRU(**kwargs) + else: + self.rnn_encoder = nn.LSTM(**kwargs) + + # global feature transformation + encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) + self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) + + def init_weights(self): + # initialize weight and bias + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m) + elif isinstance(m, nn.BatchNorm2d): + uniform_init(m) + + def forward(self, feat, img_metas=None): + if img_metas is not None: + assert utils.is_type_list(img_metas, dict) + assert len(img_metas) == feat.size(0) + + valid_ratios = None + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] if self.mask else None + + h_feat = feat.size(2) + feat_v = F.max_pool2d( + feat, kernel_size=(h_feat, 1), stride=1, padding=0) + feat_v = feat_v.squeeze(2) # bsz * C * W + feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C + + holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C + + if valid_ratios is not None: + valid_hf = [] + T = holistic_feat.size(1) + for i, valid_ratio in enumerate(valid_ratios): + valid_step = min(T, math.ceil(T * valid_ratio)) - 1 + valid_hf.append(holistic_feat[i, valid_step, :]) + valid_hf = torch.stack(valid_hf, dim=0) + else: + valid_hf = holistic_feat[:, -1, :] # bsz * C + + holistic_feat = self.linear(valid_hf) # bsz * C + + return holistic_feat diff --git a/mmocr/models/textrecog/encoders/transformer_encoder.py b/mmocr/models/textrecog/encoders/transformer_encoder.py new file mode 100644 index 00000000..c07b7754 --- /dev/null +++ b/mmocr/models/textrecog/encoders/transformer_encoder.py @@ -0,0 +1,54 @@ +import math + +import torch.nn as nn + +from mmocr.models.builder import ENCODERS +from mmocr.models.textrecog.layers import TransformerEncoderLayer +from .base_encoder import BaseEncoder + + +@ENCODERS.register_module() +class TFEncoder(BaseEncoder): + """Encode 2d feature map to 1d sequence.""" + + def __init__(self, + n_layers=6, + n_head=8, + d_k=64, + d_v=64, + d_model=512, + d_inner=256, + dropout=0.1, + **kwargs): + super().__init__() + self.d_model = d_model + self.layer_stack = nn.ModuleList([ + TransformerEncoderLayer( + d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers) + ]) + self.layer_norm = nn.LayerNorm(d_model) + + def forward(self, feat, img_metas=None): + valid_ratios = [1.0 for _ in range(feat.size(0))] + if img_metas is not None: + valid_ratios = [ + img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + ] + n, c, h, w = feat.size() + mask = feat.new_zeros((n, h, w)) + for i, valid_ratio in enumerate(valid_ratios): + valid_width = min(w, math.ceil(w * valid_ratio)) + mask[i, :, :valid_width] = 1 + mask = mask.view(n, h * w) + feat = feat.view(n, c, h * w) + + output = feat.permute(0, 2, 1).contiguous() + for enc_layer in self.layer_stack: + output = enc_layer(output, mask) + output = self.layer_norm(output) + + output = output.permute(0, 2, 1).contiguous() + output = output.view(n, self.d_model, h, w) + + return output diff --git a/mmocr/models/textrecog/heads/__init__.py b/mmocr/models/textrecog/heads/__init__.py new file mode 100755 index 00000000..761bb9a9 --- /dev/null +++ b/mmocr/models/textrecog/heads/__init__.py @@ -0,0 +1,3 @@ +from .seg_head import SegHead + +__all__ = ['SegHead'] diff --git a/mmocr/models/textrecog/heads/seg_head.py b/mmocr/models/textrecog/heads/seg_head.py new file mode 100644 index 00000000..a0ca59cc --- /dev/null +++ b/mmocr/models/textrecog/heads/seg_head.py @@ -0,0 +1,50 @@ +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from torch import nn + +from mmdet.models.builder import HEADS + + +@HEADS.register_module() +class SegHead(nn.Module): + """Head for segmentation based text recognition. + + Args: + in_channels (int): Number of input channels. + num_classes (int): Number of output classes. + upsample_param (dict | None): Config dict for interpolation layer. + Default: `dict(scale_factor=1.0, mode='nearest')` + """ + + def __init__(self, in_channels=128, num_classes=37, upsample_param=None): + super().__init__() + assert isinstance(num_classes, int) + assert num_classes > 0 + assert upsample_param is None or isinstance(upsample_param, dict) + + self.upsample_param = upsample_param + + self.seg_conv = ConvModule( + in_channels, + in_channels, + 3, + stride=1, + padding=1, + norm_cfg=dict(type='BN')) + + # prediction + self.pred_conv = nn.Conv2d( + in_channels, num_classes, kernel_size=1, stride=1, padding=0) + + def init_weights(self): + pass + + def forward(self, out_neck): + + seg_map = self.seg_conv(out_neck[-1]) + seg_map = self.pred_conv(seg_map) + + if self.upsample_param is not None: + seg_map = F.interpolate(seg_map, **self.upsample_param) + + return seg_map diff --git a/mmocr/models/textrecog/layers/__init__.py b/mmocr/models/textrecog/layers/__init__.py new file mode 100755 index 00000000..7d85a865 --- /dev/null +++ b/mmocr/models/textrecog/layers/__init__.py @@ -0,0 +1,18 @@ +from .conv_layer import BasicBlock, Bottleneck +from .dot_product_attention_layer import DotProductAttentionLayer +from .lstm_layer import BidirectionalLSTM +from .position_aware_layer import PositionAwareLayer +from .robust_scanner_fusion_layer import RobustScannerFusionLayer +from .transformer_layer import (MultiHeadAttention, PositionalEncoding, + PositionwiseFeedForward, + TransformerDecoderLayer, + TransformerEncoderLayer, get_pad_mask, + get_subsequent_mask) + +__all__ = [ + 'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding', + 'PositionwiseFeedForward', 'BasicBlock', 'Bottleneck', + 'RobustScannerFusionLayer', 'DotProductAttentionLayer', + 'PositionAwareLayer', 'get_pad_mask', 'get_subsequent_mask', + 'TransformerDecoderLayer', 'TransformerEncoderLayer' +] diff --git a/mmocr/models/textrecog/layers/conv_layer.py b/mmocr/models/textrecog/layers/conv_layer.py new file mode 100644 index 00000000..d0ce32a3 --- /dev/null +++ b/mmocr/models/textrecog/layers/conv_layer.py @@ -0,0 +1,93 @@ +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1): + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=False): + super().__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + if downsample: + self.downsample = nn.Sequential( + nn.Conv2d( + inplanes, planes * self.expansion, 1, stride, bias=False), + nn.BatchNorm2d(planes * self.expansion), + ) + else: + self.downsample = nn.Sequential() + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=False): + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + if downsample: + self.downsample = nn.Sequential( + nn.Conv2d( + inplanes, planes * self.expansion, 1, stride, bias=False), + nn.BatchNorm2d(planes * self.expansion), + ) + else: + self.downsample = nn.Sequential() + + def forward(self, x): + residual = self.downsample(x) + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out diff --git a/mmocr/models/textrecog/layers/dot_product_attention_layer.py b/mmocr/models/textrecog/layers/dot_product_attention_layer.py new file mode 100644 index 00000000..efa55a8c --- /dev/null +++ b/mmocr/models/textrecog/layers/dot_product_attention_layer.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DotProductAttentionLayer(nn.Module): + + def __init__(self, dim_model=None): + super().__init__() + + self.scale = dim_model**-0.5 if dim_model is not None else 1. + + def forward(self, query, key, value, mask=None): + n, seq_len = mask.size() + logits = torch.matmul(query.permute(0, 2, 1), key) * self.scale + + if mask is not None: + mask = mask.view(n, 1, seq_len) + logits = logits.masked_fill(mask, float('-inf')) + + weights = F.softmax(logits, dim=2) + + glimpse = torch.matmul(weights, value.transpose(1, 2)) + + glimpse = glimpse.permute(0, 2, 1).contiguous() + + return glimpse diff --git a/mmocr/models/textrecog/layers/lstm_layer.py b/mmocr/models/textrecog/layers/lstm_layer.py new file mode 100644 index 00000000..e4017d02 --- /dev/null +++ b/mmocr/models/textrecog/layers/lstm_layer.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class BidirectionalLSTM(nn.Module): + + def __init__(self, nIn, nHidden, nOut): + super().__init__() + + self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) + self.embedding = nn.Linear(nHidden * 2, nOut) + + def forward(self, input): + recurrent, _ = self.rnn(input) + T, b, h = recurrent.size() + t_rec = recurrent.view(T * b, h) + + output = self.embedding(t_rec) # [T * b, nOut] + output = output.view(T, b, -1) + + return output diff --git a/mmocr/models/textrecog/layers/position_aware_layer.py b/mmocr/models/textrecog/layers/position_aware_layer.py new file mode 100644 index 00000000..cf8cf27d --- /dev/null +++ b/mmocr/models/textrecog/layers/position_aware_layer.py @@ -0,0 +1,35 @@ +import torch.nn as nn + + +class PositionAwareLayer(nn.Module): + + def __init__(self, dim_model, rnn_layers=2): + super().__init__() + + self.dim_model = dim_model + + self.rnn = nn.LSTM( + input_size=dim_model, + hidden_size=dim_model, + num_layers=rnn_layers, + batch_first=True) + + self.mixer = nn.Sequential( + nn.Conv2d( + dim_model, dim_model, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d( + dim_model, dim_model, kernel_size=3, stride=1, padding=1)) + + def forward(self, img_feature): + n, c, h, w = img_feature.size() + + rnn_input = img_feature.permute(0, 2, 3, 1).contiguous() + rnn_input = rnn_input.view(n * h, w, c) + rnn_output, _ = self.rnn(rnn_input) + rnn_output = rnn_output.view(n, h, w, c) + rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous() + + out = self.mixer(rnn_output) + + return out diff --git a/mmocr/models/textrecog/layers/robust_scanner_fusion_layer.py b/mmocr/models/textrecog/layers/robust_scanner_fusion_layer.py new file mode 100644 index 00000000..30a8421d --- /dev/null +++ b/mmocr/models/textrecog/layers/robust_scanner_fusion_layer.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + + +class RobustScannerFusionLayer(nn.Module): + + def __init__(self, dim_model, dim=-1): + super().__init__() + + self.dim_model = dim_model + self.dim = dim + + self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2) + self.glu_layer = nn.GLU(dim=dim) + + def forward(self, x0, x1): + assert x0.size() == x1.size() + fusion_input = torch.cat([x0, x1], self.dim) + output = self.linear_layer(fusion_input) + output = self.glu_layer(output) + + return output diff --git a/mmocr/models/textrecog/layers/transformer_layer.py b/mmocr/models/textrecog/layers/transformer_layer.py new file mode 100644 index 00000000..8377827b --- /dev/null +++ b/mmocr/models/textrecog/layers/transformer_layer.py @@ -0,0 +1,230 @@ +"""This code is from https://github.com/jadore801120/attention-is-all-you-need- +pytorch.""" +import numpy as np +import torch +import torch.nn as nn + + +class TransformerEncoderLayer(nn.Module): + """""" + + def __init__(self, + d_model=512, + d_inner=256, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False, + mask_value=0, + act_layer=nn.GELU): + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.attn = MultiHeadAttention( + n_head, + d_model, + d_k, + d_v, + qkv_bias=qkv_bias, + dropout=dropout, + mask_value=mask_value) + self.norm2 = nn.LayerNorm(d_model) + self.mlp = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout, act_layer=act_layer) + + def forward(self, x, mask=None): + residual = x + x = self.norm1(x) + x = residual + self.attn(x, x, x, mask) + residual = x + x = self.norm2(x) + x = residual + self.mlp(x) + + return x + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, + d_model=512, + d_inner=256, + n_head=8, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False, + mask_value=0, + act_layer=nn.GELU): + super().__init__() + self.self_attn = MultiHeadAttention() + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + self.self_attn = MultiHeadAttention( + n_head, + d_model, + d_k, + d_v, + dropout=dropout, + qkv_bias=qkv_bias, + mask_value=mask_value) + self.enc_attn = MultiHeadAttention( + n_head, + d_model, + d_k, + d_v, + dropout=dropout, + qkv_bias=qkv_bias, + mask_value=mask_value) + self.mlp = PositionwiseFeedForward( + d_model, d_inner, dropout=dropout, act_layer=act_layer) + + def forward(self, + dec_input, + enc_output, + self_attn_mask=None, + dec_enc_attn_mask=None): + self_attn_in = self.norm1(dec_input) + self_attn_out = self.self_attn(self_attn_in, self_attn_in, + self_attn_in, self_attn_mask) + enc_attn_in = dec_input + self_attn_out + + enc_attn_q = self.norm2(enc_attn_in) + enc_attn_out = self.enc_attn(enc_attn_q, enc_output, enc_output, + dec_enc_attn_mask) + + mlp_in = enc_attn_in + enc_attn_out + mlp_out = self.mlp(self.norm3(mlp_in)) + out = mlp_in + mlp_out + + return out + + +class MultiHeadAttention(nn.Module): + """Multi-Head Attention module.""" + + def __init__(self, + n_head=8, + d_model=512, + d_k=64, + d_v=64, + dropout=0.1, + qkv_bias=False, + mask_value=0): + super().__init__() + + self.mask_value = mask_value + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.scale = d_k**-0.5 + + self.dim_k = n_head * d_k + self.dim_v = n_head * d_v + + self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) + + self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias=qkv_bias) + + self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias=qkv_bias) + + self.fc = nn.Linear(self.dim_v, d_model, bias=qkv_bias) + + self.attn_drop = nn.Dropout(dropout) + self.proj_drop = nn.Dropout(dropout) + + def forward(self, q, k, v, mask=None): + batch_size, len_q, _ = q.size() + _, len_k, _ = k.size() + + q = self.linear_q(q).view(batch_size, len_q, self.n_head, self.d_k) + k = self.linear_k(k).view(batch_size, len_k, self.n_head, self.d_k) + v = self.linear_v(v).view(batch_size, len_k, self.n_head, self.d_v) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 3, 1) + v = v.permute(0, 2, 1, 3) + + logits = torch.matmul(q, k) * self.scale + + if mask is not None: + if mask.dim() == 3: + mask = mask.unsqueeze(1) + elif mask.dim() == 2: + mask = mask.unsqueeze(1).unsqueeze(1) + logits = logits.masked_fill(mask == self.mask_value, float('-inf')) + weights = logits.softmax(dim=-1) + weights = self.attn_drop(weights) + + attn_out = torch.matmul(weights, v).transpose(1, 2) + attn_out = attn_out.reshape(batch_size, len_q, self.dim_v) + attn_out = self.fc(attn_out) + attn_out = self.proj_drop(attn_out) + + return attn_out + + +class PositionwiseFeedForward(nn.Module): + """A two-feed-forward-layer module.""" + + def __init__(self, d_in, d_hid, dropout=0.1, act_layer=nn.GELU): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) + self.w_2 = nn.Linear(d_hid, d_in) + self.act = act_layer() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.w_1(x) + x = self.act(x) + x = self.dropout(x) + x = self.w_2(x) + x = self.dropout(x) + + return x + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_hid=512, n_position=200): + super().__init__() + + # Not a parameter + self.register_buffer( + 'position_table', + self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """Sinusoid position encoding table.""" + denominator = torch.Tensor([ + 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ]) + denominator = denominator.view(1, -1) + pos_tensor = torch.arange(n_position).unsqueeze(-1).float() + sinusoid_table = pos_tensor * denominator + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) + + return sinusoid_table.unsqueeze(0) + + def forward(self, x): + self.device = x.device + return x + self.position_table[:, :x.size(1)].clone().detach() + + +def get_pad_mask(seq, pad_idx): + return (seq != pad_idx).unsqueeze(-2) + + +def get_subsequent_mask(seq): + """For masking out the subsequent info.""" + len_s = seq.size(1) + subsequent_mask = 1 - torch.triu( + torch.ones((len_s, len_s), device=seq.device), diagonal=1) + subsequent_mask = subsequent_mask.unsqueeze(0).bool() + return subsequent_mask diff --git a/mmocr/models/textrecog/losses/__init__.py b/mmocr/models/textrecog/losses/__init__.py new file mode 100755 index 00000000..226aa006 --- /dev/null +++ b/mmocr/models/textrecog/losses/__init__.py @@ -0,0 +1,5 @@ +from .ce_loss import CELoss, SARLoss, TFLoss +from .ctc_loss import CTCLoss +from .seg_loss import SegLoss + +__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss'] diff --git a/mmocr/models/textrecog/losses/ce_loss.py b/mmocr/models/textrecog/losses/ce_loss.py new file mode 100644 index 00000000..4fad5ae4 --- /dev/null +++ b/mmocr/models/textrecog/losses/ce_loss.py @@ -0,0 +1,94 @@ +import torch.nn as nn + +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class CELoss(nn.Module): + """Implementation of loss module for encoder-decoder based text recognition + method with CrossEntropy loss. + + Args: + ignore_index (int): Specifies a target value that is + ignored and does not contribute to the input gradient. + reduction (str): Specifies the reduction to apply to the output, + should be one of the following: ('none', 'mean', 'sum'). + """ + + def __init__(self, ignore_index=-1, reduction='none'): + super().__init__() + assert isinstance(ignore_index, int) + assert isinstance(reduction, str) + assert reduction in ['none', 'mean', 'sum'] + + self.loss_ce = nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction=reduction) + + def format(self, outputs, targets_dict): + targets = targets_dict['padded_targets'] + + return outputs.permute(0, 2, 1).contiguous(), targets + + def forward(self, outputs, targets_dict): + outputs, targets = self.format(outputs, targets_dict) + + loss_ce = self.loss_ce(outputs, targets.to(outputs.device)) + losses = dict(loss_ce=loss_ce) + + return losses + + +@LOSSES.register_module() +class SARLoss(CELoss): + """Implementation of loss module in `SAR. + + `_. + + Args: + ignore_index (int): Specifies a target value that is + ignored and does not contribute to the input gradient. + reduction (str): Specifies the reduction to apply to the output, + should be one of the following: ('none', 'mean', 'sum'). + """ + + def __init__(self, ignore_index=0, reduction='mean', **kwargs): + super().__init__(ignore_index, reduction) + + def format(self, outputs, targets_dict): + targets = targets_dict['padded_targets'] + # targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...] + # outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...] + + # ignore first index of target in loss calculation + targets = targets[:, 1:].contiguous() + # ignore last index of outputs to be in same seq_len with targets + outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous() + + return outputs, targets + + +@LOSSES.register_module() +class TFLoss(CELoss): + """Implementation of loss module for transformer.""" + + def __init__(self, + ignore_index=-1, + reduction='none', + flatten=True, + **kwargs): + super().__init__(ignore_index, reduction) + assert isinstance(flatten, bool) + + self.flatten = flatten + + def format(self, outputs, targets_dict): + outputs = outputs[:, :-1, :].contiguous() + targets = targets_dict['padded_targets'] + targets = targets[:, 1:].contiguous() + if self.flatten: + outputs = outputs.view(-1, outputs.size(-1)) + targets = targets.view(-1) + else: + outputs = outputs.permute(0, 2, 1).contiguous() + + return outputs, targets diff --git a/mmocr/models/textrecog/losses/ctc_loss.py b/mmocr/models/textrecog/losses/ctc_loss.py new file mode 100644 index 00000000..231469a2 --- /dev/null +++ b/mmocr/models/textrecog/losses/ctc_loss.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn + +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class CTCLoss(nn.Module): + """Implementation of loss module for CTC-loss based text recognition. + + Args: + flatten (bool): If True, use flattened targets, else padded targets. + blank (int): Blank label. Default 0. + reduction (str): Specifies the reduction to apply to the output, + should be one of the following: ('none', 'mean', 'sum'). + zero_infinity (bool): Whether to zero infinite losses and + the associated gradients. Default: False. + Infinite losses mainly occur when the inputs + are too short to be aligned to the targets. + """ + + def __init__(self, + flatten=True, + blank=0, + reduction='mean', + zero_infinity=False, + **kwargs): + super().__init__() + assert isinstance(flatten, bool) + assert isinstance(blank, int) + assert isinstance(reduction, str) + assert isinstance(zero_infinity, bool) + + self.flatten = flatten + self.blank = blank + self.ctc_loss = nn.CTCLoss( + blank=blank, reduction=reduction, zero_infinity=zero_infinity) + + def forward(self, outputs, targets_dict): + + outputs = torch.log_softmax(outputs, dim=2) + bsz, seq_len = outputs.size(0), outputs.size(1) + input_lengths = torch.full( + size=(bsz, ), fill_value=seq_len, dtype=torch.long) + outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C + + if self.flatten: + targets = targets_dict['flatten_targets'] + else: + targets = torch.full( + size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long) + for idx, tensor in enumerate(targets_dict['targets']): + valid_len = min(tensor.size(0), seq_len) + targets[idx, :valid_len] = tensor[:valid_len] + + target_lengths = targets_dict['target_lengths'] + + loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths, + target_lengths) + + losses = dict(loss_ctc=loss_ctc) + + return losses diff --git a/mmocr/models/textrecog/losses/seg_loss.py b/mmocr/models/textrecog/losses/seg_loss.py new file mode 100644 index 00000000..9a2911e8 --- /dev/null +++ b/mmocr/models/textrecog/losses/seg_loss.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmdet.models.builder import LOSSES + + +@LOSSES.register_module() +class SegLoss(nn.Module): + """Implementation of loss module for segmentation based text recognition + method. + + Args: + seg_downsample_ratio (float): Downsample ratio of + segmentation map. + seg_with_loss_weight (bool): If True, set weight for + segmentation loss. + ignore_index (int): Specifies a target value that is ignored + and does not contribute to the input gradient. + """ + + def __init__(self, + seg_downsample_ratio=0.5, + seg_with_loss_weight=True, + ignore_index=255, + **kwargs): + super().__init__() + + assert isinstance(seg_downsample_ratio, (int, float)) + assert 0 < seg_downsample_ratio <= 1 + assert isinstance(ignore_index, int) + + self.seg_downsample_ratio = seg_downsample_ratio + self.seg_with_loss_weight = seg_with_loss_weight + self.ignore_index = ignore_index + + def seg_loss(self, out_head, gt_kernels): + seg_map = out_head # bsz * num_classes * H/2 * W/2 + seg_target = [ + item[1].rescale(self.seg_downsample_ratio).to_tensor( + torch.long, seg_map.device) for item in gt_kernels + ] + seg_target = torch.stack(seg_target).squeeze(1) + + loss_weight = None + if self.seg_with_loss_weight: + N = torch.sum(seg_target != self.ignore_index) + N_neg = torch.sum(seg_target == 0) + weight_val = 1.0 * N_neg / (N - N_neg) + loss_weight = torch.ones(seg_map.size(1), device=seg_map.device) + loss_weight[1:] = weight_val + + loss_seg = F.cross_entropy( + seg_map, + seg_target, + weight=loss_weight, + ignore_index=self.ignore_index) + + return loss_seg + + def forward(self, out_neck, out_head, gt_kernels): + + losses = {} + + loss_seg = self.seg_loss(out_head, gt_kernels) + + losses['loss_seg'] = loss_seg + + return losses diff --git a/mmocr/models/textrecog/necks/__init__.py b/mmocr/models/textrecog/necks/__init__.py new file mode 100755 index 00000000..71ceadc1 --- /dev/null +++ b/mmocr/models/textrecog/necks/__init__.py @@ -0,0 +1,3 @@ +from .fpn_ocr import FPNOCR + +__all__ = ['FPNOCR'] diff --git a/mmocr/models/textrecog/necks/fpn_ocr.py b/mmocr/models/textrecog/necks/fpn_ocr.py new file mode 100644 index 00000000..c1e9f178 --- /dev/null +++ b/mmocr/models/textrecog/necks/fpn_ocr.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmdet.models.builder import NECKS + + +@NECKS.register_module() +class FPNOCR(nn.Module): + """FPN-like Network for segmentation based text recognition. + + Args: + in_channels (list[int]): Number of input channels for each scale. + out_channels (int): Number of output channels for each scale. + last_stage_only (bool): If True, output last stage only. + """ + + def __init__(self, in_channels, out_channels, last_stage_only=True): + super(FPNOCR, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + + self.last_stage_only = last_stage_only + + self.lateral_convs = nn.ModuleList() + self.smooth_convs_1x1 = nn.ModuleList() + self.smooth_convs_3x3 = nn.ModuleList() + + for i in range(self.num_ins): + l_conv = ConvModule( + in_channels[i], out_channels, 1, norm_cfg=dict(type='BN')) + self.lateral_convs.append(l_conv) + + for i in range(self.num_ins - 1): + s_conv_1x1 = ConvModule( + out_channels * 2, out_channels, 1, norm_cfg=dict(type='BN')) + s_conv_3x3 = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=dict(type='BN')) + self.smooth_convs_1x1.append(s_conv_1x1) + self.smooth_convs_3x3.append(s_conv_3x3) + + def init_weights(self): + pass + + def _upsample_x2(self, x): + return F.interpolate(x, scale_factor=2, mode='bilinear') + + def forward(self, inputs): + lateral_features = [ + l_conv(inputs[i]) for i, l_conv in enumerate(self.lateral_convs) + ] + + outs = [] + for i in range(len(self.smooth_convs_3x3), 0, -1): # 3, 2, 1 + last_out = lateral_features[-1] if len(outs) == 0 else outs[-1] + upsample = self._upsample_x2(last_out) + upsample_cat = torch.cat((upsample, lateral_features[i - 1]), + dim=1) + smooth_1x1 = self.smooth_convs_1x1[i - 1](upsample_cat) + smooth_3x3 = self.smooth_convs_3x3[i - 1](smooth_1x1) + outs.append(smooth_3x3) + + return tuple(outs[-1:]) if self.last_stage_only else tuple(outs) diff --git a/mmocr/models/textrecog/recognizer/__init__.py b/mmocr/models/textrecog/recognizer/__init__.py new file mode 100644 index 00000000..91af5666 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/__init__.py @@ -0,0 +1,12 @@ +from .base import BaseRecognizer +from .crnn import CRNNNet +from .encode_decode_recognizer import EncodeDecodeRecognizer +from .nrtr import NRTR +from .robust_scanner import RobustScanner +from .sar import SARNet +from .seg_recognizer import SegRecognizer + +__all__ = [ + 'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet', 'NRTR', + 'SegRecognizer', 'RobustScanner' +] diff --git a/mmocr/models/textrecog/recognizer/base.py b/mmocr/models/textrecog/recognizer/base.py new file mode 100644 index 00000000..bf0f0f4c --- /dev/null +++ b/mmocr/models/textrecog/recognizer/base.py @@ -0,0 +1,236 @@ +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import mmcv +import torch +import torch.distributed as dist +import torch.nn as nn +from mmcv.runner import auto_fp16 +from mmcv.utils import print_log + +from mmdet.utils import get_root_logger +from mmocr.core import imshow_text_label + + +class BaseRecognizer(nn.Module, metaclass=ABCMeta): + """Base class for text recognition.""" + + def __init__(self): + super().__init__() + self.fp16_enabled = False + + @abstractmethod + def extract_feat(self, imgs): + """Extract features from images.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """ + Args: + img (tensor): tensors with shape (N, C, H, W). + Typically should be mean centered and std scaled. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details of the values of these keys, see + :class:`mmdet.datasets.pipelines.Collect`. + kwargs (keyword arguments): Specific to concrete implementation. + """ + pass + + @abstractmethod + def simple_test(self, img, img_metas, **kwargs): + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Test function with test time augmentation. + + Args: + imgs (list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): The metadata of images. + """ + pass + + def init_weights(self, pretrained=None): + """Initialize the weights for detector. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if pretrained is not None: + logger = get_root_logger() + print_log(f'load model from: {pretrained}', logger=logger) + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (tensor | list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[dict] | list[list[dict]]): + The outer list indicates images in a batch. + """ + if isinstance(imgs, list): + assert len(imgs) == len(img_metas) + assert len(imgs) > 0 + assert imgs[0].size(0) == 1, ('aug test does not support ' + f'inference with batch size ' + f'{imgs[0].size(0)}') + return self.aug_test(imgs, img_metas, **kwargs) + + return self.simple_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=('img', )) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note that img and img_meta are single-nested (i.e. tensor and + list[dict]). + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + + return self.forward_test(img, img_metas, **kwargs) + + def _parse_losses(self, losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw outputs of the network, which usually contain + losses and other necessary infomation. + + Returns: + tuple[tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data, optimizer): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer update, which are done by an optimizer + hook. Note that in some complicated cases or models (e.g. GAN), + the whole process (including the back propagation and optimizer update) + is also defined by this method. + + Args: + data (dict): The outputs of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + + - ``loss`` is a tensor for back propagation, which is a + weighted sum of multiple losses. + - ``log_vars`` contains all the variables to be sent to the + logger. + - ``num_samples`` indicates the batch size used for + averaging the logs (Note: for the + DDP model, num_samples refers to the batch size for each GPU). + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def val_step(self, data, optimizer): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but is + used during val epochs. Note that the evaluation after training epochs + is not implemented by this method, but by an evaluation hook. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['img_metas'])) + + return outputs + + def show_result(self, + img, + result, + gt_label='', + win_name='', + show=False, + wait_time=0, + out_file=None, + **kwargs): + """Draw `result` on `img`. + + Args: + img (str or tensor): The image to be displayed. + result (dict): The results to draw on `img`. + gt_label (str): Ground truth label of img. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + Default: 0. + show (bool): Whether to show the image. + Default: False. + out_file (str or None): The output filename. + Default: None. + + Returns: + img (tensor): Only if not `show` or `out_file`. + """ + img = mmcv.imread(img) + img = img.copy() + pred_label = None + if 'text' in result.keys(): + pred_label = result['text'] + + # if out_file specified, do not show image in window + if out_file is not None: + show = False + # draw text label + if pred_label is not None: + img = imshow_text_label( + img, + pred_label, + gt_label, + show=show, + win_name=win_name, + wait_time=wait_time, + out_file=out_file) + + if not (show or out_file): + warnings.warn('show==False and out_file is not specified, only ' + 'result image will be returned') + return img + + return img diff --git a/mmocr/models/textrecog/recognizer/crnn.py b/mmocr/models/textrecog/recognizer/crnn.py new file mode 100644 index 00000000..cfc98aae --- /dev/null +++ b/mmocr/models/textrecog/recognizer/crnn.py @@ -0,0 +1,7 @@ +from mmdet.models.builder import DETECTORS +from .encode_decode_recognizer import EncodeDecodeRecognizer + + +@DETECTORS.register_module() +class CRNNNet(EncodeDecodeRecognizer): + """CTC-loss based recognizer.""" diff --git a/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py new file mode 100644 index 00000000..ee1fb29b --- /dev/null +++ b/mmocr/models/textrecog/recognizer/encode_decode_recognizer.py @@ -0,0 +1,173 @@ +from mmdet.models.builder import DETECTORS, build_backbone, build_loss +from mmocr.models.builder import (build_convertor, build_decoder, + build_encoder, build_preprocessor) +from .base import BaseRecognizer + + +@DETECTORS.register_module() +class EncodeDecodeRecognizer(BaseRecognizer): + """Base class for encode-decode recognizer.""" + + def __init__(self, + preprocessor=None, + backbone=None, + encoder=None, + decoder=None, + loss=None, + label_convertor=None, + train_cfg=None, + test_cfg=None, + max_seq_len=40, + pretrained=None): + super().__init__() + + # Label convertor (str2tensor, tensor2str) + assert label_convertor is not None + label_convertor.update(max_seq_len=max_seq_len) + self.label_convertor = build_convertor(label_convertor) + + # Preprocessor module, e.g., TPS + self.preprocessor = None + if preprocessor is not None: + self.preprocessor = build_preprocessor(preprocessor) + + # Backbone + assert backbone is not None + self.backbone = build_backbone(backbone) + + # Encoder module + self.encoder = None + if encoder is not None: + self.encoder = build_encoder(encoder) + + # Decoder module + assert decoder is not None + decoder.update(num_classes=self.label_convertor.num_classes()) + decoder.update(start_idx=self.label_convertor.start_idx) + decoder.update(padding_idx=self.label_convertor.padding_idx) + decoder.update(max_seq_len=max_seq_len) + self.decoder = build_decoder(decoder) + + # Loss + assert loss is not None + loss.update(ignore_index=self.label_convertor.padding_idx) + self.loss = build_loss(loss) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.max_seq_len = max_seq_len + self.init_weights(pretrained=pretrained) + + def init_weights(self, pretrained=None): + """Initialize the weights of recognizer.""" + super().init_weights(pretrained) + + if self.preprocessor is not None: + self.preprocessor.init_weights() + + self.backbone.init_weights() + + if self.encoder is not None: + self.encoder.init_weights() + + self.decoder.init_weights() + + def extract_feat(self, img): + """Directly extract features from the backbone.""" + if self.preprocessor is not None: + img = self.preprocessor(img) + + x = self.backbone(img) + + return x + + def forward_train(self, img, img_metas): + """ + Args: + img (tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + contains: 'img_shape', 'filename', and may also contain + 'ori_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + + Returns: + dict[str, tensor]: A dictionary of loss components. + """ + feat = self.extract_feat(img) + + gt_labels = [img_meta['text'] for img_meta in img_metas] + + targets_dict = self.label_convertor.str2tensor(gt_labels) + + out_enc = None + if self.encoder is not None: + out_enc = self.encoder(feat, img_metas) + + out_dec = self.decoder( + feat, out_enc, targets_dict, img_metas, train_mode=True) + + loss_inputs = ( + out_dec, + targets_dict, + ) + losses = self.loss(*loss_inputs) + + return losses + + def simple_test(self, img, img_metas, **kwargs): + """Test function with test time augmentation. + + Args: + imgs (torch.Tensor): Image input tensor. + img_metas (list[dict]): List of image information. + + Returns: + list[str]: Text label result of each image. + """ + feat = self.extract_feat(img) + + out_enc = None + if self.encoder is not None: + out_enc = self.encoder(feat, img_metas) + + out_dec = self.decoder( + feat, out_enc, None, img_metas, train_mode=False) + + label_indexes, label_scores = self.label_convertor.tensor2idx( + out_dec, img_metas) + label_strings = self.label_convertor.idx2str(label_indexes) + + # flatten batch results + results = [] + for string, score in zip(label_strings, label_scores): + results.append(dict(text=string, score=score)) + + return results + + def merge_aug_results(self, aug_results): + out_text, out_score = '', -1 + for result in aug_results: + text = result[0]['text'] + score = sum(result[0]['score']) / max(1, len(text)) + if score > out_score: + out_text = text + out_score = score + out_results = [dict(text=out_text, score=out_score)] + return out_results + + def aug_test(self, imgs, img_metas, **kwargs): + """Test function as well as time augmentation. + + Args: + imgs (list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): The metadata of images. + """ + aug_results = [] + for img, img_meta in zip(imgs, img_metas): + result = self.simple_test(img, img_meta, **kwargs) + aug_results.append(result) + + return self.merge_aug_results(aug_results) diff --git a/mmocr/models/textrecog/recognizer/nrtr.py b/mmocr/models/textrecog/recognizer/nrtr.py new file mode 100644 index 00000000..e1b3c5d2 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/nrtr.py @@ -0,0 +1,7 @@ +from mmdet.models.builder import DETECTORS +from .encode_decode_recognizer import EncodeDecodeRecognizer + + +@DETECTORS.register_module() +class NRTR(EncodeDecodeRecognizer): + """Implementation of `NRTR `_""" diff --git a/mmocr/models/textrecog/recognizer/robust_scanner.py b/mmocr/models/textrecog/recognizer/robust_scanner.py new file mode 100644 index 00000000..7189396e --- /dev/null +++ b/mmocr/models/textrecog/recognizer/robust_scanner.py @@ -0,0 +1,10 @@ +from mmdet.models.builder import DETECTORS +from .encode_decode_recognizer import EncodeDecodeRecognizer + + +@DETECTORS.register_module() +class RobustScanner(EncodeDecodeRecognizer): + """Implementation of `RobustScanner. + + + """ diff --git a/mmocr/models/textrecog/recognizer/sar.py b/mmocr/models/textrecog/recognizer/sar.py new file mode 100644 index 00000000..bce67dca --- /dev/null +++ b/mmocr/models/textrecog/recognizer/sar.py @@ -0,0 +1,7 @@ +from mmdet.models.builder import DETECTORS +from .encode_decode_recognizer import EncodeDecodeRecognizer + + +@DETECTORS.register_module() +class SARNet(EncodeDecodeRecognizer): + """Implementation of `SAR `_""" diff --git a/mmocr/models/textrecog/recognizer/seg_recognizer.py b/mmocr/models/textrecog/recognizer/seg_recognizer.py new file mode 100644 index 00000000..e013e1d7 --- /dev/null +++ b/mmocr/models/textrecog/recognizer/seg_recognizer.py @@ -0,0 +1,153 @@ +from mmdet.models.builder import (DETECTORS, build_backbone, build_head, + build_loss, build_neck) +from mmocr.models.builder import build_convertor, build_preprocessor +from .base import BaseRecognizer + + +@DETECTORS.register_module() +class SegRecognizer(BaseRecognizer): + """Base class for segmentation based recognizer.""" + + def __init__(self, + preprocessor=None, + backbone=None, + neck=None, + head=None, + loss=None, + label_convertor=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super().__init__() + + # Label_convertor + assert label_convertor is not None + self.label_convertor = build_convertor(label_convertor) + + # Preprocessor module, e.g., TPS + self.preprocessor = None + if preprocessor is not None: + self.preprocessor = build_preprocessor(preprocessor) + + # Backbone + assert backbone is not None + self.backbone = build_backbone(backbone) + + # Neck + assert neck is not None + self.neck = build_neck(neck) + + # Head + assert head is not None + head.update(num_classes=self.label_convertor.num_classes()) + self.head = build_head(head) + + # Loss + assert loss is not None + self.loss = build_loss(loss) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.init_weights(pretrained=pretrained) + + def init_weights(self, pretrained=None): + """Initialize the weights of recognizer.""" + super().init_weights(pretrained) + + if self.preprocessor is not None: + self.preprocessor.init_weights() + + self.backbone.init_weights(pretrained=pretrained) + + if self.neck is not None: + self.neck.init_weights() + + self.head.init_weights() + + def extract_feat(self, img): + """Directly extract features from the backbone.""" + if self.preprocessor is not None: + img = self.preprocessor(img) + + x = self.backbone(img) + + return x + + def forward_train(self, img, img_metas, gt_kernels=None): + """ + Args: + img (tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A list of image info dict where each dict + contains: 'img_shape', 'filename', and may also contain + 'ori_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + + Returns: + dict[str, tensor]: A dictionary of loss components. + """ + + feats = self.extract_feat(img) + + out_neck = self.neck(feats) + + out_head = self.head(out_neck) + + loss_inputs = (out_neck, out_head, gt_kernels) + + losses = self.loss(*loss_inputs) + + return losses + + def simple_test(self, img, img_metas, **kwargs): + """Test function without test time augmentation. + + Args: + imgs (torch.Tensor): Image input tensor. + img_metas (list[dict]): List of image information. + + Returns: + list[str]: Text label result of each image. + """ + + feat = self.extract_feat(img) + + out_neck = self.neck(feat) + + out_head = self.head(out_neck) + + texts, scores = self.label_convertor.tensor2str(out_head, img_metas) + + # flatten batch results + results = [] + for text, score in zip(texts, scores): + results.append(dict(text=text, score=score)) + + return results + + def merge_aug_results(self, aug_results): + out_text, out_score = '', -1 + for result in aug_results: + text = result[0]['text'] + score = sum(result[0]['score']) / max(1, len(text)) + if score > out_score: + out_text = text + out_score = score + out_results = [dict(text=out_text, score=out_score)] + return out_results + + def aug_test(self, imgs, img_metas, **kwargs): + """Test function with test time augmentation. + + Args: + imgs (list[tensor]): Tensor should have shape NxCxHxW, + which contains all images in the batch. + img_metas (list[list[dict]]): The metadata of images. + """ + aug_results = [] + for img, img_meta in zip(imgs, img_metas): + result = self.simple_test(img, img_meta, **kwargs) + aug_results.append(result) + + return self.merge_aug_results(aug_results) diff --git a/mmocr/models/utils/__init__.py b/mmocr/models/utils/__init__.py new file mode 100644 index 00000000..7afea00a --- /dev/null +++ b/mmocr/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .ops.rroi_align import RROIAlign + +__all__ = ['RROIAlign'] diff --git a/mmocr/models/utils/ops/rroi_align/__init__.py b/mmocr/models/utils/ops/rroi_align/__init__.py new file mode 100644 index 00000000..a6161b3e --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/__init__.py @@ -0,0 +1,3 @@ +from .rroi_align import RROIAlign + +__all__ = ['RROIAlign'] diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/RROIAlign.h b/mmocr/models/utils/ops/rroi_align/csrc/csc/RROIAlign.h new file mode 100644 index 00000000..166b95e3 --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/RROIAlign.h @@ -0,0 +1,46 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +// Interface for Python +std::tuple RROIAlign_forward(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return RROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); + //return RROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +at::Tensor RROIAlign_backward(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& con_idx_x, + const at::Tensor& con_idx_y, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + if (grad.type().is_cuda()) { +#ifdef WITH_CUDA + return RROIAlign_backward_cuda(grad, rois, con_idx_x, con_idx_y, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/cpu/ROIAlign_cpu.cpp b/mmocr/models/utils/ops/rroi_align/csrc/csc/cpu/ROIAlign_cpu.cpp new file mode 100644 index 00000000..822c0bcc --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/cpu/ROIAlign_cpu.cpp @@ -0,0 +1,257 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "vision.h" + +// implementation taken from Caffe2 +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template +void pre_calc_for_bilinear_interpolate( + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int iy_upper, + const int ix_upper, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T x = xx; + T y = yy; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indeces + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +void ROIAlignForward_cpu_kernel( + const int nthreads, + const T* bottom_data, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* bottom_rois, + //int roi_cols, + T* top_data) { + //AT_ASSERT(roi_cols == 4 || roi_cols == 5); + int roi_cols = 5; + + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + // roi could have 4 or 5 columns + const T* offset_bottom_rois = bottom_rois + n * roi_cols; + int roi_batch_ind = 0; + if (roi_cols == 5) { + roi_batch_ind = offset_bottom_rois[0]; + offset_bottom_rois++; + } + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[0] * spatial_scale; + T roi_start_h = offset_bottom_rois[1] * spatial_scale; + T roi_end_w = offset_bottom_rois[2] * spatial_scale; + T roi_end_h = offset_bottom_rois[3] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); + T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + // we want to precalculate indeces and weights shared by all chanels, + // this is the key point of optimiation + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_bottom_data[pc.pos1] + + pc.w2 * offset_bottom_data[pc.pos2] + + pc.w3 * offset_bottom_data[pc.pos3] + + pc.w4 * offset_bottom_data[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + top_data[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor"); + AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { + ROIAlignForward_cpu_kernel( + output_size, + input.data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.data(), + output.data()); + }); + return output; +} diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/cpu/vision.h b/mmocr/models/utils/ops/rroi_align/csrc/csc/cpu/vision.h new file mode 100644 index 00000000..20c459e0 --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/cpu/vision.h @@ -0,0 +1,11 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include + + +at::Tensor ROIAlign_forward_cpu(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/ROIAlign_cuda.cu b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/ROIAlign_cuda.cu new file mode 100644 index 00000000..5fe97ca9 --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/ROIAlign_cuda.cu @@ -0,0 +1,346 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__device__ T bilinear_interpolate(const T* bottom_data, + const int height, const int width, + T y, T x, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + //empty + return 0; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low = (int) y; + int x_low = (int) x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T) y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T) x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = bottom_data[y_low * width + x_low]; + T v2 = bottom_data[y_low * width + x_high]; + T v3 = bottom_data[y_high * width + x_low]; + T v4 = bottom_data[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__global__ void RoIAlignForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, + const T* bottom_rois, T* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[1] * spatial_scale; + T roi_start_h = offset_bottom_rois[2] * spatial_scale; + T roi_end_w = offset_bottom_rois[3] * spatial_scale; + T roi_end_h = offset_bottom_rois[4] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix ++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + + +template +__device__ void bilinear_interpolate_gradient( + const int height, const int width, + T y, T x, + T & w1, T & w2, T & w3, T & w4, + int & x_low, int & x_high, int & y_low, int & y_high, + const int index /* index for debug only*/) { + + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + //empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + y_low = (int) y; + x_low = (int) x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T) y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T) x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = bottom_data[y_low * width + x_low]; + // T v2 = bottom_data[y_low * width + x_high]; + // T v3 = bottom_data[y_high * width + x_low]; + // T v4 = bottom_data[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff, + const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, + const int sampling_ratio, + T* bottom_diff, + const T* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_bottom_rois[1] * spatial_scale; + T roi_start_h = offset_bottom_rois[2] * spatial_scale; + T roi_end_w = offset_bottom_rois[3] * spatial_scale; + T roi_end_h = offset_bottom_rois[4] * spatial_scale; + // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix ++) + { + const T x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, + w1, w2, w3, w4, + x_low, x_high, y_low, y_high, + index); + + T g1 = top_diff_this_bin * w1 / count; + T g2 = top_diff_this_bin * w2 / count; + T g3 = top_diff_this_bin * w3 / count; + T g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) + { + atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); + atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); + atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); + atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // RoIAlignBackward + + +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return output; + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] { + RoIAlignForward<<>>( + output_size, + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.contiguous().data(), + output.data()); + }); + THCudaCheck(cudaGetLastError()); + return output; +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] { + RoIAlignBackwardFeature<<>>( + grad.numel(), + grad.contiguous().data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data(), + rois.contiguous().data()); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/ROIPool_cuda.cu b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/ROIPool_cuda.cu new file mode 100644 index 00000000..b826dd9b --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/ROIPool_cuda.cu @@ -0,0 +1,202 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include +#include + +#include +#include +#include + + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +template +__global__ void RoIPoolFForward(const int nthreads, const T* bottom_data, + const T spatial_scale, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const T* bottom_rois, T* top_data, int* argmax_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); + int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); + int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); + int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); + + // Force malformed ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w + 1, 1); + int roi_height = max(roi_end_h - roi_start_h + 1, 1); + T bin_size_h = static_cast(roi_height) + / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) + / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) + * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) + * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) + * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) + * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Define an empty pooling region to be zero + T maxval = is_empty ? 0 : -FLT_MAX; + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + int maxidx = -1; + const T* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int bottom_index = h * width + w; + if (offset_bottom_data[bottom_index] > maxval) { + maxval = offset_bottom_data[bottom_index]; + maxidx = bottom_index; + } + } + } + top_data[index] = maxval; + argmax_data[index] = maxidx; + } +} + +template +__global__ void RoIPoolFBackward(const int nthreads, const T* top_diff, + const int* argmax_data, const int num_rois, const T spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, T* bottom_diff, + const T* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + int bottom_offset = (roi_batch_ind * channels + c) * height * width; + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + T* offset_bottom_diff = bottom_diff + bottom_offset; + const int* offset_argmax_data = argmax_data + top_offset; + + int argmax = offset_argmax_data[ph * pooled_width + pw]; + if (argmax != -1) { + atomicAdd( + offset_bottom_diff + argmax, + static_cast(offset_top_diff[ph * pooled_width + pw])); + + } + } +} + +std::tuple ROIPool_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + auto argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt)); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] { + RoIPoolFForward<<>>( + output_size, + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.contiguous().data(), + output.data(), + argmax.data()); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, argmax); +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& rois, + const at::Tensor& argmax, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + // TODO add more checks + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] { + RoIPoolFBackward<<>>( + grad.numel(), + grad.contiguous().data(), + argmax.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + grad_input.data(), + rois.contiguous().data()); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/RROIAlign_cuda.cu b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/RROIAlign_cuda.cu new file mode 100644 index 00000000..7f4462ac --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/RROIAlign_cuda.cu @@ -0,0 +1,375 @@ +// Copyright (c) Jianqi Ma. All Rights Reserved. +#include +#include + +#include +#include +#include + +#include +#include +#include +//#include "rroi_alignment_kernel.h" + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +template +__global__ void RROIAlignForward( + const int nthreads, + const T* bottom_data, + const T spatial_scale, + int height, + int width, + int channels, + const int pooled_height, + const int pooled_width, + const T* bottom_rois, + T* top_data, + float* con_idx_x, + float* con_idx_y) +{ + + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + // +0.5 shift removed + int imageWidth = width; + int imageHeight = height; + + // (n, c, ph, pw) is an element in the pooled output + int n = index; + int pw = n % pooled_width; + n /= pooled_width; + int ph = n % pooled_height; + n /= pooled_height; + int c = n % channels; + n /= channels; + + const T* offset_bottom_rois = bottom_rois + n * 7; //= 7 is rois dimension 0 + + int roi_batch_ind = offset_bottom_rois[0]; + T cx = offset_bottom_rois[1]; + T cy = offset_bottom_rois[2]; + T h = offset_bottom_rois[3]; + T w = offset_bottom_rois[4]; + //T angle = offset_bottom_rois[5]/180.0*3.1415926535; + T Alpha = offset_bottom_rois[5]; + T Beta = offset_bottom_rois[6]; + + //TransformPrepare + T dx = -pooled_width/2.0; + T dy = -pooled_height/2.0; + T Sx = w*spatial_scale/pooled_width; + T Sy = h*spatial_scale/pooled_height; + //T Alpha = cos(angle); + //T Beta = sin(angle); + T Dx = cx*spatial_scale; + T Dy = cy*spatial_scale; + + T M[2][3]; + M[0][0] = Alpha*Sx; + M[0][1] = Beta*Sy; + M[0][2] = Alpha*Sx*dx+Beta*Sy*dy+Dx; + M[1][0] = -Beta*Sx; + M[1][1] = Alpha*Sy; + M[1][2] = -Beta*Sx*dx+Alpha*Sy*dy+Dy; + + T P[8]; + P[0] = M[0][0]*pw+M[0][1]*ph+M[0][2]; + P[1] = M[1][0]*pw+M[1][1]*ph+M[1][2]; + P[2] = M[0][0]*pw+M[0][1]*(ph+1)+M[0][2]; + P[3] = M[1][0]*pw+M[1][1]*(ph+1)+M[1][2]; + P[4] = M[0][0]*(pw+1)+M[0][1]*ph+M[0][2]; + P[5] = M[1][0]*(pw+1)+M[1][1]*ph+M[1][2]; + P[6] = M[0][0]*(pw+1)+M[0][1]*(ph+1)+M[0][2]; + P[7] = M[1][0]*(pw+1)+M[1][1]*(ph+1)+M[1][2]; + + T leftMost = (max(round(min(min(P[0],P[2]),min(P[4],P[6]))),0.0)); + T rightMost= (min(round(max(max(P[0],P[2]),max(P[4],P[6]))),imageWidth-1.0)); + T topMost= (max(round(min(min(P[1],P[3]),min(P[5],P[7]))),0.0)); + T bottomMost= (min(round(max(max(P[1],P[3]),max(P[5],P[7]))),imageHeight-1.0)); + + //float maxval = 0; + //int maxidx = -1; + const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; + + //float AB[2]; + //AB[0] = P[2] - P[0]; + //AB[1] = P[3] - P[1]; + //float ABAB = AB[0]*AB[0] +AB[1]*AB[1]; + //float AC[2]; + //AC[0] = P[4] - P[0]; + //AC[1] = P[5] - P[1]; + //float ACAC = AC[0]*AC[0] + AC[1]*AC[1]; + + float bin_cx = (leftMost + rightMost) / 2.0; // shift + float bin_cy = (topMost + bottomMost) / 2.0; + + int bin_l = (int)floor(bin_cx); + int bin_r = (int)ceil(bin_cx); + int bin_t = (int)floor(bin_cy); + int bin_b = (int)ceil(bin_cy); + + T lt_value = 0.0; + if (bin_t > 0 && bin_l > 0 && bin_t < height && bin_l < width) + lt_value = offset_bottom_data[bin_t * width + bin_l]; + T rt_value = 0.0; + if (bin_t > 0 && bin_r > 0 && bin_t < height && bin_r < width) + rt_value = offset_bottom_data[bin_t * width + bin_r]; + T lb_value = 0.0; + if (bin_b > 0 && bin_l > 0 && bin_b < height && bin_l < width) + lb_value = offset_bottom_data[bin_b * width + bin_l]; + T rb_value = 0.0; + if (bin_b > 0 && bin_r > 0 && bin_b < height && bin_r < width) + rb_value = offset_bottom_data[bin_b * width + bin_r]; + + T rx = bin_cx - floor(bin_cx); + T ry = bin_cy - floor(bin_cy); + + T wlt = (1.0 - rx) * (1.0 - ry); + T wrt = rx * (1.0 - ry); + T wrb = rx * ry; + T wlb = (1.0 - rx) * ry; + + T inter_val = 0.0; + + inter_val += lt_value * wlt; + inter_val += rt_value * wrt; + inter_val += rb_value * wrb; + inter_val += lb_value * wlb; + + atomicAdd(top_data + index, static_cast(inter_val)); + atomicAdd(con_idx_x + index, static_cast(bin_cx)); + atomicAdd(con_idx_y + index, static_cast(bin_cy)); + + //top_data[index] = static_cast(inter_val); + //con_idx_x[index] = bin_cx; + //con_idx_y[index] = bin_cy; + + } +} + +/** +int RROIAlignForwardLaucher( + const float* bottom_data, const float spatial_scale, const int num_rois, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const float* bottom_rois, + float* top_data, float* con_idx_x, float* con_idx_y, const float* im_info, cudaStream_t stream) +{ + const int kThreadsPerBlock = 1024; + const int output_size = num_rois * pooled_height * pooled_width * channels; + cudaError_t err; + + + RROIAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( + output_size, bottom_data, spatial_scale, height, width, channels, pooled_height, + pooled_width, bottom_rois, top_data, con_idx_x, con_idx_y, im_info); + + err = cudaGetLastError(); + if(cudaSuccess != err) + { + fprintf( stderr, "RRoI forward cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); + exit( -1 ); + } + + return 1; +} +*/ +//ROIAlign_forward_cuda + +//std::tuple ROIPool_forward_cuda(const at::Tensor& input, +// const at::Tensor& rois, +// const float spatial_scale, +// const int pooled_height, +// const int pooled_width) + +std::tuple RROIAlign_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) +{ + + + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + auto con_idx_x = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat)); + auto con_idx_y = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat)); + + dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + //const int kThreadsPerBlock = 1024; + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (output.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, con_idx_x, con_idx_y);//, con_idx_y; //std::make_tuple( + } + + AT_DISPATCH_FLOATING_TYPES(input.type(), "RROIAlign_forward", [&] { + RROIAlignForward<<>>( + output_size, + input.contiguous().data(), + spatial_scale, + height, + width, + channels, + pooled_height, + pooled_width, + rois.contiguous().data(), + output.data(), + con_idx_x.data(), + con_idx_y.data()); + } + ); + + THCudaCheck(cudaGetLastError()); + return std::make_tuple(output, con_idx_x, con_idx_y); +} + +template +__global__ void RROIAlignBackward( + const int nthreads, + const T* top_diff, + const float* con_idx_x, + const float* con_idx_y, + const int num_rois, + const float spatial_scale, + const int height, + const int width, + const int channels, + const int pooled_height, + const int pooled_width, + T* bottom_diff, + const T* bottom_rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + + // (n, c, ph, pw) is an element in the pooled output + int n = index; + //int w = n % width; + n /= pooled_width; + //int h = n % height; + n /= pooled_height; + int c = n % channels; + n /= channels; + + const T* offset_bottom_rois = bottom_rois + n * 7; + int roi_batch_ind = offset_bottom_rois[0]; + T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; + + //int bottom_index = argmax_data[index]; + + float bw = con_idx_x[index]; + float bh = con_idx_y[index]; + //if (bh > 0.00001 && bw > 0.00001 && bw < height-1 && bw < width-1){ + + int bin_xs = int(floor(bw)); + int bin_ys = int(floor(bh)); + + float rx = bw - float(bin_xs); + float ry = bh - float(bin_ys); + + T wlt = (1.0 - rx) * (1.0 - ry); + T wrt = rx * (1.0 - ry); + T wrb = rx * ry; + T wlb = (1.0 - rx) * ry; + + // if(bottom_index >= 0) // original != 0 maybe wrong + // bottom_diff[bottom_index]+=top_diff[index] ; + + //int min_x = bin_xs, 0), width - 1); + //int min_y = min(max(bin_ys, 0), height - 1); + //int max_x = max(min(bin_xs + 1, width - 1), 0); + //int max_y = max(min(bin_ys + 1, height - 1), 0); + + int min_x = (int)floor(bw); + int max_x = (int)ceil(bw); + int min_y = (int)floor(bh); + int max_y = (int)ceil(bh); + + T top_diff_of_bin = top_diff[index]; + + T v1 = wlt * top_diff_of_bin; + T v2 = wrt * top_diff_of_bin; + T v3 = wrb * top_diff_of_bin; + T v4 = wlb * top_diff_of_bin; + + // Atomic add + + if (min_y > 0 && min_x > 0 && min_y < height - 1 && min_x < width - 1) + atomicAdd(offset_bottom_diff + min_y * width + min_x, static_cast(v1)); + if (min_y > 0 && max_x < width - 1 && min_y < height - 1 && max_x > 0) + atomicAdd(offset_bottom_diff + min_y * width + max_x, static_cast(v2)); + if (max_y < height - 1 && max_x < width - 1 && max_y > 0 && max_x > 0) + atomicAdd(offset_bottom_diff + max_y * width + max_x, static_cast(v3)); + if (max_y < height - 1 && min_x > 0 && max_y > 0 && min_x < width - 1) + atomicAdd(offset_bottom_diff + max_y * width + min_x, static_cast(v4)); + + //} + + } +} + + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor RROIAlign_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& con_idx_x, + const at::Tensor& con_idx_y, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + // TODO add more checks + + auto num_rois = rois.size(0); + auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + THCudaCheck(cudaGetLastError()); + return grad_input; + } + + AT_DISPATCH_FLOATING_TYPES(grad.type(), "RROIAlign_backward", [&] { + RROIAlignBackward<<>>( + grad.numel(), + grad.contiguous().data(), + con_idx_x.data(), + con_idx_y.data(), + num_rois, + spatial_scale, + height, + width, + channels, + pooled_height, + pooled_width, + grad_input.data(), + rois.contiguous().data()); + }); + THCudaCheck(cudaGetLastError()); + return grad_input; +} diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/vision.h b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/vision.h new file mode 100644 index 00000000..25e54c08 --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/cuda/vision.h @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#pragma once +#include + +std::tuple RROIAlign_forward_cuda(const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor RROIAlign_backward_cuda(const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& con_idx_x, + const at::Tensor& con_idx_y, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); + +at::Tensor compute_flow_cuda(const at::Tensor& boxes, + const int height, + const int width); diff --git a/mmocr/models/utils/ops/rroi_align/csrc/csc/vision.cpp b/mmocr/models/utils/ops/rroi_align/csrc/csc/vision.cpp new file mode 100644 index 00000000..439bb2de --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/csrc/csc/vision.cpp @@ -0,0 +1,10 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +#include "RROIAlign.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + + m.def("rroi_align_forward", &RROIAlign_forward, "RROIAlign_forward"); + m.def("rroi_align_backward", &RROIAlign_backward, "RROIAlign_backward"); + +} diff --git a/mmocr/models/utils/ops/rroi_align/rroi_align.py b/mmocr/models/utils/ops/rroi_align/rroi_align.py new file mode 100644 index 00000000..c3f0ca7b --- /dev/null +++ b/mmocr/models/utils/ops/rroi_align/rroi_align.py @@ -0,0 +1,57 @@ +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + +from .csrc import rroi_align_backward, rroi_align_forward + + +class _RROIAlign(Function): + + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale): + + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + total_output = rroi_align_forward(input, roi, spatial_scale, + output_size[0], output_size[1]) + + output, con_idx_x, con_idx_y = total_output + ctx.save_for_backward(roi, con_idx_x, con_idx_y) + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, con_idx_x, con_idx_y = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + bs, ch, h, w = ctx.input_shape + grad_input = rroi_align_backward(grad_output, rois, con_idx_x, + con_idx_y, spatial_scale, + output_size[0], output_size[1], bs, + ch, h, w) + return grad_input, None, None, None + + +rroi_align = _RROIAlign.apply + + +class RROIAlign(nn.Module): + + def __init__(self, output_size, spatial_scale): + super(RROIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + def forward(self, input, rois): + return rroi_align(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'output_size=' + str(self.output_size) + tmpstr += ', spatial_scale=' + str(self.spatial_scale) + tmpstr += ')' + return tmpstr diff --git a/mmocr/utils/__init__.py b/mmocr/utils/__init__.py new file mode 100644 index 00000000..5192ce36 --- /dev/null +++ b/mmocr/utils/__init__.py @@ -0,0 +1,15 @@ +from mmcv.utils import Registry, build_from_cfg + +from mmdet.utils import get_root_logger +from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list, + is_none_or_type, is_type_list, valid_boundary) +from .collect_env import collect_env +from .img_util import drop_orientation +from .lmdb_util import lmdb_converter + +__all__ = [ + 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', + 'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type', + 'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter', + 'drop_orientation' +] diff --git a/mmocr/utils/check_argument.py b/mmocr/utils/check_argument.py new file mode 100644 index 00000000..92955df2 --- /dev/null +++ b/mmocr/utils/check_argument.py @@ -0,0 +1,72 @@ +import numpy as np + + +def is_3dlist(x): + + if not isinstance(x, list): + return False + if len(x) > 0: + if isinstance(x[0], list): + if len(x[0]) > 0: + return isinstance(x[0][0], list) + return True + return False + + return True + + +def is_2dlist(x): + + if not isinstance(x, list): + return False + if len(x) > 0: + return bool(isinstance(x[0], list)) + + return True + + +def is_ndarray_list(x): + + if not isinstance(x, list): + return False + if len(x) > 0: + return isinstance(x[0], np.ndarray) + + return True + + +def is_type_list(x, type): + + if not isinstance(x, list): + return False + if len(x) > 0: + return isinstance(x[0], type) + + return True + + +def is_none_or_type(x, type): + + return isinstance(x, type) or x is None + + +def equal_len(*argv): + assert len(argv) > 0 + + num_arg = len(argv[0]) + for arg in argv: + if len(arg) != num_arg: + return False + return True + + +def valid_boundary(x, with_score=True): + num = len(x) + if num < 8: + return False + if num % 2 == 0 and (not with_score): + return True + if num % 2 == 1 and with_score: + return True + + return False diff --git a/mmocr/utils/collect_env.py b/mmocr/utils/collect_env.py new file mode 100644 index 00000000..d664ae65 --- /dev/null +++ b/mmocr/utils/collect_env.py @@ -0,0 +1,60 @@ +import os.path as osp +import subprocess +import sys +from collections import defaultdict + +import cv2 +import mmcv +import torch +import torchvision + +import mmdet +import mmocr + + +def collect_env(): + env_info = {} + env_info['sys.platform'] = sys.platform + env_info['Python'] = sys.version.replace('\n', '') + + cuda_available = torch.cuda.is_available() + env_info['CUDA available'] = cuda_available + + if cuda_available: + from torch.utils.cpp_extension import CUDA_HOME + env_info['CUDA_HOME'] = CUDA_HOME + + if CUDA_HOME is not None and osp.isdir(CUDA_HOME): + try: + nvcc = osp.join(CUDA_HOME, 'bin/nvcc') + nvcc = subprocess.check_output( + '"{}" -V | tail -n1'.format(nvcc), shell=True) + nvcc = nvcc.decode('utf-8').strip() + except subprocess.SubprocessError: + nvcc = 'Not Available' + env_info['NVCC'] = nvcc + + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + devices[torch.cuda.get_device_name(k)].append(str(k)) + for name, devids in devices.items(): + env_info['GPU ' + ','.join(devids)] = name + + gcc = subprocess.check_output('gcc --version | head -n1', shell=True) + gcc = gcc.decode('utf-8').strip() + env_info['GCC'] = gcc + + env_info['PyTorch'] = torch.__version__ + env_info['PyTorch compiling details'] = torch.__config__.show() + + env_info['TorchVision'] = torchvision.__version__ + + env_info['OpenCV'] = cv2.__version__ + + env_info['MMCV'] = mmcv.__version__ + env_info['MMDetection'] = mmdet.__version__ + env_info['MMOCR'] = mmocr.__version__ + from mmdet.ops import get_compiler_version, get_compiling_cuda_version + env_info['MMOCR Compiler'] = get_compiler_version() + env_info['MMOCR CUDA Compiler'] = get_compiling_cuda_version() + return env_info diff --git a/mmocr/utils/img_util.py b/mmocr/utils/img_util.py new file mode 100644 index 00000000..b11625ef --- /dev/null +++ b/mmocr/utils/img_util.py @@ -0,0 +1,34 @@ +import os + +import mmcv + + +def drop_orientation(img_file): + """Check if the image has orientation information. If yes, ignore it by + converting the image format to png, and return new filename, otherwise + return the original filename. + + Args: + img_file(str): The image path + + Returns: + The converted image filename with proper postfix + """ + assert isinstance(img_file, str) + assert img_file + + # read imgs with ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + # read imgs with orientations as dataloader does when training and testing + img_color = mmcv.imread(img_file, 'color') + # make sure imgs have no orientation info, or annotation gt is wrong. + if img.shape[:2] == img_color.shape[:2]: + return img_file + + target_file = os.path.splitext(img_file)[0] + '.png' + # read img with ignoring orientation information + img = mmcv.imread(img_file, 'unchanged') + mmcv.imwrite(img, target_file) + os.remove(img_file) + print(f'{img_file} has orientation info. Ignore it by converting to png') + return target_file diff --git a/mmocr/utils/lmdb_util.py b/mmocr/utils/lmdb_util.py new file mode 100644 index 00000000..ac3cc3b9 --- /dev/null +++ b/mmocr/utils/lmdb_util.py @@ -0,0 +1,46 @@ +import shutil +import sys +import time +from pathlib import Path + +import lmdb + + +def lmdb_converter(img_list, output, batch_size=1000, coding='utf-8'): + # read img_list + with open(img_list) as f: + lines = f.readlines() + + # create lmdb database + if Path(output).is_dir(): + while True: + print('%s already exist, delete or not? [Y/n]' % output) + Yn = input().strip() + if Yn in ['Y', 'y']: + shutil.rmtree(output) + break + elif Yn in ['N', 'n']: + return + print('create database %s' % output) + Path(output).mkdir(parents=True, exist_ok=False) + env = lmdb.open(output, map_size=1099511627776) + + # build lmdb + beg_time = time.strftime('%H:%M:%S') + for beg_index in range(0, len(lines), batch_size): + end_index = min(beg_index + batch_size, len(lines)) + sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' % + (beg_time, time.strftime('%H:%M:%S'), beg_index, + end_index, len(lines))) + sys.stdout.flush() + batch = [(str(index).encode(coding), lines[index].encode(coding)) + for index in range(beg_index, end_index)] + with env.begin(write=True) as txn: + cursor = txn.cursor() + cursor.putmulti(batch, dupdata=False, overwrite=True) + sys.stdout.write('\n') + with env.begin(write=True) as txn: + key = 'total_number'.encode(coding) + value = str(len(lines)).encode(coding) + txn.put(key, value) + print('done', flush=True) diff --git a/mmocr/version.py b/mmocr/version.py new file mode 100644 index 00000000..2913da5a --- /dev/null +++ b/mmocr/version.py @@ -0,0 +1,4 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '0.1.0' +short_version = __version__ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..6981bd72 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +-r requirements/build.txt +-r requirements/optional.txt +-r requirements/runtime.txt +-r requirements/tests.txt diff --git a/requirements/build.txt b/requirements/build.txt new file mode 100644 index 00000000..1dee987a --- /dev/null +++ b/requirements/build.txt @@ -0,0 +1,5 @@ +# These must be installed before building mmocr +numpy +Polygon3 +pyclipper +torch>=1.1 diff --git a/requirements/docs.txt b/requirements/docs.txt new file mode 100644 index 00000000..89fbf86c --- /dev/null +++ b/requirements/docs.txt @@ -0,0 +1,4 @@ +recommonmark +sphinx +sphinx_markdown_tables +sphinx_rtd_theme diff --git a/requirements/optional.txt b/requirements/optional.txt new file mode 100644 index 00000000..e69de29b diff --git a/requirements/readthedocs.txt b/requirements/readthedocs.txt new file mode 100644 index 00000000..0542bfce --- /dev/null +++ b/requirements/readthedocs.txt @@ -0,0 +1,3 @@ +mmcv +torch +torchvision diff --git a/requirements/runtime.txt b/requirements/runtime.txt new file mode 100644 index 00000000..dc48806d --- /dev/null +++ b/requirements/runtime.txt @@ -0,0 +1,14 @@ +imgaug +Levenshtein +lmdb +matplotlib +numba>=0.45.1 +numpy +# need older pillow until torchvision is fixed +Pillow<=6.2.2 +python-Levenshtein +scikit-image +six +terminaltables +torch>=1.1 +torchvision diff --git a/requirements/tests.txt b/requirements/tests.txt new file mode 100644 index 00000000..ba224cb3 --- /dev/null +++ b/requirements/tests.txt @@ -0,0 +1,13 @@ +asynctest +codecov +flake8 +isort +# Note: used for kwarray.group_items, this may be ported to mmcv in the future. +kwarray +Polygon3 +pytest +pytest-cov +pytest-runner +ubelt +xdoctest >= 0.10.0 +yapf diff --git a/resources/mmocr-logo.jpg b/resources/mmocr-logo.jpg deleted file mode 100644 index b916627a..00000000 Binary files a/resources/mmocr-logo.jpg and /dev/null differ diff --git a/resources/mmocr-logo.png b/resources/mmocr-logo.png new file mode 100644 index 00000000..c81a3c83 Binary files /dev/null and b/resources/mmocr-logo.png differ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..febe7f2c --- /dev/null +++ b/setup.cfg @@ -0,0 +1,30 @@ +[bdist_wheel] +universal=1 + +[aliases] +test=pytest + +[tool:pytest] +norecursedirs=tests/integration/* +addopts=tests + +[yapf] +based_on_style = pep8 +blank_line_before_nested_class_or_def = true +split_before_expression_after_opening_paren = true +split_penalty_import_names=0 +SPLIT_PENALTY_AFTER_OPENING_BRACKET=800 + +[isort] +line_length = 79 +multi_line_output = 0 +known_standard_library = setuptools +known_first_party = mmdet,mmocr +known_third_party = Levenshtein,PIL,Polygon,cv2,imgaug,lmdb,matplotlib,mmcv,numpy,pyclipper,pycocotools,pytest,scipy,shapely,skimage,titlecase,torch,torchvision +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY + +[style] +BASED_ON_STYLE = pep8 +BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true +SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..6e316c9c --- /dev/null +++ b/setup.py @@ -0,0 +1,184 @@ +import glob +import os +from setuptools import find_packages, setup + +import torch +from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, + CUDAExtension) + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +version_file = 'mmocr/version.py' + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + import sys + # return short version for sdist + if 'sdist' in sys.argv or 'bdist_wheel' in sys.argv: + return locals()['short_version'] + else: + return locals()['__version__'] + + +def parse_requirements(fname='requirements.txt', with_version=True): + """Parse the package dependencies listed in a requirements file but strip + specific version information. + + Args: + fname (str): Path to requirements file. + with_version (bool, default=False): If True, include version specs. + Returns: + info (list[str]): List of requirements items. + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import sys + from os.path import exists + import re + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for info in parse_require_file(target): + yield info + else: + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, + rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith('#'): + for info in parse_line(line): + yield info + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + + +def get_rroi_align_extensions(): + + extensions_dir = 'mmocr/models/utils/ops/rroi_align/csrc/csc' + main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp')) + source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu')) + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {'cxx': []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [('WITH_CUDA', None)] + extra_compile_args['nvcc'] = [ + '-DCUDA_HAS_FP16=1', + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + + print(sources) + include_dirs = [extensions_dir] + print('include_dirs', include_dirs, flush=True) + ext = extension( + name='mmocr.models.utils.ops.rroi_align.csrc', + sources=sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + + return ext + + +if __name__ == '__main__': + library_dirs = [ + lp for lp in os.environ.get('LD_LIBRARY_PATH', '').split(':') + if len(lp) > 1 + ] + cpp_root = 'mmocr/models/textdet/postprocess/' + setup( + name='mmocr', + version=get_version(), + description='Text Detection, OCR, and NLP Toolbox', + long_description=readme(), + keywords='Text Detection, OCR, KIE, NLP', + url='https://github.com/jeffreykuang/mmocr', + packages=find_packages(exclude=('configs', 'tools', 'demo')), + package_data={'mmocr.ops': ['*/*.so']}, + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + ], + license='Apache License 2.0', + setup_requires=parse_requirements('requirements/build.txt'), + tests_require=parse_requirements('requirements/tests.txt'), + install_requires=parse_requirements('requirements/runtime.txt'), + extras_require={ + 'all': parse_requirements('requirements.txt'), + 'tests': parse_requirements('requirements/tests.txt'), + 'build': parse_requirements('requirements/build.txt'), + 'optional': parse_requirements('requirements/optional.txt'), + }, + ext_modules=[ + CppExtension( + name='mmocr.models.textdet.postprocess.pan', + sources=[cpp_root + 'pan.cpp']), + CppExtension( + name='mmocr.models.textdet.postprocess.pse', + sources=[cpp_root + 'pse.cpp']), + get_rroi_align_extensions() + ], + cmdclass={'build_ext': BuildExtension}, + zip_safe=False) diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/resort_88_101_1.png b/tests/data/ocr_char_ann_toy_dataset/imgs/resort_88_101_1.png new file mode 100644 index 00000000..4eecacaa Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/resort_88_101_1.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/resort_95_53_6.png b/tests/data/ocr_char_ann_toy_dataset/imgs/resort_95_53_6.png new file mode 100644 index 00000000..6e788a26 Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/resort_95_53_6.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_101_8_6.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_101_8_6.png new file mode 100644 index 00000000..6f131503 Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_101_8_6.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_104_58_9.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_104_58_9.png new file mode 100644 index 00000000..6c632f67 Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_104_58_9.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_110_1_6.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_110_1_6.png new file mode 100644 index 00000000..6940a4ca Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_110_1_6.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_12_61_4.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_12_61_4.png new file mode 100644 index 00000000..b148b6ce Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_12_61_4.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_130_74_1.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_130_74_1.png new file mode 100644 index 00000000..a29b7f9e Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_130_74_1.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_134_30_15.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_134_30_15.png new file mode 100644 index 00000000..b40863a4 Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_134_30_15.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_15_43_4.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_15_43_4.png new file mode 100644 index 00000000..7465fc2e Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_15_43_4.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_18_18_5.png b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_18_18_5.png new file mode 100644 index 00000000..1ce69334 Binary files /dev/null and b/tests/data/ocr_char_ann_toy_dataset/imgs/richard+feynman_18_18_5.png differ diff --git a/tests/data/ocr_char_ann_toy_dataset/instances_test.txt b/tests/data/ocr_char_ann_toy_dataset/instances_test.txt new file mode 100644 index 00000000..59b63e06 --- /dev/null +++ b/tests/data/ocr_char_ann_toy_dataset/instances_test.txt @@ -0,0 +1,10 @@ +resort_88_101_1.png From: +resort_95_53_6.png out +richard+feynman_101_8_6.png the +richard+feynman_104_58_9.png fast +richard+feynman_110_1_6.png many +richard+feynman_12_61_4.png the +richard+feynman_130_74_1.png the +richard+feynman_134_30_15.png how +richard+feynman_15_43_4.png the +richard+feynman_18_18_5.png Lines: diff --git a/tests/data/ocr_char_ann_toy_dataset/instances_train.txt b/tests/data/ocr_char_ann_toy_dataset/instances_train.txt new file mode 100644 index 00000000..c3c0fb36 --- /dev/null +++ b/tests/data/ocr_char_ann_toy_dataset/instances_train.txt @@ -0,0 +1,10 @@ +{"file_name": "resort_88_101_1.png", "annotations": [{"char_text": "F", "char_box": [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]}, {"char_text": "r", "char_box": [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]}, {"char_text": "o", "char_box": [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]}, {"char_text": "m", "char_box": [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]}, {"char_text": ":", "char_box": [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]}], "text": "From:"} +{"file_name": "resort_95_53_6.png", "annotations": [{"char_text": "o", "char_box": [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0]}, {"char_text": "u", "char_box": [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0]}, {"char_text": "t", "char_box": [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0]}], "text": "out"} +{"file_name": "richard+feynman_101_8_6.png", "annotations": [{"char_text": "t", "char_box": [5.0, 3.0, 13.0, 6.0, 10.0, 21.0, 1.0, 18.0]}, {"char_text": "h", "char_box": [14.0, 3.0, 27.0, 8.0, 22.0, 25.0, 10.0, 21.0]}, {"char_text": "e", "char_box": [25.0, 14.0, 35.0, 17.0, 32.0, 29.0, 22.0, 25.0]}], "text": "the"} +{"file_name": "richard+feynman_104_58_9.png", "annotations": [{"char_text": "f", "char_box": [22.0, 19.0, 30.0, 15.0, 20.0, 51.0, 12.0, 54.0]}, {"char_text": "a", "char_box": [27.0, 27.0, 37.0, 21.0, 31.0, 46.0, 21.0, 50.0]}, {"char_text": "s", "char_box": [37.0, 22.0, 47.0, 16.0, 40.0, 41.0, 30.0, 46.0]}, {"char_text": "t", "char_box": [50.0, 5.0, 58.0, 0.0, 47.0, 38.0, 40.0, 41.0]}], "text": "fast"} +{"file_name": "richard+feynman_110_1_6.png", "annotations": [{"char_text": "m", "char_box": [6.0, 33.0, 21.0, 23.0, 19.0, 31.0, 4.0, 41.0]}, {"char_text": "a", "char_box": [21.0, 22.0, 33.0, 15.0, 31.0, 24.0, 19.0, 31.0]}, {"char_text": "n", "char_box": [32.0, 16.0, 45.0, 8.0, 43.0, 17.0, 30.0, 25.0]}, {"char_text": "y", "char_box": [45.0, 8.0, 57.0, 0.0, 55.0, 11.0, 43.0, 19.0]}], "text": "many"} +{"file_name": "richard+feynman_12_61_4.png", "annotations": [{"char_text": "t", "char_box": [5.0, 0.0, 35.0, 6.0, 35.0, 34.0, 4.0, 28.0]}, {"char_text": "h", "char_box": [33.0, 6.0, 71.0, 13.0, 70.0, 40.0, 32.0, 33.0]}, {"char_text": "e", "char_box": [71.0, 13.0, 98.0, 18.0, 98.0, 45.0, 70.0, 40.0]}], "text": "the"} +{"file_name": "richard+feynman_130_74_1.png", "annotations": [{"char_text": "t", "char_box": [4.0, 12.0, 27.0, 10.0, 26.0, 47.0, 4.0, 49.0]}, {"char_text": "h", "char_box": [30.0, 3.0, 48.0, 2.0, 48.0, 45.0, 29.0, 47.0]}, {"char_text": "e", "char_box": [50.0, 17.0, 68.0, 15.0, 68.0, 44.0, 50.0, 46.0]}], "text": "the"} +{"file_name": "richard+feynman_134_30_15.png", "annotations": [{"char_text": "h", "char_box": [5.0, 1.0, 24.0, 7.0, 23.0, 23.0, 4.0, 17.0]}, {"char_text": "o", "char_box": [25.0, 12.0, 42.0, 18.0, 41.0, 29.0, 24.0, 24.0]}, {"char_text": "w", "char_box": [40.0, 18.0, 69.0, 26.0, 67.0, 37.0, 39.0, 28.0]}], "text": "how"} +{"file_name": "richard+feynman_15_43_4.png", "annotations": [{"char_text": "t", "char_box": [4.0, 8.0, 12.0, 5.0, 12.0, 19.0, 4.0, 22.0]}, {"char_text": "h", "char_box": [13.0, 5.0, 21.0, 2.0, 21.0, 16.0, 13.0, 19.0]}, {"char_text": "e", "char_box": [21.0, 2.0, 28.0, 0.0, 28.0, 14.0, 21.0, 16.0]}], "text": "the"} +{"file_name": "richard+feynman_18_18_5.png", "annotations": [{"char_text": "L", "char_box": [13.0, 14.0, 32.0, 12.0, 23.0, 36.0, 3.0, 38.0]}, {"char_text": "i", "char_box": [35.0, 7.0, 46.0, 6.0, 37.0, 31.0, 26.0, 32.0]}, {"char_text": "n", "char_box": [47.0, 9.0, 66.0, 8.0, 60.0, 27.0, 41.0, 29.0]}, {"char_text": "e", "char_box": [67.0, 9.0, 85.0, 8.0, 80.0, 27.0, 61.0, 28.0]}, {"char_text": "s", "char_box": [88.0, 7.0, 106.0, 6.0, 101.0, 27.0, 82.0, 28.0]}, {"char_text": ":", "char_box": [106.0, 8.0, 118.0, 7.0, 113.0, 29.0, 101.0, 29.0]}], "text": "Lines:"} diff --git a/tests/data/ocr_toy_dataset/imgs/1036169.jpg b/tests/data/ocr_toy_dataset/imgs/1036169.jpg new file mode 100644 index 00000000..51857bb8 Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1036169.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1058891.jpg b/tests/data/ocr_toy_dataset/imgs/1058891.jpg new file mode 100644 index 00000000..dfff2c37 Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1058891.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1058892.jpg b/tests/data/ocr_toy_dataset/imgs/1058892.jpg new file mode 100644 index 00000000..dea537de Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1058892.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1190237.jpg b/tests/data/ocr_toy_dataset/imgs/1190237.jpg new file mode 100644 index 00000000..5395fddf Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1190237.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1210236.jpg b/tests/data/ocr_toy_dataset/imgs/1210236.jpg new file mode 100644 index 00000000..5203a7e4 Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1210236.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1223729.jpg b/tests/data/ocr_toy_dataset/imgs/1223729.jpg new file mode 100644 index 00000000..5fe2a906 Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1223729.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1223731.jpg b/tests/data/ocr_toy_dataset/imgs/1223731.jpg new file mode 100644 index 00000000..1c77d194 Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1223731.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1223732.jpg b/tests/data/ocr_toy_dataset/imgs/1223732.jpg new file mode 100644 index 00000000..8a65a5dc Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1223732.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1223733.jpg b/tests/data/ocr_toy_dataset/imgs/1223733.jpg new file mode 100644 index 00000000..daa61483 Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1223733.jpg differ diff --git a/tests/data/ocr_toy_dataset/imgs/1240078.jpg b/tests/data/ocr_toy_dataset/imgs/1240078.jpg new file mode 100644 index 00000000..7346e68e Binary files /dev/null and b/tests/data/ocr_toy_dataset/imgs/1240078.jpg differ diff --git a/tests/data/ocr_toy_dataset/label.lmdb/data.mdb b/tests/data/ocr_toy_dataset/label.lmdb/data.mdb new file mode 100644 index 00000000..5876a258 Binary files /dev/null and b/tests/data/ocr_toy_dataset/label.lmdb/data.mdb differ diff --git a/tests/data/ocr_toy_dataset/label.lmdb/lock.mdb b/tests/data/ocr_toy_dataset/label.lmdb/lock.mdb new file mode 100644 index 00000000..2ad277ed Binary files /dev/null and b/tests/data/ocr_toy_dataset/label.lmdb/lock.mdb differ diff --git a/tests/data/ocr_toy_dataset/label.txt b/tests/data/ocr_toy_dataset/label.txt new file mode 100644 index 00000000..4b20ed5a --- /dev/null +++ b/tests/data/ocr_toy_dataset/label.txt @@ -0,0 +1,10 @@ +1223731.jpg GRAND +1223733.jpg HOTEL +1223732.jpg HOTEL +1223729.jpg PACIFIC +1036169.jpg 03/09/2009 +1190237.jpg ANING +1058891.jpg Virgin +1058892.jpg america +1240078.jpg ATTACK +1210236.jpg DAVIDSON diff --git a/tests/data/test_img1.jpg b/tests/data/test_img1.jpg new file mode 100644 index 00000000..b7f9d1fd Binary files /dev/null and b/tests/data/test_img1.jpg differ diff --git a/tests/data/test_img1.png b/tests/data/test_img1.png new file mode 100644 index 00000000..9e2554b5 Binary files /dev/null and b/tests/data/test_img1.png differ diff --git a/tests/data/test_img2.jpg b/tests/data/test_img2.jpg new file mode 100644 index 00000000..f71aa6bb Binary files /dev/null and b/tests/data/test_img2.jpg differ diff --git a/tests/data/toy_dataset/annotations/test/gt_img_1.txt b/tests/data/toy_dataset/annotations/test/gt_img_1.txt new file mode 100644 index 00000000..1b22ebbd --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_1.txt @@ -0,0 +1,7 @@ +377,117,463,117,465,130,378,130,Genaxis Theatre +493,115,519,115,519,131,493,131,[06] +374,155,409,155,409,170,374,170,### +492,151,551,151,551,170,492,170,62-03 +376,198,422,198,422,212,376,212,Carpark +494,190,539,189,539,205,494,206,### +374,1,494,0,492,85,372,86,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_10.txt b/tests/data/toy_dataset/annotations/test/gt_img_10.txt new file mode 100644 index 00000000..01334be1 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_10.txt @@ -0,0 +1,8 @@ +261,138,284,140,279,158,260,158,### +288,138,417,140,416,161,290,157,HarbourFront +743,145,779,146,780,163,746,163,CC22 +783,129,831,132,833,155,785,153,bua +831,133,870,135,874,156,835,155,### +159,205,230,204,231,218,159,219,### +785,158,856,158,860,178,787,179,### +1011,157,1079,160,1076,173,1011,170,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_2.txt b/tests/data/toy_dataset/annotations/test/gt_img_2.txt new file mode 100644 index 00000000..19b427c2 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_2.txt @@ -0,0 +1,2 @@ +602,173,635,175,634,197,602,196,EXIT +734,310,792,320,792,364,738,361,I2R diff --git a/tests/data/toy_dataset/annotations/test/gt_img_3.txt b/tests/data/toy_dataset/annotations/test/gt_img_3.txt new file mode 100644 index 00000000..484f6c57 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_3.txt @@ -0,0 +1,13 @@ +58,80,191,71,194,114,61,123,fusionopolis +147,21,176,21,176,36,147,36,### +328,75,391,81,387,112,326,113,### +401,76,448,84,445,108,402,111,### +780,7,1015,6,1016,37,788,42,### +221,72,311,80,312,117,222,118,fusionopolis +113,19,144,19,144,33,113,33,### +257,28,308,28,308,57,257,57,### +140,120,196,115,195,129,141,133,### +86,176,110,177,112,189,89,196,### +101,193,129,185,132,198,103,204,### +223,175,244,150,294,183,235,197,### +140,239,174,232,176,247,142,256,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_4.txt b/tests/data/toy_dataset/annotations/test/gt_img_4.txt new file mode 100644 index 00000000..8b40444a --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_4.txt @@ -0,0 +1,3 @@ +692,268,710,268,710,293,692,293,### +663,224,733,230,737,246,661,242,### +668,242,737,244,734,260,670,256,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_5.txt b/tests/data/toy_dataset/annotations/test/gt_img_5.txt new file mode 100644 index 00000000..815420f9 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_5.txt @@ -0,0 +1,2 @@ +408,409,437,436,434,461,405,433,### +437,434,443,440,441,467,435,462,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_6.txt b/tests/data/toy_dataset/annotations/test/gt_img_6.txt new file mode 100644 index 00000000..0d483f22 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_6.txt @@ -0,0 +1,20 @@ +875,92,910,92,910,112,875,112,### +748,95,787,95,787,109,748,109,### +106,395,150,394,153,425,106,424,### +165,393,213,396,210,421,165,421,### +706,52,747,49,746,62,705,64,### +111,459,206,461,207,482,113,480,Reserve +831,9,894,9,894,22,831,22,### +641,456,693,454,693,467,641,469,CAUTION +839,32,891,32,891,47,839,47,### +788,46,831,46,831,59,788,59,### +830,95,872,95,872,106,830,106,### +921,92,952,92,952,111,921,111,### +968,40,1013,40,1013,53,968,53,### +1002,89,1031,89,1031,100,1002,100,### +1043,38,1098,38,1098,52,1043,52,### +1069,85,1138,85,1138,99,1069,99,### +1128,36,1178,36,1178,52,1128,52,### +1168,84,1200,84,1200,97,1168,97,### +1223,27,1259,27,1255,49,1219,49,### +1264,28,1279,28,1279,46,1264,46,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_7.txt b/tests/data/toy_dataset/annotations/test/gt_img_7.txt new file mode 100644 index 00000000..58171fc4 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_7.txt @@ -0,0 +1,15 @@ +346,133,400,130,401,148,345,153,### +301,127,349,123,351,154,303,158,### +869,67,920,61,923,85,872,91,citi +886,144,934,141,932,157,884,160,smrt +634,106,812,86,816,104,634,121,### +418,117,469,112,471,143,420,148,### +634,124,781,107,783,123,635,135,### +634,138,844,117,843,141,636,155,### +468,124,518,117,525,138,468,143,### +301,181,532,162,530,182,301,201,### +296,157,396,147,400,165,300,174,### +420,151,526,136,527,154,421,163,### +617,251,657,250,656,282,616,285,### +695,246,738,243,738,276,698,278,### +739,241,760,241,763,260,742,262,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_8.txt b/tests/data/toy_dataset/annotations/test/gt_img_8.txt new file mode 100644 index 00000000..65a32e41 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_8.txt @@ -0,0 +1,8 @@ +568,347,623,350,617,380,568,375,WHY +626,347,673,345,668,382,625,380,PAY +675,351,725,350,726,381,678,379,FOR +598,381,728,385,724,420,598,413,NOTHING? +762,351,845,357,845,380,760,377,### +562,588,613,588,611,632,564,633,### +615,593,730,603,727,646,614,634,### +560,634,730,650,730,691,556,678,### diff --git a/tests/data/toy_dataset/annotations/test/gt_img_9.txt b/tests/data/toy_dataset/annotations/test/gt_img_9.txt new file mode 100644 index 00000000..f59d7d90 --- /dev/null +++ b/tests/data/toy_dataset/annotations/test/gt_img_9.txt @@ -0,0 +1,4 @@ +344,206,384,207,381,228,342,227,EXIT +47,183,94,183,83,212,42,206,### +913,515,1068,526,1081,595,921,578,STAGE +240,291,273,291,273,298,240,297,### diff --git a/tests/data/toy_dataset/img_list.txt b/tests/data/toy_dataset/img_list.txt new file mode 100644 index 00000000..206384cf --- /dev/null +++ b/tests/data/toy_dataset/img_list.txt @@ -0,0 +1,10 @@ +img_10.jpg +img_1.jpg +img_2.jpg +img_3.jpg +img_4.jpg +img_5.jpg +img_6.jpg +img_7.jpg +img_8.jpg +img_9.jpg diff --git a/tests/data/toy_dataset/imgs/test/img_1.jpg b/tests/data/toy_dataset/imgs/test/img_1.jpg new file mode 100644 index 00000000..e2606540 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_1.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_10.jpg b/tests/data/toy_dataset/imgs/test/img_10.jpg new file mode 100644 index 00000000..f8949164 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_10.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_2.jpg b/tests/data/toy_dataset/imgs/test/img_2.jpg new file mode 100644 index 00000000..32275ce0 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_2.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_3.jpg b/tests/data/toy_dataset/imgs/test/img_3.jpg new file mode 100644 index 00000000..b6c12926 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_3.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_4.jpg b/tests/data/toy_dataset/imgs/test/img_4.jpg new file mode 100644 index 00000000..413de43f Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_4.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_5.jpg b/tests/data/toy_dataset/imgs/test/img_5.jpg new file mode 100644 index 00000000..873f4ac2 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_5.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_6.jpg b/tests/data/toy_dataset/imgs/test/img_6.jpg new file mode 100644 index 00000000..2d8e7471 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_6.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_7.jpg b/tests/data/toy_dataset/imgs/test/img_7.jpg new file mode 100644 index 00000000..efd7d9b0 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_7.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_8.jpg b/tests/data/toy_dataset/imgs/test/img_8.jpg new file mode 100644 index 00000000..89705000 Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_8.jpg differ diff --git a/tests/data/toy_dataset/imgs/test/img_9.jpg b/tests/data/toy_dataset/imgs/test/img_9.jpg new file mode 100644 index 00000000..975f5c4b Binary files /dev/null and b/tests/data/toy_dataset/imgs/test/img_9.jpg differ diff --git a/tests/data/toy_dataset/instances_test.json b/tests/data/toy_dataset/instances_test.json new file mode 100644 index 00000000..1dd51cc8 --- /dev/null +++ b/tests/data/toy_dataset/instances_test.json @@ -0,0 +1 @@ +{"images": [{"file_name": "test/img_10.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_10.txt", "id": 0}, {"file_name": "test/img_2.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_2.txt", "id": 1}, {"file_name": "test/img_6.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_6.txt", "id": 2}, {"file_name": "test/img_3.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_3.txt", "id": 3}, {"file_name": "test/img_9.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_9.txt", "id": 4}, {"file_name": "test/img_8.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_8.txt", "id": 5}, {"file_name": "test/img_1.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_1.txt", "id": 6}, {"file_name": "test/img_5.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_5.txt", "id": 7}, {"file_name": "test/img_7.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_7.txt", "id": 8}, {"file_name": "test/img_4.jpg", "height": 720, "width": 1280, "segm_file": "test/gt_img_4.txt", "id": 9}], "categories": [{"id": 1, "name": "text"}], "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [260.0, 138.0, 24.0, 20.0], "area": 402.0, "segmentation": [[261, 138, 284, 140, 279, 158, 260, 158]], "image_id": 0, "id": 0}, {"iscrowd": 0, "category_id": 1, "bbox": [288.0, 138.0, 129.0, 23.0], "area": 2548.5, "segmentation": [[288, 138, 417, 140, 416, 161, 290, 157]], "image_id": 0, "id": 1}, {"iscrowd": 0, "category_id": 1, "bbox": [743.0, 145.0, 37.0, 18.0], "area": 611.5, "segmentation": [[743, 145, 779, 146, 780, 163, 746, 163]], "image_id": 0, "id": 2}, {"iscrowd": 0, "category_id": 1, "bbox": [783.0, 129.0, 50.0, 26.0], "area": 1123.0, "segmentation": [[783, 129, 831, 132, 833, 155, 785, 153]], "image_id": 0, "id": 3}, {"iscrowd": 1, "category_id": 1, "bbox": [831.0, 133.0, 43.0, 23.0], "area": 832.5, "segmentation": [[831, 133, 870, 135, 874, 156, 835, 155]], "image_id": 0, "id": 4}, {"iscrowd": 1, "category_id": 1, "bbox": [159.0, 204.0, 72.0, 15.0], "area": 1001.5, "segmentation": [[159, 205, 230, 204, 231, 218, 159, 219]], "image_id": 0, "id": 5}, {"iscrowd": 1, "category_id": 1, "bbox": [785.0, 158.0, 75.0, 21.0], "area": 1477.5, "segmentation": [[785, 158, 856, 158, 860, 178, 787, 179]], "image_id": 0, "id": 6}, {"iscrowd": 1, "category_id": 1, "bbox": [1011.0, 157.0, 68.0, 16.0], "area": 869.0, "segmentation": [[1011, 157, 1079, 160, 1076, 173, 1011, 170]], "image_id": 0, "id": 7}, {"iscrowd": 0, "category_id": 1, "bbox": [602.0, 173.0, 33.0, 24.0], "area": 732.0, "segmentation": [[602, 173, 635, 175, 634, 197, 602, 196]], "image_id": 1, "id": 8}, {"iscrowd": 0, "category_id": 1, "bbox": [734.0, 310.0, 58.0, 54.0], "area": 2647.0, "segmentation": [[734, 310, 792, 320, 792, 364, 738, 361]], "image_id": 1, "id": 9}, {"iscrowd": 1, "category_id": 1, "bbox": [875.0, 92.0, 35.0, 20.0], "area": 700.0, "segmentation": [[875, 92, 910, 92, 910, 112, 875, 112]], "image_id": 2, "id": 10}, {"iscrowd": 1, "category_id": 1, "bbox": [748.0, 95.0, 39.0, 14.0], "area": 546.0, "segmentation": [[748, 95, 787, 95, 787, 109, 748, 109]], "image_id": 2, "id": 11}, {"iscrowd": 1, "category_id": 1, "bbox": [106.0, 394.0, 47.0, 31.0], "area": 1365.0, "segmentation": [[106, 395, 150, 394, 153, 425, 106, 424]], "image_id": 2, "id": 12}, {"iscrowd": 1, "category_id": 1, "bbox": [165.0, 393.0, 48.0, 28.0], "area": 1234.5, "segmentation": [[165, 393, 213, 396, 210, 421, 165, 421]], "image_id": 2, "id": 13}, {"iscrowd": 1, "category_id": 1, "bbox": [705.0, 49.0, 42.0, 15.0], "area": 510.0, "segmentation": [[706, 52, 747, 49, 746, 62, 705, 64]], "image_id": 2, "id": 14}, {"iscrowd": 0, "category_id": 1, "bbox": [111.0, 459.0, 96.0, 23.0], "area": 1981.5, "segmentation": [[111, 459, 206, 461, 207, 482, 113, 480]], "image_id": 2, "id": 15}, {"iscrowd": 1, "category_id": 1, "bbox": [831.0, 9.0, 63.0, 13.0], "area": 819.0, "segmentation": [[831, 9, 894, 9, 894, 22, 831, 22]], "image_id": 2, "id": 16}, {"iscrowd": 0, "category_id": 1, "bbox": [641.0, 454.0, 52.0, 15.0], "area": 676.0, "segmentation": [[641, 456, 693, 454, 693, 467, 641, 469]], "image_id": 2, "id": 17}, {"iscrowd": 1, "category_id": 1, "bbox": [839.0, 32.0, 52.0, 15.0], "area": 780.0, "segmentation": [[839, 32, 891, 32, 891, 47, 839, 47]], "image_id": 2, "id": 18}, {"iscrowd": 1, "category_id": 1, "bbox": [788.0, 46.0, 43.0, 13.0], "area": 559.0, "segmentation": [[788, 46, 831, 46, 831, 59, 788, 59]], "image_id": 2, "id": 19}, {"iscrowd": 1, "category_id": 1, "bbox": [830.0, 95.0, 42.0, 11.0], "area": 462.0, "segmentation": [[830, 95, 872, 95, 872, 106, 830, 106]], "image_id": 2, "id": 20}, {"iscrowd": 1, "category_id": 1, "bbox": [921.0, 92.0, 31.0, 19.0], "area": 589.0, "segmentation": [[921, 92, 952, 92, 952, 111, 921, 111]], "image_id": 2, "id": 21}, {"iscrowd": 1, "category_id": 1, "bbox": [968.0, 40.0, 45.0, 13.0], "area": 585.0, "segmentation": [[968, 40, 1013, 40, 1013, 53, 968, 53]], "image_id": 2, "id": 22}, {"iscrowd": 1, "category_id": 1, "bbox": [1002.0, 89.0, 29.0, 11.0], "area": 319.0, "segmentation": [[1002, 89, 1031, 89, 1031, 100, 1002, 100]], "image_id": 2, "id": 23}, {"iscrowd": 1, "category_id": 1, "bbox": [1043.0, 38.0, 55.0, 14.0], "area": 770.0, "segmentation": [[1043, 38, 1098, 38, 1098, 52, 1043, 52]], "image_id": 2, "id": 24}, {"iscrowd": 1, "category_id": 1, "bbox": [1069.0, 85.0, 69.0, 14.0], "area": 966.0, "segmentation": [[1069, 85, 1138, 85, 1138, 99, 1069, 99]], "image_id": 2, "id": 25}, {"iscrowd": 1, "category_id": 1, "bbox": [1128.0, 36.0, 50.0, 16.0], "area": 800.0, "segmentation": [[1128, 36, 1178, 36, 1178, 52, 1128, 52]], "image_id": 2, "id": 26}, {"iscrowd": 1, "category_id": 1, "bbox": [1168.0, 84.0, 32.0, 13.0], "area": 416.0, "segmentation": [[1168, 84, 1200, 84, 1200, 97, 1168, 97]], "image_id": 2, "id": 27}, {"iscrowd": 1, "category_id": 1, "bbox": [1219.0, 27.0, 40.0, 22.0], "area": 792.0, "segmentation": [[1223, 27, 1259, 27, 1255, 49, 1219, 49]], "image_id": 2, "id": 28}, {"iscrowd": 1, "category_id": 1, "bbox": [1264.0, 28.0, 15.0, 18.0], "area": 270.0, "segmentation": [[1264, 28, 1279, 28, 1279, 46, 1264, 46]], "image_id": 2, "id": 29}, {"iscrowd": 0, "category_id": 1, "bbox": [58.0, 71.0, 136.0, 52.0], "area": 5746.0, "segmentation": [[58, 80, 191, 71, 194, 114, 61, 123]], "image_id": 3, "id": 30}, {"iscrowd": 1, "category_id": 1, "bbox": [147.0, 21.0, 29.0, 15.0], "area": 435.0, "segmentation": [[147, 21, 176, 21, 176, 36, 147, 36]], "image_id": 3, "id": 31}, {"iscrowd": 1, "category_id": 1, "bbox": [326.0, 75.0, 65.0, 38.0], "area": 2146.5, "segmentation": [[328, 75, 391, 81, 387, 112, 326, 113]], "image_id": 3, "id": 32}, {"iscrowd": 1, "category_id": 1, "bbox": [401.0, 76.0, 47.0, 35.0], "area": 1330.0, "segmentation": [[401, 76, 448, 84, 445, 108, 402, 111]], "image_id": 3, "id": 33}, {"iscrowd": 1, "category_id": 1, "bbox": [780.0, 6.0, 236.0, 36.0], "area": 7653.0, "segmentation": [[780, 7, 1015, 6, 1016, 37, 788, 42]], "image_id": 3, "id": 34}, {"iscrowd": 0, "category_id": 1, "bbox": [221.0, 72.0, 91.0, 46.0], "area": 3731.5, "segmentation": [[221, 72, 311, 80, 312, 117, 222, 118]], "image_id": 3, "id": 35}, {"iscrowd": 1, "category_id": 1, "bbox": [113.0, 19.0, 31.0, 14.0], "area": 434.0, "segmentation": [[113, 19, 144, 19, 144, 33, 113, 33]], "image_id": 3, "id": 36}, {"iscrowd": 1, "category_id": 1, "bbox": [257.0, 28.0, 51.0, 29.0], "area": 1479.0, "segmentation": [[257, 28, 308, 28, 308, 57, 257, 57]], "image_id": 3, "id": 37}, {"iscrowd": 1, "category_id": 1, "bbox": [140.0, 115.0, 56.0, 18.0], "area": 742.5, "segmentation": [[140, 120, 196, 115, 195, 129, 141, 133]], "image_id": 3, "id": 38}, {"iscrowd": 1, "category_id": 1, "bbox": [86.0, 176.0, 26.0, 20.0], "area": 383.5, "segmentation": [[86, 176, 110, 177, 112, 189, 89, 196]], "image_id": 3, "id": 39}, {"iscrowd": 1, "category_id": 1, "bbox": [101.0, 185.0, 31.0, 19.0], "area": 359.5, "segmentation": [[101, 193, 129, 185, 132, 198, 103, 204]], "image_id": 3, "id": 40}, {"iscrowd": 1, "category_id": 1, "bbox": [223.0, 150.0, 71.0, 47.0], "area": 1704.5, "segmentation": [[223, 175, 244, 150, 294, 183, 235, 197]], "image_id": 3, "id": 41}, {"iscrowd": 1, "category_id": 1, "bbox": [140.0, 232.0, 36.0, 24.0], "area": 560.0, "segmentation": [[140, 239, 174, 232, 176, 247, 142, 256]], "image_id": 3, "id": 42}, {"iscrowd": 0, "category_id": 1, "bbox": [342.0, 206.0, 42.0, 22.0], "area": 832.0, "segmentation": [[344, 206, 384, 207, 381, 228, 342, 227]], "image_id": 4, "id": 43}, {"iscrowd": 1, "category_id": 1, "bbox": [42.0, 183.0, 52.0, 29.0], "area": 1168.0, "segmentation": [[47, 183, 94, 183, 83, 212, 42, 206]], "image_id": 4, "id": 44}, {"iscrowd": 0, "category_id": 1, "bbox": [913.0, 515.0, 168.0, 80.0], "area": 10248.0, "segmentation": [[913, 515, 1068, 526, 1081, 595, 921, 578]], "image_id": 4, "id": 45}, {"iscrowd": 1, "category_id": 1, "bbox": [240.0, 291.0, 33.0, 7.0], "area": 214.5, "segmentation": [[240, 291, 273, 291, 273, 298, 240, 297]], "image_id": 4, "id": 46}, {"iscrowd": 0, "category_id": 1, "bbox": [568.0, 347.0, 55.0, 33.0], "area": 1520.0, "segmentation": [[568, 347, 623, 350, 617, 380, 568, 375]], "image_id": 5, "id": 47}, {"iscrowd": 0, "category_id": 1, "bbox": [625.0, 345.0, 48.0, 37.0], "area": 1575.0, "segmentation": [[626, 347, 673, 345, 668, 382, 625, 380]], "image_id": 5, "id": 48}, {"iscrowd": 0, "category_id": 1, "bbox": [675.0, 350.0, 51.0, 31.0], "area": 1444.5, "segmentation": [[675, 351, 725, 350, 726, 381, 678, 379]], "image_id": 5, "id": 49}, {"iscrowd": 0, "category_id": 1, "bbox": [598.0, 381.0, 130.0, 39.0], "area": 4299.0, "segmentation": [[598, 381, 728, 385, 724, 420, 598, 413]], "image_id": 5, "id": 50}, {"iscrowd": 1, "category_id": 1, "bbox": [760.0, 351.0, 85.0, 29.0], "area": 2062.5, "segmentation": [[762, 351, 845, 357, 845, 380, 760, 377]], "image_id": 5, "id": 51}, {"iscrowd": 1, "category_id": 1, "bbox": [562.0, 588.0, 51.0, 45.0], "area": 2180.5, "segmentation": [[562, 588, 613, 588, 611, 632, 564, 633]], "image_id": 5, "id": 52}, {"iscrowd": 1, "category_id": 1, "bbox": [614.0, 593.0, 116.0, 53.0], "area": 4810.0, "segmentation": [[615, 593, 730, 603, 727, 646, 614, 634]], "image_id": 5, "id": 53}, {"iscrowd": 1, "category_id": 1, "bbox": [556.0, 634.0, 174.0, 57.0], "area": 7339.0, "segmentation": [[560, 634, 730, 650, 730, 691, 556, 678]], "image_id": 5, "id": 54}, {"iscrowd": 0, "category_id": 1, "bbox": [377.0, 117.0, 88.0, 13.0], "area": 1124.5, "segmentation": [[377, 117, 463, 117, 465, 130, 378, 130]], "image_id": 6, "id": 55}, {"iscrowd": 0, "category_id": 1, "bbox": [493.0, 115.0, 26.0, 16.0], "area": 416.0, "segmentation": [[493, 115, 519, 115, 519, 131, 493, 131]], "image_id": 6, "id": 56}, {"iscrowd": 1, "category_id": 1, "bbox": [374.0, 155.0, 35.0, 15.0], "area": 525.0, "segmentation": [[374, 155, 409, 155, 409, 170, 374, 170]], "image_id": 6, "id": 57}, {"iscrowd": 0, "category_id": 1, "bbox": [492.0, 151.0, 59.0, 19.0], "area": 1121.0, "segmentation": [[492, 151, 551, 151, 551, 170, 492, 170]], "image_id": 6, "id": 58}, {"iscrowd": 0, "category_id": 1, "bbox": [376.0, 198.0, 46.0, 14.0], "area": 644.0, "segmentation": [[376, 198, 422, 198, 422, 212, 376, 212]], "image_id": 6, "id": 59}, {"iscrowd": 1, "category_id": 1, "bbox": [494.0, 189.0, 45.0, 17.0], "area": 720.0, "segmentation": [[494, 190, 539, 189, 539, 205, 494, 206]], "image_id": 6, "id": 60}, {"iscrowd": 1, "category_id": 1, "bbox": [372.0, 0.0, 122.0, 86.0], "area": 10198.0, "segmentation": [[374, 1, 494, 0, 492, 85, 372, 86]], "image_id": 6, "id": 61}, {"iscrowd": 1, "category_id": 1, "bbox": [405.0, 409.0, 32.0, 52.0], "area": 793.0, "segmentation": [[408, 409, 437, 436, 434, 461, 405, 433]], "image_id": 7, "id": 62}, {"iscrowd": 1, "category_id": 1, "bbox": [435.0, 434.0, 8.0, 33.0], "area": 176.0, "segmentation": [[437, 434, 443, 440, 441, 467, 435, 462]], "image_id": 7, "id": 63}, {"iscrowd": 1, "category_id": 1, "bbox": [345.0, 130.0, 56.0, 23.0], "area": 1045.0, "segmentation": [[346, 133, 400, 130, 401, 148, 345, 153]], "image_id": 8, "id": 64}, {"iscrowd": 1, "category_id": 1, "bbox": [301.0, 123.0, 50.0, 35.0], "area": 1496.0, "segmentation": [[301, 127, 349, 123, 351, 154, 303, 158]], "image_id": 8, "id": 65}, {"iscrowd": 0, "category_id": 1, "bbox": [869.0, 61.0, 54.0, 30.0], "area": 1242.0, "segmentation": [[869, 67, 920, 61, 923, 85, 872, 91]], "image_id": 8, "id": 66}, {"iscrowd": 0, "category_id": 1, "bbox": [884.0, 141.0, 50.0, 19.0], "area": 762.0, "segmentation": [[886, 144, 934, 141, 932, 157, 884, 160]], "image_id": 8, "id": 67}, {"iscrowd": 1, "category_id": 1, "bbox": [634.0, 86.0, 182.0, 35.0], "area": 3007.0, "segmentation": [[634, 106, 812, 86, 816, 104, 634, 121]], "image_id": 8, "id": 68}, {"iscrowd": 1, "category_id": 1, "bbox": [418.0, 112.0, 53.0, 36.0], "area": 1591.0, "segmentation": [[418, 117, 469, 112, 471, 143, 420, 148]], "image_id": 8, "id": 69}, {"iscrowd": 1, "category_id": 1, "bbox": [634.0, 107.0, 149.0, 28.0], "area": 2013.0, "segmentation": [[634, 124, 781, 107, 783, 123, 635, 135]], "image_id": 8, "id": 70}, {"iscrowd": 1, "category_id": 1, "bbox": [634.0, 117.0, 210.0, 38.0], "area": 4283.0, "segmentation": [[634, 138, 844, 117, 843, 141, 636, 155]], "image_id": 8, "id": 71}, {"iscrowd": 1, "category_id": 1, "bbox": [468.0, 117.0, 57.0, 26.0], "area": 1091.0, "segmentation": [[468, 124, 518, 117, 525, 138, 468, 143]], "image_id": 8, "id": 72}, {"iscrowd": 1, "category_id": 1, "bbox": [301.0, 162.0, 231.0, 39.0], "area": 4581.0, "segmentation": [[301, 181, 532, 162, 530, 182, 301, 201]], "image_id": 8, "id": 73}, {"iscrowd": 1, "category_id": 1, "bbox": [296.0, 147.0, 104.0, 27.0], "area": 1788.0, "segmentation": [[296, 157, 396, 147, 400, 165, 300, 174]], "image_id": 8, "id": 74}, {"iscrowd": 1, "category_id": 1, "bbox": [420.0, 136.0, 107.0, 27.0], "area": 1602.0, "segmentation": [[420, 151, 526, 136, 527, 154, 421, 163]], "image_id": 8, "id": 75}, {"iscrowd": 1, "category_id": 1, "bbox": [616.0, 250.0, 41.0, 35.0], "area": 1318.0, "segmentation": [[617, 251, 657, 250, 656, 282, 616, 285]], "image_id": 8, "id": 76}, {"iscrowd": 1, "category_id": 1, "bbox": [695.0, 243.0, 43.0, 35.0], "area": 1352.5, "segmentation": [[695, 246, 738, 243, 738, 276, 698, 278]], "image_id": 8, "id": 77}, {"iscrowd": 1, "category_id": 1, "bbox": [739.0, 241.0, 24.0, 21.0], "area": 423.0, "segmentation": [[739, 241, 760, 241, 763, 260, 742, 262]], "image_id": 8, "id": 78}, {"iscrowd": 1, "category_id": 1, "bbox": [692.0, 268.0, 18.0, 25.0], "area": 450.0, "segmentation": [[692, 268, 710, 268, 710, 293, 692, 293]], "image_id": 9, "id": 79}, {"iscrowd": 1, "category_id": 1, "bbox": [661.0, 224.0, 76.0, 22.0], "area": 1236.0, "segmentation": [[663, 224, 733, 230, 737, 246, 661, 242]], "image_id": 9, "id": 80}, {"iscrowd": 1, "category_id": 1, "bbox": [668.0, 242.0, 69.0, 18.0], "area": 999.0, "segmentation": [[668, 242, 737, 244, 734, 260, 670, 256]], "image_id": 9, "id": 81}]} diff --git a/tests/data/toy_dataset/instances_test.txt b/tests/data/toy_dataset/instances_test.txt new file mode 100644 index 00000000..af3e8e65 --- /dev/null +++ b/tests/data/toy_dataset/instances_test.txt @@ -0,0 +1,10 @@ +{"file_name": "test/img_10.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [260.0, 138.0, 24.0, 20.0], "segmentation": [[261, 138, 284, 140, 279, 158, 260, 158]]}, {"iscrowd": 0, "category_id": 1, "bbox": [288.0, 138.0, 129.0, 23.0], "segmentation": [[288, 138, 417, 140, 416, 161, 290, 157]]}, {"iscrowd": 0, "category_id": 1, "bbox": [743.0, 145.0, 37.0, 18.0], "segmentation": [[743, 145, 779, 146, 780, 163, 746, 163]]}, {"iscrowd": 0, "category_id": 1, "bbox": [783.0, 129.0, 50.0, 26.0], "segmentation": [[783, 129, 831, 132, 833, 155, 785, 153]]}, {"iscrowd": 1, "category_id": 1, "bbox": [831.0, 133.0, 43.0, 23.0], "segmentation": [[831, 133, 870, 135, 874, 156, 835, 155]]}, {"iscrowd": 1, "category_id": 1, "bbox": [159.0, 204.0, 72.0, 15.0], "segmentation": [[159, 205, 230, 204, 231, 218, 159, 219]]}, {"iscrowd": 1, "category_id": 1, "bbox": [785.0, 158.0, 75.0, 21.0], "segmentation": [[785, 158, 856, 158, 860, 178, 787, 179]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1011.0, 157.0, 68.0, 16.0], "segmentation": [[1011, 157, 1079, 160, 1076, 173, 1011, 170]]}]} +{"file_name": "test/img_2.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 0, "category_id": 1, "bbox": [602.0, 173.0, 33.0, 24.0], "segmentation": [[602, 173, 635, 175, 634, 197, 602, 196]]}, {"iscrowd": 0, "category_id": 1, "bbox": [734.0, 310.0, 58.0, 54.0], "segmentation": [[734, 310, 792, 320, 792, 364, 738, 361]]}]} +{"file_name": "test/img_6.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [875.0, 92.0, 35.0, 20.0], "segmentation": [[875, 92, 910, 92, 910, 112, 875, 112]]}, {"iscrowd": 1, "category_id": 1, "bbox": [748.0, 95.0, 39.0, 14.0], "segmentation": [[748, 95, 787, 95, 787, 109, 748, 109]]}, {"iscrowd": 1, "category_id": 1, "bbox": [106.0, 394.0, 47.0, 31.0], "segmentation": [[106, 395, 150, 394, 153, 425, 106, 424]]}, {"iscrowd": 1, "category_id": 1, "bbox": [165.0, 393.0, 48.0, 28.0], "segmentation": [[165, 393, 213, 396, 210, 421, 165, 421]]}, {"iscrowd": 1, "category_id": 1, "bbox": [705.0, 49.0, 42.0, 15.0], "segmentation": [[706, 52, 747, 49, 746, 62, 705, 64]]}, {"iscrowd": 0, "category_id": 1, "bbox": [111.0, 459.0, 96.0, 23.0], "segmentation": [[111, 459, 206, 461, 207, 482, 113, 480]]}, {"iscrowd": 1, "category_id": 1, "bbox": [831.0, 9.0, 63.0, 13.0], "segmentation": [[831, 9, 894, 9, 894, 22, 831, 22]]}, {"iscrowd": 0, "category_id": 1, "bbox": [641.0, 454.0, 52.0, 15.0], "segmentation": [[641, 456, 693, 454, 693, 467, 641, 469]]}, {"iscrowd": 1, "category_id": 1, "bbox": [839.0, 32.0, 52.0, 15.0], "segmentation": [[839, 32, 891, 32, 891, 47, 839, 47]]}, {"iscrowd": 1, "category_id": 1, "bbox": [788.0, 46.0, 43.0, 13.0], "segmentation": [[788, 46, 831, 46, 831, 59, 788, 59]]}, {"iscrowd": 1, "category_id": 1, "bbox": [830.0, 95.0, 42.0, 11.0], "segmentation": [[830, 95, 872, 95, 872, 106, 830, 106]]}, {"iscrowd": 1, "category_id": 1, "bbox": [921.0, 92.0, 31.0, 19.0], "segmentation": [[921, 92, 952, 92, 952, 111, 921, 111]]}, {"iscrowd": 1, "category_id": 1, "bbox": [968.0, 40.0, 45.0, 13.0], "segmentation": [[968, 40, 1013, 40, 1013, 53, 968, 53]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1002.0, 89.0, 29.0, 11.0], "segmentation": [[1002, 89, 1031, 89, 1031, 100, 1002, 100]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1043.0, 38.0, 55.0, 14.0], "segmentation": [[1043, 38, 1098, 38, 1098, 52, 1043, 52]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1069.0, 85.0, 69.0, 14.0], "segmentation": [[1069, 85, 1138, 85, 1138, 99, 1069, 99]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1128.0, 36.0, 50.0, 16.0], "segmentation": [[1128, 36, 1178, 36, 1178, 52, 1128, 52]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1168.0, 84.0, 32.0, 13.0], "segmentation": [[1168, 84, 1200, 84, 1200, 97, 1168, 97]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1219.0, 27.0, 40.0, 22.0], "segmentation": [[1223, 27, 1259, 27, 1255, 49, 1219, 49]]}, {"iscrowd": 1, "category_id": 1, "bbox": [1264.0, 28.0, 15.0, 18.0], "segmentation": [[1264, 28, 1279, 28, 1279, 46, 1264, 46]]}]} +{"file_name": "test/img_3.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 0, "category_id": 1, "bbox": [58.0, 71.0, 136.0, 52.0], "segmentation": [[58, 80, 191, 71, 194, 114, 61, 123]]}, {"iscrowd": 1, "category_id": 1, "bbox": [147.0, 21.0, 29.0, 15.0], "segmentation": [[147, 21, 176, 21, 176, 36, 147, 36]]}, {"iscrowd": 1, "category_id": 1, "bbox": [326.0, 75.0, 65.0, 38.0], "segmentation": [[328, 75, 391, 81, 387, 112, 326, 113]]}, {"iscrowd": 1, "category_id": 1, "bbox": [401.0, 76.0, 47.0, 35.0], "segmentation": [[401, 76, 448, 84, 445, 108, 402, 111]]}, {"iscrowd": 1, "category_id": 1, "bbox": [780.0, 6.0, 236.0, 36.0], "segmentation": [[780, 7, 1015, 6, 1016, 37, 788, 42]]}, {"iscrowd": 0, "category_id": 1, "bbox": [221.0, 72.0, 91.0, 46.0], "segmentation": [[221, 72, 311, 80, 312, 117, 222, 118]]}, {"iscrowd": 1, "category_id": 1, "bbox": [113.0, 19.0, 31.0, 14.0], "segmentation": [[113, 19, 144, 19, 144, 33, 113, 33]]}, {"iscrowd": 1, "category_id": 1, "bbox": [257.0, 28.0, 51.0, 29.0], "segmentation": [[257, 28, 308, 28, 308, 57, 257, 57]]}, {"iscrowd": 1, "category_id": 1, "bbox": [140.0, 115.0, 56.0, 18.0], "segmentation": [[140, 120, 196, 115, 195, 129, 141, 133]]}, {"iscrowd": 1, "category_id": 1, "bbox": [86.0, 176.0, 26.0, 20.0], "segmentation": [[86, 176, 110, 177, 112, 189, 89, 196]]}, {"iscrowd": 1, "category_id": 1, "bbox": [101.0, 185.0, 31.0, 19.0], "segmentation": [[101, 193, 129, 185, 132, 198, 103, 204]]}, {"iscrowd": 1, "category_id": 1, "bbox": [223.0, 150.0, 71.0, 47.0], "segmentation": [[223, 175, 244, 150, 294, 183, 235, 197]]}, {"iscrowd": 1, "category_id": 1, "bbox": [140.0, 232.0, 36.0, 24.0], "segmentation": [[140, 239, 174, 232, 176, 247, 142, 256]]}]} +{"file_name": "test/img_9.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 0, "category_id": 1, "bbox": [342.0, 206.0, 42.0, 22.0], "segmentation": [[344, 206, 384, 207, 381, 228, 342, 227]]}, {"iscrowd": 1, "category_id": 1, "bbox": [42.0, 183.0, 52.0, 29.0], "segmentation": [[47, 183, 94, 183, 83, 212, 42, 206]]}, {"iscrowd": 0, "category_id": 1, "bbox": [913.0, 515.0, 168.0, 80.0], "segmentation": [[913, 515, 1068, 526, 1081, 595, 921, 578]]}, {"iscrowd": 1, "category_id": 1, "bbox": [240.0, 291.0, 33.0, 7.0], "segmentation": [[240, 291, 273, 291, 273, 298, 240, 297]]}]} +{"file_name": "test/img_8.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 0, "category_id": 1, "bbox": [568.0, 347.0, 55.0, 33.0], "segmentation": [[568, 347, 623, 350, 617, 380, 568, 375]]}, {"iscrowd": 0, "category_id": 1, "bbox": [625.0, 345.0, 48.0, 37.0], "segmentation": [[626, 347, 673, 345, 668, 382, 625, 380]]}, {"iscrowd": 0, "category_id": 1, "bbox": [675.0, 350.0, 51.0, 31.0], "segmentation": [[675, 351, 725, 350, 726, 381, 678, 379]]}, {"iscrowd": 0, "category_id": 1, "bbox": [598.0, 381.0, 130.0, 39.0], "segmentation": [[598, 381, 728, 385, 724, 420, 598, 413]]}, {"iscrowd": 1, "category_id": 1, "bbox": [760.0, 351.0, 85.0, 29.0], "segmentation": [[762, 351, 845, 357, 845, 380, 760, 377]]}, {"iscrowd": 1, "category_id": 1, "bbox": [562.0, 588.0, 51.0, 45.0], "segmentation": [[562, 588, 613, 588, 611, 632, 564, 633]]}, {"iscrowd": 1, "category_id": 1, "bbox": [614.0, 593.0, 116.0, 53.0], "segmentation": [[615, 593, 730, 603, 727, 646, 614, 634]]}, {"iscrowd": 1, "category_id": 1, "bbox": [556.0, 634.0, 174.0, 57.0], "segmentation": [[560, 634, 730, 650, 730, 691, 556, 678]]}]} +{"file_name": "test/img_1.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 0, "category_id": 1, "bbox": [377.0, 117.0, 88.0, 13.0], "segmentation": [[377, 117, 463, 117, 465, 130, 378, 130]]}, {"iscrowd": 0, "category_id": 1, "bbox": [493.0, 115.0, 26.0, 16.0], "segmentation": [[493, 115, 519, 115, 519, 131, 493, 131]]}, {"iscrowd": 1, "category_id": 1, "bbox": [374.0, 155.0, 35.0, 15.0], "segmentation": [[374, 155, 409, 155, 409, 170, 374, 170]]}, {"iscrowd": 0, "category_id": 1, "bbox": [492.0, 151.0, 59.0, 19.0], "segmentation": [[492, 151, 551, 151, 551, 170, 492, 170]]}, {"iscrowd": 0, "category_id": 1, "bbox": [376.0, 198.0, 46.0, 14.0], "segmentation": [[376, 198, 422, 198, 422, 212, 376, 212]]}, {"iscrowd": 1, "category_id": 1, "bbox": [494.0, 189.0, 45.0, 17.0], "segmentation": [[494, 190, 539, 189, 539, 205, 494, 206]]}, {"iscrowd": 1, "category_id": 1, "bbox": [372.0, 0.0, 122.0, 86.0], "segmentation": [[374, 1, 494, 0, 492, 85, 372, 86]]}]} +{"file_name": "test/img_5.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [405.0, 409.0, 32.0, 52.0], "segmentation": [[408, 409, 437, 436, 434, 461, 405, 433]]}, {"iscrowd": 1, "category_id": 1, "bbox": [435.0, 434.0, 8.0, 33.0], "segmentation": [[437, 434, 443, 440, 441, 467, 435, 462]]}]} +{"file_name": "test/img_7.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [345.0, 130.0, 56.0, 23.0], "segmentation": [[346, 133, 400, 130, 401, 148, 345, 153]]}, {"iscrowd": 1, "category_id": 1, "bbox": [301.0, 123.0, 50.0, 35.0], "segmentation": [[301, 127, 349, 123, 351, 154, 303, 158]]}, {"iscrowd": 0, "category_id": 1, "bbox": [869.0, 61.0, 54.0, 30.0], "segmentation": [[869, 67, 920, 61, 923, 85, 872, 91]]}, {"iscrowd": 0, "category_id": 1, "bbox": [884.0, 141.0, 50.0, 19.0], "segmentation": [[886, 144, 934, 141, 932, 157, 884, 160]]}, {"iscrowd": 1, "category_id": 1, "bbox": [634.0, 86.0, 182.0, 35.0], "segmentation": [[634, 106, 812, 86, 816, 104, 634, 121]]}, {"iscrowd": 1, "category_id": 1, "bbox": [418.0, 112.0, 53.0, 36.0], "segmentation": [[418, 117, 469, 112, 471, 143, 420, 148]]}, {"iscrowd": 1, "category_id": 1, "bbox": [634.0, 107.0, 149.0, 28.0], "segmentation": [[634, 124, 781, 107, 783, 123, 635, 135]]}, {"iscrowd": 1, "category_id": 1, "bbox": [634.0, 117.0, 210.0, 38.0], "segmentation": [[634, 138, 844, 117, 843, 141, 636, 155]]}, {"iscrowd": 1, "category_id": 1, "bbox": [468.0, 117.0, 57.0, 26.0], "segmentation": [[468, 124, 518, 117, 525, 138, 468, 143]]}, {"iscrowd": 1, "category_id": 1, "bbox": [301.0, 162.0, 231.0, 39.0], "segmentation": [[301, 181, 532, 162, 530, 182, 301, 201]]}, {"iscrowd": 1, "category_id": 1, "bbox": [296.0, 147.0, 104.0, 27.0], "segmentation": [[296, 157, 396, 147, 400, 165, 300, 174]]}, {"iscrowd": 1, "category_id": 1, "bbox": [420.0, 136.0, 107.0, 27.0], "segmentation": [[420, 151, 526, 136, 527, 154, 421, 163]]}, {"iscrowd": 1, "category_id": 1, "bbox": [616.0, 250.0, 41.0, 35.0], "segmentation": [[617, 251, 657, 250, 656, 282, 616, 285]]}, {"iscrowd": 1, "category_id": 1, "bbox": [695.0, 243.0, 43.0, 35.0], "segmentation": [[695, 246, 738, 243, 738, 276, 698, 278]]}, {"iscrowd": 1, "category_id": 1, "bbox": [739.0, 241.0, 24.0, 21.0], "segmentation": [[739, 241, 760, 241, 763, 260, 742, 262]]}]} +{"file_name": "test/img_4.jpg", "height": 720, "width": 1280, "annotations": [{"iscrowd": 1, "category_id": 1, "bbox": [692.0, 268.0, 18.0, 25.0], "segmentation": [[692, 268, 710, 268, 710, 293, 692, 293]]}, {"iscrowd": 1, "category_id": 1, "bbox": [661.0, 224.0, 76.0, 22.0], "segmentation": [[663, 224, 733, 230, 737, 246, 661, 242]]}, {"iscrowd": 1, "category_id": 1, "bbox": [668.0, 242.0, 69.0, 18.0], "segmentation": [[668, 242, 737, 244, 734, 260, 670, 256]]}]} diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py new file mode 100644 index 00000000..d94bec53 --- /dev/null +++ b/tests/test_apis/test_model_inference.py @@ -0,0 +1,46 @@ +import os +import shutil +import urllib + +import pytest + +from mmdet.apis import init_detector +from mmocr.apis.inference import model_inference + + +def test_model_inference(): + + project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + print(project_dir) + config_file = os.path.join( + project_dir, + '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py') + checkpoint_file = os.path.join( + project_dir, + '../checkpoints/sar_r31_parallel_decoder_academic-dba3a4a3.pth') + + if not os.path.exists(checkpoint_file): + url = ('https://download.openmmlab.com/mmocr' + '/textrecog/sar/' + 'sar_r31_parallel_decoder_academic-dba3a4a3.pth') + print(f'Downloading {url} ...') + local_filename, _ = urllib.request.urlretrieve(url) + os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True) + shutil.move(local_filename, checkpoint_file) + print(f'Saved as {checkpoint_file}') + else: + print(f'Using existing checkpoint {checkpoint_file}') + + device = 'cpu' + model = init_detector( + config_file, checkpoint=checkpoint_file, device=device) + if model.cfg.data.test['type'] == 'ConcatDataset': + model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ + 0].pipeline + + img = os.path.join(project_dir, '../demo/demo_text_recog.jpg') + + with pytest.raises(AssertionError): + model_inference(model, 1) + + model_inference(model, img) diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py new file mode 100644 index 00000000..43260d37 --- /dev/null +++ b/tests/test_dataset/test_base_dataset.py @@ -0,0 +1,74 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest + +from mmocr.datasets.base_dataset import BaseDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = 'sample1.jpg hello' + ann_info2 = 'sample2.jpg world' + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(ann_info + '\n') + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict(type='LineStrParser', keys=['file_name', 'text'])) + return loader + + +def test_custom_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_ann_file(ann_file) + loader = _create_dummy_loader() + + for mode in [True, False]: + dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode) + + # test len + assert len(dataset) == len(dataset.data_infos) + + # test set group flag + assert np.allclose(dataset.flag, [0, 0]) + + # test prepare_train_img + expect_results = { + 'img_info': { + 'file_name': 'sample1.jpg', + 'text': 'hello' + }, + 'img_prefix': '' + } + assert dataset.prepare_train_img(0) == expect_results + + # test prepare_test_img + assert dataset.prepare_test_img(0) == expect_results + + # test __getitem__ + assert dataset[0] == expect_results + + # test get_next_index + assert dataset._get_next_index(0) == 1 + + # test format_resuls + expect_results_copy = { + key: value + for key, value in expect_results.items() + } + dataset.format_results(expect_results) + assert expect_results_copy == expect_results + + # test evaluate + with pytest.raises(NotImplementedError): + dataset.evaluate(expect_results) + + tmp_dir.cleanup() diff --git a/tests/test_dataset/test_crop.py b/tests/test_dataset/test_crop.py new file mode 100644 index 00000000..a50ea54a --- /dev/null +++ b/tests/test_dataset/test_crop.py @@ -0,0 +1,96 @@ +import math + +import numpy as np +import pytest + +from mmocr.datasets.pipelines.box_utils import convert_canonical, sort_vertex +from mmocr.datasets.pipelines.crop import box_jitter, crop_img, warp_img + + +def test_order_vertex(): + dummy_points_x = [20, 20, 120, 120] + dummy_points_y = [20, 40, 40, 20] + + with pytest.raises(AssertionError): + sort_vertex([], dummy_points_y) + with pytest.raises(AssertionError): + sort_vertex(dummy_points_x, []) + + ordered_points_x, ordered_points_y = sort_vertex(dummy_points_x, + dummy_points_y) + + expect_points_x = [20, 120, 120, 20] + expect_points_y = [20, 20, 40, 40] + + assert np.allclose(ordered_points_x, expect_points_x) + assert np.allclose(ordered_points_y, expect_points_y) + + +def test_convert_canonical(): + dummy_points_x = [120, 120, 20, 20] + dummy_points_y = [20, 40, 40, 20] + + with pytest.raises(AssertionError): + convert_canonical([], dummy_points_y) + with pytest.raises(AssertionError): + convert_canonical(dummy_points_x, []) + + ordered_points_x, ordered_points_y = convert_canonical( + dummy_points_x, dummy_points_y) + + expect_points_x = [20, 120, 120, 20] + expect_points_y = [20, 20, 40, 40] + + assert np.allclose(ordered_points_x, expect_points_x) + assert np.allclose(ordered_points_y, expect_points_y) + + +def test_box_jitter(): + dummy_points_x = [20, 120, 120, 20] + dummy_points_y = [20, 20, 40, 40] + + kwargs = dict(jitter_ratio_x=0.0, jitter_ratio_y=0.0) + + with pytest.raises(AssertionError): + box_jitter([], dummy_points_y) + with pytest.raises(AssertionError): + box_jitter(dummy_points_x, []) + with pytest.raises(AssertionError): + box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_x=1.) + with pytest.raises(AssertionError): + box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_y=1.) + + box_jitter(dummy_points_x, dummy_points_y, **kwargs) + + assert np.allclose(dummy_points_x, [20, 120, 120, 20]) + assert np.allclose(dummy_points_y, [20, 20, 40, 40]) + + +def test_opencv_crop(): + dummy_img = np.ones((600, 600, 3), dtype=np.uint8) + dummy_box = [20, 20, 120, 20, 120, 40, 20, 40] + + cropped_img = warp_img(dummy_img, dummy_box) + + with pytest.raises(AssertionError): + warp_img(dummy_img, []) + with pytest.raises(AssertionError): + warp_img(dummy_img, [20, 40, 40, 20]) + + assert math.isclose(cropped_img.shape[0], 20) + assert math.isclose(cropped_img.shape[1], 100) + + +def test_min_rect_crop(): + dummy_img = np.ones((600, 600, 3), dtype=np.uint8) + dummy_box = [20, 20, 120, 20, 120, 40, 20, 40] + + cropped_img = crop_img(dummy_img, dummy_box) + + with pytest.raises(AssertionError): + crop_img(dummy_img, []) + with pytest.raises(AssertionError): + crop_img(dummy_img, [20, 40, 40, 20]) + + assert math.isclose(cropped_img.shape[0], 20) + assert math.isclose(cropped_img.shape[1], 100) diff --git a/tests/test_dataset/test_detect_dataset.py b/tests/test_dataset/test_detect_dataset.py new file mode 100644 index 00000000..83480c45 --- /dev/null +++ b/tests/test_dataset/test_detect_dataset.py @@ -0,0 +1,83 @@ +import json +import os.path as osp +import tempfile + +import numpy as np + +from mmocr.datasets.text_det_dataset import TextDetDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = { + 'file_name': + 'sample1.jpg', + 'height': + 640, + 'width': + 640, + 'annotations': [{ + 'iscrowd': 0, + 'category_id': 1, + 'bbox': [50, 70, 80, 100], + 'segmentation': [[50, 70, 80, 70, 80, 100, 50, 100]] + }, { + 'iscrowd': + 1, + 'category_id': + 1, + 'bbox': [120, 140, 200, 200], + 'segmentation': [[120, 140, 200, 140, 200, 200, 120, 200]] + }] + } + + with open(ann_file, 'w') as fw: + fw.write(json.dumps(ann_info1) + '\n') + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])) + return loader + + +def test_detect_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_ann_file(ann_file) + + # test initialization + loader = _create_dummy_loader() + dataset = TextDetDataset(ann_file, loader, pipeline=[]) + + # test _parse_ann_info + img_ann_info = dataset.data_infos[0] + ann = dataset._parse_anno_info(img_ann_info['annotations']) + print(ann['bboxes']) + assert np.allclose(ann['bboxes'], [[50., 70., 80., 100.]]) + assert np.allclose(ann['labels'], [1]) + assert np.allclose(ann['bboxes_ignore'], [[120, 140, 200, 200]]) + assert np.allclose(ann['masks'], [[[50, 70, 80, 70, 80, 100, 50, 100]]]) + assert np.allclose(ann['masks_ignore'], + [[[120, 140, 200, 140, 200, 200, 120, 200]]]) + + tmp_dir.cleanup() + + # test prepare_train_img + pipeline_results = dataset.prepare_train_img(0) + assert np.allclose(pipeline_results['bbox_fields'], []) + assert np.allclose(pipeline_results['mask_fields'], []) + assert np.allclose(pipeline_results['seg_fields'], []) + expect_img_info = {'filename': 'sample1.jpg', 'height': 640, 'width': 640} + assert pipeline_results['img_info'] == expect_img_info + + # test evluation + metrics = 'hmean-iou' + results = [{'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1]]}] + eval_res = dataset.evaluate(results, metrics) + + assert eval_res['hmean-iou:hmean'] == 1 diff --git a/tests/test_dataset/test_icdar_dataset.py b/tests/test_dataset/test_icdar_dataset.py new file mode 100644 index 00000000..7f6a1e7b --- /dev/null +++ b/tests/test_dataset/test_icdar_dataset.py @@ -0,0 +1,155 @@ +import os.path as osp +import tempfile + +import mmcv +import numpy as np + +from mmocr.datasets.icdar_dataset import IcdarDataset + + +def _create_dummy_icdar_json(json_name): + image_1 = { + 'id': 0, + 'width': 640, + 'height': 640, + 'file_name': 'fake_name.jpg', + } + image_2 = { + 'id': 1, + 'width': 640, + 'height': 640, + 'file_name': 'fake_name1.jpg', + } + + annotation_1 = { + 'id': 1, + 'image_id': 0, + 'category_id': 0, + 'area': 400, + 'bbox': [50, 60, 20, 20], + 'iscrowd': 0, + 'segmentation': [[50, 60, 70, 60, 70, 80, 50, 80]] + } + + annotation_2 = { + 'id': 2, + 'image_id': 0, + 'category_id': 0, + 'area': 900, + 'bbox': [100, 120, 30, 30], + 'iscrowd': 0, + 'segmentation': [[100, 120, 130, 120, 120, 150, 100, 150]] + } + + annotation_3 = { + 'id': 3, + 'image_id': 0, + 'category_id': 0, + 'area': 1600, + 'bbox': [150, 160, 40, 40], + 'iscrowd': 1, + 'segmentation': [[150, 160, 190, 160, 190, 200, 150, 200]] + } + + annotation_4 = { + 'id': 4, + 'image_id': 0, + 'category_id': 0, + 'area': 10000, + 'bbox': [250, 260, 100, 100], + 'iscrowd': 1, + 'segmentation': [[250, 260, 350, 260, 350, 360, 250, 360]] + } + annotation_5 = { + 'id': 5, + 'image_id': 1, + 'category_id': 0, + 'area': 10000, + 'bbox': [250, 260, 100, 100], + 'iscrowd': 1, + 'segmentation': [[250, 260, 350, 260, 350, 360, 250, 360]] + } + + categories = [{ + 'id': 0, + 'name': 'text', + 'supercategory': 'text', + }] + + fake_json = { + 'images': [image_1, image_2], + 'annotations': + [annotation_1, annotation_2, annotation_3, annotation_4, annotation_5], + 'categories': + categories + } + + mmcv.dump(fake_json, json_name) + + +def test_icdar_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + fake_json_file = osp.join(tmp_dir.name, 'fake_data.json') + _create_dummy_icdar_json(fake_json_file) + + # test initialization + dataset = IcdarDataset(ann_file=fake_json_file, pipeline=[]) + assert dataset.CLASSES == ('text') + assert dataset.img_ids == [0, 1] + assert dataset.select_first_k == -1 + + # test _parse_ann_info + ann = dataset.get_ann_info(0) + assert np.allclose(ann['bboxes'], + [[50., 60., 70., 80.], [100., 120., 130., 150.]]) + assert np.allclose(ann['labels'], [0, 0]) + assert np.allclose(ann['bboxes_ignore'], + [[150., 160., 190., 200.], [250., 260., 350., 360.]]) + assert np.allclose(ann['masks'], + [[[50, 60, 70, 60, 70, 80, 50, 80]], + [[100, 120, 130, 120, 120, 150, 100, 150]]]) + assert np.allclose(ann['masks_ignore'], + [[[150, 160, 190, 160, 190, 200, 150, 200]], + [[250, 260, 350, 260, 350, 360, 250, 360]]]) + assert dataset.cat_ids == [0] + + tmp_dir.cleanup() + + # test rank output + # result = [[]] + # out_file = tempfile.NamedTemporaryFile().name + + # with pytest.raises(AssertionError): + # dataset.output_ranklist(result, out_file) + + # result = [{'hmean': 1}, {'hmean': 0.5}] + + # output = dataset.output_ranklist(result, out_file) + + # assert output[0]['hmean'] == 0.5 + + # test get_gt_mask + # output = dataset.get_gt_mask() + # assert np.allclose(output[0][0], + # [[50, 60, 70, 60, 70, 80, 50, 80], + # [100, 120, 130, 120, 120, 150, 100, 150]]) + # assert output[0][1] == [] + # assert np.allclose(output[1][0], + # [[150, 160, 190, 160, 190, 200, 150, 200], + # [250, 260, 350, 260, 350, 360, 250, 360]]) + # assert np.allclose(output[1][1], + # [[250, 260, 350, 260, 350, 360, 250, 360]]) + + # test evluation + metrics = ['hmean-iou', 'hmean-ic13'] + results = [{ + 'boundary_result': [[50, 60, 70, 60, 70, 80, 50, 80, 1], + [100, 120, 130, 120, 120, 150, 100, 150, 1]] + }, { + 'boundary_result': [] + }] + output = dataset.evaluate(results, metrics) + + assert output['hmean-iou:hmean'] == 1 + assert output['hmean-ic13:hmean'] == 1 diff --git a/tests/test_dataset/test_kie_dataset.py b/tests/test_dataset/test_kie_dataset.py new file mode 100644 index 00000000..da74c6ed --- /dev/null +++ b/tests/test_dataset/test_kie_dataset.py @@ -0,0 +1,114 @@ +import json +import math +import os.path as osp +import tempfile + +import pytest +import torch + +from mmocr.datasets.kie_dataset import KIEDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = { + 'file_name': + 'sample1.png', + 'height': + 200, + 'width': + 200, + 'annotations': [{ + 'text': 'store', + 'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0], + 'label': 1 + }, { + 'text': 'address', + 'box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0], + 'label': 1 + }, { + 'text': 'price', + 'box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0], + 'label': 1 + }, { + 'text': '1.0', + 'box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0], + 'label': 1 + }, { + 'text': 'google', + 'box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0], + 'label': 1 + }] + } + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1]: + fw.write(json.dumps(ann_info) + '\n') + + return ann_info1 + + +def _create_dummy_dict_file(dict_file): + dict_str = '0123' + with open(dict_file, 'w') as fw: + for char in list(dict_str): + fw.write(char + '\n') + + return dict_str + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])) + return loader + + +def test_kie_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + ann_info1 = _create_dummy_ann_file(ann_file) + + dict_file = osp.join(tmp_dir.name, 'fake_dict.txt') + _create_dummy_dict_file(dict_file) + + # test initialization + loader = _create_dummy_loader() + dataset = KIEDataset(ann_file, loader, dict_file, pipeline=[]) + + tmp_dir.cleanup() + + # test pre_pipeline + img_info = dataset.data_infos[0] + results = dict(img_info=img_info) + dataset.pre_pipeline(results) + assert results['img_prefix'] == dataset.img_prefix + + # test _parse_anno_info + annos = ann_info1['annotations'] + with pytest.raises(AssertionError): + dataset._parse_anno_info(annos[0]) + tmp_annos = [{ + 'text': 'store', + 'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0] + }] + with pytest.raises(AssertionError): + dataset._parse_anno_info(tmp_annos) + + return_anno = dataset._parse_anno_info(annos) + assert 'bboxes' in return_anno + assert 'relations' in return_anno + assert 'texts' in return_anno + assert 'labels' in return_anno + + # test evaluation + result = {} + result['nodes'] = torch.full((5, 5), 1, dtype=torch.float) + result['nodes'][:, 1] = 100. + print('hello', result['nodes'].size()) + results = [result for _ in range(5)] + + eval_res = dataset.evaluate(results) + assert math.isclose(eval_res['macro_f1'], 0.2, abs_tol=1e-4) diff --git a/tests/test_dataset/test_loader.py b/tests/test_dataset/test_loader.py new file mode 100644 index 00000000..17e00447 --- /dev/null +++ b/tests/test_dataset/test_loader.py @@ -0,0 +1,71 @@ +import json +import os.path as osp +import tempfile + +import pytest + +from mmocr.datasets.utils.loader import HardDiskLoader, LmdbLoader, Loader +from mmocr.utils import lmdb_converter + + +def _create_dummy_line_str_file(ann_file): + ann_info1 = 'sample1.jpg hello' + ann_info2 = 'sample2.jpg world' + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(ann_info + '\n') + + +def _create_dummy_line_json_file(ann_file): + ann_info1 = {'filename': 'sample1.jpg', 'text': 'hello'} + ann_info2 = {'filename': 'sample2.jpg', 'text': 'world'} + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(json.dumps(ann_info) + '\n') + + +def test_loader(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_line_str_file(ann_file) + + parser = dict( + type='LineStrParser', + keys=['filename', 'text'], + keys_idx=[0, 1], + separator=' ') + + with pytest.raises(AssertionError): + Loader(ann_file, parser, repeat=0) + with pytest.raises(AssertionError): + Loader(ann_file, [], repeat=1) + with pytest.raises(AssertionError): + Loader('sample.txt', parser, repeat=1) + with pytest.raises(NotImplementedError): + loader = Loader(ann_file, parser, repeat=1) + print(loader) + + # test text loader and line str parser + text_loader = HardDiskLoader(ann_file, parser, repeat=1) + assert len(text_loader) == 2 + assert text_loader.ori_data_infos[0] == 'sample1.jpg hello' + assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'} + + # test text loader and linedict parser + _create_dummy_line_json_file(ann_file) + json_parser = dict(type='LineJsonParser', keys=['filename', 'text']) + text_loader = HardDiskLoader(ann_file, json_parser, repeat=1) + assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'} + + # test lmdb loader and line str parser + _create_dummy_line_str_file(ann_file) + lmdb_file = osp.join(tmp_dir.name, 'fake_data.lmdb') + lmdb_converter(ann_file, lmdb_file) + + lmdb_loader = LmdbLoader(lmdb_file, parser, repeat=1) + assert lmdb_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'} + + tmp_dir.cleanup() diff --git a/tests/test_dataset/test_loading.py b/tests/test_dataset/test_loading.py new file mode 100644 index 00000000..04e593ab --- /dev/null +++ b/tests/test_dataset/test_loading.py @@ -0,0 +1,38 @@ +import numpy as np + +from mmocr.datasets.pipelines import LoadTextAnnotations + + +def _create_dummy_ann(): + results = {} + results['img_info'] = {} + results['img_info']['height'] = 1000 + results['img_info']['width'] = 1000 + results['ann_info'] = {} + results['ann_info']['masks'] = [] + results['mask_fields'] = [] + results['ann_info']['masks_ignore'] = [ + [[499, 94, 531, 94, 531, 124, 499, 124]], + [[3, 156, 81, 155, 78, 181, 0, 182]], + [[11, 223, 59, 221, 59, 234, 11, 236]], + [[500, 156, 551, 156, 550, 165, 499, 165]] + ] + + return results + + +def test_loadtextannotation(): + + results = _create_dummy_ann() + with_bbox = True + with_label = True + with_mask = True + with_seg = False + poly2mask = False + + loader = LoadTextAnnotations(with_bbox, with_label, with_mask, with_seg, + poly2mask) + output = loader._load_masks(results) + assert len(output['gt_masks_ignore']) == 4 + assert np.allclose(output['gt_masks_ignore'].masks[0], + [[499, 94, 531, 94, 531, 124, 499, 124]]) diff --git a/tests/test_dataset/test_ocr_dataset.py b/tests/test_dataset/test_ocr_dataset.py new file mode 100644 index 00000000..1787db88 --- /dev/null +++ b/tests/test_dataset/test_ocr_dataset.py @@ -0,0 +1,51 @@ +import math +import os.path as osp +import tempfile + +from mmocr.datasets.ocr_dataset import OCRDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = 'sample1.jpg hello' + ann_info2 = 'sample2.jpg world' + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(ann_info + '\n') + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict(type='LineStrParser', keys=['file_name', 'text'])) + return loader + + +def test_detect_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + _create_dummy_ann_file(ann_file) + + # test initialization + loader = _create_dummy_loader() + dataset = OCRDataset(ann_file, loader, pipeline=[]) + + tmp_dir.cleanup() + + # test pre_pipeline + img_info = dataset.data_infos[0] + results = dict(img_info=img_info) + dataset.pre_pipeline(results) + assert results['img_prefix'] == dataset.img_prefix + assert results['text'] == img_info['text'] + + # test evluation + metric = 'acc' + results = [{'text': 'hello'}, {'text': 'worl'}] + eval_res = dataset.evaluate(results, metric) + + assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4) + assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4) + assert math.isclose(eval_res['char_recall'], 0.9, abs_tol=1e-4) diff --git a/tests/test_dataset/test_ocr_seg_dataset.py b/tests/test_dataset/test_ocr_seg_dataset.py new file mode 100644 index 00000000..0ecfcfdf --- /dev/null +++ b/tests/test_dataset/test_ocr_seg_dataset.py @@ -0,0 +1,127 @@ +import json +import math +import os.path as osp +import tempfile + +import pytest + +from mmocr.datasets.ocr_seg_dataset import OCRSegDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = { + 'file_name': + 'sample1.png', + 'annotations': [{ + 'char_text': + 'F', + 'char_box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0] + }, { + 'char_text': + 'r', + 'char_box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0] + }, { + 'char_text': + 'o', + 'char_box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0] + }, { + 'char_text': + 'm', + 'char_box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0] + }, { + 'char_text': + ':', + 'char_box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0] + }], + 'text': + 'From:' + } + ann_info2 = { + 'file_name': + 'sample2.png', + 'annotations': [{ + 'char_text': 'o', + 'char_box': [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0] + }, { + 'char_text': + 'u', + 'char_box': [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0] + }, { + 'char_text': + 't', + 'char_box': [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0] + }], + 'text': + 'out' + } + + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1, ann_info2]: + fw.write(json.dumps(ann_info) + '\n') + + return ann_info1, ann_info2 + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', keys=['file_name', 'text', 'annotations'])) + return loader + + +def test_ocr_seg_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + ann_info1, ann_info2 = _create_dummy_ann_file(ann_file) + + # test initialization + loader = _create_dummy_loader() + dataset = OCRSegDataset(ann_file, loader, pipeline=[]) + + tmp_dir.cleanup() + + # test pre_pipeline + img_info = dataset.data_infos[0] + results = dict(img_info=img_info) + dataset.pre_pipeline(results) + assert results['img_prefix'] == dataset.img_prefix + + # test _parse_anno_info + annos = ann_info1['annotations'] + with pytest.raises(AssertionError): + dataset._parse_anno_info(annos[0]) + annos2 = ann_info2['annotations'] + with pytest.raises(AssertionError): + dataset._parse_anno_info([{'char_text': 'i'}]) + with pytest.raises(AssertionError): + dataset._parse_anno_info([{'char_box': [1, 2, 3, 4, 5, 6, 7, 8]}]) + annos2[0]['char_box'] = [1, 2, 3] + with pytest.raises(AssertionError): + dataset._parse_anno_info(annos2) + + return_anno = dataset._parse_anno_info(annos) + assert return_anno['chars'] == ['F', 'r', 'o', 'm', ':'] + assert len(return_anno['char_rects']) == 5 + + # test prepare_train_img + expect_results = { + 'img_info': { + 'filename': 'sample1.png' + }, + 'img_prefix': '', + 'ann_info': return_anno + } + data = dataset.prepare_train_img(0) + assert data == expect_results + + # test evluation + metric = 'acc' + results = [{'text': 'From:'}, {'text': 'ou'}] + eval_res = dataset.evaluate(results, metric) + + assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4) + assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4) + assert math.isclose(eval_res['char_recall'], 0.857, abs_tol=1e-4) diff --git a/tests/test_dataset/test_ocr_seg_target.py b/tests/test_dataset/test_ocr_seg_target.py new file mode 100644 index 00000000..45b85352 --- /dev/null +++ b/tests/test_dataset/test_ocr_seg_target.py @@ -0,0 +1,93 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest + +from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets + + +def _create_dummy_dict_file(dict_file): + chars = list('0123456789') + with open(dict_file, 'w') as fw: + for char in chars: + fw.write(char + '\n') + + +def test_ocr_segm_targets(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy dict file + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + # dummy label convertor + label_convertor = dict( + type='SegConvertor', + dict_file=dict_file, + with_unknown=True, + lower=True) + # test init + with pytest.raises(AssertionError): + OCRSegTargets(None, 0.5, 0.5) + with pytest.raises(AssertionError): + OCRSegTargets(label_convertor, '1by2', 0.5) + with pytest.raises(AssertionError): + OCRSegTargets(label_convertor, 0.5, 2) + + ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5) + # test generate kernels + img_size = (8, 8) + pad_size = (8, 10) + char_boxes = [[2, 2, 6, 6]] + char_idxs = [2] + + with pytest.raises(AssertionError): + ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5, + True) + with pytest.raises(AssertionError): + ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6], + char_idxs, 0.5, True) + with pytest.raises(AssertionError): + ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5, + True) + + attn_tgt = ocr_seg_tgt.generate_kernels( + img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True) + expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], + [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], + [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] + assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32)) + + segm_tgt = ocr_seg_tgt.generate_kernels( + img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False) + expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], + [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], + [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], + [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] + assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32)) + + # test __call__ + results = {} + results['img_shape'] = (4, 4, 3) + results['resize_shape'] = (8, 8, 3) + results['pad_shape'] = (8, 10) + results['ann_info'] = {} + results['ann_info']['char_rects'] = [[1, 1, 3, 3]] + results['ann_info']['chars'] = ['1'] + + results = ocr_seg_tgt(results) + assert results['mask_fields'] == ['gt_kernels'] + assert np.allclose(results['gt_kernels'].masks[0], + np.array(expect_attn_tgt, dtype=np.int32)) + assert np.allclose(results['gt_kernels'].masks[1], + np.array(expect_segm_tgt, dtype=np.int32)) + + tmp_dir.cleanup() diff --git a/tests/test_dataset/test_ocr_transforms.py b/tests/test_dataset/test_ocr_transforms.py new file mode 100644 index 00000000..15522b25 --- /dev/null +++ b/tests/test_dataset/test_ocr_transforms.py @@ -0,0 +1,140 @@ +import math +import unittest.mock as mock + +import numpy as np +import torch +import torchvision.transforms.functional as TF +from PIL import Image + +import mmocr.datasets.pipelines.ocr_transforms as transforms + + +def test_resize_ocr(): + input_img = np.ones((64, 256, 3), dtype=np.uint8) + + rci = transforms.ResizeOCR( + 32, min_width=32, max_width=160, keep_aspect_ratio=True) + results = {'img_shape': input_img.shape, 'img': input_img} + + # test call + results = rci(results) + assert np.allclose([32, 160, 3], results['pad_shape']) + assert np.allclose([32, 160, 3], results['img'].shape) + assert 'valid_ratio' in results + assert math.isclose(results['valid_ratio'], 0.8) + assert math.isclose(np.sum(results['img'][:, 129:, :]), 0) + + rci = transforms.ResizeOCR( + 32, min_width=32, max_width=160, keep_aspect_ratio=False) + results = {'img_shape': input_img.shape, 'img': input_img} + results = rci(results) + assert math.isclose(results['valid_ratio'], 1) + + +def test_to_tensor(): + input_img = np.ones((64, 256, 3), dtype=np.uint8) + + expect_output = TF.to_tensor(input_img) + rci = transforms.ToTensorOCR() + + results = {'img': input_img} + results = rci(results) + + assert np.allclose(results['img'].numpy(), expect_output.numpy()) + + +def test_normalize(): + inputs = torch.zeros(3, 10, 10) + + expect_output = torch.ones_like(inputs) * (-1) + rci = transforms.NormalizeOCR(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + results = {'img': inputs} + results = rci(results) + + assert np.allclose(results['img'].numpy(), expect_output.numpy()) + + +@mock.patch('%s.transforms.np.random.random' % __name__) +def test_online_crop(mock_random): + kwargs = dict( + box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'], + jitter_prob=0.5, + max_jitter_ratio_x=0.05, + max_jitter_ratio_y=0.02) + + mock_random.side_effect = [0.1, 1, 1, 1] + + src_img = np.ones((100, 100, 3), dtype=np.uint8) + results = { + 'img': src_img, + 'img_info': { + 'x1': '20', + 'y1': '20', + 'x2': '40', + 'y2': '20', + 'x3': '40', + 'y3': '40', + 'x4': '20', + 'y4': '40' + } + } + + rci = transforms.OnlineCropOCR(**kwargs) + + results = rci(results) + + assert np.allclose(results['img_shape'], [20, 20, 3]) + + # test not crop + mock_random.side_effect = [0.1, 1, 1, 1] + results['img_info'] = {} + results['img'] = src_img + + results = rci(results) + assert np.allclose(results['img'].shape, [100, 100, 3]) + + +def test_fancy_pca(): + input_tensor = torch.rand(3, 32, 100) + + rci = transforms.FancyPCA() + + results = {'img': input_tensor} + results = rci(results) + + assert results['img'].shape == torch.Size([3, 32, 100]) + + +@mock.patch('%s.transforms.np.random.uniform' % __name__) +def test_random_padding(mock_random): + kwargs = dict(max_ratio=[0.0, 0.0, 0.0, 0.0], box_type=None) + + mock_random.side_effect = [1, 1, 1, 1] + + src_img = np.ones((32, 100, 3), dtype=np.uint8) + results = {'img': src_img, 'img_shape': (32, 100, 3)} + + rci = transforms.RandomPaddingOCR(**kwargs) + + results = rci(results) + print(results['img'].shape) + assert np.allclose(results['img_shape'], [96, 300, 3]) + + +def test_opencv2pil(): + src_img = np.ones((32, 100, 3), dtype=np.uint8) + results = {'img': src_img} + rci = transforms.OpencvToPil() + + results = rci(results) + assert np.allclose(results['img'].size, (100, 32)) + + +def test_pil2opencv(): + src_img = Image.new('RGB', (100, 32), color=(255, 255, 255)) + results = {'img': src_img} + rci = transforms.PilToOpencv() + + results = rci(results) + assert np.allclose(results['img'].shape, (32, 100, 3)) diff --git a/tests/test_dataset/test_parser.py b/tests/test_dataset/test_parser.py new file mode 100644 index 00000000..da38565d --- /dev/null +++ b/tests/test_dataset/test_parser.py @@ -0,0 +1,63 @@ +import json + +import pytest + +from mmocr.datasets.utils.parser import LineJsonParser, LineStrParser + + +def test_line_str_parser(): + data_ret = ['sample1.jpg hello', 'sample2.jpg world'] + keys = ['filename', 'text'] + keys_idx = [0, 1] + separator = ' ' + + # test init + with pytest.raises(AssertionError): + parser = LineStrParser('filename', keys_idx, separator) + with pytest.raises(AssertionError): + parser = LineStrParser(keys, keys_idx, [' ']) + with pytest.raises(AssertionError): + parser = LineStrParser(keys, [0], separator) + + # test get_item + parser = LineStrParser(keys, keys_idx, separator) + assert parser.get_item(data_ret, 0) == { + 'filename': 'sample1.jpg', + 'text': 'hello' + } + + with pytest.raises(Exception): + parser = LineStrParser(['filename', 'text', 'ignore'], [0, 1, 2], + separator) + parser.get_item(data_ret, 0) + + +def test_line_dict_parser(): + data_ret = [ + json.dumps({ + 'filename': 'sample1.jpg', + 'text': 'hello' + }), + json.dumps({ + 'filename': 'sample2.jpg', + 'text': 'world' + }) + ] + keys = ['filename', 'text'] + + # test init + with pytest.raises(AssertionError): + parser = LineJsonParser('filename') + with pytest.raises(AssertionError): + parser = LineJsonParser([]) + + # test get_item + parser = LineJsonParser(keys) + assert parser.get_item(data_ret, 0) == { + 'filename': 'sample1.jpg', + 'text': 'hello' + } + + with pytest.raises(Exception): + parser = LineJsonParser(['img_name', 'text']) + parser.get_item(data_ret, 0) diff --git a/tests/test_dataset/test_test_time_aug.py b/tests/test_dataset/test_test_time_aug.py new file mode 100644 index 00000000..22bf80c6 --- /dev/null +++ b/tests/test_dataset/test_test_time_aug.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest + +from mmocr.datasets.pipelines.test_time_aug import MultiRotateAugOCR + + +def test_resize_ocr(): + input_img1 = np.ones((64, 256, 3), dtype=np.uint8) + input_img2 = np.ones((64, 32, 3), dtype=np.uint8) + + rci = MultiRotateAugOCR(transforms=[], rotate_degrees=[0, 90, 270]) + + # test invalid arguments + with pytest.raises(AssertionError): + MultiRotateAugOCR(transforms=[], rotate_degrees=[45]) + with pytest.raises(AssertionError): + MultiRotateAugOCR(transforms=[], rotate_degrees=[20.5]) + + # test call with input_img1 + results = {'img_shape': input_img1.shape, 'img': input_img1} + results = rci(results) + assert np.allclose([64, 256, 3], results['img_shape']) + assert len(results['img']) == 1 + assert len(results['img_shape']) == 1 + assert np.allclose([64, 256, 3], results['img_shape'][0]) + + # test call with input_img2 + results = {'img_shape': input_img2.shape, 'img': input_img2} + results = rci(results) + assert np.allclose([64, 32, 3], results['img_shape']) + assert len(results['img']) == 3 + assert len(results['img_shape']) == 3 + assert np.allclose([64, 32, 3], results['img_shape'][0]) diff --git a/tests/test_dataset/test_textdet_targets.py b/tests/test_dataset/test_textdet_targets.py new file mode 100644 index 00000000..e06353a6 --- /dev/null +++ b/tests/test_dataset/test_textdet_targets.py @@ -0,0 +1,212 @@ +from unittest import mock + +import numpy as np + +import mmocr.datasets.pipelines.custom_format_bundle as cf_bundle +import mmocr.datasets.pipelines.textdet_targets as textdet_targets +from mmdet.core import PolygonMasks + + +@mock.patch('%s.cf_bundle.show_feature' % __name__) +def test_gen_pannet_targets(mock_show_feature): + + target_generator = textdet_targets.PANetTargets() + assert target_generator.max_shrink == 20 + + # test generate_kernels + img_size = (3, 10) + text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])], + [np.array([2, 0, 3, 0, 3, 1, 2, 1])]] + shrink_ratio = 1.0 + kernel = np.array([[1, 1, 2, 2, 0, 0, 0, 0, 0, 0], + [1, 1, 2, 2, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) + output, _ = target_generator.generate_kernels(img_size, text_polys, + shrink_ratio) + print(output) + assert np.allclose(output, kernel) + + # test generate_effective_mask + polys_ignore = text_polys + output = target_generator.generate_effective_mask((3, 10), polys_ignore) + target = np.array([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + + assert np.allclose(output, target) + + # test generate_targets + results = {} + results['img'] = np.zeros((3, 10, 3), np.uint8) + results['gt_masks'] = PolygonMasks(text_polys, 3, 10) + results['gt_masks_ignore'] = PolygonMasks([], 3, 10) + results['img_shape'] = (3, 10, 3) + results['mask_fields'] = [] + output = target_generator(results) + assert len(output['gt_kernels']) == 2 + assert len(output['gt_mask']) == 1 + + bundle = cf_bundle.CustomFormatBundle( + keys=['gt_kernels', 'gt_mask'], + visualize=dict(flag=True, boundary_key='gt_kernels')) + bundle(output) + assert 'gt_kernels' in output.keys() + assert 'gt_mask' in output.keys() + mock_show_feature.assert_called_once() + + +def test_gen_psenet_targets(): + target_generator = textdet_targets.PSENetTargets() + assert target_generator.max_shrink == 20 + assert target_generator.shrink_ratio == (1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4) + + +# Test DBNetTargets + + +def test_dbnet_targets_find_invalid(): + target_generator = textdet_targets.DBNetTargets() + assert target_generator.shrink_ratio == 0.4 + assert target_generator.thr_min == 0.3 + assert target_generator.thr_max == 0.7 + + results = {} + text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], + [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] + results['gt_masks'] = PolygonMasks(text_polys, 40, 40) + + ignore_tags = target_generator.find_invalid(results) + assert np.allclose(ignore_tags, [False, False]) + + +def test_dbnet_targets(): + target_generator = textdet_targets.DBNetTargets() + assert target_generator.shrink_ratio == 0.4 + assert target_generator.thr_min == 0.3 + assert target_generator.thr_max == 0.7 + + +def test_dbnet_ignore_texts(): + target_generator = textdet_targets.DBNetTargets() + ignore_tags = [True, False] + results = {} + text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], + [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] + text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]] + + results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, 40, 40) + results['gt_masks'] = PolygonMasks(text_polys, 40, 40) + results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]]) + results['gt_labels'] = np.array([0, 1]) + + target_generator.ignore_texts(results, ignore_tags) + + assert np.allclose(results['gt_labels'], np.array([1])) + assert len(results['gt_masks_ignore'].masks) == 2 + assert np.allclose(results['gt_masks_ignore'].masks[1][0], + text_polys[0][0]) + assert len(results['gt_masks'].masks) == 1 + + +def test_dbnet_generate_thr_map(): + target_generator = textdet_targets.DBNetTargets() + text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], + [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] + thr_map, thr_mask = target_generator.generate_thr_map((40, 40), text_polys) + assert np.all((thr_map >= 0.29) * (thr_map <= 0.71)) + + +def test_dbnet_generate_targets(): + target_generator = textdet_targets.DBNetTargets() + text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], + [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] + text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]] + + results = {} + results['mask_fields'] = [] + results['img_shape'] = (40, 40, 3) + results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, 40, 40) + results['gt_masks'] = PolygonMasks(text_polys, 40, 40) + results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]]) + results['gt_labels'] = np.array([0, 1]) + + target_generator.generate_targets(results) + assert 'gt_shrink' in results['mask_fields'] + assert 'gt_shrink_mask' in results['mask_fields'] + assert 'gt_thr' in results['mask_fields'] + assert 'gt_thr_mask' in results['mask_fields'] + + +@mock.patch('%s.cf_bundle.show_feature' % __name__) +def test_gen_textsnake_targets(mock_show_feature): + + target_generator = textdet_targets.TextSnakeTargets() + assert np.allclose(target_generator.orientation_thr, 2.0) + assert np.allclose(target_generator.resample_step, 4.0) + assert np.allclose(target_generator.center_region_shrink_ratio, 0.3) + + # test find_head_tail + polygon = np.array([[1.0, 1.0], [5.0, 1.0], [5.0, 3.0], [1.0, 3.0]]) + head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0) + assert np.allclose(head_inds, [3, 0]) + assert np.allclose(tail_inds, [1, 2]) + + # test generate_text_region_mask + img_size = (3, 10) + text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])], + [np.array([2, 0, 3, 0, 3, 1, 2, 1])]] + output = target_generator.generate_text_region_mask(img_size, text_polys) + target = np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) + assert np.allclose(output, target) + + # test generate_center_region_mask + target_generator.center_region_shrink_ratio = 1.0 + (center_region_mask, radius_map, sin_map, + cos_map) = target_generator.generate_center_mask_attrib_maps( + img_size, text_polys) + target = np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) + assert np.allclose(center_region_mask, target) + assert np.allclose(sin_map, np.zeros(img_size)) + assert np.allclose(cos_map, target) + + # test generate_effective_mask + polys_ignore = text_polys + output = target_generator.generate_effective_mask(img_size, polys_ignore) + target = np.array([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + assert np.allclose(output, target) + + # test generate_targets + results = {} + results['img'] = np.zeros((3, 10, 3), np.uint8) + results['gt_masks'] = PolygonMasks(text_polys, 3, 10) + results['gt_masks_ignore'] = PolygonMasks([], 3, 10) + results['img_shape'] = (3, 10, 3) + results['mask_fields'] = [] + output = target_generator(results) + assert len(output['gt_text_mask']) == 1 + assert len(output['gt_center_region_mask']) == 1 + assert len(output['gt_mask']) == 1 + assert len(output['gt_radius_map']) == 1 + assert len(output['gt_sin_map']) == 1 + assert len(output['gt_cos_map']) == 1 + + bundle = cf_bundle.CustomFormatBundle( + keys=[ + 'gt_text_mask', 'gt_center_region_mask', 'gt_mask', + 'gt_radius_map', 'gt_sin_map', 'gt_cos_map' + ], + visualize=dict(flag=True, boundary_key='gt_text_mask')) + bundle(output) + assert 'gt_text_mask' in output.keys() + assert 'gt_center_region_mask' in output.keys() + assert 'gt_mask' in output.keys() + assert 'gt_radius_map' in output.keys() + assert 'gt_sin_map' in output.keys() + assert 'gt_cos_map' in output.keys() + mock_show_feature.assert_called_once() diff --git a/tests/test_dataset/test_transforms.py b/tests/test_dataset/test_transforms.py new file mode 100644 index 00000000..a50b4d0a --- /dev/null +++ b/tests/test_dataset/test_transforms.py @@ -0,0 +1,166 @@ +import unittest.mock as mock + +import numpy as np +import torchvision.transforms as TF +from PIL import Image + +import mmocr.datasets.pipelines.transforms as transforms +from mmdet.core import BitmapMasks + + +@mock.patch('%s.transforms.np.random.random_sample' % __name__) +@mock.patch('%s.transforms.np.random.randint' % __name__) +def test_random_crop_instances(mock_randint, mock_sample): + + img_gt = np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 1, 1], + [0, 0, 1, 1, 1], [0, 0, 1, 1, 1]]) + # test target is bigger than img size in sample_offset + mock_sample.side_effect = [1] + rci = transforms.RandomCropInstances(6, instance_key='gt_kernels') + (i, j) = rci.sample_offset(img_gt, (5, 5)) + assert i == 0 + assert j == 0 + + # test the second branch in sample_offset + + rci = transforms.RandomCropInstances(3, instance_key='gt_kernels') + mock_sample.side_effect = [1] + mock_randint.side_effect = [1, 2] + (i, j) = rci.sample_offset(img_gt, (5, 5)) + assert i == 1 + assert j == 2 + + mock_sample.side_effect = [1] + mock_randint.side_effect = [1, 2] + rci = transforms.RandomCropInstances(5, instance_key='gt_kernels') + (i, j) = rci.sample_offset(img_gt, (5, 5)) + assert i == 0 + assert j == 0 + + # test the first bracnh is sample_offset + + rci = transforms.RandomCropInstances(3, instance_key='gt_kernels') + mock_sample.side_effect = [0.1] + mock_randint.side_effect = [1, 1] + (i, j) = rci.sample_offset(img_gt, (5, 5)) + assert i == 1 + assert j == 1 + + # test crop_img(img, offset, target_size) + + img = img_gt + offset = [0, 0] + target = [6, 6] + crop = rci.crop_img(img, offset, target) + assert np.allclose(img, crop[0]) + assert np.allclose(crop[1], [0, 0, 5, 5]) + + target = [3, 2] + crop = rci.crop_img(img, offset, target) + assert np.allclose(np.array([[0, 0], [0, 0], [0, 0]]), crop[0]) + assert np.allclose(crop[1], [0, 0, 2, 3]) + + # test __call__ + rci = transforms.RandomCropInstances(3, instance_key='gt_kernels') + results = {} + gt_kernels = [img_gt, img_gt.copy()] + results['gt_kernels'] = BitmapMasks(gt_kernels, 5, 5) + results['img'] = img_gt.copy() + results['mask_fields'] = ['gt_kernels'] + mock_sample.side_effect = [0.1] + mock_randint.side_effect = [1, 1] + output = rci(results) + print(output['img']) + target = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]]) + assert output['img_shape'] == (3, 3) + + assert np.allclose(output['img'], target) + + assert np.allclose(output['gt_kernels'].masks[0], target) + assert np.allclose(output['gt_kernels'].masks[1], target) + + +@mock.patch('%s.transforms.np.random.random_sample' % __name__) +def test_scale_aspect_jitter(mock_random): + img_scale = [(3000, 1000)] # unused + ratio_range = (0.5, 1.5) + aspect_ratio_range = (1, 1) + multiscale_mode = 'value' + long_size_bound = 2000 + short_size_bound = 640 + resize_type = 'long_short_bound' + keep_ratio = False + jitter = transforms.ScaleAspectJitter( + img_scale=img_scale, + ratio_range=ratio_range, + aspect_ratio_range=aspect_ratio_range, + multiscale_mode=multiscale_mode, + long_size_bound=long_size_bound, + short_size_bound=short_size_bound, + resize_type=resize_type, + keep_ratio=keep_ratio) + mock_random.side_effect = [0.5] + + # test sample_from_range + + result = jitter.sample_from_range([100, 200]) + assert result == 150 + + # test _random_scale + results = {} + results['img'] = np.zeros((4000, 1000)) + mock_random.side_effect = [0.5, 1] + jitter._random_scale(results) + # scale1 0.5, scale2=1 scale =0.5 650/1000, w, h + # print(results['scale']) + assert results['scale'] == (650, 2600) + + +@mock.patch('%s.transforms.np.random.random_sample' % __name__) +def test_random_rotate(mock_random): + + mock_random.side_effect = [0.5, 0] + results = {} + img = np.random.rand(5, 5) + results['img'] = img.copy() + results['mask_fields'] = ['masks'] + gt_kernels = [results['img'].copy()] + results['masks'] = BitmapMasks(gt_kernels, 5, 5) + + rotater = transforms.RandomRotateTextDet() + + results = rotater(results) + assert np.allclose(results['img'], img) + assert np.allclose(results['masks'].masks, img) + + +def test_color_jitter(): + img = np.ones((64, 256, 3), dtype=np.uint8) + results = {'img': img} + + pt_official_color_jitter = TF.ColorJitter() + output1 = pt_official_color_jitter(img) + + color_jitter = transforms.ColorJitter() + output2 = color_jitter(results) + + assert np.allclose(output1, output2['img']) + + +def test_affine_jitter(): + img = np.ones((64, 256, 3), dtype=np.uint8) + results = {'img': img} + + pt_official_affine_jitter = TF.RandomAffine(degrees=0) + output1 = pt_official_affine_jitter(Image.fromarray(img)) + + affine_jitter = transforms.AffineJitter( + degrees=0, + translate=None, + scale=None, + shear=None, + resample=False, + fillcolor=0) + output2 = affine_jitter(results) + + assert np.allclose(np.array(output1), output2['img']) diff --git a/tests/test_metrics/test_eval_utils.py b/tests/test_metrics/test_eval_utils.py new file mode 100644 index 00000000..a5ca2a35 --- /dev/null +++ b/tests/test_metrics/test_eval_utils.py @@ -0,0 +1,389 @@ +"""Tests the utils of evaluation.""" +import numpy as np +import pytest + +import mmocr.core.evaluation.utils as utils + + +def test_ignore_pred(): + + # test invalid arguments + box = [0, 0, 1, 0, 1, 1, 0, 1] + det_boxes = [box] + gt_dont_care_index = [0] + gt_polys = [utils.points2polygon(box)] + precision_thr = 0.5 + + with pytest.raises(AssertionError): + det_boxes_tmp = 1 + utils.ignore_pred(det_boxes_tmp, gt_dont_care_index, gt_polys, + precision_thr) + with pytest.raises(AssertionError): + gt_dont_care_index_tmp = 1 + utils.ignore_pred(det_boxes, gt_dont_care_index_tmp, gt_polys, + precision_thr) + with pytest.raises(AssertionError): + gt_polys_tmp = 1 + utils.ignore_pred(det_boxes, gt_dont_care_index, gt_polys_tmp, + precision_thr) + with pytest.raises(AssertionError): + precision_thr_tmp = 1.1 + utils.ignore_pred(det_boxes, gt_dont_care_index, gt_polys, + precision_thr_tmp) + + # test ignored cases + result = utils.ignore_pred(det_boxes, gt_dont_care_index, gt_polys, + precision_thr) + assert result[2] == [0] + # test unignored cases + gt_dont_care_index_tmp = [] + result = utils.ignore_pred(det_boxes, gt_dont_care_index_tmp, gt_polys, + precision_thr) + assert result[2] == [] + + det_boxes_tmp = [[10, 10, 15, 10, 15, 15, 10, 15]] + result = utils.ignore_pred(det_boxes_tmp, gt_dont_care_index, gt_polys, + precision_thr) + assert result[2] == [] + + +def test_compute_hmean(): + + # test invalid arguments + with pytest.raises(AssertionError): + utils.compute_hmean(0, 0, 0.0, 0) + with pytest.raises(AssertionError): + utils.compute_hmean(0, 0, 0, 0.0) + with pytest.raises(AssertionError): + utils.compute_hmean([1], 0, 0, 0) + with pytest.raises(AssertionError): + utils.compute_hmean(0, [1], 0, 0) + + _, _, hmean = utils.compute_hmean(2, 2, 2, 2) + assert hmean == 1 + + _, _, hmean = utils.compute_hmean(0, 0, 2, 2) + assert hmean == 0 + + +def test_points2polygon(): + + # test unsupported type + with pytest.raises(AssertionError): + points = 2 + utils.points2polygon(points) + + # test unsupported size + with pytest.raises(AssertionError): + points = [1, 2, 3, 4, 5, 6, 7] + utils.points2polygon(points) + with pytest.raises(AssertionError): + points = [1, 2, 3, 4, 5, 6] + utils.points2polygon(points) + + # test np.array + points = np.array([1, 2, 3, 4, 5, 6, 7, 8]) + poly = utils.points2polygon(points) + assert poly.nPoints() == 4 + + points = [1, 2, 3, 4, 5, 6, 7, 8] + poly = utils.points2polygon(points) + assert poly.nPoints() == 4 + + +def test_poly_intersection(): + + # test unsupported type + with pytest.raises(AssertionError): + utils.poly_intersection(0, 1) + + # test non-overlapping polygons + + points = [0, 0, 0, 1, 1, 1, 1, 0] + points1 = [10, 20, 30, 40, 50, 60, 70, 80] + poly = utils.points2polygon(points) + poly1 = utils.points2polygon(points1) + + area_inters, _ = utils.poly_intersection(poly, poly1) + + assert area_inters == 0 + + # test overlapping polygons + area_inters, _ = utils.poly_intersection(poly, poly) + assert area_inters == 1 + + +def test_poly_union(): + + # test unsupported type + with pytest.raises(AssertionError): + utils.poly_union(0, 1) + + # test non-overlapping polygons + + points = [0, 0, 0, 1, 1, 1, 1, 0] + points1 = [2, 2, 2, 3, 3, 3, 3, 2] + poly = utils.points2polygon(points) + poly1 = utils.points2polygon(points1) + + assert utils.poly_union(poly, poly1) == 2 + + # test overlapping polygons + assert utils.poly_union(poly, poly) == 1 + + +def test_poly_iou(): + + # test unsupported type + with pytest.raises(AssertionError): + utils.poly_iou([1], [2]) + + points = [0, 0, 0, 1, 1, 1, 1, 0] + points1 = [10, 20, 30, 40, 50, 60, 70, 80] + poly = utils.points2polygon(points) + poly1 = utils.points2polygon(points1) + + assert utils.poly_iou(poly, poly1) == 0 + + # test overlapping polygons + + assert utils.poly_iou(poly, poly) == 1 + + +def test_boundary_iou(): + points = [0, 0, 0, 1, 1, 1, 1, 0] + points1 = [10, 20, 30, 40, 50, 60, 70, 80] + + assert utils.boundary_iou(points, points1) == 0 + + # test overlapping boundaries + assert utils.boundary_iou(points, points) == 1 + + +def test_points_center(): + + # test unsupported type + with pytest.raises(AssertionError): + utils.points_center([1]) + with pytest.raises(AssertionError): + points = np.array([1, 2, 3]) + utils.points_center(points) + + points = np.array([1, 2, 3, 4]) + assert np.array_equal(utils.points_center(points), np.array([2, 3])) + + +def test_point_distance(): + # test unsupported type + with pytest.raises(AssertionError): + utils.point_distance([1, 2], [1, 2]) + + with pytest.raises(AssertionError): + p = np.array([1, 2, 3]) + utils.point_distance(p, p) + + p = np.array([1, 2]) + assert utils.point_distance(p, p) == 0 + + p1 = np.array([2, 2]) + assert utils.point_distance(p, p1) == 1 + + +def test_box_center_distance(): + p1 = np.array([1, 1, 3, 3]) + p2 = np.array([2, 2, 4, 2]) + + assert utils.box_center_distance(p1, p2) == 1 + + +def test_box_diag(): + # test unsupported type + with pytest.raises(AssertionError): + utils.box_diag([1, 2]) + with pytest.raises(AssertionError): + utils.box_diag(np.array([1, 2, 3, 4])) + + box = np.array([0, 0, 1, 1, 0, 10, -10, 0]) + + assert utils.box_diag(box) == 10 + + +def test_one2one_match_ic13(): + gt_id = 0 + det_id = 0 + recall_mat = np.array([[1, 0], [0, 0]]) + precision_mat = np.array([[1, 0], [0, 0]]) + recall_thr = 0.5 + precision_thr = 0.5 + # test invalid arguments. + with pytest.raises(AssertionError): + utils.one2one_match_ic13(0.0, det_id, recall_mat, precision_mat, + recall_thr, precision_thr) + with pytest.raises(AssertionError): + utils.one2one_match_ic13(gt_id, 0.0, recall_mat, precision_mat, + recall_thr, precision_thr) + with pytest.raises(AssertionError): + utils.one2one_match_ic13(gt_id, det_id, [0, 0], precision_mat, + recall_thr, precision_thr) + with pytest.raises(AssertionError): + utils.one2one_match_ic13(gt_id, det_id, recall_mat, [0, 0], recall_thr, + precision_thr) + with pytest.raises(AssertionError): + utils.one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, 1.1, + precision_thr) + with pytest.raises(AssertionError): + utils.one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, + recall_thr, 1.1) + + assert utils.one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, + recall_thr, precision_thr) + recall_mat = np.array([[1, 0], [0.6, 0]]) + precision_mat = np.array([[1, 0], [0.6, 0]]) + assert not utils.one2one_match_ic13( + gt_id, det_id, recall_mat, precision_mat, recall_thr, precision_thr) + recall_mat = np.array([[1, 0.6], [0, 0]]) + precision_mat = np.array([[1, 0.6], [0, 0]]) + assert not utils.one2one_match_ic13( + gt_id, det_id, recall_mat, precision_mat, recall_thr, precision_thr) + + +def test_one2many_match_ic13(): + gt_id = 0 + recall_mat = np.array([[1, 0], [0, 0]]) + precision_mat = np.array([[1, 0], [0, 0]]) + recall_thr = 0.5 + precision_thr = 0.5 + gt_match_flag = [0, 0] + det_match_flag = [0, 0] + det_dont_care_index = [] + # test invalid arguments. + with pytest.raises(AssertionError): + gt_id_tmp = 0.0 + utils.one2many_match_ic13(gt_id_tmp, recall_mat, precision_mat, + recall_thr, precision_thr, gt_match_flag, + det_match_flag, det_dont_care_index) + with pytest.raises(AssertionError): + recall_mat_tmp = [1, 0] + utils.one2many_match_ic13(gt_id, recall_mat_tmp, precision_mat, + recall_thr, precision_thr, gt_match_flag, + det_match_flag, det_dont_care_index) + with pytest.raises(AssertionError): + precision_mat_tmp = [1, 0] + utils.one2many_match_ic13(gt_id, recall_mat, precision_mat_tmp, + recall_thr, precision_thr, gt_match_flag, + det_match_flag, det_dont_care_index) + with pytest.raises(AssertionError): + + utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, 1.1, + precision_thr, gt_match_flag, det_match_flag, + det_dont_care_index) + with pytest.raises(AssertionError): + + utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, + 1.1, gt_match_flag, det_match_flag, + det_dont_care_index) + with pytest.raises(AssertionError): + gt_match_flag_tmp = np.array([0, 1]) + utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_match_flag_tmp, + det_match_flag, det_dont_care_index) + with pytest.raises(AssertionError): + det_match_flag_tmp = np.array([0, 1]) + utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_match_flag, + det_match_flag_tmp, det_dont_care_index) + with pytest.raises(AssertionError): + det_dont_care_index_tmp = np.array([0, 1]) + utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr, + precision_thr, gt_match_flag, det_match_flag, + det_dont_care_index_tmp) + + # test matched case + + result = utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, + recall_thr, precision_thr, + gt_match_flag, det_match_flag, + det_dont_care_index) + assert result[0] + assert result[1] == [0] + + # test unmatched case + gt_match_flag_tmp = [1, 0] + result = utils.one2many_match_ic13(gt_id, recall_mat, precision_mat, + recall_thr, precision_thr, + gt_match_flag_tmp, det_match_flag, + det_dont_care_index) + assert not result[0] + assert result[1] == [] + + +def test_many2one_match_ic13(): + det_id = 0 + recall_mat = np.array([[1, 0], [0, 0]]) + precision_mat = np.array([[1, 0], [0, 0]]) + recall_thr = 0.5 + precision_thr = 0.5 + gt_match_flag = [0, 0] + det_match_flag = [0, 0] + gt_dont_care_index = [] + # test invalid arguments. + with pytest.raises(AssertionError): + det_id_tmp = 1.0 + utils.many2one_match_ic13(det_id_tmp, recall_mat, precision_mat, + recall_thr, precision_thr, gt_match_flag, + det_match_flag, gt_dont_care_index) + with pytest.raises(AssertionError): + recall_mat_tmp = [[1, 0], [0, 0]] + utils.many2one_match_ic13(det_id, recall_mat_tmp, precision_mat, + recall_thr, precision_thr, gt_match_flag, + det_match_flag, gt_dont_care_index) + with pytest.raises(AssertionError): + precision_mat_tmp = [[1, 0], [0, 0]] + utils.many2one_match_ic13(det_id, recall_mat, precision_mat_tmp, + recall_thr, precision_thr, gt_match_flag, + det_match_flag, gt_dont_care_index) + with pytest.raises(AssertionError): + recall_thr_tmp = 1.1 + utils.many2one_match_ic13(det_id, recall_mat, precision_mat, + recall_thr_tmp, precision_thr, gt_match_flag, + det_match_flag, gt_dont_care_index) + with pytest.raises(AssertionError): + precision_thr_tmp = 1.1 + utils.many2one_match_ic13(det_id, recall_mat, precision_mat, + recall_thr, precision_thr_tmp, gt_match_flag, + det_match_flag, gt_dont_care_index) + with pytest.raises(AssertionError): + gt_match_flag_tmp = np.array([0, 1]) + utils.many2one_match_ic13(det_id, recall_mat, precision_mat, + recall_thr, precision_thr, gt_match_flag_tmp, + det_match_flag, gt_dont_care_index) + with pytest.raises(AssertionError): + det_match_flag_tmp = np.array([0, 1]) + utils.many2one_match_ic13(det_id, recall_mat, precision_mat, + recall_thr, precision_thr, gt_match_flag, + det_match_flag_tmp, gt_dont_care_index) + with pytest.raises(AssertionError): + gt_dont_care_index_tmp = np.array([0, 1]) + utils.many2one_match_ic13(det_id, recall_mat, precision_mat, + recall_thr, precision_thr, gt_match_flag, + det_match_flag, gt_dont_care_index_tmp) + + # test matched cases + + result = utils.many2one_match_ic13(det_id, recall_mat, precision_mat, + recall_thr, precision_thr, + gt_match_flag, det_match_flag, + gt_dont_care_index) + assert result[0] + assert result[1] == [0] + + # test unmatched cases + + gt_dont_care_index = [0] + + result = utils.many2one_match_ic13(det_id, recall_mat, precision_mat, + recall_thr, precision_thr, + gt_match_flag, det_match_flag, + gt_dont_care_index) + assert not result[0] + assert result[1] == [] diff --git a/tests/test_metrics/test_hmean_detect.py b/tests/test_metrics/test_hmean_detect.py new file mode 100644 index 00000000..99c6f074 --- /dev/null +++ b/tests/test_metrics/test_hmean_detect.py @@ -0,0 +1,71 @@ +import tempfile + +import numpy as np +import pytest + +from mmocr.core.evaluation.hmean import (eval_hmean, get_gt_masks, + output_ranklist) + + +def _create_dummy_ann_infos(): + ann_infos = { + 'bboxes': np.array([[50., 70., 80., 100.]], dtype=np.float32), + 'labels': np.array([1], dtype=np.int64), + 'bboxes_ignore': np.array([[120, 140, 200, 200]], dtype=np.float32), + 'masks': [[[50, 70, 80, 70, 80, 100, 50, 100]]], + 'masks_ignore': [[[120, 140, 200, 140, 200, 200, 120, 200]]] + } + return [ann_infos] + + +def test_output_ranklist(): + result = [{'hmean': 1}, {'hmean': 0.5}] + file_name = tempfile.NamedTemporaryFile().name + img_infos = [{'file_name': 'sample1.jpg'}, {'file_name': 'sample2.jpg'}] + + json_file = file_name + '.json' + with pytest.raises(AssertionError): + output_ranklist([[]], img_infos, json_file) + with pytest.raises(AssertionError): + output_ranklist(result, [[]], json_file) + with pytest.raises(AssertionError): + output_ranklist(result, img_infos, file_name) + + sorted_outputs = output_ranklist(result, img_infos, json_file) + + assert sorted_outputs[0]['hmean'] == 0.5 + + +def test_get_gt_mask(): + ann_infos = _create_dummy_ann_infos() + gt_masks, gt_masks_ignore = get_gt_masks(ann_infos) + + assert np.allclose(gt_masks[0], [[50, 70, 80, 70, 80, 100, 50, 100]]) + assert np.allclose(gt_masks_ignore[0], + [[120, 140, 200, 140, 200, 200, 120, 200]]) + + +def test_eval_hmean(): + metrics = set(['hmean-iou', 'hmean-ic13']) + results = [{ + 'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1], + [120, 140, 200, 140, 200, 200, 120, 200, 1]] + }] + + img_infos = [{'file_name': 'sample1.jpg'}] + ann_infos = _create_dummy_ann_infos() + + # test invalid arguments + with pytest.raises(AssertionError): + eval_hmean(results, [[]], ann_infos, metrics=metrics) + with pytest.raises(AssertionError): + eval_hmean(results, img_infos, [[]], metrics=metrics) + with pytest.raises(AssertionError): + eval_hmean([[]], img_infos, ann_infos, metrics=metrics) + with pytest.raises(AssertionError): + eval_hmean(results, img_infos, ann_infos, metrics='hmean-iou') + + eval_results = eval_hmean(results, img_infos, ann_infos, metrics=metrics) + + assert eval_results['hmean-iou:hmean'] == 1 + assert eval_results['hmean-ic13:hmean'] == 1 diff --git a/tests/test_metrics/test_hmean_ic13.py b/tests/test_metrics/test_hmean_ic13.py new file mode 100644 index 00000000..c5b3e69c --- /dev/null +++ b/tests/test_metrics/test_hmean_ic13.py @@ -0,0 +1,116 @@ +"""Test hmean_ic13.""" +import math + +import pytest + +import mmocr.core.evaluation.hmean_ic13 as hmean_ic13 +import mmocr.core.evaluation.utils as utils + + +def test_compute_recall_precision(): + + gt_polys = [] + det_polys = [] + + # test invalid arguments. + with pytest.raises(AssertionError): + hmean_ic13.compute_recall_precision(1, 1) + + box1 = [0, 0, 1, 0, 1, 1, 0, 1] + + box2 = [0, 0, 10, 0, 10, 1, 0, 1] + + gt_polys = [utils.points2polygon(box1)] + det_polys = [utils.points2polygon(box2)] + recall, precision = hmean_ic13.compute_recall_precision( + gt_polys, det_polys) + assert recall == 1 + assert precision == 0.1 + + +def test_eval_hmean_ic13(): + det_boxes = [] + gt_boxes = [] + gt_ignored_boxes = [] + precision_thr = 0.4 + recall_thr = 0.8 + center_dist_thr = 1.0 + one2one_score = 1. + one2many_score = 0.8 + many2one_score = 1 + # test invalid arguments. + + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13([1], gt_boxes, gt_ignored_boxes, + precision_thr, recall_thr, center_dist_thr, + one2one_score, one2many_score, + many2one_score) + + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, 1, gt_ignored_boxes, + precision_thr, recall_thr, center_dist_thr, + one2one_score, one2many_score, + many2one_score) + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, gt_boxes, 1, precision_thr, + recall_thr, center_dist_thr, one2one_score, + one2many_score, many2one_score) + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, gt_boxes, gt_ignored_boxes, 1.1, + recall_thr, center_dist_thr, one2one_score, + one2many_score, many2one_score) + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, gt_boxes, gt_ignored_boxes, + precision_thr, 1.1, center_dist_thr, + one2one_score, one2many_score, + many2one_score) + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, gt_boxes, gt_ignored_boxes, + precision_thr, recall_thr, -1, + one2one_score, one2many_score, + many2one_score) + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, gt_boxes, gt_ignored_boxes, + precision_thr, recall_thr, center_dist_thr, + -1, one2many_score, many2one_score) + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, gt_boxes, gt_ignored_boxes, + precision_thr, recall_thr, center_dist_thr, + one2one_score, -1, many2one_score) + with pytest.raises(AssertionError): + hmean_ic13.eval_hmean_ic13(det_boxes, gt_boxes, gt_ignored_boxes, + precision_thr, recall_thr, center_dist_thr, + one2one_score, one2many_score, -1) + + # test one2one match + det_boxes = [[[0, 0, 1, 0, 1, 1, 0, 1], [10, 0, 11, 0, 11, 1, 10, 1]]] + gt_boxes = [[[0, 0, 1, 0, 1, 1, 0, 1]]] + gt_ignored_boxes = [[]] + dataset_result, img_result = hmean_ic13.eval_hmean_ic13( + det_boxes, gt_boxes, gt_ignored_boxes, precision_thr, recall_thr, + center_dist_thr, one2one_score, one2many_score, many2one_score) + assert img_result[0]['recall'] == 1 + assert img_result[0]['precision'] == 0.5 + assert math.isclose(img_result[0]['hmean'], 2 * (0.5) / 1.5) + + # test one2many match + gt_boxes = [[[0, 0, 2, 0, 2, 1, 0, 1]]] + det_boxes = [[[0, 0, 1, 0, 1, 1, 0, 1], [1, 0, 2, 0, 2, 1, 1, 1]]] + dataset_result, img_result = hmean_ic13.eval_hmean_ic13( + det_boxes, gt_boxes, gt_ignored_boxes, precision_thr, recall_thr, + center_dist_thr, one2one_score, one2many_score, many2one_score) + assert img_result[0]['recall'] == 0.8 + assert img_result[0]['precision'] == 1.6 / 2 + assert math.isclose(img_result[0]['hmean'], 2 * (0.64) / 1.6) + + # test many2one match + precision_thr = 0.6 + recall_thr = 0.8 + det_boxes = [[[0, 0, 2, 0, 2, 1, 0, 1]]] + gt_boxes = [[[0, 0, 1, 0, 1, 1, 0, 1], [1, 0, 2, 0, 2, 1, 1, 1]]] + dataset_result, img_result = hmean_ic13.eval_hmean_ic13( + det_boxes, gt_boxes, gt_ignored_boxes, precision_thr, recall_thr, + center_dist_thr, one2one_score, one2many_score, many2one_score) + assert img_result[0]['recall'] == 1 + assert img_result[0]['precision'] == 1 + assert math.isclose(img_result[0]['hmean'], 1) diff --git a/tests/test_metrics/test_hmean_iou.py b/tests/test_metrics/test_hmean_iou.py new file mode 100644 index 00000000..8be4ca50 --- /dev/null +++ b/tests/test_metrics/test_hmean_iou.py @@ -0,0 +1,40 @@ +"""Test hmean_iou.""" +import pytest + +import mmocr.core.evaluation.hmean_iou as hmean_iou + + +def test_eval_hmean_iou(): + + pred_boxes = [] + gt_boxes = [] + gt_ignored_boxes = [] + iou_thr = 0.5 + precision_thr = 0.5 + + # test invalid arguments. + + with pytest.raises(AssertionError): + hmean_iou.eval_hmean_iou([1], gt_boxes, gt_ignored_boxes, iou_thr, + precision_thr) + with pytest.raises(AssertionError): + hmean_iou.eval_hmean_iou(pred_boxes, [1], gt_ignored_boxes, iou_thr, + precision_thr) + with pytest.raises(AssertionError): + hmean_iou.eval_hmean_iou(pred_boxes, gt_boxes, [1], iou_thr, + precision_thr) + with pytest.raises(AssertionError): + hmean_iou.eval_hmean_iou(pred_boxes, gt_boxes, gt_ignored_boxes, 1.1, + precision_thr) + with pytest.raises(AssertionError): + hmean_iou.eval_hmean_iou(pred_boxes, gt_boxes, gt_ignored_boxes, + iou_thr, 1.1) + + pred_boxes = [[[0, 0, 1, 0, 1, 1, 0, 1], [2, 0, 3, 0, 3, 1, 2, 1]]] + gt_boxes = [[[0, 0, 1, 0, 1, 1, 0, 1], [2, 0, 3, 0, 3, 1, 2, 1]]] + gt_ignored_boxes = [[]] + results = hmean_iou.eval_hmean_iou(pred_boxes, gt_boxes, gt_ignored_boxes, + iou_thr, precision_thr) + assert results[1][0]['recall'] == 1 + assert results[1][0]['precision'] == 1 + assert results[1][0]['hmean'] == 1 diff --git a/tests/test_models/test_detector.py b/tests/test_models/test_detector.py new file mode 100644 index 00000000..a485c490 --- /dev/null +++ b/tests/test_models/test_detector.py @@ -0,0 +1,368 @@ +"""pytest tests/test_detector.py.""" +import copy +from os.path import dirname, exists, join + +import numpy as np +import pytest +import torch + +import mmocr.core.evaluation.utils as utils + + +def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300), + num_items=None, num_classes=1): # yapf: disable + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): Input batch dimensions. + + num_items (None | list[int]): Specifies the number of boxes + for each batch item. + + num_classes (int): Number of distinct labels a box might have. + """ + from mmdet.core import BitmapMasks + + (N, C, H, W) = input_shape + + rng = np.random.RandomState(0) + + imgs = rng.rand(*input_shape) + + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': np.array([1, 1, 1, 1]), + 'flip': False, + } for _ in range(N)] + + gt_bboxes = [] + gt_labels = [] + gt_masks = [] + gt_kernels = [] + gt_effective_mask = [] + + for batch_idx in range(N): + if num_items is None: + num_boxes = rng.randint(1, 10) + else: + num_boxes = num_items[batch_idx] + + cx, cy, bw, bh = rng.rand(num_boxes, 4).T + + tl_x = ((cx * W) - (W * bw / 2)).clip(0, W) + tl_y = ((cy * H) - (H * bh / 2)).clip(0, H) + br_x = ((cx * W) + (W * bw / 2)).clip(0, W) + br_y = ((cy * H) + (H * bh / 2)).clip(0, H) + + boxes = np.vstack([tl_x, tl_y, br_x, br_y]).T + class_idxs = [0] * num_boxes + + gt_bboxes.append(torch.FloatTensor(boxes)) + gt_labels.append(torch.LongTensor(class_idxs)) + kernels = [] + for kernel_inx in range(num_kernels): + kernel = np.random.rand(H, W) + kernels.append(kernel) + gt_kernels.append(BitmapMasks(kernels, H, W)) + gt_effective_mask.append(BitmapMasks([np.ones((H, W))], H, W)) + + mask = np.random.randint(0, 2, (len(boxes), H, W), dtype=np.uint8) + gt_masks.append(BitmapMasks(mask, H, W)) + + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas, + 'gt_bboxes': gt_bboxes, + 'gt_labels': gt_labels, + 'gt_bboxes_ignore': None, + 'gt_masks': gt_masks, + 'gt_kernels': gt_kernels, + 'gt_mask': gt_effective_mask, + 'gt_thr_mask': gt_effective_mask, + 'gt_text_mask': gt_effective_mask, + 'gt_center_region_mask': gt_effective_mask, + 'gt_radius_map': gt_kernels, + 'gt_sin_map': gt_kernels, + 'gt_cos_map': gt_kernels, + } + return mm_inputs + + +def _get_config_directory(): + """Find the predefined detector config directory.""" + try: + # Assume we are running in the source mmocr repo + repo_dpath = dirname(dirname(dirname(__file__))) + except NameError: + # For IPython development when this __file__ is not defined + import mmocr + repo_dpath = dirname(dirname(mmocr.__file__)) + config_dpath = join(repo_dpath, 'configs') + if not exists(config_dpath): + raise Exception('Cannot find config path') + return config_dpath + + +def _get_config_module(fname): + """Load a configuration as a python module.""" + from mmcv import Config + config_dpath = _get_config_directory() + config_fpath = join(config_dpath, fname) + config_mod = Config.fromfile(config_fpath) + return config_mod + + +def _get_detector_cfg(fname): + """Grab configs necessary to create a detector. + + These are deep copied to allow for safe modification of parameters without + influencing other tests. + """ + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + return model + + +@pytest.mark.parametrize('cfg_file', [ + 'textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py', + 'textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py', + 'textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py' +]) +def test_ocr_mask_rcnn(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + + from mmocr.models import build_detector + detector = build_detector(model) + + input_shape = (1, 3, 224, 224) + mm_inputs = _demo_mm_inputs(0, input_shape) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_labels = mm_inputs.pop('gt_labels') + gt_masks = mm_inputs.pop('gt_masks') + + # Test forward train + gt_bboxes = mm_inputs['gt_bboxes'] + losses = detector.forward( + imgs, + img_metas, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels, + gt_masks=gt_masks) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test get_boundary + results = ([[[1]]], [[ + np.array([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) + ]]) + + boundaries = detector.get_boundary(results) + assert utils.boundary_iou(boundaries['boundary_result'][0][:-1], + [1, 1, 0, 1, 0, 0, 1, 0]) == 1 + + # Test show_result + + results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} + img = np.random.rand(5, 5) + detector.show_result(img, results) + + +@pytest.mark.parametrize('cfg_file', [ + 'textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py', + 'textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py', + 'textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py' +]) +def test_panet(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + model['backbone']['norm_cfg']['type'] = 'BN' + + from mmocr.models import build_detector + detector = build_detector(model) + + input_shape = (1, 3, 224, 224) + num_kernels = 2 + mm_inputs = _demo_mm_inputs(num_kernels, input_shape) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_kernels = mm_inputs.pop('gt_kernels') + gt_mask = mm_inputs.pop('gt_mask') + + # Test forward train + losses = detector.forward( + imgs, img_metas, gt_kernels=gt_kernels, gt_mask=gt_mask) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test show result + results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} + img = np.random.rand(5, 5) + detector.show_result(img, results) + + +@pytest.mark.parametrize('cfg_file', [ + 'textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py', + 'textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py', + 'textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py' +]) +def test_psenet(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + model['backbone']['norm_cfg']['type'] = 'BN' + + from mmocr.models import build_detector + detector = build_detector(model) + + input_shape = (1, 3, 224, 224) + num_kernels = 7 + mm_inputs = _demo_mm_inputs(num_kernels, input_shape) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + gt_kernels = mm_inputs.pop('gt_kernels') + gt_mask = mm_inputs.pop('gt_mask') + + # Test forward train + losses = detector.forward( + imgs, img_metas, gt_kernels=gt_kernels, gt_mask=gt_mask) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test show result + results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} + img = np.random.rand(5, 5) + detector.show_result(img, results) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.parametrize('cfg_file', [ + 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', + 'textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py' +]) +def test_dbnet(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + model['backbone']['norm_cfg']['type'] = 'BN' + + from mmocr.models import build_detector + detector = build_detector(model) + detector = detector.cuda() + input_shape = (1, 3, 224, 224) + num_kernels = 7 + mm_inputs = _demo_mm_inputs(num_kernels, input_shape) + + imgs = mm_inputs.pop('imgs') + imgs = imgs.cuda() + img_metas = mm_inputs.pop('img_metas') + gt_shrink = mm_inputs.pop('gt_kernels') + gt_shrink_mask = mm_inputs.pop('gt_mask') + gt_thr = mm_inputs.pop('gt_masks') + gt_thr_mask = mm_inputs.pop('gt_thr_mask') + + # Test forward train + losses = detector.forward( + imgs, + img_metas, + gt_shrink=gt_shrink, + gt_shrink_mask=gt_shrink_mask, + gt_thr=gt_thr, + gt_thr_mask=gt_thr_mask) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test show result + results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} + img = np.random.rand(5, 5) + detector.show_result(img, results) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') +@pytest.mark.parametrize( + 'cfg_file', ['textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py']) +def test_textsnake(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + model['backbone']['norm_cfg']['type'] = 'BN' + + from mmocr.models import build_detector + detector = build_detector(model) + detector = detector.cuda() + input_shape = (1, 3, 64, 64) + num_kernels = 1 + mm_inputs = _demo_mm_inputs(num_kernels, input_shape) + + imgs = mm_inputs.pop('imgs') + imgs = imgs.cuda() + img_metas = mm_inputs.pop('img_metas') + gt_text_mask = mm_inputs.pop('gt_text_mask') + gt_center_region_mask = mm_inputs.pop('gt_center_region_mask') + gt_mask = mm_inputs.pop('gt_mask') + gt_radius_map = mm_inputs.pop('gt_radius_map') + gt_sin_map = mm_inputs.pop('gt_sin_map') + gt_cos_map = mm_inputs.pop('gt_cos_map') + + # Test forward train + losses = detector.forward( + imgs, + img_metas, + gt_text_mask=gt_text_mask, + gt_center_region_mask=gt_center_region_mask, + gt_mask=gt_mask, + gt_radius_map=gt_radius_map, + gt_sin_map=gt_sin_map, + gt_cos_map=gt_cos_map) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test show result + results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]} + img = np.random.rand(5, 5) + detector.show_result(img, results) diff --git a/tests/test_models/test_label_convertor/test_attn_label_convertor.py b/tests/test_models/test_label_convertor/test_attn_label_convertor.py new file mode 100644 index 00000000..00eaeacc --- /dev/null +++ b/tests/test_models/test_label_convertor/test_attn_label_convertor.py @@ -0,0 +1,77 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest +import torch + +from mmocr.models.textrecog.convertors import AttnConvertor + + +def _create_dummy_dict_file(dict_file): + characters = list('helowrd') + with open(dict_file, 'w') as fw: + for char in characters: + fw.write(char + '\n') + + +def test_attn_label_convertor(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_dict.txt') + _create_dummy_dict_file(dict_file) + + # test invalid arguments + with pytest.raises(AssertionError): + AttnConvertor(5) + with pytest.raises(AssertionError): + AttnConvertor('DICT90', dict_file, '1') + with pytest.raises(AssertionError): + AttnConvertor('DICT90', dict_file, True, '1') + + label_convertor = AttnConvertor(dict_file=dict_file, max_seq_len=10) + # test init and parse_dict + assert label_convertor.num_classes() == 10 + assert len(label_convertor.idx2char) == 10 + assert label_convertor.idx2char[0] == 'h' + assert label_convertor.idx2char[1] == 'e' + assert label_convertor.idx2char[-3] == '' + assert label_convertor.char2idx['h'] == 0 + assert label_convertor.unknown_idx == 7 + + # test encode str to tensor + strings = ['hell'] + targets_dict = label_convertor.str2tensor(strings) + assert torch.allclose(targets_dict['targets'][0], + torch.LongTensor([0, 1, 2, 2])) + assert torch.allclose(targets_dict['padded_targets'][0], + torch.LongTensor([8, 0, 1, 2, 2, 8, 9, 9, 9, 9])) + + # test decode output to index + dummy_output = torch.Tensor([[[100, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 100, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 100, 4, 5, 6, 7, 8, 9], + [1, 2, 100, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 100], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9], + [1, 2, 3, 4, 5, 6, 7, 100, 9]]]) + indexes, scores = label_convertor.tensor2idx(dummy_output) + assert np.allclose(indexes, [[0, 1, 2, 2]]) + + # test encode_str_label_to_index + with pytest.raises(AssertionError): + label_convertor.str2idx('hell') + tmp_indexes = label_convertor.str2idx(strings) + assert np.allclose(tmp_indexes, [[0, 1, 2, 2]]) + + # test decode_index to str_label + input_indexes = [[0, 1, 2, 2]] + with pytest.raises(AssertionError): + label_convertor.idx2str('hell') + output_strings = label_convertor.idx2str(input_indexes) + assert output_strings[0] == 'hell' + + tmp_dir.cleanup() diff --git a/tests/test_models/test_label_convertor/test_ctc_label_convertor.py b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py new file mode 100644 index 00000000..07c9cbf0 --- /dev/null +++ b/tests/test_models/test_label_convertor/test_ctc_label_convertor.py @@ -0,0 +1,79 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest +import torch + +from mmocr.models.textrecog.convertors import BaseConvertor, CTCConvertor + + +def _create_dummy_dict_file(dict_file): + chars = list('helowrd') + with open(dict_file, 'w') as fw: + for char in chars: + fw.write(char + '\n') + + +def test_ctc_label_convertor(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + + # test invalid arguments + with pytest.raises(AssertionError): + CTCConvertor(5) + + label_convertor = CTCConvertor(dict_file=dict_file, with_unknown=False) + # test init and parse_chars + assert label_convertor.num_classes() == 8 + assert len(label_convertor.idx2char) == 8 + assert label_convertor.idx2char[0] == '' + assert label_convertor.char2idx['h'] == 1 + assert label_convertor.unknown_idx is None + + # test encode str to tensor + strings = ['hell'] + expect_tensor = torch.IntTensor([1, 2, 3, 3]) + targets_dict = label_convertor.str2tensor(strings) + assert torch.allclose(targets_dict['targets'][0], expect_tensor) + assert torch.allclose(targets_dict['flatten_targets'], expect_tensor) + assert torch.allclose(targets_dict['target_lengths'], torch.IntTensor([4])) + + # test decode output to index + dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [1, 2, 100, 4, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8], + [100, 2, 3, 4, 5, 6, 7, 8], + [1, 2, 3, 100, 5, 6, 7, 8]]]) + indexes, scores = label_convertor.tensor2idx( + dummy_output, img_metas=[{ + 'valid_ratio': 1.0 + }]) + assert np.allclose(indexes, [[1, 2, 3, 3]]) + + # test encode_str_label_to_index + with pytest.raises(AssertionError): + label_convertor.str2idx('hell') + tmp_indexes = label_convertor.str2idx(strings) + assert np.allclose(tmp_indexes, [[1, 2, 3, 3]]) + + # test deocde_index_to_str_label + input_indexes = [[1, 2, 3, 3]] + with pytest.raises(AssertionError): + label_convertor.idx2str('hell') + output_strings = label_convertor.idx2str(input_indexes) + assert output_strings[0] == 'hell' + + tmp_dir.cleanup() + + +def test_base_label_convertor(): + with pytest.raises(NotImplementedError): + label_convertor = BaseConvertor() + label_convertor.str2tensor(None) + label_convertor.tensor2idx(None) diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py new file mode 100644 index 00000000..88d6522c --- /dev/null +++ b/tests/test_models/test_loss.py @@ -0,0 +1,33 @@ +import numpy as np +import torch + +import mmocr.models.textdet.losses as losses +from mmdet.core import BitmapMasks + + +def test_panloss(): + panloss = losses.PANLoss() + + # test bitmasks2tensor + mask = [[1, 0, 1], [1, 1, 1], [0, 0, 1]] + target = [[1, 0, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] + masks = [np.array(mask)] + bitmasks = BitmapMasks(masks, 3, 3) + target_sz = (6, 5) + results = panloss.bitmasks2tensor([bitmasks], target_sz) + assert len(results) == 1 + assert torch.sum(torch.abs(results[0].float() - + torch.Tensor(target))).item() == 0 + + +def test_textsnakeloss(): + textsnakeloss = losses.TextSnakeLoss() + + # test balanced_bce_loss + pred = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=torch.float) + target = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long) + mask = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=torch.long) + bce_loss = textsnakeloss.balanced_bce_loss(pred, target, mask).item() + + assert np.allclose(bce_loss, 0) diff --git a/tests/test_models/test_ocr_backbone.py b/tests/test_models/test_ocr_backbone.py new file mode 100644 index 00000000..f49a334d --- /dev/null +++ b/tests/test_models/test_ocr_backbone.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from mmocr.models.textrecog.backbones import ResNet31OCR, VeryDeepVgg + + +def test_resnet31_ocr_backbone(): + """Test resnet backbone.""" + with pytest.raises(AssertionError): + ResNet31OCR(2.5) + + with pytest.raises(AssertionError): + ResNet31OCR(3, layers=5) + + with pytest.raises(AssertionError): + ResNet31OCR(3, channels=5) + + # Test ResNet18 forward + model = ResNet31OCR() + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 160) + feat = model(imgs) + assert feat.shape == torch.Size([1, 512, 4, 40]) + + +def test_vgg_deep_vgg_ocr_backbone(): + + model = VeryDeepVgg() + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 32, 160) + feats = model(imgs) + assert feats.shape == torch.Size([1, 512, 1, 41]) diff --git a/tests/test_models/test_ocr_decoder.py b/tests/test_models/test_ocr_decoder.py new file mode 100644 index 00000000..709f86ed --- /dev/null +++ b/tests/test_models/test_ocr_decoder.py @@ -0,0 +1,112 @@ +import math + +import pytest +import torch + +from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder, + ParallelSARDecoderWithBS, + SequentialSARDecoder, TFDecoder) +from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode + + +def _create_dummy_input(): + feat = torch.rand(1, 512, 4, 40) + out_enc = torch.rand(1, 512) + tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} + img_metas = [{'valid_ratio': 1.0}] + + return feat, out_enc, tgt_dict, img_metas + + +def test_base_decoder(): + decoder = BaseDecoder() + with pytest.raises(NotImplementedError): + decoder.forward_train(None, None, None, None) + with pytest.raises(NotImplementedError): + decoder.forward_test(None, None, None) + + +def test_parallel_sar_decoder(): + # test parallel sar decoder + decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + + feat, out_enc, tgt_dict, img_metas = _create_dummy_input() + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, [], True) + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, img_metas * 2, True) + + out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) + assert out_train.shape == torch.Size([1, 5, 36]) + + out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) + assert out_test.shape == torch.Size([1, 5, 36]) + + +def test_sequential_sar_decoder(): + # test parallel sar decoder + decoder = SequentialSARDecoder( + num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + + feat, out_enc, tgt_dict, img_metas = _create_dummy_input() + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, []) + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, img_metas * 2) + + out_train = decoder(feat, out_enc, tgt_dict, img_metas, True) + assert out_train.shape == torch.Size([1, 5, 36]) + + out_test = decoder(feat, out_enc, tgt_dict, img_metas, False) + assert out_test.shape == torch.Size([1, 5, 36]) + + +def test_parallel_sar_decoder_with_beam_search(): + with pytest.raises(AssertionError): + ParallelSARDecoderWithBS(beam_width='beam') + with pytest.raises(AssertionError): + ParallelSARDecoderWithBS(beam_width=0) + + feat, out_enc, tgt_dict, img_metas = _create_dummy_input() + decoder = ParallelSARDecoderWithBS( + beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, []) + with pytest.raises(AssertionError): + decoder(feat, out_enc, tgt_dict, img_metas * 2) + + out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False) + assert out_test.shape == torch.Size([1, 5, 36]) + + # test decodenode + with pytest.raises(AssertionError): + DecodeNode(1, 1) + with pytest.raises(AssertionError): + DecodeNode([1, 2], ['4', '3']) + with pytest.raises(AssertionError): + DecodeNode([1, 2], [0.5]) + decode_node = DecodeNode([1, 2], [0.7, 0.8]) + assert math.isclose(decode_node.eval(), 1.5) + + +def test_transformer_decoder(): + decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5) + decoder.init_weights() + decoder.train() + + out_enc = torch.rand(1, 512, 1, 25) + tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])} + img_metas = [{'valid_ratio': 1.0}] + tgt_dict['padded_targets'] = tgt_dict['padded_targets'] + + out_train = decoder(None, out_enc, tgt_dict, img_metas, True) + assert out_train.shape == torch.Size([1, 5, 36]) + + out_test = decoder(None, out_enc, tgt_dict, img_metas, False) + assert out_test.shape == torch.Size([1, 5, 36]) diff --git a/tests/test_models/test_ocr_encoder.py b/tests/test_models/test_ocr_encoder.py new file mode 100644 index 00000000..ac185eeb --- /dev/null +++ b/tests/test_models/test_ocr_encoder.py @@ -0,0 +1,52 @@ +import pytest +import torch + +from mmocr.models.textrecog.encoders import BaseEncoder, SAREncoder, TFEncoder + + +def test_sar_encoder(): + with pytest.raises(AssertionError): + SAREncoder(enc_bi_rnn='bi') + with pytest.raises(AssertionError): + SAREncoder(enc_do_rnn=2) + with pytest.raises(AssertionError): + SAREncoder(enc_gru='gru') + with pytest.raises(AssertionError): + SAREncoder(d_model=512.5) + with pytest.raises(AssertionError): + SAREncoder(d_enc=200.5) + with pytest.raises(AssertionError): + SAREncoder(mask='mask') + + encoder = SAREncoder() + encoder.init_weights() + encoder.train() + + feat = torch.randn(1, 512, 4, 40) + img_metas = [{'valid_ratio': 1.0}] + with pytest.raises(AssertionError): + encoder(feat, img_metas * 2) + out_enc = encoder(feat, img_metas) + + assert out_enc.shape == torch.Size([1, 512]) + + +def test_transformer_encoder(): + tf_encoder = TFEncoder() + tf_encoder.init_weights() + tf_encoder.train() + + feat = torch.randn(1, 512, 1, 25) + out_enc = tf_encoder(feat) + print('hello', out_enc.size()) + assert out_enc.shape == torch.Size([1, 512, 1, 25]) + + +def test_base_encoder(): + encoder = BaseEncoder() + encoder.init_weights() + encoder.train() + + feat = torch.randn(1, 256, 4, 40) + out_enc = encoder(feat) + assert out_enc.shape == torch.Size([1, 256, 4, 40]) diff --git a/tests/test_models/test_ocr_head.py b/tests/test_models/test_ocr_head.py new file mode 100644 index 00000000..7df0f77b --- /dev/null +++ b/tests/test_models/test_ocr_head.py @@ -0,0 +1,16 @@ +import pytest +import torch + +from mmocr.models.textrecog import SegHead + + +def test_seg_head(): + with pytest.raises(AssertionError): + SegHead(num_classes='100') + with pytest.raises(AssertionError): + SegHead(num_classes=-1) + + seg_head = SegHead(num_classes=37) + out_neck = (torch.rand(1, 128, 32, 32), ) + out_head = seg_head(out_neck) + assert out_head.shape == torch.Size([1, 37, 32, 32]) diff --git a/tests/test_models/test_ocr_layer.py b/tests/test_models/test_ocr_layer.py new file mode 100644 index 00000000..f6e3d718 --- /dev/null +++ b/tests/test_models/test_ocr_layer.py @@ -0,0 +1,56 @@ +import torch + +from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck, + PositionalEncoding, + TransformerDecoderLayer, + get_pad_mask, get_subsequent_mask) +from mmocr.models.textrecog.layers.conv_layer import conv3x3 + + +def test_conv_layer(): + conv3by3 = conv3x3(3, 6) + assert conv3by3.in_channels == 3 + assert conv3by3.out_channels == 6 + assert conv3by3.kernel_size == (3, 3) + + x = torch.rand(1, 64, 224, 224) + # test basic block + basic_block = BasicBlock(64, 64) + assert basic_block.expansion == 1 + + out = basic_block(x) + + assert out.shape == torch.Size([1, 64, 224, 224]) + + # test bottle neck + bottle_neck = Bottleneck(64, 64, downsample=True) + assert bottle_neck.expansion == 4 + + out = bottle_neck(x) + + assert out.shape == torch.Size([1, 256, 224, 224]) + + +def test_transformer_layer(): + # test decoder_layer + decoder_layer = TransformerDecoderLayer() + in_dec = torch.rand(1, 30, 512) + out_enc = torch.rand(1, 128, 512) + out_dec = decoder_layer(in_dec, out_enc) + assert out_dec.shape == torch.Size([1, 30, 512]) + + # test positional_encoding + pos_encoder = PositionalEncoding() + x = torch.rand(1, 30, 512) + out = pos_encoder(x) + assert out.size() == x.size() + + # test get pad mask + seq = torch.rand(1, 30) + pad_idx = 0 + out = get_pad_mask(seq, pad_idx) + assert out.shape == torch.Size([1, 1, 30]) + + # test get_subsequent_mask + out_mask = get_subsequent_mask(seq) + assert out_mask.shape == torch.Size([1, 30, 30]) diff --git a/tests/test_models/test_ocr_loss.py b/tests/test_models/test_ocr_loss.py new file mode 100644 index 00000000..51f84fb2 --- /dev/null +++ b/tests/test_models/test_ocr_loss.py @@ -0,0 +1,90 @@ +import pytest +import torch + +from mmocr.models.common.losses import DiceLoss +from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss + + +def test_ctc_loss(): + with pytest.raises(AssertionError): + CTCLoss(flatten='flatten') + with pytest.raises(AssertionError): + CTCLoss(blank=None) + with pytest.raises(AssertionError): + CTCLoss(reduction=1) + with pytest.raises(AssertionError): + CTCLoss(zero_infinity='zero') + # test CTCLoss + ctc_loss = CTCLoss() + outputs = torch.zeros(2, 40, 37) + targets_dict = { + 'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]), + 'target_lengths': torch.LongTensor([2, 3]) + } + + losses = ctc_loss(outputs, targets_dict) + assert isinstance(losses, dict) + assert 'loss_ctc' in losses + assert torch.allclose(losses['loss_ctc'], + torch.tensor(losses['loss_ctc'].item()).float()) + + +def test_ce_loss(): + with pytest.raises(AssertionError): + CELoss(ignore_index='ignore') + with pytest.raises(AssertionError): + CELoss(reduction=1) + with pytest.raises(AssertionError): + CELoss(reduction='avg') + + ce_loss = CELoss(ignore_index=0) + outputs = torch.rand(1, 10, 37) + targets_dict = { + 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) + } + losses = ce_loss(outputs, targets_dict) + assert isinstance(losses, dict) + assert 'loss_ce' in losses + print(losses['loss_ce'].size()) + assert losses['loss_ce'].size(1) == 10 + + +def test_sar_loss(): + outputs = torch.rand(1, 10, 37) + targets_dict = { + 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) + } + sar_loss = SARLoss() + new_output, new_target = sar_loss.format(outputs, targets_dict) + assert new_output.shape == torch.Size([1, 37, 9]) + assert new_target.shape == torch.Size([1, 9]) + + +def test_tf_loss(): + with pytest.raises(AssertionError): + TFLoss(flatten=1.0) + + outputs = torch.rand(1, 10, 37) + targets_dict = { + 'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]) + } + tf_loss = TFLoss(flatten=False) + new_output, new_target = tf_loss.format(outputs, targets_dict) + assert new_output.shape == torch.Size([1, 37, 9]) + assert new_target.shape == torch.Size([1, 9]) + + +def test_dice_loss(): + with pytest.raises(AssertionError): + DiceLoss(eps='1') + + dice_loss = DiceLoss() + pred = torch.rand(1, 1, 32, 32) + gt = torch.rand(1, 1, 32, 32) + + loss = dice_loss(pred, gt, None) + assert isinstance(loss, torch.Tensor) + + mask = torch.rand(1, 1, 1, 1) + loss = dice_loss(pred, gt, mask) + assert isinstance(loss, torch.Tensor) diff --git a/tests/test_models/test_ocr_neck.py b/tests/test_models/test_ocr_neck.py new file mode 100644 index 00000000..28009311 --- /dev/null +++ b/tests/test_models/test_ocr_neck.py @@ -0,0 +1,17 @@ +import torch + +from mmocr.models.textrecog.necks import FPNOCR + + +def test_fpn_ocr(): + in_s1 = torch.rand(1, 128, 32, 256) + in_s2 = torch.rand(1, 256, 16, 128) + in_s3 = torch.rand(1, 512, 8, 64) + in_s4 = torch.rand(1, 512, 4, 32) + + fpn_ocr = FPNOCR(in_channels=[128, 256, 512, 512], out_channels=256) + fpn_ocr.init_weights() + fpn_ocr.train() + + out_neck = fpn_ocr((in_s1, in_s2, in_s3, in_s4)) + assert out_neck[0].shape == torch.Size([1, 256, 32, 256]) diff --git a/tests/test_models/test_panhead.py b/tests/test_models/test_panhead.py new file mode 100644 index 00000000..5afbd6d1 --- /dev/null +++ b/tests/test_models/test_panhead.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +import mmocr.models.textdet.dense_heads.pan_head as pan_head + + +def test_panhead(): + in_channels = [128] + out_channels = 128 + text_repr_type = 'poly' # 'poly' or 'quad' + downsample_ratio = 0.25 + loss = dict(type='PANLoss') + + # test invalid arguments + with pytest.raises(AssertionError): + panheader = pan_head.PANHead(128, out_channels, text_repr_type, + downsample_ratio, loss) + with pytest.raises(AssertionError): + panheader = pan_head.PANHead(in_channels, [out_channels], + text_repr_type, downsample_ratio, loss) + with pytest.raises(AssertionError): + panheader = pan_head.PANHead(in_channels, out_channels, 'test', + text_repr_type, downsample_ratio, loss) + with pytest.raises(AssertionError): + panheader = pan_head.PANHead(in_channels, out_channels, 'test', + downsample_ratio, loss) + with pytest.raises(AssertionError): + panheader = pan_head.PANHead(in_channels, out_channels, text_repr_type, + 1.1, loss) + + panheader = pan_head.PANHead(in_channels, out_channels, text_repr_type, + downsample_ratio, loss) + + # test resize_boundary + boundaries = [[0, 0, 0, 1, 1, 1, 0, 1, 0.9], + [0, 0, 0, 0.1, 0.1, 0.1, 0, 0.1, 0.9]] + target_boundary = [[0, 0, 0, 0.5, 1, 0.5, 0, 0.5, 0.9], + [0, 0, 0, 0.05, 0.1, 0.05, 0, 0.05, 0.9]] + scale_factor = np.array([1, 0.5, 1, 0.5]) + resized_boundary = panheader.resize_boundary(boundaries, scale_factor) + assert np.allclose(resized_boundary, target_boundary) diff --git a/tests/test_models/test_recog_config.py b/tests/test_models/test_recog_config.py new file mode 100644 index 00000000..478743c1 --- /dev/null +++ b/tests/test_models/test_recog_config.py @@ -0,0 +1,147 @@ +import copy +from os.path import dirname, exists, join + +import numpy as np +import pytest +import torch + + +def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300), + num_items=None): # yapf: disable + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): Input batch dimensions. + + num_items (None | list[int]): Specifies the number of boxes + for each batch item. + """ + + (N, C, H, W) = input_shape + + rng = np.random.RandomState(0) + + imgs = rng.rand(*input_shape) + + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'text': 'hello', + 'valid_ratio': 1.0, + } for _ in range(N)] + + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas + } + return mm_inputs + + +def _demo_gt_kernel_inputs(num_kernels=3, input_shape=(1, 3, 300, 300), + num_items=None): # yapf: disable + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): Input batch dimensions. + + num_items (None | list[int]): Specifies the number of boxes + for each batch item. + """ + from mmdet.core import BitmapMasks + + (N, C, H, W) = input_shape + gt_kernels = [] + + for batch_idx in range(N): + kernels = [] + for kernel_inx in range(num_kernels): + kernel = np.random.rand(H, W) + kernels.append(kernel) + gt_kernels.append(BitmapMasks(kernels, H, W)) + + return gt_kernels + + +def _get_config_directory(): + """Find the predefined detector config directory.""" + try: + # Assume we are running in the source mmocr repo + repo_dpath = dirname(dirname(dirname(__file__))) + except NameError: + # For IPython development when this __file__ is not defined + import mmocr + repo_dpath = dirname(dirname(mmocr.__file__)) + config_dpath = join(repo_dpath, 'configs') + if not exists(config_dpath): + raise Exception('Cannot find config path') + return config_dpath + + +def _get_config_module(fname): + """Load a configuration as a python module.""" + from mmcv import Config + config_dpath = _get_config_directory() + config_fpath = join(config_dpath, fname) + config_mod = Config.fromfile(config_fpath) + return config_mod + + +def _get_detector_cfg(fname): + """Grab configs necessary to create a detector. + + These are deep copied to allow for safe modification of parameters without + influencing other tests. + """ + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + return model + + +@pytest.mark.parametrize('cfg_file', [ + 'textrecog/sar/sar_r31_parallel_decoder_academic.py', + 'textrecog/crnn/crnn_academic_dataset.py', + 'textrecog/nrtr/nrtr_r31_academic.py', + 'textrecog/robust_scanner/robustscanner_r31_academic.py', + 'textrecog/seg/seg_r31_1by16_fpnocr_academic.py' +]) +def test_encoder_decoder_pipeline(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + + from mmocr.models import build_detector + detector = build_detector(model) + + input_shape = (1, 3, 32, 160) + if 'crnn' in cfg_file: + input_shape = (1, 1, 32, 160) + mm_inputs = _demo_mm_inputs(0, input_shape) + gt_kernels = None + if 'seg' in cfg_file: + gt_kernels = _demo_gt_kernel_inputs(3, input_shape) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + + # Test forward train + if 'seg' in cfg_file: + losses = detector.forward(imgs, img_metas, gt_kernels=gt_kernels) + else: + losses = detector.forward(imgs, img_metas) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test show_result + + results = {'text': 'hello', 'score': 1.0} + img = np.random.rand(5, 5, 3) + detector.show_result(img, results) diff --git a/tests/test_models/test_recognizer.py b/tests/test_models/test_recognizer.py new file mode 100644 index 00000000..d1081406 --- /dev/null +++ b/tests/test_models/test_recognizer.py @@ -0,0 +1,166 @@ +import os.path as osp +import tempfile + +import numpy as np +import pytest +import torch + +from mmdet.core import BitmapMasks +from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer, + SegRecognizer) + + +def _create_dummy_dict_file(dict_file): + chars = list('helowrd') + with open(dict_file, 'w') as fw: + for char in chars: + fw.write(char + '\n') + + +def test_base_recognizer(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + + label_convertor = dict( + type='CTCConvertor', dict_file=dict_file, with_unknown=False) + + preprocessor = None + backbone = dict(type='VeryDeepVgg', leakyRelu=False) + encoder = None + decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True) + loss = dict(type='CTCLoss') + + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(backbone=None) + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(decoder=None) + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(loss=None) + with pytest.raises(AssertionError): + EncodeDecodeRecognizer(label_convertor=None) + + recognizer = EncodeDecodeRecognizer( + preprocessor=preprocessor, + backbone=backbone, + encoder=encoder, + decoder=decoder, + loss=loss, + label_convertor=label_convertor) + + recognizer.init_weights() + recognizer.train() + + imgs = torch.rand(1, 3, 32, 160) + + # test extract feat + feat = recognizer.extract_feat(imgs) + assert feat.shape == torch.Size([1, 512, 1, 41]) + + # test forward train + img_metas = [{'text': 'hello', 'valid_ratio': 1.0}] + losses = recognizer.forward_train(imgs, img_metas) + assert isinstance(losses, dict) + assert 'loss_ctc' in losses + + # test simple test + results = recognizer.simple_test(imgs, img_metas) + assert isinstance(results, list) + assert isinstance(results[0], dict) + assert 'text' in results[0] + assert 'score' in results[0] + + # test aug_test + aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) + assert isinstance(aug_results, list) + assert isinstance(aug_results[0], dict) + assert 'text' in aug_results[0] + assert 'score' in aug_results[0] + + tmp_dir.cleanup() + + +def test_seg_recognizer(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') + _create_dummy_dict_file(dict_file) + + label_convertor = dict( + type='SegConvertor', dict_file=dict_file, with_unknown=False) + + preprocessor = None + backbone = dict( + type='ResNet31OCR', + layers=[1, 2, 5, 3], + channels=[32, 64, 128, 256, 512, 512], + out_indices=[0, 1, 2, 3], + stage4_pool_cfg=dict(kernel_size=2, stride=2), + last_stage_pool=True) + neck = dict( + type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256) + head = dict( + type='SegHead', + in_channels=256, + upsample_param=dict(scale_factor=2.0, mode='nearest')) + loss = dict(type='SegLoss', seg_downsample_ratio=1.0) + + with pytest.raises(AssertionError): + SegRecognizer(backbone=None) + with pytest.raises(AssertionError): + SegRecognizer(neck=None) + with pytest.raises(AssertionError): + SegRecognizer(head=None) + with pytest.raises(AssertionError): + SegRecognizer(loss=None) + with pytest.raises(AssertionError): + SegRecognizer(label_convertor=None) + + recognizer = SegRecognizer( + preprocessor=preprocessor, + backbone=backbone, + neck=neck, + head=head, + loss=loss, + label_convertor=label_convertor) + + recognizer.init_weights() + recognizer.train() + + imgs = torch.rand(1, 3, 64, 256) + + # test extract feat + feats = recognizer.extract_feat(imgs) + assert len(feats) == 4 + + assert feats[0].shape == torch.Size([1, 128, 32, 128]) + assert feats[1].shape == torch.Size([1, 256, 16, 64]) + assert feats[2].shape == torch.Size([1, 512, 8, 32]) + assert feats[3].shape == torch.Size([1, 512, 4, 16]) + + attn_tgt = np.zeros((64, 256), dtype=np.float32) + segm_tgt = np.zeros((64, 256), dtype=np.float32) + mask = np.zeros((64, 256), dtype=np.float32) + gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], 64, 256) + + # test forward train + img_metas = [{'text': 'hello', 'valid_ratio': 1.0}] + losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels]) + assert isinstance(losses, dict) + + # test simple test + results = recognizer.simple_test(imgs, img_metas) + assert isinstance(results, list) + assert isinstance(results[0], dict) + assert 'text' in results[0] + assert 'score' in results[0] + + # test aug_test + aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas]) + assert isinstance(aug_results, list) + assert isinstance(aug_results[0], dict) + assert 'text' in aug_results[0] + assert 'score' in aug_results[0] + + tmp_dir.cleanup() diff --git a/tests/test_models/test_targets.py b/tests/test_models/test_targets.py new file mode 100644 index 00000000..d270a53a --- /dev/null +++ b/tests/test_models/test_targets.py @@ -0,0 +1,32 @@ +import numpy as np + +from mmocr.datasets.pipelines.textdet_targets.dbnet_targets import DBNetTargets + + +def test_invalid_polys(): + + dbtarget = DBNetTargets() + + poly = np.array([[256.1229216, 347.17471155], [257.63126133, 347.0069367], + [257.70317729, 347.65337423], + [256.19488113, 347.82114909]]) + + assert dbtarget.invalid_polygon(poly) + + poly = np.array([[570.34735492, + 335.00214526], [570.99778839, 335.00327318], + [569.69077318, 338.47009908], + [569.04038393, 338.46894904]]) + assert dbtarget.invalid_polygon(poly) + + poly = np.array([[481.18343777, + 305.03190065], [479.88478587, 305.10684512], + [479.90976971, 305.53968843], [480.99197962, + 305.4772347]]) + assert dbtarget.invalid_polygon(poly) + + poly = np.array([[0, 0], [2, 0], [2, 2], [0, 2]]) + assert dbtarget.invalid_polygon(poly) + + poly = np.array([[0, 0], [10, 0], [10, 10], [0, 10]]) + assert not dbtarget.invalid_polygon(poly) diff --git a/tests/test_models/test_textdet_neck.py b/tests/test_models/test_textdet_neck.py new file mode 100644 index 00000000..3a434a83 --- /dev/null +++ b/tests/test_models/test_textdet_neck.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from mmocr.models.textdet.necks import FPN_UNET, FPNC + + +def test_fpnc(): + + in_channels = [64, 128, 256, 512] + size = [112, 56, 28, 14] + for flag in [False, True]: + fpnc = FPNC( + in_channels=in_channels, + bias_on_lateral=flag, + bn_re_on_lateral=flag, + bias_on_smooth=flag, + bn_re_on_smooth=flag, + conv_after_concat=flag) + fpnc.init_weights() + inputs = [] + for i in range(4): + inputs.append(torch.rand(1, in_channels[i], size[i], size[i])) + outputs = fpnc.forward(inputs) + assert list(outputs.size()) == [1, 256, 112, 112] + + +def test_fpn_unet_neck(): + s = 64 + feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8] + in_channels = [8, 16, 32, 64] + out_channels = 4 + + # len(in_channcels) is not equal to 4 + with pytest.raises(AssertionError): + FPN_UNET(in_channels + [128], out_channels) + + # `out_channels` is not int type + with pytest.raises(AssertionError): + FPN_UNET(in_channels, [2, 4]) + + feats = [ + torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i]) + for i in range(len(in_channels)) + ] + + fpn_unet_neck = FPN_UNET(in_channels, out_channels) + fpn_unet_neck.init_weights() + + out_neck = fpn_unet_neck(feats) + assert out_neck.shape == torch.Size([1, out_channels, s * 4, s * 4]) diff --git a/tests/test_tools/test_data_converter.py b/tests/test_tools/test_data_converter.py new file mode 100644 index 00000000..3677c91b --- /dev/null +++ b/tests/test_tools/test_data_converter.py @@ -0,0 +1,18 @@ +"""Test orientation check and ignore method.""" + +import shutil +import tempfile + +from mmocr.utils import drop_orientation + + +def test_drop_orientation(): + img_file = 'tests/data/test_img2.jpg' + output_file = drop_orientation(img_file) + assert output_file is img_file + + img_file = 'tests/data/test_img1.jpg' + tmp_dir = tempfile.TemporaryDirectory() + dst_file = shutil.copy(img_file, tmp_dir.name) + output_file = drop_orientation(dst_file) + assert output_file[-3:] == 'png' diff --git a/tests/test_utils/test_check_argument.py b/tests/test_utils/test_check_argument.py new file mode 100644 index 00000000..fd247353 --- /dev/null +++ b/tests/test_utils/test_check_argument.py @@ -0,0 +1,53 @@ +import numpy as np + +import mmocr.utils as utils + + +def test_is_3dlist(): + + assert utils.is_3dlist([]) + assert utils.is_3dlist([[]]) + assert utils.is_3dlist([[[]]]) + assert utils.is_3dlist([[[1]]]) + assert not utils.is_3dlist([[1, 2]]) + assert not utils.is_3dlist([[np.array([1, 2])]]) + + +def test_is_2dlist(): + + assert utils.is_2dlist([]) + assert utils.is_2dlist([[]]) + assert utils.is_2dlist([[1]]) + + +def test_is_ndarray_list(): + assert utils.is_ndarray_list([]) + assert utils.is_ndarray_list([np.ndarray([1])]) + assert not utils.is_ndarray_list([1]) + + +def test_is_type_list(): + assert utils.is_type_list([], int) + assert utils.is_type_list([], float) + assert utils.is_type_list([np.array([])], np.ndarray) + assert utils.is_type_list([1], int) + assert utils.is_type_list(['str'], str) + + +def test_is_none_or_type(): + + assert utils.is_none_or_type(None, int) + assert utils.is_none_or_type(1.0, float) + assert utils.is_none_or_type(np.ndarray([]), np.ndarray) + assert utils.is_none_or_type(1, int) + assert utils.is_none_or_type('str', str) + + +def test_valid_boundary(): + + x = [0, 0, 1, 0, 1, 1, 0, 1] + assert not utils.valid_boundary(x, True) + assert not utils.valid_boundary([0]) + assert utils.valid_boundary(x, False) + x = [0, 0, 1, 0, 1, 1, 0, 1, 1] + assert utils.valid_boundary(x, True) diff --git a/tests/test_utils/test_mask/test_mask_utils.py b/tests/test_utils/test_mask/test_mask_utils.py new file mode 100644 index 00000000..a1f2b0dd --- /dev/null +++ b/tests/test_utils/test_mask/test_mask_utils.py @@ -0,0 +1,197 @@ +"""Test text mask_utils.""" +import tempfile +from unittest import mock + +import numpy as np +import pytest + +import mmocr.core.evaluation.utils as eval_utils +import mmocr.core.mask as mask_utils +import mmocr.core.visualize as visualize_utils + + +def test_points2boundary(): + + points = np.array([[1, 2]]) + text_repr_type = 'quad' + text_score = None + + # test invalid arguments + with pytest.raises(AssertionError): + mask_utils.points2boundary([], text_repr_type, text_score) + + with pytest.raises(AssertionError): + mask_utils.points2boundary(points, '', text_score) + with pytest.raises(AssertionError): + mask_utils.points2boundary(points, '', 1.1) + + # test quad + points = np.array([[0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1], [0, 2], + [1, 2], [2, 2]]) + text_repr_type = 'quad' + text_score = None + + result = mask_utils.points2boundary(points, text_repr_type, text_score) + pred_poly = eval_utils.points2polygon(result) + target_poly = eval_utils.points2polygon([2, 2, 0, 2, 0, 0, 2, 0]) + assert eval_utils.poly_iou(pred_poly, target_poly) == 1 + + # test poly + text_repr_type = 'poly' + result = mask_utils.points2boundary(points, text_repr_type, text_score) + pred_poly = eval_utils.points2polygon(result) + target_poly = eval_utils.points2polygon([0, 0, 0, 2, 2, 2, 2, 0]) + assert eval_utils.poly_iou(pred_poly, target_poly) == 1 + + +def test_seg2boundary(): + + seg = np.array([[]]) + text_repr_type = 'quad' + text_score = None + # test invalid arguments + with pytest.raises(AssertionError): + mask_utils.seg2boundary([[]], text_repr_type, text_score) + with pytest.raises(AssertionError): + mask_utils.seg2boundary(seg, 1, text_score) + with pytest.raises(AssertionError): + mask_utils.seg2boundary(seg, text_repr_type, 1.1) + + seg = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + result = mask_utils.seg2boundary(seg, text_repr_type, text_score) + pred_poly = eval_utils.points2polygon(result) + target_poly = eval_utils.points2polygon([2, 2, 0, 2, 0, 0, 2, 0]) + assert eval_utils.poly_iou(pred_poly, target_poly) == 1 + + +@mock.patch('%s.visualize_utils.plt' % __name__) +def test_show_feature(mock_plt): + + features = [np.random.rand(10, 10)] + names = ['test'] + to_uint8 = [0] + out_file = None + + # test invalid arguments + with pytest.raises(AssertionError): + visualize_utils.show_feature([], names, to_uint8, out_file) + with pytest.raises(AssertionError): + visualize_utils.show_feature(features, [1], to_uint8, out_file) + with pytest.raises(AssertionError): + visualize_utils.show_feature(features, names, ['a'], out_file) + with pytest.raises(AssertionError): + visualize_utils.show_feature(features, names, to_uint8, 1) + with pytest.raises(AssertionError): + visualize_utils.show_feature(features, names, to_uint8, [0, 1]) + + visualize_utils.show_feature(features, names, to_uint8) + + # test showing img + mock_plt.title.assert_called_once_with('test') + mock_plt.show.assert_called_once() + + # test saving fig + out_file = tempfile.NamedTemporaryFile().name + visualize_utils.show_feature(features, names, to_uint8, out_file) + mock_plt.savefig.assert_called_once() + + +@mock.patch('%s.visualize_utils.plt' % __name__) +def test_show_img_boundary(mock_plt): + img = np.random.rand(10, 10) + boundary = [0, 0, 1, 0, 1, 1, 0, 1] + # test invalid arguments + with pytest.raises(AssertionError): + visualize_utils.show_img_boundary([], boundary) + with pytest.raises(AssertionError): + visualize_utils.show_img_boundary(img, np.array([])) + + # test showing img + + visualize_utils.show_img_boundary(img, boundary) + mock_plt.imshow.assert_called_once() + mock_plt.show.assert_called_once() + + +@mock.patch('%s.visualize_utils.mmcv' % __name__) +def test_show_pred_gt(mock_mmcv): + preds = [[0, 0, 1, 0, 1, 1, 0, 1]] + gts = [[0, 0, 1, 0, 1, 1, 0, 1]] + show = True + win_name = 'test' + wait_time = 0 + out_file = tempfile.NamedTemporaryFile().name + + with pytest.raises(AssertionError): + visualize_utils.show_pred_gt(np.array([]), gts) + with pytest.raises(AssertionError): + visualize_utils.show_pred_gt(preds, np.array([])) + + # test showing img + + visualize_utils.show_pred_gt(preds, gts, show, win_name, wait_time, + out_file) + mock_mmcv.imshow.assert_called_once() + mock_mmcv.imwrite.assert_called_once() + + +@mock.patch('%s.visualize_utils.mmcv.imshow' % __name__) +@mock.patch('%s.visualize_utils.mmcv.imwrite' % __name__) +def test_imshow_pred_boundary(mock_imshow, mock_imwrite): + img = './tests/data/test_img1.jpg' + boundaries_with_scores = [[0, 0, 1, 0, 1, 1, 0, 1, 1]] + labels = [1] + file = tempfile.NamedTemporaryFile().name + visualize_utils.imshow_pred_boundary( + img, boundaries_with_scores, labels, show=True, out_file=file) + mock_imwrite.assert_called_once() + mock_imshow.assert_called_once() + + +@mock.patch('%s.visualize_utils.mmcv.imshow' % __name__) +@mock.patch('%s.visualize_utils.mmcv.imwrite' % __name__) +def test_imshow_text_char_boundary(mock_imshow, mock_imwrite): + + img = './tests/data/test_img1.jpg' + text_quads = [[0, 0, 1, 0, 1, 1, 0, 1]] + boundaries = [[0, 0, 1, 0, 1, 1, 0, 1]] + char_quads = [[[0, 0, 1, 0, 1, 1, 0, 1], [0, 0, 1, 0, 1, 1, 0, 1]]] + chars = [['a', 'b']] + show = True, + out_file = tempfile.NamedTemporaryFile().name + visualize_utils.imshow_text_char_boundary( + img, + text_quads, + boundaries, + char_quads, + chars, + show=show, + out_file=out_file) + mock_imwrite.assert_called_once() + mock_imshow.assert_called_once() + + +@mock.patch('%s.visualize_utils.cv2.drawContours' % __name__) +def test_overlay_mask_img(mock_drawContours): + + img = np.random.rand(10, 10) + mask = np.zeros((10, 10)) + visualize_utils.overlay_mask_img(img, mask) + mock_drawContours.assert_called_once() + + +def test_extract_boundary(): + result = {} + + # test invalid arguments + with pytest.raises(AssertionError): + mask_utils.extract_boundary(result) + + result = {'boundary_result': [0, 1]} + with pytest.raises(AssertionError): + mask_utils.extract_boundary(result) + + result = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 1]]} + + output = mask_utils.extract_boundary(result) + assert output[2] == [1] diff --git a/tests/test_utils/test_text/test_text_utils.py b/tests/test_utils/test_text/test_text_utils.py new file mode 100644 index 00000000..a1d4a50f --- /dev/null +++ b/tests/test_utils/test_text/test_text_utils.py @@ -0,0 +1,66 @@ +"""Test text label visualize.""" +import os.path as osp +import random +import tempfile +from unittest import mock + +import numpy as np +import pytest + +import mmocr.core.visualize as visualize_utils + + +def test_tile_image(): + dummp_imgs, heights, widths = [], [], [] + for _ in range(3): + h = random.randint(100, 300) + w = random.randint(100, 300) + heights.append(h) + widths.append(w) + # dummy_img = Image.new('RGB', (w, h), Image.ANTIALIAS) + dummy_img = np.ones((h, w, 3), dtype=np.uint8) + dummp_imgs.append(dummy_img) + joint_img = visualize_utils.tile_image(dummp_imgs) + assert joint_img.shape[0] == sum(heights) + assert joint_img.shape[1] == max(widths) + + # test invalid arguments + with pytest.raises(AssertionError): + visualize_utils.tile_image(dummp_imgs[0]) + with pytest.raises(AssertionError): + visualize_utils.tile_image([]) + + +@mock.patch('%s.visualize_utils.mmcv.imread' % __name__) +@mock.patch('%s.visualize_utils.mmcv.imshow' % __name__) +@mock.patch('%s.visualize_utils.mmcv.imwrite' % __name__) +def test_show_text_label(mock_imwrite, mock_imshow, mock_imread): + img = np.ones((32, 160), dtype=np.uint8) + pred_label = 'hello' + gt_label = 'world' + + tmp_dir = tempfile.TemporaryDirectory() + out_file = osp.join(tmp_dir.name, 'tmp.jpg') + + # test invalid arguments + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label(5, pred_label, gt_label) + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label(img, pred_label, 4) + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label(img, 3, gt_label) + with pytest.raises(AssertionError): + visualize_utils.imshow_text_label( + img, pred_label, gt_label, show=True, wait_time=0.1) + + mock_imread.side_effect = [img, img] + visualize_utils.imshow_text_label( + img, pred_label, gt_label, out_file=out_file) + visualize_utils.imshow_text_label( + img, pred_label, gt_label, out_file=None, show=True) + + # test showing img + mock_imshow.assert_called_once() + mock_imwrite.assert_called_once() + + tmp_dir.cleanup() diff --git a/tests/test_utils/test_wrapper.py b/tests/test_utils/test_wrapper.py new file mode 100644 index 00000000..b8083e03 --- /dev/null +++ b/tests/test_utils/test_wrapper.py @@ -0,0 +1,14 @@ +import numpy as np +import torch + + +def test_db_boxes_from_bitmaps(): + """Test the boxes_from_bitmaps function in db_decoder.""" + from mmocr.models.textdet.postprocess.wrapper import db_decode + pred = np.array([[[0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0]]]) + preds = torch.FloatTensor(pred).requires_grad_(True) + + boxes = db_decode(preds, text_repr_type='quad', min_text_width=0) + assert len(boxes) == 1 diff --git a/tools/data/textdet/coco_to_line_dict.py b/tools/data/textdet/coco_to_line_dict.py new file mode 100644 index 00000000..44b14505 --- /dev/null +++ b/tools/data/textdet/coco_to_line_dict.py @@ -0,0 +1,70 @@ +import argparse +import codecs +import json + + +def read_json(fpath): + with codecs.open(fpath, 'r', 'utf-8') as f: + obj = json.load(f) + return obj + + +def parse_coco_json(in_path): + json_obj = read_json(in_path) + image_infos = json_obj['images'] + annotations = json_obj['annotations'] + imgid2imgname = {} + img_ids = [] + for image_info in image_infos: + imgid2imgname[image_info['id']] = image_info + img_ids.append(image_info['id']) + imgid2anno = {} + for img_id in img_ids: + imgid2anno[img_id] = [] + for anno in annotations: + img_id = anno['image_id'] + new_anno = {} + new_anno['iscrowd'] = anno['iscrowd'] + new_anno['category_id'] = anno['category_id'] + new_anno['bbox'] = anno['bbox'] + new_anno['segmentation'] = anno['segmentation'] + if img_id in imgid2anno.keys(): + imgid2anno[img_id].append(new_anno) + + return imgid2imgname, imgid2anno + + +def gen_line_dict_file(out_path, imgid2imgname, imgid2anno): + # import pdb; pdb.set_trace() + with codecs.open(out_path, 'w', 'utf-8') as fw: + for key, value in imgid2imgname.items(): + if key in imgid2anno: + anno = imgid2anno[key] + line_dict = {} + line_dict['file_name'] = value['file_name'] + line_dict['height'] = value['height'] + line_dict['width'] = value['width'] + line_dict['annotations'] = anno + line_dict_str = json.dumps(line_dict) + fw.write(line_dict_str + '\n') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--in-path', help='input json path with coco format') + parser.add_argument( + '--out-path', help='output txt path with line-json format') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + imgid2imgname, imgid2anno = parse_coco_json(args.in_path) + gen_line_dict_file(args.out_path, imgid2imgname, imgid2anno) + print('finish') + + +if __name__ == '__main__': + main() diff --git a/tools/data/textdet/ctw1500_converter.py b/tools/data/textdet/ctw1500_converter.py new file mode 100644 index 00000000..aad9b3fd --- /dev/null +++ b/tools/data/textdet/ctw1500_converter.py @@ -0,0 +1,239 @@ +import argparse +import glob +import os.path as osp +import xml.etree.ElementTree as ET +from functools import partial + +import mmcv +import numpy as np +from shapely.geometry import Polygon +from tools.data.utils.common import convert_annotations, is_not_png + +from mmocr.utils import drop_orientation + + +def collect_files(img_dir, gt_dir, split): + """Collect all images and their corresponding groundtruth files. + + Args: + img_dir(str): The image directory + gt_dir(str): The groundtruth directory + split(str): The split of dataset. Namely: training or test + + Returns: + files(list): The list of tuples (img_file, groundtruth_file) + """ + assert isinstance(img_dir, str) + assert img_dir + assert isinstance(gt_dir, str) + assert gt_dir + + # note that we handle png and jpg only. Pls convert others such as gif to + # jpg or png offline + suffixes = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG'] + + imgs_list = [] + for suffix in suffixes: + imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix))) + + imgs_list = [ + drop_orientation(f) if is_not_png(f) else f for f in imgs_list + ] + + files = [] + if split == 'training': + for img_file in imgs_list: + gt_file = gt_dir + '/' + osp.splitext( + osp.basename(img_file))[0] + '.xml' + files.append((img_file, gt_file)) + assert len(files), f'No images found in {img_dir}' + print(f'Loaded {len(files)} images from {img_dir}') + elif split == 'test': + for img_file in imgs_list: + gt_file = gt_dir + '/000' + osp.splitext( + osp.basename(img_file))[0] + '.txt' + files.append((img_file, gt_file)) + assert len(files), f'No images found in {img_dir}' + print(f'Loaded {len(files)} images from {img_dir}') + + return files + + +def collect_annotations(files, split, nproc=1): + """Collect the annotation information. + + Args: + files(list): The list of tuples (image_file, groundtruth_file) + split(str): The split of dataset. Namely: training or test + nproc(int): The number of process to collect annotations + + Returns: + images(list): The list of image information dicts + """ + assert isinstance(files, list) + assert isinstance(split, str) + assert isinstance(nproc, int) + + load_img_info_with_split = partial(load_img_info, split=split) + if nproc > 1: + images = mmcv.track_parallel_progress( + load_img_info_with_split, files, nproc=nproc) + else: + images = mmcv.track_progress(load_img_info_with_split, files) + + return images + + +def load_txt_info(gt_file, img_info): + with open(gt_file) as f: + gt_list = f.readlines() + + anno_info = [] + for line in gt_list: + # each line has one ploygen (n vetices), and one text. + # e.g., 695,885,866,888,867,1146,696,1143,####Latin 9 + line = line.strip() + strs = line.split(',') + category_id = 1 + assert strs[28][0] == '#' + xy = [int(x) for x in strs[0:28]] + assert len(xy) == 28 + coordinates = np.array(xy).reshape(-1, 2) + polygon = Polygon(coordinates) + iscrowd = 0 + area = polygon.area + # convert to COCO style XYWH format + min_x, min_y, max_x, max_y = polygon.bounds + bbox = [min_x, min_y, max_x - min_x, max_y - min_y] + + anno = dict( + iscrowd=iscrowd, + category_id=category_id, + bbox=bbox, + area=area, + segmentation=[xy]) + anno_info.append(anno) + img_info.update(anno_info=anno_info) + return img_info + + +def load_xml_info(gt_file, img_info): + + obj = ET.parse(gt_file) + anno_info = [] + for image in obj.getroot(): # image + for box in image: # image + h = box.attrib['height'] + w = box.attrib['width'] + x = box.attrib['left'] + y = box.attrib['top'] + # label = box[0].text + segs = box[1].text + pts = segs.strip().split(',') + pts = [int(x) for x in pts] + assert len(pts) == 28 + # pts = [] + # for iter in range(2,len(box)): + # pts.extend([int(box[iter].attrib['x']), + # int(box[iter].attrib['y'])]) + iscrowd = 0 + category_id = 1 + bbox = [int(x), int(y), int(w), int(h)] + + coordinates = np.array(pts).reshape(-1, 2) + polygon = Polygon(coordinates) + area = polygon.area + anno = dict( + iscrowd=iscrowd, + category_id=category_id, + bbox=bbox, + area=area, + segmentation=[pts]) + anno_info.append(anno) + + img_info.update(anno_info=anno_info) + + return img_info + + +def load_img_info(files, split): + """Load the information of one image. + + Args: + files(tuple): The tuple of (img_file, groundtruth_file) + split(str): The split of dataset: training or test + + Returns: + img_info(dict): The dict of the img and annotation information + """ + assert isinstance(files, tuple) + assert isinstance(split, str) + + img_file, gt_file = files + # read imgs with ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + # read imgs with orientations as dataloader does when training and testing + img_color = mmcv.imread(img_file, 'color') + # make sure imgs have no orientations info, or annotation gt is wrong. + assert img.shape[0:2] == img_color.shape[0:2] + + split_name = osp.basename(osp.dirname(img_file)) + img_info = dict( + # remove img_prefix for filename + file_name=osp.join(split_name, osp.basename(img_file)), + height=img.shape[0], + width=img.shape[1], + # anno_info=anno_info, + segm_file=osp.join(split_name, osp.basename(gt_file))) + + if split == 'training': + img_info = load_xml_info(gt_file, img_info) + elif split == 'test': + img_info = load_txt_info(gt_file, img_info) + else: + raise NotImplementedError + + return img_info + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert ctw1500 annotations to COCO format') + parser.add_argument('root_path', help='ctw1500 root path') + parser.add_argument('-o', '--out-dir', help='output path') + parser.add_argument( + '--split-list', + nargs='+', + help='a list of splits. e.g., "--split_list training test"') + + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + root_path = args.root_path + out_dir = args.out_dir if args.out_dir else root_path + mmcv.mkdir_or_exist(out_dir) + + img_dir = osp.join(root_path, 'imgs') + gt_dir = osp.join(root_path, 'annotations') + + set_name = {} + for split in args.split_list: + set_name.update({split: 'instances_' + split + '.json'}) + assert osp.exists(osp.join(img_dir, split)) + + for split, json_name in set_name.items(): + print(f'Converting {split} into {json_name}') + with mmcv.Timer(print_tmpl='It takes {}s to convert icdar annotation'): + files = collect_files( + osp.join(img_dir, split), osp.join(gt_dir, split), split) + image_infos = collect_annotations(files, split, nproc=args.nproc) + convert_annotations(image_infos, osp.join(out_dir, json_name)) + + +if __name__ == '__main__': + main() diff --git a/tools/data/textdet/icdar_converter.py b/tools/data/textdet/icdar_converter.py new file mode 100644 index 00000000..1744e103 --- /dev/null +++ b/tools/data/textdet/icdar_converter.py @@ -0,0 +1,192 @@ +import argparse +import glob +import os.path as osp +from functools import partial + +import mmcv +import numpy as np +from shapely.geometry import Polygon +from tools.data.utils.common import convert_annotations, is_not_png + +from mmocr.utils import drop_orientation + + +def collect_files(img_dir, gt_dir): + """Collect all images and their corresponding groundtruth files. + + Args: + img_dir(str): The image directory + gt_dir(str): The groundtruth directory + + Returns: + files(list): The list of tuples (img_file, groundtruth_file) + """ + assert isinstance(img_dir, str) + assert img_dir + assert isinstance(gt_dir, str) + assert gt_dir + + # note that we handle png and jpg only. Pls convert others such as gif to + # jpg or png offline + suffixes = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG'] + imgs_list = [] + for suffix in suffixes: + imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix))) + + imgs_list = [ + drop_orientation(f) if is_not_png(f) else f for f in imgs_list + ] + + files = [] + for img_file in imgs_list: + gt_file = gt_dir + '/gt_' + osp.splitext( + osp.basename(img_file))[0] + '.txt' + files.append((img_file, gt_file)) + assert len(files), f'No images found in {img_dir}' + print(f'Loaded {len(files)} images from {img_dir}') + + return files + + +def collect_annotations(files, dataset, nproc=1): + """Collect the annotation information. + + Args: + files(list): The list of tuples (image_file, groundtruth_file) + dataset(str): The dataset name, icdar2015 or icdar2017 + nproc(int): The number of process to collect annotations + + Returns: + images(list): The list of image information dicts + """ + assert isinstance(files, list) + assert isinstance(dataset, str) + assert dataset + assert isinstance(nproc, int) + + load_img_info_with_dataset = partial(load_img_info, dataset=dataset) + if nproc > 1: + images = mmcv.track_parallel_progress( + load_img_info_with_dataset, files, nproc=nproc) + else: + images = mmcv.track_progress(load_img_info_with_dataset, files) + + return images + + +def load_img_info(files, dataset): + """Load the information of one image. + + Args: + files(tuple): The tuple of (img_file, groundtruth_file) + dataset(str): Dataset name, icdar2015 or icdar2017 + + Returns: + img_info(dict): The dict of the img and annotation information + """ + assert isinstance(files, tuple) + assert isinstance(dataset, str) + assert dataset + + img_file, gt_file = files + # read imgs with ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + # read imgs with orientations as dataloader does when training and testing + img_color = mmcv.imread(img_file, 'color') + # make sure imgs have no orientations info, or annotation gt is wrong. + assert img.shape[0:2] == img_color.shape[0:2] + + if dataset == 'icdar2017': + with open(gt_file) as f: + gt_list = f.readlines() + elif dataset == 'icdar2015': + with open(gt_file, mode='r', encoding='utf-8-sig') as f: + gt_list = f.readlines() + else: + raise NotImplementedError(f'Not support {dataset}') + + anno_info = [] + for line in gt_list: + # each line has one ploygen (4 vetices), and others. + # e.g., 695,885,866,888,867,1146,696,1143,Latin,9 + line = line.strip() + strs = line.split(',') + category_id = 1 + xy = [int(x) for x in strs[0:8]] + coordinates = np.array(xy).reshape(-1, 2) + polygon = Polygon(coordinates) + iscrowd = 0 + # set iscrowd to 1 to ignore 1. + if (dataset == 'icdar2015' + and strs[8] == '###') or (dataset == 'icdar2017' + and strs[9] == '###'): + iscrowd = 1 + print('ignore text') + + area = polygon.area + # convert to COCO style XYWH format + min_x, min_y, max_x, max_y = polygon.bounds + bbox = [min_x, min_y, max_x - min_x, max_y - min_y] + + anno = dict( + iscrowd=iscrowd, + category_id=category_id, + bbox=bbox, + area=area, + segmentation=[xy]) + anno_info.append(anno) + split_name = osp.basename(osp.dirname(img_file)) + img_info = dict( + # remove img_prefix for filename + file_name=osp.join(split_name, osp.basename(img_file)), + height=img.shape[0], + width=img.shape[1], + anno_info=anno_info, + segm_file=osp.join(split_name, osp.basename(gt_file))) + return img_info + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert Icdar2015 or Icdar2017 annotations to COCO format' + ) + parser.add_argument('icdar_path', help='icdar root path') + parser.add_argument('-o', '--out-dir', help='output path') + parser.add_argument('-d', '--dataset', help='icdar2017 or icdar2015') + parser.add_argument( + '--split-list', + nargs='+', + help='a list of splits. e.g., "--split-list training validation test"') + + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + icdar_path = args.icdar_path + out_dir = args.out_dir if args.out_dir else icdar_path + mmcv.mkdir_or_exist(out_dir) + + img_dir = osp.join(icdar_path, 'imgs') + gt_dir = osp.join(icdar_path, 'annotations') + + set_name = {} + for split in args.split_list: + set_name.update({split: 'instances_' + split + '.json'}) + assert osp.exists(osp.join(img_dir, split)) + + for split, json_name in set_name.items(): + print(f'Converting {split} into {json_name}') + with mmcv.Timer(print_tmpl='It takes {}s to convert icdar annotation'): + files = collect_files( + osp.join(img_dir, split), osp.join(gt_dir, split)) + image_infos = collect_annotations( + files, args.dataset, nproc=args.nproc) + convert_annotations(image_infos, osp.join(out_dir, json_name)) + + +if __name__ == '__main__': + main() diff --git a/tools/data/textdet/synthtext_converter.py b/tools/data/textdet/synthtext_converter.py new file mode 100644 index 00000000..31ad7f02 --- /dev/null +++ b/tools/data/textdet/synthtext_converter.py @@ -0,0 +1,178 @@ +import argparse +import json +import os.path as osp +import time + +import lmdb +import mmcv +import numpy as np +from scipy.io import loadmat +from shapely.geometry import Polygon + +from mmocr.utils import check_argument + + +def trace_boundary(char_boxes): + """Trace the boundary point of text. + + Args: + char_boxes (list[ndarray]): The char boxes for one text. Each element + is 4x2 ndarray. + + Returns: + boundary (ndarray): The boundary point sets with size nx2. + """ + assert check_argument.is_type_list(char_boxes, np.ndarray) + + # from top left to to right + p_top = [box[0:2] for box in char_boxes] + # from bottom right to bottom left + p_bottom = [ + char_boxes[inx][[2, 3], :] + for inx in range(len(char_boxes) - 1, -1, -1) + ] + + p = p_top + p_bottom + + boundary = np.concatenate(p).astype(int) + + return boundary + + +def match_bbox_char_str(bboxes, char_bboxes, strs): + """match the bboxes, char bboxes, and strs. + + Args: + bboxes (ndarray): The text boxes of size (2, 4, num_box). + char_bboxes (ndarray): The char boxes of size (2, 4, num_char_box). + strs (ndarray): The string of size (num_strs,) + """ + assert isinstance(bboxes, np.ndarray) + assert isinstance(char_bboxes, np.ndarray) + assert isinstance(strs, np.ndarray) + bboxes = bboxes.astype(np.int32) + char_bboxes = char_bboxes.astype(np.int32) + + if len(char_bboxes.shape) == 2: + char_bboxes = np.expand_dims(char_bboxes, axis=2) + char_bboxes = np.transpose(char_bboxes, (2, 1, 0)) + if len(bboxes.shape) == 2: + bboxes = np.expand_dims(bboxes, axis=2) + bboxes = np.transpose(bboxes, (2, 1, 0)) + chars = ''.join(strs).replace('\n', '').replace(' ', '') + num_boxes = bboxes.shape[0] + + poly_list = [Polygon(bboxes[iter]) for iter in range(num_boxes)] + poly_box_list = [bboxes[iter] for iter in range(num_boxes)] + + poly_char_list = [[] for iter in range(num_boxes)] + poly_char_idx_list = [[] for iter in range(num_boxes)] + poly_charbox_list = [[] for iter in range(num_boxes)] + + words = [] + for s in strs: + words += s.split() + words_len = [len(w) for w in words] + words_end_inx = np.cumsum(words_len) + start_inx = 0 + for word_inx, end_inx in enumerate(words_end_inx): + for char_inx in range(start_inx, end_inx): + poly_char_idx_list[word_inx].append(char_inx) + poly_char_list[word_inx].append(chars[char_inx]) + poly_charbox_list[word_inx].append(char_bboxes[char_inx]) + start_inx = end_inx + + for box_inx in range(num_boxes): + assert len(poly_charbox_list[box_inx]) > 0 + + poly_boundary_list = [] + for item in poly_charbox_list: + boundary = np.ndarray((0, 2)) + if len(item) > 0: + boundary = trace_boundary(item) + poly_boundary_list.append(boundary) + + return (poly_list, poly_box_list, poly_boundary_list, poly_charbox_list, + poly_char_idx_list, poly_char_list) + + +def convert_annotations(root_path, gt_name, lmdb_name): + """Convert the annotation into lmdb dataset. + + Args: + root_path (str): The root path of dataset. + gt_name (str): The ground truth filename. + lmdb_name (str): The output lmdb filename. + """ + assert isinstance(root_path, str) + assert isinstance(gt_name, str) + assert isinstance(lmdb_name, str) + start_time = time.time() + gt = loadmat(gt_name) + img_num = len(gt['imnames'][0]) + env = lmdb.open(lmdb_name, map_size=int(1e9 * 40)) + with env.begin(write=True) as txn: + for img_id in range(img_num): + if img_id % 1000 == 0 and img_id > 0: + total_time_sec = time.time() - start_time + avg_time_sec = total_time_sec / img_id + eta_mins = (avg_time_sec * (img_num - img_id)) / 60 + print(f'\ncurrent_img/total_imgs {img_id}/{img_num} | ' + f'eta: {eta_mins:.3f} mins') + # for each img + img_file = osp.join(root_path, 'imgs', gt['imnames'][0][img_id][0]) + img = mmcv.imread(img_file, 'unchanged') + height, width = img.shape[0:2] + img_json = {} + img_json['file_name'] = gt['imnames'][0][img_id][0] + img_json['height'] = height + img_json['width'] = width + img_json['annotations'] = [] + wordBB = gt['wordBB'][0][img_id] + charBB = gt['charBB'][0][img_id] + txt = gt['txt'][0][img_id] + poly_list, _, poly_boundary_list, _, _, _ = match_bbox_char_str( + wordBB, charBB, txt) + for poly_inx in range(len(poly_list)): + + polygon = poly_list[poly_inx] + min_x, min_y, max_x, max_y = polygon.bounds + bbox = [min_x, min_y, max_x - min_x, max_y - min_y] + anno_info = dict() + anno_info['iscrowd'] = 0 + anno_info['category_id'] = 1 + anno_info['bbox'] = bbox + anno_info['segmentation'] = [ + poly_boundary_list[poly_inx].flatten().tolist() + ] + + img_json['annotations'].append(anno_info) + string = json.dumps(img_json) + txn.put(str(img_id).encode('utf8'), string.encode('utf8')) + key = 'total_number'.encode('utf8') + value = str(img_num).encode('utf8') + txn.put(key, value) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert synthtext to lmdb dataset') + parser.add_argument('synthtext_path', help='synthetic root path') + parser.add_argument('-o', '--out-dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + synthtext_path = args.synthtext_path + out_dir = args.out_dir if args.out_dir else synthtext_path + mmcv.mkdir_or_exist(out_dir) + + gt_name = osp.join(synthtext_path, 'gt.mat') + lmdb_name = 'synthtext.lmdb' + convert_annotations(synthtext_path, gt_name, osp.join(out_dir, lmdb_name)) + + +if __name__ == '__main__': + main() diff --git a/tools/data/textdet/totaltext_converter.py b/tools/data/textdet/totaltext_converter.py new file mode 100644 index 00000000..a15c62aa --- /dev/null +++ b/tools/data/textdet/totaltext_converter.py @@ -0,0 +1,312 @@ +import argparse +import glob +import os.path as osp +from functools import partial + +import cv2 +import mmcv +import numpy as np +import scipy.io as scio +from shapely.geometry import Polygon +from tools.data_converter.common import convert_annotations, is_not_png + +from mmocr.utils import drop_orientation + + +def collect_files(img_dir, gt_dir, split): + """Collect all images and their corresponding groundtruth files. + + Args: + img_dir(str): The image directory + gt_dir(str): The groundtruth directory + split(str): The split of dataset. Namely: training or test + + Returns: + files(list): The list of tuples (img_file, groundtruth_file) + """ + assert isinstance(img_dir, str) + assert img_dir + assert isinstance(gt_dir, str) + assert gt_dir + + # note that we handle png and jpg only. Pls convert others such as gif to + # jpg or png offline + suffixes = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG'] + # suffixes = ['.png'] + + imgs_list = [] + for suffix in suffixes: + imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix))) + + imgs_list = [ + drop_orientation(f) if is_not_png(f) else f for f in imgs_list + ] + + files = [] + if split == 'training': + for img_file in imgs_list: + gt_file = gt_dir + '/gt_' + osp.splitext( + osp.basename(img_file))[0] + '.mat' + # gt_file = gt_dir + '/' + osp.splitext( + # osp.basename(img_file))[0] + '.png' + files.append((img_file, gt_file)) + assert len(files), f'No images found in {img_dir}' + print(f'Loaded {len(files)} images from {img_dir}') + elif split == 'test': + for img_file in imgs_list: + gt_file = gt_dir + '/poly_gt_' + osp.splitext( + osp.basename(img_file))[0] + '.mat' + files.append((img_file, gt_file)) + assert len(files), f'No images found in {img_dir}' + print(f'Loaded {len(files)} images from {img_dir}') + + return files + + +def collect_annotations(files, split, nproc=1): + """Collect the annotation information. + + Args: + files(list): The list of tuples (image_file, groundtruth_file) + split(str): The split of dataset. Namely: training or test + nproc(int): The number of process to collect annotations + + Returns: + images(list): The list of image information dicts + """ + assert isinstance(files, list) + assert isinstance(split, str) + assert isinstance(nproc, int) + + load_img_info_with_split = partial(load_img_info, split=split) + if nproc > 1: + images = mmcv.track_parallel_progress( + load_img_info_with_split, files, nproc=nproc) + else: + images = mmcv.track_progress(load_img_info_with_split, files) + + return images + + +def get_contours(gt_path, split): + """Get the contours and words for each ground_truth file. + + Args: + gt_path(str): The relative path of the ground_truth mat file + split(str): The split of dataset: training or test + + Returns: + contours(list[lists]): A list of lists of contours + for the text instances + words(list[list]): A list of lists of words (string) + for the text instances + """ + assert isinstance(gt_path, str) + assert isinstance(split, str) + + contours = [] + words = [] + data = scio.loadmat(gt_path) + if split == 'training': + data_polygt = data['gt'] + elif split == 'test': + data_polygt = data['polygt'] + + for i, lines in enumerate(data_polygt): + X = np.array(lines[1]) + Y = np.array(lines[3]) + + point_num = len(X[0]) + word = lines[4] + if len(word) == 0: + word = '???' + else: + word = word[0] + + if word == '#': + word = '###' + continue + + words.append(word) + + arr = np.concatenate([X, Y]).T + contour = [] + for i in range(point_num): + contour.append(arr[i][0]) + contour.append(arr[i][1]) + contours.append(np.asarray(contour)) + + return contours, words + + +def load_mat_info(img_info, gt_file, split): + """Load the information of one ground truth in .mat format. + + Args: + img_info(dict): The dict of only the image information + gt_file(str): The relative path of the ground_truth mat + file for one image + split(str): The split of dataset: training or test + + Returns: + img_info(dict): The dict of the img and annotation information + """ + assert isinstance(img_info, dict) + assert isinstance(gt_file, str) + assert isinstance(split, str) + + contours, words = get_contours(gt_file, split) + anno_info = [] + for contour in contours: + if contour.shape[0] == 2: + continue + category_id = 1 + coordinates = np.array(contour).reshape(-1, 2) + polygon = Polygon(coordinates) + iscrowd = 0 + + area = polygon.area + # convert to COCO style XYWH format + min_x, min_y, max_x, max_y = polygon.bounds + bbox = [min_x, min_y, max_x - min_x, max_y - min_y] + + anno = dict( + iscrowd=iscrowd, + category_id=category_id, + bbox=bbox, + area=area, + segmentation=[contour]) + anno_info.append(anno) + + img_info.update(anno_info=anno_info) + + return img_info + + +def load_png_info(gt_file, img_info): + """Load the information of one ground truth in .png format. + + Args: + gt_file(str): The relative path of the ground_truth file for one image + img_info(dict): The dict of only the image information + + Returns: + img_info(dict): The dict of the img and annotation information + """ + assert isinstance(gt_file, str) + assert isinstance(img_info, dict) + gt_img = cv2.imread(gt_file, 0) + contours, _ = cv2.findContours(gt_img, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE) + + anno_info = [] + for contour in contours: + if contour.shape[0] == 2: + continue + category_id = 1 + xy = np.array(contour).flatten().tolist() + + coordinates = np.array(contour).reshape(-1, 2) + polygon = Polygon(coordinates) + iscrowd = 0 + + area = polygon.area + # convert to COCO style XYWH format + min_x, min_y, max_x, max_y = polygon.bounds + bbox = [min_x, min_y, max_x - min_x, max_y - min_y] + + anno = dict( + iscrowd=iscrowd, + category_id=category_id, + bbox=bbox, + area=area, + segmentation=[xy]) + anno_info.append(anno) + + img_info.update(anno_info=anno_info) + + return img_info + + +def load_img_info(files, split): + """Load the information of one image. + + Args: + files(tuple): The tuple of (img_file, groundtruth_file) + split(str): The split of dataset: training or test + + Returns: + img_info(dict): The dict of the img and annotation information + """ + assert isinstance(files, tuple) + assert isinstance(split, str) + + img_file, gt_file = files + # read imgs with ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + # read imgs with orientations as dataloader does when training and testing + img_color = mmcv.imread(img_file, 'color') + # make sure imgs have no orientation info, or annotation gt is wrong. + assert img.shape[0:2] == img_color.shape[0:2] + + split_name = osp.basename(osp.dirname(img_file)) + img_info = dict( + # remove img_prefix for filename + file_name=osp.join(split_name, osp.basename(img_file)), + height=img.shape[0], + width=img.shape[1], + # anno_info=anno_info, + segm_file=osp.join(split_name, osp.basename(gt_file))) + + if split == 'training': + img_info = load_mat_info(img_info, gt_file, split) + elif split == 'test': + img_info = load_mat_info(img_info, gt_file, split) + else: + raise NotImplementedError + + return img_info + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert totaltext annotations to COCO format') + parser.add_argument('root_path', help='totaltext root path') + parser.add_argument('-o', '--out-dir', help='output path') + parser.add_argument( + '--split-list', + nargs='+', + help='a list of splits. e.g., "--split_list training test"') + + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + root_path = args.root_path + out_dir = args.out_dir if args.out_dir else root_path + mmcv.mkdir_or_exist(out_dir) + + img_dir = osp.join(root_path, 'imgs') + gt_dir = osp.join(root_path, 'annotations') + + set_name = {} + for split in args.split_list: + set_name.update({split: 'instances_' + split + '.json'}) + assert osp.exists(osp.join(img_dir, split)) + + for split, json_name in set_name.items(): + print(f'Converting {split} into {json_name}') + with mmcv.Timer( + print_tmpl='It takes {}s to convert totaltext annotation'): + files = collect_files( + osp.join(img_dir, split), osp.join(gt_dir, split), split) + image_infos = collect_annotations(files, split, nproc=args.nproc) + convert_annotations(image_infos, osp.join(out_dir, json_name)) + + +if __name__ == '__main__': + main() diff --git a/tools/data/textrecog/seg_synthtext_converter.py b/tools/data/textrecog/seg_synthtext_converter.py new file mode 100644 index 00000000..fc4e0600 --- /dev/null +++ b/tools/data/textrecog/seg_synthtext_converter.py @@ -0,0 +1,90 @@ +import argparse +import json +import os.path as osp + +import cv2 + + +def parse_old_label(data_root, in_path, img_size=False): + imgid2imgname = {} + imgid2anno = {} + idx = 0 + with open(in_path, 'r') as fr: + for line in fr: + line = line.strip().split() + img_full_path = osp.join(data_root, line[0]) + if not osp.exists(img_full_path): + continue + ann_file = osp.join(data_root, line[1]) + if not osp.exists(ann_file): + continue + + img_info = {} + img_info['file_name'] = line[0] + if img_size: + img = cv2.imread(img_full_path) + h, w = img.shape[:2] + img_info['height'] = h + img_info['width'] = w + imgid2imgname[idx] = img_info + + imgid2anno[idx] = [] + char_annos = [] + with open(ann_file, 'r') as fr: + t = 0 + for line in fr: + line = line.strip() + if t == 0: + img_info['text'] = line + else: + char_box = [float(x) for x in line.split()] + char_text = img_info['text'][t - 1] + char_ann = dict(char_box=char_box, char_text=char_text) + char_annos.append(char_ann) + t += 1 + imgid2anno[idx] = char_annos + idx += 1 + + return imgid2imgname, imgid2anno + + +def gen_line_dict_file(out_path, imgid2imgname, imgid2anno, img_size=False): + with open(out_path, 'w', encoding='utf-8') as fw: + for key, value in imgid2imgname.items(): + if key in imgid2anno: + anno = imgid2anno[key] + line_dict = {} + line_dict['file_name'] = value['file_name'] + line_dict['text'] = value['text'] + if img_size: + line_dict['height'] = value['height'] + line_dict['width'] = value['width'] + line_dict['annotations'] = anno + line_dict_str = json.dumps(line_dict) + fw.write(line_dict_str + '\n') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--data-root', help='data root for both image file and anno file') + parser.add_argument( + '--in-path', + help='mapping file of image_name and ann_file,' + ' "image_name ann_file" in each line') + parser.add_argument( + '--out-path', help='output txt path with line-json format') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + imgid2imgname, imgid2anno = parse_old_label(args.data_root, args.in_path) + gen_line_dict_file(args.out_path, imgid2imgname, imgid2anno) + print('finish') + + +if __name__ == '__main__': + main() diff --git a/tools/data/utils/common.py b/tools/data/utils/common.py new file mode 100644 index 00000000..4e791b9c --- /dev/null +++ b/tools/data/utils/common.py @@ -0,0 +1,60 @@ +import os.path as osp + +import mmcv + + +def is_not_png(img_file): + """Check img_file is not png image. + + Args: + img_file(str): The input image file name + + Returns: + The bool flag indicating whether it is not png + """ + assert isinstance(img_file, str) + assert img_file + + suffix = osp.splitext(img_file)[1] + + return (suffix not in ['.PNG', '.png']) + + +def convert_annotations(image_infos, out_json_name): + """Convert the annotation into coco style. + + Args: + image_infos(list): The list of image information dicts + out_json_name(str): The output json filename + + Returns: + out_json(dict): The coco style dict + """ + assert isinstance(image_infos, list) + assert isinstance(out_json_name, str) + assert out_json_name + + out_json = dict() + img_id = 0 + ann_id = 0 + out_json['images'] = [] + out_json['categories'] = [] + out_json['annotations'] = [] + for image_info in image_infos: + image_info['id'] = img_id + anno_infos = image_info.pop('anno_info') + out_json['images'].append(image_info) + for anno_info in anno_infos: + anno_info['image_id'] = img_id + anno_info['id'] = ann_id + out_json['annotations'].append(anno_info) + ann_id += 1 + img_id += 1 + cat = dict(id=1, name='text') + out_json['categories'].append(cat) + + if len(out_json['annotations']) == 0: + out_json.pop('annotations') + mmcv.dump(out_json, out_json_name) + + return out_json diff --git a/tools/data/utils/txt2lmdb.py b/tools/data/utils/txt2lmdb.py new file mode 100644 index 00000000..e97bfda0 --- /dev/null +++ b/tools/data/utils/txt2lmdb.py @@ -0,0 +1,30 @@ +import argparse + +from mmocr.utils import lmdb_converter + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--imglist', '-i', required=True, help='input imglist path') + parser.add_argument( + '--output', '-o', required=True, help='output lmdb path') + parser.add_argument( + '--batch_size', + '-b', + type=int, + default=10000, + help='processing batch size, default 10000') + parser.add_argument( + '--coding', + '-c', + default='utf8', + help='bytes coding scheme, default utf8') + opt = parser.parse_args() + + lmdb_converter( + opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding) + + +if __name__ == '__main__': + main() diff --git a/tools/dist_test.sh b/tools/dist_test.sh new file mode 100755 index 00000000..6e305059 --- /dev/null +++ b/tools/dist_test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +if [ $# -lt 3 ] +then + echo "Usage: bash $0 CONFIG CHECKPOINT GPUS" + exit +fi + +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +PORT=${PORT:-29500} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/dist_train.sh b/tools/dist_train.sh new file mode 100755 index 00000000..ee3a8efe --- /dev/null +++ b/tools/dist_train.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +if [ $# -lt 3 ] +then + echo "Usage: bash $0 CONFIG WORK_DIR GPUS" + exit +fi + +CONFIG=$1 +WORK_DIR=$2 +GPUS=$3 + +PORT=${PORT:-29500} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ + +if [ ${GPUS} == 1 ]; then + python $(dirname "$0")/train.py $CONFIG --work-dir=${WORK_DIR} ${@:4} +else + python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/train.py $CONFIG --work-dir=${WORK_DIR} --launcher pytorch ${@:4} +fi diff --git a/tools/kie_test_imgs.py b/tools/kie_test_imgs.py new file mode 100644 index 00000000..6fff569d --- /dev/null +++ b/tools/kie_test_imgs.py @@ -0,0 +1,108 @@ +import argparse +import os +import os.path as osp + +import mmcv +import torch +from mmcv import Config +from mmcv.image import tensor2imgs +from mmcv.parallel import MMDataParallel +from mmcv.runner import load_checkpoint + +from mmocr.datasets import build_dataloader, build_dataset +from mmocr.models import build_detector + + +def test(model, data_loader, show=False, out_dir=None): + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + + batch_size = len(result) + if show or out_dir: + img_tensor = data['img'].data[0] + img_metas = data['img_metas'].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + assert len(imgs) == len(img_metas) + gt_bboxes = [data['gt_bboxes'].data[0][0].numpy().tolist()] + + for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result[i], + gt_bboxes[i], + show=show, + out_file=out_file) + + for _ in range(batch_size): + prog_bar.update() + return results + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMOCR visualize for kie model.') + parser.add_argument('config', help='Test config file path.') + parser.add_argument('checkpoint', help='Checkpoint file.') + parser.add_argument('--show', action='store_true', help='Show results.') + parser.add_argument( + '--show-dir', help='Directory where the output images will be saved.') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + assert args.show or args.show_dir, ('Please specify at least one ' + 'operation (show the results / save )' + 'the results with the argument ' + '"--show" or "--show-dir".') + + cfg = Config.fromfile(args.config) + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.model.pretrained = None + + distributed = False + + # build the dataloader + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + load_checkpoint(model, args.checkpoint, map_location='cpu') + + model = MMDataParallel(model, device_ids=[0]) + test(model, data_loader, args.show, args.show_dir) + + +if __name__ == '__main__': + main() diff --git a/tools/kie_test_imgs.sh b/tools/kie_test_imgs.sh new file mode 100644 index 00000000..80444cc6 --- /dev/null +++ b/tools/kie_test_imgs.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +DATE=`date +%Y-%m-%d` +TIME=`date +"%H-%M-%S"` + +if [ $# -lt 3 ] +then + echo "Usage: bash $0 CONFIG CHECKPOINT SHOW_DIR" + exit +fi + +CONFIG_FILE=$1 +CHECKPOINT=$2 +SHOW_DIR=$3_${DATE}_${TIME} + +mkdir ${SHOW_DIR} -p && + +python tools/kie_test_imgs.py \ + ${CONFIG_FILE} \ + ${CHECKPOINT} \ + --show-dir ${SHOW_DIR} diff --git a/tools/ocr_test_imgs.py b/tools/ocr_test_imgs.py new file mode 100644 index 00000000..cecf8625 --- /dev/null +++ b/tools/ocr_test_imgs.py @@ -0,0 +1,132 @@ +import os.path as osp +import shutil +import time +from argparse import ArgumentParser + +import mmcv +import torch +from mmcv.utils import ProgressBar + +from mmdet.apis import init_detector +from mmdet.utils import get_root_logger +from mmocr.apis import model_inference +from mmocr.core.evaluation.ocr_metric import eval_ocr_metric +from mmocr.datasets import build_dataset # noqa: F401 +from mmocr.models import build_detector # noqa: F401 + + +def save_results(img_paths, pred_labels, gt_labels, res_dir): + """Save predicted results to txt file. + + Args: + img_paths (list[str]) + pred_labels (list[str]) + gt_labels (list[str]) + res_dir (str) + """ + assert len(img_paths) == len(pred_labels) == len(gt_labels) + res_file = osp.join(res_dir, 'results.txt') + correct_file = osp.join(res_dir, 'correct.txt') + wrong_file = osp.join(res_dir, 'wrong.txt') + with open(res_file, 'w') as fw, \ + open(correct_file, 'w') as fw_correct, \ + open(wrong_file, 'w') as fw_wrong: + for img_path, pred_label, gt_label in zip(img_paths, pred_labels, + gt_labels): + fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n') + if pred_label == gt_label: + fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label + + '\n') + else: + fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label + + '\n') + + +def main(): + parser = ArgumentParser() + parser.add_argument('--img_root_path', type=str, help='Image root path') + parser.add_argument('--img_list', type=str, help='Image path list file') + parser.add_argument('--config', type=str, help='Config file') + parser.add_argument('--checkpoint', type=str, help='Checkpoint file') + parser.add_argument( + '--out_dir', type=str, default='./results', help='Dir to save results') + parser.add_argument( + '--show', action='store_true', help='show image or save') + args = parser.parse_args() + + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(args.out_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level='INFO') + + # build the model from a config file and a checkpoint file + device = 'cuda:' + str(torch.cuda.current_device()) + model = init_detector(args.config, args.checkpoint, device=device) + if hasattr(model, 'module'): + model = model.module + if model.cfg.data.test['type'] == 'ConcatDataset': + model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ + 0].pipeline + + # Start Inference + out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') + mmcv.mkdir_or_exist(out_vis_dir) + correct_vis_dir = osp.join(args.out_dir, 'correct') + mmcv.mkdir_or_exist(correct_vis_dir) + wrong_vis_dir = osp.join(args.out_dir, 'wrong') + mmcv.mkdir_or_exist(wrong_vis_dir) + img_paths, pred_labels, gt_labels = [], [], [] + total_img_num = sum([1 for _ in open(args.img_list)]) + progressbar = ProgressBar(task_num=total_img_num) + num_gt_label = 0 + with open(args.img_list, 'r') as fr: + for line in fr: + progressbar.update() + item_list = line.strip().split() + img_file = item_list[0] + gt_label = '' + if len(item_list) >= 2: + gt_label = item_list[1] + num_gt_label += 1 + img_path = osp.join(args.img_root_path, img_file) + if not osp.exists(img_path): + raise FileNotFoundError(img_path) + # Test a single image + result = model_inference(model, img_path) + pred_label = result['text'] + + out_img_name = '_'.join(img_file.split('/')) + out_file = osp.join(out_vis_dir, out_img_name) + kwargs_dict = { + 'gt_label': gt_label, + 'show': args.show, + 'out_file': '' if args.show else out_file + } + model.show_result(img_path, result, **kwargs_dict) + if gt_label != '': + if gt_label == pred_label: + dst_file = osp.join(correct_vis_dir, out_img_name) + else: + dst_file = osp.join(wrong_vis_dir, out_img_name) + shutil.copy(out_file, dst_file) + img_paths.append(img_path) + gt_labels.append(gt_label) + pred_labels.append(pred_label) + + # Save results + save_results(img_paths, pred_labels, gt_labels, args.out_dir) + + if num_gt_label == len(pred_labels): + # eval + eval_results = eval_ocr_metric(pred_labels, gt_labels) + logger.info('\n' + '-' * 100) + info = ('eval on testset with img_root_path ' + f'{args.img_root_path} and img_list {args.img_list}\n') + logger.info(info) + logger.info(eval_results) + + print(f'\nInference done, and results saved in {args.out_dir}\n') + + +if __name__ == '__main__': + main() diff --git a/tools/ocr_test_imgs.sh b/tools/ocr_test_imgs.sh new file mode 100644 index 00000000..69d719a1 --- /dev/null +++ b/tools/ocr_test_imgs.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +DATE=`date +%Y-%m-%d` +TIME=`date +"%H-%M-%S"` + +if [ $# -lt 5 ] +then + echo "Usage: bash $0 CONFIG CHECKPOINT IMG_PREFIX IMG_LIST RESULTS_DIR" + exit +fi + +CONFIG_FILE=$1 +CHECKPOINT=$2 +IMG_ROOT_PATH=$3 +IMG_LIST=$4 +OUT_DIR=$5_${DATE}_${TIME} + +mkdir ${OUT_DIR} -p && + +python tools/ocr_test_imgs.py \ + --img_root_path ${IMG_ROOT_PATH} \ + --img_list ${IMG_LIST} \ + --config ${CONFIG_FILE} \ + --checkpoint ${CHECKPOINT} \ + --out_dir ${OUT_DIR} diff --git a/tools/publish_model.py b/tools/publish_model.py new file mode 100644 index 00000000..34bbb7c0 --- /dev/null +++ b/tools/publish_model.py @@ -0,0 +1,37 @@ +import argparse +import subprocess + +import torch + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Process a checkpoint to be published') + parser.add_argument('in_file', help='input checkpoint filename') + parser.add_argument('out_file', help='output checkpoint filename') + args = parser.parse_args() + return args + + +def process_checkpoint(in_file, out_file): + checkpoint = torch.load(in_file, map_location='cpu') + # remove optimizer for smaller file size + if 'optimizer' in checkpoint: + del checkpoint['optimizer'] + # if it is necessary to remove some sensitive data in checkpoint['meta'], + # add the code here. + if 'meta' in checkpoint: + checkpoint['meta'] = {'CLASSES': 0} + torch.save(checkpoint, out_file) + sha = subprocess.check_output(['sha256sum', out_file]).decode() + final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) + subprocess.Popen(['mv', out_file, final_file]) + + +def main(): + args = parse_args() + process_checkpoint(args.in_file, args.out_file) + + +if __name__ == '__main__': + main() diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh new file mode 100755 index 00000000..865f4559 --- /dev/null +++ b/tools/slurm_test.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -x +export PYTHONPATH=`pwd`:$PYTHONPATH + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +CHECKPOINT=$4 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +PY_ARGS=${@:5} +SRUN_ARGS=${SRUN_ARGS:-""} + +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} diff --git a/tools/slurm_train.sh b/tools/slurm_train.sh new file mode 100755 index 00000000..452b0945 --- /dev/null +++ b/tools/slurm_train.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +export MASTER_PORT=$((12000 + $RANDOM % 20000)) + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +WORK_DIR=$4 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +PY_ARGS=${@:5} +SRUN_ARGS=${SRUN_ARGS:-""} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 00000000..fa3812b2 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,216 @@ +import argparse +import os +import warnings + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.cnn import fuse_conv_bn +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, + wrap_fp16_model) + +from mmdet.apis import multi_gpu_test, single_gpu_test +from mmdet.datasets import replace_ImageToTensor +from mmocr.datasets import build_dataloader, build_dataset +from mmocr.models import build_detector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMOCR test (and eval) a model.') + parser.add_argument('config', help='Test config file path.') + parser.add_argument('checkpoint', help='Checkpoint file.') + parser.add_argument('--out', help='Output result file in pickle format.') + parser.add_argument( + '--fuse-conv-bn', + action='store_true', + help='Whether to fuse conv and bn, this will slightly increase' + 'the inference speed.') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without performing evaluation. It is' + 'useful when you want to format the results to a specific format and ' + 'submit them to the test server.') + parser.add_argument( + '--eval', + type=str, + nargs='+', + help='The evaluation metrics, which depends on the dataset, e.g.,' + '"bbox", "seg", "proposal" for COCO, and "mAP", "recall" for' + 'PASCAL VOC.') + parser.add_argument('--show', action='store_true', help='Show results.') + parser.add_argument( + '--show-dir', help='Directory where the output images will be saved.') + parser.add_argument( + '--show-score-thr', + type=float, + default=0.3, + help='Score threshold (default: 0.3).') + parser.add_argument( + '--gpu-collect', + action='store_true', + help='Whether to use gpu to collect results.') + parser.add_argument( + '--tmpdir', + help='The tmp directory used for collecting results from multiple ' + 'workers, available when gpu-collect is not specified.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into the config file. If the value ' + 'to be overwritten is a list, it should be of the form of either ' + 'key="[a,b]" or key=a,b. The argument also allows nested list/tuple ' + 'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks ' + 'are necessary and that no white space is allowed.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='Custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function (deprecate), ' + 'change to --eval-options instead.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='Custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='Options for job launcher.') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.eval_options: + raise ValueError( + '--options and --eval-options cannot be both ' + 'specified, --options is deprecated in favor of --eval-options.') + if args.options: + warnings.warn('--options is deprecated in favor of --eval-options.') + args.eval_options = args.options + return args + + +def main(): + args = parse_args() + + assert ( + args.out or args.eval or args.format_only or args.show + or args.show_dir), ( + 'Please specify at least one operation (save/eval/format/show the ' + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir".') + + if args.eval and args.format_only: + raise ValueError('--eval and --format_only cannot be both specified.') + + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.model.pretrained = None + if cfg.model.get('neck'): + if isinstance(cfg.model.neck, list): + for neck_cfg in cfg.model.neck: + if neck_cfg.get('rfp_backbone'): + if neck_cfg.rfp_backbone.get('pretrained'): + neck_cfg.rfp_backbone.pretrained = None + elif cfg.model.neck.get('rfp_backbone'): + if cfg.model.neck.rfp_backbone.get('pretrained'): + cfg.model.neck.rfp_backbone.pretrained = None + + # in case the test dataset is concatenated + samples_per_gpu = 1 + if isinstance(cfg.data.test, dict): + cfg.data.test.test_mode = True + samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) + if samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.test.pipeline = replace_ImageToTensor( + cfg.data.test.pipeline) + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + ds_cfg.test_mode = True + samples_per_gpu = max( + [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) + if samples_per_gpu > 1: + for ds_cfg in cfg.data.test: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + # build the dataloader + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + load_checkpoint(model, args.checkpoint, map_location='cpu') + if args.fuse_conv_bn: + model = fuse_conv_bn(model) + + if not distributed: + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, + args.show_score_thr) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + outputs = multi_gpu_test(model, data_loader, args.tmpdir, + args.gpu_collect) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + print(f'\nwriting results to {args.out}') + mmcv.dump(outputs, args.out) + kwargs = {} if args.eval_options is None else args.eval_options + if args.format_only: + dataset.format_results(outputs, **kwargs) + if args.eval: + eval_kwargs = cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + for key in [ + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', + 'rule' + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=args.eval, **kwargs)) + print(dataset.evaluate(outputs, **eval_kwargs)) + + +if __name__ == '__main__': + main() diff --git a/tools/test_imgs.py b/tools/test_imgs.py new file mode 100644 index 00000000..132eafd0 --- /dev/null +++ b/tools/test_imgs.py @@ -0,0 +1,167 @@ +import codecs +import os.path as osp +from argparse import ArgumentParser + +import mmcv +import numpy as np +import torch +from mmcv.utils import ProgressBar + +from mmdet.apis import inference_detector, init_detector +from mmocr.core.evaluation.utils import filter_result +from mmocr.models import build_detector # noqa: F401 + + +def gen_target_path(target_root_path, src_name, suffix): + """Gen target file path. + + Args: + target_root_path (str): The target root path. + src_name (str): The source file name. + suffix (str): The suffix of target file. + """ + assert isinstance(target_root_path, str) + assert isinstance(src_name, str) + assert isinstance(suffix, str) + + dir_name, file_name = osp.split(src_name) + name, file_suffix = osp.splitext(file_name) + return target_root_path + '/' + name + suffix + + +def save_2darray(mat, file_name): + """Save 2d array to txt file. + + Args: + mat (ndarray): 2d-array of shape (n, m). + file_name (str): The output file name. + """ + with codecs.open(file_name, 'w', 'utf-8') as fw: + for row in mat: + row_str = ','.join([str(x) for x in row]) + fw.write(row_str + '\n') + + +def save_bboxes_quadrangles(bboxes_with_scores, + quadrangles_with_scores, + img_name, + out_bbox_txt_dir, + out_quadrangle_txt_dir, + score_thr=0.3, + save_score=True): + """Save results of detected bounding boxes and quadrangles to txt file. + + Args: + bboxes_with_scores (ndarray): Detected bboxes of shape (n,5). + quadrangles_with_scores (ndarray): Detected quadrangles of shape (n,9). + img_name (str): Image file name. + out_bbox_txt_dir (str): Dir of txt files to save detected bboxes + results. + out_quadrangle_txt_dir (str): Dir of txt files to save + quadrangle results. + score_thr (float, optional): Score threshold for bboxes. + save_score (bool, optional): Whether to save score at each line end + to search best threshold when evaluating. + """ + assert bboxes_with_scores.ndim == 2 + assert bboxes_with_scores.shape[1] == 5 or bboxes_with_scores.shape[1] == 9 + assert quadrangles_with_scores.ndim == 2 + assert quadrangles_with_scores.shape[1] == 9 + assert bboxes_with_scores.shape[0] >= quadrangles_with_scores.shape[0] + assert isinstance(img_name, str) + assert isinstance(out_bbox_txt_dir, str) + assert isinstance(out_quadrangle_txt_dir, str) + assert isinstance(score_thr, float) + assert score_thr >= 0 and score_thr < 1 + + # filter out invalid results + initial_valid_bboxes, valid_bbox_scores = filter_result( + bboxes_with_scores[:, :-1], bboxes_with_scores[:, -1], score_thr) + if initial_valid_bboxes.shape[1] == 4: + valid_bboxes = np.ndarray( + (initial_valid_bboxes.shape[0], 8)).astype(int) + idx_list = [0, 1, 2, 1, 2, 3, 0, 3] + for i in range(8): + valid_bboxes[:, i] = initial_valid_bboxes[:, idx_list[i]] + + elif initial_valid_bboxes.shape[1] == 8: + valid_bboxes = initial_valid_bboxes + + valid_quadrangles, valid_quadrangle_scores = filter_result( + quadrangles_with_scores[:, :-1], quadrangles_with_scores[:, -1], + score_thr) + + # gen target file path + bbox_txt_file = gen_target_path(out_bbox_txt_dir, img_name, '.txt') + quadrangle_txt_file = gen_target_path(out_quadrangle_txt_dir, img_name, + '.txt') + + # save txt + if save_score: + valid_bboxes = np.concatenate( + (valid_bboxes, valid_bbox_scores.reshape(-1, 1)), axis=1) + valid_quadrangles = np.concatenate( + (valid_quadrangles, valid_quadrangle_scores.reshape(-1, 1)), + axis=1) + + save_2darray(valid_bboxes, bbox_txt_file) + save_2darray(valid_quadrangles, quadrangle_txt_file) + + +def main(): + parser = ArgumentParser() + parser.add_argument('config', type=str, help='Config file') + parser.add_argument('checkpoint', type=str, help='Checkpoint file') + parser.add_argument('img_root', type=str, help='Image root path') + parser.add_argument('img_list', type=str, help='Image path list file') + + parser.add_argument( + '--score-thr', type=float, default=0.5, help='Bbox score threshold') + parser.add_argument( + '--out-dir', + type=str, + default='./results', + help='Dir to save ' + 'visualize images ' + 'and bbox') + args = parser.parse_args() + + assert args.score_thr > 0 and args.score_thr < 1 + + # build the model from a config file and a checkpoint file + device = 'cuda:' + str(torch.cuda.current_device()) + model = init_detector(args.config, args.checkpoint, device=device) + if hasattr(model, 'module'): + model = model.module + if model.cfg.data.test['type'] == 'ConcatDataset': + model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ + 0].pipeline + + # Start Inference + out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') + mmcv.mkdir_or_exist(out_vis_dir) + + total_img_num = sum([1 for _ in open(args.img_list)]) + progressbar = ProgressBar(task_num=total_img_num) + with codecs.open(args.img_list, 'r', 'utf-8') as fr: + for line in fr: + progressbar.update() + img_path = args.img_root + '/' + line.strip() + if not osp.exists(img_path): + raise FileNotFoundError(img_path) + # Test a single image + result = inference_detector(model, img_path) + img_name = osp.basename(img_path) + out_file = osp.join(out_vis_dir, img_name) + kwargs_dict = { + 'score_thr': args.score_thr, + 'show': False, + 'out_file': out_file + } + model.show_result(img_path, result, **kwargs_dict) + + print(f'\nInference done, and results saved in {args.out_dir}\n') + + +if __name__ == '__main__': + main() diff --git a/tools/test_imgs.sh b/tools/test_imgs.sh new file mode 100644 index 00000000..c17a1bf7 --- /dev/null +++ b/tools/test_imgs.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +DATE=`date +%Y-%m-%d` +TIME=`date +"%H-%M-%S"` + +if [ $# -lt 5 ] +then + echo "Usage: bash $0 CONFIG CHECKPOINT IMG_ROOT_PATH IMG_LIST OUT_DIR" + exit +fi + +CONFIG_FILE=$1 +CHECKPOINT=$2 +IMG_ROOT_PATH=$3 +IMG_LIST=$4 +OUT_DIR=$5 + +mkdir ${OUT_DIR} -p && + + +python tools/test_imgs.py \ + ${CONFIG_FILE} ${CHECKPOINT} ${IMG_ROOT_PATH} ${IMG_LIST} \ + --out-dir ${OUT_DIR} diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 00000000..124cf0f7 --- /dev/null +++ b/tools/train.py @@ -0,0 +1,205 @@ +import argparse +import copy +import os +import os.path as osp +import time +import warnings + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist +from mmcv.utils import get_git_hash + +from mmdet import __version__ +from mmdet.apis import set_random_seed, train_detector +from mmdet.utils import collect_env, get_root_logger +from mmocr.datasets import build_dataset +from mmocr.models import build_detector + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a detector.') + parser.add_argument('config', help='Train config file path.') + parser.add_argument('--work-dir', help='The dir to save logs and models.') + parser.add_argument( + '--resume-from', help='The checkpoint file to resume from.') + parser.add_argument( + '--no-validate', + action='store_true', + help='Whether not to evaluate the checkpoint during training.') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + '--gpus', + type=int, + help='Number of gpus to use ' + '(only applicable to non-distributed training).') + group_gpus.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed training).') + parser.add_argument('--seed', type=int, default=None, help='Random seed.') + parser.add_argument( + '--deterministic', + action='store_true', + help='Whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file (deprecate), ' + 'change to --cfg-options instead.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be of the form of either ' + 'key="[a,b]" or key=a,b .The argument also allows nested list/tuple ' + 'values, e.g. key="[(a,b),(c,d)]". Note that the quotation marks ' + 'are necessary and that no white space is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='Options for job launcher.') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--mc-config', + type=str, + default='', + help='Memory cache config for image loading speed-up during training.') + + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # update mc config + if args.mc_config: + mc = Config.fromfile(args.mc_config) + if isinstance(cfg.data.train, list): + for i in range(len(cfg.data.train)): + cfg.data.train[i].pipeline[0].update( + file_client_args=mc['mc_file_client_args']) + else: + cfg.data.train.pipeline[0].update( + file_client_args=mc['mc_file_client_args']) + + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if args.seed is not None: + logger.info(f'Set random seed to {args.seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta['seed'] = args.seed + meta['exp_name'] = osp.basename(args.config) + + model = build_detector( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__ + get_git_hash()[:7], + CLASSES=datasets[0].CLASSES) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + train_detector( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta) + + +if __name__ == '__main__': + main()