watermark

pull/2612/head^2
zhangyubo0722 2023-01-16 12:58:44 +00:00 committed by cuicheng01
parent 601496260c
commit 54f7dd7484
3 changed files with 75 additions and 11 deletions

View File

@ -79,6 +79,7 @@ from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
from .variant_models.vgg_variant import VGG19Sigmoid
from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu
from .variant_models.efficientnet_variant import EfficientNetB3_watermark
from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200
from .model_zoo.wideresnet import WideResNet
from .model_zoo.uniformer import UniFormer_small, UniFormer_small_plus, UniFormer_small_plus_dim64, UniFormer_base, UniFormer_base_ls

View File

@ -26,6 +26,7 @@ import collections
import re
import copy
from ..base.theseus_layer import TheseusLayer
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
@ -289,7 +290,7 @@ def _drop_connect(inputs, prob, is_test):
return output
class Conv2ds(nn.Layer):
class Conv2ds(TheseusLayer):
def __init__(self,
input_channels,
output_channels,
@ -361,13 +362,14 @@ class Conv2ds(nn.Layer):
return x
class ConvBNLayer(nn.Layer):
class ConvBNLayer(TheseusLayer):
def __init__(self,
input_channels,
filter_size,
output_channels,
stride=1,
num_groups=1,
global_params=None,
padding_type="SAME",
conv_act=None,
bn_act="swish",
@ -396,12 +398,13 @@ class ConvBNLayer(nn.Layer):
if use_bn is True:
bn_name = name + bn_name
param_attr, bias_attr = init_batch_norm_layer(bn_name)
epsilon = global_params.batch_norm_epsilon
self._bn = BatchNorm(
num_channels=output_channels,
act=bn_act,
momentum=0.99,
epsilon=0.001,
epsilon=epsilon,
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance",
param_attr=param_attr,
@ -416,10 +419,11 @@ class ConvBNLayer(nn.Layer):
return self._conv(inputs)
class ExpandConvNorm(nn.Layer):
class ExpandConvNorm(TheseusLayer):
def __init__(self,
input_channels,
block_args,
global_params,
padding_type,
name=None,
model_name=None,
@ -434,6 +438,7 @@ class ExpandConvNorm(nn.Layer):
input_channels,
1,
self.oup,
global_params=global_params,
bn_act=None,
padding_type=padding_type,
name=name,
@ -449,10 +454,11 @@ class ExpandConvNorm(nn.Layer):
return inputs
class DepthwiseConvNorm(nn.Layer):
class DepthwiseConvNorm(TheseusLayer):
def __init__(self,
input_channels,
block_args,
global_params,
padding_type,
name=None,
model_name=None,
@ -471,6 +477,7 @@ class DepthwiseConvNorm(nn.Layer):
oup,
self.s,
num_groups=input_channels,
global_params=global_params,
bn_act=None,
padding_type=padding_type,
name=name,
@ -483,10 +490,11 @@ class DepthwiseConvNorm(nn.Layer):
return self._conv(inputs)
class ProjectConvNorm(nn.Layer):
class ProjectConvNorm(TheseusLayer):
def __init__(self,
input_channels,
block_args,
global_params,
padding_type,
name=None,
model_name=None,
@ -499,6 +507,7 @@ class ProjectConvNorm(nn.Layer):
input_channels,
1,
final_oup,
global_params=global_params,
bn_act=None,
padding_type=padding_type,
name=name,
@ -511,7 +520,7 @@ class ProjectConvNorm(nn.Layer):
return self._conv(inputs)
class SEBlock(nn.Layer):
class SEBlock(TheseusLayer):
def __init__(self,
input_channels,
num_squeezed_channels,
@ -549,10 +558,11 @@ class SEBlock(nn.Layer):
return out
class MbConvBlock(nn.Layer):
class MbConvBlock(TheseusLayer):
def __init__(self,
input_channels,
block_args,
global_params,
padding_type,
use_se,
name=None,
@ -573,6 +583,7 @@ class MbConvBlock(nn.Layer):
self._ecn = ExpandConvNorm(
input_channels,
block_args,
global_params,
padding_type=padding_type,
name=name,
model_name=model_name,
@ -581,6 +592,7 @@ class MbConvBlock(nn.Layer):
self._dcn = DepthwiseConvNorm(
input_channels * block_args.expand_ratio,
block_args,
global_params,
padding_type=padding_type,
name=name,
model_name=model_name,
@ -601,6 +613,7 @@ class MbConvBlock(nn.Layer):
self._pcn = ProjectConvNorm(
input_channels * block_args.expand_ratio,
block_args,
global_params,
padding_type=padding_type,
name=name,
model_name=model_name,
@ -627,7 +640,7 @@ class MbConvBlock(nn.Layer):
return x
class ConvStemNorm(nn.Layer):
class ConvStemNorm(TheseusLayer):
def __init__(self,
input_channels,
padding_type,
@ -643,6 +656,7 @@ class ConvStemNorm(nn.Layer):
filter_size=3,
output_channels=output_channels,
stride=2,
global_params=_global_params,
bn_act=None,
padding_type=padding_type,
name="",
@ -655,7 +669,7 @@ class ConvStemNorm(nn.Layer):
return self._conv(inputs)
class ExtractFeatures(nn.Layer):
class ExtractFeatures(TheseusLayer):
def __init__(self,
input_channels,
_block_args,
@ -708,6 +722,7 @@ class ExtractFeatures(nn.Layer):
MbConvBlock(
block_args.input_filters,
block_args=block_args,
global_params=_global_params,
padding_type=padding_type,
use_se=use_se,
name="_blocks." + str(idx) + ".",
@ -728,6 +743,7 @@ class ExtractFeatures(nn.Layer):
MbConvBlock(
block_args.input_filters,
block_args,
global_params=_global_params,
padding_type=padding_type,
use_se=use_se,
name="_blocks." + str(idx) + ".",
@ -746,7 +762,7 @@ class ExtractFeatures(nn.Layer):
return x
class EfficientNet(nn.Layer):
class EfficientNet(TheseusLayer):
def __init__(self,
name="b0",
padding_type="SAME",
@ -789,6 +805,7 @@ class EfficientNet(nn.Layer):
oup,
1,
output_channels,
global_params=self._global_params,
bn_act="swish",
padding_type=self.padding_type,
name="",

View File

@ -0,0 +1,46 @@
import paddle
import paddle.nn as nn
from paddle.nn import Sigmoid
from paddle.nn import Tanh
from ..model_zoo.efficientnet import EfficientNetB3, _load_pretrained
MODEL_URLS = {
"EfficientNetB3_watermark":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB3_watermark_pretrained.pdparams"
}
__all__ = list(MODEL_URLS.keys())
def EfficientNetB3_watermark(padding_type='DYNAMIC',
override_params={"batch_norm_epsilon": 0.00001},
use_se=True,
pretrained=False,
use_ssld=False,
**kwargs):
def replace_function(_fc, pattern):
classifier = nn.Sequential(
# 1536 is the orginal in_features
nn.Linear(
in_features=1536, out_features=625),
nn.ReLU(), # ReLu to be the activation function
nn.Dropout(p=0.3),
nn.Linear(
in_features=625, out_features=256),
nn.ReLU(),
nn.Linear(
in_features=256, out_features=2), )
return classifier
pattern = "_fc"
model = EfficientNetB3(
padding_type=padding_type,
override_params=override_params,
use_se=True,
pretrained=False,
use_ssld=False,
**kwargs)
model.upgrade_sublayer(pattern, replace_function)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB3_watermark"],
use_ssld)
return model