support fp16 training (#435)
* support fp16 training * Use compiled training program * Change timing ips. * Use dali * add pure fp16 training * fix a bug, which will not use fuse pass using pure fp16 training. * modify code as review * modify loss, so that it will use different loss when using pure fp16 training. * remove some fluid API * add static optimizer.pull/468/head
parent
9992415867
commit
dc3020ab4a
|
@ -12,11 +12,16 @@ valid_interval: 1
|
|||
epochs: 120
|
||||
topk: 5
|
||||
image_shape: [3, 224, 224]
|
||||
is_distributed: True
|
||||
|
||||
# mixed precision training
|
||||
use_fp16: True
|
||||
amp_scale_loss: 128.0
|
||||
use_amp: True
|
||||
use_pure_fp16: False
|
||||
multi_precision: False
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
data_format: "NCHW"
|
||||
image_shape: [3, 224, 224]
|
||||
|
||||
use_mix: False
|
||||
ls_epsilon: -1
|
||||
|
|
|
@ -38,7 +38,8 @@ class ConvBNLayer(nn.Layer):
|
|||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
name=None,
|
||||
data_format="NCHW"):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
|
@ -49,7 +50,8 @@ class ConvBNLayer(nn.Layer):
|
|||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
bias_attr=False,
|
||||
data_format=data_format)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
|
@ -60,7 +62,8 @@ class ConvBNLayer(nn.Layer):
|
|||
param_attr=ParamAttr(name=bn_name + "_scale"),
|
||||
bias_attr=ParamAttr(bn_name + "_offset"),
|
||||
moving_mean_name=bn_name + "_mean",
|
||||
moving_variance_name=bn_name + "_variance")
|
||||
moving_variance_name=bn_name + "_variance",
|
||||
data_layout=data_format)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
|
@ -74,7 +77,8 @@ class BottleneckBlock(nn.Layer):
|
|||
num_filters,
|
||||
stride,
|
||||
shortcut=True,
|
||||
name=None):
|
||||
name=None,
|
||||
data_format="NCHW"):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
|
@ -82,20 +86,23 @@ class BottleneckBlock(nn.Layer):
|
|||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act="relu",
|
||||
name=name + "_branch2a")
|
||||
name=name + "_branch2a",
|
||||
data_format=data_format)
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act="relu",
|
||||
name=name + "_branch2b")
|
||||
name=name + "_branch2b",
|
||||
data_format=data_format)
|
||||
self.conv2 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
name=name + "_branch2c",
|
||||
data_format=data_format)
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
|
@ -103,7 +110,8 @@ class BottleneckBlock(nn.Layer):
|
|||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
name=name + "_branch1")
|
||||
name=name + "_branch1",
|
||||
data_format=data_format)
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
|
@ -130,7 +138,8 @@ class BasicBlock(nn.Layer):
|
|||
num_filters,
|
||||
stride,
|
||||
shortcut=True,
|
||||
name=None):
|
||||
name=None,
|
||||
data_format="NCHW"):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
|
@ -139,13 +148,15 @@ class BasicBlock(nn.Layer):
|
|||
filter_size=3,
|
||||
stride=stride,
|
||||
act="relu",
|
||||
name=name + "_branch2a")
|
||||
name=name + "_branch2a",
|
||||
data_format=data_format)
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=num_filters,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
name=name + "_branch2b",
|
||||
data_format=data_format)
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
|
@ -153,7 +164,8 @@ class BasicBlock(nn.Layer):
|
|||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
stride=stride,
|
||||
name=name + "_branch1")
|
||||
name=name + "_branch1",
|
||||
data_format=data_format)
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
|
@ -171,10 +183,13 @@ class BasicBlock(nn.Layer):
|
|||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, layers=50, class_dim=1000):
|
||||
def __init__(self, layers=50, class_dim=1000, input_image_channel=3, data_format="NCHW"):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
self.data_format = data_format
|
||||
self.input_image_channel = input_image_channel
|
||||
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
|
@ -193,13 +208,18 @@ class ResNet(nn.Layer):
|
|||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
self.conv = ConvBNLayer(
|
||||
num_channels=3,
|
||||
num_channels=self.input_image_channel,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act="relu",
|
||||
name="conv1")
|
||||
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
name="conv1",
|
||||
data_format=self.data_format)
|
||||
self.pool2d_max = MaxPool2D(
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
data_format=self.data_format)
|
||||
|
||||
self.block_list = []
|
||||
if layers >= 50:
|
||||
|
@ -221,7 +241,8 @@ class ResNet(nn.Layer):
|
|||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
name=conv_name))
|
||||
name=conv_name,
|
||||
data_format=self.data_format))
|
||||
self.block_list.append(bottleneck_block)
|
||||
shortcut = True
|
||||
else:
|
||||
|
@ -237,11 +258,12 @@ class ResNet(nn.Layer):
|
|||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
name=conv_name))
|
||||
name=conv_name,
|
||||
data_format=self.data_format))
|
||||
self.block_list.append(basic_block)
|
||||
shortcut = True
|
||||
|
||||
self.pool2d_avg = AdaptiveAvgPool2D(1)
|
||||
self.pool2d_avg = AdaptiveAvgPool2D(1, data_format=self.data_format)
|
||||
|
||||
self.pool2d_avg_channels = num_channels[-1] * 2
|
||||
|
||||
|
|
|
@ -42,14 +42,17 @@ class Loss(object):
|
|||
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
|
||||
return soft_target
|
||||
|
||||
def _crossentropy(self, input, target):
|
||||
def _crossentropy(self, input, target, use_pure_fp16=False):
|
||||
if self._label_smoothing:
|
||||
target = self._labelsmoothing(target)
|
||||
input = -F.log_softmax(input, axis=-1)
|
||||
cost = paddle.sum(target * input, axis=-1)
|
||||
else:
|
||||
cost = F.cross_entropy(input=input, label=target)
|
||||
avg_cost = paddle.mean(cost)
|
||||
cost = F.cross_entropy(input=input, label=target)
|
||||
if use_pure_fp16:
|
||||
avg_cost = paddle.sum(cost)
|
||||
else:
|
||||
avg_cost = paddle.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
def _kldiv(self, input, target, name=None):
|
||||
|
@ -78,8 +81,8 @@ class CELoss(Loss):
|
|||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(CELoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target):
|
||||
cost = self._crossentropy(input, target)
|
||||
def __call__(self, input, target, use_pure_fp16=False):
|
||||
cost = self._crossentropy(input, target, use_pure_fp16)
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -91,11 +94,14 @@ class MixCELoss(Loss):
|
|||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(MixCELoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target0, target1, lam):
|
||||
cost0 = self._crossentropy(input, target0)
|
||||
cost1 = self._crossentropy(input, target1)
|
||||
cost = lam * cost0 + (1.0 - lam) * cost1
|
||||
avg_cost = paddle.mean(cost)
|
||||
def __call__(self, input, target0, target1, lam, use_pure_fp16=False):
|
||||
cost0 = self._crossentropy(input, target0, use_pure_fp16)
|
||||
cost1 = self._crossentropy(input, target1, use_pure_fp16)
|
||||
cost = lam * cost0 + (1.0 - lam) * cost1
|
||||
if use_pure_fp16:
|
||||
avg_cost = paddle.sum(cost)
|
||||
else:
|
||||
avg_cost = paddle.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
|
||||
|
|
|
@ -44,7 +44,9 @@ class HybridTrainPipe(Pipeline):
|
|||
num_shards=1,
|
||||
random_shuffle=True,
|
||||
num_threads=4,
|
||||
seed=42):
|
||||
seed=42,
|
||||
pad_output=False,
|
||||
output_dtype=types.FLOAT):
|
||||
super(HybridTrainPipe, self).__init__(
|
||||
batch_size, num_threads, device_id, seed=seed)
|
||||
self.input = ops.FileReader(
|
||||
|
@ -69,12 +71,13 @@ class HybridTrainPipe(Pipeline):
|
|||
device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
|
||||
self.cmnp = ops.CropMirrorNormalize(
|
||||
device="gpu",
|
||||
output_dtype=types.FLOAT,
|
||||
output_dtype=output_dtype,
|
||||
output_layout=types.NCHW,
|
||||
crop=(crop, crop),
|
||||
image_type=types.RGB,
|
||||
mean=mean,
|
||||
std=std)
|
||||
std=std,
|
||||
pad_output=pad_output)
|
||||
self.coin = ops.CoinFlip(probability=0.5)
|
||||
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
|
||||
|
||||
|
@ -105,7 +108,9 @@ class HybridValPipe(Pipeline):
|
|||
num_shards=1,
|
||||
random_shuffle=False,
|
||||
num_threads=4,
|
||||
seed=42):
|
||||
seed=42,
|
||||
pad_output=False,
|
||||
output_dtype=types.FLOAT):
|
||||
super(HybridValPipe, self).__init__(
|
||||
batch_size, num_threads, device_id, seed=seed)
|
||||
self.input = ops.FileReader(
|
||||
|
@ -119,12 +124,13 @@ class HybridValPipe(Pipeline):
|
|||
device="gpu", resize_shorter=resize_shorter, interp_type=interp)
|
||||
self.cmnp = ops.CropMirrorNormalize(
|
||||
device="gpu",
|
||||
output_dtype=types.FLOAT,
|
||||
output_dtype=output_dtype,
|
||||
output_layout=types.NCHW,
|
||||
crop=(crop, crop),
|
||||
image_type=types.RGB,
|
||||
mean=mean,
|
||||
std=std)
|
||||
std=std,
|
||||
pad_output=pad_output)
|
||||
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
|
||||
|
||||
def define_graph(self):
|
||||
|
@ -170,8 +176,13 @@ def build(config, mode='train'):
|
|||
2: types.INTERP_CUBIC, # cv2.INTER_CUBIC
|
||||
4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
|
||||
}
|
||||
output_dtype = types.FLOAT16 if config.get("use_pure_fp16", False) else types.FLOAT
|
||||
assert interp in interp_map, "interpolation method not supported by DALI"
|
||||
interp = interp_map[interp]
|
||||
pad_output = False
|
||||
image_shape = config.get("image_shape", None)
|
||||
if image_shape and image_shape[0] == 4:
|
||||
pad_output = True
|
||||
|
||||
transforms = {
|
||||
k: v
|
||||
|
@ -214,7 +225,9 @@ def build(config, mode='train'):
|
|||
device_id,
|
||||
shard_id,
|
||||
num_shards,
|
||||
seed=42 + shard_id)
|
||||
seed=42 + shard_id,
|
||||
pad_output=pad_output,
|
||||
output_dtype=output_dtype)
|
||||
pipe.build()
|
||||
pipelines = [pipe]
|
||||
sample_per_shard = len(pipe) // num_shards
|
||||
|
@ -241,7 +254,9 @@ def build(config, mode='train'):
|
|||
device_id,
|
||||
idx,
|
||||
num_shards,
|
||||
seed=42 + idx)
|
||||
seed=42 + idx,
|
||||
pad_output=pad_output,
|
||||
output_dtype=output_dtype)
|
||||
pipe.build()
|
||||
pipelines.append(pipe)
|
||||
sample_per_shard = len(pipelines[0])
|
||||
|
@ -264,7 +279,9 @@ def build(config, mode='train'):
|
|||
interp,
|
||||
mean,
|
||||
std,
|
||||
device_id=device_id)
|
||||
device_id=device_id,
|
||||
pad_output=pad_output,
|
||||
output_dtype=output_dtype)
|
||||
pipe.build()
|
||||
return DALIGenericIterator(
|
||||
pipe, ['feed_image', 'feed_label'],
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.regularizer as regularizer
|
||||
|
||||
__all__ = ['OptimizerBuilder']
|
||||
|
||||
|
||||
class L1Decay(object):
|
||||
"""
|
||||
L1 Weight Decay Regularization, which encourages the weights to be sparse.
|
||||
|
||||
Args:
|
||||
factor(float): regularization coeff. Default:0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L1Decay, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = regularizer.L1Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
class L2Decay(object):
|
||||
"""
|
||||
L2 Weight Decay Regularization, which encourages the weights to be sparse.
|
||||
|
||||
Args:
|
||||
factor(float): regularization coeff. Default:0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L2Decay, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = regularizer.L2Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
class Momentum(object):
|
||||
"""
|
||||
Simple Momentum optimizer with velocity state.
|
||||
|
||||
Args:
|
||||
learning_rate (float|Variable) - The learning rate used to update parameters.
|
||||
Can be a float value or a Variable with one float value as data element.
|
||||
momentum (float) - Momentum factor.
|
||||
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
momentum,
|
||||
parameter_list=None,
|
||||
regularization=None,
|
||||
config=None,
|
||||
**args):
|
||||
super(Momentum, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.parameter_list = parameter_list
|
||||
self.regularization = regularization
|
||||
self.multi_precision = config.get('multi_precision', False)
|
||||
self.rescale_grad = (1.0 / (config['TRAIN']['batch_size'] / len(fluid.cuda_places()))
|
||||
if config.get('use_pure_fp16', False) else 1.0)
|
||||
|
||||
def __call__(self):
|
||||
opt = fluid.contrib.optimizer.Momentum(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
regularization=self.regularization,
|
||||
multi_precision=self.multi_precision,
|
||||
rescale_grad=self.rescale_grad)
|
||||
return opt
|
||||
|
||||
|
||||
class RMSProp(object):
|
||||
"""
|
||||
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
|
||||
|
||||
Args:
|
||||
learning_rate (float|Variable) - The learning rate used to update parameters.
|
||||
Can be a float value or a Variable with one float value as data element.
|
||||
momentum (float) - Momentum factor.
|
||||
rho (float) - rho value in equation.
|
||||
epsilon (float) - avoid division by zero, default is 1e-6.
|
||||
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
momentum,
|
||||
rho=0.95,
|
||||
epsilon=1e-6,
|
||||
parameter_list=None,
|
||||
regularization=None,
|
||||
**args):
|
||||
super(RMSProp, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.rho = rho
|
||||
self.epsilon = epsilon
|
||||
self.parameter_list = parameter_list
|
||||
self.regularization = regularization
|
||||
|
||||
def __call__(self):
|
||||
opt = paddle.optimizer.RMSProp(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
rho=self.rho,
|
||||
epsilon=self.epsilon,
|
||||
parameters=self.parameter_list,
|
||||
weight_decay=self.regularization)
|
||||
return opt
|
||||
|
||||
|
||||
class OptimizerBuilder(object):
|
||||
"""
|
||||
Build optimizer
|
||||
|
||||
Args:
|
||||
function(str): optimizer name of learning rate
|
||||
params(dict): parameters used for init the class
|
||||
regularizer (dict): parameters used for create regularization
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config=None,
|
||||
function='Momentum',
|
||||
params={'momentum': 0.9},
|
||||
regularizer=None):
|
||||
self.function = function
|
||||
self.params = params
|
||||
self.config = config
|
||||
# create regularizer
|
||||
if regularizer is not None:
|
||||
mod = sys.modules[__name__]
|
||||
reg_func = regularizer['function'] + 'Decay'
|
||||
del regularizer['function']
|
||||
reg = getattr(mod, reg_func)(**regularizer)()
|
||||
self.params['regularization'] = reg
|
||||
|
||||
def __call__(self, learning_rate, parameter_list=None):
|
||||
mod = sys.modules[__name__]
|
||||
opt = getattr(mod, self.function)
|
||||
return opt(learning_rate=learning_rate,
|
||||
parameter_list=parameter_list,
|
||||
config=self.config,
|
||||
**self.params)()
|
|
@ -21,12 +21,14 @@ import time
|
|||
import numpy as np
|
||||
|
||||
from collections import OrderedDict
|
||||
from optimizer import OptimizerBuilder
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import fluid
|
||||
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
|
||||
|
||||
from ppcls.optimizer.learning_rate import LearningRateBuilder
|
||||
from ppcls.optimizer.optimizer import OptimizerBuilder
|
||||
from ppcls.modeling import architectures
|
||||
from ppcls.modeling.loss import CELoss
|
||||
from ppcls.modeling.loss import MixCELoss
|
||||
|
@ -39,7 +41,7 @@ from paddle.distributed import fleet
|
|||
from paddle.distributed.fleet import DistributedStrategy
|
||||
|
||||
|
||||
def create_feeds(image_shape, use_mix=None, use_dali=None):
|
||||
def create_feeds(image_shape, use_mix=None, use_dali=None, dtype="float32"):
|
||||
"""
|
||||
Create feeds as model input
|
||||
|
||||
|
@ -52,14 +54,14 @@ def create_feeds(image_shape, use_mix=None, use_dali=None):
|
|||
"""
|
||||
feeds = OrderedDict()
|
||||
feeds['image'] = paddle.static.data(
|
||||
name="feed_image", shape=[None] + image_shape, dtype="float32")
|
||||
name="feed_image", shape=[None] + image_shape, dtype=dtype)
|
||||
if use_mix and not use_dali:
|
||||
feeds['feed_y_a'] = paddle.static.data(
|
||||
name="feed_y_a", shape=[None, 1], dtype="int64")
|
||||
feeds['feed_y_b'] = paddle.static.data(
|
||||
name="feed_y_b", shape=[None, 1], dtype="int64")
|
||||
feeds['feed_lam'] = paddle.static.data(
|
||||
name="feed_lam", shape=[None, 1], dtype="float32")
|
||||
name="feed_lam", shape=[None, 1], dtype=dtype)
|
||||
else:
|
||||
feeds['label'] = paddle.static.data(
|
||||
name="feed_label", shape=[None, 1], dtype="int64")
|
||||
|
@ -67,7 +69,7 @@ def create_feeds(image_shape, use_mix=None, use_dali=None):
|
|||
return feeds
|
||||
|
||||
|
||||
def create_model(architecture, image, classes_num, is_train):
|
||||
def create_model(architecture, image, classes_num, config, is_train):
|
||||
"""
|
||||
Create a model
|
||||
|
||||
|
@ -76,16 +78,33 @@ def create_model(architecture, image, classes_num, is_train):
|
|||
name(such as ResNet50) is needed
|
||||
image(variable): model input variable
|
||||
classes_num(int): num of classes
|
||||
config(dict): model config
|
||||
|
||||
Returns:
|
||||
out(variable): model output variable
|
||||
"""
|
||||
use_pure_fp16 = config.get("use_pure_fp16", False)
|
||||
name = architecture["name"]
|
||||
params = architecture.get("params", {})
|
||||
data_format = config.get("data_format", "NCHW")
|
||||
input_image_channel = config.get('image_shape', [3, 224, 224])[0]
|
||||
if "is_test" in params:
|
||||
params['is_test'] = not is_train
|
||||
model = architectures.__dict__[name](class_dim=classes_num, **params)
|
||||
model = architectures.__dict__[name](
|
||||
class_dim=classes_num,
|
||||
input_image_channel=input_image_channel,
|
||||
data_format=data_format,
|
||||
**params)
|
||||
|
||||
if use_pure_fp16 and not config.get("use_dali", False):
|
||||
image = image.astype('float16')
|
||||
if data_format == "NHWC":
|
||||
image = paddle.tensor.transpose(image, [0, 2, 3, 1])
|
||||
image.stop_gradient = True
|
||||
out = model(image)
|
||||
if config.get("use_pure_fp16", False):
|
||||
cast_model_to_fp16(paddle.static.default_main_program())
|
||||
out = out.astype('float32')
|
||||
return out
|
||||
|
||||
|
||||
|
@ -95,7 +114,8 @@ def create_loss(out,
|
|||
classes_num=1000,
|
||||
epsilon=None,
|
||||
use_mix=False,
|
||||
use_distillation=False):
|
||||
use_distillation=False,
|
||||
use_pure_fp16=False):
|
||||
"""
|
||||
Create a loss for optimization, such as:
|
||||
1. CrossEnotry loss
|
||||
|
@ -112,6 +132,7 @@ def create_loss(out,
|
|||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
use_pure_fp16(bool): whether to use pure fp16 data as training parameter
|
||||
|
||||
Returns:
|
||||
loss(variable): loss variable
|
||||
|
@ -136,10 +157,10 @@ def create_loss(out,
|
|||
|
||||
if use_mix:
|
||||
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
|
||||
return loss(out, feed_y_a, feed_y_b, feed_lam)
|
||||
return loss(out, feed_y_a, feed_y_b, feed_lam, use_pure_fp16)
|
||||
else:
|
||||
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
|
||||
return loss(out, target)
|
||||
return loss(out, target, use_pure_fp16)
|
||||
|
||||
|
||||
def create_metric(out,
|
||||
|
@ -147,6 +168,7 @@ def create_metric(out,
|
|||
architecture,
|
||||
topk=5,
|
||||
classes_num=1000,
|
||||
config=None,
|
||||
use_distillation=False):
|
||||
"""
|
||||
Create measures of model accuracy, such as top1 and top5
|
||||
|
@ -156,6 +178,7 @@ def create_metric(out,
|
|||
feeds(dict): dict of model input variables(included label)
|
||||
topk(int): usually top5
|
||||
classes_num(int): num of classes
|
||||
config(dict) : model config
|
||||
|
||||
Returns:
|
||||
fetchs(dict): dict of measures
|
||||
|
@ -189,6 +212,7 @@ def create_fetchs(out,
|
|||
classes_num=1000,
|
||||
epsilon=None,
|
||||
use_mix=False,
|
||||
config=None,
|
||||
use_distillation=False):
|
||||
"""
|
||||
Create fetchs as model outputs(included loss and measures),
|
||||
|
@ -204,17 +228,19 @@ def create_fetchs(out,
|
|||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
config(dict): model config
|
||||
|
||||
Returns:
|
||||
fetchs(dict): dict of model outputs(included loss and measures)
|
||||
"""
|
||||
fetchs = OrderedDict()
|
||||
use_pure_fp16 = config.get("use_pure_fp16", False)
|
||||
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
|
||||
use_distillation)
|
||||
use_distillation, use_pure_fp16)
|
||||
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
|
||||
if not use_mix:
|
||||
metric = create_metric(out, feeds, architecture, topk, classes_num,
|
||||
use_distillation)
|
||||
config, use_distillation)
|
||||
fetchs.update(metric)
|
||||
|
||||
return fetchs
|
||||
|
@ -254,7 +280,7 @@ def create_optimizer(config):
|
|||
|
||||
# create optimizer instance
|
||||
opt_config = config['OPTIMIZER']
|
||||
opt = OptimizerBuilder(**opt_config)
|
||||
opt = OptimizerBuilder(config, **opt_config)
|
||||
return opt(lr), lr
|
||||
|
||||
|
||||
|
@ -283,13 +309,13 @@ def dist_optimizer(config, optimizer):
|
|||
|
||||
|
||||
def mixed_precision_optimizer(config, optimizer):
|
||||
use_fp16 = config.get('use_fp16', False)
|
||||
amp_scale_loss = config.get('amp_scale_loss', 1.0)
|
||||
use_amp = config.get('use_amp', False)
|
||||
scale_loss = config.get('scale_loss', 1.0)
|
||||
use_dynamic_loss_scaling = config.get('use_dynamic_loss_scaling', False)
|
||||
if use_fp16:
|
||||
if use_amp:
|
||||
optimizer = fluid.contrib.mixed_precision.decorate(
|
||||
optimizer,
|
||||
init_loss_scaling=amp_scale_loss,
|
||||
init_loss_scaling=scale_loss,
|
||||
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||||
|
||||
return optimizer
|
||||
|
@ -320,13 +346,18 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
|||
use_mix = config.get('use_mix') and is_train
|
||||
use_dali = config.get('use_dali', False)
|
||||
use_distillation = config.get('use_distillation')
|
||||
|
||||
image_dtype = "float32"
|
||||
if config["ARCHITECTURE"]["name"] == "ResNet50" and config.get("use_pure_fp16", False) \
|
||||
and config.get("use_dali", False):
|
||||
image_dtype = "float16"
|
||||
feeds = create_feeds(
|
||||
config.image_shape, use_mix=use_mix, use_dali=use_dali)
|
||||
config.image_shape, use_mix=use_mix, use_dali=use_dali, dtype = image_dtype)
|
||||
if use_dali and use_mix:
|
||||
import dali
|
||||
feeds = dali.mix(feeds, config, is_train)
|
||||
out = create_model(config.ARCHITECTURE, feeds['image'],
|
||||
config.classes_num, is_train)
|
||||
config.classes_num, config, is_train)
|
||||
fetchs = create_fetchs(
|
||||
out,
|
||||
feeds,
|
||||
|
@ -335,6 +366,7 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
|||
config.classes_num,
|
||||
epsilon=config.get('ls_epsilon'),
|
||||
use_mix=use_mix,
|
||||
config=config,
|
||||
use_distillation=use_distillation)
|
||||
lr_scheduler = None
|
||||
if is_train:
|
||||
|
@ -342,7 +374,6 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
|||
optimizer = mixed_precision_optimizer(config, optimizer)
|
||||
if is_distributed:
|
||||
optimizer = dist_optimizer(config, optimizer)
|
||||
|
||||
optimizer.minimize(fetchs['loss'][0])
|
||||
return fetchs, lr_scheduler, feeds
|
||||
|
||||
|
@ -364,7 +395,40 @@ def compile(config, program, loss_name=None, share_prog=None):
|
|||
exec_strategy = paddle.static.ExecutionStrategy()
|
||||
|
||||
exec_strategy.num_threads = 1
|
||||
exec_strategy.num_iteration_per_drop_scope = 10
|
||||
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get('use_pure_fp16', False) else 10
|
||||
|
||||
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16', False)
|
||||
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
||||
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
||||
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
||||
enable_addto = config.get('enable_addto', fuse_op)
|
||||
|
||||
try:
|
||||
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle version 1.7.0 or higher is "
|
||||
"required when you want to fuse batch_norm and activation_op.")
|
||||
|
||||
try:
|
||||
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle version 1.7.0 or higher is "
|
||||
"required when you want to fuse elewise_add_act and activation_op.")
|
||||
|
||||
try:
|
||||
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"PaddlePaddle 2.0-rc or higher is "
|
||||
"required when you want to enable fuse_bn_add_act_ops strategy.")
|
||||
|
||||
try:
|
||||
build_strategy.enable_addto = enable_addto
|
||||
except Exception as e:
|
||||
logger.info("PaddlePaddle 2.0-rc or higher is "
|
||||
"required when you want to enable addto strategy.")
|
||||
|
||||
compiled_program = paddle.static.CompiledProgram(
|
||||
program).with_data_parallel(
|
||||
|
|
|
@ -26,6 +26,8 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
|||
from sys import version_info
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from ppcls.data import Reader
|
||||
|
@ -59,11 +61,26 @@ def parse_args():
|
|||
|
||||
|
||||
def main(args):
|
||||
fleet.init(is_collective=True)
|
||||
|
||||
config = get_config(args.config, overrides=args.override, show=True)
|
||||
if config.get("is_distributed", True):
|
||||
fleet.init(is_collective=True)
|
||||
# assign the place
|
||||
use_gpu = config.get("use_gpu", False)
|
||||
assert use_gpu is True, "gpu must be true in static mode!"
|
||||
place = paddle.set_device("gpu")
|
||||
|
||||
# amp related config
|
||||
use_amp = config.get('use_amp', False)
|
||||
use_pure_fp16 = config.get('use_pure_fp16', False)
|
||||
if use_amp or use_pure_fp16:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_exhaustive_search': 1,
|
||||
'FLAGS_conv_workspace_size_limit': 4000,
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||
'FLAGS_max_inplace_grad_add': 8,
|
||||
}
|
||||
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
|
||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
use_xpu = config.get("use_xpu", False)
|
||||
assert (
|
||||
use_gpu and use_xpu
|
||||
|
@ -105,10 +122,17 @@ def main(args):
|
|||
exe = paddle.static.Executor(place)
|
||||
# Parameter initialization
|
||||
exe.run(startup_prog)
|
||||
|
||||
if config.get("use_pure_fp16", False):
|
||||
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
|
||||
# load pretrained models or checkpoints
|
||||
init_model(config, train_prog, exe)
|
||||
|
||||
if not config.get("is_distributed", True):
|
||||
compiled_train_prog = program.compile(
|
||||
config, train_prog, loss_name=train_fetchs["loss"][0].name)
|
||||
else:
|
||||
compiled_train_prog = train_prog
|
||||
|
||||
if not config.get('use_dali', False):
|
||||
train_dataloader = Reader(config, 'train', places=place)()
|
||||
if config.validate and paddle.distributed.get_rank() == 0:
|
||||
|
@ -137,7 +161,7 @@ def main(args):
|
|||
|
||||
for epoch_id in range(config.epochs):
|
||||
# 1. train with train dataset
|
||||
program.run(train_dataloader, exe, train_prog, train_feeds,
|
||||
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
|
||||
train_fetchs, epoch_id, 'train', config, vdl_writer,
|
||||
lr_scheduler)
|
||||
if paddle.distributed.get_rank() == 0:
|
||||
|
|
Loading…
Reference in New Issue