mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
add profile
This commit is contained in:
parent
3b2e50a16a
commit
d89c6b4308
@ -42,6 +42,13 @@ class ArgsParser(ArgumentParser):
|
||||
self.add_argument("-c", "--config", help="configuration file to use")
|
||||
self.add_argument(
|
||||
"-o", "--opt", nargs='+', help="set configuration options")
|
||||
self.add_argument(
|
||||
'-p',
|
||||
'--profiler_options',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
||||
)
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
args = super(ArgsParser, self).parse_args(argv)
|
||||
@ -151,7 +158,8 @@ def train(config,
|
||||
eval_class,
|
||||
pre_best_model_dict,
|
||||
logger,
|
||||
vdl_writer=None):
|
||||
vdl_writer=None,
|
||||
profiler_options=None):
|
||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||
False)
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
@ -208,6 +216,7 @@ def train(config,
|
||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
profiler.add_profiler_step(profiler_options)
|
||||
train_reader_cost += time.time() - batch_start
|
||||
if idx >= max_iter:
|
||||
break
|
||||
@ -392,6 +401,7 @@ def eval(model,
|
||||
|
||||
def preprocess(is_train=False):
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
profiler_options = FLAGS.profiler_options
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
|
||||
@ -431,4 +441,4 @@ def preprocess(is_train=False):
|
||||
print_dict(config, logger)
|
||||
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
||||
device))
|
||||
return config, device, logger, vdl_writer
|
||||
return config, device, logger, vdl_writer, profiler_options
|
||||
|
@ -41,7 +41,7 @@ import tools.program as program
|
||||
dist.get_world_size()
|
||||
|
||||
|
||||
def main(config, device, logger, vdl_writer):
|
||||
def main(config, device, logger, vdl_writer, profiler_options):
|
||||
# init dist environment
|
||||
if config['Global']['distributed']:
|
||||
dist.init_parallel_env()
|
||||
@ -105,7 +105,8 @@ def main(config, device, logger, vdl_writer):
|
||||
# start train
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer,
|
||||
profiler_options)
|
||||
|
||||
|
||||
def test_reader(config, device, logger):
|
||||
@ -127,6 +128,8 @@ def test_reader(config, device, logger):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
||||
main(config, device, logger, vdl_writer)
|
||||
config, device, logger, vdl_writer, profiler_options = program.preprocess(
|
||||
is_train=True)
|
||||
main(config, device, logger, vdl_writer, profiler_options)
|
||||
# test_reader(config, device, logger)
|
||||
# test_reader(config, device, logger)
|
||||
|
Loading…
x
Reference in New Issue
Block a user