[Fix] Fix save scheduler state dict with optim wrapper (#375)
* fix save scheduler state dict with optim wrapper * remove for loop and inherit TestParameterScheduler * remove for loop and inherit TestParameterScheduler * minor refinepull/379/head
parent
5b065b10fd
commit
6b47035fdf
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
@ -7,7 +8,7 @@ import torch.nn as nn
|
|||
from torch.nn.utils import clip_grad
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.logging import MessageHub, print_log
|
||||
from mmengine.registry import OPTIM_WRAPPERS
|
||||
from mmengine.utils import has_batch_norm
|
||||
|
||||
|
@ -106,7 +107,6 @@ class OptimWrapper:
|
|||
'If `clip_grad` is not None, it should be a `dict` '
|
||||
'which is the arguments of `torch.nn.utils.clip_grad`')
|
||||
self.clip_grad_kwargs = clip_grad
|
||||
self.logger = MMLogger.get_current_instance()
|
||||
# Used to update `grad_norm` log message.
|
||||
self.message_hub = MessageHub.get_current_instance()
|
||||
self._inner_count = 0
|
||||
|
@ -318,16 +318,20 @@ class OptimWrapper:
|
|||
self._inner_count = init_counts
|
||||
self._max_counts = max_counts
|
||||
if self._inner_count % self._accumulative_counts != 0:
|
||||
self.logger.warning(
|
||||
print_log(
|
||||
'Resumed iteration number is not divisible by '
|
||||
'`_accumulative_counts` in `GradientCumulativeOptimizerHook`, '
|
||||
'which means the gradient of some iterations is lost and the '
|
||||
'result may be influenced slightly.')
|
||||
'result may be influenced slightly.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
if has_batch_norm(model) and self._accumulative_counts > 1:
|
||||
self.logger.warning(
|
||||
print_log(
|
||||
'Gradient accumulative may slightly decrease '
|
||||
'performance because the model has BatchNorm layers.')
|
||||
'performance because the model has BatchNorm layers.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
# Remainder of `_max_counts` divided by `_accumulative_counts`
|
||||
self._remainder_counts = self._max_counts % self._accumulative_counts
|
||||
|
||||
|
|
|
@ -1025,13 +1025,15 @@ class OneCycleParamScheduler(_ParamScheduler):
|
|||
else:
|
||||
return [param] * len(optimizer.param_groups)
|
||||
|
||||
def _annealing_cos(self, start, end, pct):
|
||||
@staticmethod
|
||||
def _annealing_cos(start, end, pct):
|
||||
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
|
||||
|
||||
cos_out = math.cos(math.pi * pct) + 1
|
||||
return end + (start - end) / 2.0 * cos_out
|
||||
|
||||
def _annealing_linear(self, start, end, pct):
|
||||
@staticmethod
|
||||
def _annealing_linear(start, end, pct):
|
||||
"""Linearly anneal from `start` to `end` as pct goes from 0.0 to
|
||||
1.0."""
|
||||
return (end - start) * pct + start
|
||||
|
|
|
@ -64,7 +64,6 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
self.assertIs(optim_wrapper.optimizer, self.optimizer)
|
||||
self.assertIsNone(optim_wrapper.clip_grad_kwargs)
|
||||
self.assertEqual(optim_wrapper._accumulative_counts, 1)
|
||||
self.assertIs(optim_wrapper.logger, self.logger)
|
||||
self.assertIs(optim_wrapper.message_hub, self.message_hub)
|
||||
self.assertEqual(optim_wrapper._inner_count, 0)
|
||||
self.assertEqual(optim_wrapper._max_counts, -1)
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from mmengine.optim import OptimWrapper
|
||||
# yapf: disable
|
||||
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
|
@ -55,6 +58,7 @@ class TestParameterScheduler(TestCase):
|
|||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay)
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def test_base_scheduler_step(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
|
@ -408,7 +412,10 @@ class TestParameterScheduler(TestCase):
|
|||
scheduler.optimizer.step()
|
||||
scheduler.step()
|
||||
scheduler_copy = construct2()
|
||||
scheduler_copy.load_state_dict(scheduler.state_dict())
|
||||
torch.save(scheduler.state_dict(),
|
||||
osp.join(self.temp_dir.name, 'tmp.pth'))
|
||||
state_dict = torch.load(osp.join(self.temp_dir.name, 'tmp.pth'))
|
||||
scheduler_copy.load_state_dict(state_dict)
|
||||
for key in scheduler.__dict__.keys():
|
||||
if key != 'optimizer':
|
||||
self.assertEqual(scheduler.__dict__[key],
|
||||
|
@ -743,3 +750,10 @@ class TestParameterScheduler(TestCase):
|
|||
param_name='lr',
|
||||
total_steps=10,
|
||||
anneal_strategy='a')
|
||||
|
||||
|
||||
class TestParameterSchedulerOptimWrapper(TestParameterScheduler):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.optimizer = OptimWrapper(optimizer=self.optimizer)
|
||||
|
|
Loading…
Reference in New Issue