diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 448d2545..7c1ca50b 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -74,6 +74,12 @@ class CheckpointHook(Hook): file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Defaults to None. + filename_tmpl (str, optional): String template to indicate checkpoint + name. If specified, must contain one and only one "{}", which will + be replaced with ``epoch + 1`` if ``by_epoch=True`` else + ``iteration + 1``. + Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" + accordingly. Examples: >>> # Save best based on single metric @@ -116,6 +122,7 @@ class CheckpointHook(Hook): greater_keys: Optional[Sequence[str]] = None, less_keys: Optional[Sequence[str]] = None, file_client_args: Optional[dict] = None, + filename_tmpl: Optional[str] = None, **kwargs) -> None: self.interval = interval self.by_epoch = by_epoch @@ -124,8 +131,15 @@ class CheckpointHook(Hook): self.out_dir = out_dir # type: ignore self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last - self.args = kwargs self.file_client_args = file_client_args + if filename_tmpl is None: + if self.by_epoch: + self.filename_tmpl = 'epoch_{}.pth' + else: + self.filename_tmpl = 'iter_{}.pth' + else: + self.filename_tmpl = filename_tmpl + self.args = kwargs # save best logic assert (isinstance(save_best, str) or is_list_of(save_best, str) @@ -277,11 +291,9 @@ class CheckpointHook(Hook): runner (Runner): The runner of the training process. """ if self.by_epoch: - ckpt_filename = self.args.get( - 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) + ckpt_filename = self.filename_tmpl.format(runner.epoch + 1) else: - ckpt_filename = self.args.get( - 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) + ckpt_filename = self.filename_tmpl.format(runner.iter + 1) runner.save_checkpoint( self.out_dir, @@ -299,18 +311,15 @@ class CheckpointHook(Hook): # remove other checkpoints if self.max_keep_ckpts > 0: if self.by_epoch: - name = 'epoch_{}.pth' current_ckpt = runner.epoch + 1 else: - name = 'iter_{}.pth' current_ckpt = runner.iter + 1 redundant_ckpts = range( current_ckpt - self.max_keep_ckpts * self.interval, 0, -self.interval) - filename_tmpl = self.args.get('filename_tmpl', name) for _step in redundant_ckpts: ckpt_path = self.file_client.join_path( - self.out_dir, filename_tmpl.format(_step)) + self.out_dir, self.filename_tmpl.format(_step)) if self.file_client.isfile(ckpt_path): self.file_client.remove(ckpt_path) else: @@ -334,12 +343,10 @@ class CheckpointHook(Hook): return if self.by_epoch: - ckpt_filename = self.args.get('filename_tmpl', - 'epoch_{}.pth').format(runner.epoch) + ckpt_filename = self.filename_tmpl.format(runner.epoch) cur_type, cur_time = 'epoch', runner.epoch else: - ckpt_filename = self.args.get('filename_tmpl', - 'iter_{}.pth').format(runner.iter) + ckpt_filename = self.filename_tmpl.format(runner.iter) cur_type, cur_time = 'iter', runner.iter # handle auto in self.key_indicators and self.rules before the loop diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d6bf79d4..deb01992 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -4,9 +4,71 @@ import os.path as osp from unittest.mock import Mock, patch import pytest +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from mmengine.evaluator import BaseMetric from mmengine.hooks import CheckpointHook from mmengine.logging import MessageHub +from mmengine.model import BaseModel +from mmengine.optim import OptimWrapper +from mmengine.runner import Runner + + +class ToyModel(BaseModel): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + + def forward(self, inputs, data_sample, mode='tensor'): + labels = torch.stack(data_sample) + inputs = torch.stack(inputs) + outputs = self.linear(inputs) + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = (labels - outputs).sum() + outputs = dict(loss=loss) + return outputs + else: + return outputs + + +class DummyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 2) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class TriangleMetric(BaseMetric): + + default_prefix: str = 'test' + + def __init__(self, length): + super().__init__() + self.length = length + self.best_idx = length // 2 + self.cur_idx = 0 + + def process(self, *args, **kwargs): + self.results.append(0) + + def compute_metrics(self, *args, **kwargs): + self.cur_idx += 1 + acc = 1.0 - abs(self.cur_idx - self.best_idx) / self.length + return dict(acc=acc) class MockPetrel: @@ -370,3 +432,40 @@ class TestCheckpointHook: checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert not os.path.exists(f'{work_dir}/iter_8.pth') + + def test_with_runner(self, tmp_path): + max_epoch = 10 + work_dir = osp.join(str(tmp_path), 'runner_test') + tmpl = '{}.pth' + save_interval = 2 + checkpoint_cfg = dict( + type='CheckpointHook', + interval=save_interval, + filename_tmpl=tmpl, + by_epoch=True) + runner = Runner( + model=ToyModel(), + work_dir=work_dir, + train_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=dict(type=TriangleMetric, length=max_epoch), + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict( + by_epoch=True, max_epochs=max_epoch, val_interval=1), + val_cfg=dict(), + default_hooks=dict(checkpoint=checkpoint_cfg)) + runner.train() + for epoch in range(max_epoch): + if epoch % save_interval != 0 or epoch == 0: + continue + path = osp.join(work_dir, tmpl.format(epoch)) + assert osp.isfile(path=path)