mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] CheckpointHook behavior incorrect if given filename_tmpl
argument (#518)
This commit is contained in:
parent
e56b1736c1
commit
c64243aa9e
@ -74,6 +74,12 @@ class CheckpointHook(Hook):
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Defaults to None.
|
||||
filename_tmpl (str, optional): String template to indicate checkpoint
|
||||
name. If specified, must contain one and only one "{}", which will
|
||||
be replaced with ``epoch + 1`` if ``by_epoch=True`` else
|
||||
``iteration + 1``.
|
||||
Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth"
|
||||
accordingly.
|
||||
|
||||
Examples:
|
||||
>>> # Save best based on single metric
|
||||
@ -116,6 +122,7 @@ class CheckpointHook(Hook):
|
||||
greater_keys: Optional[Sequence[str]] = None,
|
||||
less_keys: Optional[Sequence[str]] = None,
|
||||
file_client_args: Optional[dict] = None,
|
||||
filename_tmpl: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
self.interval = interval
|
||||
self.by_epoch = by_epoch
|
||||
@ -124,8 +131,15 @@ class CheckpointHook(Hook):
|
||||
self.out_dir = out_dir # type: ignore
|
||||
self.max_keep_ckpts = max_keep_ckpts
|
||||
self.save_last = save_last
|
||||
self.args = kwargs
|
||||
self.file_client_args = file_client_args
|
||||
if filename_tmpl is None:
|
||||
if self.by_epoch:
|
||||
self.filename_tmpl = 'epoch_{}.pth'
|
||||
else:
|
||||
self.filename_tmpl = 'iter_{}.pth'
|
||||
else:
|
||||
self.filename_tmpl = filename_tmpl
|
||||
self.args = kwargs
|
||||
|
||||
# save best logic
|
||||
assert (isinstance(save_best, str) or is_list_of(save_best, str)
|
||||
@ -277,11 +291,9 @@ class CheckpointHook(Hook):
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.by_epoch:
|
||||
ckpt_filename = self.args.get(
|
||||
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
||||
ckpt_filename = self.filename_tmpl.format(runner.epoch + 1)
|
||||
else:
|
||||
ckpt_filename = self.args.get(
|
||||
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
||||
ckpt_filename = self.filename_tmpl.format(runner.iter + 1)
|
||||
|
||||
runner.save_checkpoint(
|
||||
self.out_dir,
|
||||
@ -299,18 +311,15 @@ class CheckpointHook(Hook):
|
||||
# remove other checkpoints
|
||||
if self.max_keep_ckpts > 0:
|
||||
if self.by_epoch:
|
||||
name = 'epoch_{}.pth'
|
||||
current_ckpt = runner.epoch + 1
|
||||
else:
|
||||
name = 'iter_{}.pth'
|
||||
current_ckpt = runner.iter + 1
|
||||
redundant_ckpts = range(
|
||||
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
||||
-self.interval)
|
||||
filename_tmpl = self.args.get('filename_tmpl', name)
|
||||
for _step in redundant_ckpts:
|
||||
ckpt_path = self.file_client.join_path(
|
||||
self.out_dir, filename_tmpl.format(_step))
|
||||
self.out_dir, self.filename_tmpl.format(_step))
|
||||
if self.file_client.isfile(ckpt_path):
|
||||
self.file_client.remove(ckpt_path)
|
||||
else:
|
||||
@ -334,12 +343,10 @@ class CheckpointHook(Hook):
|
||||
return
|
||||
|
||||
if self.by_epoch:
|
||||
ckpt_filename = self.args.get('filename_tmpl',
|
||||
'epoch_{}.pth').format(runner.epoch)
|
||||
ckpt_filename = self.filename_tmpl.format(runner.epoch)
|
||||
cur_type, cur_time = 'epoch', runner.epoch
|
||||
else:
|
||||
ckpt_filename = self.args.get('filename_tmpl',
|
||||
'iter_{}.pth').format(runner.iter)
|
||||
ckpt_filename = self.filename_tmpl.format(runner.iter)
|
||||
cur_type, cur_time = 'iter', runner.iter
|
||||
|
||||
# handle auto in self.key_indicators and self.rules before the loop
|
||||
|
@ -4,9 +4,71 @@ import os.path as osp
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.hooks import CheckpointHook
|
||||
from mmengine.logging import MessageHub
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
labels = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
outputs = self.linear(inputs)
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class TriangleMetric(BaseMetric):
|
||||
|
||||
default_prefix: str = 'test'
|
||||
|
||||
def __init__(self, length):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.best_idx = length // 2
|
||||
self.cur_idx = 0
|
||||
|
||||
def process(self, *args, **kwargs):
|
||||
self.results.append(0)
|
||||
|
||||
def compute_metrics(self, *args, **kwargs):
|
||||
self.cur_idx += 1
|
||||
acc = 1.0 - abs(self.cur_idx - self.best_idx) / self.length
|
||||
return dict(acc=acc)
|
||||
|
||||
|
||||
class MockPetrel:
|
||||
@ -370,3 +432,40 @@ class TestCheckpointHook:
|
||||
checkpoint_hook.before_train(runner)
|
||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
assert not os.path.exists(f'{work_dir}/iter_8.pth')
|
||||
|
||||
def test_with_runner(self, tmp_path):
|
||||
max_epoch = 10
|
||||
work_dir = osp.join(str(tmp_path), 'runner_test')
|
||||
tmpl = '{}.pth'
|
||||
save_interval = 2
|
||||
checkpoint_cfg = dict(
|
||||
type='CheckpointHook',
|
||||
interval=save_interval,
|
||||
filename_tmpl=tmpl,
|
||||
by_epoch=True)
|
||||
runner = Runner(
|
||||
model=ToyModel(),
|
||||
work_dir=work_dir,
|
||||
train_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_dataloader=dict(
|
||||
dataset=DummyDataset(),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
batch_size=3,
|
||||
num_workers=0),
|
||||
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(
|
||||
by_epoch=True, max_epochs=max_epoch, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(checkpoint=checkpoint_cfg))
|
||||
runner.train()
|
||||
for epoch in range(max_epoch):
|
||||
if epoch % save_interval != 0 or epoch == 0:
|
||||
continue
|
||||
path = osp.join(work_dir, tmpl.format(epoch))
|
||||
assert osp.isfile(path=path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user