[Fix] Support Runner dump cfg without filename (#228)

* fix runner dump cfg

* convert dict cfg to Config
This commit is contained in:
Mashiro 2022-05-17 17:32:10 +08:00 committed by GitHub
parent 84e60f6463
commit fd962437e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 15 deletions

View File

@ -198,7 +198,7 @@ class Runner:
>>> runner.train() >>> runner.train()
>>> runner.test() >>> runner.test()
""" """
cfg: ConfigType cfg: Config
train_loop: Optional[Union[BaseLoop, Dict]] train_loop: Optional[Union[BaseLoop, Dict]]
val_loop: Optional[Union[BaseLoop, Dict]] val_loop: Optional[Union[BaseLoop, Dict]]
test_loop: Optional[Union[BaseLoop, Dict]] test_loop: Optional[Union[BaseLoop, Dict]]
@ -237,9 +237,12 @@ class Runner:
# recursively copy the `cfg` because `self.cfg` will be modified # recursively copy the `cfg` because `self.cfg` will be modified
# everywhere. # everywhere.
if cfg is not None: if cfg is not None:
if isinstance(cfg, Config):
self.cfg = copy.deepcopy(cfg) self.cfg = copy.deepcopy(cfg)
elif isinstance(cfg, dict):
self.cfg = Config(cfg)
else: else:
self.cfg = dict() self.cfg = Config(dict())
self._epoch = 0 self._epoch = 0
self._iter = 0 self._iter = 0
@ -313,9 +316,8 @@ class Runner:
if experiment_name is not None: if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}' self._experiment_name = f'{experiment_name}_{self._timestamp}'
elif self.cfg.get('filename') is not None: elif self.cfg.filename is not None:
filename_no_ext = osp.splitext(osp.basename( filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0]
self.cfg['filename']))[0]
self._experiment_name = f'{filename_no_ext}_{self._timestamp}' self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
else: else:
self._experiment_name = self.timestamp self._experiment_name = self.timestamp
@ -1596,10 +1598,8 @@ class Runner:
@master_only @master_only
def dump_config(self) -> None: def dump_config(self) -> None:
"""Dump config to `work_dir`.""" """Dump config to `work_dir`."""
if isinstance(self.cfg, if self.cfg.filename is not None:
Config) and self.cfg.get('filename') is not None: filename = osp.basename(self.cfg.filename)
self.cfg.dump( else:
osp.join(self.work_dir, osp.basename(self.cfg.filename))) filename = f'{self.timestamp}.py'
elif self.cfg: self.cfg.dump(osp.join(self.work_dir, filename))
# TODO
pass

View File

@ -398,8 +398,26 @@ class TestRunner(TestCase):
runner.train() runner.train()
runner.test() runner.test()
# 5. test `dump_config` def test_dump_config(self):
# TODO # dump config from dict.
cfg = copy.deepcopy(self.epoch_based_cfg)
for idx, cfg in enumerate((cfg, cfg._cfg_dict)):
cfg.experiment_name = f'test_dump{idx}'
runner = Runner.from_cfg(cfg=cfg)
assert osp.exists(
osp.join(runner.work_dir, f'{runner.timestamp}.py'))
# dump config from file.
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix='.py')
file_cfg = Config(
self.epoch_based_cfg._cfg_dict,
filename=temp_config_file.name)
file_cfg.experiment_name = f'test_dump2{idx}'
runner = Runner.from_cfg(cfg=file_cfg)
assert osp.exists(
osp.join(runner.work_dir,
osp.basename(temp_config_file.name)))
def test_from_cfg(self): def test_from_cfg(self):
runner = Runner.from_cfg(cfg=self.epoch_based_cfg) runner = Runner.from_cfg(cfg=self.epoch_based_cfg)