Create vgg_variant.py

pull/1166/head
Bin Lu 2021-08-23 19:07:35 +08:00 committed by GitHub
parent d963dc17da
commit 0b9f0ad001
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 0 deletions

View File

@ -0,0 +1,28 @@
import paddle
from paddle.nn import Sigmoid
from ppcls.arch.backbone.legendary_models.vgg import VGG19
__all__ = ["VGG19Sigmoid"]
class SigmoidSuffix(paddle.nn.Layer):
def __init__(self, origin_layer):
super(SigmoidSuffix, self).__init__()
self.origin_layer = origin_layer
self.sigmoid = Sigmoid()
def forward(self, *input, res_dict=None, **kwargs):
x = self.origin_layer(input)
x = self.sigmoid(x)
return x
def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs):
def replace_function(origin_layer):
new_layer = SigmoidSuffix(origin_layer)
return new_layer
match_re = "linear_2"
model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True)
return model