mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Enhance checkpoint meta info. (#279)
This commit is contained in:
parent
538ff48aec
commit
ad965a5309
@ -10,4 +10,5 @@ from .logging import *
|
||||
from .registry import *
|
||||
from .runner import *
|
||||
from .utils import *
|
||||
from .version import __version__, version_info
|
||||
from .visualization import *
|
||||
|
@ -162,6 +162,4 @@ class InfiniteSampler(Sampler):
|
||||
|
||||
def set_epoch(self, epoch: int) -> None:
|
||||
"""Not supported in iteration-based runner."""
|
||||
raise NotImplementedError(
|
||||
'The `InfiniteSampler` is only used in iteration-based runner, '
|
||||
"and doesn't need `set_epoch`")
|
||||
pass
|
||||
|
@ -146,7 +146,8 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||
'None.')
|
||||
self.dataloader = iter(self.dataloader)
|
||||
# get the iterator of the dataloader
|
||||
self.dataloader_iterator = iter(self.dataloader)
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
@ -177,7 +178,7 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
while self._iter < self._max_iters:
|
||||
self.runner.model.train()
|
||||
|
||||
data_batch = next(self.dataloader)
|
||||
data_batch = next(self.dataloader_iterator)
|
||||
self.run_iter(data_batch)
|
||||
|
||||
if (self.runner.val_loop is not None
|
||||
|
@ -34,7 +34,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||
count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSORS
|
||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||
find_latest_checkpoint, is_list_of,
|
||||
find_latest_checkpoint, get_git_hash, is_list_of,
|
||||
set_multi_processing, symlink)
|
||||
from mmengine.visualization import Visualizer
|
||||
from .base_loop import BaseLoop
|
||||
@ -331,6 +331,7 @@ class Runner:
|
||||
self.setup_env(env_cfg)
|
||||
# self._deterministic and self._seed will be set in the
|
||||
# `set_randomness`` method
|
||||
self._randomness_cfg = randomness
|
||||
self.set_randomness(**randomness)
|
||||
|
||||
if experiment_name is not None:
|
||||
@ -1796,8 +1797,26 @@ class Runner:
|
||||
'previous training state resuming from the checkpoint '
|
||||
'or set `enable` in `auto_scale_lr to False.')
|
||||
|
||||
# resume meta information meta
|
||||
self.meta = checkpoint['meta']
|
||||
# resume random seed
|
||||
resumed_seed = checkpoint['meta'].get('seed', None)
|
||||
current_seed = self._randomness_cfg.get('seed')
|
||||
if resumed_seed is not None and resumed_seed != current_seed:
|
||||
if current_seed is not None:
|
||||
warnings.warn(f'The value of random seed in the '
|
||||
f'checkpoint "{resumed_seed}" is '
|
||||
f'different from the value in '
|
||||
f'`randomness` config "{current_seed}"')
|
||||
self._randomness_cfg.update(seed=resumed_seed)
|
||||
self.set_randomness(**self._randomness_cfg)
|
||||
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
if (dataset_meta is not None
|
||||
and dataset_meta != self.train_dataloader.dataset.metainfo):
|
||||
warnings.warn(
|
||||
'The dataset metainfo from the resumed checkpoint is '
|
||||
'different from the current training dataset, please '
|
||||
'check the correctness of the checkpoint or the training '
|
||||
'dataset.')
|
||||
|
||||
# resume optimizer
|
||||
if 'optimizer' in checkpoint and resume_optimizer:
|
||||
@ -1909,9 +1928,13 @@ class Runner:
|
||||
|
||||
filepath = osp.join(out_dir, filename)
|
||||
|
||||
if hasattr(self.model, 'CLASSES') and self.model.CLASSES is not None:
|
||||
# save class name to the meta
|
||||
meta.update(CLASSES=self.model.CLASSES)
|
||||
meta.update(
|
||||
cfg=self.cfg.pretty_text,
|
||||
dataset_meta=self.train_dataloader.dataset.metainfo,
|
||||
seed=self.seed,
|
||||
experiment_name=self.experiment_name,
|
||||
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
|
||||
mmengine_version=mmengine.__version__ + get_git_hash())
|
||||
|
||||
if is_model_wrapper(self.model):
|
||||
model = self.model.module
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from functools import partial
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -139,4 +138,4 @@ class TestInfiniteSampler(TestCase):
|
||||
|
||||
def test_set_epoch(self):
|
||||
sampler = InfiniteSampler(self.dataset)
|
||||
self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10))
|
||||
sampler.set_epoch(10)
|
||||
|
@ -46,6 +46,10 @@ class DummyDataset(Dataset):
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
|
@ -109,6 +109,10 @@ class ToyDataset(Dataset):
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
@property
|
||||
def metainfo(self):
|
||||
return self.METAINFO
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
@ -1333,7 +1337,7 @@ class TestRunner(TestCase):
|
||||
self.runner.cur_dataloader = self.warmup_loader
|
||||
self.runner.call_hook('before_train_epoch')
|
||||
while self.runner.iter < self._max_iters:
|
||||
data_batch = next(self.dataloader)
|
||||
data_batch = next(self.dataloader_iterator)
|
||||
self.run_iter(data_batch)
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
|
||||
@ -1404,7 +1408,11 @@ class TestRunner(TestCase):
|
||||
ckpt = torch.load(path)
|
||||
self.assertEqual(ckpt['meta']['epoch'], 3)
|
||||
self.assertEqual(ckpt['meta']['iter'], 12)
|
||||
# self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
|
||||
self.assertEqual(ckpt['meta']['dataset_meta'],
|
||||
runner.train_dataloader.dataset.metainfo)
|
||||
self.assertEqual(ckpt['meta']['experiment_name'],
|
||||
runner.experiment_name)
|
||||
self.assertEqual(ckpt['meta']['seed'], runner.seed)
|
||||
assert isinstance(ckpt['optimizer'], dict)
|
||||
assert isinstance(ckpt['param_schedulers'], list)
|
||||
|
||||
@ -1424,7 +1432,7 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner.param_schedulers, list)
|
||||
self.assertIsInstance(runner.param_schedulers[0], dict)
|
||||
|
||||
# 1.3 test `resume`
|
||||
# 1.3.1 test `resume`
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint3'
|
||||
cfg.optim_wrapper = dict(
|
||||
@ -1441,8 +1449,38 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1})
|
||||
|
||||
# 1.3.2 test resume with unmatched dataset_meta
|
||||
ckpt_modified = copy.deepcopy(ckpt)
|
||||
ckpt_modified['meta']['dataset_meta'] = {'CLASSES': ['cat', 'dog']}
|
||||
# ckpt_modified['meta']['seed'] = 123
|
||||
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||
torch.save(ckpt_modified, path_modified)
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The dataset metainfo from the resumed checkpoint is '
|
||||
'different from the current training dataset, please '
|
||||
'check the correctness of the checkpoint or the training '
|
||||
'dataset.'):
|
||||
runner.resume(path_modified)
|
||||
|
||||
# 1.3.3 test resume with unmatched seed
|
||||
ckpt_modified = copy.deepcopy(ckpt)
|
||||
ckpt_modified['meta']['seed'] = 123
|
||||
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||
torch.save(ckpt_modified, path_modified)
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The value of random seed in the checkpoint'):
|
||||
runner.resume(path_modified)
|
||||
|
||||
# 1.3.3 test resume with no seed and dataset meta
|
||||
ckpt_modified = copy.deepcopy(ckpt)
|
||||
ckpt_modified['meta'].pop('seed')
|
||||
ckpt_modified['meta'].pop('dataset_meta')
|
||||
path_modified = osp.join(self.temp_dir, 'modified.pth')
|
||||
torch.save(ckpt_modified, path_modified)
|
||||
runner.resume(path_modified)
|
||||
|
||||
# 1.4 test auto resume
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint4'
|
||||
cfg.resume = True
|
||||
runner = Runner.from_cfg(cfg)
|
||||
@ -1454,7 +1492,7 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 1.5 test resume from a specified checkpoint
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint5'
|
||||
cfg.resume = True
|
||||
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
|
||||
|
Loading…
x
Reference in New Issue
Block a user