[Fix] Support Runner dump cfg without filename (#228)

* fix runner dump cfg

* convert dict cfg to Config
pull/236/head
Mashiro 2022-05-17 17:32:10 +08:00 committed by GitHub
parent 84e60f6463
commit fd962437e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 15 deletions

View File

@ -198,7 +198,7 @@ class Runner:
>>> runner.train()
>>> runner.test()
"""
cfg: ConfigType
cfg: Config
train_loop: Optional[Union[BaseLoop, Dict]]
val_loop: Optional[Union[BaseLoop, Dict]]
test_loop: Optional[Union[BaseLoop, Dict]]
@ -237,9 +237,12 @@ class Runner:
# recursively copy the `cfg` because `self.cfg` will be modified
# everywhere.
if cfg is not None:
self.cfg = copy.deepcopy(cfg)
if isinstance(cfg, Config):
self.cfg = copy.deepcopy(cfg)
elif isinstance(cfg, dict):
self.cfg = Config(cfg)
else:
self.cfg = dict()
self.cfg = Config(dict())
self._epoch = 0
self._iter = 0
@ -313,9 +316,8 @@ class Runner:
if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}'
elif self.cfg.get('filename') is not None:
filename_no_ext = osp.splitext(osp.basename(
self.cfg['filename']))[0]
elif self.cfg.filename is not None:
filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0]
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
else:
self._experiment_name = self.timestamp
@ -1596,10 +1598,8 @@ class Runner:
@master_only
def dump_config(self) -> None:
"""Dump config to `work_dir`."""
if isinstance(self.cfg,
Config) and self.cfg.get('filename') is not None:
self.cfg.dump(
osp.join(self.work_dir, osp.basename(self.cfg.filename)))
elif self.cfg:
# TODO
pass
if self.cfg.filename is not None:
filename = osp.basename(self.cfg.filename)
else:
filename = f'{self.timestamp}.py'
self.cfg.dump(osp.join(self.work_dir, filename))

View File

@ -398,8 +398,26 @@ class TestRunner(TestCase):
runner.train()
runner.test()
# 5. test `dump_config`
# TODO
def test_dump_config(self):
# dump config from dict.
cfg = copy.deepcopy(self.epoch_based_cfg)
for idx, cfg in enumerate((cfg, cfg._cfg_dict)):
cfg.experiment_name = f'test_dump{idx}'
runner = Runner.from_cfg(cfg=cfg)
assert osp.exists(
osp.join(runner.work_dir, f'{runner.timestamp}.py'))
# dump config from file.
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix='.py')
file_cfg = Config(
self.epoch_based_cfg._cfg_dict,
filename=temp_config_file.name)
file_cfg.experiment_name = f'test_dump2{idx}'
runner = Runner.from_cfg(cfg=file_cfg)
assert osp.exists(
osp.join(runner.work_dir,
osp.basename(temp_config_file.name)))
def test_from_cfg(self):
runner = Runner.from_cfg(cfg=self.epoch_based_cfg)