225 lines
7.6 KiB
Python
225 lines
7.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import logging
|
|
import os.path as osp
|
|
import tempfile
|
|
from collections import OrderedDict
|
|
from unittest import TestCase
|
|
from unittest.mock import ANY, MagicMock, call
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.device import get_device
|
|
from mmengine.evaluator import Evaluator
|
|
from mmengine.logging import MMLogger
|
|
from mmengine.model import BaseModel
|
|
from mmengine.optim import OptimWrapper
|
|
from mmengine.runner import Runner
|
|
from mmengine.testing import assert_allclose
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmpretrain.engine import EMAHook
|
|
|
|
|
|
class SimpleModel(BaseModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.para = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, *args, mode='tensor', **kwargs):
|
|
if mode == 'predict':
|
|
return self.para.clone()
|
|
elif mode == 'loss':
|
|
return {'loss': self.para.mean()}
|
|
|
|
|
|
class DummyDataset(Dataset):
|
|
METAINFO = dict() # type: ignore
|
|
data = torch.randn(6, 2)
|
|
label = torch.ones(6)
|
|
|
|
@property
|
|
def metainfo(self):
|
|
return self.METAINFO
|
|
|
|
def __len__(self):
|
|
return self.data.size(0)
|
|
|
|
def __getitem__(self, index):
|
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
|
|
|
|
|
class TestEMAHook(TestCase):
|
|
|
|
def setUp(self):
|
|
self.temp_dir = tempfile.TemporaryDirectory()
|
|
state_dict = OrderedDict(
|
|
meta=dict(epoch=1, iter=2),
|
|
# The actual ema para
|
|
state_dict={'para': torch.tensor([1.])},
|
|
# The actual original para
|
|
ema_state_dict={'module.para': torch.tensor([2.])},
|
|
)
|
|
self.ckpt = osp.join(self.temp_dir.name, 'ema.pth')
|
|
torch.save(state_dict, self.ckpt)
|
|
|
|
def tearDown(self):
|
|
# `FileHandler` should be closed in Windows, otherwise we cannot
|
|
# delete the temporary directory
|
|
logging.shutdown()
|
|
MMLogger._instance_dict.clear()
|
|
self.temp_dir.cleanup()
|
|
|
|
def test_load_state_dict(self):
|
|
device = get_device()
|
|
model = SimpleModel().to(device)
|
|
ema_hook = EMAHook()
|
|
runner = Runner(
|
|
model=model,
|
|
train_dataloader=dict(
|
|
dataset=DummyDataset(),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
optim_wrapper=OptimWrapper(
|
|
optimizer=torch.optim.Adam(model.parameters(), lr=0.)),
|
|
train_cfg=dict(by_epoch=True, max_epochs=2),
|
|
work_dir=self.temp_dir.name,
|
|
resume=False,
|
|
load_from=self.ckpt,
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[ema_hook],
|
|
default_scope='mmpretrain',
|
|
experiment_name='load_state_dict')
|
|
runner.train()
|
|
assert_allclose(runner.model.para, torch.tensor([1.], device=device))
|
|
|
|
def test_evaluate_on_ema(self):
|
|
|
|
device = get_device()
|
|
model = SimpleModel().to(device)
|
|
|
|
# Test validate on ema model
|
|
evaluator = Evaluator([MagicMock()])
|
|
runner = Runner(
|
|
model=model,
|
|
val_dataloader=dict(
|
|
dataset=DummyDataset(),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
val_evaluator=evaluator,
|
|
val_cfg=dict(),
|
|
work_dir=self.temp_dir.name,
|
|
load_from=self.ckpt,
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook')],
|
|
default_scope='mmpretrain',
|
|
experiment_name='validate_on_ema')
|
|
runner.val()
|
|
evaluator.metrics[0].process.assert_has_calls([
|
|
call(ANY, [torch.tensor([1.]).to(device)]),
|
|
])
|
|
self.assertNotIn(
|
|
call(ANY, [torch.tensor([2.]).to(device)]),
|
|
evaluator.metrics[0].process.mock_calls)
|
|
|
|
# Test test on ema model
|
|
evaluator = Evaluator([MagicMock()])
|
|
runner = Runner(
|
|
model=model,
|
|
test_dataloader=dict(
|
|
dataset=DummyDataset(),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
test_evaluator=evaluator,
|
|
test_cfg=dict(),
|
|
work_dir=self.temp_dir.name,
|
|
load_from=self.ckpt,
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook')],
|
|
default_scope='mmpretrain',
|
|
experiment_name='test_on_ema')
|
|
runner.test()
|
|
evaluator.metrics[0].process.assert_has_calls([
|
|
call(ANY, [torch.tensor([1.]).to(device)]),
|
|
])
|
|
self.assertNotIn(
|
|
call(ANY, [torch.tensor([2.]).to(device)]),
|
|
evaluator.metrics[0].process.mock_calls)
|
|
|
|
# Test validate on both models
|
|
evaluator = Evaluator([MagicMock()])
|
|
runner = Runner(
|
|
model=model,
|
|
val_dataloader=dict(
|
|
dataset=DummyDataset(),
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
val_evaluator=evaluator,
|
|
val_cfg=dict(),
|
|
work_dir=self.temp_dir.name,
|
|
load_from=self.ckpt,
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook', evaluate_on_origin=True)],
|
|
default_scope='mmpretrain',
|
|
experiment_name='validate_on_ema_false',
|
|
)
|
|
runner.val()
|
|
evaluator.metrics[0].process.assert_has_calls([
|
|
call(ANY, [torch.tensor([1.]).to(device)]),
|
|
call(ANY, [torch.tensor([2.]).to(device)]),
|
|
])
|
|
|
|
# Test test on both models
|
|
evaluator = Evaluator([MagicMock()])
|
|
runner = Runner(
|
|
model=model,
|
|
test_dataloader=dict(
|
|
dataset=DummyDataset(),
|
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
test_evaluator=evaluator,
|
|
test_cfg=dict(),
|
|
work_dir=self.temp_dir.name,
|
|
load_from=self.ckpt,
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook', evaluate_on_origin=True)],
|
|
default_scope='mmpretrain',
|
|
experiment_name='test_on_ema_false',
|
|
)
|
|
runner.test()
|
|
evaluator.metrics[0].process.assert_has_calls([
|
|
call(ANY, [torch.tensor([1.]).to(device)]),
|
|
call(ANY, [torch.tensor([2.]).to(device)]),
|
|
])
|
|
|
|
# Test evaluate_on_ema=False
|
|
evaluator = Evaluator([MagicMock()])
|
|
with self.assertWarnsRegex(UserWarning, 'evaluate_on_origin'):
|
|
runner = Runner(
|
|
model=model,
|
|
test_dataloader=dict(
|
|
dataset=DummyDataset(),
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
batch_size=3,
|
|
num_workers=0),
|
|
test_evaluator=evaluator,
|
|
test_cfg=dict(),
|
|
work_dir=self.temp_dir.name,
|
|
load_from=self.ckpt,
|
|
default_hooks=dict(logger=None),
|
|
custom_hooks=[dict(type='EMAHook', evaluate_on_ema=False)],
|
|
default_scope='mmpretrain',
|
|
experiment_name='not_test_on_ema')
|
|
runner.test()
|
|
evaluator.metrics[0].process.assert_has_calls([
|
|
call(ANY, [torch.tensor([2.]).to(device)]),
|
|
])
|
|
self.assertNotIn(
|
|
call(ANY, [torch.tensor([1.]).to(device)]),
|
|
evaluator.metrics[0].process.mock_calls)
|