add static training (#1037)
* add static training * fix typo * add se fp16 * rm note * fix loader * fix cfgpull/1046/head
parent
73004f78f5
commit
9d9cd3719e
|
@ -104,7 +104,8 @@ class ConvBNLayer(TheseusLayer):
|
||||||
groups=1,
|
groups=1,
|
||||||
is_vd_mode=False,
|
is_vd_mode=False,
|
||||||
act=None,
|
act=None,
|
||||||
lr_mult=1.0):
|
lr_mult=1.0,
|
||||||
|
data_format="NCHW"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_vd_mode = is_vd_mode
|
self.is_vd_mode = is_vd_mode
|
||||||
self.act = act
|
self.act = act
|
||||||
|
@ -118,11 +119,13 @@ class ConvBNLayer(TheseusLayer):
|
||||||
padding=(filter_size - 1) // 2,
|
padding=(filter_size - 1) // 2,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
weight_attr=ParamAttr(learning_rate=lr_mult),
|
weight_attr=ParamAttr(learning_rate=lr_mult),
|
||||||
bias_attr=False)
|
bias_attr=False,
|
||||||
|
data_format=data_format)
|
||||||
self.bn = BatchNorm(
|
self.bn = BatchNorm(
|
||||||
num_filters,
|
num_filters,
|
||||||
param_attr=ParamAttr(learning_rate=lr_mult),
|
param_attr=ParamAttr(learning_rate=lr_mult),
|
||||||
bias_attr=ParamAttr(learning_rate=lr_mult))
|
bias_attr=ParamAttr(learning_rate=lr_mult),
|
||||||
|
data_layout=data_format)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -136,14 +139,14 @@ class ConvBNLayer(TheseusLayer):
|
||||||
|
|
||||||
|
|
||||||
class BottleneckBlock(TheseusLayer):
|
class BottleneckBlock(TheseusLayer):
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
num_channels,
|
||||||
num_channels,
|
num_filters,
|
||||||
num_filters,
|
stride,
|
||||||
stride,
|
shortcut=True,
|
||||||
shortcut=True,
|
if_first=False,
|
||||||
if_first=False,
|
lr_mult=1.0,
|
||||||
lr_mult=1.0, ):
|
data_format="NCHW"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv0 = ConvBNLayer(
|
self.conv0 = ConvBNLayer(
|
||||||
|
@ -151,20 +154,23 @@ class BottleneckBlock(TheseusLayer):
|
||||||
num_filters=num_filters,
|
num_filters=num_filters,
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
act="relu",
|
act="relu",
|
||||||
lr_mult=lr_mult)
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.conv1 = ConvBNLayer(
|
self.conv1 = ConvBNLayer(
|
||||||
num_channels=num_filters,
|
num_channels=num_filters,
|
||||||
num_filters=num_filters,
|
num_filters=num_filters,
|
||||||
filter_size=3,
|
filter_size=3,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
act="relu",
|
act="relu",
|
||||||
lr_mult=lr_mult)
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.conv2 = ConvBNLayer(
|
self.conv2 = ConvBNLayer(
|
||||||
num_channels=num_filters,
|
num_channels=num_filters,
|
||||||
num_filters=num_filters * 4,
|
num_filters=num_filters * 4,
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
act=None,
|
act=None,
|
||||||
lr_mult=lr_mult)
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
|
|
||||||
if not shortcut:
|
if not shortcut:
|
||||||
self.short = ConvBNLayer(
|
self.short = ConvBNLayer(
|
||||||
|
@ -173,7 +179,8 @@ class BottleneckBlock(TheseusLayer):
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
stride=stride if if_first else 1,
|
stride=stride if if_first else 1,
|
||||||
is_vd_mode=False if if_first else True,
|
is_vd_mode=False if if_first else True,
|
||||||
lr_mult=lr_mult)
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.shortcut = shortcut
|
self.shortcut = shortcut
|
||||||
|
|
||||||
|
@ -199,7 +206,8 @@ class BasicBlock(TheseusLayer):
|
||||||
stride,
|
stride,
|
||||||
shortcut=True,
|
shortcut=True,
|
||||||
if_first=False,
|
if_first=False,
|
||||||
lr_mult=1.0):
|
lr_mult=1.0,
|
||||||
|
data_format="NCHW"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
@ -209,13 +217,15 @@ class BasicBlock(TheseusLayer):
|
||||||
filter_size=3,
|
filter_size=3,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
act="relu",
|
act="relu",
|
||||||
lr_mult=lr_mult)
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.conv1 = ConvBNLayer(
|
self.conv1 = ConvBNLayer(
|
||||||
num_channels=num_filters,
|
num_channels=num_filters,
|
||||||
num_filters=num_filters,
|
num_filters=num_filters,
|
||||||
filter_size=3,
|
filter_size=3,
|
||||||
act=None,
|
act=None,
|
||||||
lr_mult=lr_mult)
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
if not shortcut:
|
if not shortcut:
|
||||||
self.short = ConvBNLayer(
|
self.short = ConvBNLayer(
|
||||||
num_channels=num_channels,
|
num_channels=num_channels,
|
||||||
|
@ -223,7 +233,8 @@ class BasicBlock(TheseusLayer):
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
stride=stride if if_first else 1,
|
stride=stride if if_first else 1,
|
||||||
is_vd_mode=False if if_first else True,
|
is_vd_mode=False if if_first else True,
|
||||||
lr_mult=lr_mult)
|
lr_mult=lr_mult,
|
||||||
|
data_format=data_format)
|
||||||
self.shortcut = shortcut
|
self.shortcut = shortcut
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
@ -256,7 +267,9 @@ class ResNet(TheseusLayer):
|
||||||
config,
|
config,
|
||||||
version="vb",
|
version="vb",
|
||||||
class_num=1000,
|
class_num=1000,
|
||||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
|
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||||
|
data_format="NCHW",
|
||||||
|
input_image_channel=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cfg = config
|
self.cfg = config
|
||||||
|
@ -279,22 +292,25 @@ class ResNet(TheseusLayer):
|
||||||
|
|
||||||
self.stem_cfg = {
|
self.stem_cfg = {
|
||||||
#num_channels, num_filters, filter_size, stride
|
#num_channels, num_filters, filter_size, stride
|
||||||
"vb": [[3, 64, 7, 2]],
|
"vb": [[input_image_channel, 64, 7, 2]],
|
||||||
"vd": [[3, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
|
"vd":
|
||||||
|
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
|
||||||
}
|
}
|
||||||
|
|
||||||
self.stem = nn.Sequential(*[
|
self.stem = nn.Sequential(* [
|
||||||
ConvBNLayer(
|
ConvBNLayer(
|
||||||
num_channels=in_c,
|
num_channels=in_c,
|
||||||
num_filters=out_c,
|
num_filters=out_c,
|
||||||
filter_size=k,
|
filter_size=k,
|
||||||
stride=s,
|
stride=s,
|
||||||
act="relu",
|
act="relu",
|
||||||
lr_mult=self.lr_mult_list[0])
|
lr_mult=self.lr_mult_list[0],
|
||||||
|
data_format=data_format)
|
||||||
for in_c, out_c, k, s in self.stem_cfg[version]
|
for in_c, out_c, k, s in self.stem_cfg[version]
|
||||||
])
|
])
|
||||||
|
|
||||||
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
|
self.max_pool = MaxPool2D(
|
||||||
|
kernel_size=3, stride=2, padding=1, data_format=data_format)
|
||||||
block_list = []
|
block_list = []
|
||||||
for block_idx in range(len(self.block_depth)):
|
for block_idx in range(len(self.block_depth)):
|
||||||
shortcut = False
|
shortcut = False
|
||||||
|
@ -306,11 +322,12 @@ class ResNet(TheseusLayer):
|
||||||
stride=2 if i == 0 and block_idx != 0 else 1,
|
stride=2 if i == 0 and block_idx != 0 else 1,
|
||||||
shortcut=shortcut,
|
shortcut=shortcut,
|
||||||
if_first=block_idx == i == 0 if version == "vd" else True,
|
if_first=block_idx == i == 0 if version == "vd" else True,
|
||||||
lr_mult=self.lr_mult_list[block_idx + 1]))
|
lr_mult=self.lr_mult_list[block_idx + 1],
|
||||||
|
data_format=data_format))
|
||||||
shortcut = True
|
shortcut = True
|
||||||
self.blocks = nn.Sequential(*block_list)
|
self.blocks = nn.Sequential(*block_list)
|
||||||
|
|
||||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
|
||||||
self.flatten = nn.Flatten()
|
self.flatten = nn.Flatten()
|
||||||
self.avg_pool_channels = self.num_channels[-1] * 2
|
self.avg_pool_channels = self.num_channels[-1] * 2
|
||||||
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
|
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
|
||||||
|
@ -319,13 +336,19 @@ class ResNet(TheseusLayer):
|
||||||
self.class_num,
|
self.class_num,
|
||||||
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
|
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
|
||||||
|
|
||||||
|
self.data_format = data_format
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.stem(x)
|
with paddle.static.amp.fp16_guard():
|
||||||
x = self.max_pool(x)
|
if self.data_format == "NHWC":
|
||||||
x = self.blocks(x)
|
x = paddle.transpose(x, [0, 2, 3, 1])
|
||||||
x = self.avg_pool(x)
|
x.stop_gradient = True
|
||||||
x = self.flatten(x)
|
x = self.stem(x)
|
||||||
x = self.fc(x)
|
x = self.max_pool(x)
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.avg_pool(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,145 @@
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 120
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_channel: &image_channel 4
|
||||||
|
image_shape: [*image_channel, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
# training model under @to_static
|
||||||
|
to_static: False
|
||||||
|
|
||||||
|
# mixed precision training
|
||||||
|
AMP:
|
||||||
|
scale_loss: 128.0
|
||||||
|
use_dynamic_loss_scaling: True
|
||||||
|
use_pure_fp16: &use_pure_fp16 True
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: ResNet50
|
||||||
|
class_num: 1000
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
multi_precision: False # *use_pure_fp16
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
learning_rate: 0.1
|
||||||
|
decay_epochs: [30, 60, 90]
|
||||||
|
values: [0.1, 0.01, 0.001, 0.0001]
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
coeff: 0.0001
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
output_fp16: *use_pure_fp16
|
||||||
|
channel_num: *image_channel
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 32
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
output_fp16: *use_pure_fp16
|
||||||
|
channel_num: *image_channel
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: docs/images/whl/demo.jpg
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
output_fp16: *use_pure_fp16
|
||||||
|
channel_num: *image_channel
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
|
@ -0,0 +1,139 @@
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 200
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_channel: &image_channel 4
|
||||||
|
image_shape: [*image_channel, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: SE_ResNeXt101_32x4d
|
||||||
|
class_num: 1000
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
epsilon: 0.1
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
# mixed precision training
|
||||||
|
AMP:
|
||||||
|
scale_loss: 128.0
|
||||||
|
use_dynamic_loss_scaling: True
|
||||||
|
use_pure_fp16: &use_pure_fp16 True
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.1
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
coeff: 0.00007
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
output_fp16: *use_pure_fp16
|
||||||
|
channel_num: *image_channel
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
output_fp16: *use_pure_fp16
|
||||||
|
channel_num: *image_channel
|
||||||
|
sampler:
|
||||||
|
name: BatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 4
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: docs/images/whl/demo.jpg
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
output_fp16: *use_pure_fp16
|
||||||
|
channel_num: *image_channel
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- TopkAcc:
|
||||||
|
topk: [1, 5]
|
|
@ -60,6 +60,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
|
||||||
if use_dali:
|
if use_dali:
|
||||||
from ppcls.data.dataloader.dali import dali_dataloader
|
from ppcls.data.dataloader.dali import dali_dataloader
|
||||||
return dali_dataloader(config, mode, paddle.device.get_device(), seed)
|
return dali_dataloader(config, mode, paddle.device.get_device(), seed)
|
||||||
|
|
||||||
config_dataset = config[mode]['dataset']
|
config_dataset = config[mode]['dataset']
|
||||||
config_dataset = copy.deepcopy(config_dataset)
|
config_dataset = copy.deepcopy(config_dataset)
|
||||||
dataset_name = config_dataset.pop('name')
|
dataset_name = config_dataset.pop('name')
|
||||||
|
@ -74,10 +75,6 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
|
||||||
|
|
||||||
# build sampler
|
# build sampler
|
||||||
config_sampler = config[mode]['sampler']
|
config_sampler = config[mode]['sampler']
|
||||||
#config_sampler["batch_size"] = config_sampler[
|
|
||||||
# "batch_size"] // paddle.distributed.get_world_size()
|
|
||||||
#assert config_sampler[
|
|
||||||
# "batch_size"] >= 1, "The batch_size should be larger than gpu number."
|
|
||||||
if "name" not in config_sampler:
|
if "name" not in config_sampler:
|
||||||
batch_sampler = None
|
batch_sampler = None
|
||||||
batch_size = config_sampler["batch_size"]
|
batch_size = config_sampler["batch_size"]
|
||||||
|
|
|
@ -148,7 +148,6 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
assert "gpu" in device, "gpu training is required for DALI"
|
assert "gpu" in device, "gpu training is required for DALI"
|
||||||
device_id = int(device.split(':')[1])
|
device_id = int(device.split(':')[1])
|
||||||
config_dataloader = config[mode]
|
config_dataloader = config[mode]
|
||||||
# mode = 'train' if mode.lower() == 'train' else 'eval'
|
|
||||||
seed = 42 if seed is None else seed
|
seed = 42 if seed is None else seed
|
||||||
ops = [
|
ops = [
|
||||||
list(x.keys())[0]
|
list(x.keys())[0]
|
||||||
|
@ -160,6 +159,7 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
support_ops_eval = [
|
support_ops_eval = [
|
||||||
"DecodeImage", "ResizeImage", "CropImage", "NormalizeImage"
|
"DecodeImage", "ResizeImage", "CropImage", "NormalizeImage"
|
||||||
]
|
]
|
||||||
|
|
||||||
if mode.lower() == 'train':
|
if mode.lower() == 'train':
|
||||||
assert set(ops) == set(
|
assert set(ops) == set(
|
||||||
support_ops_train
|
support_ops_train
|
||||||
|
@ -171,6 +171,14 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
), "The supported trasform_ops for eval_dataset in dali is : {}".format(
|
), "The supported trasform_ops for eval_dataset in dali is : {}".format(
|
||||||
",".join(support_ops_eval))
|
",".join(support_ops_eval))
|
||||||
|
|
||||||
|
normalize_ops = [
|
||||||
|
op for op in config_dataloader["dataset"]["transform_ops"]
|
||||||
|
if "NormalizeImage" in op
|
||||||
|
][0]["NormalizeImage"]
|
||||||
|
channel_num = normalize_ops.get("channel_num", 3)
|
||||||
|
output_dtype = types.FLOAT16 if normalize_ops.get("output_fp16",
|
||||||
|
False) else types.FLOAT
|
||||||
|
|
||||||
env = os.environ
|
env = os.environ
|
||||||
# assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
|
# assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
|
||||||
# "Please leave enough GPU memory for DALI workspace, e.g., by setting" \
|
# "Please leave enough GPU memory for DALI workspace, e.g., by setting" \
|
||||||
|
@ -179,9 +187,6 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
gpu_num = paddle.distributed.get_world_size()
|
gpu_num = paddle.distributed.get_world_size()
|
||||||
|
|
||||||
batch_size = config_dataloader["sampler"]["batch_size"]
|
batch_size = config_dataloader["sampler"]["batch_size"]
|
||||||
# assert batch_size % gpu_num == 0, \
|
|
||||||
# "batch size must be multiple of number of devices"
|
|
||||||
# batch_size = batch_size // gpu_num
|
|
||||||
|
|
||||||
file_root = config_dataloader["dataset"]["image_root"]
|
file_root = config_dataloader["dataset"]["image_root"]
|
||||||
file_list = config_dataloader["dataset"]["cls_label_path"]
|
file_list = config_dataloader["dataset"]["cls_label_path"]
|
||||||
|
@ -195,15 +200,9 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
|
INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
|
||||||
}
|
}
|
||||||
|
|
||||||
output_dtype = (types.FLOAT16 if 'AMP' in config and
|
|
||||||
config.AMP.get("use_pure_fp16", False) else types.FLOAT)
|
|
||||||
|
|
||||||
assert interp in interp_map, "interpolation method not supported by DALI"
|
assert interp in interp_map, "interpolation method not supported by DALI"
|
||||||
interp = interp_map[interp]
|
interp = interp_map[interp]
|
||||||
pad_output = False
|
pad_output = channel_num == 4
|
||||||
image_shape = config.get("image_shape", None)
|
|
||||||
if image_shape and image_shape[0] == 4:
|
|
||||||
pad_output = True
|
|
||||||
|
|
||||||
transforms = {
|
transforms = {
|
||||||
k: v
|
k: v
|
||||||
|
@ -218,6 +217,10 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
mean = [v / scale for v in mean]
|
mean = [v / scale for v in mean]
|
||||||
std = [v / scale for v in std]
|
std = [v / scale for v in std]
|
||||||
|
|
||||||
|
sampler_name = config_dataloader["sampler"].get("name",
|
||||||
|
"DistributedBatchSampler")
|
||||||
|
assert sampler_name in ["DistributedBatchSampler", "BatchSampler"]
|
||||||
|
|
||||||
if mode.lower() == "train":
|
if mode.lower() == "train":
|
||||||
resize_shorter = 256
|
resize_shorter = 256
|
||||||
crop = transforms["RandCropImage"]["size"]
|
crop = transforms["RandCropImage"]["size"]
|
||||||
|
@ -279,10 +282,11 @@ def dali_dataloader(config, mode, device, seed=None):
|
||||||
else:
|
else:
|
||||||
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
|
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
|
||||||
crop = transforms["CropImage"]["size"]
|
crop = transforms["CropImage"]["size"]
|
||||||
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
|
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler":
|
||||||
shard_id = int(env['PADDLE_TRAINER_ID'])
|
shard_id = int(env['PADDLE_TRAINER_ID'])
|
||||||
num_shards = int(env['PADDLE_TRAINERS_NUM'])
|
num_shards = int(env['PADDLE_TRAINERS_NUM'])
|
||||||
device_id = int(env['FLAGS_selected_gpus'])
|
device_id = int(env['FLAGS_selected_gpus'])
|
||||||
|
|
||||||
pipe = HybridValPipe(
|
pipe = HybridValPipe(
|
||||||
file_root,
|
file_root,
|
||||||
file_list,
|
file_list,
|
||||||
|
|
|
@ -197,14 +197,26 @@ class NormalizeImage(object):
|
||||||
""" normalize image such as substract mean, divide std
|
""" normalize image such as substract mean, divide std
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scale=None, mean=None, std=None, order='chw'):
|
def __init__(self,
|
||||||
|
scale=None,
|
||||||
|
mean=None,
|
||||||
|
std=None,
|
||||||
|
order='chw',
|
||||||
|
output_fp16=False,
|
||||||
|
channel_num=3):
|
||||||
if isinstance(scale, str):
|
if isinstance(scale, str):
|
||||||
scale = eval(scale)
|
scale = eval(scale)
|
||||||
|
assert channel_num in [
|
||||||
|
3, 4
|
||||||
|
], "channel number of input image should be set to 3 or 4."
|
||||||
|
self.channel_num = channel_num
|
||||||
|
self.output_dtype = 'float16' if output_fp16 else 'float32'
|
||||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||||
|
self.order = order
|
||||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
|
||||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||||
self.std = np.array(std).reshape(shape).astype('float32')
|
self.std = np.array(std).reshape(shape).astype('float32')
|
||||||
|
|
||||||
|
@ -215,7 +227,20 @@ class NormalizeImage(object):
|
||||||
|
|
||||||
assert isinstance(img,
|
assert isinstance(img,
|
||||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||||
return (img.astype('float32') * self.scale - self.mean) / self.std
|
|
||||||
|
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
||||||
|
|
||||||
|
if self.channel_num == 4:
|
||||||
|
img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
|
||||||
|
img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
|
||||||
|
pad_zeros = np.zeros(
|
||||||
|
(1, img_h, img_w)) if self.order == 'chw' else np.zeros(
|
||||||
|
(img_h, img_w, 1))
|
||||||
|
img = (np.concatenate(
|
||||||
|
(img, pad_zeros), axis=0)
|
||||||
|
if self.order == 'chw' else np.concatenate(
|
||||||
|
(img, pad_zeros), axis=2))
|
||||||
|
return img.astype(self.output_dtype)
|
||||||
|
|
||||||
|
|
||||||
class ToCHWImage(object):
|
class ToCHWImage(object):
|
||||||
|
|
|
@ -41,7 +41,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
def build_optimizer(config, epochs, step_each_epoch, parameters):
|
def build_optimizer(config, epochs, step_each_epoch, parameters=None):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
# step1 build lr
|
# step1 build lr
|
||||||
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
|
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
|
||||||
|
|
|
@ -33,12 +33,14 @@ class Momentum(object):
|
||||||
learning_rate,
|
learning_rate,
|
||||||
momentum,
|
momentum,
|
||||||
weight_decay=None,
|
weight_decay=None,
|
||||||
grad_clip=None):
|
grad_clip=None,
|
||||||
|
multi_precision=False):
|
||||||
super(Momentum, self).__init__()
|
super(Momentum, self).__init__()
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.grad_clip = grad_clip
|
self.grad_clip = grad_clip
|
||||||
|
self.multi_precision = multi_precision
|
||||||
|
|
||||||
def __call__(self, parameters):
|
def __call__(self, parameters):
|
||||||
opt = optim.Momentum(
|
opt = optim.Momentum(
|
||||||
|
@ -46,6 +48,7 @@ class Momentum(object):
|
||||||
momentum=self.momentum,
|
momentum=self.momentum,
|
||||||
weight_decay=self.weight_decay,
|
weight_decay=self.weight_decay,
|
||||||
grad_clip=self.grad_clip,
|
grad_clip=self.grad_clip,
|
||||||
|
multi_precision=self.multi_precision,
|
||||||
parameters=parameters)
|
parameters=parameters)
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
@ -60,7 +63,8 @@ class Adam(object):
|
||||||
weight_decay=None,
|
weight_decay=None,
|
||||||
grad_clip=None,
|
grad_clip=None,
|
||||||
name=None,
|
name=None,
|
||||||
lazy_mode=False):
|
lazy_mode=False,
|
||||||
|
multi_precision=False):
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.beta1 = beta1
|
self.beta1 = beta1
|
||||||
self.beta2 = beta2
|
self.beta2 = beta2
|
||||||
|
@ -71,6 +75,7 @@ class Adam(object):
|
||||||
self.grad_clip = grad_clip
|
self.grad_clip = grad_clip
|
||||||
self.name = name
|
self.name = name
|
||||||
self.lazy_mode = lazy_mode
|
self.lazy_mode = lazy_mode
|
||||||
|
self.multi_precision = multi_precision
|
||||||
|
|
||||||
def __call__(self, parameters):
|
def __call__(self, parameters):
|
||||||
opt = optim.Adam(
|
opt = optim.Adam(
|
||||||
|
@ -82,6 +87,7 @@ class Adam(object):
|
||||||
grad_clip=self.grad_clip,
|
grad_clip=self.grad_clip,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
lazy_mode=self.lazy_mode,
|
lazy_mode=self.lazy_mode,
|
||||||
|
multi_precision=self.multi_precision,
|
||||||
parameters=parameters)
|
parameters=parameters)
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
@ -104,7 +110,8 @@ class RMSProp(object):
|
||||||
rho=0.95,
|
rho=0.95,
|
||||||
epsilon=1e-6,
|
epsilon=1e-6,
|
||||||
weight_decay=None,
|
weight_decay=None,
|
||||||
grad_clip=None):
|
grad_clip=None,
|
||||||
|
multi_precision=False):
|
||||||
super(RMSProp, self).__init__()
|
super(RMSProp, self).__init__()
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
|
@ -112,6 +119,7 @@ class RMSProp(object):
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.grad_clip = grad_clip
|
self.grad_clip = grad_clip
|
||||||
|
self.multi_precision = multi_precision
|
||||||
|
|
||||||
def __call__(self, parameters):
|
def __call__(self, parameters):
|
||||||
opt = optim.RMSProp(
|
opt = optim.RMSProp(
|
||||||
|
@ -121,5 +129,6 @@ class RMSProp(object):
|
||||||
epsilon=self.epsilon,
|
epsilon=self.epsilon,
|
||||||
weight_decay=self.weight_decay,
|
weight_decay=self.weight_decay,
|
||||||
grad_clip=self.grad_clip,
|
grad_clip=self.grad_clip,
|
||||||
|
multi_precision=self.multi_precision,
|
||||||
parameters=parameters)
|
parameters=parameters)
|
||||||
return opt
|
return opt
|
||||||
|
|
|
@ -0,0 +1,456 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from paddle.distributed import fleet
|
||||||
|
from paddle.distributed.fleet import DistributedStrategy
|
||||||
|
|
||||||
|
# from ppcls.optimizer import OptimizerBuilder
|
||||||
|
# from ppcls.optimizer.learning_rate import LearningRateBuilder
|
||||||
|
|
||||||
|
from ppcls.arch import build_model
|
||||||
|
from ppcls.loss import build_loss
|
||||||
|
from ppcls.metric import build_metrics
|
||||||
|
from ppcls.optimizer import build_optimizer
|
||||||
|
from ppcls.optimizer import build_lr_scheduler
|
||||||
|
|
||||||
|
from ppcls.utils.misc import AverageMeter
|
||||||
|
from ppcls.utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
def create_feeds(image_shape, use_mix=None, dtype="float32"):
|
||||||
|
"""
|
||||||
|
Create feeds as model input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_shape(list[int]): model input shape, such as [3, 224, 224]
|
||||||
|
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
feeds(dict): dict of model input variables
|
||||||
|
"""
|
||||||
|
feeds = OrderedDict()
|
||||||
|
feeds['data'] = paddle.static.data(
|
||||||
|
name="data", shape=[None] + image_shape, dtype=dtype)
|
||||||
|
if use_mix:
|
||||||
|
feeds['y_a'] = paddle.static.data(
|
||||||
|
name="y_a", shape=[None, 1], dtype="int64")
|
||||||
|
feeds['y_b'] = paddle.static.data(
|
||||||
|
name="y_b", shape=[None, 1], dtype="int64")
|
||||||
|
feeds['lam'] = paddle.static.data(
|
||||||
|
name="lam", shape=[None, 1], dtype=dtype)
|
||||||
|
else:
|
||||||
|
feeds['label'] = paddle.static.data(
|
||||||
|
name="label", shape=[None, 1], dtype="int64")
|
||||||
|
|
||||||
|
return feeds
|
||||||
|
|
||||||
|
|
||||||
|
def create_fetchs(out,
|
||||||
|
feeds,
|
||||||
|
architecture,
|
||||||
|
topk=5,
|
||||||
|
epsilon=None,
|
||||||
|
use_mix=False,
|
||||||
|
config=None,
|
||||||
|
mode="Train"):
|
||||||
|
"""
|
||||||
|
Create fetchs as model outputs(included loss and measures),
|
||||||
|
will call create_loss and create_metric(if use_mix).
|
||||||
|
Args:
|
||||||
|
out(variable): model output variable
|
||||||
|
feeds(dict): dict of model input variables.
|
||||||
|
If use mix_up, it will not include label.
|
||||||
|
architecture(dict): architecture information,
|
||||||
|
name(such as ResNet50) is needed
|
||||||
|
topk(int): usually top5
|
||||||
|
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||||
|
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||||
|
config(dict): model config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fetchs(dict): dict of model outputs(included loss and measures)
|
||||||
|
"""
|
||||||
|
fetchs = OrderedDict()
|
||||||
|
# build loss
|
||||||
|
# TODO(littletomatodonkey): support mix training
|
||||||
|
if use_mix:
|
||||||
|
y_a = paddle.reshape(feeds['y_a'], [-1, 1])
|
||||||
|
y_b = paddle.reshape(feeds['y_b'], [-1, 1])
|
||||||
|
lam = paddle.reshape(feeds['lam'], [-1, 1])
|
||||||
|
else:
|
||||||
|
target = paddle.reshape(feeds['label'], [-1, 1])
|
||||||
|
|
||||||
|
loss_func = build_loss(config["Loss"][mode])
|
||||||
|
|
||||||
|
# TODO: support mix training
|
||||||
|
loss_dict = loss_func(out, target)
|
||||||
|
|
||||||
|
loss_out = loss_dict["loss"]
|
||||||
|
# if "AMP" in config and config.AMP.get("use_pure_fp16", False):
|
||||||
|
# loss_out = loss_out.astype("float16")
|
||||||
|
|
||||||
|
# if use_mix:
|
||||||
|
# return loss_func(out, feed_y_a, feed_y_b, feed_lam)
|
||||||
|
# else:
|
||||||
|
# return loss_func(out, target)
|
||||||
|
|
||||||
|
fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True))
|
||||||
|
|
||||||
|
assert use_mix is False
|
||||||
|
|
||||||
|
# build metric
|
||||||
|
if not use_mix:
|
||||||
|
metric_func = build_metrics(config["Metric"][mode])
|
||||||
|
|
||||||
|
metric_dict = metric_func(out, target)
|
||||||
|
|
||||||
|
for key in metric_dict:
|
||||||
|
if mode != "Train" and paddle.distributed.get_world_size() > 1:
|
||||||
|
paddle.distributed.all_reduce(
|
||||||
|
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
|
||||||
|
metric_dict[key] = metric_dict[
|
||||||
|
key] / paddle.distributed.get_world_size()
|
||||||
|
|
||||||
|
fetchs[key] = (metric_dict[key], AverageMeter(
|
||||||
|
key, '7.4f', need_avg=True))
|
||||||
|
|
||||||
|
return fetchs
|
||||||
|
|
||||||
|
|
||||||
|
def create_optimizer(config, step_each_epoch):
|
||||||
|
# create learning_rate instance
|
||||||
|
optimizer, lr_sch = build_optimizer(
|
||||||
|
config["Optimizer"], config["Global"]["epochs"], step_each_epoch)
|
||||||
|
return optimizer, lr_sch
|
||||||
|
|
||||||
|
|
||||||
|
def create_strategy(config):
|
||||||
|
"""
|
||||||
|
Create build strategy and exec strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config(dict): config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
build_strategy: build strategy
|
||||||
|
exec_strategy: exec strategy
|
||||||
|
"""
|
||||||
|
build_strategy = paddle.static.BuildStrategy()
|
||||||
|
exec_strategy = paddle.static.ExecutionStrategy()
|
||||||
|
|
||||||
|
exec_strategy.num_threads = 1
|
||||||
|
exec_strategy.num_iteration_per_drop_scope = (
|
||||||
|
10000
|
||||||
|
if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)
|
||||||
|
|
||||||
|
fuse_op = True if 'AMP' in config else False
|
||||||
|
|
||||||
|
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
||||||
|
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
||||||
|
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
||||||
|
enable_addto = config.get('enable_addto', fuse_op)
|
||||||
|
|
||||||
|
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
|
||||||
|
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
|
||||||
|
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
|
||||||
|
build_strategy.enable_addto = enable_addto
|
||||||
|
|
||||||
|
return build_strategy, exec_strategy
|
||||||
|
|
||||||
|
|
||||||
|
def dist_optimizer(config, optimizer):
|
||||||
|
"""
|
||||||
|
Create a distributed optimizer based on a normal optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config(dict):
|
||||||
|
optimizer(): a normal optimizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
optimizer: a distributed optimizer
|
||||||
|
"""
|
||||||
|
build_strategy, exec_strategy = create_strategy(config)
|
||||||
|
|
||||||
|
dist_strategy = DistributedStrategy()
|
||||||
|
dist_strategy.execution_strategy = exec_strategy
|
||||||
|
dist_strategy.build_strategy = build_strategy
|
||||||
|
|
||||||
|
dist_strategy.nccl_comm_num = 1
|
||||||
|
dist_strategy.fuse_all_reduce_ops = True
|
||||||
|
dist_strategy.fuse_grad_size_in_MB = 16
|
||||||
|
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def mixed_precision_optimizer(config, optimizer):
|
||||||
|
if 'AMP' in config:
|
||||||
|
amp_cfg = config.AMP if config.AMP else dict()
|
||||||
|
scale_loss = amp_cfg.get('scale_loss', 1.0)
|
||||||
|
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
|
||||||
|
False)
|
||||||
|
use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
|
||||||
|
optimizer = paddle.static.amp.decorate(
|
||||||
|
optimizer,
|
||||||
|
init_loss_scaling=scale_loss,
|
||||||
|
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
|
||||||
|
use_pure_fp16=use_pure_fp16,
|
||||||
|
use_fp16_guard=True)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def build(config,
|
||||||
|
main_prog,
|
||||||
|
startup_prog,
|
||||||
|
step_each_epoch=100,
|
||||||
|
is_train=True,
|
||||||
|
is_distributed=True):
|
||||||
|
"""
|
||||||
|
Build a program using a model and an optimizer
|
||||||
|
1. create feeds
|
||||||
|
2. create a dataloader
|
||||||
|
3. create a model
|
||||||
|
4. create fetchs
|
||||||
|
5. create an optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config(dict): config
|
||||||
|
main_prog(): main program
|
||||||
|
startup_prog(): startup program
|
||||||
|
is_train(bool): train or eval
|
||||||
|
is_distributed(bool): whether to use distributed training method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataloader(): a bridge between the model and the data
|
||||||
|
fetchs(dict): dict of model outputs(included loss and measures)
|
||||||
|
"""
|
||||||
|
with paddle.static.program_guard(main_prog, startup_prog):
|
||||||
|
with paddle.utils.unique_name.guard():
|
||||||
|
mode = "Train" if is_train else "Eval"
|
||||||
|
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
|
||||||
|
"dataset"]
|
||||||
|
use_dali = config["Global"].get('use_dali', False)
|
||||||
|
feeds = create_feeds(
|
||||||
|
config["Global"]["image_shape"],
|
||||||
|
use_mix=use_mix,
|
||||||
|
dtype="float32")
|
||||||
|
|
||||||
|
# build model
|
||||||
|
# data_format should be assigned in arch-dict
|
||||||
|
input_image_channel = config["Global"]["image_shape"][
|
||||||
|
0] # default as [3, 224, 224]
|
||||||
|
if input_image_channel != 3:
|
||||||
|
logger.warning(
|
||||||
|
"Input image channel is changed to {}, maybe for better speed-up".
|
||||||
|
format(input_image_channel))
|
||||||
|
config["Arch"]["input_image_channel"] = input_image_channel
|
||||||
|
model = build_model(config["Arch"])
|
||||||
|
out = model(feeds["data"])
|
||||||
|
# end of build model
|
||||||
|
|
||||||
|
fetchs = create_fetchs(
|
||||||
|
out,
|
||||||
|
feeds,
|
||||||
|
config["Arch"],
|
||||||
|
epsilon=config.get('ls_epsilon'),
|
||||||
|
use_mix=use_mix,
|
||||||
|
config=config,
|
||||||
|
mode=mode)
|
||||||
|
lr_scheduler = None
|
||||||
|
optimizer = None
|
||||||
|
if is_train:
|
||||||
|
optimizer, lr_scheduler = build_optimizer(
|
||||||
|
config["Optimizer"], config["Global"]["epochs"],
|
||||||
|
step_each_epoch)
|
||||||
|
optimizer = mixed_precision_optimizer(config, optimizer)
|
||||||
|
if is_distributed:
|
||||||
|
optimizer = dist_optimizer(config, optimizer)
|
||||||
|
optimizer.minimize(fetchs['loss'][0])
|
||||||
|
return fetchs, lr_scheduler, feeds, optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def compile(config, program, loss_name=None, share_prog=None):
|
||||||
|
"""
|
||||||
|
Compile the program
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config(dict): config
|
||||||
|
program(): the program which is wrapped by
|
||||||
|
loss_name(str): loss name
|
||||||
|
share_prog(): the shared program, used for evaluation during training
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
compiled_program(): a compiled program
|
||||||
|
"""
|
||||||
|
build_strategy, exec_strategy = create_strategy(config)
|
||||||
|
|
||||||
|
compiled_program = paddle.static.CompiledProgram(
|
||||||
|
program).with_data_parallel(
|
||||||
|
share_vars_from=share_prog,
|
||||||
|
loss_name=loss_name,
|
||||||
|
build_strategy=build_strategy,
|
||||||
|
exec_strategy=exec_strategy)
|
||||||
|
|
||||||
|
return compiled_program
|
||||||
|
|
||||||
|
|
||||||
|
total_step = 0
|
||||||
|
|
||||||
|
|
||||||
|
def run(dataloader,
|
||||||
|
exe,
|
||||||
|
program,
|
||||||
|
feeds,
|
||||||
|
fetchs,
|
||||||
|
epoch=0,
|
||||||
|
mode='train',
|
||||||
|
config=None,
|
||||||
|
vdl_writer=None,
|
||||||
|
lr_scheduler=None):
|
||||||
|
"""
|
||||||
|
Feed data to the model and fetch the measures and loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader(paddle io dataloader):
|
||||||
|
exe():
|
||||||
|
program():
|
||||||
|
fetchs(dict): dict of measures and the loss
|
||||||
|
epoch(int): epoch of training or evaluation
|
||||||
|
model(str): log only
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
fetch_list = [f[0] for f in fetchs.values()]
|
||||||
|
metric_dict = OrderedDict([("lr", AverageMeter(
|
||||||
|
'lr', 'f', postfix=",", need_avg=False))])
|
||||||
|
|
||||||
|
for k in fetchs:
|
||||||
|
metric_dict[k] = fetchs[k][1]
|
||||||
|
|
||||||
|
metric_dict["batch_time"] = AverageMeter(
|
||||||
|
'batch_cost', '.5f', postfix=" s,")
|
||||||
|
metric_dict["reader_time"] = AverageMeter(
|
||||||
|
'reader_cost', '.5f', postfix=" s,")
|
||||||
|
|
||||||
|
for m in metric_dict.values():
|
||||||
|
m.reset()
|
||||||
|
|
||||||
|
use_dali = config["Global"].get('use_dali', False)
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
if not use_dali:
|
||||||
|
dataloader = dataloader()
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
batch_size = None
|
||||||
|
while True:
|
||||||
|
# The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
|
||||||
|
try:
|
||||||
|
batch = next(dataloader)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"Except RuntimeError when reading data from dataloader, try to read once again..."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
idx += 1
|
||||||
|
# ignore the warmup iters
|
||||||
|
if idx == 5:
|
||||||
|
metric_dict["batch_time"].reset()
|
||||||
|
metric_dict["reader_time"].reset()
|
||||||
|
|
||||||
|
metric_dict['reader_time'].update(time.time() - tic)
|
||||||
|
|
||||||
|
if use_dali:
|
||||||
|
batch_size = batch[0]["data"].shape()[0]
|
||||||
|
feed_dict = batch[0]
|
||||||
|
else:
|
||||||
|
batch_size = batch[0].shape()[0]
|
||||||
|
feed_dict = {
|
||||||
|
key.name: batch[idx]
|
||||||
|
for idx, key in enumerate(feeds.values())
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics = exe.run(program=program,
|
||||||
|
feed=feed_dict,
|
||||||
|
fetch_list=fetch_list)
|
||||||
|
|
||||||
|
for name, m in zip(fetchs.keys(), metrics):
|
||||||
|
metric_dict[name].update(np.mean(m), batch_size)
|
||||||
|
metric_dict["batch_time"].update(time.time() - tic)
|
||||||
|
if mode == "train":
|
||||||
|
metric_dict['lr'].update(lr_scheduler.get_lr())
|
||||||
|
|
||||||
|
fetchs_str = ' '.join([
|
||||||
|
str(metric_dict[key].mean)
|
||||||
|
if "time" in key else str(metric_dict[key].value)
|
||||||
|
for key in metric_dict
|
||||||
|
])
|
||||||
|
ips_info = " ips: {:.5f} images/sec.".format(
|
||||||
|
batch_size / metric_dict["batch_time"].avg)
|
||||||
|
fetchs_str += ips_info
|
||||||
|
|
||||||
|
if lr_scheduler is not None:
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
if vdl_writer:
|
||||||
|
global total_step
|
||||||
|
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
|
||||||
|
total_step += 1
|
||||||
|
if mode == 'eval':
|
||||||
|
if idx % config.get('print_interval', 10) == 0:
|
||||||
|
logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
|
||||||
|
fetchs_str))
|
||||||
|
else:
|
||||||
|
epoch_str = "epoch:{:<3d}".format(epoch)
|
||||||
|
step_str = "{:s} step:{:<4d}".format(mode, idx)
|
||||||
|
|
||||||
|
if idx % config.get('print_interval', 10) == 0:
|
||||||
|
logger.info("{:s} {:s} {:s}".format(epoch_str, step_str,
|
||||||
|
fetchs_str))
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
|
end_str = ' '.join([str(m.mean) for m in metric_dict.values()] +
|
||||||
|
[metric_dict["batch_time"].total])
|
||||||
|
ips_info = "ips: {:.5f} images/sec.".format(
|
||||||
|
batch_size * metric_dict["batch_time"].count /
|
||||||
|
metric_dict["batch_time"].sum)
|
||||||
|
if mode == 'eval':
|
||||||
|
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
|
||||||
|
else:
|
||||||
|
end_epoch_str = "END epoch:{:<3d}".format(epoch)
|
||||||
|
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
|
||||||
|
ips_info))
|
||||||
|
if use_dali:
|
||||||
|
dataloader.reset()
|
||||||
|
|
||||||
|
# return top1_acc in order to save the best model
|
||||||
|
if mode == 'eval':
|
||||||
|
return fetchs["top1"][1].avg
|
|
@ -0,0 +1,11 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
export FLAGS_fraction_of_gpu_memory_to_use=0.80
|
||||||
|
|
||||||
|
python3.7 -m paddle.distributed.launch \
|
||||||
|
--gpus="0,1,2,3" \
|
||||||
|
ppcls/static//train.py \
|
||||||
|
-c ./ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml \
|
||||||
|
-o Global.use_dali=True
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from ppcls.utils import logger
|
||||||
|
|
||||||
|
__all__ = ['init_model', 'save_model']
|
||||||
|
|
||||||
|
|
||||||
|
def _mkdir_if_not_exist(path):
|
||||||
|
"""
|
||||||
|
mkdir if not exists, ignore the exception when multiprocess mkdir together
|
||||||
|
"""
|
||||||
|
if not os.path.exists(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.EEXIST and os.path.isdir(path):
|
||||||
|
logger.warning(
|
||||||
|
'be happy if some process has already created {}'.format(
|
||||||
|
path))
|
||||||
|
else:
|
||||||
|
raise OSError('Failed to mkdir {}'.format(path))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state(path):
|
||||||
|
if os.path.exists(path + '.pdopt'):
|
||||||
|
# XXX another hack to ignore the optimizer state
|
||||||
|
tmp = tempfile.mkdtemp()
|
||||||
|
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
|
||||||
|
shutil.copy(path + '.pdparams', dst + '.pdparams')
|
||||||
|
state = paddle.static.load_program_state(dst)
|
||||||
|
shutil.rmtree(tmp)
|
||||||
|
else:
|
||||||
|
state = paddle.static.load_program_state(path)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def load_params(exe, prog, path, ignore_params=None):
|
||||||
|
"""
|
||||||
|
Load model from the given path.
|
||||||
|
Args:
|
||||||
|
exe (fluid.Executor): The fluid.Executor object.
|
||||||
|
prog (fluid.Program): load weight to which Program object.
|
||||||
|
path (string): URL string or loca model path.
|
||||||
|
ignore_params (list): ignore variable to load when finetuning.
|
||||||
|
It can be specified by finetune_exclude_pretrained_params
|
||||||
|
and the usage can refer to the document
|
||||||
|
docs/advanced_tutorials/TRANSFER_LEARNING.md
|
||||||
|
"""
|
||||||
|
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||||
|
raise ValueError("Model pretrain path {} does not "
|
||||||
|
"exists.".format(path))
|
||||||
|
|
||||||
|
logger.info("Loading parameters from {}...".format(path))
|
||||||
|
|
||||||
|
ignore_set = set()
|
||||||
|
state = _load_state(path)
|
||||||
|
|
||||||
|
# ignore the parameter which mismatch the shape
|
||||||
|
# between the model and pretrain weight.
|
||||||
|
all_var_shape = {}
|
||||||
|
for block in prog.blocks:
|
||||||
|
for param in block.all_parameters():
|
||||||
|
all_var_shape[param.name] = param.shape
|
||||||
|
ignore_set.update([
|
||||||
|
name for name, shape in all_var_shape.items()
|
||||||
|
if name in state and shape != state[name].shape
|
||||||
|
])
|
||||||
|
|
||||||
|
if ignore_params:
|
||||||
|
all_var_names = [var.name for var in prog.list_vars()]
|
||||||
|
ignore_list = filter(
|
||||||
|
lambda var: any([re.match(name, var) for name in ignore_params]),
|
||||||
|
all_var_names)
|
||||||
|
ignore_set.update(list(ignore_list))
|
||||||
|
|
||||||
|
if len(ignore_set) > 0:
|
||||||
|
for k in ignore_set:
|
||||||
|
if k in state:
|
||||||
|
logger.warning(
|
||||||
|
'variable {} is already excluded automatically'.format(k))
|
||||||
|
del state[k]
|
||||||
|
|
||||||
|
paddle.static.set_program_state(prog, state)
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(config, program, exe):
|
||||||
|
"""
|
||||||
|
load model from checkpoint or pretrained_model
|
||||||
|
"""
|
||||||
|
checkpoints = config.get('checkpoints')
|
||||||
|
if checkpoints:
|
||||||
|
paddle.static.load(program, checkpoints, exe)
|
||||||
|
logger.info("Finish initing model from {}".format(checkpoints))
|
||||||
|
return
|
||||||
|
|
||||||
|
pretrained_model = config.get('pretrained_model')
|
||||||
|
if pretrained_model:
|
||||||
|
if not isinstance(pretrained_model, list):
|
||||||
|
pretrained_model = [pretrained_model]
|
||||||
|
for pretrain in pretrained_model:
|
||||||
|
load_params(exe, program, pretrain)
|
||||||
|
logger.info("Finish initing model from {}".format(pretrained_model))
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(program, model_path, epoch_id, prefix='ppcls'):
|
||||||
|
"""
|
||||||
|
save model to the target path
|
||||||
|
"""
|
||||||
|
if paddle.distributed.get_rank() != 0:
|
||||||
|
return
|
||||||
|
model_path = os.path.join(model_path, str(epoch_id))
|
||||||
|
_mkdir_if_not_exist(model_path)
|
||||||
|
model_prefix = os.path.join(model_path, prefix)
|
||||||
|
paddle.static.save(program, model_prefix)
|
||||||
|
logger.info("Already save model in {}".format(model_path))
|
|
@ -0,0 +1,197 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle.distributed import fleet
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
from ppcls.data import build_dataloader
|
||||||
|
from ppcls.utils.config import get_config, print_config
|
||||||
|
from ppcls.utils import logger
|
||||||
|
from ppcls.utils.logger import init_logger
|
||||||
|
from ppcls.static.save_load import init_model, save_model
|
||||||
|
from ppcls.static import program
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser("PaddleClas train script")
|
||||||
|
parser.add_argument(
|
||||||
|
'-c',
|
||||||
|
'--config',
|
||||||
|
type=str,
|
||||||
|
default='configs/ResNet/ResNet50.yaml',
|
||||||
|
help='config file path')
|
||||||
|
parser.add_argument(
|
||||||
|
'-o',
|
||||||
|
'--override',
|
||||||
|
action='append',
|
||||||
|
default=[],
|
||||||
|
help='config options to be overridden')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
"""
|
||||||
|
all the config of training paradigm should be in config["Global"]
|
||||||
|
"""
|
||||||
|
config = get_config(args.config, overrides=args.override, show=False)
|
||||||
|
global_config = config["Global"]
|
||||||
|
|
||||||
|
mode = "train"
|
||||||
|
|
||||||
|
log_file = os.path.join(global_config['output_dir'],
|
||||||
|
config["Arch"]["name"], f"{mode}.log")
|
||||||
|
init_logger(name='root', log_file=log_file)
|
||||||
|
print_config(config)
|
||||||
|
|
||||||
|
if global_config.get("is_distributed", True):
|
||||||
|
fleet.init(is_collective=True)
|
||||||
|
# assign the device
|
||||||
|
use_gpu = global_config.get("use_gpu", True)
|
||||||
|
# amp related config
|
||||||
|
if 'AMP' in config:
|
||||||
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
|
'FLAGS_cudnn_exhaustive_search': "1",
|
||||||
|
'FLAGS_conv_workspace_size_limit': "1500",
|
||||||
|
'FLAGS_cudnn_batchnorm_spatial_persistent': "1",
|
||||||
|
'FLAGS_max_indevice_grad_add': "8",
|
||||||
|
"FLAGS_cudnn_batchnorm_spatial_persistent": "1",
|
||||||
|
}
|
||||||
|
for k in AMP_RELATED_FLAGS_SETTING:
|
||||||
|
os.environ[k] = AMP_RELATED_FLAGS_SETTING[k]
|
||||||
|
|
||||||
|
use_xpu = global_config.get("use_xpu", False)
|
||||||
|
assert (
|
||||||
|
use_gpu and use_xpu
|
||||||
|
) is not True, "gpu and xpu can not be true in the same time in static mode!"
|
||||||
|
|
||||||
|
if use_gpu:
|
||||||
|
device = paddle.set_device('gpu')
|
||||||
|
elif use_xpu:
|
||||||
|
device = paddle.set_device('xpu')
|
||||||
|
else:
|
||||||
|
device = paddle.set_device('cpu')
|
||||||
|
|
||||||
|
# visualDL
|
||||||
|
vdl_writer = None
|
||||||
|
if global_config["use_visualdl"]:
|
||||||
|
vdl_dir = os.path.join(global_config["output_dir"], "vdl")
|
||||||
|
vdl_writer = LogWriter(vdl_dir)
|
||||||
|
|
||||||
|
# build dataloader
|
||||||
|
eval_dataloader = None
|
||||||
|
use_dali = global_config.get('use_dali', False)
|
||||||
|
|
||||||
|
train_dataloader = build_dataloader(
|
||||||
|
config["DataLoader"], "Train", device=device, use_dali=use_dali)
|
||||||
|
if global_config["eval_during_train"]:
|
||||||
|
eval_dataloader = build_dataloader(
|
||||||
|
config["DataLoader"], "Eval", device=device, use_dali=use_dali)
|
||||||
|
|
||||||
|
step_each_epoch = len(train_dataloader)
|
||||||
|
|
||||||
|
# startup_prog is used to do some parameter init work,
|
||||||
|
# and train prog is used to hold the network
|
||||||
|
startup_prog = paddle.static.Program()
|
||||||
|
train_prog = paddle.static.Program()
|
||||||
|
|
||||||
|
best_top1_acc = 0.0 # best top1 acc record
|
||||||
|
|
||||||
|
train_fetchs, lr_scheduler, train_feeds, optimizer = program.build(
|
||||||
|
config,
|
||||||
|
train_prog,
|
||||||
|
startup_prog,
|
||||||
|
step_each_epoch=step_each_epoch,
|
||||||
|
is_train=True,
|
||||||
|
is_distributed=global_config.get("is_distributed", True))
|
||||||
|
|
||||||
|
if global_config["eval_during_train"]:
|
||||||
|
eval_prog = paddle.static.Program()
|
||||||
|
eval_fetchs, _, eval_feeds, _ = program.build(
|
||||||
|
config,
|
||||||
|
eval_prog,
|
||||||
|
startup_prog,
|
||||||
|
is_train=False,
|
||||||
|
is_distributed=global_config.get("is_distributed", True))
|
||||||
|
# clone to prune some content which is irrelevant in eval_prog
|
||||||
|
eval_prog = eval_prog.clone(for_test=True)
|
||||||
|
|
||||||
|
# create the "Executor" with the statement of which device
|
||||||
|
exe = paddle.static.Executor(device)
|
||||||
|
# Parameter initialization
|
||||||
|
exe.run(startup_prog)
|
||||||
|
# load pretrained models or checkpoints
|
||||||
|
init_model(global_config, train_prog, exe)
|
||||||
|
|
||||||
|
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
|
||||||
|
optimizer.amp_init(
|
||||||
|
device,
|
||||||
|
scope=paddle.static.global_scope(),
|
||||||
|
test_program=eval_prog
|
||||||
|
if global_config["eval_during_train"] else None)
|
||||||
|
|
||||||
|
if not global_config.get("is_distributed", True):
|
||||||
|
compiled_train_prog = program.compile(
|
||||||
|
config, train_prog, loss_name=train_fetchs["loss"][0].name)
|
||||||
|
else:
|
||||||
|
compiled_train_prog = train_prog
|
||||||
|
|
||||||
|
if eval_dataloader is not None:
|
||||||
|
compiled_eval_prog = program.compile(config, eval_prog)
|
||||||
|
|
||||||
|
for epoch_id in range(global_config["epochs"]):
|
||||||
|
# 1. train with train dataset
|
||||||
|
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
|
||||||
|
train_fetchs, epoch_id, 'train', config, vdl_writer,
|
||||||
|
lr_scheduler)
|
||||||
|
# 2. evaate with eval dataset
|
||||||
|
if global_config["eval_during_train"] and epoch_id % global_config[
|
||||||
|
"eval_interval"] == 0:
|
||||||
|
top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog,
|
||||||
|
eval_feeds, eval_fetchs, epoch_id, "eval",
|
||||||
|
config)
|
||||||
|
if top1_acc > best_top1_acc:
|
||||||
|
best_top1_acc = top1_acc
|
||||||
|
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
|
||||||
|
best_top1_acc, epoch_id)
|
||||||
|
logger.info(message)
|
||||||
|
if epoch_id % global_config["save_interval"] == 0:
|
||||||
|
|
||||||
|
model_path = os.path.join(global_config["output_dir"],
|
||||||
|
config["Arch"]["name"])
|
||||||
|
save_model(train_prog, model_path, "best_model")
|
||||||
|
|
||||||
|
# 3. save the persistable model
|
||||||
|
if epoch_id % global_config["save_interval"] == 0:
|
||||||
|
model_path = os.path.join(global_config["output_dir"],
|
||||||
|
config["Arch"]["name"])
|
||||||
|
save_model(train_prog, model_path, epoch_id)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
paddle.enable_static()
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
Loading…
Reference in New Issue