Yixiao Fang 08dc8c75d3
[Refactor] Add selfsup algorithms. (#1389)
* remove basehead

* add moco series

* add byol simclr simsiam

* add ut

* update configs

* add simsiam hook

* add and refactor beit

* update ut

* add cae

* update extract_feat

* refactor cae

* add mae

* refactor data preprocessor

* update heads

* add maskfeat

* add milan

* add simmim

* add mixmim

* fix lint

* fix ut

* fix lint

* add eva

* add densecl

* add barlowtwins

* add swav

* fix lint

* update readtherdocs rst

* update docs

* update

* Decrease UT memory usage

* Fix docstring

* update DALLEEncoder

* Update model docs

* refactor dalle encoder

* update docstring

* fix ut

* fix config error

* add val_cfg and test_cfg

* refactor clip generator

* fix lint

* pass check

* fix ut

* add lars

* update type of BEiT in configs

* Use MMEngine style momentum in EMA.

* apply mmpretrain solarize

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-03-06 16:53:15 +08:00

127 lines
3.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import logging
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner
from mmengine.structures import LabelData
from torch.utils.data import Dataset
from mmpretrain.engine import SwAVHook
from mmpretrain.models.heads import SwAVHead
from mmpretrain.models.selfsup import BaseSelfSupervisor
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from mmpretrain.utils import get_ori_model
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = DataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=[self.data[index]], data_samples=data_sample)
@MODELS.register_module()
class SwAVDummyLayer(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
class ToyModel(BaseSelfSupervisor):
def __init__(self):
super().__init__(backbone=dict(type='SwAVDummyLayer'))
self.prototypes_test = nn.Linear(1, 1)
self.head = SwAVHead(
loss=dict(
type='SwAVLoss',
feat_dim=2,
num_crops=[2, 6],
num_prototypes=3))
def loss(self, inputs, data_samples):
labels = []
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(inputs[0])
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
class TestSwAVHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging.shutdown()
MMLogger._instance_dict.clear()
self.temp_dir.cleanup()
def test_swav_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
swav_hook = SwAVHook(
batch_size=1,
epoch_queue_starts=15,
crops_for_assign=[0, 1],
feat_dim=128,
queue_length=300,
frozen_layers_cfg=dict(prototypes=2))
# test SwAVHook
runner = Runner(
model=toy_model,
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
batch_size=1,
num_workers=0),
optim_wrapper=OptimWrapper(
torch.optim.Adam(toy_model.parameters())),
param_scheduler=dict(type='MultiStepLR', milestones=[1]),
train_cfg=dict(by_epoch=True, max_epochs=2),
custom_hooks=[swav_hook],
default_hooks=dict(logger=None),
log_processor=dict(window_size=1),
experiment_name='test_swav_hook',
default_scope='mmpretrain')
runner.train()
for hook in runner.hooks:
if isinstance(hook, SwAVHook):
assert hook.queue_length == 300
assert get_ori_model(runner.model).head.loss_module.use_queue is False