[Refactor]: refactor densecl algorithm

pull/352/head
renqin 2022-07-07 12:37:16 +00:00 committed by fangyixiao18
parent d6dfa9fe40
commit 5f778aa552
5 changed files with 97 additions and 122 deletions

View File

@ -5,6 +5,10 @@ model = dict(
feat_dim=128, feat_dim=128,
momentum=0.999, momentum=0.999,
loss_lambda=0.5, loss_lambda=0.5,
data_preprocessor=dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
bgr_to_rgb=True),
backbone=dict( backbone=dict(
type='ResNet', type='ResNet',
depth=50, depth=50,
@ -17,5 +21,8 @@ model = dict(
hid_channels=2048, hid_channels=2048,
out_channels=128, out_channels=128,
num_grid=None), num_grid=None),
head=dict(type='ContrastiveHead', temperature=0.2), head=dict(
loss=dict(type='mmcls.CrossEntropyLoss')) type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.2),
)

View File

@ -5,8 +5,10 @@ _base_ = [
'../_base_/default_runtime.py', '../_base_/default_runtime.py',
] ]
find_unused_parameters = True
# runtime settings # runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs default_hooks = dict(
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt logger=dict(type='LoggerHook', interval=50),
# it will remove the oldest one to keep the number of total ckpts as 3 # only keeps the latest 3 checkpoints
checkpoint_config = dict(interval=10, max_keep_ckpts=3) checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))

View File

@ -4,16 +4,16 @@ from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from mmengine.model import ExponentialMovingAverage
from mmselfsup.core import SelfSupDataSample from mmselfsup.core import SelfSupDataSample
from mmselfsup.registry import MODELS
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp, from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
concat_all_gather) concat_all_gather)
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
build_neck)
from .base import BaseModel from .base import BaseModel
@ALGORITHMS.register_module() @MODELS.register_module()
class DenseCL(BaseModel): class DenseCL(BaseModel):
"""DenseCL. """DenseCL.
@ -25,11 +25,8 @@ class DenseCL(BaseModel):
Args: Args:
backbone (dict): Config dict for module of backbone. backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact neck (dict): Config dict for module of deep features to compact
feature vectors. Defaults to None. feature vectors.
head (dict): Config dict for module of head functions. head (dict): Config dict for module of head functions.
Defaults to None.
loss (dict): Config dict for module of loss functions.
Defaults to None.
queue_len (int): Number of negative keys maintained in the queue. queue_len (int): Number of negative keys maintained in the queue.
Defaults to 65536. Defaults to 65536.
feat_dim (int): Dimension of compact feature vectors. Defaults to 128. feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
@ -37,66 +34,55 @@ class DenseCL(BaseModel):
encoder. Defaults to 0.999. encoder. Defaults to 0.999.
loss_lambda (float): Loss weight for the single and dense contrastive loss_lambda (float): Loss weight for the single and dense contrastive
loss. Defaults to 0.5. loss. Defaults to 0.5.
preprocess_cfg (Dict, optional): Config dict to preprocess images. pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None. Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None. initialization. Defaults to None.
""" """
def __init__(self, def __init__(self,
backbone, backbone: dict,
neck: Optional[Dict] = None, neck: dict,
head: Optional[Dict] = None, head: dict,
loss: Optional[Dict] = None,
queue_len: int = 65536, queue_len: int = 65536,
feat_dim: int = 128, feat_dim: int = 128,
momentum: float = 0.999, momentum: float = 0.999,
loss_lambda: float = 0.5, loss_lambda: float = 0.5,
preprocess_cfg: Optional[Dict] = None, pretrained: Optional[str] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None, data_preprocessor: Optional[Union[dict, nn.Module]] = None,
**kwargs): init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg) super().__init__(
assert backbone is not None backbone=backbone,
assert neck is not None neck=neck,
self.encoder_q = nn.Sequential( head=head,
build_backbone(backbone), build_neck(neck)) pretrained=pretrained,
self.encoder_k = nn.Sequential( data_preprocessor=data_preprocessor,
build_backbone(backbone), build_neck(neck)) init_cfg=init_cfg)
for param_q, param_k in zip(self.encoder_q.parameters(), # create momentum model
self.encoder_k.parameters()): self.encoder_k = ExponentialMovingAverage(
param_k.data.copy_(param_q.data) nn.Sequential(self.backbone, self.neck), 1 - momentum)
param_k.requires_grad = False
self.backbone = self.encoder_q[0]
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
self.queue_len = queue_len self.queue_len = queue_len
self.momentum = momentum
self.loss_lambda = loss_lambda self.loss_lambda = loss_lambda
# create the queue # create the queue
self.register_buffer('queue', torch.randn(feat_dim, queue_len)) self.register_buffer('queue', torch.randn(feat_dim, queue_len))
self.queue = nn.functional.normalize(self.queue, dim=0) self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
# create the second queue for dense output # create the second queue for dense output
self.register_buffer('queue2', torch.randn(feat_dim, queue_len)) self.register_buffer('queue2', torch.randn(feat_dim, queue_len))
self.queue2 = nn.functional.normalize(self.queue2, dim=0) self.queue2 = nn.functional.normalize(self.queue2, dim=0)
self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long)) self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad() @torch.no_grad()
def _momentum_update_key_encoder(self): def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
"""Momentum update of the key encoder."""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.momentum + \
param_q.data * (1. - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""Update queue.""" """Update queue."""
# gather keys before updating queue # gather keys before updating queue
keys = concat_all_gather(keys) keys = concat_all_gather(keys)
@ -113,7 +99,7 @@ class DenseCL(BaseModel):
self.queue_ptr[0] = ptr self.queue_ptr[0] = ptr
@torch.no_grad() @torch.no_grad()
def _dequeue_and_enqueue2(self, keys): def _dequeue_and_enqueue2(self, keys: torch.Tensor) -> None:
"""Update queue2.""" """Update queue2."""
# gather keys before updating queue # gather keys before updating queue
keys = concat_all_gather(keys) keys = concat_all_gather(keys)
@ -129,41 +115,40 @@ class DenseCL(BaseModel):
self.queue2_ptr[0] = ptr self.queue2_ptr[0] = ptr
def extract_feat(self, inputs: List[torch.Tensor], def extract_feat(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Tuple[torch.Tensor]: **kwargs) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone. """Function to extract features from backbone.
Args: Args:
inputs (List[torch.Tensor]): The input images. batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required data_samples (List[SelfSupDataSample]): All elements required
during the forward function. during the forward function.
Returns: Returns:
Tuple[torch.Tensor]: backbone outputs. Tuple[torch.Tensor]: backbone outputs.
""" """
x = self.backbone(inputs[0]) x = self.backbone(batch_inputs[0])
return x return x
def forward_train(self, inputs: List[torch.Tensor], def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample], data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]: **kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training. """Forward computation during training.
Args: Args:
inputs (List[torch.Tensor]): The input images. batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required data_samples (List[SelfSupDataSample]): All elements required
during the forward function. during the forward function.
Returns: Returns:
Dict[str, torch.Tensor]: A dictionary of loss components. Dict[str, torch.Tensor]: A dictionary of loss components.
""" """
assert isinstance(inputs, list) assert isinstance(batch_inputs, list)
im_q = inputs[0] im_q = batch_inputs[0]
im_k = inputs[1] im_k = batch_inputs[1]
# compute query features # compute query features
q_b = self.encoder_q[0](im_q) # backbone features q_b = self.backbone(im_q) # backbone features
q, q_grid, q2 = self.encoder_q[1](q_b) # queries: NxC; NxCxS^2 q, q_grid, q2 = self.neck(q_b) # queries: NxC; NxCxS^2
q_b = q_b[0] q_b = q_b[0]
q_b = q_b.view(q_b.size(0), q_b.size(1), -1) q_b = q_b.view(q_b.size(0), q_b.size(1), -1)
@ -175,13 +160,14 @@ class DenseCL(BaseModel):
# compute key features # compute key features
with torch.no_grad(): # no gradient to keys with torch.no_grad(): # no gradient to keys
# update the key encoder # update the key encoder
self._momentum_update_key_encoder() self.encoder_k.update_parameters(
nn.Sequential(self.backbone, self.neck))
# shuffle for making use of BN # shuffle for making use of BN
im_k, idx_unshuffle = batch_shuffle_ddp(im_k) im_k, idx_unshuffle = batch_shuffle_ddp(im_k)
k_b = self.encoder_k[0](im_k) # backbone features k_b = self.encoder_k.module[0](im_k) # backbone features
k, k_grid, k2 = self.encoder_k[1](k_b) # keys: NxC; NxCxS^2 k, k_grid, k2 = self.encoder_k.module[1](k_b) # keys: NxC; NxCxS^2
k_b = k_b[0] k_b = k_b[0]
k_b = k_b.view(k_b.size(0), k_b.size(1), -1) k_b = k_b.view(k_b.size(0), k_b.size(1), -1)
@ -221,10 +207,8 @@ class DenseCL(BaseModel):
l_neg_dense = torch.einsum( l_neg_dense = torch.einsum(
'nc,ck->nk', [q_grid, self.queue2.clone().detach()]) 'nc,ck->nk', [q_grid, self.queue2.clone().detach()])
logits, labels = self.head(l_pos, l_neg) loss_single = self.head(l_pos, l_neg)
logits_dense, labels_dense = self.head(l_pos_dense, l_neg_dense) loss_dense = self.head(l_pos_dense, l_neg_dense)
loss_single = self.loss(logits, labels)
loss_dense = self.loss(logits_dense, labels_dense)
losses = dict() losses = dict()
losses['loss_single'] = loss_single * (1 - self.loss_lambda) losses['loss_single'] = loss_single * (1 - self.loss_lambda)
@ -235,16 +219,17 @@ class DenseCL(BaseModel):
return losses return losses
def forward_test(self, inputs: List[torch.Tensor], def predict(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample], data_samples: List[SelfSupDataSample],
**kwargs) -> object: **kwargs) -> SelfSupDataSample:
"""The forward function in testing """Predict results from the extracted features.
Args: Args:
inputs (List[torch.Tensor]): The input images. inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required data_samples (List[SelfSupDataSample]): All elements required
during the forward function. during the forward function.
""" """
q_grid = self.extract_feat(inputs, data_samples)[0] q_grid = self.extract_feat(batch_inputs)[0]
q_grid = q_grid.view(q_grid.size(0), q_grid.size(1), -1) q_grid = q_grid.view(q_grid.size(0), q_grid.size(1), -1)
q_grid = nn.functional.normalize(q_grid, dim=1) q_grid = nn.functional.normalize(q_grid, dim=1)

View File

@ -1,11 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn import torch.nn as nn
from mmcv.runner import BaseModule from mmengine.model import BaseModule
from ..builder import NECKS from mmselfsup.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class DenseCLNeck(BaseModule): class DenseCLNeck(BaseModule):
"""The non-linear neck of DenseCL. """The non-linear neck of DenseCL.
@ -22,11 +25,11 @@ class DenseCLNeck(BaseModule):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
hid_channels, hid_channels: int,
out_channels, out_channels: int,
num_grid=None, num_grid: Optional[int] = None,
init_cfg=None): init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super(DenseCLNeck, self).__init__(init_cfg) super(DenseCLNeck, self).__init__(init_cfg)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
@ -41,7 +44,7 @@ class DenseCLNeck(BaseModule):
nn.Conv2d(hid_channels, out_channels, 1)) nn.Conv2d(hid_channels, out_channels, 1))
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x): def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
"""Forward function of neck. """Forward function of neck.
Args: Args:

View File

@ -27,8 +27,10 @@ neck = dict(
hid_channels=2, hid_channels=2,
out_channels=2, out_channels=2,
num_grid=None) num_grid=None)
head = dict(type='ContrastiveHead', temperature=0.2) head = dict(
loss = dict(type='mmcls.CrossEntropyLoss') type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.2)
def mock_batch_shuffle_ddp(img): def mock_batch_shuffle_ddp(img):
@ -45,43 +47,21 @@ def mock_concat_all_gather(img):
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_densecl(): def test_densecl():
preprocess_cfg = { data_preprocessor = {
'mean': [0.5, 0.5, 0.5], 'mean': (123.675, 116.28, 103.53),
'std': [0.5, 0.5, 0.5], 'std': (58.395, 57.12, 57.375),
'to_rgb': True 'bgr_to_rgb': True
} }
with pytest.raises(AssertionError):
alg = DenseCL(
backbone=backbone,
neck=None,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = DenseCL(
backbone=backbone,
neck=neck,
head=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = DenseCL(
backbone=backbone,
neck=neck,
head=head,
loss=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = DenseCL( alg = DenseCL(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
head=head, head=head,
loss=loss,
queue_len=queue_len, queue_len=queue_len,
feat_dim=feat_dim, feat_dim=feat_dim,
momentum=momentum, momentum=momentum,
loss_lambda=loss_lambda, loss_lambda=loss_lambda,
preprocess_cfg=copy.deepcopy(preprocess_cfg)) data_preprocessor=copy.deepcopy(data_preprocessor))
assert alg.queue.size() == torch.Size([feat_dim, queue_len]) assert alg.queue.size() == torch.Size([feat_dim, queue_len])
assert alg.queue2.size() == torch.Size([feat_dim, queue_len]) assert alg.queue2.size() == torch.Size([feat_dim, queue_len])
@ -100,21 +80,19 @@ def test_densecl():
SelfSupDataSample(), SelfSupDataSample(),
} for _ in range(2)] } for _ in range(2)]
fake_outputs = alg(fake_data, return_loss=True) fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
assert isinstance(fake_outputs['loss'].item(), float) fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['log_vars']['loss_single'], float) assert isinstance(fake_loss['loss_single'].item(), float)
assert isinstance(fake_outputs['log_vars']['loss_dense'], float) assert isinstance(fake_loss['loss_dense'].item(), float)
assert fake_outputs['log_vars']['loss_single'] > 0 assert fake_loss['loss_single'].item() > 0
assert fake_outputs['log_vars']['loss_dense'] > 0 assert fake_loss['loss_dense'].item() > 0
assert alg.queue_ptr.item() == 2 assert alg.queue_ptr.item() == 2
assert alg.queue2_ptr.item() == 2 assert alg.queue2_ptr.item() == 2
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data) fake_feat = alg(fake_inputs, fake_data_samples, mode='tensor')
fake_feat = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert list(fake_feat[0].shape) == [2, 512, 7, 7] assert list(fake_feat[0].shape) == [2, 512, 7, 7]
fake_outputs = alg(fake_data, return_loss=False) fake_outputs = alg(fake_inputs, fake_data_samples, mode='predict')
assert 'q_grid' in fake_outputs assert 'q_grid' in fake_outputs
assert 'value' in fake_outputs.q_grid assert 'value' in fake_outputs.q_grid
assert list(fake_outputs.q_grid.value.shape) == [2, 512, 49] assert list(fake_outputs.q_grid.value.shape) == [2, 512, 49]