Merge pull request #25 from yhcao6/vgg

add ceil_mode and with_last_pool to vgg
This commit is contained in:
Kai Chen 2018-12-07 10:48:29 +08:00 committed by GitHub
commit a097c65fbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,7 +16,8 @@ def conv3x3(in_planes, out_planes, dilation=1):
dilation=dilation) dilation=dilation)
def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False): def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False,
ceil_mode=False):
layers = [] layers = []
for _ in range(num_blocks): for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation)) layers.append(conv3x3(inplanes, planes, dilation))
@ -24,7 +25,7 @@ def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False):
layers.append(nn.BatchNorm2d(planes)) layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True)) layers.append(nn.ReLU(inplace=True))
inplanes = planes inplanes = planes
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
return layers return layers
@ -62,7 +63,9 @@ class VGG(nn.Module):
out_indices=(0, 1, 2, 3, 4), out_indices=(0, 1, 2, 3, 4),
frozen_stages=-1, frozen_stages=-1,
bn_eval=True, bn_eval=True,
bn_frozen=False): bn_frozen=False,
ceil_mode=False,
with_last_pool=True):
super(VGG, self).__init__() super(VGG, self).__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError('invalid depth {} for vgg'.format(depth)) raise KeyError('invalid depth {} for vgg'.format(depth))
@ -92,11 +95,14 @@ class VGG(nn.Module):
planes, planes,
num_blocks, num_blocks,
dilation=dilation, dilation=dilation,
with_bn=with_bn) with_bn=with_bn,
ceil_mode=ceil_mode)
vgg_layers.extend(vgg_layer) vgg_layers.extend(vgg_layer)
self.inplanes = planes self.inplanes = planes
self.range_sub_modules.append([start_idx, end_idx]) self.range_sub_modules.append([start_idx, end_idx])
start_idx = end_idx start_idx = end_idx
if not with_last_pool:
vgg_layers.pop(-1)
self.module_name = 'features' self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers)) self.add_module(self.module_name, nn.Sequential(*vgg_layers))