Yixiao Fang 0ea07c0750
[Refactor] Refactor hooks to support non-dist training (#397)
* apply mmengine.dist instead of torch.dist

* appy get_model to densecl_hook

* fix bug to pass ut

* update typehint
2022-08-08 15:01:47 +08:00

119 lines
3.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine import Runner
from mmengine.data import LabelData
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
from torch.utils.data import Dataset
from mmselfsup.engine import SwAVHook
from mmselfsup.models.algorithms import BaseModel
from mmselfsup.models.heads import SwAVHead
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import get_model
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=[self.data[index]], data_sample=data_sample)
@MODELS.register_module()
class SwAVDummyLayer(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
class ToyModel(BaseModel):
def __init__(self):
super().__init__(backbone=dict(type='SwAVDummyLayer'))
self.prototypes_test = nn.Linear(1, 1)
self.head = SwAVHead(
loss=dict(
type='SwAVLoss',
feat_dim=2,
num_crops=[2, 6],
num_prototypes=3))
def loss(self, batch_inputs, data_samples):
labels = []
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(batch_inputs[0])
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
class TestSwAVHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_swav_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
swav_hook = SwAVHook(
batch_size=1,
epoch_queue_starts=15,
crops_for_assign=[0, 1],
feat_dim=128,
queue_length=300,
frozen_layers_cfg=dict(prototypes=2))
# test SwAVHook
runner = Runner(
model=toy_model,
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=1,
num_workers=0),
optim_wrapper=OptimWrapper(
torch.optim.Adam(toy_model.parameters())),
param_scheduler=dict(type='MultiStepLR', milestones=[1]),
train_cfg=dict(by_epoch=True, max_epochs=2),
custom_hooks=[swav_hook],
default_hooks=dict(logger=None),
log_processor=dict(window_size=1),
experiment_name='test_swav_hook')
runner.train()
for hook in runner.hooks:
if isinstance(hook, SwAVHook):
assert hook.queue_length == 300
assert get_model(runner.model).head.loss.use_queue is False