mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix resume from checkpoint. (#174)
This commit is contained in:
parent
798eab4825
commit
6e4bcc997d
@ -1084,7 +1084,13 @@ class Runner:
|
|||||||
# decide to load from checkpoint or resume from checkpoint
|
# decide to load from checkpoint or resume from checkpoint
|
||||||
resume_from = None
|
resume_from = None
|
||||||
if self._resume and self._load_from is None:
|
if self._resume and self._load_from is None:
|
||||||
|
# auto resume from the latest checkpoint
|
||||||
resume_from = find_latest_checkpoint(self.work_dir)
|
resume_from = find_latest_checkpoint(self.work_dir)
|
||||||
|
self.logger.info(
|
||||||
|
f'Auto resumed from the latest checkpoint {resume_from}.')
|
||||||
|
elif self._resume and self._load_from is not None:
|
||||||
|
# resume from the specified checkpoint
|
||||||
|
resume_from = self._load_from
|
||||||
|
|
||||||
if resume_from is not None:
|
if resume_from is not None:
|
||||||
self.resume(resume_from)
|
self.resume(resume_from)
|
||||||
|
@ -1075,9 +1075,34 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.optimizer, SGD)
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
# 2. test iter based
|
# 1.4 test auto resume
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
cfg.experiment_name = 'test_checkpoint4'
|
cfg.experiment_name = 'test_checkpoint4'
|
||||||
|
cfg.resume = True
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.load_or_resume()
|
||||||
|
self.assertEqual(runner.epoch, 3)
|
||||||
|
self.assertEqual(runner.iter, 12)
|
||||||
|
self.assertTrue(runner._has_loaded)
|
||||||
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
|
# 1.5 test resume from a specified checkpoint
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint5'
|
||||||
|
cfg.resume = True
|
||||||
|
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.load_or_resume()
|
||||||
|
self.assertEqual(runner.epoch, 1)
|
||||||
|
self.assertEqual(runner.iter, 4)
|
||||||
|
self.assertTrue(runner._has_loaded)
|
||||||
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
|
# 2. test iter based
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint6'
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
@ -1096,7 +1121,7 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
# 2.2 test `load_checkpoint`
|
# 2.2 test `load_checkpoint`
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
cfg.experiment_name = 'test_checkpoint5'
|
cfg.experiment_name = 'test_checkpoint7'
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.load_checkpoint(path)
|
runner.load_checkpoint(path)
|
||||||
self.assertEqual(runner.epoch, 0)
|
self.assertEqual(runner.epoch, 0)
|
||||||
@ -1105,7 +1130,7 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
# 2.3 test `resume`
|
# 2.3 test `resume`
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
cfg.experiment_name = 'test_checkpoint6'
|
cfg.experiment_name = 'test_checkpoint8'
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.resume(path)
|
runner.resume(path)
|
||||||
self.assertEqual(runner.epoch, 0)
|
self.assertEqual(runner.epoch, 0)
|
||||||
@ -1113,3 +1138,28 @@ class TestRunner(TestCase):
|
|||||||
self.assertTrue(runner._has_loaded)
|
self.assertTrue(runner._has_loaded)
|
||||||
self.assertIsInstance(runner.optimizer, SGD)
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
|
# 2.4 test auto resume
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint9'
|
||||||
|
cfg.resume = True
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.load_or_resume()
|
||||||
|
self.assertEqual(runner.epoch, 0)
|
||||||
|
self.assertEqual(runner.iter, 12)
|
||||||
|
self.assertTrue(runner._has_loaded)
|
||||||
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
|
# 2.5 test resume from a specified checkpoint
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint10'
|
||||||
|
cfg.resume = True
|
||||||
|
cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.load_or_resume()
|
||||||
|
self.assertEqual(runner.epoch, 0)
|
||||||
|
self.assertEqual(runner.iter, 3)
|
||||||
|
self.assertTrue(runner._has_loaded)
|
||||||
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user