Merge pull request #1585 from zhangbo9674/dev/resnet50_optimize
Accelerate dynamic graph amp trainingpull/1580/head
commit
f45f9ee4d4
|
@ -250,6 +250,8 @@ class Engine(object):
|
|||
self.scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=self.scale_loss,
|
||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||
if self.config['AMP']['use_pure_fp16'] is True:
|
||||
self.model = paddle.amp.decorate(models=self.model, level='O2')
|
||||
|
||||
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(self.train_dataloader)
|
||||
|
|
|
@ -21,6 +21,7 @@ from ppcls.utils import profiler
|
|||
|
||||
def train_epoch(engine, epoch_id, print_batch_step):
|
||||
tic = time.time()
|
||||
v_current = [int(i) for i in paddle.__version__.split(".")]
|
||||
for iter_id, batch in enumerate(engine.train_dataloader):
|
||||
if iter_id >= engine.max_iter:
|
||||
break
|
||||
|
@ -41,14 +42,15 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
|
||||
# image input
|
||||
if engine.amp:
|
||||
with paddle.amp.auto_cast(custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
}):
|
||||
amp_level = 'O1'
|
||||
if engine.config['AMP']['use_pure_fp16'] is True:
|
||||
amp_level = 'O2'
|
||||
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
|
||||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
else:
|
||||
out = forward(engine, batch)
|
||||
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
|
||||
# step opt and lr
|
||||
if engine.amp:
|
||||
|
|
|
@ -17,6 +17,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from paddle import optimizer as optim
|
||||
import paddle
|
||||
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
@ -36,7 +37,7 @@ class Momentum(object):
|
|||
momentum,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
multi_precision=False):
|
||||
multi_precision=True):
|
||||
super().__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
|
@ -55,6 +56,15 @@ class Momentum(object):
|
|||
grad_clip=self.grad_clip,
|
||||
multi_precision=self.multi_precision,
|
||||
parameters=parameters)
|
||||
if hasattr(opt, '_use_multi_tensor'):
|
||||
opt = optim.Momentum(
|
||||
learning_rate=self.learning_rate,
|
||||
momentum=self.momentum,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip=self.grad_clip,
|
||||
multi_precision=self.multi_precision,
|
||||
parameters=parameters,
|
||||
use_multi_tensor=True)
|
||||
return opt
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue