mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
add profile
This commit is contained in:
parent
3b2e50a16a
commit
d89c6b4308
@ -42,6 +42,13 @@ class ArgsParser(ArgumentParser):
|
|||||||
self.add_argument("-c", "--config", help="configuration file to use")
|
self.add_argument("-c", "--config", help="configuration file to use")
|
||||||
self.add_argument(
|
self.add_argument(
|
||||||
"-o", "--opt", nargs='+', help="set configuration options")
|
"-o", "--opt", nargs='+', help="set configuration options")
|
||||||
|
self.add_argument(
|
||||||
|
'-p',
|
||||||
|
'--profiler_options',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
||||||
|
)
|
||||||
|
|
||||||
def parse_args(self, argv=None):
|
def parse_args(self, argv=None):
|
||||||
args = super(ArgsParser, self).parse_args(argv)
|
args = super(ArgsParser, self).parse_args(argv)
|
||||||
@ -151,7 +158,8 @@ def train(config,
|
|||||||
eval_class,
|
eval_class,
|
||||||
pre_best_model_dict,
|
pre_best_model_dict,
|
||||||
logger,
|
logger,
|
||||||
vdl_writer=None):
|
vdl_writer=None,
|
||||||
|
profiler_options=None):
|
||||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||||
False)
|
False)
|
||||||
log_smooth_window = config['Global']['log_smooth_window']
|
log_smooth_window = config['Global']['log_smooth_window']
|
||||||
@ -208,6 +216,7 @@ def train(config,
|
|||||||
max_iter = len(train_dataloader) - 1 if platform.system(
|
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||||
) == "Windows" else len(train_dataloader)
|
) == "Windows" else len(train_dataloader)
|
||||||
for idx, batch in enumerate(train_dataloader):
|
for idx, batch in enumerate(train_dataloader):
|
||||||
|
profiler.add_profiler_step(profiler_options)
|
||||||
train_reader_cost += time.time() - batch_start
|
train_reader_cost += time.time() - batch_start
|
||||||
if idx >= max_iter:
|
if idx >= max_iter:
|
||||||
break
|
break
|
||||||
@ -392,6 +401,7 @@ def eval(model,
|
|||||||
|
|
||||||
def preprocess(is_train=False):
|
def preprocess(is_train=False):
|
||||||
FLAGS = ArgsParser().parse_args()
|
FLAGS = ArgsParser().parse_args()
|
||||||
|
profiler_options = FLAGS.profiler_options
|
||||||
config = load_config(FLAGS.config)
|
config = load_config(FLAGS.config)
|
||||||
merge_config(FLAGS.opt)
|
merge_config(FLAGS.opt)
|
||||||
|
|
||||||
@ -431,4 +441,4 @@ def preprocess(is_train=False):
|
|||||||
print_dict(config, logger)
|
print_dict(config, logger)
|
||||||
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
||||||
device))
|
device))
|
||||||
return config, device, logger, vdl_writer
|
return config, device, logger, vdl_writer, profiler_options
|
||||||
|
@ -41,7 +41,7 @@ import tools.program as program
|
|||||||
dist.get_world_size()
|
dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
def main(config, device, logger, vdl_writer):
|
def main(config, device, logger, vdl_writer, profiler_options):
|
||||||
# init dist environment
|
# init dist environment
|
||||||
if config['Global']['distributed']:
|
if config['Global']['distributed']:
|
||||||
dist.init_parallel_env()
|
dist.init_parallel_env()
|
||||||
@ -105,7 +105,8 @@ def main(config, device, logger, vdl_writer):
|
|||||||
# start train
|
# start train
|
||||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
eval_class, pre_best_model_dict, logger, vdl_writer,
|
||||||
|
profiler_options)
|
||||||
|
|
||||||
|
|
||||||
def test_reader(config, device, logger):
|
def test_reader(config, device, logger):
|
||||||
@ -127,6 +128,8 @@ def test_reader(config, device, logger):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
config, device, logger, vdl_writer, profiler_options = program.preprocess(
|
||||||
main(config, device, logger, vdl_writer)
|
is_train=True)
|
||||||
|
main(config, device, logger, vdl_writer, profiler_options)
|
||||||
|
# test_reader(config, device, logger)
|
||||||
# test_reader(config, device, logger)
|
# test_reader(config, device, logger)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user