refactor: change params to be consistent with amp
parent
d6d5efe055
commit
7040ce8314
|
@ -22,7 +22,8 @@ Global:
|
|||
AMP:
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
use_pure_fp16: &use_pure_fp16 False
|
||||
# O1: mixed fp16
|
||||
level: O1
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
|
@ -44,6 +45,7 @@ Loss:
|
|||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
multi_precision: True
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.1
|
||||
|
@ -74,12 +76,11 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
channel_num: *image_channel
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 256
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
|
@ -104,7 +105,6 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
channel_num: *image_channel
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
|
@ -131,7 +131,6 @@ Infer:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
channel_num: *image_channel
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
|
@ -10,8 +10,8 @@ Global:
|
|||
epochs: 120
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_channel: &image_channel 4
|
||||
# used for static mode and model export
|
||||
image_shape: [*image_channel, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
# training model under @to_static
|
||||
|
@ -22,7 +22,8 @@ Global:
|
|||
AMP:
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
use_pure_fp16: &use_pure_fp16 True
|
||||
# O2: pure fp16
|
||||
level: O2
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
|
@ -43,7 +44,7 @@ Loss:
|
|||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
multi_precision: *use_pure_fp16
|
||||
multi_precision: True
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.1
|
||||
|
@ -74,7 +75,7 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
|
||||
sampler:
|
||||
|
@ -104,7 +105,7 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
|
@ -131,7 +132,7 @@ Infer:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
|
@ -35,11 +35,13 @@ Loss:
|
|||
AMP:
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
use_pure_fp16: &use_pure_fp16 True
|
||||
# O2: pure fp16
|
||||
level: O2
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
multi_precision: True
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.1
|
||||
|
@ -67,7 +69,7 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
|
@ -96,7 +98,7 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
sampler:
|
||||
name: BatchSampler
|
||||
|
@ -123,7 +125,7 @@ Infer:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: *use_pure_fp16
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
|
@ -217,8 +217,14 @@ class Engine(object):
|
|||
self.scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=self.scale_loss,
|
||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||
if self.config['AMP']['use_pure_fp16'] is True:
|
||||
self.model = paddle.amp.decorate(models=self.model, level='O2', save_dtype='float32')
|
||||
amp_level = self.config['AMP'].get("level", "O1")
|
||||
if amp_level not in ["O1", "O2"]:
|
||||
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
|
||||
logger.warning(msg)
|
||||
self.config['AMP']["level"] = "O1"
|
||||
amp_level = "O1"
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model, level=amp_level, save_dtype='float32')
|
||||
|
||||
# for distributed
|
||||
self.config["Global"][
|
||||
|
|
|
@ -59,10 +59,12 @@ def classification_eval(engine, epoch_id=0):
|
|||
|
||||
# image input
|
||||
if engine.amp:
|
||||
amp_level = 'O1'
|
||||
if engine.config['AMP']['use_pure_fp16'] is True:
|
||||
amp_level = 'O2'
|
||||
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
out = engine.model(batch[0])
|
||||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
|
@ -70,7 +72,8 @@ def classification_eval(engine, epoch_id=0):
|
|||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
else:
|
||||
out = engine.model(batch[0])
|
||||
# calc loss
|
||||
|
@ -79,7 +82,8 @@ def classification_eval(engine, epoch_id=0):
|
|||
for key in loss_dict:
|
||||
if key not in output_info:
|
||||
output_info[key] = AverageMeter(key, '7.5f')
|
||||
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
|
||||
output_info[key].update(loss_dict[key].numpy()[0],
|
||||
batch_size)
|
||||
|
||||
# just for DistributedBatchSampler issue: repeat sampling
|
||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||
|
|
|
@ -42,10 +42,12 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
|
||||
# image input
|
||||
if engine.amp:
|
||||
amp_level = 'O1'
|
||||
if engine.config['AMP']['use_pure_fp16'] is True:
|
||||
amp_level = 'O2'
|
||||
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
else:
|
||||
|
|
|
@ -158,7 +158,7 @@ def create_strategy(config):
|
|||
exec_strategy.num_threads = 1
|
||||
exec_strategy.num_iteration_per_drop_scope = (
|
||||
10000
|
||||
if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)
|
||||
if 'AMP' in config and config.AMP.get("level", "O1") == "O2" else 10)
|
||||
|
||||
fuse_op = True if 'AMP' in config else False
|
||||
|
||||
|
@ -206,7 +206,7 @@ def mixed_precision_optimizer(config, optimizer):
|
|||
scale_loss = amp_cfg.get('scale_loss', 1.0)
|
||||
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
|
||||
False)
|
||||
use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
|
||||
use_pure_fp16 = amp_cfg.get("level", "O1") == "O2"
|
||||
optimizer = paddle.static.amp.decorate(
|
||||
optimizer,
|
||||
init_loss_scaling=scale_loss,
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
export FLAGS_fraction_of_gpu_memory_to_use=0.80
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
python3.7 -m paddle.distributed.launch \
|
||||
--gpus="0,1,2,3,4,5,6,7" \
|
||||
--gpus="0,1,2,3" \
|
||||
ppcls/static/train.py \
|
||||
-c ./ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml \
|
||||
-o Global.use_dali=True
|
||||
|
||||
-c ./ppcls/configs/ImageNet/ResNet/ResNet50_amp_O1.yaml
|
||||
|
|
|
@ -158,7 +158,7 @@ def main(args):
|
|||
# load pretrained models or checkpoints
|
||||
init_model(global_config, train_prog, exe)
|
||||
|
||||
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
|
||||
if 'AMP' in config and config.AMP.get("level", "O1") == "O2":
|
||||
optimizer.amp_init(
|
||||
device,
|
||||
scope=paddle.static.global_scope(),
|
||||
|
|
Loading…
Reference in New Issue