mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] change CheckPointHook before_run to before train (#214)
* change CheckPointHook before_run to before train * using tmp_path in each checkpointhook test case
This commit is contained in:
parent
a1adbff11e
commit
5007825619
@ -69,7 +69,7 @@ class CheckpointHook(Hook):
|
|||||||
self.args = kwargs
|
self.args = kwargs
|
||||||
self.file_client_args = file_client_args
|
self.file_client_args = file_client_args
|
||||||
|
|
||||||
def before_run(self, runner) -> None:
|
def before_train(self, runner) -> None:
|
||||||
"""Finish all operations, related to checkpoint.
|
"""Finish all operations, related to checkpoint.
|
||||||
|
|
||||||
This function will get the appropriate file client, and the directory
|
This function will get the appropriate file client, and the directory
|
||||||
@ -78,12 +78,11 @@ class CheckpointHook(Hook):
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if not self.out_dir:
|
if self.out_dir is None:
|
||||||
self.out_dir = runner.work_dir
|
self.out_dir = runner.work_dir
|
||||||
|
|
||||||
self.file_client = FileClient.infer_client(self.file_client_args,
|
self.file_client = FileClient.infer_client(self.file_client_args,
|
||||||
self.out_dir)
|
self.out_dir)
|
||||||
|
|
||||||
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
|
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
|
||||||
# `self.out_dir` is set so the final `self.out_dir` is the
|
# `self.out_dir` is set so the final `self.out_dir` is the
|
||||||
# concatenation of `self.out_dir` and the last level directory of
|
# concatenation of `self.out_dir` and the last level directory of
|
||||||
@ -186,8 +185,7 @@ class CheckpointHook(Hook):
|
|||||||
batch_idx (int): The index of the current batch in the train loop.
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
outputs (dict, optional): Outputs from model.
|
outputs (dict, optional): Outputs from model. Defaults to None.
|
||||||
Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
return
|
return
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os
|
import os
|
||||||
import sys
|
import os.path as osp
|
||||||
from tempfile import TemporaryDirectory
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from mmengine.hooks import CheckpointHook
|
from mmengine.hooks import CheckpointHook
|
||||||
|
|
||||||
sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client']
|
|
||||||
|
|
||||||
|
|
||||||
class MockPetrel:
|
class MockPetrel:
|
||||||
|
|
||||||
@ -30,75 +27,80 @@ prefix_to_backends = {'s3': MockPetrel}
|
|||||||
|
|
||||||
class TestCheckpointHook:
|
class TestCheckpointHook:
|
||||||
|
|
||||||
@patch('file_client.FileClient._prefix_to_backends', prefix_to_backends)
|
@patch('mmengine.fileio.file_client.FileClient._prefix_to_backends',
|
||||||
def test_before_run(self):
|
prefix_to_backends)
|
||||||
|
def test_before_train(self, tmp_path):
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.work_dir = './tmp'
|
work_dir = str(tmp_path)
|
||||||
|
runner.work_dir = work_dir
|
||||||
|
|
||||||
# the out_dir of the checkpoint hook is None
|
# the out_dir of the checkpoint hook is None
|
||||||
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
|
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
assert checkpoint_hook.out_dir == runner.work_dir
|
assert checkpoint_hook.out_dir == runner.work_dir
|
||||||
|
|
||||||
# the out_dir of the checkpoint hook is not None
|
# the out_dir of the checkpoint hook is not None
|
||||||
checkpoint_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=1, by_epoch=True, out_dir='test_dir')
|
interval=1, by_epoch=True, out_dir='test_dir')
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
assert checkpoint_hook.out_dir == 'test_dir/tmp'
|
assert checkpoint_hook.out_dir == (
|
||||||
|
f'test_dir/{osp.basename(work_dir)}')
|
||||||
|
|
||||||
# create_symlink in args and create_symlink is True
|
# create_symlink in args and create_symlink is True
|
||||||
checkpoint_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
|
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
assert checkpoint_hook.args['create_symlink']
|
assert checkpoint_hook.args['create_symlink']
|
||||||
|
|
||||||
runner.work_dir = 's3://path/of/file'
|
runner.work_dir = 's3://path/of/file'
|
||||||
checkpoint_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=1, by_epoch=True, create_symlink=True)
|
interval=1, by_epoch=True, create_symlink=True)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
assert not checkpoint_hook.args['create_symlink']
|
assert not checkpoint_hook.args['create_symlink']
|
||||||
|
|
||||||
def test_after_train_epoch(self):
|
def test_after_train_epoch(self, tmp_path):
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.work_dir = './tmp'
|
work_dir = str(tmp_path)
|
||||||
|
runner.work_dir = tmp_path
|
||||||
runner.epoch = 9
|
runner.epoch = 9
|
||||||
runner.meta = dict()
|
runner.meta = dict()
|
||||||
runner.model = Mock()
|
runner.model = Mock()
|
||||||
|
|
||||||
# by epoch is True
|
# by epoch is True
|
||||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
checkpoint_hook.after_train_epoch(runner)
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
assert (runner.epoch + 1) % 2 == 0
|
assert (runner.epoch + 1) % 2 == 0
|
||||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
|
assert runner.meta['hook_msgs']['last_ckpt'] == (
|
||||||
|
f'{work_dir}/epoch_10.pth')
|
||||||
# epoch can not be evenly divided by 2
|
# epoch can not be evenly divided by 2
|
||||||
runner.epoch = 10
|
runner.epoch = 10
|
||||||
checkpoint_hook.after_train_epoch(runner)
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
|
assert runner.meta['hook_msgs']['last_ckpt'] == (
|
||||||
|
f'{work_dir}/epoch_10.pth')
|
||||||
|
|
||||||
# by epoch is False
|
# by epoch is False
|
||||||
runner.epoch = 9
|
runner.epoch = 9
|
||||||
runner.meta = dict()
|
runner.meta = dict()
|
||||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
checkpoint_hook.after_train_epoch(runner)
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
assert runner.meta.get('hook_msgs', None) is None
|
assert runner.meta.get('hook_msgs', None) is None
|
||||||
|
|
||||||
# max_keep_ckpts > 0
|
# max_keep_ckpts > 0
|
||||||
with TemporaryDirectory() as tempo_dir:
|
runner.work_dir = work_dir
|
||||||
runner.work_dir = tempo_dir
|
os.system(f'touch {work_dir}/epoch_8.pth')
|
||||||
os.system(f'touch {tempo_dir}/epoch_8.pth')
|
|
||||||
checkpoint_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=True, max_keep_ckpts=1)
|
interval=2, by_epoch=True, max_keep_ckpts=1)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
checkpoint_hook.after_train_epoch(runner)
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
assert (runner.epoch + 1) % 2 == 0
|
assert (runner.epoch + 1) % 2 == 0
|
||||||
assert not os.path.exists(f'{tempo_dir}/epoch_8.pth')
|
assert not os.path.exists(f'{work_dir}/epoch_8.pth')
|
||||||
|
|
||||||
def test_after_train_iter(self):
|
def test_after_train_iter(self, tmp_path):
|
||||||
|
work_dir = str(tmp_path)
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
runner.work_dir = './tmp'
|
runner.work_dir = str(work_dir)
|
||||||
runner.iter = 9
|
runner.iter = 9
|
||||||
batch_idx = 9
|
batch_idx = 9
|
||||||
runner.meta = dict()
|
runner.meta = dict()
|
||||||
@ -106,29 +108,30 @@ class TestCheckpointHook:
|
|||||||
|
|
||||||
# by epoch is True
|
# by epoch is True
|
||||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
assert runner.meta.get('hook_msgs', None) is None
|
assert runner.meta.get('hook_msgs', None) is None
|
||||||
|
|
||||||
# by epoch is False
|
# by epoch is False
|
||||||
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
assert (runner.iter + 1) % 2 == 0
|
assert (runner.iter + 1) % 2 == 0
|
||||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
assert runner.meta['hook_msgs']['last_ckpt'] == (
|
||||||
|
f'{work_dir}/iter_10.pth')
|
||||||
|
|
||||||
# epoch can not be evenly divided by 2
|
# epoch can not be evenly divided by 2
|
||||||
runner.iter = 10
|
runner.iter = 10
|
||||||
checkpoint_hook.after_train_epoch(runner)
|
checkpoint_hook.after_train_epoch(runner)
|
||||||
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
|
assert runner.meta['hook_msgs']['last_ckpt'] == (
|
||||||
|
f'{work_dir}/iter_10.pth')
|
||||||
|
|
||||||
# max_keep_ckpts > 0
|
# max_keep_ckpts > 0
|
||||||
runner.iter = 9
|
runner.iter = 9
|
||||||
with TemporaryDirectory() as tempo_dir:
|
runner.work_dir = work_dir
|
||||||
runner.work_dir = tempo_dir
|
os.system(f'touch {work_dir}/iter_8.pth')
|
||||||
os.system(f'touch {tempo_dir}/iter_8.pth')
|
|
||||||
checkpoint_hook = CheckpointHook(
|
checkpoint_hook = CheckpointHook(
|
||||||
interval=2, by_epoch=False, max_keep_ckpts=1)
|
interval=2, by_epoch=False, max_keep_ckpts=1)
|
||||||
checkpoint_hook.before_run(runner)
|
checkpoint_hook.before_train(runner)
|
||||||
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||||
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
|
assert not os.path.exists(f'{work_dir}/iter_8.pth')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user