fix dnl_head export onnx inference difference type Cast error (#1161)

* fix export onnx inference difference type Cast error

* fix export onnx inference difference type Cast error.

* use yapf format

* use same device type with pairwise_weight
pull/1801/head
sshuair 2022-03-01 10:16:08 +08:00 committed by GitHub
parent 941d3619c0
commit 06b3533ee5
1 changed files with 7 additions and 2 deletions

View File

@ -26,8 +26,13 @@ class DisentangledNonLocal2d(NonLocal2d):
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight /= self.temperature
pairwise_weight /= torch.tensor(
theta_x.shape[-1],
dtype=torch.float,
device=pairwise_weight.device)**torch.tensor(
0.5, device=pairwise_weight.device)
pairwise_weight /= torch.tensor(
self.temperature, device=pairwise_weight.device)
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight