mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix resume from checkpoint. (#174)
This commit is contained in:
parent
798eab4825
commit
6e4bcc997d
@ -1084,7 +1084,13 @@ class Runner:
|
||||
# decide to load from checkpoint or resume from checkpoint
|
||||
resume_from = None
|
||||
if self._resume and self._load_from is None:
|
||||
# auto resume from the latest checkpoint
|
||||
resume_from = find_latest_checkpoint(self.work_dir)
|
||||
self.logger.info(
|
||||
f'Auto resumed from the latest checkpoint {resume_from}.')
|
||||
elif self._resume and self._load_from is not None:
|
||||
# resume from the specified checkpoint
|
||||
resume_from = self._load_from
|
||||
|
||||
if resume_from is not None:
|
||||
self.resume(resume_from)
|
||||
|
@ -1075,9 +1075,34 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 2. test iter based
|
||||
# 1.4 test auto resume
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint4'
|
||||
cfg.resume = True
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.load_or_resume()
|
||||
self.assertEqual(runner.epoch, 3)
|
||||
self.assertEqual(runner.iter, 12)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 1.5 test resume from a specified checkpoint
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint5'
|
||||
cfg.resume = True
|
||||
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.load_or_resume()
|
||||
self.assertEqual(runner.epoch, 1)
|
||||
self.assertEqual(runner.iter, 4)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 2. test iter based
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint6'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
@ -1096,7 +1121,7 @@ class TestRunner(TestCase):
|
||||
|
||||
# 2.2 test `load_checkpoint`
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint5'
|
||||
cfg.experiment_name = 'test_checkpoint7'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.load_checkpoint(path)
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
@ -1105,7 +1130,7 @@ class TestRunner(TestCase):
|
||||
|
||||
# 2.3 test `resume`
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint6'
|
||||
cfg.experiment_name = 'test_checkpoint8'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.resume(path)
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
@ -1113,3 +1138,28 @@ class TestRunner(TestCase):
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 2.4 test auto resume
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint9'
|
||||
cfg.resume = True
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.load_or_resume()
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
self.assertEqual(runner.iter, 12)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 2.5 test resume from a specified checkpoint
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint10'
|
||||
cfg.resume = True
|
||||
cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.load_or_resume()
|
||||
self.assertEqual(runner.epoch, 0)
|
||||
self.assertEqual(runner.iter, 3)
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
Loading…
x
Reference in New Issue
Block a user