mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fix mv2 and mv3
This commit is contained in:
parent
515c9c996b
commit
0e1789d4c9
@ -18,9 +18,10 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
|
||||
|
||||
import math
|
||||
|
||||
@ -30,7 +31,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
@ -43,16 +44,14 @@ class ConvBNLayer(fluid.dygraph.Layer):
|
||||
use_cudnn=True):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
self._conv = Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
act=None,
|
||||
use_cudnn=use_cudnn,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
self._batch_norm = BatchNorm(
|
||||
@ -66,11 +65,11 @@ class ConvBNLayer(fluid.dygraph.Layer):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
if if_act:
|
||||
y = fluid.layers.relu6(y)
|
||||
y = F.relu6(y)
|
||||
return y
|
||||
|
||||
|
||||
class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
class InvertedResidualUnit(nn.Layer):
|
||||
def __init__(self, num_channels, num_in_filter, num_filters, stride,
|
||||
filter_size, padding, expansion_factor, name):
|
||||
super(InvertedResidualUnit, self).__init__()
|
||||
@ -108,11 +107,11 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
|
||||
y = self._bottleneck_conv(y, if_act=True)
|
||||
y = self._linear_conv(y, if_act=False)
|
||||
if ifshortcut:
|
||||
y = fluid.layers.elementwise_add(inputs, y)
|
||||
y = paddle.elementwise_add(inputs, y)
|
||||
return y
|
||||
|
||||
|
||||
class InvresiBlocks(fluid.dygraph.Layer):
|
||||
class InvresiBlocks(nn.Layer):
|
||||
def __init__(self, in_c, t, c, n, s, name):
|
||||
super(InvresiBlocks, self).__init__()
|
||||
|
||||
@ -148,7 +147,7 @@ class InvresiBlocks(fluid.dygraph.Layer):
|
||||
return y
|
||||
|
||||
|
||||
class MobileNet(fluid.dygraph.Layer):
|
||||
class MobileNet(nn.Layer):
|
||||
def __init__(self, class_dim=1000, scale=1.0):
|
||||
super(MobileNet, self).__init__()
|
||||
self.scale = scale
|
||||
@ -204,7 +203,7 @@ class MobileNet(fluid.dygraph.Layer):
|
||||
self.out = Linear(
|
||||
self.out_c,
|
||||
class_dim,
|
||||
param_attr=ParamAttr(name="fc10_weights"),
|
||||
weight_attr=ParamAttr(name="fc10_weights"),
|
||||
bias_attr=ParamAttr(name="fc10_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
@ -213,7 +212,7 @@ class MobileNet(fluid.dygraph.Layer):
|
||||
y = block(y)
|
||||
y = self.conv9(y, if_act=True)
|
||||
y = self.pool2d_avg(y)
|
||||
y = fluid.layers.reshape(y, shape=[-1, self.out_c])
|
||||
y = paddle.reshape(y, shape=[-1, self.out_c])
|
||||
y = self.out(y)
|
||||
return y
|
||||
|
||||
|
@ -18,9 +18,12 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
|
||||
# TODO: need to be removed later!
|
||||
from paddle.fluid.regularizer import L2Decay
|
||||
|
||||
import math
|
||||
|
||||
@ -42,7 +45,7 @@ def make_divisible(v, divisor=8, min_value=None):
|
||||
return new_v
|
||||
|
||||
|
||||
class MobileNetV3(fluid.dygraph.Layer):
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self, scale=1.0, model_name="small", class_dim=1000):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
@ -133,20 +136,19 @@ class MobileNetV3(fluid.dygraph.Layer):
|
||||
self.pool = Pool2D(
|
||||
pool_type="avg", global_pooling=True, use_cudnn=False)
|
||||
|
||||
self.last_conv = Conv2D(
|
||||
num_channels=make_divisible(scale * self.cls_ch_squeeze),
|
||||
num_filters=self.cls_ch_expand,
|
||||
filter_size=1,
|
||||
self.last_conv = Conv2d(
|
||||
in_channels=make_divisible(scale * self.cls_ch_squeeze),
|
||||
out_channels=self.cls_ch_expand,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name="last_1x1_conv_weights"),
|
||||
weight_attr=ParamAttr(name="last_1x1_conv_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
self.out = Linear(
|
||||
input_dim=self.cls_ch_expand,
|
||||
output_dim=class_dim,
|
||||
param_attr=ParamAttr("fc_weights"),
|
||||
self.cls_ch_expand,
|
||||
class_dim,
|
||||
weight_attr=ParamAttr("fc_weights"),
|
||||
bias_attr=ParamAttr(name="fc_offset"))
|
||||
|
||||
def forward(self, inputs, label=None, dropout_prob=0.2):
|
||||
@ -156,15 +158,15 @@ class MobileNetV3(fluid.dygraph.Layer):
|
||||
x = self.last_second_conv(x)
|
||||
x = self.pool(x)
|
||||
x = self.last_conv(x)
|
||||
x = fluid.layers.hard_swish(x)
|
||||
x = fluid.layers.dropout(x=x, dropout_prob=dropout_prob)
|
||||
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
|
||||
x = F.hard_swish(x)
|
||||
x = F.dropout(x=x, p=dropout_prob)
|
||||
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
||||
x = self.out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(fluid.dygraph.Layer):
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
out_c,
|
||||
@ -179,28 +181,24 @@ class ConvBNLayer(fluid.dygraph.Layer):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = fluid.dygraph.Conv2D(
|
||||
num_channels=in_c,
|
||||
num_filters=out_c,
|
||||
filter_size=filter_size,
|
||||
self.conv = Conv2d(
|
||||
in_channels=in_c,
|
||||
out_channels=out_c,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False,
|
||||
use_cudnn=use_cudnn,
|
||||
act=None)
|
||||
self.bn = fluid.dygraph.BatchNorm(
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
self.bn = BatchNorm(
|
||||
num_channels=out_c,
|
||||
act=None,
|
||||
param_attr=ParamAttr(
|
||||
name=name + "_bn_scale",
|
||||
regularizer=fluid.regularizer.L2DecayRegularizer(
|
||||
regularization_coeff=0.0)),
|
||||
regularizer=L2Decay(regularization_coeff=0.0)),
|
||||
bias_attr=ParamAttr(
|
||||
name=name + "_bn_offset",
|
||||
regularizer=fluid.regularizer.L2DecayRegularizer(
|
||||
regularization_coeff=0.0)),
|
||||
regularizer=L2Decay(regularization_coeff=0.0)),
|
||||
moving_mean_name=name + "_bn_mean",
|
||||
moving_variance_name=name + "_bn_variance")
|
||||
|
||||
@ -209,16 +207,16 @@ class ConvBNLayer(fluid.dygraph.Layer):
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
if self.act == "relu":
|
||||
x = fluid.layers.relu(x)
|
||||
x = F.relu(x)
|
||||
elif self.act == "hard_swish":
|
||||
x = fluid.layers.hard_swish(x)
|
||||
x = F.hard_swish(x)
|
||||
else:
|
||||
print("The activation function is selected incorrectly.")
|
||||
exit()
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(fluid.dygraph.Layer):
|
||||
class ResidualUnit(nn.Layer):
|
||||
def __init__(self,
|
||||
in_c,
|
||||
mid_c,
|
||||
@ -270,40 +268,38 @@ class ResidualUnit(fluid.dygraph.Layer):
|
||||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = fluid.layers.elementwise_add(inputs, x)
|
||||
x = paddle.elementwise_add(inputs, x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(fluid.dygraph.Layer):
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4, name=""):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = fluid.dygraph.Pool2D(
|
||||
pool_type="avg", global_pooling=True, use_cudnn=False)
|
||||
self.conv1 = fluid.dygraph.Conv2D(
|
||||
num_channels=channel,
|
||||
num_filters=channel // reduction,
|
||||
filter_size=1,
|
||||
self.avg_pool = Pool2D(pool_type="avg", global_pooling=True)
|
||||
self.conv1 = Conv2d(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act="relu",
|
||||
param_attr=ParamAttr(name=name + "_1_weights"),
|
||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||
self.conv2 = fluid.dygraph.Conv2D(
|
||||
num_channels=channel // reduction,
|
||||
num_filters=channel,
|
||||
filter_size=1,
|
||||
self.conv2 = Conv2d(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name + "_2_weights"),
|
||||
weight_attr=ParamAttr(name + "_2_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_2_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = fluid.layers.hard_sigmoid(outputs)
|
||||
return fluid.layers.elementwise_mul(x=inputs, y=outputs, axis=0)
|
||||
outputs = F.hard_sigmoid(outputs)
|
||||
return paddle.multiply(x=inputs, y=outputs, axis=0)
|
||||
|
||||
|
||||
def MobileNetV3_small_x0_35(**args):
|
||||
|
Loading…
x
Reference in New Issue
Block a user