mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] CheckpointHook behavior incorrect if given filename_tmpl
argument (#518)
This commit is contained in:
parent
e56b1736c1
commit
c64243aa9e
@ -74,6 +74,12 @@ class CheckpointHook(Hook):
|
|||||||
file_client_args (dict, optional): Arguments to instantiate a
|
file_client_args (dict, optional): Arguments to instantiate a
|
||||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
filename_tmpl (str, optional): String template to indicate checkpoint
|
||||||
|
name. If specified, must contain one and only one "{}", which will
|
||||||
|
be replaced with ``epoch + 1`` if ``by_epoch=True`` else
|
||||||
|
``iteration + 1``.
|
||||||
|
Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth"
|
||||||
|
accordingly.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # Save best based on single metric
|
>>> # Save best based on single metric
|
||||||
@ -116,6 +122,7 @@ class CheckpointHook(Hook):
|
|||||||
greater_keys: Optional[Sequence[str]] = None,
|
greater_keys: Optional[Sequence[str]] = None,
|
||||||
less_keys: Optional[Sequence[str]] = None,
|
less_keys: Optional[Sequence[str]] = None,
|
||||||
file_client_args: Optional[dict] = None,
|
file_client_args: Optional[dict] = None,
|
||||||
|
filename_tmpl: Optional[str] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.by_epoch = by_epoch
|
self.by_epoch = by_epoch
|
||||||
@ -124,8 +131,15 @@ class CheckpointHook(Hook):
|
|||||||
self.out_dir = out_dir # type: ignore
|
self.out_dir = out_dir # type: ignore
|
||||||
self.max_keep_ckpts = max_keep_ckpts
|
self.max_keep_ckpts = max_keep_ckpts
|
||||||
self.save_last = save_last
|
self.save_last = save_last
|
||||||
self.args = kwargs
|
|
||||||
self.file_client_args = file_client_args
|
self.file_client_args = file_client_args
|
||||||
|
if filename_tmpl is None:
|
||||||
|
if self.by_epoch:
|
||||||
|
self.filename_tmpl = 'epoch_{}.pth'
|
||||||
|
else:
|
||||||
|
self.filename_tmpl = 'iter_{}.pth'
|
||||||
|
else:
|
||||||
|
self.filename_tmpl = filename_tmpl
|
||||||
|
self.args = kwargs
|
||||||
|
|
||||||
# save best logic
|
# save best logic
|
||||||
assert (isinstance(save_best, str) or is_list_of(save_best, str)
|
assert (isinstance(save_best, str) or is_list_of(save_best, str)
|
||||||
@ -277,11 +291,9 @@ class CheckpointHook(Hook):
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
ckpt_filename = self.args.get(
|
ckpt_filename = self.filename_tmpl.format(runner.epoch + 1)
|
||||||
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
|
||||||
else:
|
else:
|
||||||
ckpt_filename = self.args.get(
|
ckpt_filename = self.filename_tmpl.format(runner.iter + 1)
|
||||||
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
|
||||||
|
|
||||||
runner.save_checkpoint(
|
runner.save_checkpoint(
|
||||||
self.out_dir,
|
self.out_dir,
|
||||||
@ -299,18 +311,15 @@ class CheckpointHook(Hook):
|
|||||||
# remove other checkpoints
|
# remove other checkpoints
|
||||||
if self.max_keep_ckpts > 0:
|
if self.max_keep_ckpts > 0:
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
name = 'epoch_{}.pth'
|
|
||||||
current_ckpt = runner.epoch + 1
|
current_ckpt = runner.epoch + 1
|
||||||
else:
|
else:
|
||||||
name = 'iter_{}.pth'
|
|
||||||
current_ckpt = runner.iter + 1
|
current_ckpt = runner.iter + 1
|
||||||
redundant_ckpts = range(
|
redundant_ckpts = range(
|
||||||
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
||||||
-self.interval)
|
-self.interval)
|
||||||
filename_tmpl = self.args.get('filename_tmpl', name)
|
|
||||||
for _step in redundant_ckpts:
|
for _step in redundant_ckpts:
|
||||||
ckpt_path = self.file_client.join_path(
|
ckpt_path = self.file_client.join_path(
|
||||||
self.out_dir, filename_tmpl.format(_step))
|
self.out_dir, self.filename_tmpl.format(_step))
|
||||||
if self.file_client.isfile(ckpt_path):
|
if self.file_client.isfile(ckpt_path):
|
||||||
self.file_client.remove(ckpt_path)
|
self.file_client.remove(ckpt_path)
|
||||||
else:
|
else:
|
||||||
@ -334,12 +343,10 @@ class CheckpointHook(Hook):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
ckpt_filename = self.args.get('filename_tmpl',
|
ckpt_filename = self.filename_tmpl.format(runner.epoch)
|
||||||
'epoch_{}.pth').format(runner.epoch)
|
|
||||||
cur_type, cur_time = 'epoch', runner.epoch
|
cur_type, cur_time = 'epoch', runner.epoch
|
||||||
else:
|
else:
|
||||||
ckpt_filename = self.args.get('filename_tmpl',
|
ckpt_filename = self.filename_tmpl.format(runner.iter)
|
||||||
'iter_{}.pth').format(runner.iter)
|
|
||||||
cur_type, cur_time = 'iter', runner.iter
|
cur_type, cur_time = 'iter', runner.iter
|
||||||
|
|
||||||
# handle auto in self.key_indicators and self.rules before the loop
|
# handle auto in self.key_indicators and self.rules before the loop
|
||||||
|
@ -4,9 +4,71 @@ import os.path as osp
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from mmengine.evaluator import BaseMetric
|
||||||
from mmengine.hooks import CheckpointHook
|
from mmengine.hooks import CheckpointHook
|
||||||
from mmengine.logging import MessageHub
|
from mmengine.logging import MessageHub
|
||||||
|
from mmengine.model import BaseModel
|
||||||
|
from mmengine.optim import OptimWrapper
|
||||||
|
from mmengine.runner import Runner
|
||||||
|
|
||||||
|
|
||||||
|
class ToyModel(BaseModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(2, 1)
|
||||||
|
|
||||||
|
def forward(self, inputs, data_sample, mode='tensor'):
|
||||||
|
labels = torch.stack(data_sample)
|
||||||
|
inputs = torch.stack(inputs)
|
||||||
|
outputs = self.linear(inputs)
|
||||||
|
if mode == 'tensor':
|
||||||
|
return outputs
|
||||||
|
elif mode == 'loss':
|
||||||
|
loss = (labels - outputs).sum()
|
||||||
|
outputs = dict(loss=loss)
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataset(Dataset):
|
||||||
|
METAINFO = dict() # type: ignore
|
||||||
|
data = torch.randn(12, 2)
|
||||||
|
label = torch.ones(12)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metainfo(self):
|
||||||
|
return self.METAINFO
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.data.size(0)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMetric(BaseMetric):
|
||||||
|
|
||||||
|
default_prefix: str = 'test'
|
||||||
|
|
||||||
|
def __init__(self, length):
|
||||||
|
super().__init__()
|
||||||
|
self.length = length
|
||||||
|
self.best_idx = length // 2
|
||||||
|
self.cur_idx = 0
|
||||||
|
|
||||||
|
def process(self, *args, **kwargs):
|
||||||
|
self.results.append(0)
|
||||||
|
|
||||||
|
def compute_metrics(self, *args, **kwargs):
|
||||||
|
self.cur_idx += 1
|
||||||
|
acc = 1.0 - abs(self.cur_idx - self.best_idx) / self.length
|
||||||
|
return dict(acc=acc)
|
||||||
|
|
||||||
|
|
||||||
class MockPetrel:
|
class MockPetrel:
|
||||||
@ -370,3 +432,40 @@ class TestCheckpointHook:
|
|||||||
checkpoint_hook.before_train(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
assert not os.path.exists(f'{work_dir}/iter_8.pth')
|
assert not os.path.exists(f'{work_dir}/iter_8.pth')
|
||||||
|
|
||||||
|
def test_with_runner(self, tmp_path):
|
||||||
|
max_epoch = 10
|
||||||
|
work_dir = osp.join(str(tmp_path), 'runner_test')
|
||||||
|
tmpl = '{}.pth'
|
||||||
|
save_interval = 2
|
||||||
|
checkpoint_cfg = dict(
|
||||||
|
type='CheckpointHook',
|
||||||
|
interval=save_interval,
|
||||||
|
filename_tmpl=tmpl,
|
||||||
|
by_epoch=True)
|
||||||
|
runner = Runner(
|
||||||
|
model=ToyModel(),
|
||||||
|
work_dir=work_dir,
|
||||||
|
train_dataloader=dict(
|
||||||
|
dataset=DummyDataset(),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
val_dataloader=dict(
|
||||||
|
dataset=DummyDataset(),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
batch_size=3,
|
||||||
|
num_workers=0),
|
||||||
|
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
|
||||||
|
optim_wrapper=OptimWrapper(
|
||||||
|
torch.optim.Adam(ToyModel().parameters())),
|
||||||
|
train_cfg=dict(
|
||||||
|
by_epoch=True, max_epochs=max_epoch, val_interval=1),
|
||||||
|
val_cfg=dict(),
|
||||||
|
default_hooks=dict(checkpoint=checkpoint_cfg))
|
||||||
|
runner.train()
|
||||||
|
for epoch in range(max_epoch):
|
||||||
|
if epoch % save_interval != 0 or epoch == 0:
|
||||||
|
continue
|
||||||
|
path = osp.join(work_dir, tmpl.format(epoch))
|
||||||
|
assert osp.isfile(path=path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user