From 06b3533ee56fe0164e6770f022c1cc5fa7bea25a Mon Sep 17 00:00:00 2001 From: sshuair Date: Tue, 1 Mar 2022 10:16:08 +0800 Subject: [PATCH] fix dnl_head export onnx inference difference type Cast error (#1161) * fix export onnx inference difference type Cast error * fix export onnx inference difference type Cast error. * use yapf format * use same device type with pairwise_weight --- mmseg/models/decode_heads/dnl_head.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mmseg/models/decode_heads/dnl_head.py b/mmseg/models/decode_heads/dnl_head.py index ab53d9a24..dabf15421 100644 --- a/mmseg/models/decode_heads/dnl_head.py +++ b/mmseg/models/decode_heads/dnl_head.py @@ -26,8 +26,13 @@ class DisentangledNonLocal2d(NonLocal2d): pairwise_weight = torch.matmul(theta_x, phi_x) if self.use_scale: # theta_x.shape[-1] is `self.inter_channels` - pairwise_weight /= theta_x.shape[-1]**0.5 - pairwise_weight /= self.temperature + pairwise_weight /= torch.tensor( + theta_x.shape[-1], + dtype=torch.float, + device=pairwise_weight.device)**torch.tensor( + 0.5, device=pairwise_weight.device) + pairwise_weight /= torch.tensor( + self.temperature, device=pairwise_weight.device) pairwise_weight = pairwise_weight.softmax(dim=-1) return pairwise_weight