From ad965a5309e89f887c38adda5053fd1ce3b13571 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Jun 2022 18:48:50 +0800 Subject: [PATCH] [Enhance] Enhance checkpoint meta info. (#279) --- mmengine/__init__.py | 1 + mmengine/data/sampler.py | 4 +-- mmengine/runner/loops.py | 5 ++-- mmengine/runner/runner.py | 35 +++++++++++++++++++---- tests/test_data/test_sampler.py | 3 +- tests/test_hook/test_ema_hook.py | 4 +++ tests/test_runner/test_runner.py | 48 ++++++++++++++++++++++++++++---- 7 files changed, 82 insertions(+), 18 deletions(-) diff --git a/mmengine/__init__.py b/mmengine/__init__.py index 586ddf57..5e10018c 100644 --- a/mmengine/__init__.py +++ b/mmengine/__init__.py @@ -10,4 +10,5 @@ from .logging import * from .registry import * from .runner import * from .utils import * +from .version import __version__, version_info from .visualization import * diff --git a/mmengine/data/sampler.py b/mmengine/data/sampler.py index 23d928a9..95e8e2da 100644 --- a/mmengine/data/sampler.py +++ b/mmengine/data/sampler.py @@ -162,6 +162,4 @@ class InfiniteSampler(Sampler): def set_epoch(self, epoch: int) -> None: """Not supported in iteration-based runner.""" - raise NotImplementedError( - 'The `InfiniteSampler` is only used in iteration-based runner, ' - "and doesn't need `set_epoch`") + pass diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index fc2634b3..482207dd 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -146,7 +146,8 @@ class IterBasedTrainLoop(BaseLoop): f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' 'metainfo. ``dataset_meta`` in visualizer will be ' 'None.') - self.dataloader = iter(self.dataloader) + # get the iterator of the dataloader + self.dataloader_iterator = iter(self.dataloader) @property def max_epochs(self): @@ -177,7 +178,7 @@ class IterBasedTrainLoop(BaseLoop): while self._iter < self._max_iters: self.runner.model.train() - data_batch = next(self.dataloader) + data_batch = next(self.dataloader_iterator) self.run_iter(data_batch) if (self.runner.val_loop is not None diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 13a0754c..aff5c6ba 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -34,7 +34,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, count_registered_modules) from mmengine.registry.root import LOG_PROCESSORS from mmengine.utils import (TORCH_VERSION, digit_version, - find_latest_checkpoint, is_list_of, + find_latest_checkpoint, get_git_hash, is_list_of, set_multi_processing, symlink) from mmengine.visualization import Visualizer from .base_loop import BaseLoop @@ -331,6 +331,7 @@ class Runner: self.setup_env(env_cfg) # self._deterministic and self._seed will be set in the # `set_randomness`` method + self._randomness_cfg = randomness self.set_randomness(**randomness) if experiment_name is not None: @@ -1796,8 +1797,26 @@ class Runner: 'previous training state resuming from the checkpoint ' 'or set `enable` in `auto_scale_lr to False.') - # resume meta information meta - self.meta = checkpoint['meta'] + # resume random seed + resumed_seed = checkpoint['meta'].get('seed', None) + current_seed = self._randomness_cfg.get('seed') + if resumed_seed is not None and resumed_seed != current_seed: + if current_seed is not None: + warnings.warn(f'The value of random seed in the ' + f'checkpoint "{resumed_seed}" is ' + f'different from the value in ' + f'`randomness` config "{current_seed}"') + self._randomness_cfg.update(seed=resumed_seed) + self.set_randomness(**self._randomness_cfg) + + dataset_meta = checkpoint['meta'].get('dataset_meta', None) + if (dataset_meta is not None + and dataset_meta != self.train_dataloader.dataset.metainfo): + warnings.warn( + 'The dataset metainfo from the resumed checkpoint is ' + 'different from the current training dataset, please ' + 'check the correctness of the checkpoint or the training ' + 'dataset.') # resume optimizer if 'optimizer' in checkpoint and resume_optimizer: @@ -1909,9 +1928,13 @@ class Runner: filepath = osp.join(out_dir, filename) - if hasattr(self.model, 'CLASSES') and self.model.CLASSES is not None: - # save class name to the meta - meta.update(CLASSES=self.model.CLASSES) + meta.update( + cfg=self.cfg.pretty_text, + dataset_meta=self.train_dataloader.dataset.metainfo, + seed=self.seed, + experiment_name=self.experiment_name, + time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), + mmengine_version=mmengine.__version__ + get_git_hash()) if is_model_wrapper(self.model): model = self.model.module diff --git a/tests/test_data/test_sampler.py b/tests/test_data/test_sampler.py index ff5fb383..846e5d51 100644 --- a/tests/test_data/test_sampler.py +++ b/tests/test_data/test_sampler.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from functools import partial from unittest import TestCase from unittest.mock import patch @@ -139,4 +138,4 @@ class TestInfiniteSampler(TestCase): def test_set_epoch(self): sampler = InfiniteSampler(self.dataset) - self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10)) + sampler.set_epoch(10) diff --git a/tests/test_hook/test_ema_hook.py b/tests/test_hook/test_ema_hook.py index fa6233d0..4cae6c83 100644 --- a/tests/test_hook/test_ema_hook.py +++ b/tests/test_hook/test_ema_hook.py @@ -46,6 +46,10 @@ class DummyDataset(Dataset): data = torch.randn(12, 2) label = torch.ones(12) + @property + def metainfo(self): + return self.METAINFO + def __len__(self): return self.data.size(0) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 5005927a..438b840e 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -109,6 +109,10 @@ class ToyDataset(Dataset): data = torch.randn(12, 2) label = torch.ones(12) + @property + def metainfo(self): + return self.METAINFO + def __len__(self): return self.data.size(0) @@ -1333,7 +1337,7 @@ class TestRunner(TestCase): self.runner.cur_dataloader = self.warmup_loader self.runner.call_hook('before_train_epoch') while self.runner.iter < self._max_iters: - data_batch = next(self.dataloader) + data_batch = next(self.dataloader_iterator) self.run_iter(data_batch) self.runner.call_hook('after_train_epoch') @@ -1404,7 +1408,11 @@ class TestRunner(TestCase): ckpt = torch.load(path) self.assertEqual(ckpt['meta']['epoch'], 3) self.assertEqual(ckpt['meta']['iter'], 12) - # self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path) + self.assertEqual(ckpt['meta']['dataset_meta'], + runner.train_dataloader.dataset.metainfo) + self.assertEqual(ckpt['meta']['experiment_name'], + runner.experiment_name) + self.assertEqual(ckpt['meta']['seed'], runner.seed) assert isinstance(ckpt['optimizer'], dict) assert isinstance(ckpt['param_schedulers'], list) @@ -1424,7 +1432,7 @@ class TestRunner(TestCase): self.assertIsInstance(runner.param_schedulers, list) self.assertIsInstance(runner.param_schedulers[0], dict) - # 1.3 test `resume` + # 1.3.1 test `resume` cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint3' cfg.optim_wrapper = dict( @@ -1441,8 +1449,38 @@ class TestRunner(TestCase): self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1}) + # 1.3.2 test resume with unmatched dataset_meta + ckpt_modified = copy.deepcopy(ckpt) + ckpt_modified['meta']['dataset_meta'] = {'CLASSES': ['cat', 'dog']} + # ckpt_modified['meta']['seed'] = 123 + path_modified = osp.join(self.temp_dir, 'modified.pth') + torch.save(ckpt_modified, path_modified) + with self.assertWarnsRegex( + Warning, 'The dataset metainfo from the resumed checkpoint is ' + 'different from the current training dataset, please ' + 'check the correctness of the checkpoint or the training ' + 'dataset.'): + runner.resume(path_modified) + + # 1.3.3 test resume with unmatched seed + ckpt_modified = copy.deepcopy(ckpt) + ckpt_modified['meta']['seed'] = 123 + path_modified = osp.join(self.temp_dir, 'modified.pth') + torch.save(ckpt_modified, path_modified) + with self.assertWarnsRegex( + Warning, 'The value of random seed in the checkpoint'): + runner.resume(path_modified) + + # 1.3.3 test resume with no seed and dataset meta + ckpt_modified = copy.deepcopy(ckpt) + ckpt_modified['meta'].pop('seed') + ckpt_modified['meta'].pop('dataset_meta') + path_modified = osp.join(self.temp_dir, 'modified.pth') + torch.save(ckpt_modified, path_modified) + runner.resume(path_modified) + # 1.4 test auto resume - cfg = copy.deepcopy(self.iter_based_cfg) + cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint4' cfg.resume = True runner = Runner.from_cfg(cfg) @@ -1454,7 +1492,7 @@ class TestRunner(TestCase): self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) # 1.5 test resume from a specified checkpoint - cfg = copy.deepcopy(self.iter_based_cfg) + cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint5' cfg.resume = True cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')