[Fix] Fix save scheduler state dict with optim wrapper (#375)

* fix save scheduler state dict with optim wrapper

* remove for loop and inherit TestParameterScheduler

* remove for loop and inherit TestParameterScheduler

* minor refine
pull/379/head
Mashiro 2022-07-20 16:32:48 +08:00 committed by GitHub
parent 5b065b10fd
commit 6b47035fdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 10 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional
@ -7,7 +8,7 @@ import torch.nn as nn
from torch.nn.utils import clip_grad
from torch.optim import Optimizer
from mmengine.logging import MessageHub, MMLogger
from mmengine.logging import MessageHub, print_log
from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import has_batch_norm
@ -106,7 +107,6 @@ class OptimWrapper:
'If `clip_grad` is not None, it should be a `dict` '
'which is the arguments of `torch.nn.utils.clip_grad`')
self.clip_grad_kwargs = clip_grad
self.logger = MMLogger.get_current_instance()
# Used to update `grad_norm` log message.
self.message_hub = MessageHub.get_current_instance()
self._inner_count = 0
@ -318,16 +318,20 @@ class OptimWrapper:
self._inner_count = init_counts
self._max_counts = max_counts
if self._inner_count % self._accumulative_counts != 0:
self.logger.warning(
print_log(
'Resumed iteration number is not divisible by '
'`_accumulative_counts` in `GradientCumulativeOptimizerHook`, '
'which means the gradient of some iterations is lost and the '
'result may be influenced slightly.')
'result may be influenced slightly.',
logger='current',
level=logging.WARNING)
if has_batch_norm(model) and self._accumulative_counts > 1:
self.logger.warning(
print_log(
'Gradient accumulative may slightly decrease '
'performance because the model has BatchNorm layers.')
'performance because the model has BatchNorm layers.',
logger='current',
level=logging.WARNING)
# Remainder of `_max_counts` divided by `_accumulative_counts`
self._remainder_counts = self._max_counts % self._accumulative_counts

View File

@ -1025,13 +1025,15 @@ class OneCycleParamScheduler(_ParamScheduler):
else:
return [param] * len(optimizer.param_groups)
def _annealing_cos(self, start, end, pct):
@staticmethod
def _annealing_cos(start, end, pct):
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def _annealing_linear(self, start, end, pct):
@staticmethod
def _annealing_linear(start, end, pct):
"""Linearly anneal from `start` to `end` as pct goes from 0.0 to
1.0."""
return (end - start) * pct + start

View File

@ -64,7 +64,6 @@ class TestOptimWrapper(MultiProcessTestCase):
self.assertIs(optim_wrapper.optimizer, self.optimizer)
self.assertIsNone(optim_wrapper.clip_grad_kwargs)
self.assertEqual(optim_wrapper._accumulative_counts, 1)
self.assertIs(optim_wrapper.logger, self.logger)
self.assertIs(optim_wrapper.message_hub, self.message_hub)
self.assertEqual(optim_wrapper._inner_count, 0)
self.assertEqual(optim_wrapper._max_counts, -1)

View File

@ -1,11 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import os.path as osp
import tempfile
from unittest import TestCase
import torch
import torch.nn.functional as F
import torch.optim as optim
from mmengine.optim import OptimWrapper
# yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
@ -55,6 +58,7 @@ class TestParameterScheduler(TestCase):
lr=lr,
momentum=momentum,
weight_decay=weight_decay)
self.temp_dir = tempfile.TemporaryDirectory()
def test_base_scheduler_step(self):
with self.assertRaises(NotImplementedError):
@ -408,7 +412,10 @@ class TestParameterScheduler(TestCase):
scheduler.optimizer.step()
scheduler.step()
scheduler_copy = construct2()
scheduler_copy.load_state_dict(scheduler.state_dict())
torch.save(scheduler.state_dict(),
osp.join(self.temp_dir.name, 'tmp.pth'))
state_dict = torch.load(osp.join(self.temp_dir.name, 'tmp.pth'))
scheduler_copy.load_state_dict(state_dict)
for key in scheduler.__dict__.keys():
if key != 'optimizer':
self.assertEqual(scheduler.__dict__[key],
@ -743,3 +750,10 @@ class TestParameterScheduler(TestCase):
param_name='lr',
total_steps=10,
anneal_strategy='a')
class TestParameterSchedulerOptimWrapper(TestParameterScheduler):
def setUp(self):
super().setUp()
self.optimizer = OptimWrapper(optimizer=self.optimizer)