EasyCV/tests/test_hooks/test_ema_hook.py

102 lines
2.9 KiB
Python

#! -*- coding: utf8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import time
import unittest
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from torch import nn
from easycv.datasets import build_dataloader
from easycv.file import io
from easycv.hooks.ema_hook import EMAHook
from easycv.hooks.optimizer_hook import OptimizerHook
from easycv.runner import EVRunner
from easycv.utils.logger import get_root_logger
from easycv.utils.test_util import get_tmp_dir
class DummyDataset(object):
def __getitem__(self, idx):
return torch.ones((5, 2))
def __len__(self):
return 3
def _build_model():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
nn.init.constant_(self.linear.bias, 1.0)
nn.init.constant_(self.linear.weight, 2.0)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).squeeze()[0])
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).squeeze()[0])
return Model()
class EMAHookTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_ema_hook(self):
model = _build_model()
model = MMDataParallel(model, device_ids=[0])
optimizer = torch.optim.SGD(model.parameters(), lr=1, momentum=0.9)
tmp_dir = get_tmp_dir()
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = os.path.join(tmp_dir, '{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file)
runner = EVRunner(
model=model, work_dir=tmp_dir, optimizer=optimizer, logger=logger)
optimizer_config = OptimizerHook()
runner.register_optimizer_hook(optimizer_config)
hook = EMAHook(decay=0.99)
dataset = DummyDataset()
dataloader = build_dataloader(
dataset, imgs_per_gpu=1, workers_per_gpu=1)
runner.register_hook(hook)
runner.run([dataloader], [('train', 1)], 2)
state_dict = runner.model.state_dict()
self.assertEqual(
state_dict['module.linear.weight'].detach().cpu().numpy().all(),
np.asarray([[-15.8297, -15.8297]]).all())
self.assertEqual(
state_dict['module.linear.bias'].detach().cpu().numpy().all(),
np.asarray([-16.8297]).all())
ema_state_dict = runner.ema.model.state_dict()
self.assertEqual(
ema_state_dict['linear.weight'].detach().cpu().numpy().all(),
np.asarray([[-15.8158, -15.8158]]).all())
self.assertEqual(
ema_state_dict['linear.bias'].detach().cpu().numpy().all(),
np.asarray([-16.8158]).all())
io.rmtree(runner.work_dir)
if __name__ == '__main__':
unittest.main()