watermark
parent
601496260c
commit
54f7dd7484
|
@ -79,6 +79,7 @@ from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
|
|||
from .variant_models.vgg_variant import VGG19Sigmoid
|
||||
from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
|
||||
from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu
|
||||
from .variant_models.efficientnet_variant import EfficientNetB3_watermark
|
||||
from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200
|
||||
from .model_zoo.wideresnet import WideResNet
|
||||
from .model_zoo.uniformer import UniFormer_small, UniFormer_small_plus, UniFormer_small_plus_dim64, UniFormer_base, UniFormer_base_ls
|
||||
|
|
|
@ -26,6 +26,7 @@ import collections
|
|||
import re
|
||||
import copy
|
||||
|
||||
from ..base.theseus_layer import TheseusLayer
|
||||
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
|
@ -289,7 +290,7 @@ def _drop_connect(inputs, prob, is_test):
|
|||
return output
|
||||
|
||||
|
||||
class Conv2ds(nn.Layer):
|
||||
class Conv2ds(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
output_channels,
|
||||
|
@ -361,13 +362,14 @@ class Conv2ds(nn.Layer):
|
|||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
class ConvBNLayer(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
filter_size,
|
||||
output_channels,
|
||||
stride=1,
|
||||
num_groups=1,
|
||||
global_params=None,
|
||||
padding_type="SAME",
|
||||
conv_act=None,
|
||||
bn_act="swish",
|
||||
|
@ -396,12 +398,13 @@ class ConvBNLayer(nn.Layer):
|
|||
if use_bn is True:
|
||||
bn_name = name + bn_name
|
||||
param_attr, bias_attr = init_batch_norm_layer(bn_name)
|
||||
epsilon = global_params.batch_norm_epsilon
|
||||
|
||||
self._bn = BatchNorm(
|
||||
num_channels=output_channels,
|
||||
act=bn_act,
|
||||
momentum=0.99,
|
||||
epsilon=0.001,
|
||||
epsilon=epsilon,
|
||||
moving_mean_name=bn_name + "_mean",
|
||||
moving_variance_name=bn_name + "_variance",
|
||||
param_attr=param_attr,
|
||||
|
@ -416,10 +419,11 @@ class ConvBNLayer(nn.Layer):
|
|||
return self._conv(inputs)
|
||||
|
||||
|
||||
class ExpandConvNorm(nn.Layer):
|
||||
class ExpandConvNorm(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
block_args,
|
||||
global_params,
|
||||
padding_type,
|
||||
name=None,
|
||||
model_name=None,
|
||||
|
@ -434,6 +438,7 @@ class ExpandConvNorm(nn.Layer):
|
|||
input_channels,
|
||||
1,
|
||||
self.oup,
|
||||
global_params=global_params,
|
||||
bn_act=None,
|
||||
padding_type=padding_type,
|
||||
name=name,
|
||||
|
@ -449,10 +454,11 @@ class ExpandConvNorm(nn.Layer):
|
|||
return inputs
|
||||
|
||||
|
||||
class DepthwiseConvNorm(nn.Layer):
|
||||
class DepthwiseConvNorm(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
block_args,
|
||||
global_params,
|
||||
padding_type,
|
||||
name=None,
|
||||
model_name=None,
|
||||
|
@ -471,6 +477,7 @@ class DepthwiseConvNorm(nn.Layer):
|
|||
oup,
|
||||
self.s,
|
||||
num_groups=input_channels,
|
||||
global_params=global_params,
|
||||
bn_act=None,
|
||||
padding_type=padding_type,
|
||||
name=name,
|
||||
|
@ -483,10 +490,11 @@ class DepthwiseConvNorm(nn.Layer):
|
|||
return self._conv(inputs)
|
||||
|
||||
|
||||
class ProjectConvNorm(nn.Layer):
|
||||
class ProjectConvNorm(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
block_args,
|
||||
global_params,
|
||||
padding_type,
|
||||
name=None,
|
||||
model_name=None,
|
||||
|
@ -499,6 +507,7 @@ class ProjectConvNorm(nn.Layer):
|
|||
input_channels,
|
||||
1,
|
||||
final_oup,
|
||||
global_params=global_params,
|
||||
bn_act=None,
|
||||
padding_type=padding_type,
|
||||
name=name,
|
||||
|
@ -511,7 +520,7 @@ class ProjectConvNorm(nn.Layer):
|
|||
return self._conv(inputs)
|
||||
|
||||
|
||||
class SEBlock(nn.Layer):
|
||||
class SEBlock(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
num_squeezed_channels,
|
||||
|
@ -549,10 +558,11 @@ class SEBlock(nn.Layer):
|
|||
return out
|
||||
|
||||
|
||||
class MbConvBlock(nn.Layer):
|
||||
class MbConvBlock(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
block_args,
|
||||
global_params,
|
||||
padding_type,
|
||||
use_se,
|
||||
name=None,
|
||||
|
@ -573,6 +583,7 @@ class MbConvBlock(nn.Layer):
|
|||
self._ecn = ExpandConvNorm(
|
||||
input_channels,
|
||||
block_args,
|
||||
global_params,
|
||||
padding_type=padding_type,
|
||||
name=name,
|
||||
model_name=model_name,
|
||||
|
@ -581,6 +592,7 @@ class MbConvBlock(nn.Layer):
|
|||
self._dcn = DepthwiseConvNorm(
|
||||
input_channels * block_args.expand_ratio,
|
||||
block_args,
|
||||
global_params,
|
||||
padding_type=padding_type,
|
||||
name=name,
|
||||
model_name=model_name,
|
||||
|
@ -601,6 +613,7 @@ class MbConvBlock(nn.Layer):
|
|||
self._pcn = ProjectConvNorm(
|
||||
input_channels * block_args.expand_ratio,
|
||||
block_args,
|
||||
global_params,
|
||||
padding_type=padding_type,
|
||||
name=name,
|
||||
model_name=model_name,
|
||||
|
@ -627,7 +640,7 @@ class MbConvBlock(nn.Layer):
|
|||
return x
|
||||
|
||||
|
||||
class ConvStemNorm(nn.Layer):
|
||||
class ConvStemNorm(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
padding_type,
|
||||
|
@ -643,6 +656,7 @@ class ConvStemNorm(nn.Layer):
|
|||
filter_size=3,
|
||||
output_channels=output_channels,
|
||||
stride=2,
|
||||
global_params=_global_params,
|
||||
bn_act=None,
|
||||
padding_type=padding_type,
|
||||
name="",
|
||||
|
@ -655,7 +669,7 @@ class ConvStemNorm(nn.Layer):
|
|||
return self._conv(inputs)
|
||||
|
||||
|
||||
class ExtractFeatures(nn.Layer):
|
||||
class ExtractFeatures(TheseusLayer):
|
||||
def __init__(self,
|
||||
input_channels,
|
||||
_block_args,
|
||||
|
@ -708,6 +722,7 @@ class ExtractFeatures(nn.Layer):
|
|||
MbConvBlock(
|
||||
block_args.input_filters,
|
||||
block_args=block_args,
|
||||
global_params=_global_params,
|
||||
padding_type=padding_type,
|
||||
use_se=use_se,
|
||||
name="_blocks." + str(idx) + ".",
|
||||
|
@ -728,6 +743,7 @@ class ExtractFeatures(nn.Layer):
|
|||
MbConvBlock(
|
||||
block_args.input_filters,
|
||||
block_args,
|
||||
global_params=_global_params,
|
||||
padding_type=padding_type,
|
||||
use_se=use_se,
|
||||
name="_blocks." + str(idx) + ".",
|
||||
|
@ -746,7 +762,7 @@ class ExtractFeatures(nn.Layer):
|
|||
return x
|
||||
|
||||
|
||||
class EfficientNet(nn.Layer):
|
||||
class EfficientNet(TheseusLayer):
|
||||
def __init__(self,
|
||||
name="b0",
|
||||
padding_type="SAME",
|
||||
|
@ -789,6 +805,7 @@ class EfficientNet(nn.Layer):
|
|||
oup,
|
||||
1,
|
||||
output_channels,
|
||||
global_params=self._global_params,
|
||||
bn_act="swish",
|
||||
padding_type=self.padding_type,
|
||||
name="",
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.nn import Sigmoid
|
||||
from paddle.nn import Tanh
|
||||
from ..model_zoo.efficientnet import EfficientNetB3, _load_pretrained
|
||||
|
||||
MODEL_URLS = {
|
||||
"EfficientNetB3_watermark":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetB3_watermark_pretrained.pdparams"
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
|
||||
def EfficientNetB3_watermark(padding_type='DYNAMIC',
|
||||
override_params={"batch_norm_epsilon": 0.00001},
|
||||
use_se=True,
|
||||
pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs):
|
||||
def replace_function(_fc, pattern):
|
||||
classifier = nn.Sequential(
|
||||
# 1536 is the orginal in_features
|
||||
nn.Linear(
|
||||
in_features=1536, out_features=625),
|
||||
nn.ReLU(), # ReLu to be the activation function
|
||||
nn.Dropout(p=0.3),
|
||||
nn.Linear(
|
||||
in_features=625, out_features=256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
in_features=256, out_features=2), )
|
||||
return classifier
|
||||
|
||||
pattern = "_fc"
|
||||
model = EfficientNetB3(
|
||||
padding_type=padding_type,
|
||||
override_params=override_params,
|
||||
use_se=True,
|
||||
pretrained=False,
|
||||
use_ssld=False,
|
||||
**kwargs)
|
||||
model.upgrade_sublayer(pattern, replace_function)
|
||||
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB3_watermark"],
|
||||
use_ssld)
|
||||
return model
|
Loading…
Reference in New Issue