Update resnet.py

pull/742/head
cuicheng01 2021-05-28 11:29:37 +08:00 committed by GitHub
parent 45b02b86af
commit 90321ce38d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 10 deletions

View File

@ -82,7 +82,7 @@ class ConvBNLayer(TheseusLayer):
super().__init__()
self.is_vd_mode = is_vd_mode
self.act = act
self.avgpool = AvgPool2D(
self.avg_pool = AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.conv = Conv2D(
in_channels=num_channels,
@ -101,7 +101,7 @@ class ConvBNLayer(TheseusLayer):
def forward(self, x):
if self.is_vd_mode:
x = self.avgpool(x)
x = self.avg_pool(x)
x = self.conv(x)
x = self.bn(x)
if self.act:
@ -273,7 +273,7 @@ class ResNet(TheseusLayer):
for in_c, out_c, k, s in self.stem_cfg[version]
])
self.maxpool = MaxPool2D(kernel_size=3, stride=2, padding=1)
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
block_list = []
for block_idx in range(len(self.block_depth)):
shortcut = False
@ -290,22 +290,22 @@ class ResNet(TheseusLayer):
shortcut = True
self.blocks = nn.Sequential(*block_list)
self.avgpool = AdaptiveAvgPool2D(1)
self.avgpool_channels = self.num_channels[-1] * 2
self.avg_pool = AdaptiveAvgPool2D(1)
self.avg_pool_channels = self.num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.avgpool_channels * 1.0)
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
self.out = Linear(
self.avgpool_channels,
self.avg_pool_channels,
self.class_num,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv)))
def forward(self, x):
x = self.stem(x)
x = self.maxpool(x)
x = self.max_pool(x)
x = self.blocks(x)
x = self.avgpool(x)
x = paddle.reshape(x, shape=[-1, self.avgpool_channels])
x = self.avg_pool(x)
x = paddle.reshape(x, shape=[-1, self.avg_pool_channels])
x = self.out(x)
return x