mmclassification ConformerHead support (#1905)

* mmclassification ConformerHead support

* add mmclassification ConformerHead test config

---------

Co-authored-by: lishengxi <mtdp@MacBook-Pro-8.local>
This commit is contained in:
Shengxi Li 2023-03-23 17:27:39 +08:00 committed by GitHub
parent a14177c0eb
commit 032ce75afa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 1 deletions

View File

@ -35,9 +35,11 @@ def base_classifier__forward(
if self.head is not None:
output = self.head(output)
from mmcls.models.heads import MultiLabelClsHead
from mmcls.models.heads import ConformerHead, MultiLabelClsHead
if isinstance(self.head, MultiLabelClsHead):
output = torch.sigmoid(output)
elif isinstance(self.head, ConformerHead):
output = F.softmax(torch.add(output[0], output[1]), dim=1)
else:
output = F.softmax(output, dim=1)
return output

View File

@ -228,3 +228,11 @@ models:
- *pipeline_ort_static_fp32
- convert_image: *convert_image
deploy_config: configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py
- name: Conformer
metafile: configs/conformer/metafile.yml
model_configs:
- configs/conformer/conformer-tiny-p16_8xb128_in1k.py
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32