[Refactor]: refactor swav algorithm

pull/352/head
renqin 2022-07-08 03:55:27 +00:00 committed by fangyixiao18
parent 5f778aa552
commit dfa4d180df
10 changed files with 132 additions and 110 deletions

View File

@ -1,4 +1,5 @@
# dataset settings
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
@ -6,7 +7,9 @@ file_client_args = dict(backend='disk')
num_crops = [2, 6]
color_distort_strength = 1.0
view_pipeline1 = [
dict(type='RandomResizedCrop', size=224, scale=(0.14, 1.)),
dict(
type='RandomResizedCrop', size=224, scale=(0.14, 1.),
backend='pillow'),
dict(
type='RandomApply',
transforms=[
@ -18,12 +21,20 @@ view_pipeline1 = [
hue=0.2 * color_distort_strength)
],
prob=0.8),
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
dict(type='RandomFlip', prob=0.5),
]
view_pipeline2 = [
dict(type='RandomResizedCrop', size=96, scale=(0.05, 0.14)),
dict(
type='RandomResizedCrop',
size=96,
scale=(0.05, 0.14),
backend='pillow'),
dict(
type='RandomApply',
transforms=[
@ -35,7 +46,11 @@ view_pipeline2 = [
hue=0.2 * color_distort_strength)
],
prob=0.8),
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
dict(type='RandomFlip', prob=0.5),
]
@ -51,7 +66,8 @@ train_pipeline = [
train_dataloader = dict(
batch_size=32,
num_workers=4,
num_workers=8,
drop_last=True,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(

View File

@ -1,6 +1,10 @@
# model settings
model = dict(
type='SwAV',
data_preprocessor=dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
@ -14,10 +18,12 @@ model = dict(
hid_channels=2048,
out_channels=128,
with_avg_pool=True),
loss=dict(
type='SwAVLoss',
feat_dim=128, # equal to neck['out_channels']
epsilon=0.05,
temperature=0.1,
num_crops=[2, 6],
))
head=dict(
type='SwAVHead',
loss=dict(
type='SwAVLoss',
feat_dim=128, # equal to neck['out_channels']
epsilon=0.05,
temperature=0.1,
num_crops=[2, 6],
)))

View File

@ -6,7 +6,7 @@ _base_ = [
]
# model settings
model = dict(head=dict(num_crops={{_base_.num_crops}}))
model = dict(head=dict(loss=dict(num_crops={{_base_.num_crops}})))
# additional hooks
custom_hooks = [
@ -17,7 +17,8 @@ custom_hooks = [
epoch_queue_starts=15,
crops_for_assign=[0, 1],
feat_dim=128,
queue_length=3840)
queue_length=3840,
frozen_layers_cfg=dict(prototypes=5005))
]
# dataset summary
@ -25,7 +26,7 @@ data = dict(num_views={{_base_.num_crops}})
# optimizer
optimizer = dict(type='LARS', lr=0.6)
optimizer_config = dict(frozen_layers_cfg=dict(prototypes=5005))
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
# learning policy
param_scheduler = [
@ -35,11 +36,14 @@ param_scheduler = [
eta_min=6e-4,
by_epoch=True,
begin=0,
end=200)
end=200,
convert_to_iter_based=True)
]
# runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
find_unused_parameters = True

View File

@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence
import torch
import torch.distributed as dist
from mmengine.hooks import Hook
from mmengine.logging import MMLogger
from mmselfsup.registry import HOOKS
@ -64,7 +65,10 @@ class SwAVHook(Hook):
# build the queue
if osp.isfile(self.queue_path):
self.queue = torch.load(self.queue_path)['queue']
runner.model.module.head.queue = self.queue
runner.model.module.head.loss.queue = self.queue
MMLogger.get_current_instance().info(
f'Load queue from file: {self.queue_path}')
# the queue needs to be divisible by the batch size
self.queue_length -= self.queue_length % self.batch_size
@ -96,11 +100,11 @@ class SwAVHook(Hook):
).cuda()
# set the boolean type of use_the_queue
runner.model.module.head.queue = self.queue
runner.model.module.head.use_queue = False
runner.model.module.head.loss.queue = self.queue
runner.model.module.head.loss.use_queue = False
def after_train_epoch(self, runner) -> None:
self.queue = runner.model.module.head.queue
self.queue = runner.model.module.head.loss.queue
if self.queue is not None and self.every_n_epochs(
runner, self.interval):

View File

@ -1,92 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
import torch
from mmselfsup.core import SelfSupDataSample
from ..builder import ALGORITHMS, build_backbone, build_loss, build_neck
from mmselfsup.registry import MODELS
from .base import BaseModel
@ALGORITHMS.register_module()
@MODELS.register_module()
class SwAV(BaseModel):
"""SwAV.
Implementation of `Unsupervised Learning of Visual Features by Contrasting
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_.
The queue is built in `core/hooks/swav_hook.py`.
Args:
backbone (Dict, optional): Config dict for module of backbone.
neck (Dict, optional): Config dict for module of deep features
to compact
feature vectors. Defaults to None.
loss (Dict, optional): Config dict for module of loss functions.
Defaults to None.
preprocess_cfg (Dict, optional): Config dict to preprocess images.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_. The queue is
built in `core/hooks/swav_hook.py`.
"""
def __init__(self,
backbone: Optional[Dict] = None,
neck: Optional[Dict] = None,
loss: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
**kwargs) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert neck is not None
self.neck = build_neck(neck)
assert loss is not None
self.loss = build_loss(loss)
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
def extract_feat(self, batch_inputs: List[torch.Tensor],
**kwargs) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
batch_inputs (List[torch.Tensor]): The input images.
Returns:
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(inputs[0])
x = self.backbone(batch_inputs[0])
return x
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
assert isinstance(inputs, list)
assert isinstance(batch_inputs, list)
# multi-res forward passes
idx_crops = torch.cumsum(
torch.unique_consecutive(
torch.tensor([input.shape[-1] for input in inputs]),
torch.tensor([input.shape[-1] for input in batch_inputs]),
return_counts=True)[1], 0)
start_idx = 0
output = []
for end_idx in idx_crops:
_out = self.backbone(torch.cat(inputs[start_idx:end_idx]))
_out = self.backbone(torch.cat(batch_inputs[start_idx:end_idx]))
output.append(_out)
start_idx = end_idx
output = self.neck(output)[0]
loss = self.loss(output)
loss = self.head(output)
losses = dict(loss=loss)
return losses

View File

@ -7,9 +7,11 @@ from .mae_head import MAEFinetuneHead, MAELinprobeHead, MAEPretrainHead
from .mocov3_head import MoCoV3Head
from .multi_cls_head import MultiClsHead
from .simmim_head import SimMIMHead
from .swav_head import SwAVHead
__all__ = [
'ContrastiveHead', 'ClsHead', 'LatentPredictHead',
'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEFinetuneHead',
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'MAELinprobeHead'
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead',
'MAELinprobeHead', 'SwAVHead'
]

View File

@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from ..builder import MODELS
@MODELS.register_module()
class SwAVHead(BaseModule):
"""Head for SwAV.
Args:
loss (dict): Config dict for module of loss functions.
"""
def __init__(self, loss: dict) -> None:
super().__init__()
self.loss = MODELS.build(loss)
def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""Forward function of SwAV head.
Args:
pred (torch.Tensor): NxC input features.
Returns:
torch.Tensor: The SwAV loss.
"""
loss = self.loss(pred)
return loss

View File

@ -8,11 +8,11 @@ import torch.nn as nn
from mmcv.runner import BaseModule
from mmselfsup.utils import distributed_sinkhorn
from ..builder import LOSSES
from ..builder import MODELS
from ..utils import MultiPrototypes
@LOSSES.register_module()
@MODELS.register_module()
class SwAVLoss(BaseModule):
"""The Loss for SwAV.

View File

@ -14,7 +14,7 @@ from torch.utils.data import Dataset
from mmselfsup.core.data_structures import SelfSupDataSample
from mmselfsup.core.hooks import SwAVHook
from mmselfsup.models.algorithms import BaseModel
from mmselfsup.models.losses import SwAVLoss
from mmselfsup.models.heads import SwAVHead
from mmselfsup.registry import MODELS
@ -53,7 +53,12 @@ class ToyModel(BaseModel):
def __init__(self):
super().__init__(backbone=dict(type='SwAVDummyLayer'))
self.prototypes_test = nn.Linear(1, 1)
self.head = SwAVLoss(feat_dim=2, num_crops=[2, 6], num_prototypes=3)
self.head = SwAVHead(
loss=dict(
type='SwAVLoss',
feat_dim=2,
num_crops=[2, 6],
num_prototypes=3))
def loss(self, batch_inputs, data_samples):
labels = []
@ -119,4 +124,4 @@ class TestSwAVHook(TestCase):
if isinstance(hook, SwAVHook):
assert hook.queue_length == 300
assert runner.model.module.head.use_queue is False
assert runner.model.module.head.loss.use_queue is False

View File

@ -24,44 +24,29 @@ neck = dict(
out_channels=2,
norm_cfg=dict(type='BN1d'),
with_avg_pool=True)
loss = dict(
type='SwAVLoss',
feat_dim=2, # equal to neck['out_channels']
epsilon=0.05,
temperature=0.1,
num_crops=nmb_crops)
head = dict(
type='SwAVHead',
loss=dict(
type='SwAVLoss',
feat_dim=2, # equal to neck['out_channels']
epsilon=0.05,
temperature=0.1,
num_crops=nmb_crops))
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_swav():
preprocess_cfg = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
data_preprocessor = {
'mean': (123.675, 116.28, 103.53),
'std': (58.395, 57.12, 57.375),
'bgr_to_rgb': True
}
with pytest.raises(AssertionError):
alg = SwAV(
backbone=backbone,
neck=neck,
loss=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = SwAV(
backbone=backbone,
neck=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = SwAV(
backbone=None,
neck=neck,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = SwAV(
backbone=backbone,
neck=neck,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
head=head,
data_preprocessor=copy.deepcopy(data_preprocessor))
fake_data = [{
'inputs': [
@ -78,10 +63,9 @@ def test_swav():
SelfSupDataSample()
} for _ in range(2)]
fake_outputs = alg(fake_data, return_loss=True)
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_feat = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
fake_feat = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
assert list(fake_feat[0].shape) == [2, 512, 7, 7]