[Refactor] refactor hooks and add ut

pull/352/head
fangyixiao.vendor 2022-06-10 11:20:20 +00:00 committed by fangyixiao18
parent f78fe71794
commit df8c204d75
5 changed files with 316 additions and 6 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from typing import Optional, Sequence
from mmengine.hooks import Hook
@ -26,7 +26,10 @@ class DenseCLHook(Hook):
"The runner must have attribute \"loss_lambda\" in DenseCL."
self.loss_lambda = runner.model.module.loss_lambda
def before_train_iter(self, runner) -> None:
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[dict]] = None) -> None:
assert hasattr(runner.model.module, 'loss_lambda'), \
"The runner must have attribute \"loss_lambda\" in DenseCL."
cur_iter = runner.iter

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from typing import Optional, Sequence
from mmengine.hooks import Hook
@ -27,19 +27,22 @@ class SimSiamHook(Hook):
self.lr = lr
self.adjust_by_epoch = adjust_by_epoch
def before_train_iter(self, runner) -> None:
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[dict]] = None) -> None:
"""fix lr of predictor by iter."""
if self.adjust_by_epoch:
return
else:
if self.fix_pred_lr:
for param_group in runner.optimizer.param_groups:
for param_group in runner.optim_wrapper.optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
param_group['lr'] = self.lr
def before_train_epoch(self, runner) -> None:
"""fix lr of predictor by epoch."""
if self.fix_pred_lr:
for param_group in runner.optimizer.param_groups:
for param_group in runner.optim_wrapper.optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
param_group['lr'] = self.lr

View File

@ -0,0 +1,76 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import torch
from mmengine.data import LabelData
from torch.utils.data import Dataset
from mmselfsup.core.data_structures import SelfSupDataSample
from mmselfsup.core.hooks import DeepClusterHook
num_classes = 5
with_sobel = True,
backbone = dict(
type='ResNet',
depth=18,
in_channels=2,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
neck = dict(type='AvgPool2dNeck')
head = dict(
type='ClsHead',
with_avg_pool=False, # already has avgpool in the neck
in_channels=512,
num_classes=num_classes)
loss = dict(type='mmcls.CrossEntropyLoss')
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=self.data[index], data_sample=data_sample)
class TestDeepClusterHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_deepcluster_hook(self):
dummy_dataset = DummyDataset()
extract_dataloader = dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=1,
num_workers=0,
persistent_workers=False)
deepcluster_hook = DeepClusterHook(
extract_dataloader=extract_dataloader,
clustering=dict(type='Kmeans', k=num_classes, pca_dim=16),
unif_sampling=True,
reweight=False,
reweight_pow=0.5,
initial=True,
interval=1,
dist_mode=False)
# test DeepClusterHook
assert deepcluster_hook.clustering_type == 'Kmeans'

View File

@ -0,0 +1,113 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine import Runner
from mmengine.data import LabelData
from mmengine.model import BaseModel as EngineBaseModel
from mmengine.model import BaseModule
from mmengine.optim import OptimWrapper
from torch.utils.data import Dataset
from mmselfsup.core.data_structures import SelfSupDataSample
from mmselfsup.core.hooks import DenseCLHook
from mmselfsup.models.algorithms import BaseModel
from mmselfsup.registry import MODELS
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=self.data[index], data_sample=data_sample)
@MODELS.register_module()
class DenseCLDummyLayer(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
class ToyModel(BaseModel):
def __init__(self):
super().__init__(backbone=dict(type='DenseCLDummyLayer'))
self.loss_lambda = 0.5
def loss(self, batch_inputs, data_samples):
labels = []
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(batch_inputs)
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
class TestDenseCLHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_densecl_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
densecl_hook = DenseCLHook(start_iters=1)
class DummyWrapper(EngineBaseModel):
def __init__(self, model):
super().__init__()
self.module = model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# test DenseCLHook with model wrapper
runner = Runner(
model=DummyWrapper(toy_model),
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=1,
num_workers=0),
optim_wrapper=OptimWrapper(
torch.optim.Adam(toy_model.parameters())),
param_scheduler=dict(type='MultiStepLR', milestones=[1]),
train_cfg=dict(by_epoch=True, max_epochs=2),
custom_hooks=[densecl_hook],
default_hooks=dict(logger=None),
log_processor=dict(window_size=1),
experiment_name='test_densecl_hook')
runner.train()
if runner.iter >= 1:
assert runner.model.module.loss_lambda == 0.5
else:
assert runner.model.module.loss_lambda == 0.

View File

@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import torch
import torch.nn as nn
from mmengine import Runner
from mmengine.data import LabelData
from mmengine.model import BaseModel as EngineBaseModel
from mmengine.model import BaseModule
from torch.utils.data import Dataset
from mmselfsup.core.data_structures import SelfSupDataSample
from mmselfsup.core.hooks import SimSiamHook
from mmselfsup.models.algorithms import BaseModel
from mmselfsup.registry import MODELS
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=self.data[index], data_sample=data_sample)
@MODELS.register_module()
class SimSiamDummyLayer(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.predictor = nn.Linear(2, 1)
def forward(self, x):
return self.predictor(x)
class ToyModel(BaseModel):
def __init__(self):
super().__init__(backbone=dict(type='SimSiamDummyLayer'))
def loss(self, batch_inputs, data_samples):
labels = []
for x in data_samples:
labels.append(x.gt_label.value)
labels = torch.stack(labels)
outputs = self.backbone(batch_inputs)
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
class TestSimSiamHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_simsiam_hook(self):
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dummy_dataset = DummyDataset()
toy_model = ToyModel().to(device)
simsiam_hook = SimSiamHook(
fix_pred_lr=True, lr=0.05, adjust_by_epoch=False)
class DummyWrapper(EngineBaseModel):
def __init__(self, model):
super().__init__()
self.module = model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# test SimSiamHook
runner = Runner(
model=DummyWrapper(toy_model),
work_dir=self.temp_dir.name,
train_dataloader=dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=1,
num_workers=0),
optim_wrapper=dict(
optimizer=dict(type='SGD', lr=0.05),
paramwise_cfg=dict(
custom_keys={'predictor': dict(fix_lr=True)})),
param_scheduler=dict(type='MultiStepLR', milestones=[1]),
train_cfg=dict(by_epoch=True, max_epochs=2),
custom_hooks=[simsiam_hook],
default_hooks=dict(logger=None),
log_processor=dict(window_size=1),
experiment_name='test_simsiam_hook')
runner.train()
for param_group in runner.optim_wrapper.optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
assert param_group['lr'] == 0.05
else:
assert param_group['lr'] != 0.05