From b54ee04491d081d98efea0c30737b0563f433ad7 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 20 Dec 2021 06:36:56 +0000 Subject: [PATCH 1/3] Accelerate dynamic graph amp training --- ppcls/engine/engine.py | 2 ++ ppcls/engine/train/train.py | 13 +++++++------ ppcls/optimizer/optimizer.py | 5 ++++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index fe069b1de..fca3a82bb 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -250,6 +250,8 @@ class Engine(object): self.scaler = paddle.amp.GradScaler( init_loss_scaling=self.scale_loss, use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) + if self.config['AMP']['use_pure_fp16'] is True: + self.model = paddle.amp.decorate(models=self.model, level='O2') self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index cbf868e4e..d8f425dc8 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -41,14 +41,15 @@ def train_epoch(engine, epoch_id, print_batch_step): # image input if engine.amp: - with paddle.amp.auto_cast(custom_black_list={ - "flatten_contiguous_range", "greater_than" - }): + amp_level = 'O1' + if engine.config['AMP']['use_pure_fp16'] is True: + amp_level = 'O2' + with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level): out = forward(engine, batch) + loss_dict = engine.train_loss_func(out, batch[1]) else: out = forward(engine, batch) - - loss_dict = engine.train_loss_func(out, batch[1]) + loss_dict = engine.train_loss_func(out, batch[1]) # step opt and lr if engine.amp: @@ -58,7 +59,7 @@ def train_epoch(engine, epoch_id, print_batch_step): else: loss_dict["loss"].backward() engine.optimizer.step() - engine.optimizer.clear_grad() + engine.optimizer.clear_grad(set_to_zero=True) engine.lr_sch.step() # below code just for logging diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index f429755fc..290632d02 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -36,13 +36,15 @@ class Momentum(object): momentum, weight_decay=None, grad_clip=None, - multi_precision=False): + multi_precision=True, + use_multi_tensor=True): super().__init__() self.learning_rate = learning_rate self.momentum = momentum self.weight_decay = weight_decay self.grad_clip = grad_clip self.multi_precision = multi_precision + self.use_multi_tensor = use_multi_tensor def __call__(self, model_list): # model_list is None in static graph @@ -54,6 +56,7 @@ class Momentum(object): weight_decay=self.weight_decay, grad_clip=self.grad_clip, multi_precision=self.multi_precision, + use_multi_tensor=self.use_multi_tensor, parameters=parameters) return opt From 28061f537cca8a93d7e82d0ad0cd16391ebd461d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 21 Dec 2021 06:28:13 +0000 Subject: [PATCH 2/3] refine optimizer init logice --- ppcls/engine/train/train.py | 3 ++- ppcls/optimizer/optimizer.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index d8f425dc8..b7fa9d3a0 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -21,6 +21,7 @@ from ppcls.utils import profiler def train_epoch(engine, epoch_id, print_batch_step): tic = time.time() + v_current = [int(i) for i in paddle.__version__.split(".")] for iter_id, batch in enumerate(engine.train_dataloader): if iter_id >= engine.max_iter: break @@ -59,7 +60,7 @@ def train_epoch(engine, epoch_id, print_batch_step): else: loss_dict["loss"].backward() engine.optimizer.step() - engine.optimizer.clear_grad(set_to_zero=True) + engine.optimizer.clear_grad() engine.lr_sch.step() # below code just for logging diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index 290632d02..fe8348a24 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function from paddle import optimizer as optim +import paddle from ppcls.utils import logger @@ -36,15 +37,13 @@ class Momentum(object): momentum, weight_decay=None, grad_clip=None, - multi_precision=True, - use_multi_tensor=True): + multi_precision=True): super().__init__() self.learning_rate = learning_rate self.momentum = momentum self.weight_decay = weight_decay self.grad_clip = grad_clip self.multi_precision = multi_precision - self.use_multi_tensor = use_multi_tensor def __call__(self, model_list): # model_list is None in static graph @@ -56,8 +55,16 @@ class Momentum(object): weight_decay=self.weight_decay, grad_clip=self.grad_clip, multi_precision=self.multi_precision, - use_multi_tensor=self.use_multi_tensor, parameters=parameters) + if hasattr(opt, '_use_multi_tensor'): + opt = optim.Momentum( + learning_rate=self.learning_rate, + momentum=self.momentum, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + multi_precision=self.multi_precision, + parameters=parameters, + use_multi_tensor=False) return opt From 558f03d68494fc5d286f43cda0089cee7520a8ba Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 21 Dec 2021 06:30:13 +0000 Subject: [PATCH 3/3] refine code --- ppcls/optimizer/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index fe8348a24..4422ea70d 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -64,7 +64,7 @@ class Momentum(object): grad_clip=self.grad_clip, multi_precision=self.multi_precision, parameters=parameters, - use_multi_tensor=False) + use_multi_tensor=True) return opt