[Enhance] Enhance checkpoint meta info. (#279)

This commit is contained in:
RangiLyu 2022-06-07 18:48:50 +08:00 committed by GitHub
parent 538ff48aec
commit ad965a5309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 82 additions and 18 deletions

View File

@ -10,4 +10,5 @@ from .logging import *
from .registry import *
from .runner import *
from .utils import *
from .version import __version__, version_info
from .visualization import *

View File

@ -162,6 +162,4 @@ class InfiniteSampler(Sampler):
def set_epoch(self, epoch: int) -> None:
"""Not supported in iteration-based runner."""
raise NotImplementedError(
'The `InfiniteSampler` is only used in iteration-based runner, '
"and doesn't need `set_epoch`")
pass

View File

@ -146,7 +146,8 @@ class IterBasedTrainLoop(BaseLoop):
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in visualizer will be '
'None.')
self.dataloader = iter(self.dataloader)
# get the iterator of the dataloader
self.dataloader_iterator = iter(self.dataloader)
@property
def max_epochs(self):
@ -177,7 +178,7 @@ class IterBasedTrainLoop(BaseLoop):
while self._iter < self._max_iters:
self.runner.model.train()
data_batch = next(self.dataloader)
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
if (self.runner.val_loop is not None

View File

@ -34,7 +34,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of,
find_latest_checkpoint, get_git_hash, is_list_of,
set_multi_processing, symlink)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
@ -331,6 +331,7 @@ class Runner:
self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self._randomness_cfg = randomness
self.set_randomness(**randomness)
if experiment_name is not None:
@ -1796,8 +1797,26 @@ class Runner:
'previous training state resuming from the checkpoint '
'or set `enable` in `auto_scale_lr to False.')
# resume meta information meta
self.meta = checkpoint['meta']
# resume random seed
resumed_seed = checkpoint['meta'].get('seed', None)
current_seed = self._randomness_cfg.get('seed')
if resumed_seed is not None and resumed_seed != current_seed:
if current_seed is not None:
warnings.warn(f'The value of random seed in the '
f'checkpoint "{resumed_seed}" is '
f'different from the value in '
f'`randomness` config "{current_seed}"')
self._randomness_cfg.update(seed=resumed_seed)
self.set_randomness(**self._randomness_cfg)
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
if (dataset_meta is not None
and dataset_meta != self.train_dataloader.dataset.metainfo):
warnings.warn(
'The dataset metainfo from the resumed checkpoint is '
'different from the current training dataset, please '
'check the correctness of the checkpoint or the training '
'dataset.')
# resume optimizer
if 'optimizer' in checkpoint and resume_optimizer:
@ -1909,9 +1928,13 @@ class Runner:
filepath = osp.join(out_dir, filename)
if hasattr(self.model, 'CLASSES') and self.model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=self.model.CLASSES)
meta.update(
cfg=self.cfg.pretty_text,
dataset_meta=self.train_dataloader.dataset.metainfo,
seed=self.seed,
experiment_name=self.experiment_name,
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
mmengine_version=mmengine.__version__ + get_git_hash())
if is_model_wrapper(self.model):
model = self.model.module

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from unittest import TestCase
from unittest.mock import patch
@ -139,4 +138,4 @@ class TestInfiniteSampler(TestCase):
def test_set_epoch(self):
sampler = InfiniteSampler(self.dataset)
self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10))
sampler.set_epoch(10)

View File

@ -46,6 +46,10 @@ class DummyDataset(Dataset):
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)

View File

@ -109,6 +109,10 @@ class ToyDataset(Dataset):
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
@ -1333,7 +1337,7 @@ class TestRunner(TestCase):
self.runner.cur_dataloader = self.warmup_loader
self.runner.call_hook('before_train_epoch')
while self.runner.iter < self._max_iters:
data_batch = next(self.dataloader)
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
self.runner.call_hook('after_train_epoch')
@ -1404,7 +1408,11 @@ class TestRunner(TestCase):
ckpt = torch.load(path)
self.assertEqual(ckpt['meta']['epoch'], 3)
self.assertEqual(ckpt['meta']['iter'], 12)
# self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
self.assertEqual(ckpt['meta']['dataset_meta'],
runner.train_dataloader.dataset.metainfo)
self.assertEqual(ckpt['meta']['experiment_name'],
runner.experiment_name)
self.assertEqual(ckpt['meta']['seed'], runner.seed)
assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list)
@ -1424,7 +1432,7 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers, list)
self.assertIsInstance(runner.param_schedulers[0], dict)
# 1.3 test `resume`
# 1.3.1 test `resume`
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint3'
cfg.optim_wrapper = dict(
@ -1441,8 +1449,38 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1})
# 1.3.2 test resume with unmatched dataset_meta
ckpt_modified = copy.deepcopy(ckpt)
ckpt_modified['meta']['dataset_meta'] = {'CLASSES': ['cat', 'dog']}
# ckpt_modified['meta']['seed'] = 123
path_modified = osp.join(self.temp_dir, 'modified.pth')
torch.save(ckpt_modified, path_modified)
with self.assertWarnsRegex(
Warning, 'The dataset metainfo from the resumed checkpoint is '
'different from the current training dataset, please '
'check the correctness of the checkpoint or the training '
'dataset.'):
runner.resume(path_modified)
# 1.3.3 test resume with unmatched seed
ckpt_modified = copy.deepcopy(ckpt)
ckpt_modified['meta']['seed'] = 123
path_modified = osp.join(self.temp_dir, 'modified.pth')
torch.save(ckpt_modified, path_modified)
with self.assertWarnsRegex(
Warning, 'The value of random seed in the checkpoint'):
runner.resume(path_modified)
# 1.3.3 test resume with no seed and dataset meta
ckpt_modified = copy.deepcopy(ckpt)
ckpt_modified['meta'].pop('seed')
ckpt_modified['meta'].pop('dataset_meta')
path_modified = osp.join(self.temp_dir, 'modified.pth')
torch.save(ckpt_modified, path_modified)
runner.resume(path_modified)
# 1.4 test auto resume
cfg = copy.deepcopy(self.iter_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint4'
cfg.resume = True
runner = Runner.from_cfg(cfg)
@ -1454,7 +1492,7 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 1.5 test resume from a specified checkpoint
cfg = copy.deepcopy(self.iter_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint5'
cfg.resume = True
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')