mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* remove supports for python 2.7 * fix the unit test for python 3.5 * add python 3.8 in CI * try ubuntu 18.04 as the environment
43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
import os.path as osp
|
|
import sys
|
|
import warnings
|
|
from unittest.mock import MagicMock
|
|
|
|
import mmcv.runner
|
|
|
|
|
|
def test_pavi_hook():
|
|
try:
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
except ImportError:
|
|
warnings.warn('Skipping test_pavi_hook in the absense of torch')
|
|
return
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
model = nn.Linear(1, 1)
|
|
loader = DataLoader(torch.ones((5, 5)))
|
|
work_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'data')
|
|
runner = mmcv.runner.Runner(
|
|
model=model,
|
|
work_dir=work_dir,
|
|
batch_processor=lambda model, x, **kwargs: {
|
|
'log_vars': {
|
|
'loss': 2.333
|
|
},
|
|
'num_samples': 5
|
|
})
|
|
|
|
hook = mmcv.runner.hooks.PaviLoggerHook(
|
|
add_graph=False, add_last_ckpt=True)
|
|
runner.register_hook(hook)
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
|
|
|
|
assert hasattr(hook, 'writer')
|
|
hook.writer.add_scalars.assert_called_with('val', {'loss': 2.333}, 5)
|
|
hook.writer.add_snapshot_file.assert_called_with(
|
|
tag='data',
|
|
snapshot_file_path=osp.join(work_dir, 'latest.pth'),
|
|
iteration=5)
|