minor fix init (#384)

* add densent init

* fix export model
pull/385/head
littletomatodonkey 2020-11-10 17:09:39 +08:00 committed by GitHub
parent 00a0f7fb56
commit c933dcd8db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 4 deletions

View File

@ -23,7 +23,7 @@ from .se_resnet_vd import SE_ResNet18_vd, SE_ResNet34_vd, SE_ResNet50_vd, SE_Res
from .se_resnext_vd import SE_ResNeXt50_vd_32x4d, SE_ResNeXt50_vd_32x4d, SENet154_vd
from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_64x4d
from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from .densenet import DenseNet121
from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
from .efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7
from .resnest import ResNeSt50_fast_1s1x64d, ResNeSt50

View File

@ -243,11 +243,12 @@ inp_shape = {
def _drop_connect(inputs, prob, is_test):
if is_test:
return inputs
keep_prob = 1.0 - prob
inputs_shape = paddle.shape(inputs)
random_tensor = keep_prob + paddle.rand(shape=[inputs_shape[0], 1, 1, 1])
binary_tensor = paddle.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
output = paddle.multiply(inputs, binary_tensor) / keep_prob
return output
@ -507,7 +508,8 @@ class SEBlock(nn.Layer):
x = self._pool(inputs)
x = self._conv1(x)
x = self._conv2(x)
return paddle.multiply(inputs, x)
out = paddle.multiply(inputs, x)
return out
class MbConvBlock(nn.Layer):
@ -572,11 +574,13 @@ class MbConvBlock(nn.Layer):
if self.expand_ratio != 1:
x = self._ecn(x)
x = F.swish(x)
x = self._dcn(x)
x = F.swish(x)
if self.has_se:
x = self._se(x)
x = self._pcn(x)
if self.id_skip and \
self.block_args.stride == 1 and \
self.block_args.input_filters == self.block_args.output_filters:

View File

@ -65,11 +65,12 @@ def main():
net = architectures.__dict__[args.model]
model = Net(net, to_static, args.class_dim)
load_dygraph_pretrain(
model.pre_net,
path=args.pretrained_model,
load_static_weights=args.load_static_weights)
model.eval()
paddle.jit.save(model, args.output_path)