mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix resnest (#490)
This commit is contained in:
parent
bcaf6a8528
commit
d80d53f955
@ -165,14 +165,15 @@ class SplatConv(nn.Layer):
|
||||
|
||||
atten = self.conv3(gap)
|
||||
atten = self.rsoftmax(atten)
|
||||
atten = paddle.reshape(x=atten, shape=[-1, atten.shape[1], 1, 1])
|
||||
|
||||
if self.radix > 1:
|
||||
attens = paddle.split(atten, num_or_sections=self.radix, axis=1)
|
||||
y = paddle.add_n(
|
||||
[split * att for (att, split) in zip(attens, splited)])
|
||||
y = paddle.add_n([
|
||||
paddle.multiply(split, att)
|
||||
for (att, split) in zip(attens, splited)
|
||||
])
|
||||
else:
|
||||
y = x * atten
|
||||
y = paddle.multiply(x, atten)
|
||||
|
||||
return y
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user