Add the profiler back for static training. (#1094)

pull/1104/head
Yiqun Liu 2021-07-29 10:18:45 +08:00 committed by GitHub
parent 274f819040
commit 00455839f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -38,7 +38,7 @@ from ppcls.optimizer import build_optimizer
from ppcls.optimizer import build_lr_scheduler
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils import logger, profiler
def create_feeds(image_shape, use_mix=None, dtype="float32"):
@ -326,7 +326,8 @@ def run(dataloader,
mode='train',
config=None,
vdl_writer=None,
lr_scheduler=None):
lr_scheduler=None,
profiler_options=None):
"""
Feed data to the model and fetch the measures and loss
@ -382,6 +383,8 @@ def run(dataloader,
metric_dict['reader_time'].update(time.time() - tic)
profiler.add_profiler_step(profiler_options)
if use_dali:
batch_size = batch[0]["data"].shape()[0]
feed_dict = batch[0]

View File

@ -43,6 +43,13 @@ def parse_args():
type=str,
default='configs/ResNet/ResNet50.yaml',
help='config file path')
parser.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
parser.add_argument(
'-o',
'--override',
@ -166,7 +173,7 @@ def main(args):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
train_fetchs, epoch_id, 'train', config, vdl_writer,
lr_scheduler)
lr_scheduler, args.profiler_options)
# 2. evaate with eval dataset
if global_config["eval_during_train"] and epoch_id % global_config[
"eval_interval"] == 0: