[Refactor]: refactor densecl algorithm

pull/352/head
renqin 2022-07-07 12:37:16 +00:00 committed by fangyixiao18
parent d6dfa9fe40
commit 5f778aa552
5 changed files with 97 additions and 122 deletions

View File

@ -5,6 +5,10 @@ model = dict(
feat_dim=128,
momentum=0.999,
loss_lambda=0.5,
data_preprocessor=dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
@ -17,5 +21,8 @@ model = dict(
hid_channels=2048,
out_channels=128,
num_grid=None),
head=dict(type='ContrastiveHead', temperature=0.2),
loss=dict(type='mmcls.CrossEntropyLoss'))
head=dict(
type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.2),
)

View File

@ -5,8 +5,10 @@ _base_ = [
'../_base_/default_runtime.py',
]
find_unused_parameters = True
# runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))

View File

@ -4,16 +4,16 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.data import BaseDataElement
from mmengine.model import ExponentialMovingAverage
from mmselfsup.core import SelfSupDataSample
from mmselfsup.registry import MODELS
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
concat_all_gather)
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
build_neck)
from .base import BaseModel
@ALGORITHMS.register_module()
@MODELS.register_module()
class DenseCL(BaseModel):
"""DenseCL.
@ -25,11 +25,8 @@ class DenseCL(BaseModel):
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact
feature vectors. Defaults to None.
feature vectors.
head (dict): Config dict for module of head functions.
Defaults to None.
loss (dict): Config dict for module of loss functions.
Defaults to None.
queue_len (int): Number of negative keys maintained in the queue.
Defaults to 65536.
feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
@ -37,66 +34,55 @@ class DenseCL(BaseModel):
encoder. Defaults to 0.999.
loss_lambda (float): Loss weight for the single and dense contrastive
loss. Defaults to 0.5.
preprocess_cfg (Dict, optional): Config dict to preprocess images.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
backbone: dict,
neck: dict,
head: dict,
queue_len: int = 65536,
feat_dim: int = 128,
momentum: float = 0.999,
loss_lambda: float = 0.5,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
**kwargs):
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
assert neck is not None
self.encoder_q = nn.Sequential(
build_backbone(backbone), build_neck(neck))
self.encoder_k = nn.Sequential(
build_backbone(backbone), build_neck(neck))
pretrained: Optional[str] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
self.backbone = self.encoder_q[0]
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
# create momentum model
self.encoder_k = ExponentialMovingAverage(
nn.Sequential(self.backbone, self.neck), 1 - momentum)
self.queue_len = queue_len
self.momentum = momentum
self.loss_lambda = loss_lambda
# create the queue
self.register_buffer('queue', torch.randn(feat_dim, queue_len))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
# create the second queue for dense output
self.register_buffer('queue2', torch.randn(feat_dim, queue_len))
self.queue2 = nn.functional.normalize(self.queue2, dim=0)
self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""Momentum update of the key encoder."""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.momentum + \
param_q.data * (1. - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
"""Update queue."""
# gather keys before updating queue
keys = concat_all_gather(keys)
@ -113,7 +99,7 @@ class DenseCL(BaseModel):
self.queue_ptr[0] = ptr
@torch.no_grad()
def _dequeue_and_enqueue2(self, keys):
def _dequeue_and_enqueue2(self, keys: torch.Tensor) -> None:
"""Update queue2."""
# gather keys before updating queue
keys = concat_all_gather(keys)
@ -129,41 +115,40 @@ class DenseCL(BaseModel):
self.queue2_ptr[0] = ptr
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
def extract_feat(self, batch_inputs: List[torch.Tensor],
**kwargs) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(inputs[0])
x = self.backbone(batch_inputs[0])
return x
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
assert isinstance(inputs, list)
im_q = inputs[0]
im_k = inputs[1]
assert isinstance(batch_inputs, list)
im_q = batch_inputs[0]
im_k = batch_inputs[1]
# compute query features
q_b = self.encoder_q[0](im_q) # backbone features
q, q_grid, q2 = self.encoder_q[1](q_b) # queries: NxC; NxCxS^2
q_b = self.backbone(im_q) # backbone features
q, q_grid, q2 = self.neck(q_b) # queries: NxC; NxCxS^2
q_b = q_b[0]
q_b = q_b.view(q_b.size(0), q_b.size(1), -1)
@ -175,13 +160,14 @@ class DenseCL(BaseModel):
# compute key features
with torch.no_grad(): # no gradient to keys
# update the key encoder
self._momentum_update_key_encoder()
self.encoder_k.update_parameters(
nn.Sequential(self.backbone, self.neck))
# shuffle for making use of BN
im_k, idx_unshuffle = batch_shuffle_ddp(im_k)
k_b = self.encoder_k[0](im_k) # backbone features
k, k_grid, k2 = self.encoder_k[1](k_b) # keys: NxC; NxCxS^2
k_b = self.encoder_k.module[0](im_k) # backbone features
k, k_grid, k2 = self.encoder_k.module[1](k_b) # keys: NxC; NxCxS^2
k_b = k_b[0]
k_b = k_b.view(k_b.size(0), k_b.size(1), -1)
@ -221,10 +207,8 @@ class DenseCL(BaseModel):
l_neg_dense = torch.einsum(
'nc,ck->nk', [q_grid, self.queue2.clone().detach()])
logits, labels = self.head(l_pos, l_neg)
logits_dense, labels_dense = self.head(l_pos_dense, l_neg_dense)
loss_single = self.loss(logits, labels)
loss_dense = self.loss(logits_dense, labels_dense)
loss_single = self.head(l_pos, l_neg)
loss_dense = self.head(l_pos_dense, l_neg_dense)
losses = dict()
losses['loss_single'] = loss_single * (1 - self.loss_lambda)
@ -235,16 +219,17 @@ class DenseCL(BaseModel):
return losses
def forward_test(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> object:
"""The forward function in testing
def predict(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> SelfSupDataSample:
"""Predict results from the extracted features.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
"""
q_grid = self.extract_feat(inputs, data_samples)[0]
q_grid = self.extract_feat(batch_inputs)[0]
q_grid = q_grid.view(q_grid.size(0), q_grid.size(1), -1)
q_grid = nn.functional.normalize(q_grid, dim=1)

View File

@ -1,11 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from ..builder import NECKS
from mmselfsup.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class DenseCLNeck(BaseModule):
"""The non-linear neck of DenseCL.
@ -22,11 +25,11 @@ class DenseCLNeck(BaseModule):
"""
def __init__(self,
in_channels,
hid_channels,
out_channels,
num_grid=None,
init_cfg=None):
in_channels: int,
hid_channels: int,
out_channels: int,
num_grid: Optional[int] = None,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super(DenseCLNeck, self).__init__(init_cfg)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential(
@ -41,7 +44,7 @@ class DenseCLNeck(BaseModule):
nn.Conv2d(hid_channels, out_channels, 1))
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
"""Forward function of neck.
Args:

View File

@ -27,8 +27,10 @@ neck = dict(
hid_channels=2,
out_channels=2,
num_grid=None)
head = dict(type='ContrastiveHead', temperature=0.2)
loss = dict(type='mmcls.CrossEntropyLoss')
head = dict(
type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.2)
def mock_batch_shuffle_ddp(img):
@ -45,43 +47,21 @@ def mock_concat_all_gather(img):
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_densecl():
preprocess_cfg = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
data_preprocessor = {
'mean': (123.675, 116.28, 103.53),
'std': (58.395, 57.12, 57.375),
'bgr_to_rgb': True
}
with pytest.raises(AssertionError):
alg = DenseCL(
backbone=backbone,
neck=None,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = DenseCL(
backbone=backbone,
neck=neck,
head=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = DenseCL(
backbone=backbone,
neck=neck,
head=head,
loss=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = DenseCL(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
queue_len=queue_len,
feat_dim=feat_dim,
momentum=momentum,
loss_lambda=loss_lambda,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
data_preprocessor=copy.deepcopy(data_preprocessor))
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
assert alg.queue2.size() == torch.Size([feat_dim, queue_len])
@ -100,21 +80,19 @@ def test_densecl():
SelfSupDataSample(),
} for _ in range(2)]
fake_outputs = alg(fake_data, return_loss=True)
assert isinstance(fake_outputs['loss'].item(), float)
assert isinstance(fake_outputs['log_vars']['loss_single'], float)
assert isinstance(fake_outputs['log_vars']['loss_dense'], float)
assert fake_outputs['log_vars']['loss_single'] > 0
assert fake_outputs['log_vars']['loss_dense'] > 0
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_loss['loss_single'].item(), float)
assert isinstance(fake_loss['loss_dense'].item(), float)
assert fake_loss['loss_single'].item() > 0
assert fake_loss['loss_dense'].item() > 0
assert alg.queue_ptr.item() == 2
assert alg.queue2_ptr.item() == 2
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_feat = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
fake_feat = alg(fake_inputs, fake_data_samples, mode='tensor')
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
fake_outputs = alg(fake_data, return_loss=False)
fake_outputs = alg(fake_inputs, fake_data_samples, mode='predict')
assert 'q_grid' in fake_outputs
assert 'value' in fake_outputs.q_grid
assert list(fake_outputs.q_grid.value.shape) == [2, 512, 49]