[Fix] CheckpointHook behavior incorrect if given filename_tmpl argument (#518)

This commit is contained in:
Qian Zhao 2022-09-22 12:47:45 +08:00 committed by GitHub
parent e56b1736c1
commit c64243aa9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 119 additions and 13 deletions

View File

@ -74,6 +74,12 @@ class CheckpointHook(Hook):
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Defaults to None.
filename_tmpl (str, optional): String template to indicate checkpoint
name. If specified, must contain one and only one "{}", which will
be replaced with ``epoch + 1`` if ``by_epoch=True`` else
``iteration + 1``.
Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth"
accordingly.
Examples:
>>> # Save best based on single metric
@ -116,6 +122,7 @@ class CheckpointHook(Hook):
greater_keys: Optional[Sequence[str]] = None,
less_keys: Optional[Sequence[str]] = None,
file_client_args: Optional[dict] = None,
filename_tmpl: Optional[str] = None,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
@ -124,8 +131,15 @@ class CheckpointHook(Hook):
self.out_dir = out_dir # type: ignore
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.file_client_args = file_client_args
if filename_tmpl is None:
if self.by_epoch:
self.filename_tmpl = 'epoch_{}.pth'
else:
self.filename_tmpl = 'iter_{}.pth'
else:
self.filename_tmpl = filename_tmpl
self.args = kwargs
# save best logic
assert (isinstance(save_best, str) or is_list_of(save_best, str)
@ -277,11 +291,9 @@ class CheckpointHook(Hook):
runner (Runner): The runner of the training process.
"""
if self.by_epoch:
ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
ckpt_filename = self.filename_tmpl.format(runner.epoch + 1)
else:
ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
ckpt_filename = self.filename_tmpl.format(runner.iter + 1)
runner.save_checkpoint(
self.out_dir,
@ -299,18 +311,15 @@ class CheckpointHook(Hook):
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step))
self.out_dir, self.filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
@ -334,12 +343,10 @@ class CheckpointHook(Hook):
return
if self.by_epoch:
ckpt_filename = self.args.get('filename_tmpl',
'epoch_{}.pth').format(runner.epoch)
ckpt_filename = self.filename_tmpl.format(runner.epoch)
cur_type, cur_time = 'epoch', runner.epoch
else:
ckpt_filename = self.args.get('filename_tmpl',
'iter_{}.pth').format(runner.iter)
ckpt_filename = self.filename_tmpl.format(runner.iter)
cur_type, cur_time = 'iter', runner.iter
# handle auto in self.key_indicators and self.rules before the loop

View File

@ -4,9 +4,71 @@ import os.path as osp
from unittest.mock import Mock, patch
import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from mmengine.evaluator import BaseMetric
from mmengine.hooks import CheckpointHook
from mmengine.logging import MessageHub
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner
class ToyModel(BaseModel):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, inputs, data_sample, mode='tensor'):
labels = torch.stack(data_sample)
inputs = torch.stack(inputs)
outputs = self.linear(inputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
else:
return outputs
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class TriangleMetric(BaseMetric):
default_prefix: str = 'test'
def __init__(self, length):
super().__init__()
self.length = length
self.best_idx = length // 2
self.cur_idx = 0
def process(self, *args, **kwargs):
self.results.append(0)
def compute_metrics(self, *args, **kwargs):
self.cur_idx += 1
acc = 1.0 - abs(self.cur_idx - self.best_idx) / self.length
return dict(acc=acc)
class MockPetrel:
@ -370,3 +432,40 @@ class TestCheckpointHook:
checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert not os.path.exists(f'{work_dir}/iter_8.pth')
def test_with_runner(self, tmp_path):
max_epoch = 10
work_dir = osp.join(str(tmp_path), 'runner_test')
tmpl = '{}.pth'
save_interval = 2
checkpoint_cfg = dict(
type='CheckpointHook',
interval=save_interval,
filename_tmpl=tmpl,
by_epoch=True)
runner = Runner(
model=ToyModel(),
work_dir=work_dir,
train_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(
by_epoch=True, max_epochs=max_epoch, val_interval=1),
val_cfg=dict(),
default_hooks=dict(checkpoint=checkpoint_cfg))
runner.train()
for epoch in range(max_epoch):
if epoch % save_interval != 0 or epoch == 0:
continue
path = osp.join(work_dir, tmpl.format(epoch))
assert osp.isfile(path=path)