"""Tests the hooks with runners. CommandLine: pytest tests/test_hooks.py xdoctest tests/test_hooks.py zero """ import logging import os.path as osp import shutil import sys import tempfile from unittest.mock import MagicMock, call import pytest import torch import torch.nn as nn from torch.nn.init import constant_ from torch.utils.data import DataLoader from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook, MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook, build_runner) from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook def test_checkpoint_hook(): """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook.""" # test epoch based runner loader = DataLoader(torch.ones((5, 2))) runner = _build_demo_runner('EpochBasedRunner', max_epochs=1) runner.meta = dict() checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(checkpointhook) runner.run([loader], [('train', 1)]) assert runner.meta['hook_msgs']['last_ckpt'] == osp.join( runner.work_dir, 'epoch_1.pth') shutil.rmtree(runner.work_dir) # test iter based runner runner = _build_demo_runner( 'IterBasedRunner', max_iters=1, max_epochs=None) runner.meta = dict() checkpointhook = CheckpointHook(interval=1, by_epoch=False) runner.register_hook(checkpointhook) runner.run([loader], [('train', 1)]) assert runner.meta['hook_msgs']['last_ckpt'] == osp.join( runner.work_dir, 'iter_1.pth') shutil.rmtree(runner.work_dir) def test_ema_hook(): """xdoctest -m tests/test_hooks.py test_ema_hook.""" class DemoModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d( in_channels=1, out_channels=2, kernel_size=1, padding=1, bias=True) self._init_weight() def _init_weight(self): constant_(self.conv.weight, 0) constant_(self.conv.bias, 0) def forward(self, x): return self.conv(x).sum() def train_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) def val_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) loader = DataLoader(torch.ones((1, 1, 1, 1))) runner = _build_demo_runner() demo_model = DemoModel() runner.model = demo_model emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None) checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(emahook, priority='HIGHEST') runner.register_hook(checkpointhook) runner.run([loader, loader], [('train', 1), ('val', 1)]) checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth') contain_ema_buffer = False for name, value in checkpoint['state_dict'].items(): if 'ema' in name: contain_ema_buffer = True assert value.sum() == 0 value.fill_(1) else: assert value.sum() == 0 assert contain_ema_buffer torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth') work_dir = runner.work_dir resume_ema_hook = EMAHook( momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth') runner = _build_demo_runner(max_epochs=2) runner.model = demo_model runner.register_hook(resume_ema_hook, priority='HIGHEST') checkpointhook = CheckpointHook(interval=1, by_epoch=True) runner.register_hook(checkpointhook) runner.run([loader, loader], [('train', 1), ('val', 1)]) checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth') contain_ema_buffer = False for name, value in checkpoint['state_dict'].items(): if 'ema' in name: contain_ema_buffer = True assert value.sum() == 2 else: assert value.sum() == 1 assert contain_ema_buffer shutil.rmtree(runner.work_dir) shutil.rmtree(work_dir) def test_pavi_hook(): sys.modules['pavi'] = MagicMock() loader = DataLoader(torch.ones((5, 2))) runner = _build_demo_runner() runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1))) hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True) runner.register_hook(hook) runner.run([loader, loader], [('train', 1), ('val', 1)]) shutil.rmtree(runner.work_dir) assert hasattr(hook, 'writer') hook.writer.add_scalars.assert_called_with('val', { 'learning_rate': 0.02, 'momentum': 0.95 }, 1) hook.writer.add_snapshot_file.assert_called_with( tag=runner.work_dir.split('/')[-1], snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'), iteration=1) def test_sync_buffers_hook(): loader = DataLoader(torch.ones((5, 2))) runner = _build_demo_runner() runner.register_hook_from_cfg(dict(type='SyncBuffersHook')) runner.run([loader, loader], [('train', 1), ('val', 1)]) shutil.rmtree(runner.work_dir) def test_momentum_runner_hook(): """xdoctest -m tests/test_hooks.py test_momentum_runner_hook.""" sys.modules['pavi'] = MagicMock() loader = DataLoader(torch.ones((10, 2))) runner = _build_demo_runner() # add momentum scheduler hook_cfg = dict( type='CyclicMomentumUpdaterHook', by_epoch=False, target_ratio=(0.85 / 0.95, 1), cyclic_times=1, step_ratio_up=0.4) runner.register_hook_from_cfg(hook_cfg) # add momentum LR scheduler hook_cfg = dict( type='CyclicLrUpdaterHook', by_epoch=False, target_ratio=(10, 1), cyclic_times=1, step_ratio_up=0.4) runner.register_hook_from_cfg(hook_cfg) runner.register_hook_from_cfg(dict(type='IterTimerHook')) # add pavi hook hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) runner.register_hook(hook) runner.run([loader], [('train', 1)]) shutil.rmtree(runner.work_dir) # TODO: use a more elegant way to check values assert hasattr(hook, 'writer') calls = [ call('train', { 'learning_rate': 0.01999999999999999, 'momentum': 0.95 }, 1), call('train', { 'learning_rate': 0.2, 'momentum': 0.85 }, 5), call('train', { 'learning_rate': 0.155, 'momentum': 0.875 }, 7), ] hook.writer.add_scalars.assert_has_calls(calls, any_order=True) def test_cosine_runner_hook(): """xdoctest -m tests/test_hooks.py test_cosine_runner_hook.""" sys.modules['pavi'] = MagicMock() loader = DataLoader(torch.ones((10, 2))) runner = _build_demo_runner() # add momentum scheduler hook_cfg = dict( type='CosineAnnealingMomentumUpdaterHook', min_momentum_ratio=0.99 / 0.95, by_epoch=False, warmup_iters=2, warmup_ratio=0.9 / 0.95) runner.register_hook_from_cfg(hook_cfg) # add momentum LR scheduler hook_cfg = dict( type='CosineAnnealingLrUpdaterHook', by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9) runner.register_hook_from_cfg(hook_cfg) runner.register_hook_from_cfg(dict(type='IterTimerHook')) runner.register_hook(IterTimerHook()) # add pavi hook hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) runner.register_hook(hook) runner.run([loader], [('train', 1)]) shutil.rmtree(runner.work_dir) # TODO: use a more elegant way to check values assert hasattr(hook, 'writer') calls = [ call('train', { 'learning_rate': 0.02, 'momentum': 0.95 }, 1), call('train', { 'learning_rate': 0.01, 'momentum': 0.97 }, 6), call('train', { 'learning_rate': 0.0004894348370484647, 'momentum': 0.9890211303259032 }, 10) ] hook.writer.add_scalars.assert_has_calls(calls, any_order=True) def test_cosine_restart_lr_update_hook(): """Test CosineRestartLrUpdaterHook.""" with pytest.raises(AssertionError): # either `min_lr` or `min_lr_ratio` should be specified CosineRestartLrUpdaterHook( by_epoch=False, periods=[2, 10], restart_weights=[0.5, 0.5], min_lr=0.1, min_lr_ratio=0) with pytest.raises(AssertionError): # periods and restart_weights should have the same length CosineRestartLrUpdaterHook( by_epoch=False, periods=[2, 10], restart_weights=[0.5], min_lr_ratio=0) with pytest.raises(ValueError): # the last cumulative_periods 7 (out of [5, 7]) should >= 10 sys.modules['pavi'] = MagicMock() loader = DataLoader(torch.ones((10, 2))) runner = _build_demo_runner() # add cosine restart LR scheduler hook = CosineRestartLrUpdaterHook( by_epoch=False, periods=[5, 2], # cumulative_periods [5, 7 (5 + 2)] restart_weights=[0.5, 0.5], min_lr=0.0001) runner.register_hook(hook) runner.register_hook(IterTimerHook()) # add pavi hook hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) runner.register_hook(hook) runner.run([loader], [('train', 1)]) shutil.rmtree(runner.work_dir) sys.modules['pavi'] = MagicMock() loader = DataLoader(torch.ones((10, 2))) runner = _build_demo_runner() # add cosine restart LR scheduler hook = CosineRestartLrUpdaterHook( by_epoch=False, periods=[5, 5], restart_weights=[0.5, 0.5], min_lr_ratio=0) runner.register_hook(hook) runner.register_hook(IterTimerHook()) # add pavi hook hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True) runner.register_hook(hook) runner.run([loader], [('train', 1)]) shutil.rmtree(runner.work_dir) # TODO: use a more elegant way to check values assert hasattr(hook, 'writer') calls = [ call('train', { 'learning_rate': 0.01, 'momentum': 0.95 }, 1), call('train', { 'learning_rate': 0.01, 'momentum': 0.95 }, 6), call('train', { 'learning_rate': 0.0009549150281252633, 'momentum': 0.95 }, 10) ] hook.writer.add_scalars.assert_has_calls(calls, any_order=True) @pytest.mark.parametrize('log_model', (True, False)) def test_mlflow_hook(log_model): sys.modules['mlflow'] = MagicMock() sys.modules['mlflow.pytorch'] = MagicMock() runner = _build_demo_runner() loader = DataLoader(torch.ones((5, 2))) hook = MlflowLoggerHook(exp_name='test', log_model=log_model) runner.register_hook(hook) runner.run([loader, loader], [('train', 1), ('val', 1)]) shutil.rmtree(runner.work_dir) hook.mlflow.set_experiment.assert_called_with('test') hook.mlflow.log_metrics.assert_called_with( { 'learning_rate': 0.02, 'momentum': 0.95 }, step=6) if log_model: hook.mlflow_pytorch.log_model.assert_called_with( runner.model, 'models') else: assert not hook.mlflow_pytorch.log_model.called def test_wandb_hook(): sys.modules['wandb'] = MagicMock() runner = _build_demo_runner() hook = WandbLoggerHook() loader = DataLoader(torch.ones((5, 2))) runner.register_hook(hook) runner.run([loader, loader], [('train', 1), ('val', 1)]) shutil.rmtree(runner.work_dir) hook.wandb.init.assert_called_with() hook.wandb.log.assert_called_with({ 'learning_rate': 0.02, 'momentum': 0.95 }, step=6, commit=True) hook.wandb.join.assert_called_with() def _build_demo_runner(runner_type='EpochBasedRunner', max_epochs=1, max_iters=None): class Model(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 1) def forward(self, x): return self.linear(x) def train_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) def val_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) model = Model() optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95) log_config = dict( interval=1, hooks=[ dict(type='TextLoggerHook'), ]) tmp_dir = tempfile.mkdtemp() runner = build_runner( dict(type=runner_type), default_args=dict( model=model, work_dir=tmp_dir, optimizer=optimizer, logger=logging.getLogger(), max_epochs=max_epochs, max_iters=max_iters)) runner.register_checkpoint_hook(dict(interval=1)) runner.register_logger_hooks(log_config) return runner