1
0
mirror of https://github.com/PaddlePaddle/PaddleClas.git synced 2025-06-03 21:55:06 +08:00

29 lines
778 B
Python
Raw Normal View History

2021-08-23 19:07:35 +08:00
import paddle
from paddle.nn import Sigmoid
2022-08-17 14:34:06 +00:00
from ..legendary_models.vgg import VGG19
2021-12-27 14:03:55 +00:00
2021-08-23 19:07:35 +08:00
__all__ = ["VGG19Sigmoid"]
2021-12-27 14:03:55 +00:00
2021-08-23 19:07:35 +08:00
class SigmoidSuffix(paddle.nn.Layer):
def __init__(self, origin_layer):
2021-12-27 14:03:55 +00:00
super().__init__()
2021-08-23 19:07:35 +08:00
self.origin_layer = origin_layer
self.sigmoid = Sigmoid()
2021-12-27 14:03:55 +00:00
2021-08-24 10:53:13 +08:00
def forward(self, input, res_dict=None, **kwargs):
2021-08-23 19:07:35 +08:00
x = self.origin_layer(input)
x = self.sigmoid(x)
return x
2021-12-27 14:03:55 +00:00
2021-08-23 19:07:35 +08:00
def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs):
2021-12-27 14:03:55 +00:00
def replace_function(origin_layer, pattern):
2021-08-23 19:07:35 +08:00
new_layer = SigmoidSuffix(origin_layer)
return new_layer
2021-12-27 14:03:55 +00:00
pattern = "fc2"
2021-08-23 19:07:35 +08:00
model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
2021-12-27 14:03:55 +00:00
model.upgrade_sublayer(pattern, replace_function)
2021-08-23 19:07:35 +08:00
return model