mirror of https://github.com/open-mmlab/mmcv.git
add ceil_mode and with_last_pool to vgg
parent
d4d1108a88
commit
020348ecb1
|
@ -16,7 +16,8 @@ def conv3x3(in_planes, out_planes, dilation=1):
|
|||
dilation=dilation)
|
||||
|
||||
|
||||
def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False):
|
||||
def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False,
|
||||
ceil_mode=False):
|
||||
layers = []
|
||||
for _ in range(num_blocks):
|
||||
layers.append(conv3x3(inplanes, planes, dilation))
|
||||
|
@ -24,7 +25,7 @@ def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False):
|
|||
layers.append(nn.BatchNorm2d(planes))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
inplanes = planes
|
||||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
|
||||
|
||||
return layers
|
||||
|
||||
|
@ -62,7 +63,9 @@ class VGG(nn.Module):
|
|||
out_indices=(0, 1, 2, 3, 4),
|
||||
frozen_stages=-1,
|
||||
bn_eval=True,
|
||||
bn_frozen=False):
|
||||
bn_frozen=False,
|
||||
ceil_mode=False,
|
||||
with_last_pool=True):
|
||||
super(VGG, self).__init__()
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError('invalid depth {} for vgg'.format(depth))
|
||||
|
@ -92,11 +95,14 @@ class VGG(nn.Module):
|
|||
planes,
|
||||
num_blocks,
|
||||
dilation=dilation,
|
||||
with_bn=with_bn)
|
||||
with_bn=with_bn,
|
||||
ceil_mode=ceil_mode)
|
||||
vgg_layers.extend(vgg_layer)
|
||||
self.inplanes = planes
|
||||
self.range_sub_modules.append([start_idx, end_idx])
|
||||
start_idx = end_idx
|
||||
if not with_last_pool:
|
||||
vgg_layers.pop(-1)
|
||||
self.module_name = 'features'
|
||||
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
|
||||
|
||||
|
|
Loading…
Reference in New Issue