[Refactor]: refactor swav algorithm
parent
5f778aa552
commit
dfa4d180df
|
@ -1,4 +1,5 @@
|
|||
# dataset settings
|
||||
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
@ -6,7 +7,9 @@ file_client_args = dict(backend='disk')
|
|||
num_crops = [2, 6]
|
||||
color_distort_strength = 1.0
|
||||
view_pipeline1 = [
|
||||
dict(type='RandomResizedCrop', size=224, scale=(0.14, 1.)),
|
||||
dict(
|
||||
type='RandomResizedCrop', size=224, scale=(0.14, 1.),
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
|
@ -18,12 +21,20 @@ view_pipeline1 = [
|
|||
hue=0.2 * color_distort_strength)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
]
|
||||
view_pipeline2 = [
|
||||
dict(type='RandomResizedCrop', size=96, scale=(0.05, 0.14)),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=96,
|
||||
scale=(0.05, 0.14),
|
||||
backend='pillow'),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
|
@ -35,7 +46,11 @@ view_pipeline2 = [
|
|||
hue=0.2 * color_distort_strength)
|
||||
],
|
||||
prob=0.8),
|
||||
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=0.5),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
]
|
||||
|
@ -51,7 +66,8 @@ train_pipeline = [
|
|||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='SwAV',
|
||||
data_preprocessor=dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -14,10 +18,12 @@ model = dict(
|
|||
hid_channels=2048,
|
||||
out_channels=128,
|
||||
with_avg_pool=True),
|
||||
head=dict(
|
||||
type='SwAVHead',
|
||||
loss=dict(
|
||||
type='SwAVLoss',
|
||||
feat_dim=128, # equal to neck['out_channels']
|
||||
epsilon=0.05,
|
||||
temperature=0.1,
|
||||
num_crops=[2, 6],
|
||||
))
|
||||
)))
|
||||
|
|
|
@ -6,7 +6,7 @@ _base_ = [
|
|||
]
|
||||
|
||||
# model settings
|
||||
model = dict(head=dict(num_crops={{_base_.num_crops}}))
|
||||
model = dict(head=dict(loss=dict(num_crops={{_base_.num_crops}})))
|
||||
|
||||
# additional hooks
|
||||
custom_hooks = [
|
||||
|
@ -17,7 +17,8 @@ custom_hooks = [
|
|||
epoch_queue_starts=15,
|
||||
crops_for_assign=[0, 1],
|
||||
feat_dim=128,
|
||||
queue_length=3840)
|
||||
queue_length=3840,
|
||||
frozen_layers_cfg=dict(prototypes=5005))
|
||||
]
|
||||
|
||||
# dataset summary
|
||||
|
@ -25,7 +26,7 @@ data = dict(num_views={{_base_.num_crops}})
|
|||
|
||||
# optimizer
|
||||
optimizer = dict(type='LARS', lr=0.6)
|
||||
optimizer_config = dict(frozen_layers_cfg=dict(prototypes=5005))
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
|
@ -35,11 +36,14 @@ param_scheduler = [
|
|||
eta_min=6e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=200)
|
||||
end=200,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', interval=50),
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
|
||||
|
@ -64,7 +65,10 @@ class SwAVHook(Hook):
|
|||
# build the queue
|
||||
if osp.isfile(self.queue_path):
|
||||
self.queue = torch.load(self.queue_path)['queue']
|
||||
runner.model.module.head.queue = self.queue
|
||||
runner.model.module.head.loss.queue = self.queue
|
||||
MMLogger.get_current_instance().info(
|
||||
f'Load queue from file: {self.queue_path}')
|
||||
|
||||
# the queue needs to be divisible by the batch size
|
||||
self.queue_length -= self.queue_length % self.batch_size
|
||||
|
||||
|
@ -96,11 +100,11 @@ class SwAVHook(Hook):
|
|||
).cuda()
|
||||
|
||||
# set the boolean type of use_the_queue
|
||||
runner.model.module.head.queue = self.queue
|
||||
runner.model.module.head.use_queue = False
|
||||
runner.model.module.head.loss.queue = self.queue
|
||||
runner.model.module.head.loss.use_queue = False
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
self.queue = runner.model.module.head.queue
|
||||
self.queue = runner.model.module.head.loss.queue
|
||||
|
||||
if self.queue is not None and self.every_n_epochs(
|
||||
runner, self.interval):
|
||||
|
|
|
@ -1,92 +1,62 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from ..builder import ALGORITHMS, build_backbone, build_loss, build_neck
|
||||
from mmselfsup.registry import MODELS
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@ALGORITHMS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SwAV(BaseModel):
|
||||
"""SwAV.
|
||||
|
||||
Implementation of `Unsupervised Learning of Visual Features by Contrasting
|
||||
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_.
|
||||
The queue is built in `core/hooks/swav_hook.py`.
|
||||
|
||||
Args:
|
||||
backbone (Dict, optional): Config dict for module of backbone.
|
||||
neck (Dict, optional): Config dict for module of deep features
|
||||
to compact
|
||||
feature vectors. Defaults to None.
|
||||
loss (Dict, optional): Config dict for module of loss functions.
|
||||
Defaults to None.
|
||||
preprocess_cfg (Dict, optional): Config dict to preprocess images.
|
||||
Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_. The queue is
|
||||
built in `core/hooks/swav_hook.py`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: Optional[Dict] = None,
|
||||
neck: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
assert neck is not None
|
||||
self.neck = build_neck(neck)
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||
**kwargs) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
x = self.backbone(inputs[0])
|
||||
x = self.backbone(batch_inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
def loss(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""Forward computation during training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
assert isinstance(inputs, list)
|
||||
assert isinstance(batch_inputs, list)
|
||||
# multi-res forward passes
|
||||
idx_crops = torch.cumsum(
|
||||
torch.unique_consecutive(
|
||||
torch.tensor([input.shape[-1] for input in inputs]),
|
||||
torch.tensor([input.shape[-1] for input in batch_inputs]),
|
||||
return_counts=True)[1], 0)
|
||||
start_idx = 0
|
||||
output = []
|
||||
for end_idx in idx_crops:
|
||||
_out = self.backbone(torch.cat(inputs[start_idx:end_idx]))
|
||||
_out = self.backbone(torch.cat(batch_inputs[start_idx:end_idx]))
|
||||
output.append(_out)
|
||||
start_idx = end_idx
|
||||
output = self.neck(output)[0]
|
||||
|
||||
loss = self.loss(output)
|
||||
loss = self.head(output)
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
||||
|
|
|
@ -7,9 +7,11 @@ from .mae_head import MAEFinetuneHead, MAELinprobeHead, MAEPretrainHead
|
|||
from .mocov3_head import MoCoV3Head
|
||||
from .multi_cls_head import MultiClsHead
|
||||
from .simmim_head import SimMIMHead
|
||||
from .swav_head import SwAVHead
|
||||
|
||||
__all__ = [
|
||||
'ContrastiveHead', 'ClsHead', 'LatentPredictHead',
|
||||
'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEFinetuneHead',
|
||||
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'MAELinprobeHead'
|
||||
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead',
|
||||
'MAELinprobeHead', 'SwAVHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from ..builder import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SwAVHead(BaseModule):
|
||||
"""Head for SwAV.
|
||||
|
||||
Args:
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
"""
|
||||
|
||||
def __init__(self, loss: dict) -> None:
|
||||
super().__init__()
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
def forward(self, pred: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of SwAV head.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): NxC input features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The SwAV loss.
|
||||
"""
|
||||
loss = self.loss(pred)
|
||||
|
||||
return loss
|
|
@ -8,11 +8,11 @@ import torch.nn as nn
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmselfsup.utils import distributed_sinkhorn
|
||||
from ..builder import LOSSES
|
||||
from ..builder import MODELS
|
||||
from ..utils import MultiPrototypes
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class SwAVLoss(BaseModule):
|
||||
"""The Loss for SwAV.
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from torch.utils.data import Dataset
|
|||
from mmselfsup.core.data_structures import SelfSupDataSample
|
||||
from mmselfsup.core.hooks import SwAVHook
|
||||
from mmselfsup.models.algorithms import BaseModel
|
||||
from mmselfsup.models.losses import SwAVLoss
|
||||
from mmselfsup.models.heads import SwAVHead
|
||||
from mmselfsup.registry import MODELS
|
||||
|
||||
|
||||
|
@ -53,7 +53,12 @@ class ToyModel(BaseModel):
|
|||
def __init__(self):
|
||||
super().__init__(backbone=dict(type='SwAVDummyLayer'))
|
||||
self.prototypes_test = nn.Linear(1, 1)
|
||||
self.head = SwAVLoss(feat_dim=2, num_crops=[2, 6], num_prototypes=3)
|
||||
self.head = SwAVHead(
|
||||
loss=dict(
|
||||
type='SwAVLoss',
|
||||
feat_dim=2,
|
||||
num_crops=[2, 6],
|
||||
num_prototypes=3))
|
||||
|
||||
def loss(self, batch_inputs, data_samples):
|
||||
labels = []
|
||||
|
@ -119,4 +124,4 @@ class TestSwAVHook(TestCase):
|
|||
if isinstance(hook, SwAVHook):
|
||||
assert hook.queue_length == 300
|
||||
|
||||
assert runner.model.module.head.use_queue is False
|
||||
assert runner.model.module.head.loss.use_queue is False
|
||||
|
|
|
@ -24,44 +24,29 @@ neck = dict(
|
|||
out_channels=2,
|
||||
norm_cfg=dict(type='BN1d'),
|
||||
with_avg_pool=True)
|
||||
head = dict(
|
||||
type='SwAVHead',
|
||||
loss=dict(
|
||||
type='SwAVLoss',
|
||||
feat_dim=2, # equal to neck['out_channels']
|
||||
epsilon=0.05,
|
||||
temperature=0.1,
|
||||
num_crops=nmb_crops)
|
||||
num_crops=nmb_crops))
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_swav():
|
||||
preprocess_cfg = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
data_preprocessor = {
|
||||
'mean': (123.675, 116.28, 103.53),
|
||||
'std': (58.395, 57.12, 57.375),
|
||||
'bgr_to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
alg = SwAV(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
loss=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = SwAV(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = SwAV(
|
||||
backbone=None,
|
||||
neck=neck,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
alg = SwAV(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
head=head,
|
||||
data_preprocessor=copy.deepcopy(data_preprocessor))
|
||||
|
||||
fake_data = [{
|
||||
'inputs': [
|
||||
|
@ -78,10 +63,9 @@ def test_swav():
|
|||
SelfSupDataSample()
|
||||
} for _ in range(2)]
|
||||
|
||||
fake_outputs = alg(fake_data, return_loss=True)
|
||||
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_feat = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
fake_feat = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
|
||||
|
|
Loading…
Reference in New Issue