2020-07-04 00:55:25 +08:00
|
|
|
"""Tests the hooks with runners.
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
CommandLine:
|
2021-04-02 09:44:18 +08:00
|
|
|
pytest tests/test_runner/test_hooks.py
|
2020-04-20 01:23:53 +08:00
|
|
|
xdoctest tests/test_hooks.py zero
|
|
|
|
"""
|
2020-04-22 23:33:54 +08:00
|
|
|
import logging
|
2020-03-02 18:48:52 +08:00
|
|
|
import os.path as osp
|
2021-03-03 10:59:11 +08:00
|
|
|
import re
|
2020-04-22 23:33:54 +08:00
|
|
|
import shutil
|
2020-03-02 18:48:52 +08:00
|
|
|
import sys
|
2020-04-22 23:33:54 +08:00
|
|
|
import tempfile
|
2020-04-20 01:23:53 +08:00
|
|
|
from unittest.mock import MagicMock, call
|
2020-03-02 18:48:52 +08:00
|
|
|
|
2020-04-14 23:54:55 +08:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2020-07-30 22:06:19 +08:00
|
|
|
from torch.nn.init import constant_
|
2020-04-14 23:54:55 +08:00
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
2020-09-25 10:25:29 +08:00
|
|
|
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
|
2021-05-23 15:28:21 +08:00
|
|
|
MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook,
|
|
|
|
WandbLoggerHook, build_runner)
|
2021-05-13 20:29:17 +08:00
|
|
|
from mmcv.runner.hooks.hook import HOOKS, Hook
|
2021-04-02 09:44:18 +08:00
|
|
|
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
|
2021-05-11 13:25:43 +08:00
|
|
|
CyclicLrUpdaterHook,
|
2021-04-27 20:53:29 +08:00
|
|
|
OneCycleLrUpdaterHook,
|
|
|
|
StepLrUpdaterHook)
|
2020-03-02 18:48:52 +08:00
|
|
|
|
|
|
|
|
2020-11-06 19:56:50 +08:00
|
|
|
def test_checkpoint_hook():
|
|
|
|
"""xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""
|
|
|
|
|
|
|
|
# test epoch based runner
|
|
|
|
loader = DataLoader(torch.ones((5, 2)))
|
|
|
|
runner = _build_demo_runner('EpochBasedRunner', max_epochs=1)
|
|
|
|
runner.meta = dict()
|
|
|
|
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
|
|
|
|
runner.register_hook(checkpointhook)
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
|
|
|
|
runner.work_dir, 'epoch_1.pth')
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
# test iter based runner
|
|
|
|
runner = _build_demo_runner(
|
|
|
|
'IterBasedRunner', max_iters=1, max_epochs=None)
|
|
|
|
runner.meta = dict()
|
|
|
|
checkpointhook = CheckpointHook(interval=1, by_epoch=False)
|
|
|
|
runner.register_hook(checkpointhook)
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
|
|
|
|
runner.work_dir, 'iter_1.pth')
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
|
2020-07-30 22:06:19 +08:00
|
|
|
def test_ema_hook():
|
|
|
|
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
|
|
|
|
|
|
|
|
class DemoModel(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
in_channels=1,
|
|
|
|
out_channels=2,
|
|
|
|
kernel_size=1,
|
|
|
|
padding=1,
|
|
|
|
bias=True)
|
|
|
|
self._init_weight()
|
|
|
|
|
|
|
|
def _init_weight(self):
|
|
|
|
constant_(self.conv.weight, 0)
|
|
|
|
constant_(self.conv.bias, 0)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.conv(x).sum()
|
|
|
|
|
|
|
|
def train_step(self, x, optimizer, **kwargs):
|
|
|
|
return dict(loss=self(x))
|
|
|
|
|
|
|
|
def val_step(self, x, optimizer, **kwargs):
|
|
|
|
return dict(loss=self(x))
|
|
|
|
|
|
|
|
loader = DataLoader(torch.ones((1, 1, 1, 1)))
|
|
|
|
runner = _build_demo_runner()
|
|
|
|
demo_model = DemoModel()
|
|
|
|
runner.model = demo_model
|
|
|
|
emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
|
|
|
|
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
|
|
|
|
runner.register_hook(emahook, priority='HIGHEST')
|
|
|
|
runner.register_hook(checkpointhook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
2020-07-30 22:06:19 +08:00
|
|
|
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
|
|
|
|
contain_ema_buffer = False
|
|
|
|
for name, value in checkpoint['state_dict'].items():
|
|
|
|
if 'ema' in name:
|
|
|
|
contain_ema_buffer = True
|
|
|
|
assert value.sum() == 0
|
|
|
|
value.fill_(1)
|
|
|
|
else:
|
|
|
|
assert value.sum() == 0
|
|
|
|
assert contain_ema_buffer
|
|
|
|
torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
|
|
|
|
work_dir = runner.work_dir
|
|
|
|
resume_ema_hook = EMAHook(
|
|
|
|
momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
|
2020-09-25 10:25:29 +08:00
|
|
|
runner = _build_demo_runner(max_epochs=2)
|
2020-07-30 22:06:19 +08:00
|
|
|
runner.model = demo_model
|
|
|
|
runner.register_hook(resume_ema_hook, priority='HIGHEST')
|
|
|
|
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
|
|
|
|
runner.register_hook(checkpointhook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
2020-07-30 22:06:19 +08:00
|
|
|
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
|
|
|
|
contain_ema_buffer = False
|
|
|
|
for name, value in checkpoint['state_dict'].items():
|
|
|
|
if 'ema' in name:
|
|
|
|
contain_ema_buffer = True
|
|
|
|
assert value.sum() == 2
|
|
|
|
else:
|
|
|
|
assert value.sum() == 1
|
|
|
|
assert contain_ema_buffer
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
shutil.rmtree(work_dir)
|
|
|
|
|
|
|
|
|
2021-05-13 20:29:17 +08:00
|
|
|
def test_custom_hook():
|
|
|
|
|
|
|
|
@HOOKS.register_module()
|
|
|
|
class ToyHook(Hook):
|
|
|
|
|
|
|
|
def __init__(self, info, *args, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
self.info = info
|
|
|
|
|
|
|
|
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
|
|
|
|
# test if custom_hooks is None
|
|
|
|
runner.register_custom_hooks(None)
|
|
|
|
assert len(runner.hooks) == 0
|
|
|
|
# test if custom_hooks is dict list
|
|
|
|
custom_hooks_cfg = [
|
|
|
|
dict(type='ToyHook', priority=51, info=51),
|
|
|
|
dict(type='ToyHook', priority=49, info=49)
|
|
|
|
]
|
|
|
|
runner.register_custom_hooks(custom_hooks_cfg)
|
|
|
|
assert [hook.info for hook in runner.hooks] == [49, 51]
|
|
|
|
# test if custom_hooks is object and without priority
|
|
|
|
runner.register_custom_hooks(ToyHook(info='default'))
|
|
|
|
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
|
|
|
|
# test register_training_hooks order
|
|
|
|
custom_hooks_cfg = [
|
|
|
|
dict(type='ToyHook', priority=1, info='custom 1'),
|
|
|
|
dict(type='ToyHook', priority=89, info='custom 89')
|
|
|
|
]
|
|
|
|
runner.register_training_hooks(
|
|
|
|
lr_config=ToyHook('lr'),
|
|
|
|
optimizer_config=ToyHook('optimizer'),
|
|
|
|
checkpoint_config=ToyHook('checkpoint'),
|
|
|
|
log_config=dict(interval=1, hooks=[dict(type='ToyHook', info='log')]),
|
|
|
|
momentum_config=ToyHook('momentum'),
|
|
|
|
timer_config=ToyHook('timer'),
|
|
|
|
custom_hooks_config=custom_hooks_cfg)
|
|
|
|
hooks_order = [
|
|
|
|
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
|
|
|
|
'custom 89', 'log'
|
|
|
|
]
|
|
|
|
assert [hook.info for hook in runner.hooks] == hooks_order
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
|
2020-03-02 18:48:52 +08:00
|
|
|
def test_pavi_hook():
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
|
2020-04-20 01:23:53 +08:00
|
|
|
loader = DataLoader(torch.ones((5, 2)))
|
|
|
|
runner = _build_demo_runner()
|
2020-08-16 01:20:08 +08:00
|
|
|
runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
|
2020-06-02 22:23:21 +08:00
|
|
|
hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
|
2020-03-02 18:48:52 +08:00
|
|
|
runner.register_hook(hook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
2020-04-22 23:33:54 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
2020-03-02 18:48:52 +08:00
|
|
|
|
|
|
|
assert hasattr(hook, 'writer')
|
2020-04-20 01:23:53 +08:00
|
|
|
hook.writer.add_scalars.assert_called_with('val', {
|
|
|
|
'learning_rate': 0.02,
|
|
|
|
'momentum': 0.95
|
2020-08-16 01:20:08 +08:00
|
|
|
}, 1)
|
2020-03-02 18:48:52 +08:00
|
|
|
hook.writer.add_snapshot_file.assert_called_with(
|
2020-04-22 23:33:54 +08:00
|
|
|
tag=runner.work_dir.split('/')[-1],
|
2020-08-16 01:20:08 +08:00
|
|
|
snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'),
|
|
|
|
iteration=1)
|
2020-04-14 23:54:55 +08:00
|
|
|
|
|
|
|
|
2020-07-25 12:51:46 +08:00
|
|
|
def test_sync_buffers_hook():
|
|
|
|
loader = DataLoader(torch.ones((5, 2)))
|
|
|
|
runner = _build_demo_runner()
|
|
|
|
runner.register_hook_from_cfg(dict(type='SyncBuffersHook'))
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
2020-07-25 12:51:46 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
|
2021-04-09 10:13:32 +08:00
|
|
|
@pytest.mark.parametrize('multi_optimziers', (True, False))
|
|
|
|
def test_momentum_runner_hook(multi_optimziers):
|
2020-07-04 00:55:25 +08:00
|
|
|
"""xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
|
2020-04-20 01:23:53 +08:00
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
2021-04-09 10:13:32 +08:00
|
|
|
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
# add momentum scheduler
|
2020-07-24 14:15:44 +08:00
|
|
|
hook_cfg = dict(
|
|
|
|
type='CyclicMomentumUpdaterHook',
|
2020-04-20 01:23:53 +08:00
|
|
|
by_epoch=False,
|
|
|
|
target_ratio=(0.85 / 0.95, 1),
|
|
|
|
cyclic_times=1,
|
|
|
|
step_ratio_up=0.4)
|
2020-07-24 14:15:44 +08:00
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
# add momentum LR scheduler
|
2020-07-24 14:15:44 +08:00
|
|
|
hook_cfg = dict(
|
|
|
|
type='CyclicLrUpdaterHook',
|
2020-04-20 01:23:53 +08:00
|
|
|
by_epoch=False,
|
|
|
|
target_ratio=(10, 1),
|
|
|
|
cyclic_times=1,
|
|
|
|
step_ratio_up=0.4)
|
2020-07-24 14:15:44 +08:00
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
|
|
|
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
# add pavi hook
|
2020-06-02 22:23:21 +08:00
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
2020-04-20 01:23:53 +08:00
|
|
|
runner.register_hook(hook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader], [('train', 1)])
|
2020-04-22 23:33:54 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
# TODO: use a more elegant way to check values
|
|
|
|
assert hasattr(hook, 'writer')
|
2021-04-09 10:13:32 +08:00
|
|
|
if multi_optimziers:
|
|
|
|
calls = [
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.01999999999999999,
|
|
|
|
'learning_rate/model2': 0.009999999999999995,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.2,
|
|
|
|
'learning_rate/model2': 0.1,
|
|
|
|
'momentum/model1': 0.85,
|
|
|
|
'momentum/model2': 0.8052631578947369,
|
|
|
|
}, 5),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.155,
|
|
|
|
'learning_rate/model2': 0.0775,
|
|
|
|
'momentum/model1': 0.875,
|
|
|
|
'momentum/model2': 0.8289473684210527,
|
|
|
|
}, 7)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
calls = [
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.01999999999999999,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 1),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.2,
|
|
|
|
'momentum': 0.85
|
|
|
|
}, 5),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.155,
|
|
|
|
'momentum': 0.875
|
|
|
|
}, 7),
|
|
|
|
]
|
2020-04-20 01:23:53 +08:00
|
|
|
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
|
|
|
|
|
|
|
|
|
2021-04-09 10:13:32 +08:00
|
|
|
@pytest.mark.parametrize('multi_optimziers', (True, False))
|
|
|
|
def test_cosine_runner_hook(multi_optimziers):
|
2020-07-04 00:55:25 +08:00
|
|
|
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
|
2020-04-20 01:23:53 +08:00
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
2021-04-09 10:13:32 +08:00
|
|
|
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
# add momentum scheduler
|
2020-07-24 14:15:44 +08:00
|
|
|
|
|
|
|
hook_cfg = dict(
|
|
|
|
type='CosineAnnealingMomentumUpdaterHook',
|
2020-06-02 22:23:21 +08:00
|
|
|
min_momentum_ratio=0.99 / 0.95,
|
|
|
|
by_epoch=False,
|
|
|
|
warmup_iters=2,
|
|
|
|
warmup_ratio=0.9 / 0.95)
|
2020-07-24 14:15:44 +08:00
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
# add momentum LR scheduler
|
2020-07-24 14:15:44 +08:00
|
|
|
hook_cfg = dict(
|
|
|
|
type='CosineAnnealingLrUpdaterHook',
|
|
|
|
by_epoch=False,
|
|
|
|
min_lr_ratio=0,
|
|
|
|
warmup_iters=2,
|
|
|
|
warmup_ratio=0.9)
|
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
|
|
|
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
|
2020-06-02 22:23:21 +08:00
|
|
|
runner.register_hook(IterTimerHook())
|
2020-04-20 01:23:53 +08:00
|
|
|
# add pavi hook
|
2020-06-02 22:23:21 +08:00
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
2020-04-20 01:23:53 +08:00
|
|
|
runner.register_hook(hook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader], [('train', 1)])
|
2020-04-22 23:33:54 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
# TODO: use a more elegant way to check values
|
|
|
|
assert hasattr(hook, 'writer')
|
2021-04-09 10:13:32 +08:00
|
|
|
if multi_optimziers:
|
|
|
|
calls = [
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.02,
|
|
|
|
'learning_rate/model2': 0.01,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.01,
|
|
|
|
'learning_rate/model2': 0.005,
|
|
|
|
'momentum/model1': 0.97,
|
|
|
|
'momentum/model2': 0.9189473684210527,
|
|
|
|
}, 6),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.0004894348370484647,
|
|
|
|
'learning_rate/model2': 0.00024471741852423234,
|
|
|
|
'momentum/model1': 0.9890211303259032,
|
|
|
|
'momentum/model2': 0.9369673866245399,
|
|
|
|
}, 10)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
calls = [
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.02,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 1),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.01,
|
|
|
|
'momentum': 0.97
|
|
|
|
}, 6),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate': 0.0004894348370484647,
|
|
|
|
'momentum': 0.9890211303259032
|
|
|
|
}, 10)
|
|
|
|
]
|
2020-04-20 01:23:53 +08:00
|
|
|
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
|
|
|
|
|
|
|
|
|
2021-04-10 21:19:45 +08:00
|
|
|
@pytest.mark.parametrize('multi_optimziers, max_iters', [(True, 10), (True, 2),
|
|
|
|
(False, 10),
|
|
|
|
(False, 2)])
|
|
|
|
def test_one_cycle_runner_hook(multi_optimziers, max_iters):
|
2021-04-02 09:44:18 +08:00
|
|
|
"""Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook."""
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# by_epoch should be False
|
|
|
|
OneCycleLrUpdaterHook(max_lr=0.1, by_epoch=True)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
# expected float between 0 and 1
|
|
|
|
OneCycleLrUpdaterHook(max_lr=0.1, pct_start=-0.1)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
# anneal_strategy should be either 'cos' or 'linear'
|
|
|
|
OneCycleLrUpdaterHook(max_lr=0.1, anneal_strategy='sin')
|
|
|
|
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
2021-04-09 10:13:32 +08:00
|
|
|
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
|
2021-04-02 09:44:18 +08:00
|
|
|
|
|
|
|
# add momentum scheduler
|
|
|
|
hook_cfg = dict(
|
|
|
|
type='OneCycleMomentumUpdaterHook',
|
|
|
|
base_momentum=0.85,
|
|
|
|
max_momentum=0.95,
|
|
|
|
pct_start=0.5,
|
|
|
|
anneal_strategy='cos',
|
|
|
|
three_phase=False)
|
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
|
|
|
|
2021-04-09 10:13:32 +08:00
|
|
|
# add LR scheduler
|
2021-04-02 09:44:18 +08:00
|
|
|
hook_cfg = dict(
|
|
|
|
type='OneCycleLrUpdaterHook',
|
|
|
|
max_lr=0.01,
|
|
|
|
pct_start=0.5,
|
|
|
|
anneal_strategy='cos',
|
|
|
|
div_factor=25,
|
|
|
|
final_div_factor=1e4,
|
|
|
|
three_phase=False)
|
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
|
|
|
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
|
|
|
|
runner.register_hook(IterTimerHook())
|
|
|
|
# add pavi hook
|
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
# TODO: use a more elegant way to check values
|
|
|
|
assert hasattr(hook, 'writer')
|
2021-04-09 10:13:32 +08:00
|
|
|
if multi_optimziers:
|
|
|
|
calls = [
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.0003999999999999993,
|
|
|
|
'learning_rate/model2': 0.0003999999999999993,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.95,
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.00904508879153485,
|
|
|
|
'learning_rate/model2': 0.00904508879153485,
|
|
|
|
'momentum/model1': 0.8595491502812526,
|
|
|
|
'momentum/model2': 0.8595491502812526,
|
|
|
|
}, 6),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 4e-08,
|
|
|
|
'learning_rate/model2': 4e-08,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.95,
|
|
|
|
}, 10)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
calls = [
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.0003999999999999993,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate': 0.00904508879153485,
|
|
|
|
'momentum': 0.8595491502812526
|
|
|
|
}, 6),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 4e-08,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 10)
|
|
|
|
]
|
2021-04-02 09:44:18 +08:00
|
|
|
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
|
|
|
|
|
2021-04-10 21:19:45 +08:00
|
|
|
# Test OneCycleLrUpdaterHook
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
|
|
|
runner = _build_demo_runner(
|
|
|
|
runner_type='IterBasedRunner', max_epochs=None, max_iters=max_iters)
|
|
|
|
|
|
|
|
args = dict(
|
|
|
|
max_lr=0.01,
|
|
|
|
total_steps=5,
|
|
|
|
pct_start=0.5,
|
|
|
|
anneal_strategy='linear',
|
|
|
|
div_factor=25,
|
|
|
|
final_div_factor=1e4,
|
|
|
|
)
|
|
|
|
hook = OneCycleLrUpdaterHook(**args)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
if max_iters == 10:
|
|
|
|
# test total_steps < max_iters
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
else:
|
|
|
|
# test total_steps > max_iters
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
lr_last = runner.current_lr()
|
|
|
|
t = torch.tensor([0.0], requires_grad=True)
|
|
|
|
optim = torch.optim.SGD([t], lr=0.01)
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, **args)
|
|
|
|
lr_target = []
|
|
|
|
for _ in range(max_iters):
|
|
|
|
optim.step()
|
|
|
|
lr_target.append(optim.param_groups[0]['lr'])
|
|
|
|
lr_scheduler.step()
|
|
|
|
assert lr_target[-1] == lr_last[0]
|
|
|
|
|
2021-04-02 09:44:18 +08:00
|
|
|
|
2021-04-09 10:13:32 +08:00
|
|
|
@pytest.mark.parametrize('multi_optimziers', (True, False))
|
|
|
|
def test_cosine_restart_lr_update_hook(multi_optimziers):
|
2020-06-15 23:01:26 +08:00
|
|
|
"""Test CosineRestartLrUpdaterHook."""
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# either `min_lr` or `min_lr_ratio` should be specified
|
|
|
|
CosineRestartLrUpdaterHook(
|
|
|
|
by_epoch=False,
|
|
|
|
periods=[2, 10],
|
|
|
|
restart_weights=[0.5, 0.5],
|
|
|
|
min_lr=0.1,
|
|
|
|
min_lr_ratio=0)
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# periods and restart_weights should have the same length
|
|
|
|
CosineRestartLrUpdaterHook(
|
|
|
|
by_epoch=False,
|
|
|
|
periods=[2, 10],
|
|
|
|
restart_weights=[0.5],
|
|
|
|
min_lr_ratio=0)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
# the last cumulative_periods 7 (out of [5, 7]) should >= 10
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
|
|
|
runner = _build_demo_runner()
|
|
|
|
|
|
|
|
# add cosine restart LR scheduler
|
|
|
|
hook = CosineRestartLrUpdaterHook(
|
|
|
|
by_epoch=False,
|
|
|
|
periods=[5, 2], # cumulative_periods [5, 7 (5 + 2)]
|
|
|
|
restart_weights=[0.5, 0.5],
|
|
|
|
min_lr=0.0001)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.register_hook(IterTimerHook())
|
|
|
|
|
|
|
|
# add pavi hook
|
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
|
|
|
runner.register_hook(hook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader], [('train', 1)])
|
2020-06-15 23:01:26 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
2021-04-09 10:13:32 +08:00
|
|
|
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
|
2020-06-15 23:01:26 +08:00
|
|
|
|
|
|
|
# add cosine restart LR scheduler
|
|
|
|
hook = CosineRestartLrUpdaterHook(
|
|
|
|
by_epoch=False,
|
|
|
|
periods=[5, 5],
|
|
|
|
restart_weights=[0.5, 0.5],
|
|
|
|
min_lr_ratio=0)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.register_hook(IterTimerHook())
|
|
|
|
|
|
|
|
# add pavi hook
|
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
|
|
|
runner.register_hook(hook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader], [('train', 1)])
|
2020-06-15 23:01:26 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
# TODO: use a more elegant way to check values
|
|
|
|
assert hasattr(hook, 'writer')
|
2021-04-09 10:13:32 +08:00
|
|
|
if multi_optimziers:
|
|
|
|
calls = [
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.01,
|
|
|
|
'learning_rate/model2': 0.005,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.01,
|
|
|
|
'learning_rate/model2': 0.005,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 6),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.0009549150281252633,
|
|
|
|
'learning_rate/model2': 0.00047745751406263163,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 10)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
calls = [
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.01,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 1),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.01,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 6),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.0009549150281252633,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 10)
|
|
|
|
]
|
2020-06-15 23:01:26 +08:00
|
|
|
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
|
|
|
|
|
|
|
|
|
2021-04-27 20:53:29 +08:00
|
|
|
@pytest.mark.parametrize('multi_optimziers', (True, False))
|
2021-05-11 20:06:04 +08:00
|
|
|
def test_step_runner_hook(multi_optimziers):
|
2021-04-27 20:53:29 +08:00
|
|
|
"""Test StepLrUpdaterHook."""
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
# `step` should be specified
|
|
|
|
StepLrUpdaterHook()
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# if `step` is int, should be positive
|
|
|
|
StepLrUpdaterHook(-10)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# if `step` is list of int, should all be positive
|
|
|
|
StepLrUpdaterHook([10, 16, -20])
|
|
|
|
|
|
|
|
# test StepLrUpdaterHook with int `step` value
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((30, 2)))
|
|
|
|
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
|
|
|
|
|
2021-05-11 20:06:04 +08:00
|
|
|
# add momentum scheduler
|
|
|
|
hook_cfg = dict(
|
|
|
|
type='StepMomentumUpdaterHook',
|
|
|
|
by_epoch=False,
|
|
|
|
step=5,
|
|
|
|
gamma=0.5,
|
|
|
|
min_momentum=0.05)
|
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
|
|
|
|
2021-04-27 20:53:29 +08:00
|
|
|
# add step LR scheduler
|
|
|
|
hook = StepLrUpdaterHook(by_epoch=False, step=5, gamma=0.5, min_lr=1e-3)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.register_hook(IterTimerHook())
|
|
|
|
|
|
|
|
# add pavi hook
|
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
# TODO: use a more elegant way to check values
|
|
|
|
assert hasattr(hook, 'writer')
|
|
|
|
if multi_optimziers:
|
|
|
|
calls = [
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.02,
|
|
|
|
'learning_rate/model2': 0.01,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.01,
|
|
|
|
'learning_rate/model2': 0.005,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 0.475,
|
|
|
|
'momentum/model2': 0.45
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 6),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.0025,
|
|
|
|
'learning_rate/model2': 0.00125,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 0.11875,
|
|
|
|
'momentum/model2': 0.1125
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 16),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.00125,
|
|
|
|
'learning_rate/model2': 0.001,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 0.059375,
|
|
|
|
'momentum/model2': 0.05625
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 21),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.001,
|
|
|
|
'learning_rate/model2': 0.001,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 0.05,
|
|
|
|
'momentum/model2': 0.05
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 26),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.001,
|
|
|
|
'learning_rate/model2': 0.001,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 0.05,
|
|
|
|
'momentum/model2': 0.05
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 30)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
calls = [
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.02,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 1),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.01,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum': 0.475
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 6),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.0025,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum': 0.11875
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 16),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.00125,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum': 0.059375
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 21),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.001,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum': 0.05
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 26),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.001,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum': 0.05
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 30)
|
|
|
|
]
|
|
|
|
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
|
|
|
|
|
|
|
|
# test StepLrUpdaterHook with list[int] `step` value
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
|
|
|
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
|
|
|
|
|
2021-05-11 20:06:04 +08:00
|
|
|
# add momentum scheduler
|
|
|
|
hook_cfg = dict(
|
|
|
|
type='StepMomentumUpdaterHook',
|
|
|
|
by_epoch=False,
|
|
|
|
step=[4, 6, 8],
|
|
|
|
gamma=0.1)
|
|
|
|
runner.register_hook_from_cfg(hook_cfg)
|
|
|
|
|
2021-04-27 20:53:29 +08:00
|
|
|
# add step LR scheduler
|
|
|
|
hook = StepLrUpdaterHook(by_epoch=False, step=[4, 6, 8], gamma=0.1)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.register_hook(IterTimerHook())
|
|
|
|
|
|
|
|
# add pavi hook
|
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
# TODO: use a more elegant way to check values
|
|
|
|
assert hasattr(hook, 'writer')
|
|
|
|
if multi_optimziers:
|
|
|
|
calls = [
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.02,
|
|
|
|
'learning_rate/model2': 0.01,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.002,
|
|
|
|
'learning_rate/model2': 0.001,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 9.5e-2,
|
|
|
|
'momentum/model2': 9.000000000000001e-2
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 5),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 2.0000000000000004e-4,
|
|
|
|
'learning_rate/model2': 1.0000000000000002e-4,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 9.500000000000001e-3,
|
|
|
|
'momentum/model2': 9.000000000000003e-3
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 7),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 2.0000000000000005e-05,
|
|
|
|
'learning_rate/model2': 1.0000000000000003e-05,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum/model1': 9.500000000000002e-4,
|
|
|
|
'momentum/model2': 9.000000000000002e-4
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 9)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
calls = [
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.02,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 1),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.002,
|
2021-05-11 20:06:04 +08:00
|
|
|
'momentum': 0.095
|
2021-04-27 20:53:29 +08:00
|
|
|
}, 5),
|
2021-05-11 20:06:04 +08:00
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate': 2.0000000000000004e-4,
|
|
|
|
'momentum': 9.500000000000001e-3
|
|
|
|
}, 7),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate': 2.0000000000000005e-05,
|
|
|
|
'momentum': 9.500000000000002e-4
|
|
|
|
}, 9)
|
2021-04-27 20:53:29 +08:00
|
|
|
]
|
|
|
|
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
|
|
|
|
|
|
|
|
|
2021-05-11 13:25:43 +08:00
|
|
|
@pytest.mark.parametrize('multi_optimizers, max_iters', [(True, 8),
|
|
|
|
(False, 8)])
|
|
|
|
def test_cyclic_lr_update_hook(multi_optimizers, max_iters):
|
|
|
|
"""Test CyclicLrUpdateHook."""
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# by_epoch should be False
|
|
|
|
CyclicLrUpdaterHook(by_epoch=True)
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# target_ratio" must be either float or tuple/list of two floats
|
|
|
|
CyclicLrUpdaterHook(by_epoch=False, target_ratio=(10.0, 0.1, 0.2))
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
# step_ratio_up" must be in range [0,1)
|
|
|
|
CyclicLrUpdaterHook(by_epoch=False, step_ratio_up=1.4)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
# anneal_strategy must be one of "cos" or "linear"
|
|
|
|
CyclicLrUpdaterHook(by_epoch=False, anneal_strategy='sin')
|
|
|
|
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
loader = DataLoader(torch.ones((10, 2)))
|
|
|
|
runner = _build_demo_runner(
|
|
|
|
runner_type='IterBasedRunner',
|
|
|
|
max_epochs=None,
|
|
|
|
max_iters=max_iters,
|
|
|
|
multi_optimziers=multi_optimizers)
|
|
|
|
|
|
|
|
# add cyclic LR scheduler
|
|
|
|
hook = CyclicLrUpdaterHook(
|
|
|
|
by_epoch=False,
|
|
|
|
target_ratio=(10.0, 1.0),
|
|
|
|
cyclic_times=1,
|
|
|
|
step_ratio_up=0.5,
|
|
|
|
anneal_strategy='linear')
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
|
|
|
|
runner.register_hook(IterTimerHook())
|
|
|
|
# add pavi hook
|
|
|
|
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.run([loader], [('train', 1)])
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
assert hasattr(hook, 'writer')
|
|
|
|
if multi_optimizers:
|
|
|
|
calls = [
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.02,
|
|
|
|
'learning_rate/model2': 0.01,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 1),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.155,
|
|
|
|
'learning_rate/model2': 0.0775,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 4),
|
|
|
|
call(
|
|
|
|
'train', {
|
|
|
|
'learning_rate/model1': 0.155,
|
|
|
|
'learning_rate/model2': 0.0775,
|
|
|
|
'momentum/model1': 0.95,
|
|
|
|
'momentum/model2': 0.9,
|
|
|
|
}, 6)
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
calls = [
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.02,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 1),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.155,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 4),
|
|
|
|
call('train', {
|
|
|
|
'learning_rate': 0.155,
|
|
|
|
'momentum': 0.95
|
|
|
|
}, 6),
|
|
|
|
]
|
|
|
|
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
|
|
|
|
|
|
|
|
|
2020-04-14 23:54:55 +08:00
|
|
|
@pytest.mark.parametrize('log_model', (True, False))
|
|
|
|
def test_mlflow_hook(log_model):
|
|
|
|
sys.modules['mlflow'] = MagicMock()
|
|
|
|
sys.modules['mlflow.pytorch'] = MagicMock()
|
|
|
|
|
2020-04-20 01:23:53 +08:00
|
|
|
runner = _build_demo_runner()
|
|
|
|
loader = DataLoader(torch.ones((5, 2)))
|
2020-04-14 23:54:55 +08:00
|
|
|
|
2020-06-02 22:23:21 +08:00
|
|
|
hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
|
2020-04-14 23:54:55 +08:00
|
|
|
runner.register_hook(hook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
2020-04-22 23:33:54 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
2020-04-14 23:54:55 +08:00
|
|
|
|
|
|
|
hook.mlflow.set_experiment.assert_called_with('test')
|
2020-04-20 01:23:53 +08:00
|
|
|
hook.mlflow.log_metrics.assert_called_with(
|
|
|
|
{
|
|
|
|
'learning_rate': 0.02,
|
|
|
|
'momentum': 0.95
|
2020-11-23 10:50:18 +08:00
|
|
|
}, step=6)
|
2020-04-14 23:54:55 +08:00
|
|
|
if log_model:
|
|
|
|
hook.mlflow_pytorch.log_model.assert_called_with(
|
|
|
|
runner.model, 'models')
|
|
|
|
else:
|
|
|
|
assert not hook.mlflow_pytorch.log_model.called
|
|
|
|
|
|
|
|
|
|
|
|
def test_wandb_hook():
|
|
|
|
sys.modules['wandb'] = MagicMock()
|
2020-04-20 01:23:53 +08:00
|
|
|
runner = _build_demo_runner()
|
2020-06-02 22:23:21 +08:00
|
|
|
hook = WandbLoggerHook()
|
2020-04-20 01:23:53 +08:00
|
|
|
loader = DataLoader(torch.ones((5, 2)))
|
2020-04-14 23:54:55 +08:00
|
|
|
|
|
|
|
runner.register_hook(hook)
|
2020-09-25 10:25:29 +08:00
|
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
2020-04-22 23:33:54 +08:00
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
2020-04-14 23:54:55 +08:00
|
|
|
hook.wandb.init.assert_called_with()
|
2020-04-20 01:23:53 +08:00
|
|
|
hook.wandb.log.assert_called_with({
|
|
|
|
'learning_rate': 0.02,
|
|
|
|
'momentum': 0.95
|
|
|
|
},
|
2020-11-23 10:50:18 +08:00
|
|
|
step=6,
|
|
|
|
commit=True)
|
2020-04-14 23:54:55 +08:00
|
|
|
hook.wandb.join.assert_called_with()
|
2020-04-20 01:23:53 +08:00
|
|
|
|
|
|
|
|
2021-05-23 15:28:21 +08:00
|
|
|
def test_neptune_hook():
|
|
|
|
sys.modules['neptune'] = MagicMock()
|
|
|
|
sys.modules['neptune.new'] = MagicMock()
|
|
|
|
runner = _build_demo_runner()
|
|
|
|
hook = NeptuneLoggerHook()
|
|
|
|
loader = DataLoader(torch.ones((5, 2)))
|
|
|
|
|
|
|
|
runner.register_hook(hook)
|
|
|
|
runner.run([loader, loader], [('train', 1), ('val', 1)])
|
|
|
|
shutil.rmtree(runner.work_dir)
|
|
|
|
|
|
|
|
hook.neptune.init.assert_called_with()
|
|
|
|
hook.run['momentum'].log.assert_called_with(0.95, step=6)
|
|
|
|
hook.run.stop.assert_called_with()
|
|
|
|
|
|
|
|
|
2021-05-13 20:29:17 +08:00
|
|
|
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
|
|
|
|
max_epochs=1,
|
|
|
|
max_iters=None,
|
|
|
|
multi_optimziers=False):
|
2020-06-02 22:23:21 +08:00
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.linear = nn.Linear(2, 1)
|
2021-04-09 10:13:32 +08:00
|
|
|
self.conv = nn.Conv2d(3, 3, 3)
|
2020-06-02 22:23:21 +08:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.linear(x)
|
|
|
|
|
|
|
|
def train_step(self, x, optimizer, **kwargs):
|
|
|
|
return dict(loss=self(x))
|
|
|
|
|
|
|
|
def val_step(self, x, optimizer, **kwargs):
|
|
|
|
return dict(loss=self(x))
|
|
|
|
|
|
|
|
model = Model()
|
|
|
|
|
2021-04-09 10:13:32 +08:00
|
|
|
if multi_optimziers:
|
|
|
|
optimizer = {
|
|
|
|
'model1':
|
|
|
|
torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
|
|
|
|
'model2':
|
|
|
|
torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
|
2020-04-20 01:23:53 +08:00
|
|
|
|
2020-04-22 23:33:54 +08:00
|
|
|
tmp_dir = tempfile.mkdtemp()
|
2020-09-25 10:25:29 +08:00
|
|
|
runner = build_runner(
|
|
|
|
dict(type=runner_type),
|
|
|
|
default_args=dict(
|
|
|
|
model=model,
|
|
|
|
work_dir=tmp_dir,
|
|
|
|
optimizer=optimizer,
|
|
|
|
logger=logging.getLogger(),
|
|
|
|
max_epochs=max_epochs,
|
|
|
|
max_iters=max_iters))
|
2021-05-13 20:29:17 +08:00
|
|
|
return runner
|
|
|
|
|
|
|
|
|
|
|
|
def _build_demo_runner(runner_type='EpochBasedRunner',
|
|
|
|
max_epochs=1,
|
|
|
|
max_iters=None,
|
|
|
|
multi_optimziers=False):
|
|
|
|
|
|
|
|
log_config = dict(
|
|
|
|
interval=1, hooks=[
|
|
|
|
dict(type='TextLoggerHook'),
|
|
|
|
])
|
|
|
|
|
|
|
|
runner = _build_demo_runner_without_hook(runner_type, max_epochs,
|
|
|
|
max_iters, multi_optimziers)
|
|
|
|
|
2020-08-16 01:20:08 +08:00
|
|
|
runner.register_checkpoint_hook(dict(interval=1))
|
2020-04-20 01:23:53 +08:00
|
|
|
runner.register_logger_hooks(log_config)
|
|
|
|
return runner
|
2021-03-03 10:59:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_runner_with_revise_keys():
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(3, 3, 1)
|
|
|
|
|
|
|
|
class PrefixModel(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.backbone = Model()
|
|
|
|
|
|
|
|
pmodel = PrefixModel()
|
|
|
|
model = Model()
|
|
|
|
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
|
|
|
|
|
|
|
|
# add prefix
|
|
|
|
torch.save(model.state_dict(), checkpoint_path)
|
|
|
|
runner = _build_demo_runner(runner_type='EpochBasedRunner')
|
|
|
|
runner.model = pmodel
|
|
|
|
state_dict = runner.load_checkpoint(
|
|
|
|
checkpoint_path, revise_keys=[(r'^', 'backbone.')])
|
|
|
|
for key in pmodel.backbone.state_dict().keys():
|
|
|
|
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
|
|
|
|
# strip prefix
|
|
|
|
torch.save(pmodel.state_dict(), checkpoint_path)
|
|
|
|
runner.model = model
|
|
|
|
state_dict = runner.load_checkpoint(
|
|
|
|
checkpoint_path, revise_keys=[(r'^backbone\.', '')])
|
|
|
|
for key in state_dict.keys():
|
|
|
|
key_stripped = re.sub(r'^backbone\.', '', key)
|
|
|
|
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
|
|
|
|
os.remove(checkpoint_path)
|