diff --git a/tools/program.py b/tools/program.py index 2015a0fa74..c89d7e3c99 100755 --- a/tools/program.py +++ b/tools/program.py @@ -42,6 +42,13 @@ class ArgsParser(ArgumentParser): self.add_argument("-c", "--config", help="configuration file to use") self.add_argument( "-o", "--opt", nargs='+', help="set configuration options") + self.add_argument( + '-p', + '--profiler_options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) @@ -151,7 +158,8 @@ def train(config, eval_class, pre_best_model_dict, logger, - vdl_writer=None): + vdl_writer=None, + profiler_options=None): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) log_smooth_window = config['Global']['log_smooth_window'] @@ -208,6 +216,7 @@ def train(config, max_iter = len(train_dataloader) - 1 if platform.system( ) == "Windows" else len(train_dataloader) for idx, batch in enumerate(train_dataloader): + profiler.add_profiler_step(profiler_options) train_reader_cost += time.time() - batch_start if idx >= max_iter: break @@ -392,6 +401,7 @@ def eval(model, def preprocess(is_train=False): FLAGS = ArgsParser().parse_args() + profiler_options = FLAGS.profiler_options config = load_config(FLAGS.config) merge_config(FLAGS.opt) @@ -431,4 +441,4 @@ def preprocess(is_train=False): print_dict(config, logger) logger.info('train with paddle {} and device {}'.format(paddle.__version__, device)) - return config, device, logger, vdl_writer + return config, device, logger, vdl_writer, profiler_options diff --git a/tools/train.py b/tools/train.py index 05d295aa99..17a1239040 100755 --- a/tools/train.py +++ b/tools/train.py @@ -41,7 +41,7 @@ import tools.program as program dist.get_world_size() -def main(config, device, logger, vdl_writer): +def main(config, device, logger, vdl_writer, profiler_options): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() @@ -105,7 +105,8 @@ def main(config, device, logger, vdl_writer): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer) + eval_class, pre_best_model_dict, logger, vdl_writer, + profiler_options) def test_reader(config, device, logger): @@ -127,6 +128,8 @@ def test_reader(config, device, logger): if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess(is_train=True) - main(config, device, logger, vdl_writer) + config, device, logger, vdl_writer, profiler_options = program.preprocess( + is_train=True) + main(config, device, logger, vdl_writer, profiler_options) + # test_reader(config, device, logger) # test_reader(config, device, logger)