[Fix] Support Runner dump cfg without filename (#228)
* fix runner dump cfg * convert dict cfg to Configpull/236/head
parent
84e60f6463
commit
fd962437e9
|
@ -198,7 +198,7 @@ class Runner:
|
|||
>>> runner.train()
|
||||
>>> runner.test()
|
||||
"""
|
||||
cfg: ConfigType
|
||||
cfg: Config
|
||||
train_loop: Optional[Union[BaseLoop, Dict]]
|
||||
val_loop: Optional[Union[BaseLoop, Dict]]
|
||||
test_loop: Optional[Union[BaseLoop, Dict]]
|
||||
|
@ -237,9 +237,12 @@ class Runner:
|
|||
# recursively copy the `cfg` because `self.cfg` will be modified
|
||||
# everywhere.
|
||||
if cfg is not None:
|
||||
self.cfg = copy.deepcopy(cfg)
|
||||
if isinstance(cfg, Config):
|
||||
self.cfg = copy.deepcopy(cfg)
|
||||
elif isinstance(cfg, dict):
|
||||
self.cfg = Config(cfg)
|
||||
else:
|
||||
self.cfg = dict()
|
||||
self.cfg = Config(dict())
|
||||
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
|
@ -313,9 +316,8 @@ class Runner:
|
|||
|
||||
if experiment_name is not None:
|
||||
self._experiment_name = f'{experiment_name}_{self._timestamp}'
|
||||
elif self.cfg.get('filename') is not None:
|
||||
filename_no_ext = osp.splitext(osp.basename(
|
||||
self.cfg['filename']))[0]
|
||||
elif self.cfg.filename is not None:
|
||||
filename_no_ext = osp.splitext(osp.basename(self.cfg.filename))[0]
|
||||
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
|
||||
else:
|
||||
self._experiment_name = self.timestamp
|
||||
|
@ -1596,10 +1598,8 @@ class Runner:
|
|||
@master_only
|
||||
def dump_config(self) -> None:
|
||||
"""Dump config to `work_dir`."""
|
||||
if isinstance(self.cfg,
|
||||
Config) and self.cfg.get('filename') is not None:
|
||||
self.cfg.dump(
|
||||
osp.join(self.work_dir, osp.basename(self.cfg.filename)))
|
||||
elif self.cfg:
|
||||
# TODO
|
||||
pass
|
||||
if self.cfg.filename is not None:
|
||||
filename = osp.basename(self.cfg.filename)
|
||||
else:
|
||||
filename = f'{self.timestamp}.py'
|
||||
self.cfg.dump(osp.join(self.work_dir, filename))
|
||||
|
|
|
@ -398,8 +398,26 @@ class TestRunner(TestCase):
|
|||
runner.train()
|
||||
runner.test()
|
||||
|
||||
# 5. test `dump_config`
|
||||
# TODO
|
||||
def test_dump_config(self):
|
||||
# dump config from dict.
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
for idx, cfg in enumerate((cfg, cfg._cfg_dict)):
|
||||
cfg.experiment_name = f'test_dump{idx}'
|
||||
runner = Runner.from_cfg(cfg=cfg)
|
||||
assert osp.exists(
|
||||
osp.join(runner.work_dir, f'{runner.timestamp}.py'))
|
||||
# dump config from file.
|
||||
with tempfile.TemporaryDirectory() as temp_config_dir:
|
||||
temp_config_file = tempfile.NamedTemporaryFile(
|
||||
dir=temp_config_dir, suffix='.py')
|
||||
file_cfg = Config(
|
||||
self.epoch_based_cfg._cfg_dict,
|
||||
filename=temp_config_file.name)
|
||||
file_cfg.experiment_name = f'test_dump2{idx}'
|
||||
runner = Runner.from_cfg(cfg=file_cfg)
|
||||
assert osp.exists(
|
||||
osp.join(runner.work_dir,
|
||||
osp.basename(temp_config_file.name)))
|
||||
|
||||
def test_from_cfg(self):
|
||||
runner = Runner.from_cfg(cfg=self.epoch_based_cfg)
|
||||
|
|
Loading…
Reference in New Issue