mmdeploy/csrc/codebase/mmpose/keypoints_from_heatmap.cpp
lvhan028 9306bcec80
Dev v0.4.0 (#301)
* bump version to v0.4.0

* [Enhancement] Make rewriter more powerful (#150)

* Finish function tests

* lint

* resolve comments

* Fix tests

* docstring & fix

* Complement informations

* lint

* Add example

* Fix version

* Remove todo

Co-authored-by: RunningLeon <mnsheng@yeah.net>

* Torchscript support (#159)

* support torchscript

* add nms

* add torchscript configs and update deploy process and dump-info

* typescript -> torchscript

* add torchscript custom extension support

* add ts custom ops again

* support mmseg unet

* [WIP] add optimizer for torchscript (#119)

* add passes

* add python api

* Torchscript optimizer python api (#121)

* add passes

* add python api

* use python api instead of executable

* Merge Master, update optimizer (#151)

* [Feature] add yolox ncnn (#29)

* add yolox ncnn

* add ncnn android performance of yolox

* add ut

* fix lint

* fix None bugs for ncnn

* test codecov

* test codecov

* add device

* fix yapf

* remove if-else for img shape

* use channelshuffle optimize

* change benchmark after channelshuffle

* fix yapf

* fix yapf

* fuse continuous reshape

* fix static shape deploy

* fix code

* drop pad

* only static shape

* fix static

* fix docstring

* Added mask overlay to output image, changed fprintf info messages to … (#55)

* Added mask overlay to output image, changed fprintf info messages to stdout

* Improved box filtering (filter area/score), make sure roi coordinates stay within bounds

* clang-format

* Support UNet in mmseg (#77)

* Repeatdataset in train has no CLASSES & PALETTE

* update result for unet

* update docstring for mmdet

* remove ppl for unet in docs

* fix ort wrap about input type (#81)

* Fix memleak (#86)

* delete []

* fix build error when enble MMDEPLOY_ACTIVE_LEVEL

* fix lint

* [Doc] Nano benchmark and tutorial (#71)

* add cls benchmark

* add nano zh-cn benchmark and en tutorial

* add device row

* add doc path to index.rst

* fix typo

* [Fix] fix missing deploy_core (#80)

* fix missing deploy_core

* mv flag to demo

* target link

* [Docs] Fix links in Chinese doc (#84)

* Fix docs in Chinese link

* Fix links

* Delete symbolic link and add links to html

* delete files

* Fix link

* [Feature] Add docker files (#67)

* add gpu and cpu dockerfile

* fix lint

* fix cpu docker and remove redundant

* use pip instead

* add build arg and readme

* fix grammar

* update readme

* add chinese doc for dockerfile and add docker build to build.md

* grammar

* refine dockerfiles

* add FAQs

* update Dpplcv_DIR for SDK building

* remove mmcls

* add sdk demos

* fix typo and lint

* update FAQs

* [Fix]fix check_env (#101)

* fix check_env

* update

* Replace convert_syncbatchnorm in mmseg (#93)

* replace convert_syncbatchnorm with revert_sync_batchnorm from mmcv

* change logger

* [Doc] Update FAQ for TensorRT (#96)

* update FAQ

* comment

* [Docs]: Update doc for openvino installation (#102)

* fix docs

* fix docs

* fix docs

* fix mmcv version

* fix docs

* rm blank line

* simplify non batch nms (#99)

* [Enhacement] Allow test.py to save evaluation results (#108)

* Add log file

* Delete debug code

* Rename logger

* resolve comments

* [Enhancement] Support mmocr v0.4+ (#115)

* support mmocr v0.4+

* 0.4.0 -> 0.4.1

* fix onnxruntime wrapper for gpu inference (#123)

* fix ncnn wrapper for ort-gpu

* resolve comment

* fix lint

* Fix typo (#132)

* lock mmcls version (#131)

* [Enhancement] upgrade isort in pre-commit config (#141)

* [Enhancement] upgrade isort in pre-commit config by refering to mmflow pr #87

* fix lint

* remove .isort.cfg and put its known_third_party to setup.cfg

* Fix ci for mmocr (#144)

* fix mmocr unittests

* remove useless

* lock mmdet maximum version to 2.20

* pip install -U numpy

* Fix capture_output (#125)

Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>

* configs for all tasks

* use torchvision roi align

* remote unnecessary code

* fix ut

* fix ut

* export

* det dynamic

* det dynamic

* add ut

* fix ut

* add ut and docs

* fix ut

* skip torchscript ut if no ops available

* add torchscript option to build.md

* update benchmark and resolve comments

* resolve conflicts

* rename configs

* fix mrcnn cuda test

* remove useless

* add version requirements to docs and comments to codes

* enable empty image exporting for torchscript and accelerate ORT inference for MRCNN

* rebase

* update example for torchscript.md

* update FAQs for torchscript.md

* resolve comments

* only use torchvision roi_align for torchscript

* fix ut

* use torchvision roi align when pool model is avg

* resolve comments

Co-authored-by: grimoire <streetyao@live.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>

* Update supported mmseg models (#181)

* fix ocrnet cascade decoder

* update mmseg support models

* update mmseg configs

* support emanet and icnet

* set max K of TopK for tensorrt

* update supported models for mmseg in docs

* add test for emamodule

* add configs and update docs

* Update docs

* update benchmark

* [Features]Support mmdet3d (#103)

* add mmdet3d code

* add code

* update code

* [log]This commit finish pointpillar export and evaluate on onnxruntime.The model is sample with nvidia repo model

* add tensorrt config

* fix config

* update

* support for tensorrt

* add config

* fix config`

* fix apis about torch2onnx

* update

* mmdet3d deploy version1.0

* map is ok

* fix code

* version1.0

* fix code

* fix visual

* fix bug

* tensorrt support success

* add docstring

* add docs

* fix docs

* fix comments

* fix comment

* fix comment

* fix openvino wrapper

* add unit test

* fix device about cpu

* fix comment

* fix show_result

* fix lint

* fix requirments

* remove ci about det3d

* fix ut

* add ut data

* support for new version pointpillars

* fix comment

* fix support_list

* fix comments

* fix config name

* [Enhancement] Update pad logic in detection heads (#168)

* pad with register

* fix lint

Co-authored-by: AllentDan <dongchunyu@sensetime.com>

* [Enhancement] Additional arguments support for OpenVINO Model Optimizer (#178)

* Add mo args.

* [Docs]: update docs and argument descriptions (#196)

* bump version to v0.4.0

* update docs and argument descriptions

* revert version change

* fix unnecessary change of config for dynamic exportation (#199)

* fix mmcls get classes (#215)

* fix mmcls get classes

* resolve comment

* resolve comment

* Add ModelOptimizerOptions.

* Fix merge bugs.

* Update mmpose.md (#224)

* [Dostring]add example in apis docstring (#214)

* add example in apis docstring

* add backend example in docstring

* rm blank line

* Fixed get_mo_options_from_cfg args

* fix l2norm test

Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: Haofan Wang <frankmiracle@outlook.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>

* [Enhancement] Switch to statically typed Value::Any (#209)

* replace std::any with StaticAny

* fix __compare_typeid

* remove fallback id support

* constraint on traits::TypeId<T>::value

* fix includes

* [Enhancement] TensorRT DCN support (#205)

* add tensorrt dcn support

* fix lint

* remove roi_align plugin for ORT (#258)

* remove roi_align plugin

* remove ut

* skip single_roi_extractor UT for ORT in CI

* move align to symbolic and update docs

* recover UT

* resolve comments

* [Enhancement]: Support fcn_unet deployment with dynamic shape (#251)

* support mmseg fcn+unet dynamic shape

* add test

* fix ci

* fix units

* resolve comments

* [Enhancement] fix-cmake-relocatable (#223)

* require user to specify xxx_dir

* fix line ending

* fix end-of-file-fixer

* try to fix ld cudart cublas

* add ENV var search

* fix CMAKE_CUDA_COMPILER

* cpu, cuda should all work well

* remove commented code

* fix ncnn example find ncnn package (#282)

* table format is wrong (#283)

* update pre-commit (#284)

* update pre-commit

* fix clang-format

* fix mmseg config (#281)

* fix mmseg config

* fix mmpose evaluate outputs

* fix lint

* update pre-commit config

* fix lint

* Revert "update pre-commit config"

This reverts commit c3fd71611f0b79dfa9ad73fc0f4555c1b3563665.

* miss code symbol (#296)

* refactor cmake build (#295)

* add-mmpose-sdk (#259)

* Torchscript support (#159)

* support torchscript

* add nms

* add torchscript configs and update deploy process and dump-info

* typescript -> torchscript

* add torchscript custom extension support

* add ts custom ops again

* support mmseg unet

* [WIP] add optimizer for torchscript (#119)

* add passes

* add python api

* Torchscript optimizer python api (#121)

* add passes

* add python api

* use python api instead of executable

* Merge Master, update optimizer (#151)

* [Feature] add yolox ncnn (#29)

* add yolox ncnn

* add ncnn android performance of yolox

* add ut

* fix lint

* fix None bugs for ncnn

* test codecov

* test codecov

* add device

* fix yapf

* remove if-else for img shape

* use channelshuffle optimize

* change benchmark after channelshuffle

* fix yapf

* fix yapf

* fuse continuous reshape

* fix static shape deploy

* fix code

* drop pad

* only static shape

* fix static

* fix docstring

* Added mask overlay to output image, changed fprintf info messages to … (#55)

* Added mask overlay to output image, changed fprintf info messages to stdout

* Improved box filtering (filter area/score), make sure roi coordinates stay within bounds

* clang-format

* Support UNet in mmseg (#77)

* Repeatdataset in train has no CLASSES & PALETTE

* update result for unet

* update docstring for mmdet

* remove ppl for unet in docs

* fix ort wrap about input type (#81)

* Fix memleak (#86)

* delete []

* fix build error when enble MMDEPLOY_ACTIVE_LEVEL

* fix lint

* [Doc] Nano benchmark and tutorial (#71)

* add cls benchmark

* add nano zh-cn benchmark and en tutorial

* add device row

* add doc path to index.rst

* fix typo

* [Fix] fix missing deploy_core (#80)

* fix missing deploy_core

* mv flag to demo

* target link

* [Docs] Fix links in Chinese doc (#84)

* Fix docs in Chinese link

* Fix links

* Delete symbolic link and add links to html

* delete files

* Fix link

* [Feature] Add docker files (#67)

* add gpu and cpu dockerfile

* fix lint

* fix cpu docker and remove redundant

* use pip instead

* add build arg and readme

* fix grammar

* update readme

* add chinese doc for dockerfile and add docker build to build.md

* grammar

* refine dockerfiles

* add FAQs

* update Dpplcv_DIR for SDK building

* remove mmcls

* add sdk demos

* fix typo and lint

* update FAQs

* [Fix]fix check_env (#101)

* fix check_env

* update

* Replace convert_syncbatchnorm in mmseg (#93)

* replace convert_syncbatchnorm with revert_sync_batchnorm from mmcv

* change logger

* [Doc] Update FAQ for TensorRT (#96)

* update FAQ

* comment

* [Docs]: Update doc for openvino installation (#102)

* fix docs

* fix docs

* fix docs

* fix mmcv version

* fix docs

* rm blank line

* simplify non batch nms (#99)

* [Enhacement] Allow test.py to save evaluation results (#108)

* Add log file

* Delete debug code

* Rename logger

* resolve comments

* [Enhancement] Support mmocr v0.4+ (#115)

* support mmocr v0.4+

* 0.4.0 -> 0.4.1

* fix onnxruntime wrapper for gpu inference (#123)

* fix ncnn wrapper for ort-gpu

* resolve comment

* fix lint

* Fix typo (#132)

* lock mmcls version (#131)

* [Enhancement] upgrade isort in pre-commit config (#141)

* [Enhancement] upgrade isort in pre-commit config by refering to mmflow pr #87

* fix lint

* remove .isort.cfg and put its known_third_party to setup.cfg

* Fix ci for mmocr (#144)

* fix mmocr unittests

* remove useless

* lock mmdet maximum version to 2.20

* pip install -U numpy

* Fix capture_output (#125)

Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>

* configs for all tasks

* use torchvision roi align

* remote unnecessary code

* fix ut

* fix ut

* export

* det dynamic

* det dynamic

* add ut

* fix ut

* add ut and docs

* fix ut

* skip torchscript ut if no ops available

* add torchscript option to build.md

* update benchmark and resolve comments

* resolve conflicts

* rename configs

* fix mrcnn cuda test

* remove useless

* add version requirements to docs and comments to codes

* enable empty image exporting for torchscript and accelerate ORT inference for MRCNN

* rebase

* update example for torchscript.md

* update FAQs for torchscript.md

* resolve comments

* only use torchvision roi_align for torchscript

* fix ut

* use torchvision roi align when pool model is avg

* resolve comments

Co-authored-by: grimoire <streetyao@live.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>

* Update supported mmseg models (#181)

* fix ocrnet cascade decoder

* update mmseg support models

* update mmseg configs

* support emanet and icnet

* set max K of TopK for tensorrt

* update supported models for mmseg in docs

* add test for emamodule

* add configs and update docs

* Update docs

* update benchmark

* [Features]Support mmdet3d (#103)

* add mmdet3d code

* add code

* update code

* [log]This commit finish pointpillar export and evaluate on onnxruntime.The model is sample with nvidia repo model

* add tensorrt config

* fix config

* update

* support for tensorrt

* add config

* fix config`

* fix apis about torch2onnx

* update

* mmdet3d deploy version1.0

* map is ok

* fix code

* version1.0

* fix code

* fix visual

* fix bug

* tensorrt support success

* add docstring

* add docs

* fix docs

* fix comments

* fix comment

* fix comment

* fix openvino wrapper

* add unit test

* fix device about cpu

* fix comment

* fix show_result

* fix lint

* fix requirments

* remove ci about det3d

* fix ut

* add ut data

* support for new version pointpillars

* fix comment

* fix support_list

* fix comments

* fix config name

* [Enhancement] Additional arguments support for OpenVINO Model Optimizer (#178)

* Add mo args.

* [Docs]: update docs and argument descriptions (#196)

* bump version to v0.4.0

* update docs and argument descriptions

* revert version change

* fix unnecessary change of config for dynamic exportation (#199)

* fix mmcls get classes (#215)

* fix mmcls get classes

* resolve comment

* resolve comment

* Add ModelOptimizerOptions.

* Fix merge bugs.

* Update mmpose.md (#224)

* [Dostring]add example in apis docstring (#214)

* add example in apis docstring

* add backend example in docstring

* rm blank line

* Fixed get_mo_options_from_cfg args

* fix l2norm test

Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: Haofan Wang <frankmiracle@outlook.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>

* add-mmpose-codebase

* fix ci

* fix img_shape after TopDownAffine

* rename TopDown module -> XheadDecode & implement regression decode

* align keypoints_from_heatmap

* remove hardcode keypoint_head, need refactor, current only support topdown config

* add mmpose python api

* update mmpose-python code

* can't clip fake box

* fix rebase error

* fix rebase error

* link mspn decoder to base decoder

* fix ci

* compile with gcc7.5

* remove no use code

* fix

* fix prompt

* remove unnecessary cv::parallel_for_

* rewrite TopdownHeatmapMultiStageHead.inference_model

* add comment

* add more detail docstring why use _cs2xyxy in sdk backend

* fix Registry name

* remove no use param & add comment of output result

Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: grimoire <streetyao@live.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>
Co-authored-by: Haofan Wang <frankmiracle@outlook.com>

* update faq about WinError 1455 (#297)

* update faq about WinError 1455

* Update faq.md

* Update faq.md

* fix ci

Co-authored-by: chenxin2 <chenxin2@sensetime.com>

* [Feature]Support centerpoint (#252)

* bump version to v0.4.0

* [Enhancement] Make rewriter more powerful (#150)

* Finish function tests

* lint

* resolve comments

* Fix tests

* docstring & fix

* Complement informations

* lint

* Add example

* Fix version

* Remove todo

Co-authored-by: RunningLeon <mnsheng@yeah.net>

* Torchscript support (#159)

* support torchscript

* add nms

* add torchscript configs and update deploy process and dump-info

* typescript -> torchscript

* add torchscript custom extension support

* add ts custom ops again

* support mmseg unet

* [WIP] add optimizer for torchscript (#119)

* add passes

* add python api

* Torchscript optimizer python api (#121)

* add passes

* add python api

* use python api instead of executable

* Merge Master, update optimizer (#151)

* [Feature] add yolox ncnn (#29)

* add yolox ncnn

* add ncnn android performance of yolox

* add ut

* fix lint

* fix None bugs for ncnn

* test codecov

* test codecov

* add device

* fix yapf

* remove if-else for img shape

* use channelshuffle optimize

* change benchmark after channelshuffle

* fix yapf

* fix yapf

* fuse continuous reshape

* fix static shape deploy

* fix code

* drop pad

* only static shape

* fix static

* fix docstring

* Added mask overlay to output image, changed fprintf info messages to … (#55)

* Added mask overlay to output image, changed fprintf info messages to stdout

* Improved box filtering (filter area/score), make sure roi coordinates stay within bounds

* clang-format

* Support UNet in mmseg (#77)

* Repeatdataset in train has no CLASSES & PALETTE

* update result for unet

* update docstring for mmdet

* remove ppl for unet in docs

* fix ort wrap about input type (#81)

* Fix memleak (#86)

* delete []

* fix build error when enble MMDEPLOY_ACTIVE_LEVEL

* fix lint

* [Doc] Nano benchmark and tutorial (#71)

* add cls benchmark

* add nano zh-cn benchmark and en tutorial

* add device row

* add doc path to index.rst

* fix typo

* [Fix] fix missing deploy_core (#80)

* fix missing deploy_core

* mv flag to demo

* target link

* [Docs] Fix links in Chinese doc (#84)

* Fix docs in Chinese link

* Fix links

* Delete symbolic link and add links to html

* delete files

* Fix link

* [Feature] Add docker files (#67)

* add gpu and cpu dockerfile

* fix lint

* fix cpu docker and remove redundant

* use pip instead

* add build arg and readme

* fix grammar

* update readme

* add chinese doc for dockerfile and add docker build to build.md

* grammar

* refine dockerfiles

* add FAQs

* update Dpplcv_DIR for SDK building

* remove mmcls

* add sdk demos

* fix typo and lint

* update FAQs

* [Fix]fix check_env (#101)

* fix check_env

* update

* Replace convert_syncbatchnorm in mmseg (#93)

* replace convert_syncbatchnorm with revert_sync_batchnorm from mmcv

* change logger

* [Doc] Update FAQ for TensorRT (#96)

* update FAQ

* comment

* [Docs]: Update doc for openvino installation (#102)

* fix docs

* fix docs

* fix docs

* fix mmcv version

* fix docs

* rm blank line

* simplify non batch nms (#99)

* [Enhacement] Allow test.py to save evaluation results (#108)

* Add log file

* Delete debug code

* Rename logger

* resolve comments

* [Enhancement] Support mmocr v0.4+ (#115)

* support mmocr v0.4+

* 0.4.0 -> 0.4.1

* fix onnxruntime wrapper for gpu inference (#123)

* fix ncnn wrapper for ort-gpu

* resolve comment

* fix lint

* Fix typo (#132)

* lock mmcls version (#131)

* [Enhancement] upgrade isort in pre-commit config (#141)

* [Enhancement] upgrade isort in pre-commit config by refering to mmflow pr #87

* fix lint

* remove .isort.cfg and put its known_third_party to setup.cfg

* Fix ci for mmocr (#144)

* fix mmocr unittests

* remove useless

* lock mmdet maximum version to 2.20

* pip install -U numpy

* Fix capture_output (#125)

Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>

* configs for all tasks

* use torchvision roi align

* remote unnecessary code

* fix ut

* fix ut

* export

* det dynamic

* det dynamic

* add ut

* fix ut

* add ut and docs

* fix ut

* skip torchscript ut if no ops available

* add torchscript option to build.md

* update benchmark and resolve comments

* resolve conflicts

* rename configs

* fix mrcnn cuda test

* remove useless

* add version requirements to docs and comments to codes

* enable empty image exporting for torchscript and accelerate ORT inference for MRCNN

* rebase

* update example for torchscript.md

* update FAQs for torchscript.md

* resolve comments

* only use torchvision roi_align for torchscript

* fix ut

* use torchvision roi align when pool model is avg

* resolve comments

Co-authored-by: grimoire <streetyao@live.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>

* Update supported mmseg models (#181)

* fix ocrnet cascade decoder

* update mmseg support models

* update mmseg configs

* support emanet and icnet

* set max K of TopK for tensorrt

* update supported models for mmseg in docs

* add test for emamodule

* add configs and update docs

* Update docs

* update benchmark

* [Features]Support mmdet3d (#103)

* add mmdet3d code

* add code

* update code

* [log]This commit finish pointpillar export and evaluate on onnxruntime.The model is sample with nvidia repo model

* add tensorrt config

* fix config

* update

* support for tensorrt

* add config

* fix config`

* fix apis about torch2onnx

* update

* mmdet3d deploy version1.0

* map is ok

* fix code

* version1.0

* fix code

* fix visual

* fix bug

* tensorrt support success

* add docstring

* add docs

* fix docs

* fix comments

* fix comment

* fix comment

* fix openvino wrapper

* add unit test

* fix device about cpu

* fix comment

* fix show_result

* fix lint

* fix requirments

* remove ci about det3d

* fix ut

* add ut data

* support for new version pointpillars

* fix comment

* fix support_list

* fix comments

* fix config name

* [Enhancement] Update pad logic in detection heads (#168)

* pad with register

* fix lint

Co-authored-by: AllentDan <dongchunyu@sensetime.com>

* [Enhancement] Additional arguments support for OpenVINO Model Optimizer (#178)

* Add mo args.

* [Docs]: update docs and argument descriptions (#196)

* bump version to v0.4.0

* update docs and argument descriptions

* revert version change

* fix unnecessary change of config for dynamic exportation (#199)

* fix mmcls get classes (#215)

* fix mmcls get classes

* resolve comment

* resolve comment

* Add ModelOptimizerOptions.

* Fix merge bugs.

* Update mmpose.md (#224)

* [Dostring]add example in apis docstring (#214)

* add example in apis docstring

* add backend example in docstring

* rm blank line

* Fixed get_mo_options_from_cfg args

* fix l2norm test

Co-authored-by: RunningLeon <mnsheng@yeah.net>
Co-authored-by: Haofan Wang <frankmiracle@outlook.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>

* [Enhancement] Switch to statically typed Value::Any (#209)

* replace std::any with StaticAny

* fix __compare_typeid

* remove fallback id support

* constraint on traits::TypeId<T>::value

* fix includes

* support for centerpoint

* [Enhancement] TensorRT DCN support (#205)

* add tensorrt dcn support

* fix lint

* add docstring and dcn model support

* add centerpoint ut and docs

* add config and fix input rank

* fix merge error

* fix a bug

* fix comment

* [Doc] update benchmark add supported-model-list (#286)

* update benchmark add supported-model-list

* fix lint

* fix lint

* loc mmocr maximum version

* fix ut

Co-authored-by: maningsheng <mnsheng@yeah.net>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: grimoire <streetyao@live.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>
Co-authored-by: AllentDan <dongchunyu@sensetime.com>
Co-authored-by: Haofan Wang <frankmiracle@outlook.com>
Co-authored-by: lzhangzz <lzhang329@gmail.com>

Co-authored-by: maningsheng <mnsheng@yeah.net>
Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: grimoire <streetyao@live.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: Johannes L <tehkillerbee@users.noreply.github.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: 杨培文 (Yang Peiwen) <915505626@qq.com>
Co-authored-by: Semyon Bevzyuk <semen.bevzuk@gmail.com>
Co-authored-by: AllentDan <dongchunyu@sensetime.com>
Co-authored-by: Haofan Wang <frankmiracle@outlook.com>
Co-authored-by: lzhangzz <lzhang329@gmail.com>
Co-authored-by: Chen Xin <xinchen.tju@gmail.com>
Co-authored-by: chenxin2 <chenxin2@sensetime.com>
2022-04-01 18:14:23 +08:00

391 lines
15 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include <cctype>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/imgproc.hpp>
#include "core/device.h"
#include "core/registry.h"
#include "core/serialization.h"
#include "core/tensor.h"
#include "core/utils/device_utils.h"
#include "core/utils/formatter.h"
#include "core/value.h"
#include "experimental/module_adapter.h"
#include "mmpose.h"
#include "opencv_utils.h"
namespace mmdeploy::mmpose {
using std::string;
using std::vector;
template <class F>
struct _LoopBody : public cv::ParallelLoopBody {
F f_;
_LoopBody(F f) : f_(std::move(f)) {}
void operator()(const cv::Range& range) const override { f_(range); }
};
std::string to_lower(const std::string& s) {
std::string t = s;
std::transform(t.begin(), t.end(), t.begin(), [](unsigned char c) { return std::tolower(c); });
return t;
}
class TopdownHeatmapBaseHeadDecode : public MMPose {
public:
explicit TopdownHeatmapBaseHeadDecode(const Value& config) : MMPose(config) {
if (config.contains("params")) {
auto& params = config["params"];
flip_test_ = params.value("flip_test", flip_test_);
use_udp_ = params.value("use_udp", use_udp_);
target_type_ = params.value("target_type", target_type_);
valid_radius_factor_ = params.value("valid_radius_factor", valid_radius_factor_);
unbiased_decoding_ = params.value("unbiased_decoding", unbiased_decoding_);
post_process_ = params.value("post_process", post_process_);
shift_heatmap_ = params.value("shift_heatmap", shift_heatmap_);
modulate_kernel_ = params.value("modulate_kernel", modulate_kernel_);
}
}
Result<Value> operator()(const Value& _data, const Value& _prob) {
MMDEPLOY_DEBUG("preprocess_result: {}", _data);
MMDEPLOY_DEBUG("inference_result: {}", _prob);
Device cpu_device{"cpu"};
OUTCOME_TRY(auto heatmap,
MakeAvailableOnDevice(_prob["output"].get<Tensor>(), cpu_device, stream()));
OUTCOME_TRY(stream().Wait());
if (!(heatmap.shape().size() == 4 && heatmap.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", heatmap.shape(),
(int)heatmap.data_type());
return Status(eNotSupported);
}
auto& img_metas = _data["img_metas"];
vector<float> center;
vector<float> scale;
from_value(img_metas["center"], center);
from_value(img_metas["scale"], scale);
Tensor pred =
keypoints_from_heatmap(heatmap, center, scale, unbiased_decoding_, post_process_,
modulate_kernel_, valid_radius_factor_, use_udp_, target_type_);
return GetOutput(pred);
}
Value GetOutput(Tensor& pred) {
PoseDetectorOutput output;
int K = pred.shape(1);
float* data = pred.data<float>();
for (int i = 0; i < K; i++) {
float x = *(data + 0);
float y = *(data + 1);
float s = *(data + 2);
output.key_points.push_back({{x, y}, s});
data += 3;
}
return to_value(std::move(output));
}
Tensor keypoints_from_heatmap(const Tensor& _heatmap, const vector<float>& center,
const vector<float>& scale, bool unbiased_decoding,
const string& post_process, int modulate_kernel,
float valid_radius_factor, bool use_udp,
const string& target_type) {
Tensor heatmap(_heatmap.desc());
heatmap.CopyFrom(_heatmap, stream()).value();
stream().Wait().value();
int K = heatmap.shape(1);
int H = heatmap.shape(2);
int W = heatmap.shape(3);
if (post_process == "megvii") {
heatmap = gaussian_blur(heatmap, modulate_kernel);
}
Tensor pred;
if (use_udp) {
if (to_lower(target_type) == to_lower(string("GaussianHeatMap"))) {
pred = get_max_pred(heatmap);
post_dark_udp(pred, heatmap, modulate_kernel);
} else if (to_lower(target_type) == to_lower(string("CombinedTarget"))) {
// output channel = 3 * channel_cfg['num_output_channels']
assert(K % 3 == 0);
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
for (int i = r.start; i < r.end; i++) {
int kt = (i % 3 == 0) ? 2 * modulate_kernel + 1 : modulate_kernel;
float* data = heatmap.data<float>() + i * H * W;
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
cv::GaussianBlur(work, work, {kt, kt}, 0); // inplace
}
}});
float valid_radius = valid_radius_factor_ * H;
TensorDesc desc = {Device{"cpu"}, DataType::kFLOAT, {1, K / 3, H, W}};
Tensor offset_x(desc);
Tensor offset_y(desc);
Tensor heatmap_(desc);
{
// split heatmap
float* src = heatmap.data<float>();
float* dst0 = heatmap_.data<float>();
float* dst1 = offset_x.data<float>();
float* dst2 = offset_y.data<float>();
for (int i = 0; i < K / 3; i++) {
std::copy_n(src, H * W, dst0);
std::transform(src + H * W, src + 2 * H * W, dst1,
[=](float& x) { return x * valid_radius; });
std::transform(src + 2 * H * W, src + 3 * H * W, dst2,
[=](float& x) { return x * valid_radius; });
src += 3 * H * W;
dst0 += H * W;
dst1 += H * W;
dst2 += H * W;
}
}
pred = get_max_pred(heatmap_);
for (int i = 0; i < K / 3; i++) {
float* data = pred.data<float>() + i * 3;
int index = *(data + 0) + *(data + 1) * W + H * W * i;
float* offx = offset_x.data<float>() + index;
float* offy = offset_y.data<float>() + index;
*(data + 0) += *offx;
*(data + 1) += *offy;
}
}
} else {
pred = get_max_pred(heatmap);
if (post_process == "unbiased") {
heatmap = gaussian_blur(heatmap, modulate_kernel);
float* data = heatmap.data<float>();
std::for_each(data, data + K * H * W, [](float& v) {
double _v = std::max((double)v, 1e-10);
v = std::log(_v);
});
for (int i = 0; i < K; i++) {
taylor(heatmap, pred, i);
}
} else if (post_process != "null") {
for (int i = 0; i < K; i++) {
float* data = heatmap.data<float>() + i * W * H;
auto _data = [&](int y, int x) { return *(data + y * W + x); };
int px = *(pred.data<float>() + i * 3 + 0);
int py = *(pred.data<float>() + i * 3 + 1);
if (1 < px && px < W - 1 && 1 < py && py < H - 1) {
float v1 = _data(py, px + 1) - _data(py, px - 1);
float v2 = _data(py + 1, px) - _data(py - 1, px);
*(pred.data<float>() + i * 3 + 0) += (v1 > 0) ? 0.25 : ((v1 < 0) ? -0.25 : 0);
*(pred.data<float>() + i * 3 + 1) += (v2 > 0) ? 0.25 : ((v2 < 0) ? -0.25 : 0);
if (post_process_ == "megvii") {
*(pred.data<float>() + i * 3 + 0) += 0.5;
*(pred.data<float>() + i * 3 + 1) += 0.5;
}
}
}
}
}
K = pred.shape(1); // changed if target_type is CombinedTarget
// Transform back to the image
for (int i = 0; i < K; i++) {
transform_pred(pred, i, center, scale, {W, H}, use_udp);
}
if (post_process_ == "megvii") {
for (int i = 0; i < K; i++) {
float* data = pred.data<float>() + i * 3 + 2;
*data = *data / 255.0 + 0.5;
}
}
return pred;
}
void post_dark_udp(Tensor& pred, Tensor& heatmap, int kernel) {
int K = heatmap.shape(1);
int H = heatmap.shape(2);
int W = heatmap.shape(3);
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
for (int i = r.start; i < r.end; i++) {
float* data = heatmap.data<float>() + i * H * W;
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
cv::GaussianBlur(work, work, {kernel, kernel}, 0); // inplace
}
}});
std::for_each(heatmap.data<float>(), heatmap.data<float>() + K * H * W, [](float& x) {
x = std::max(0.001f, std::min(50.f, x));
x = std::log(x);
});
auto _heatmap_data = [&](int index, int c) -> float {
int y = index / (W + 2);
int x = index % (W + 2);
y = std::max(0, y - 1);
x = std::max(0, x - 1);
return *(heatmap.data<float>() + c * H * W + y * W + x);
};
for (int i = 0; i < K; i++) {
float* data = pred.data<float>() + i * 3;
int index = *(data + 0) + 1 + (*(data + 1) + 1) * (W + 2);
float i_ = _heatmap_data(index, i);
float ix1 = _heatmap_data(index + 1, i);
float iy1 = _heatmap_data(index + W + 2, i);
float ix1y1 = _heatmap_data(index + W + 3, i);
float ix1_y1_ = _heatmap_data(index - W - 3, i);
float ix1_ = _heatmap_data(index - 1, i);
float iy1_ = _heatmap_data(index - 2 - W, i);
float dx = 0.5 * (ix1 - ix1_);
float dy = 0.5 * (iy1 - iy1_);
float dxx = ix1 - 2 * i_ + ix1_;
float dyy = iy1 - 2 * i_ + iy1_;
float dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_);
vector<float> _data0 = {dx, dy};
vector<float> _data1 = {dxx, dxy, dxy, dyy};
cv::Mat derivative = cv::Mat(2, 1, CV_32FC1, _data0.data());
cv::Mat hessian = cv::Mat(2, 2, CV_32FC1, _data1.data());
cv::Mat hessianinv = hessian.inv();
cv::Mat offset = -hessianinv * derivative;
*(data + 0) += offset.at<float>(0, 0);
*(data + 1) += offset.at<float>(1, 0);
}
}
void transform_pred(Tensor& pred, int k, const vector<float>& center, const vector<float>& _scale,
const vector<int>& output_size, bool use_udp = false) {
auto scale = _scale;
scale[0] *= 200;
scale[1] *= 200;
float scale_x, scale_y;
if (use_udp) {
scale_x = scale[0] / (output_size[0] - 1.0);
scale_y = scale[1] / (output_size[1] - 1.0);
} else {
scale_x = scale[0] / output_size[0];
scale_y = scale[1] / output_size[1];
}
float* data = pred.data<float>() + k * 3;
*(data + 0) = *(data + 0) * scale_x + center[0] - scale[0] * 0.5;
*(data + 1) = *(data + 1) * scale_y + center[1] - scale[1] * 0.5;
}
void taylor(const Tensor& heatmap, Tensor& pred, int k) {
int K = heatmap.shape(1);
int H = heatmap.shape(2);
int W = heatmap.shape(3);
int px = *(pred.data<float>() + k * 3 + 0);
int py = *(pred.data<float>() + k * 3 + 1);
if (1 < px && px < W - 2 && 1 < py && py < H - 2) {
float* data = const_cast<float*>(heatmap.data<float>() + k * H * W);
auto get_data = [&](int r, int c) { return *(data + r * W + c); };
float dx = 0.5 * (get_data(py, px + 1) - get_data(py, px - 1));
float dy = 0.5 * (get_data(py + 1, px) - get_data(py - 1, px));
float dxx = 0.25 * (get_data(py, px + 2) - 2 * get_data(py, px) + get_data(py, px - 2));
float dxy = 0.25 * (get_data(py + 1, px + 1) - get_data(py - 1, px + 1) -
get_data(py + 1, px - 1) + get_data(py - 1, px - 1));
float dyy = 0.25 * (get_data(py + 2, px) - 2 * get_data(py, px) + get_data(py - 2, px));
vector<float> _data0 = {dx, dy};
vector<float> _data1 = {dxx, dxy, dxy, dyy};
cv::Mat derivative = cv::Mat(2, 1, CV_32FC1, _data0.data());
cv::Mat hessian = cv::Mat(2, 2, CV_32FC1, _data1.data());
if (std::fabs(dxx * dyy - dxy * dxy) > 1e-6) {
cv::Mat hessianinv = hessian.inv();
cv::Mat offset = -hessianinv * derivative;
*(pred.data<float>() + k * 3 + 0) += offset.at<float>(0, 0);
*(pred.data<float>() + k * 3 + 1) += offset.at<float>(1, 0);
}
}
}
Tensor gaussian_blur(const Tensor& _heatmap, int kernel) {
assert(kernel % 2 == 1);
auto desc = _heatmap.desc();
Tensor heatmap(desc);
int K = _heatmap.shape(1);
int H = _heatmap.shape(2);
int W = _heatmap.shape(3);
int num_points = H * W;
int border = (kernel - 1) / 2;
for (int i = 0; i < K; i++) {
int offset = i * H * W;
float* data = const_cast<float*>(_heatmap.data<float>()) + offset;
float origin_max = *std::max_element(data, data + num_points);
cv::Mat work = cv::Mat(H + 2 * border, W + 2 * border, CV_32FC1, cv::Scalar{});
cv::Mat curr = cv::Mat(H, W, CV_32FC1, data);
cv::Rect roi = {border, border, W, H};
curr.copyTo(work(roi));
cv::GaussianBlur(work, work, {kernel, kernel}, 0);
cv::Mat valid = work(roi).clone();
float cur_max = *std::max_element((float*)valid.data, (float*)valid.data + num_points);
float* dst = heatmap.data<float>() + offset;
std::transform((float*)valid.data, (float*)valid.data + num_points, dst,
[&](float v) { return v * origin_max / cur_max; });
}
return heatmap;
}
Tensor get_max_pred(const Tensor& heatmap) {
int K = heatmap.shape(1);
int H = heatmap.shape(2);
int W = heatmap.shape(3);
int num_points = H * W;
TensorDesc pred_desc = {Device{"cpu"}, DataType::kFLOAT, {1, K, 3}};
Tensor pred(pred_desc);
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
for (int i = r.start; i < r.end; i++) {
float* src_data = const_cast<float*>(heatmap.data<float>()) + i * H * W;
cv::Mat mat = cv::Mat(H, W, CV_32FC1, src_data);
double min_val, max_val;
cv::Point min_loc, max_loc;
cv::minMaxLoc(mat, &min_val, &max_val, &min_loc, &max_loc);
float* dst_data = pred.data<float>() + i * 3;
*(dst_data + 0) = -1;
*(dst_data + 1) = -1;
*(dst_data + 2) = max_val;
if (max_val > 0.0) {
*(dst_data + 0) = max_loc.x;
*(dst_data + 1) = max_loc.y;
}
}
}});
return pred;
}
private:
bool flip_test_{true};
bool shift_heatmap_{true};
string post_process_ = {"default"};
int modulate_kernel_{11};
bool unbiased_decoding_{false};
float valid_radius_factor_{0.0546875f};
bool use_udp_{false};
string target_type_{"GaussianHeatmap"};
};
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapBaseHeadDecode);
// decode process is same
using TopdownHeatmapSimpleHeadDecode = TopdownHeatmapBaseHeadDecode;
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapSimpleHeadDecode);
using TopdownHeatmapMultiStageHeadDecode = TopdownHeatmapBaseHeadDecode;
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapMultiStageHeadDecode);
using ViPNASHeatmapSimpleHeadDecode = TopdownHeatmapBaseHeadDecode;
REGISTER_CODEBASE_COMPONENT(MMPose, ViPNASHeatmapSimpleHeadDecode);
using TopdownHeatmapMSMUHeadDecode = TopdownHeatmapBaseHeadDecode;
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapMSMUHeadDecode);
} // namespace mmdeploy::mmpose