mmselfsup/tests/test_engine/test_hooks/test_simsiam_hook.py

114 lines
3.5 KiB
Python
Raw Normal View History

2022-06-10 11:20:20 +00:00
# Copyright (c) OpenMMLab. All rights reserved.
import logging
2022-06-10 11:20:20 +00:00
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
2022-06-10 11:20:20 +00:00
from mmengine.model import BaseModule
from mmengine.runner import Runner
from mmengine.structures import LabelData
2022-06-10 11:20:20 +00:00
from torch.utils.data import Dataset
2022-07-15 05:23:54 +00:00
from mmselfsup.engine import SimSiamHook
2022-06-10 11:20:20 +00:00
from mmselfsup.models.algorithms import BaseModel
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
2022-06-10 11:20:20 +00:00
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=[self.data[index]], data_samples=data_sample)
2022-06-10 11:20:20 +00:00
@MODELS.register_module()
class SimSiamDummyLayer(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.predictor = nn.Linear(2, 1)
def forward(self, x):
return self.predictor(x)
class ToyModel(BaseModel):
def __init__(self):
super().__init__(backbone=dict(type='SimSiamDummyLayer'))
def loss(self, batch_inputs, data_samples):
labels = []
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(batch_inputs[0])
2022-06-10 11:20:20 +00:00
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
class TestSimSiamHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
2022-06-10 11:20:20 +00:00
self.temp_dir.cleanup()
def test_simsiam_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
simsiam_hook = SimSiamHook(
fix_pred_lr=True, lr=0.05, adjust_by_epoch=False)
# test SimSiamHook
runner = Runner(
model=toy_model,
2022-06-10 11:20:20 +00:00
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
2022-06-10 11:20:20 +00:00
batch_size=1,
num_workers=0),
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.05),
paramwise_cfg=dict(
custom_keys={'predictor': dict(fix_lr=True)})),
param_scheduler=dict(type='MultiStepLR', milestones=[1]),
train_cfg=dict(by_epoch=True, max_epochs=2),
custom_hooks=[simsiam_hook],
default_hooks=dict(logger=None),
log_processor=dict(window_size=1),
experiment_name='test_simsiam_hook',
default_scope='mmselfsup')
2022-06-10 11:20:20 +00:00
runner.train()
for param_group in runner.optim_wrapper.optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
assert param_group['lr'] == 0.05
else:
assert param_group['lr'] != 0.05