Add the profiler back for static training. (#1094)
parent
274f819040
commit
00455839f9
|
@ -38,7 +38,7 @@ from ppcls.optimizer import build_optimizer
|
|||
from ppcls.optimizer import build_lr_scheduler
|
||||
|
||||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils import logger, profiler
|
||||
|
||||
|
||||
def create_feeds(image_shape, use_mix=None, dtype="float32"):
|
||||
|
@ -326,7 +326,8 @@ def run(dataloader,
|
|||
mode='train',
|
||||
config=None,
|
||||
vdl_writer=None,
|
||||
lr_scheduler=None):
|
||||
lr_scheduler=None,
|
||||
profiler_options=None):
|
||||
"""
|
||||
Feed data to the model and fetch the measures and loss
|
||||
|
||||
|
@ -382,6 +383,8 @@ def run(dataloader,
|
|||
|
||||
metric_dict['reader_time'].update(time.time() - tic)
|
||||
|
||||
profiler.add_profiler_step(profiler_options)
|
||||
|
||||
if use_dali:
|
||||
batch_size = batch[0]["data"].shape()[0]
|
||||
feed_dict = batch[0]
|
||||
|
|
|
@ -43,6 +43,13 @@ def parse_args():
|
|||
type=str,
|
||||
default='configs/ResNet/ResNet50.yaml',
|
||||
help='config file path')
|
||||
parser.add_argument(
|
||||
'-p',
|
||||
'--profiler_options',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
|
@ -166,7 +173,7 @@ def main(args):
|
|||
# 1. train with train dataset
|
||||
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
|
||||
train_fetchs, epoch_id, 'train', config, vdl_writer,
|
||||
lr_scheduler)
|
||||
lr_scheduler, args.profiler_options)
|
||||
# 2. evaate with eval dataset
|
||||
if global_config["eval_during_train"] and epoch_id % global_config[
|
||||
"eval_interval"] == 0:
|
||||
|
|
Loading…
Reference in New Issue