[Enhancement] Support DETR (#924)
* add detr support * fix softmax * add reg test, update documentpull/960/head
parent
f7e0905e95
commit
f4decda86e
|
@ -773,6 +773,19 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](../
|
|||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/detr/detr_r50_8x2_150e_coco.py">DETR</a></td>
|
||||
<td align="center">Object Detection</td>
|
||||
<td align="center">COCO2017</td>
|
||||
<td align="center">box AP</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py">Mask R-CNN</a></td>
|
||||
<td align="center" rowspan="2">Instance Segmentation</td>
|
||||
|
|
|
@ -2,80 +2,82 @@
|
|||
|
||||
The table below lists the models that are guaranteed to be exportable to other backends.
|
||||
|
||||
| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
||||
| :------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: |
|
||||
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
|
||||
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
|
||||
| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
|
||||
| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
|
||||
| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) |
|
||||
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
|
||||
| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
|
||||
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
|
||||
| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
|
||||
| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
|
||||
| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
|
||||
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Swin Transformer | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
|
||||
| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
|
||||
| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) |
|
||||
| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) |
|
||||
| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) |
|
||||
| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) |
|
||||
| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) |
|
||||
| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) |
|
||||
| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) |
|
||||
| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) |
|
||||
| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) |
|
||||
| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) |
|
||||
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
|
||||
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
|
||||
| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
|
||||
| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
|
||||
| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
|
||||
| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
|
||||
| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
|
||||
| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) |
|
||||
| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) |
|
||||
| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) |
|
||||
| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) |
|
||||
| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) |
|
||||
| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) |
|
||||
| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) |
|
||||
| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) |
|
||||
| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) |
|
||||
| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) |
|
||||
| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) |
|
||||
| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) |
|
||||
| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) |
|
||||
| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) |
|
||||
| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) |
|
||||
| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) |
|
||||
| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) |
|
||||
| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) |
|
||||
| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) |
|
||||
| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
||||
| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
||||
| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
||||
| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
||||
| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) |
|
||||
| PANet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/panet) |
|
||||
| PSENet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/psenet) |
|
||||
| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) |
|
||||
| SAR[\*](#note) | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) |
|
||||
| SATRN | MMOCR | Y | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/satrn) |
|
||||
| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) |
|
||||
| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
|
||||
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
|
||||
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
|
||||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
|
||||
| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
||||
| :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: |
|
||||
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
|
||||
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
|
||||
| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
|
||||
| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
|
||||
| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) |
|
||||
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
|
||||
| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
|
||||
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
|
||||
| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
|
||||
| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
|
||||
| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
|
||||
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
|
||||
| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
|
||||
| RepPoints | MMDetection | N | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
|
||||
| DETR | MMDetection | N | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/detr) |
|
||||
| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) |
|
||||
| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) |
|
||||
| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) |
|
||||
| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) |
|
||||
| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) |
|
||||
| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) |
|
||||
| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) |
|
||||
| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) |
|
||||
| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) |
|
||||
| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) |
|
||||
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
|
||||
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
|
||||
| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
|
||||
| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
|
||||
| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
|
||||
| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
|
||||
| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
|
||||
| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) |
|
||||
| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) |
|
||||
| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) |
|
||||
| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) |
|
||||
| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) |
|
||||
| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) |
|
||||
| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) |
|
||||
| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) |
|
||||
| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) |
|
||||
| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) |
|
||||
| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) |
|
||||
| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) |
|
||||
| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) |
|
||||
| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) |
|
||||
| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) |
|
||||
| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) |
|
||||
| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) |
|
||||
| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) |
|
||||
| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) |
|
||||
| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
||||
| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
||||
| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
||||
| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
||||
| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) |
|
||||
| PANet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/panet) |
|
||||
| DBNet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/psenet) |
|
||||
| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) |
|
||||
| SAR | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) |
|
||||
| SATRN | MMOCR | Y | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/satrn) |
|
||||
| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) |
|
||||
| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
|
||||
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
|
||||
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
|
||||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
|
||||
|
||||
### Note
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/
|
|||
| Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
|
||||
| GFL | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
|
||||
| RepPoints | ObjectDetection | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
|
||||
| DETR | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/detr) |
|
||||
| Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
|
||||
| Swin Transformer | InstanceSegmentation | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
|
||||
|
|
|
@ -749,6 +749,19 @@ GPU: ncnn, TensorRT, PPLNN
|
|||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/detr/detr_r50_8x2_150e_coco.py">DETR</a></td>
|
||||
<td align="center">Object Detection</td>
|
||||
<td align="center">COCO2017</td>
|
||||
<td align="center">box AP</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">40.1</td>
|
||||
<td align="center">-</td>
|
||||
<td align="center">-</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center" rowspan="2"><a href="https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py">Mask R-CNN</a></td>
|
||||
<td align="center" rowspan="2">Instance Segmentation</td>
|
||||
|
|
|
@ -2,81 +2,79 @@
|
|||
|
||||
自测完成的 model-backend 组合:
|
||||
|
||||
| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
||||
| :------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: |
|
||||
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
|
||||
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
|
||||
| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
|
||||
| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
|
||||
| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) |
|
||||
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
|
||||
| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
|
||||
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
|
||||
| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
|
||||
| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
|
||||
| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
|
||||
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Swin Transformer | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
|
||||
| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
|
||||
| RepPoints | MMDetection | N | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
|
||||
| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) |
|
||||
| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) |
|
||||
| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) |
|
||||
| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) |
|
||||
| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) |
|
||||
| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) |
|
||||
| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) |
|
||||
| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) |
|
||||
| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) |
|
||||
| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) |
|
||||
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
|
||||
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
|
||||
| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
|
||||
| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
|
||||
| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
|
||||
| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
|
||||
| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
|
||||
| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) |
|
||||
| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) |
|
||||
| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) |
|
||||
| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) |
|
||||
| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) |
|
||||
| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) |
|
||||
| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) |
|
||||
| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) |
|
||||
| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) |
|
||||
| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) |
|
||||
| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) |
|
||||
| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) |
|
||||
| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) |
|
||||
| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) |
|
||||
| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) |
|
||||
| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) |
|
||||
| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) |
|
||||
| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) |
|
||||
| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) |
|
||||
| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
||||
| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
||||
| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
||||
| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
||||
| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) |
|
||||
| PANet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/panet) |
|
||||
| PSENet | MMOCR | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/psenet) |
|
||||
| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) |
|
||||
| SAR[\*](#note) | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) |
|
||||
| SATRN | MMOCR | Y | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/satrn) |
|
||||
| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) |
|
||||
| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
|
||||
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
|
||||
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
|
||||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
|
||||
| Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
||||
| :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :---------------------------------------------------------------------------------------------: |
|
||||
| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) |
|
||||
| Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
|
||||
| YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) |
|
||||
| YOLOX | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) |
|
||||
| FCOS | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) |
|
||||
| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) |
|
||||
| Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
|
||||
| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
|
||||
| FoveaBox | MMDetection | Y | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
|
||||
| ATSS | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
|
||||
| GFL | MMDetection | N | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
|
||||
| Cascade R-CNN | MMDetection | N | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Cascade Mask R-CNN | MMDetection | N | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Swin Transformer[\*](#note) | MMDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
|
||||
| VFNet | MMDetection | N | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
|
||||
| RepPoints | MMDetection | N | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
|
||||
| DETR | MMDetection | N | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/detr) |
|
||||
| ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet) |
|
||||
| ResNeXt | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext) |
|
||||
| SE-ResNet | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet) |
|
||||
| MobileNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2) |
|
||||
| ShuffleNetV1 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1) |
|
||||
| ShuffleNetV2 | MMClassification | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2) |
|
||||
| VisionTransformer | MMClassification | Y | Y | Y | Y | ? | Y | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/vision_transformer) |
|
||||
| SwinTransformer | MMClassification | Y | Y | Y | N | ? | N | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/swin_transformer) |
|
||||
| FCN | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fcn) |
|
||||
| PSPNet[\*static](#note) | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/pspnet) |
|
||||
| DeepLabV3 | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3) |
|
||||
| DeepLabV3+ | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/deeplabv3plus) |
|
||||
| Fast-SCNN[\*static](#note) | MMSegmentation | Y | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastscnn) |
|
||||
| UNet | MMSegmentation | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/unet) |
|
||||
| ANN[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ann) |
|
||||
| APCNet | MMSegmentation | ? | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/apcnet) |
|
||||
| BiSeNetV1 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv1) |
|
||||
| BiSeNetV2 | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/bisenetv2) |
|
||||
| CGNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/cgnet) |
|
||||
| DMNet | MMSegmentation | ? | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dmnet) |
|
||||
| DNLNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/dnlnet) |
|
||||
| EMANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/emanet) |
|
||||
| EncNet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/encnet) |
|
||||
| ERFNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/erfnet) |
|
||||
| FastFCN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/fastfcn) |
|
||||
| GCNet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/gcnet) |
|
||||
| ICNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/icnet) |
|
||||
| ISANet | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/isanet) |
|
||||
| NonLocal Net | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/nonlocal_net) |
|
||||
| OCRNet | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/ocrnet) |
|
||||
| PointRend | MMSegmentation | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend) |
|
||||
| Semantic FPN | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/sem_fpn) |
|
||||
| STDC | MMSegmentation | ? | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/stdc) |
|
||||
| UPerNet[\*](#note) | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/upernet) |
|
||||
| DANet | MMSegmentation | ? | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/danet) |
|
||||
| Segmenter | MMSegmentation | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segmenter) |
|
||||
| SRCNN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
||||
| ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| SRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| SRResNet | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| Real-ESRGAN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
||||
| EDSR | MMEditing | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
||||
| RDN | MMEditing | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
||||
| DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) |
|
||||
| CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) |
|
||||
| SAR | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) |
|
||||
| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) |
|
||||
| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
|
||||
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
|
||||
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |
|
||||
| CenterPoint (pillar) | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/centerpoint) |
|
||||
| RotatedRetinaNet | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/rotated_retinanet/README.md) |
|
||||
| Oriented RCNN | RotatedDetection | N | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/oriented_rcnn/README.md) |
|
||||
| Gliding Vertex | RotatedDetection | N | N | Y | N | N | N | [config](https://github.com/open-mmlab/mmrotate/blob/main/configs/gliding_vertex/README.md) |
|
||||
|
||||
## Note
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ mmdet 是基于 pytorch 的检测工具箱,属于 [OpenMMLab](https://openmmla
|
|||
| Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
|
||||
| GFL | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
|
||||
| RepPoints | ObjectDetection | N | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints) |
|
||||
| DETR | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/detr) |
|
||||
| Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
|
||||
| Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |
|
||||
| Swin Transformer | InstanceSegmentation | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/swin) |
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import detr_head
|
||||
from .base_dense_head import (base_dense_head__get_bbox,
|
||||
base_dense_head__get_bboxes__ncnn)
|
||||
from .fovea_head import fovea_head__get_bboxes
|
||||
|
@ -15,5 +16,5 @@ __all__ = [
|
|||
'yolox_head__get_bboxes', 'base_dense_head__get_bbox',
|
||||
'fovea_head__get_bboxes', 'base_dense_head__get_bboxes__ncnn',
|
||||
'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn',
|
||||
'gfl_head__get_bbox', 'reppoints_head__get_bboxes'
|
||||
'gfl_head__get_bbox', 'reppoints_head__get_bboxes', 'detr_head'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmdet.core import bbox_cxcywh_to_xyxy
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.DETRHead.forward_single')
|
||||
def detrhead__forward_single__default(ctx, self, x, img_metas):
|
||||
"""forward_single of DETRHead.
|
||||
|
||||
Ease the mask computation
|
||||
"""
|
||||
|
||||
batch_size = x.size(0)
|
||||
|
||||
x = self.input_proj(x)
|
||||
# interpolate masks to have the same spatial shape with x
|
||||
masks = x.new_zeros((batch_size, x.size(-2), x.size(-1))).to(torch.bool)
|
||||
|
||||
# position encoding
|
||||
pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
|
||||
# outs_dec: [nb_dec, bs, num_query, embed_dim]
|
||||
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
|
||||
pos_embed)
|
||||
all_cls_scores = self.fc_cls(outs_dec)
|
||||
all_bbox_preds = self.fc_reg(self.activate(
|
||||
self.reg_ffn(outs_dec))).sigmoid()
|
||||
return all_cls_scores, all_bbox_preds
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.DETRHead.get_bboxes')
|
||||
def detrhead__get_bboxes__default(ctx,
|
||||
self,
|
||||
all_cls_scores_list,
|
||||
all_bbox_preds_list,
|
||||
img_metas,
|
||||
rescale=False):
|
||||
"""Rewrite `get_bboxes` of `FoveaHead` for default backend."""
|
||||
cls_scores = all_cls_scores_list[-1][-1]
|
||||
bbox_preds = all_bbox_preds_list[-1][-1]
|
||||
|
||||
img_shape = img_metas[0]['img_shape']
|
||||
max_per_img = self.test_cfg.get('max_per_img', self.num_query)
|
||||
batch_size = cls_scores.size(0)
|
||||
# `batch_index_offset` is used for the gather of concatenated tensor
|
||||
|
||||
# supports dynamical batch inference
|
||||
if self.loss_cls.use_sigmoid:
|
||||
batch_index_offset = torch.arange(batch_size).to(
|
||||
cls_scores.device) * max_per_img
|
||||
batch_index_offset = batch_index_offset.unsqueeze(1).expand(
|
||||
batch_size, max_per_img)
|
||||
cls_scores = cls_scores.sigmoid()
|
||||
scores, indexes = cls_scores.flatten(1).topk(max_per_img, dim=1)
|
||||
det_labels = indexes % self.num_classes
|
||||
bbox_index = indexes // self.num_classes
|
||||
bbox_index = (bbox_index + batch_index_offset).view(-1)
|
||||
bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
|
||||
bbox_preds = bbox_preds.view(batch_size, -1, 4)
|
||||
else:
|
||||
scores, det_labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1)
|
||||
scores, bbox_index = scores.topk(max_per_img, dim=1)
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=scores.device).unsqueeze(-1)
|
||||
bbox_preds = bbox_preds[batch_inds, bbox_index, ...]
|
||||
# add unsqueeze to support tensorrt
|
||||
det_labels = det_labels.unsqueeze(-1)[batch_inds, bbox_index,
|
||||
...].squeeze(-1)
|
||||
|
||||
det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)
|
||||
|
||||
if isinstance(img_shape, torch.Tensor):
|
||||
hw = img_shape.flip(0).to(det_bboxes.device)
|
||||
else:
|
||||
hw = det_bboxes.new_tensor([img_shape[1], img_shape[0]])
|
||||
shape_scale = torch.cat([hw, hw])
|
||||
shape_scale = shape_scale.view(1, 1, -1)
|
||||
det_bboxes = det_bboxes * shape_scale
|
||||
# dynamically clip bboxes
|
||||
x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
|
||||
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
||||
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, img_shape)
|
||||
det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
|
||||
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)
|
||||
|
||||
return det_bboxes, det_labels
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import multi_head_attention_forward
|
||||
from .adaptive_pool import (adaptive_avg_pool2d__default,
|
||||
adaptive_avg_pool2d__ncnn)
|
||||
from .atan2 import atan2__default
|
||||
|
@ -23,5 +24,5 @@ __all__ = [
|
|||
'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
|
||||
'chunk__torchscript', 'masked_fill__onnxruntime',
|
||||
'tensor__setitem__default', 'adaptive_avg_pool2d__default',
|
||||
'adaptive_avg_pool2d__ncnn'
|
||||
'adaptive_avg_pool2d__ncnn', 'multi_head_attention_forward'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional._scaled_dot_product_attention',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def _scaled_dot_product_attention__default(
|
||||
ctx,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Rewrite `_scaled_dot_product_attention` to enable softmax."""
|
||||
B, Nt, E = q.shape
|
||||
q = q / math.sqrt(E)
|
||||
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
||||
attn = torch.bmm(q, k.transpose(-2, -1))
|
||||
if attn_mask is not None:
|
||||
attn += attn_mask
|
||||
# add slice to enable softmax
|
||||
# TODO: Find the reason
|
||||
step = 500
|
||||
if attn.size(-1) > step:
|
||||
attn_max = attn[..., :step].max(-1, keepdim=True)[0]
|
||||
for i in range(step, attn.size(-1), step):
|
||||
attn_max_new = attn[..., i:i + step].max(-1, keepdim=True)[0]
|
||||
attn_max = attn_max.where(attn_max > attn_max_new, attn_max_new)
|
||||
else:
|
||||
attn_max = attn.max(-1, keepdim=True)[0]
|
||||
|
||||
attn = attn - attn_max
|
||||
attn_exp = attn.exp()
|
||||
if attn_exp.size(-1) > step:
|
||||
attn_sum = attn_exp[..., :step].sum(-1, keepdim=True)
|
||||
for i in range(step, attn_exp.size(-1), step):
|
||||
attn_sum_new = attn_exp[..., i:i + step].sum(-1, keepdim=True)
|
||||
attn_sum += attn_sum_new
|
||||
else:
|
||||
attn_sum = attn_exp.sum(-1, keepdim=True)
|
||||
attn = attn_exp / attn_sum
|
||||
|
||||
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
||||
output = torch.bmm(attn, v)
|
||||
return output, attn
|
|
@ -323,3 +323,11 @@ models:
|
|||
pipelines:
|
||||
- *pipeline_seg_ort_dynamic_fp32
|
||||
- *pipeline_seg_trt_dynamic_fp32
|
||||
|
||||
- name: DETR
|
||||
metafile: configs/detr/metafile.yml
|
||||
model_configs:
|
||||
- configs/detr/detr_r50_8x2_150e_coco.py
|
||||
pipelines:
|
||||
- *pipeline_seg_ort_dynamic_fp32
|
||||
- *pipeline_seg_trt_dynamic_fp32
|
||||
|
|
|
@ -165,6 +165,70 @@ def get_reppoints_head_model():
|
|||
return model
|
||||
|
||||
|
||||
def get_detrhead_model():
|
||||
"""DETR head Config."""
|
||||
from mmdet.models import build_head
|
||||
model = build_head(
|
||||
dict(
|
||||
type='DETRHead',
|
||||
num_classes=4,
|
||||
in_channels=1,
|
||||
transformer=dict(
|
||||
type='Transformer',
|
||||
encoder=dict(
|
||||
type='DetrTransformerEncoder',
|
||||
num_layers=1,
|
||||
transformerlayers=dict(
|
||||
type='BaseTransformerLayer',
|
||||
attn_cfgs=[
|
||||
dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=4,
|
||||
num_heads=1,
|
||||
dropout=0.1)
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=4,
|
||||
feedforward_channels=32,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
feedforward_channels=32,
|
||||
ffn_dropout=0.1,
|
||||
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
|
||||
decoder=dict(
|
||||
type='DetrTransformerDecoder',
|
||||
return_intermediate=True,
|
||||
num_layers=1,
|
||||
transformerlayers=dict(
|
||||
type='DetrTransformerDecoderLayer',
|
||||
attn_cfgs=dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=4,
|
||||
num_heads=1,
|
||||
dropout=0.1),
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=4,
|
||||
feedforward_channels=32,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
feedforward_channels=32,
|
||||
ffn_dropout=0.1,
|
||||
operation_order=('self_attn', 'norm', 'cross_attn',
|
||||
'norm', 'ffn', 'norm')),
|
||||
)),
|
||||
positional_encoding=dict(
|
||||
type='SinePositionalEncoding', num_feats=2, normalize=True),
|
||||
test_cfg=dict(max_per_img=100)))
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_single_roi_extractor():
|
||||
"""SingleRoIExtractor Config."""
|
||||
from mmdet.models.roi_heads import SingleRoIExtractor
|
||||
|
@ -1692,3 +1756,40 @@ def test_mlvl_point_generator__single_level_grid_priors__tensorrt(
|
|||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type, ir_type',
|
||||
[(Backend.ONNXRUNTIME, 'onnx')])
|
||||
def test_detrhead_get_bboxes(backend_type: Backend, ir_type: str):
|
||||
"""Test get_bboxes rewrite of base dense head."""
|
||||
check_backend(backend_type)
|
||||
dense_head = get_detrhead_model()
|
||||
dense_head.cpu().eval()
|
||||
s = 128
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
deploy_cfg = get_deploy_cfg(backend_type, ir_type)
|
||||
|
||||
seed_everything(1234)
|
||||
cls_score = [[torch.rand(1, 100, 5) for i in range(5, 0, -1)]]
|
||||
seed_everything(5678)
|
||||
bboxes = [[torch.rand(1, 100, 4) for i in range(5, 0, -1)]]
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(dense_head, 'get_bboxes', img_metas=img_metas)
|
||||
rewrite_inputs = {
|
||||
'all_cls_scores_list': cls_score,
|
||||
'all_bbox_preds_list': bboxes,
|
||||
}
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=False)
|
||||
|
||||
assert rewrite_outputs is not None
|
||||
|
|
|
@ -390,3 +390,57 @@ def test_adaptive_avg_pool2d(output_size):
|
|||
deploy_cfg=deploy_cfg_ort,
|
||||
run_with_backend=True)
|
||||
assert torch.allclose(pytorch_output, rewrite_output[0])
|
||||
|
||||
|
||||
@backend_checker(Backend.TENSORRT)
|
||||
def test_scaled_dot_product_attention():
|
||||
L = 10
|
||||
B = 1
|
||||
E = 4
|
||||
q = k = v = torch.rand(B, L, E)
|
||||
attn_mask = torch.rand(B, L, L)
|
||||
|
||||
from torch.nn.functional import _scaled_dot_product_attention
|
||||
model = WrapFunction(_scaled_dot_product_attention)
|
||||
pytorch_output = model(q, k, v, attn_mask)
|
||||
deploy_cfg_ort = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(
|
||||
input_shape=None,
|
||||
input_names=['q', 'k', 'v', 'attn_mask'],
|
||||
output_names=['output', 'attn']),
|
||||
backend_config=dict(
|
||||
type='tensorrt',
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
q=dict(
|
||||
min_shape=q.shape,
|
||||
opt_shape=q.shape,
|
||||
max_shape=q.shape),
|
||||
k=dict(
|
||||
min_shape=k.shape,
|
||||
opt_shape=k.shape,
|
||||
max_shape=k.shape),
|
||||
v=dict(
|
||||
min_shape=v.shape,
|
||||
opt_shape=v.shape,
|
||||
max_shape=v.shape),
|
||||
attn_mask=dict(
|
||||
min_shape=attn_mask.shape,
|
||||
opt_shape=attn_mask.shape,
|
||||
max_shape=attn_mask.shape)))
|
||||
]),
|
||||
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
||||
rewrite_output, _ = get_rewrite_outputs(
|
||||
model,
|
||||
model_inputs={
|
||||
'q': q,
|
||||
'k': k,
|
||||
'v': v,
|
||||
'attn_mask': attn_mask
|
||||
},
|
||||
deploy_cfg=deploy_cfg_ort,
|
||||
run_with_backend=True)
|
||||
assert torch.allclose(pytorch_output[0],
|
||||
rewrite_output[0].to(pytorch_output[0].device))
|
||||
|
|
Loading…
Reference in New Issue