fix dnl_head export onnx inference difference type Cast error (#1161)
* fix export onnx inference difference type Cast error * fix export onnx inference difference type Cast error. * use yapf format * use same device type with pairwise_weightpull/1801/head
parent
941d3619c0
commit
06b3533ee5
|
@ -26,8 +26,13 @@ class DisentangledNonLocal2d(NonLocal2d):
|
|||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
if self.use_scale:
|
||||
# theta_x.shape[-1] is `self.inter_channels`
|
||||
pairwise_weight /= theta_x.shape[-1]**0.5
|
||||
pairwise_weight /= self.temperature
|
||||
pairwise_weight /= torch.tensor(
|
||||
theta_x.shape[-1],
|
||||
dtype=torch.float,
|
||||
device=pairwise_weight.device)**torch.tensor(
|
||||
0.5, device=pairwise_weight.device)
|
||||
pairwise_weight /= torch.tensor(
|
||||
self.temperature, device=pairwise_weight.device)
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
|
|
Loading…
Reference in New Issue