Support AMP for SE_RexNeXt101_32x4d (#681)
parent
74dd40c6a2
commit
86346fda54
|
@ -0,0 +1,89 @@
|
|||
mode: 'train'
|
||||
ARCHITECTURE:
|
||||
name: 'SE_ResNeXt101_32x4d'
|
||||
|
||||
pretrained_model: ""
|
||||
model_save_dir: "./output/"
|
||||
classes_num: 1000
|
||||
total_images: 1281167
|
||||
save_interval: 1
|
||||
validate: True
|
||||
valid_interval: 1
|
||||
epochs: 200
|
||||
topk: 5
|
||||
is_distributed: False
|
||||
|
||||
use_dali: False
|
||||
use_gpu: True
|
||||
data_format: "NCHW"
|
||||
image_channel: &image_channel 4
|
||||
image_shape: [*image_channel, 224, 224]
|
||||
|
||||
|
||||
use_mix: False
|
||||
ls_epsilon: -1
|
||||
|
||||
AMP:
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
use_pure_fp16: &use_pure_fp16 True
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Cosine'
|
||||
params:
|
||||
lr: 0.1
|
||||
|
||||
OPTIMIZER:
|
||||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
multi_precision: *use_pure_fp16
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.000015
|
||||
|
||||
TRAIN:
|
||||
batch_size: 96
|
||||
num_workers: 0
|
||||
file_list: "/home/datasets/ILSVRC2012/train_list.txt"
|
||||
data_dir: "/home/datasets/ILSVRC2012/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
channel_num: *image_channel
|
||||
- ToCHWImage:
|
||||
|
||||
VALID:
|
||||
batch_size: 16
|
||||
num_workers: 0
|
||||
file_list: "/home/datasets/ILSVRC2012/val_list.txt"
|
||||
data_dir: "/home/datasets/ILSVRC2012/"
|
||||
shuffle_seed: 0
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
to_np: False
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
|
@ -38,7 +38,8 @@ class ConvBNLayer(nn.Layer):
|
|||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
name=None,
|
||||
data_format='NCHW'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
|
@ -49,7 +50,8 @@ class ConvBNLayer(nn.Layer):
|
|||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
bias_attr=False,
|
||||
data_format=data_format)
|
||||
bn_name = name + '_bn'
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
|
@ -57,7 +59,8 @@ class ConvBNLayer(nn.Layer):
|
|||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
moving_variance_name=bn_name + '_variance',
|
||||
data_layout=data_format)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
|
@ -74,7 +77,8 @@ class BottleneckBlock(nn.Layer):
|
|||
reduction_ratio,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
name=None,
|
||||
data_format="NCHW"):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
|
@ -82,7 +86,8 @@ class BottleneckBlock(nn.Layer):
|
|||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act='relu',
|
||||
name='conv' + name + '_x1')
|
||||
name='conv' + name + '_x1',
|
||||
data_format=data_format)
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
|
@ -90,18 +95,21 @@ class BottleneckBlock(nn.Layer):
|
|||
groups=cardinality,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name='conv' + name + '_x2')
|
||||
name='conv' + name + '_x2',
|
||||
data_format=data_format)
|
||||
self.conv2 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters * 2 if cardinality == 32 else num_filters,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name='conv' + name + '_x3')
|
||||
name='conv' + name + '_x3',
|
||||
data_format=data_format)
|
||||
self.scale = SELayer(
|
||||
num_channels=num_filters * 2 if cardinality == 32 else num_filters,
|
||||
num_filters=num_filters * 2 if cardinality == 32 else num_filters,
|
||||
reduction_ratio=reduction_ratio,
|
||||
name='fc' + name)
|
||||
name='fc' + name,
|
||||
data_format=data_format)
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
|
@ -110,7 +118,8 @@ class BottleneckBlock(nn.Layer):
|
|||
if cardinality == 32 else num_filters,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
name='conv' + name + '_prj')
|
||||
name='conv' + name + '_prj',
|
||||
data_format=data_format)
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
|
@ -130,10 +139,11 @@ class BottleneckBlock(nn.Layer):
|
|||
|
||||
|
||||
class SELayer(nn.Layer):
|
||||
def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
|
||||
def __init__(self, num_channels, num_filters, reduction_ratio, name=None, data_format="NCHW"):
|
||||
super(SELayer, self).__init__()
|
||||
|
||||
self.pool2d_gap = AdaptiveAvgPool2D(1)
|
||||
self.data_format = data_format
|
||||
self.pool2d_gap = AdaptiveAvgPool2D(1, data_format=self.data_format)
|
||||
|
||||
self._num_channels = num_channels
|
||||
|
||||
|
@ -157,23 +167,32 @@ class SELayer(nn.Layer):
|
|||
|
||||
def forward(self, input):
|
||||
pool = self.pool2d_gap(input)
|
||||
pool = paddle.squeeze(pool, axis=[2, 3])
|
||||
if self.data_format == "NHWC":
|
||||
pool = paddle.squeeze(pool, axis=[1, 2])
|
||||
else:
|
||||
pool = paddle.squeeze(pool, axis=[2, 3])
|
||||
squeeze = self.squeeze(pool)
|
||||
squeeze = self.relu(squeeze)
|
||||
excitation = self.excitation(squeeze)
|
||||
excitation = self.sigmoid(excitation)
|
||||
excitation = paddle.unsqueeze(excitation, axis=[2, 3])
|
||||
if self.data_format == "NHWC":
|
||||
excitation = paddle.unsqueeze(excitation, axis=[1, 2])
|
||||
else:
|
||||
excitation = paddle.unsqueeze(excitation, axis=[2, 3])
|
||||
out = input * excitation
|
||||
return out
|
||||
|
||||
|
||||
class ResNeXt(nn.Layer):
|
||||
def __init__(self, layers=50, class_dim=1000, cardinality=32):
|
||||
def __init__(self, layers=50, class_dim=1000, cardinality=32, input_image_channel=3, data_format="NCHW"):
|
||||
super(ResNeXt, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
self.cardinality = cardinality
|
||||
self.reduction_ratio = 16
|
||||
self.data_format = data_format
|
||||
self.input_image_channel = input_image_channel
|
||||
|
||||
supported_layers = [50, 101, 152]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
|
@ -193,36 +212,40 @@ class ResNeXt(nn.Layer):
|
|||
1024] if cardinality == 32 else [256, 512, 1024, 2048]
|
||||
if layers < 152:
|
||||
self.conv = ConvBNLayer(
|
||||
num_channels=3,
|
||||
num_channels=self.input_image_channel,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1")
|
||||
name="conv1",
|
||||
data_format=self.data_format)
|
||||
else:
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
num_channels=3,
|
||||
num_channels=self.input_image_channel,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1")
|
||||
name="conv1",
|
||||
data_format=self.data_format)
|
||||
self.conv1_2 = ConvBNLayer(
|
||||
num_channels=64,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv2")
|
||||
name="conv2",
|
||||
data_format=self.data_format)
|
||||
self.conv1_3 = ConvBNLayer(
|
||||
num_channels=64,
|
||||
num_filters=128,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv3")
|
||||
name="conv3",
|
||||
data_format=self.data_format)
|
||||
|
||||
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1, data_format=self.data_format)
|
||||
|
||||
self.block_list = []
|
||||
n = 1 if layers == 50 or layers == 101 else 3
|
||||
|
@ -241,11 +264,12 @@ class ResNeXt(nn.Layer):
|
|||
reduction_ratio=self.reduction_ratio,
|
||||
shortcut=shortcut,
|
||||
if_first=block == 0,
|
||||
name=str(n) + '_' + str(i + 1)))
|
||||
name=str(n) + '_' + str(i + 1),
|
||||
data_format=self.data_format))
|
||||
self.block_list.append(bottleneck_block)
|
||||
shortcut = True
|
||||
|
||||
self.pool2d_avg = AdaptiveAvgPool2D(1)
|
||||
self.pool2d_avg = AdaptiveAvgPool2D(1, data_format=self.data_format)
|
||||
|
||||
self.pool2d_avg_channels = num_channels[-1] * 2
|
||||
|
||||
|
@ -259,20 +283,23 @@ class ResNeXt(nn.Layer):
|
|||
bias_attr=ParamAttr(name="fc6_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.layers < 152:
|
||||
y = self.conv(inputs)
|
||||
else:
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
y = self.pool2d_max(y)
|
||||
|
||||
for block in self.block_list:
|
||||
y = block(y)
|
||||
y = self.pool2d_avg(y)
|
||||
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
|
||||
y = self.out(y)
|
||||
return y
|
||||
with paddle.static.amp.fp16_guard():
|
||||
if self.data_format == "NHWC":
|
||||
inputs = paddle.tensor.transpose(inputs, [0, 2, 3, 1])
|
||||
inputs.stop_gradient = True
|
||||
if self.layers < 152:
|
||||
y = self.conv(inputs)
|
||||
else:
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
y = self.pool2d_max(y)
|
||||
for i, block in enumerate(self.block_list):
|
||||
y = block(y)
|
||||
y = self.pool2d_avg(y)
|
||||
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
|
||||
y = self.out(y)
|
||||
return y
|
||||
|
||||
|
||||
def SE_ResNeXt50_32x4d(**args):
|
||||
|
|
Loading…
Reference in New Issue