mmselfsup/tests/test_engine/test_hooks/test_simsiam_hook.py
Yixiao Fang ff5f5ae2a9 [CI] update CI package (#476)
* update circleci

* update github action

* fix windows error

* update

* update pr and merge state

* fix lint

* update

* update

* update circle ci

* fix

* fix ut
2022-10-12 19:35:44 +08:00

114 lines
3.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import logging
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import BaseModule
from mmengine.runner import Runner
from mmengine.structures import LabelData
from torch.utils.data import Dataset
from mmselfsup.engine import SimSiamHook
from mmselfsup.models.algorithms import BaseModel
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=[self.data[index]], data_samples=data_sample)
@MODELS.register_module()
class SimSiamDummyLayer(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.predictor = nn.Linear(2, 1)
def forward(self, x):
return self.predictor(x)
class ToyModel(BaseModel):
def __init__(self):
super().__init__(backbone=dict(type='SimSiamDummyLayer'))
def loss(self, batch_inputs, data_samples):
labels = []
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(batch_inputs[0])
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
class TestSimSiamHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
self.temp_dir.cleanup()
def test_simsiam_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
simsiam_hook = SimSiamHook(
fix_pred_lr=True, lr=0.05, adjust_by_epoch=False)
# test SimSiamHook
runner = Runner(
model=toy_model,
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
batch_size=1,
num_workers=0),
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.05),
paramwise_cfg=dict(
custom_keys={'predictor': dict(fix_lr=True)})),
param_scheduler=dict(type='MultiStepLR', milestones=[1]),
train_cfg=dict(by_epoch=True, max_epochs=2),
custom_hooks=[simsiam_hook],
default_hooks=dict(logger=None),
log_processor=dict(window_size=1),
experiment_name='test_simsiam_hook',
default_scope='mmselfsup')
runner.train()
for param_group in runner.optim_wrapper.optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
assert param_group['lr'] == 0.05
else:
assert param_group['lr'] != 0.05