mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enhance checkpoint meta info. (#279)
This commit is contained in:
parent
538ff48aec
commit
ad965a5309
@ -10,4 +10,5 @@ from .logging import *
|
|||||||
from .registry import *
|
from .registry import *
|
||||||
from .runner import *
|
from .runner import *
|
||||||
from .utils import *
|
from .utils import *
|
||||||
|
from .version import __version__, version_info
|
||||||
from .visualization import *
|
from .visualization import *
|
||||||
|
@ -162,6 +162,4 @@ class InfiniteSampler(Sampler):
|
|||||||
|
|
||||||
def set_epoch(self, epoch: int) -> None:
|
def set_epoch(self, epoch: int) -> None:
|
||||||
"""Not supported in iteration-based runner."""
|
"""Not supported in iteration-based runner."""
|
||||||
raise NotImplementedError(
|
pass
|
||||||
'The `InfiniteSampler` is only used in iteration-based runner, '
|
|
||||||
"and doesn't need `set_epoch`")
|
|
||||||
|
@ -146,7 +146,8 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||||
'metainfo. ``dataset_meta`` in visualizer will be '
|
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||||
'None.')
|
'None.')
|
||||||
self.dataloader = iter(self.dataloader)
|
# get the iterator of the dataloader
|
||||||
|
self.dataloader_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_epochs(self):
|
def max_epochs(self):
|
||||||
@ -177,7 +178,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
while self._iter < self._max_iters:
|
while self._iter < self._max_iters:
|
||||||
self.runner.model.train()
|
self.runner.model.train()
|
||||||
|
|
||||||
data_batch = next(self.dataloader)
|
data_batch = next(self.dataloader_iterator)
|
||||||
self.run_iter(data_batch)
|
self.run_iter(data_batch)
|
||||||
|
|
||||||
if (self.runner.val_loop is not None
|
if (self.runner.val_loop is not None
|
||||||
|
@ -34,7 +34,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
|||||||
count_registered_modules)
|
count_registered_modules)
|
||||||
from mmengine.registry.root import LOG_PROCESSORS
|
from mmengine.registry.root import LOG_PROCESSORS
|
||||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||||
find_latest_checkpoint, is_list_of,
|
find_latest_checkpoint, get_git_hash, is_list_of,
|
||||||
set_multi_processing, symlink)
|
set_multi_processing, symlink)
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
@ -331,6 +331,7 @@ class Runner:
|
|||||||
self.setup_env(env_cfg)
|
self.setup_env(env_cfg)
|
||||||
# self._deterministic and self._seed will be set in the
|
# self._deterministic and self._seed will be set in the
|
||||||
# `set_randomness`` method
|
# `set_randomness`` method
|
||||||
|
self._randomness_cfg = randomness
|
||||||
self.set_randomness(**randomness)
|
self.set_randomness(**randomness)
|
||||||
|
|
||||||
if experiment_name is not None:
|
if experiment_name is not None:
|
||||||
@ -1796,8 +1797,26 @@ class Runner:
|
|||||||
'previous training state resuming from the checkpoint '
|
'previous training state resuming from the checkpoint '
|
||||||
'or set `enable` in `auto_scale_lr to False.')
|
'or set `enable` in `auto_scale_lr to False.')
|
||||||
|
|
||||||
# resume meta information meta
|
# resume random seed
|
||||||
self.meta = checkpoint['meta']
|
resumed_seed = checkpoint['meta'].get('seed', None)
|
||||||
|
current_seed = self._randomness_cfg.get('seed')
|
||||||
|
if resumed_seed is not None and resumed_seed != current_seed:
|
||||||
|
if current_seed is not None:
|
||||||
|
warnings.warn(f'The value of random seed in the '
|
||||||
|
f'checkpoint "{resumed_seed}" is '
|
||||||
|
f'different from the value in '
|
||||||
|
f'`randomness` config "{current_seed}"')
|
||||||
|
self._randomness_cfg.update(seed=resumed_seed)
|
||||||
|
self.set_randomness(**self._randomness_cfg)
|
||||||
|
|
||||||
|
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||||
|
if (dataset_meta is not None
|
||||||
|
and dataset_meta != self.train_dataloader.dataset.metainfo):
|
||||||
|
warnings.warn(
|
||||||
|
'The dataset metainfo from the resumed checkpoint is '
|
||||||
|
'different from the current training dataset, please '
|
||||||
|
'check the correctness of the checkpoint or the training '
|
||||||
|
'dataset.')
|
||||||
|
|
||||||
# resume optimizer
|
# resume optimizer
|
||||||
if 'optimizer' in checkpoint and resume_optimizer:
|
if 'optimizer' in checkpoint and resume_optimizer:
|
||||||
@ -1909,9 +1928,13 @@ class Runner:
|
|||||||
|
|
||||||
filepath = osp.join(out_dir, filename)
|
filepath = osp.join(out_dir, filename)
|
||||||
|
|
||||||
if hasattr(self.model, 'CLASSES') and self.model.CLASSES is not None:
|
meta.update(
|
||||||
# save class name to the meta
|
cfg=self.cfg.pretty_text,
|
||||||
meta.update(CLASSES=self.model.CLASSES)
|
dataset_meta=self.train_dataloader.dataset.metainfo,
|
||||||
|
seed=self.seed,
|
||||||
|
experiment_name=self.experiment_name,
|
||||||
|
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
|
||||||
|
mmengine_version=mmengine.__version__ + get_git_hash())
|
||||||
|
|
||||||
if is_model_wrapper(self.model):
|
if is_model_wrapper(self.model):
|
||||||
model = self.model.module
|
model = self.model.module
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -139,4 +138,4 @@ class TestInfiniteSampler(TestCase):
|
|||||||
|
|
||||||
def test_set_epoch(self):
|
def test_set_epoch(self):
|
||||||
sampler = InfiniteSampler(self.dataset)
|
sampler = InfiniteSampler(self.dataset)
|
||||||
self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10))
|
sampler.set_epoch(10)
|
||||||
|
@ -46,6 +46,10 @@ class DummyDataset(Dataset):
|
|||||||
data = torch.randn(12, 2)
|
data = torch.randn(12, 2)
|
||||||
label = torch.ones(12)
|
label = torch.ones(12)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metainfo(self):
|
||||||
|
return self.METAINFO
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.data.size(0)
|
return self.data.size(0)
|
||||||
|
|
||||||
|
@ -109,6 +109,10 @@ class ToyDataset(Dataset):
|
|||||||
data = torch.randn(12, 2)
|
data = torch.randn(12, 2)
|
||||||
label = torch.ones(12)
|
label = torch.ones(12)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metainfo(self):
|
||||||
|
return self.METAINFO
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.data.size(0)
|
return self.data.size(0)
|
||||||
|
|
||||||
@ -1333,7 +1337,7 @@ class TestRunner(TestCase):
|
|||||||
self.runner.cur_dataloader = self.warmup_loader
|
self.runner.cur_dataloader = self.warmup_loader
|
||||||
self.runner.call_hook('before_train_epoch')
|
self.runner.call_hook('before_train_epoch')
|
||||||
while self.runner.iter < self._max_iters:
|
while self.runner.iter < self._max_iters:
|
||||||
data_batch = next(self.dataloader)
|
data_batch = next(self.dataloader_iterator)
|
||||||
self.run_iter(data_batch)
|
self.run_iter(data_batch)
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
|
|
||||||
@ -1404,7 +1408,11 @@ class TestRunner(TestCase):
|
|||||||
ckpt = torch.load(path)
|
ckpt = torch.load(path)
|
||||||
self.assertEqual(ckpt['meta']['epoch'], 3)
|
self.assertEqual(ckpt['meta']['epoch'], 3)
|
||||||
self.assertEqual(ckpt['meta']['iter'], 12)
|
self.assertEqual(ckpt['meta']['iter'], 12)
|
||||||
# self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
|
self.assertEqual(ckpt['meta']['dataset_meta'],
|
||||||
|
runner.train_dataloader.dataset.metainfo)
|
||||||
|
self.assertEqual(ckpt['meta']['experiment_name'],
|
||||||
|
runner.experiment_name)
|
||||||
|
self.assertEqual(ckpt['meta']['seed'], runner.seed)
|
||||||
assert isinstance(ckpt['optimizer'], dict)
|
assert isinstance(ckpt['optimizer'], dict)
|
||||||
assert isinstance(ckpt['param_schedulers'], list)
|
assert isinstance(ckpt['param_schedulers'], list)
|
||||||
|
|
||||||
@ -1424,7 +1432,7 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.param_schedulers, list)
|
self.assertIsInstance(runner.param_schedulers, list)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], dict)
|
self.assertIsInstance(runner.param_schedulers[0], dict)
|
||||||
|
|
||||||
# 1.3 test `resume`
|
# 1.3.1 test `resume`
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_checkpoint3'
|
cfg.experiment_name = 'test_checkpoint3'
|
||||||
cfg.optim_wrapper = dict(
|
cfg.optim_wrapper = dict(
|
||||||
@ -1441,8 +1449,38 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1})
|
self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1})
|
||||||
|
|
||||||
|
# 1.3.2 test resume with unmatched dataset_meta
|
||||||
|
ckpt_modified = copy.deepcopy(ckpt)
|
||||||
|
ckpt_modified['meta']['dataset_meta'] = {'CLASSES': ['cat', 'dog']}
|
||||||
|
# ckpt_modified['meta']['seed'] = 123
|
||||||
|
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||||
|
torch.save(ckpt_modified, path_modified)
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
Warning, 'The dataset metainfo from the resumed checkpoint is '
|
||||||
|
'different from the current training dataset, please '
|
||||||
|
'check the correctness of the checkpoint or the training '
|
||||||
|
'dataset.'):
|
||||||
|
runner.resume(path_modified)
|
||||||
|
|
||||||
|
# 1.3.3 test resume with unmatched seed
|
||||||
|
ckpt_modified = copy.deepcopy(ckpt)
|
||||||
|
ckpt_modified['meta']['seed'] = 123
|
||||||
|
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||||
|
torch.save(ckpt_modified, path_modified)
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
Warning, 'The value of random seed in the checkpoint'):
|
||||||
|
runner.resume(path_modified)
|
||||||
|
|
||||||
|
# 1.3.3 test resume with no seed and dataset meta
|
||||||
|
ckpt_modified = copy.deepcopy(ckpt)
|
||||||
|
ckpt_modified['meta'].pop('seed')
|
||||||
|
ckpt_modified['meta'].pop('dataset_meta')
|
||||||
|
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||||
|
torch.save(ckpt_modified, path_modified)
|
||||||
|
runner.resume(path_modified)
|
||||||
|
|
||||||
# 1.4 test auto resume
|
# 1.4 test auto resume
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_checkpoint4'
|
cfg.experiment_name = 'test_checkpoint4'
|
||||||
cfg.resume = True
|
cfg.resume = True
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
@ -1454,7 +1492,7 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
# 1.5 test resume from a specified checkpoint
|
# 1.5 test resume from a specified checkpoint
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_checkpoint5'
|
cfg.experiment_name = 'test_checkpoint5'
|
||||||
cfg.resume = True
|
cfg.resume = True
|
||||||
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
|
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user