mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Accelerate dynamic graph amp training
This commit is contained in:
parent
7732a69f1b
commit
b54ee04491
@ -250,6 +250,8 @@ class Engine(object):
|
|||||||
self.scaler = paddle.amp.GradScaler(
|
self.scaler = paddle.amp.GradScaler(
|
||||||
init_loss_scaling=self.scale_loss,
|
init_loss_scaling=self.scale_loss,
|
||||||
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
|
||||||
|
if self.config['AMP']['use_pure_fp16'] is True:
|
||||||
|
self.model = paddle.amp.decorate(models=self.model, level='O2')
|
||||||
|
|
||||||
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
|
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
|
||||||
) == "Windows" else len(self.train_dataloader)
|
) == "Windows" else len(self.train_dataloader)
|
||||||
|
@ -41,13 +41,14 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||||||
|
|
||||||
# image input
|
# image input
|
||||||
if engine.amp:
|
if engine.amp:
|
||||||
with paddle.amp.auto_cast(custom_black_list={
|
amp_level = 'O1'
|
||||||
"flatten_contiguous_range", "greater_than"
|
if engine.config['AMP']['use_pure_fp16'] is True:
|
||||||
}):
|
amp_level = 'O2'
|
||||||
|
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
|
||||||
out = forward(engine, batch)
|
out = forward(engine, batch)
|
||||||
|
loss_dict = engine.train_loss_func(out, batch[1])
|
||||||
else:
|
else:
|
||||||
out = forward(engine, batch)
|
out = forward(engine, batch)
|
||||||
|
|
||||||
loss_dict = engine.train_loss_func(out, batch[1])
|
loss_dict = engine.train_loss_func(out, batch[1])
|
||||||
|
|
||||||
# step opt and lr
|
# step opt and lr
|
||||||
@ -58,7 +59,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||||||
else:
|
else:
|
||||||
loss_dict["loss"].backward()
|
loss_dict["loss"].backward()
|
||||||
engine.optimizer.step()
|
engine.optimizer.step()
|
||||||
engine.optimizer.clear_grad()
|
engine.optimizer.clear_grad(set_to_zero=True)
|
||||||
engine.lr_sch.step()
|
engine.lr_sch.step()
|
||||||
|
|
||||||
# below code just for logging
|
# below code just for logging
|
||||||
|
@ -36,13 +36,15 @@ class Momentum(object):
|
|||||||
momentum,
|
momentum,
|
||||||
weight_decay=None,
|
weight_decay=None,
|
||||||
grad_clip=None,
|
grad_clip=None,
|
||||||
multi_precision=False):
|
multi_precision=True,
|
||||||
|
use_multi_tensor=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.grad_clip = grad_clip
|
self.grad_clip = grad_clip
|
||||||
self.multi_precision = multi_precision
|
self.multi_precision = multi_precision
|
||||||
|
self.use_multi_tensor = use_multi_tensor
|
||||||
|
|
||||||
def __call__(self, model_list):
|
def __call__(self, model_list):
|
||||||
# model_list is None in static graph
|
# model_list is None in static graph
|
||||||
@ -54,6 +56,7 @@ class Momentum(object):
|
|||||||
weight_decay=self.weight_decay,
|
weight_decay=self.weight_decay,
|
||||||
grad_clip=self.grad_clip,
|
grad_clip=self.grad_clip,
|
||||||
multi_precision=self.multi_precision,
|
multi_precision=self.multi_precision,
|
||||||
|
use_multi_tensor=self.use_multi_tensor,
|
||||||
parameters=parameters)
|
parameters=parameters)
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user