diff --git a/mmcv/runner/hooks/profiler.py b/mmcv/runner/hooks/profiler.py index fef9adc13..6b0fc4b86 100644 --- a/mmcv/runner/hooks/profiler.py +++ b/mmcv/runner/hooks/profiler.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp import warnings from typing import Callable, List, Optional, Union @@ -131,6 +132,15 @@ class ProfilerHook(Hook): raise ImportError('please run "pip install ' 'torch-tb-profiler" to install ' 'torch_tb_profiler') + if 'dir_name' not in trace_cfg: + trace_cfg['dir_name'] = osp.join(runner.work_dir, + 'tf_tracing_logs') + elif not osp.isabs(trace_cfg['dir_name']): + trace_cfg['dir_name'] = osp.join(runner.work_dir, + trace_cfg['dir_name']) + runner.logger.info( + 'tracing files of ProfilerHook will be saved to ' + f"{trace_cfg['dir_name']}.") _on_trace_ready = torch.profiler.tensorboard_trace_handler( **trace_cfg) else: