new usage of amp training. (#564)
* new usage of amp training. * change the usage of amp and pure fp16 training. * modified code as reviewspull/619/head
parent
ba17052a54
commit
4e43ec6995
|
@ -11,21 +11,23 @@ validate: True
|
|||
valid_interval: 1
|
||||
epochs: 120
|
||||
topk: 5
|
||||
image_shape: [3, 224, 224]
|
||||
is_distributed: True
|
||||
is_distributed: False
|
||||
|
||||
# mixed precision training
|
||||
use_amp: True
|
||||
use_pure_fp16: False
|
||||
multi_precision: False
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
data_format: "NCHW"
|
||||
image_shape: [3, 224, 224]
|
||||
use_dali: True
|
||||
use_gpu: True
|
||||
data_format: "NHWC"
|
||||
image_channel: &image_channel 4
|
||||
image_shape: [*image_channel, 224, 224]
|
||||
|
||||
use_mix: False
|
||||
ls_epsilon: -1
|
||||
|
||||
# mixed precision training
|
||||
AMP:
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
use_pure_fp16: &use_pure_fp16 True
|
||||
|
||||
LEARNING_RATE:
|
||||
function: 'Piecewise'
|
||||
params:
|
||||
|
@ -37,6 +39,7 @@ OPTIMIZER:
|
|||
function: 'Momentum'
|
||||
params:
|
||||
momentum: 0.9
|
||||
multi_precision: *use_pure_fp16
|
||||
regularizer:
|
||||
function: 'L2'
|
||||
factor: 0.000100
|
||||
|
@ -61,6 +64,8 @@ TRAIN:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
channel_num: *image_channel
|
||||
- ToCHWImage:
|
||||
|
||||
VALID:
|
|
@ -195,14 +195,18 @@ class NormalizeImage(object):
|
|||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw'):
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw', output_fp16=False, channel_num=3):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
assert channel_num in [3, 4], "channel number of input image should be set to 3 or 4."
|
||||
self.channel_num = channel_num
|
||||
self.output_dtype = 'float16' if output_fp16 else 'float32'
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
self.order = order
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
||||
shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
|
@ -213,7 +217,16 @@ class NormalizeImage(object):
|
|||
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
return (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
|
||||
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
|
||||
if self.channel_num == 4:
|
||||
img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
|
||||
img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
|
||||
pad_zeros = np.zeros((1, img_h, img_w)) if self.order == 'chw' else np.zeros((img_h, img_w, 1))
|
||||
img = (np.concatenate((img, pad_zeros), axis=0) if self.order == 'chw'
|
||||
else np.concatenate((img, pad_zeros), axis=2))
|
||||
return img.astype(self.output_dtype)
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
|
|
|
@ -277,14 +277,18 @@ class ResNet(nn.Layer):
|
|||
bias_attr=ParamAttr(name="fc_0.b_0"))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv(inputs)
|
||||
y = self.pool2d_max(y)
|
||||
for block in self.block_list:
|
||||
y = block(y)
|
||||
y = self.pool2d_avg(y)
|
||||
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
|
||||
y = self.out(y)
|
||||
return y
|
||||
with paddle.static.amp.fp16_guard():
|
||||
if self.data_format == "NHWC":
|
||||
inputs = paddle.tensor.transpose(inputs, [0, 2, 3, 1])
|
||||
inputs.stop_gradient = True
|
||||
y = self.conv(inputs)
|
||||
y = self.pool2d_max(y)
|
||||
for block in self.block_list:
|
||||
y = block(y)
|
||||
y = self.pool2d_avg(y)
|
||||
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
|
||||
y = self.out(y)
|
||||
return y
|
||||
|
||||
|
||||
def ResNet18(**args):
|
||||
|
|
|
@ -42,17 +42,14 @@ class Loss(object):
|
|||
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
|
||||
return soft_target
|
||||
|
||||
def _crossentropy(self, input, target, use_pure_fp16=False):
|
||||
def _crossentropy(self, input, target):
|
||||
if self._label_smoothing:
|
||||
target = self._labelsmoothing(target)
|
||||
input = -F.log_softmax(input, axis=-1)
|
||||
cost = paddle.sum(target * input, axis=-1)
|
||||
else:
|
||||
cost = F.cross_entropy(input=input, label=target)
|
||||
if use_pure_fp16:
|
||||
avg_cost = paddle.sum(cost)
|
||||
else:
|
||||
avg_cost = paddle.mean(cost)
|
||||
avg_cost = paddle.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
def _kldiv(self, input, target, name=None):
|
||||
|
@ -81,8 +78,8 @@ class CELoss(Loss):
|
|||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(CELoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target, use_pure_fp16=False):
|
||||
cost = self._crossentropy(input, target, use_pure_fp16)
|
||||
def __call__(self, input, target):
|
||||
cost = self._crossentropy(input, target)
|
||||
return cost
|
||||
|
||||
|
||||
|
@ -94,14 +91,11 @@ class MixCELoss(Loss):
|
|||
def __init__(self, class_dim=1000, epsilon=None):
|
||||
super(MixCELoss, self).__init__(class_dim, epsilon)
|
||||
|
||||
def __call__(self, input, target0, target1, lam, use_pure_fp16=False):
|
||||
cost0 = self._crossentropy(input, target0, use_pure_fp16)
|
||||
cost1 = self._crossentropy(input, target1, use_pure_fp16)
|
||||
def __call__(self, input, target0, target1, lam):
|
||||
cost0 = self._crossentropy(input, target0)
|
||||
cost1 = self._crossentropy(input, target1)
|
||||
cost = lam * cost0 + (1.0 - lam) * cost1
|
||||
if use_pure_fp16:
|
||||
avg_cost = paddle.sum(cost)
|
||||
else:
|
||||
avg_cost = paddle.mean(cost)
|
||||
avg_cost = paddle.mean(cost)
|
||||
return avg_cost
|
||||
|
||||
|
||||
|
|
|
@ -74,19 +74,22 @@ class Momentum(object):
|
|||
momentum,
|
||||
parameter_list=None,
|
||||
regularization=None,
|
||||
multi_precision=False,
|
||||
**args):
|
||||
super(Momentum, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.parameter_list = parameter_list
|
||||
self.regularization = regularization
|
||||
self.multi_precision = multi_precision
|
||||
|
||||
def __call__(self):
|
||||
opt = paddle.optimizer.Momentum(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
parameters=self.parameter_list,
|
||||
weight_decay=self.regularization)
|
||||
weight_decay=self.regularization,
|
||||
multi_precision=self.multi_precision)
|
||||
return opt
|
||||
|
||||
|
||||
|
|
|
@ -176,7 +176,11 @@ def build(config, mode='train'):
|
|||
2: types.INTERP_CUBIC, # cv2.INTER_CUBIC
|
||||
4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
|
||||
}
|
||||
output_dtype = types.FLOAT16 if config.get("use_pure_fp16", False) else types.FLOAT
|
||||
|
||||
output_dtype = (types.FLOAT16 if 'AMP' in config and
|
||||
config.AMP.get("use_pure_fp16", False)
|
||||
else types.FLOAT)
|
||||
|
||||
assert interp in interp_map, "interpolation method not supported by DALI"
|
||||
interp = interp_map[interp]
|
||||
pad_output = False
|
||||
|
|
|
@ -1,171 +0,0 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.regularizer as regularizer
|
||||
|
||||
__all__ = ['OptimizerBuilder']
|
||||
|
||||
|
||||
class L1Decay(object):
|
||||
"""
|
||||
L1 Weight Decay Regularization, which encourages the weights to be sparse.
|
||||
|
||||
Args:
|
||||
factor(float): regularization coeff. Default:0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L1Decay, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = regularizer.L1Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
class L2Decay(object):
|
||||
"""
|
||||
L2 Weight Decay Regularization, which encourages the weights to be sparse.
|
||||
|
||||
Args:
|
||||
factor(float): regularization coeff. Default:0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, factor=0.0):
|
||||
super(L2Decay, self).__init__()
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = regularizer.L2Decay(self.factor)
|
||||
return reg
|
||||
|
||||
|
||||
class Momentum(object):
|
||||
"""
|
||||
Simple Momentum optimizer with velocity state.
|
||||
|
||||
Args:
|
||||
learning_rate (float|Variable) - The learning rate used to update parameters.
|
||||
Can be a float value or a Variable with one float value as data element.
|
||||
momentum (float) - Momentum factor.
|
||||
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
momentum,
|
||||
parameter_list=None,
|
||||
regularization=None,
|
||||
config=None,
|
||||
**args):
|
||||
super(Momentum, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.parameter_list = parameter_list
|
||||
self.regularization = regularization
|
||||
self.multi_precision = config.get('multi_precision', False)
|
||||
self.rescale_grad = (1.0 / (config['TRAIN']['batch_size'] / len(fluid.cuda_places()))
|
||||
if config.get('use_pure_fp16', False) else 1.0)
|
||||
|
||||
def __call__(self):
|
||||
opt = fluid.contrib.optimizer.Momentum(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
regularization=self.regularization,
|
||||
multi_precision=self.multi_precision,
|
||||
rescale_grad=self.rescale_grad)
|
||||
return opt
|
||||
|
||||
|
||||
class RMSProp(object):
|
||||
"""
|
||||
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
|
||||
|
||||
Args:
|
||||
learning_rate (float|Variable) - The learning rate used to update parameters.
|
||||
Can be a float value or a Variable with one float value as data element.
|
||||
momentum (float) - Momentum factor.
|
||||
rho (float) - rho value in equation.
|
||||
epsilon (float) - avoid division by zero, default is 1e-6.
|
||||
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate,
|
||||
momentum,
|
||||
rho=0.95,
|
||||
epsilon=1e-6,
|
||||
parameter_list=None,
|
||||
regularization=None,
|
||||
**args):
|
||||
super(RMSProp, self).__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.rho = rho
|
||||
self.epsilon = epsilon
|
||||
self.parameter_list = parameter_list
|
||||
self.regularization = regularization
|
||||
|
||||
def __call__(self):
|
||||
opt = paddle.optimizer.RMSProp(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
rho=self.rho,
|
||||
epsilon=self.epsilon,
|
||||
parameters=self.parameter_list,
|
||||
weight_decay=self.regularization)
|
||||
return opt
|
||||
|
||||
|
||||
class OptimizerBuilder(object):
|
||||
"""
|
||||
Build optimizer
|
||||
|
||||
Args:
|
||||
function(str): optimizer name of learning rate
|
||||
params(dict): parameters used for init the class
|
||||
regularizer (dict): parameters used for create regularization
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config=None,
|
||||
function='Momentum',
|
||||
params={'momentum': 0.9},
|
||||
regularizer=None):
|
||||
self.function = function
|
||||
self.params = params
|
||||
self.config = config
|
||||
# create regularizer
|
||||
if regularizer is not None:
|
||||
mod = sys.modules[__name__]
|
||||
reg_func = regularizer['function'] + 'Decay'
|
||||
del regularizer['function']
|
||||
reg = getattr(mod, reg_func)(**regularizer)()
|
||||
self.params['regularization'] = reg
|
||||
|
||||
def __call__(self, learning_rate, parameter_list=None):
|
||||
mod = sys.modules[__name__]
|
||||
opt = getattr(mod, self.function)
|
||||
return opt(learning_rate=learning_rate,
|
||||
parameter_list=parameter_list,
|
||||
config=self.config,
|
||||
**self.params)()
|
|
@ -21,12 +21,10 @@ import time
|
|||
import numpy as np
|
||||
|
||||
from collections import OrderedDict
|
||||
from optimizer import OptimizerBuilder
|
||||
from ppcls.optimizer import OptimizerBuilder
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import fluid
|
||||
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
|
||||
|
||||
from ppcls.optimizer.learning_rate import LearningRateBuilder
|
||||
from ppcls.modeling import architectures
|
||||
|
@ -83,11 +81,9 @@ def create_model(architecture, image, classes_num, config, is_train):
|
|||
Returns:
|
||||
out(variable): model output variable
|
||||
"""
|
||||
use_pure_fp16 = config.get("use_pure_fp16", False)
|
||||
name = architecture["name"]
|
||||
params = architecture.get("params", {})
|
||||
|
||||
data_format = "NCHW"
|
||||
if "data_format" in config:
|
||||
params["data_format"] = config["data_format"]
|
||||
data_format = config["data_format"]
|
||||
|
@ -100,16 +96,8 @@ def create_model(architecture, image, classes_num, config, is_train):
|
|||
if "is_test" in params:
|
||||
params['is_test'] = not is_train
|
||||
model = architectures.__dict__[name](class_dim=classes_num, **params)
|
||||
|
||||
if use_pure_fp16 and not config.get("use_dali", False):
|
||||
image = image.astype('float16')
|
||||
if data_format == "NHWC":
|
||||
image = paddle.tensor.transpose(image, [0, 2, 3, 1])
|
||||
image.stop_gradient = True
|
||||
|
||||
out = model(image)
|
||||
if config.get("use_pure_fp16", False):
|
||||
cast_model_to_fp16(paddle.static.default_main_program())
|
||||
out = out.astype('float32')
|
||||
return out
|
||||
|
||||
|
||||
|
@ -119,8 +107,7 @@ def create_loss(out,
|
|||
classes_num=1000,
|
||||
epsilon=None,
|
||||
use_mix=False,
|
||||
use_distillation=False,
|
||||
use_pure_fp16=False):
|
||||
use_distillation=False):
|
||||
"""
|
||||
Create a loss for optimization, such as:
|
||||
1. CrossEnotry loss
|
||||
|
@ -137,7 +124,6 @@ def create_loss(out,
|
|||
classes_num(int): num of classes
|
||||
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
|
||||
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
|
||||
use_pure_fp16(bool): whether to use pure fp16 data as training parameter
|
||||
|
||||
Returns:
|
||||
loss(variable): loss variable
|
||||
|
@ -162,10 +148,10 @@ def create_loss(out,
|
|||
|
||||
if use_mix:
|
||||
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
|
||||
return loss(out, feed_y_a, feed_y_b, feed_lam, use_pure_fp16)
|
||||
return loss(out, feed_y_a, feed_y_b, feed_lam)
|
||||
else:
|
||||
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
|
||||
return loss(out, target, use_pure_fp16)
|
||||
return loss(out, target)
|
||||
|
||||
|
||||
def create_metric(out,
|
||||
|
@ -239,9 +225,8 @@ def create_fetchs(out,
|
|||
fetchs(dict): dict of model outputs(included loss and measures)
|
||||
"""
|
||||
fetchs = OrderedDict()
|
||||
use_pure_fp16 = config.get("use_pure_fp16", False)
|
||||
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
|
||||
use_distillation, use_pure_fp16)
|
||||
use_distillation)
|
||||
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
|
||||
if not use_mix:
|
||||
metric = create_metric(out, feeds, architecture, topk, classes_num,
|
||||
|
@ -285,7 +270,7 @@ def create_optimizer(config):
|
|||
|
||||
# create optimizer instance
|
||||
opt_config = config['OPTIMIZER']
|
||||
opt = OptimizerBuilder(config, **opt_config)
|
||||
opt = OptimizerBuilder(**opt_config)
|
||||
return opt(lr), lr
|
||||
|
||||
|
||||
|
@ -304,11 +289,11 @@ def create_strategy(config):
|
|||
exec_strategy = paddle.static.ExecutionStrategy()
|
||||
|
||||
exec_strategy.num_threads = 1
|
||||
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
|
||||
'use_pure_fp16', False) else 10
|
||||
exec_strategy.num_iteration_per_drop_scope = (10000 if 'AMP' in config and
|
||||
config.AMP.get("use_pure_fp16", False) else 10)
|
||||
|
||||
fuse_op = True if 'AMP' in config else False
|
||||
|
||||
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
|
||||
False)
|
||||
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
|
||||
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
|
||||
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
|
||||
|
@ -369,14 +354,17 @@ def dist_optimizer(config, optimizer):
|
|||
|
||||
|
||||
def mixed_precision_optimizer(config, optimizer):
|
||||
use_amp = config.get('use_amp', False)
|
||||
scale_loss = config.get('scale_loss', 1.0)
|
||||
use_dynamic_loss_scaling = config.get('use_dynamic_loss_scaling', False)
|
||||
if use_amp:
|
||||
optimizer = fluid.contrib.mixed_precision.decorate(
|
||||
if 'AMP' in config:
|
||||
amp_cfg = config.AMP if config.AMP else dict()
|
||||
scale_loss = amp_cfg.get('scale_loss', 1.0)
|
||||
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', False)
|
||||
use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
|
||||
optimizer = paddle.static.amp.decorate(
|
||||
optimizer,
|
||||
init_loss_scaling=scale_loss,
|
||||
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||||
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
|
||||
use_pure_fp16=use_pure_fp16,
|
||||
use_fp16_guard=True)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
@ -407,15 +395,11 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
|||
use_dali = config.get('use_dali', False)
|
||||
use_distillation = config.get('use_distillation')
|
||||
|
||||
image_dtype = "float32"
|
||||
if config["ARCHITECTURE"]["name"] == "ResNet50" and config.get("use_pure_fp16", False) \
|
||||
and config.get("use_dali", False):
|
||||
image_dtype = "float16"
|
||||
feeds = create_feeds(
|
||||
config.image_shape,
|
||||
use_mix=use_mix,
|
||||
use_dali=use_dali,
|
||||
dtype=image_dtype)
|
||||
dtype="float32")
|
||||
if use_dali and use_mix:
|
||||
import dali
|
||||
feeds = dali.mix(feeds, config, is_train)
|
||||
|
@ -432,13 +416,14 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
|||
config=config,
|
||||
use_distillation=use_distillation)
|
||||
lr_scheduler = None
|
||||
optimizer = None
|
||||
if is_train:
|
||||
optimizer, lr_scheduler = create_optimizer(config)
|
||||
optimizer = mixed_precision_optimizer(config, optimizer)
|
||||
if is_distributed:
|
||||
optimizer = dist_optimizer(config, optimizer)
|
||||
optimizer.minimize(fetchs['loss'][0])
|
||||
return fetchs, lr_scheduler, feeds
|
||||
return fetchs, lr_scheduler, feeds, optimizer
|
||||
|
||||
|
||||
def compile(config, program, loss_name=None, share_prog=None):
|
||||
|
|
|
@ -26,8 +26,6 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
|
|||
from sys import version_info
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from ppcls.data import Reader
|
||||
|
@ -67,9 +65,7 @@ def main(args):
|
|||
# assign the place
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
# amp related config
|
||||
use_amp = config.get('use_amp', False)
|
||||
use_pure_fp16 = config.get('use_pure_fp16', False)
|
||||
if use_amp or use_pure_fp16:
|
||||
if 'AMP' in config:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_exhaustive_search': 1,
|
||||
'FLAGS_conv_workspace_size_limit': 4000,
|
||||
|
@ -97,7 +93,7 @@ def main(args):
|
|||
|
||||
best_top1_acc = 0.0 # best top1 acc record
|
||||
|
||||
train_fetchs, lr_scheduler, train_feeds = program.build(
|
||||
train_fetchs, lr_scheduler, train_feeds, optimizer = program.build(
|
||||
config,
|
||||
train_prog,
|
||||
startup_prog,
|
||||
|
@ -106,7 +102,7 @@ def main(args):
|
|||
|
||||
if config.validate:
|
||||
valid_prog = paddle.static.Program()
|
||||
valid_fetchs, _, valid_feeds = program.build(
|
||||
valid_fetchs, _, valid_feeds, _ = program.build(
|
||||
config,
|
||||
valid_prog,
|
||||
startup_prog,
|
||||
|
@ -119,11 +115,14 @@ def main(args):
|
|||
exe = paddle.static.Executor(place)
|
||||
# Parameter initialization
|
||||
exe.run(startup_prog)
|
||||
if config.get("use_pure_fp16", False):
|
||||
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
|
||||
# load pretrained models or checkpoints
|
||||
init_model(config, train_prog, exe)
|
||||
|
||||
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
|
||||
optimizer.amp_init(place,
|
||||
scope=paddle.static.global_scope(),
|
||||
test_program=valid_prog if config.validate else None)
|
||||
|
||||
if not config.get("is_distributed", True) and not use_xpu:
|
||||
compiled_train_prog = program.compile(
|
||||
config, train_prog, loss_name=train_fetchs["loss"][0].name)
|
||||
|
|
Loading…
Reference in New Issue