mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
mmclassification ConformerHead support (#1905)
* mmclassification ConformerHead support * add mmclassification ConformerHead test config --------- Co-authored-by: lishengxi <mtdp@MacBook-Pro-8.local>
This commit is contained in:
parent
a14177c0eb
commit
032ce75afa
@ -35,9 +35,11 @@ def base_classifier__forward(
|
||||
if self.head is not None:
|
||||
output = self.head(output)
|
||||
|
||||
from mmcls.models.heads import MultiLabelClsHead
|
||||
from mmcls.models.heads import ConformerHead, MultiLabelClsHead
|
||||
if isinstance(self.head, MultiLabelClsHead):
|
||||
output = torch.sigmoid(output)
|
||||
elif isinstance(self.head, ConformerHead):
|
||||
output = F.softmax(torch.add(output[0], output[1]), dim=1)
|
||||
else:
|
||||
output = F.softmax(output, dim=1)
|
||||
return output
|
||||
|
@ -228,3 +228,11 @@ models:
|
||||
- *pipeline_ort_static_fp32
|
||||
- convert_image: *convert_image
|
||||
deploy_config: configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py
|
||||
|
||||
- name: Conformer
|
||||
metafile: configs/conformer/metafile.yml
|
||||
model_configs:
|
||||
- configs/conformer/conformer-tiny-p16_8xb128_in1k.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp32
|
||||
|
Loading…
x
Reference in New Issue
Block a user