mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* Support multiple optimizers * minor refinement * improve unit tests * minor fix * Update unit tests for resuming or saving ckpt for multiple optimizers * refine docstring * refine docstring * fix typo * update docstring * refactor the logic to build multiple optimizers * resolve comments * ParamSchedulers spports multiple optimizers * add optimizer_wrapper * fix comment and docstirng * fix unit test * add unit test * refine docstring * RuntimeInfoHook supports printing multi learning rates * resolve comments * add optimizer_wrapper * fix mypy * fix lint * fix OptimizerWrapperDict docstring and add unit test * rename OptimizerWrapper to OptimWrapper, OptimWrapperDict inherit OptimWrapper, and fix as comment * Fix AmpOptimizerWrapper * rename build_optmizer_wrapper to build_optim_wrapper * refine optimizer wrapper * fix AmpOptimWrapper.step, docstring * resolve confict * rename DefaultOptimConstructor * fix as comment * rename clig grad auguments * refactor optim_wrapper config * fix docstring of DefaultOptimWrapperConstructor fix docstring of DefaultOptimWrapperConstructor * add get_lr method to OptimWrapper and OptimWrapperDict * skip some amp unit test * fix unit test * fix get_lr, get_momentum docstring * refactor get_lr, get_momentum, fix as comment * fix error message Co-authored-by: zhouzaida <zhouzaida@163.com>
145 lines
4.8 KiB
Python
145 lines
4.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
from unittest import TestCase
|
|
from unittest.mock import Mock
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmengine.hooks import EMAHook
|
|
from mmengine.model import ExponentialMovingAverage
|
|
from mmengine.optim import OptimWrapper
|
|
from mmengine.registry import DATASETS, MODEL_WRAPPERS
|
|
from mmengine.runner import Runner
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(2, 1)
|
|
|
|
def forward(self, data_batch, return_loss=False):
|
|
inputs, labels = [], []
|
|
for x in data_batch:
|
|
inputs.append(x['inputs'])
|
|
labels.append(x['data_sample'])
|
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
inputs = torch.stack(inputs).to(device)
|
|
labels = torch.stack(labels).to(device)
|
|
outputs = self.linear(inputs)
|
|
if return_loss:
|
|
loss = (labels - outputs).sum()
|
|
outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
|
|
return outputs
|
|
else:
|
|
outputs = dict(log_vars=dict(a=1, b=0.5))
|
|
return outputs
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class DummyDataset(Dataset):
|
|
METAINFO = dict() # type: ignore
|
|
data = torch.randn(12, 2)
|
|
label = torch.ones(12)
|
|
|
|
def __len__(self):
|
|
return self.data.size(0)
|
|
|
|
def __getitem__(self, index):
|
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
|
|
|
|
|
class TestEMAHook(TestCase):
|
|
|
|
def setUp(self):
|
|
self.temp_dir = tempfile.TemporaryDirectory()
|
|
|
|
def tearDown(self):
|
|
self.temp_dir.cleanup()
|
|
|
|
def test_ema_hook(self):
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
model = ToyModel().to(device)
|
|
evaluator = Mock()
|
|
evaluator.evaluate = Mock(return_value=dict(acc=0.5))
|
|
runner = Runner(
|
|
model=model,
|
|
train_dataloader=dict(
|
|
dataset=dict(type='DummyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
val_dataloader=dict(
|
|
dataset=dict(type='DummyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
val_evaluator=evaluator,
|
|
work_dir=self.temp_dir.name,
|
|
optim_wrapper=OptimWrapper(
|
|
torch.optim.Adam(ToyModel().parameters())),
|
|
train_cfg=dict(by_epoch=True, max_epochs=2),
|
|
val_cfg=dict(interval=1),
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook', )],
|
|
experiment_name='test1')
|
|
runner.train()
|
|
for hook in runner.hooks:
|
|
if isinstance(hook, EMAHook):
|
|
self.assertTrue(
|
|
isinstance(hook.ema_model, ExponentialMovingAverage))
|
|
|
|
self.assertTrue(
|
|
osp.exists(osp.join(self.temp_dir.name, 'epoch_2.pth')))
|
|
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
|
|
self.assertTrue('ema_state_dict' in checkpoint)
|
|
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
|
|
|
|
# load and testing
|
|
runner = Runner(
|
|
model=model,
|
|
test_dataloader=dict(
|
|
dataset=dict(type='DummyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
test_evaluator=evaluator,
|
|
test_cfg=dict(),
|
|
work_dir=self.temp_dir.name,
|
|
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook')],
|
|
experiment_name='test2')
|
|
runner.test()
|
|
|
|
@MODEL_WRAPPERS.register_module()
|
|
class DummyWrapper(nn.Module):
|
|
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.module = model
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.module(*args, **kwargs)
|
|
|
|
# with model wrapper
|
|
runner = Runner(
|
|
model=DummyWrapper(model),
|
|
test_dataloader=dict(
|
|
dataset=dict(type='DummyDataset'),
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
test_evaluator=evaluator,
|
|
test_cfg=dict(),
|
|
work_dir=self.temp_dir.name,
|
|
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook')],
|
|
experiment_name='test3')
|
|
runner.test()
|