mmcv/tests/test_hooks.py

94 lines
2.7 KiB
Python
Raw Normal View History

import os.path as osp
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import mmcv.runner
def test_pavi_hook():
sys.modules['pavi'] = MagicMock()
model = nn.Linear(1, 1)
loader = DataLoader(torch.ones((5, 5)))
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'loss': 2.333
},
'num_samples': 5
})
hook = mmcv.runner.hooks.PaviLoggerHook(
add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
assert hasattr(hook, 'writer')
hook.writer.add_scalars.assert_called_with('val', {'loss': 2.333}, 5)
hook.writer.add_snapshot_file.assert_called_with(
tag='data',
snapshot_file_path=osp.join(work_dir, 'latest.pth'),
iteration=5)
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
sys.modules['mlflow'] = MagicMock()
sys.modules['mlflow.pytorch'] = MagicMock()
model = nn.Linear(1, 1)
loader = DataLoader(torch.ones((5, 5)))
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
runner = mmcv.runner.Runner(
model=model,
work_dir=work_dir,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
hook = mmcv.runner.hooks.MlflowLoggerHook(
exp_name='test', log_model=log_model)
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.mlflow.set_experiment.assert_called_with('test')
hook.mlflow.log_metrics.assert_called_with({'accuracy/val': 0.98}, step=5)
if log_model:
hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models')
else:
assert not hook.mlflow_pytorch.log_model.called
def test_wandb_hook():
sys.modules['wandb'] = MagicMock()
hook = mmcv.runner.hooks.WandbLoggerHook()
loader = DataLoader(torch.ones((5, 5)))
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda model, x, **kwargs: {
'log_vars': {
'accuracy': 0.98
},
'num_samples': 5
})
runner.register_hook(hook)
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
hook.wandb.init.assert_called_with()
hook.wandb.log.assert_called_with({'accuracy/val': 0.98}, step=5)
hook.wandb.join.assert_called_with()