diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 73f225087..d78508147 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -16,6 +16,7 @@ from __future__ import absolute_import, division, print_function import time import paddle from ppcls.engine.train.utils import update_loss, update_metric, log_info +from ppcls.utils import profiler def train_epoch(trainer, epoch_id, print_batch_step): @@ -26,6 +27,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): for iter_id, batch in enumerate(train_dataloader): if iter_id >= trainer.max_iter: break + profiler.add_profiler_step(trainer.config["profiler_options"]) if iter_id == 5: for key in trainer.time_info: trainer.time_info[key].reset() diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index b92f0d945..e3277c480 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -199,5 +199,12 @@ def parse_args(): action='append', default=[], help='config options to be overridden') + parser.add_argument( + '-p', + '--profiler_options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) args = parser.parse_args() return args diff --git a/tools/train.py b/tools/train.py index 1d8359036..e7c9d7bcc 100644 --- a/tools/train.py +++ b/tools/train.py @@ -27,5 +27,6 @@ if __name__ == "__main__": args = config.parse_args() config = config.get_config( args.config, overrides=args.override, show=False) + config.profiler_options = args.profiler_options engine = Engine(config, mode="train") engine.train()