Update resnet.py

pull/742/head
cuicheng01 2021-05-28 11:32:51 +08:00 committed by GitHub
parent 90321ce38d
commit e6792d34da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -294,7 +294,7 @@ class ResNet(TheseusLayer):
self.avg_pool_channels = self.num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
self.out = Linear(
self.fc = Linear(
self.avg_pool_channels,
self.class_num,
weight_attr=ParamAttr(
@ -306,7 +306,7 @@ class ResNet(TheseusLayer):
x = self.blocks(x)
x = self.avg_pool(x)
x = paddle.reshape(x, shape=[-1, self.avg_pool_channels])
x = self.out(x)
x = self.fc(x)
return x