[Refactor]: refactor densecl algorithm
parent
d6dfa9fe40
commit
5f778aa552
|
@ -5,6 +5,10 @@ model = dict(
|
|||
feat_dim=128,
|
||||
momentum=0.999,
|
||||
loss_lambda=0.5,
|
||||
data_preprocessor=dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -17,5 +21,8 @@ model = dict(
|
|||
hid_channels=2048,
|
||||
out_channels=128,
|
||||
num_grid=None),
|
||||
head=dict(type='ContrastiveHead', temperature=0.2),
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'))
|
||||
head=dict(
|
||||
type='ContrastiveHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
temperature=0.2),
|
||||
)
|
||||
|
|
|
@ -5,8 +5,10 @@ _base_ = [
|
|||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
# runtime settings
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', interval=50),
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
|
||||
|
|
|
@ -4,16 +4,16 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.model import ExponentialMovingAverage
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
|
||||
concat_all_gather)
|
||||
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
|
||||
build_neck)
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@ALGORITHMS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DenseCL(BaseModel):
|
||||
"""DenseCL.
|
||||
|
||||
|
@ -25,11 +25,8 @@ class DenseCL(BaseModel):
|
|||
Args:
|
||||
backbone (dict): Config dict for module of backbone.
|
||||
neck (dict): Config dict for module of deep features to compact
|
||||
feature vectors. Defaults to None.
|
||||
feature vectors.
|
||||
head (dict): Config dict for module of head functions.
|
||||
Defaults to None.
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
Defaults to None.
|
||||
queue_len (int): Number of negative keys maintained in the queue.
|
||||
Defaults to 65536.
|
||||
feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
|
||||
|
@ -37,66 +34,55 @@ class DenseCL(BaseModel):
|
|||
encoder. Defaults to 0.999.
|
||||
loss_lambda (float): Loss weight for the single and dense contrastive
|
||||
loss. Defaults to 0.5.
|
||||
preprocess_cfg (Dict, optional): Config dict to preprocess images.
|
||||
pretrained (str, optional): The pretrained checkpoint path, support
|
||||
local path and remote path. Defaults to None.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
preprocessing input data. If None or no specified type, it will use
|
||||
"SelfSupDataPreprocessor" as type.
|
||||
See :class:`SelfSupDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
backbone: dict,
|
||||
neck: dict,
|
||||
head: dict,
|
||||
queue_len: int = 65536,
|
||||
feat_dim: int = 128,
|
||||
momentum: float = 0.999,
|
||||
loss_lambda: float = 0.5,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
|
||||
**kwargs):
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
assert neck is not None
|
||||
self.encoder_q = nn.Sequential(
|
||||
build_backbone(backbone), build_neck(neck))
|
||||
self.encoder_k = nn.Sequential(
|
||||
build_backbone(backbone), build_neck(neck))
|
||||
pretrained: Optional[str] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
pretrained=pretrained,
|
||||
data_preprocessor=data_preprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(),
|
||||
self.encoder_k.parameters()):
|
||||
param_k.data.copy_(param_q.data)
|
||||
param_k.requires_grad = False
|
||||
|
||||
self.backbone = self.encoder_q[0]
|
||||
assert head is not None
|
||||
self.head = build_head(head)
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
# create momentum model
|
||||
self.encoder_k = ExponentialMovingAverage(
|
||||
nn.Sequential(self.backbone, self.neck), 1 - momentum)
|
||||
|
||||
self.queue_len = queue_len
|
||||
self.momentum = momentum
|
||||
self.loss_lambda = loss_lambda
|
||||
|
||||
# create the queue
|
||||
self.register_buffer('queue', torch.randn(feat_dim, queue_len))
|
||||
self.queue = nn.functional.normalize(self.queue, dim=0)
|
||||
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
|
||||
|
||||
# create the second queue for dense output
|
||||
self.register_buffer('queue2', torch.randn(feat_dim, queue_len))
|
||||
self.queue2 = nn.functional.normalize(self.queue2, dim=0)
|
||||
self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long))
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update_key_encoder(self):
|
||||
"""Momentum update of the key encoder."""
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(),
|
||||
self.encoder_k.parameters()):
|
||||
param_k.data = param_k.data * self.momentum + \
|
||||
param_q.data * (1. - self.momentum)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, keys):
|
||||
def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
|
||||
"""Update queue."""
|
||||
# gather keys before updating queue
|
||||
keys = concat_all_gather(keys)
|
||||
|
@ -113,7 +99,7 @@ class DenseCL(BaseModel):
|
|||
self.queue_ptr[0] = ptr
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue2(self, keys):
|
||||
def _dequeue_and_enqueue2(self, keys: torch.Tensor) -> None:
|
||||
"""Update queue2."""
|
||||
# gather keys before updating queue
|
||||
keys = concat_all_gather(keys)
|
||||
|
@ -129,41 +115,40 @@ class DenseCL(BaseModel):
|
|||
|
||||
self.queue2_ptr[0] = ptr
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||
**kwargs) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
x = self.backbone(inputs[0])
|
||||
x = self.backbone(batch_inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
def loss(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""Forward computation during training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
assert isinstance(inputs, list)
|
||||
im_q = inputs[0]
|
||||
im_k = inputs[1]
|
||||
assert isinstance(batch_inputs, list)
|
||||
im_q = batch_inputs[0]
|
||||
im_k = batch_inputs[1]
|
||||
# compute query features
|
||||
q_b = self.encoder_q[0](im_q) # backbone features
|
||||
q, q_grid, q2 = self.encoder_q[1](q_b) # queries: NxC; NxCxS^2
|
||||
q_b = self.backbone(im_q) # backbone features
|
||||
q, q_grid, q2 = self.neck(q_b) # queries: NxC; NxCxS^2
|
||||
q_b = q_b[0]
|
||||
q_b = q_b.view(q_b.size(0), q_b.size(1), -1)
|
||||
|
||||
|
@ -175,13 +160,14 @@ class DenseCL(BaseModel):
|
|||
# compute key features
|
||||
with torch.no_grad(): # no gradient to keys
|
||||
# update the key encoder
|
||||
self._momentum_update_key_encoder()
|
||||
self.encoder_k.update_parameters(
|
||||
nn.Sequential(self.backbone, self.neck))
|
||||
|
||||
# shuffle for making use of BN
|
||||
im_k, idx_unshuffle = batch_shuffle_ddp(im_k)
|
||||
|
||||
k_b = self.encoder_k[0](im_k) # backbone features
|
||||
k, k_grid, k2 = self.encoder_k[1](k_b) # keys: NxC; NxCxS^2
|
||||
k_b = self.encoder_k.module[0](im_k) # backbone features
|
||||
k, k_grid, k2 = self.encoder_k.module[1](k_b) # keys: NxC; NxCxS^2
|
||||
k_b = k_b[0]
|
||||
k_b = k_b.view(k_b.size(0), k_b.size(1), -1)
|
||||
|
||||
|
@ -221,10 +207,8 @@ class DenseCL(BaseModel):
|
|||
l_neg_dense = torch.einsum(
|
||||
'nc,ck->nk', [q_grid, self.queue2.clone().detach()])
|
||||
|
||||
logits, labels = self.head(l_pos, l_neg)
|
||||
logits_dense, labels_dense = self.head(l_pos_dense, l_neg_dense)
|
||||
loss_single = self.loss(logits, labels)
|
||||
loss_dense = self.loss(logits_dense, labels_dense)
|
||||
loss_single = self.head(l_pos, l_neg)
|
||||
loss_dense = self.head(l_pos_dense, l_neg_dense)
|
||||
|
||||
losses = dict()
|
||||
losses['loss_single'] = loss_single * (1 - self.loss_lambda)
|
||||
|
@ -235,16 +219,17 @@ class DenseCL(BaseModel):
|
|||
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs: List[torch.Tensor],
|
||||
def predict(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> object:
|
||||
"""The forward function in testing
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
"""Predict results from the extracted features.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
"""
|
||||
q_grid = self.extract_feat(inputs, data_samples)[0]
|
||||
q_grid = self.extract_feat(batch_inputs)[0]
|
||||
q_grid = q_grid.view(q_grid.size(0), q_grid.size(1), -1)
|
||||
q_grid = nn.functional.normalize(q_grid, dim=1)
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from ..builder import NECKS
|
||||
from mmselfsup.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DenseCLNeck(BaseModule):
|
||||
"""The non-linear neck of DenseCL.
|
||||
|
||||
|
@ -22,11 +25,11 @@ class DenseCLNeck(BaseModule):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
num_grid=None,
|
||||
init_cfg=None):
|
||||
in_channels: int,
|
||||
hid_channels: int,
|
||||
out_channels: int,
|
||||
num_grid: Optional[int] = None,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super(DenseCLNeck, self).__init__(init_cfg)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.mlp = nn.Sequential(
|
||||
|
@ -41,7 +44,7 @@ class DenseCLNeck(BaseModule):
|
|||
nn.Conv2d(hid_channels, out_channels, 1))
|
||||
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
"""Forward function of neck.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -27,8 +27,10 @@ neck = dict(
|
|||
hid_channels=2,
|
||||
out_channels=2,
|
||||
num_grid=None)
|
||||
head = dict(type='ContrastiveHead', temperature=0.2)
|
||||
loss = dict(type='mmcls.CrossEntropyLoss')
|
||||
head = dict(
|
||||
type='ContrastiveHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
temperature=0.2)
|
||||
|
||||
|
||||
def mock_batch_shuffle_ddp(img):
|
||||
|
@ -45,43 +47,21 @@ def mock_concat_all_gather(img):
|
|||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_densecl():
|
||||
preprocess_cfg = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
data_preprocessor = {
|
||||
'mean': (123.675, 116.28, 103.53),
|
||||
'std': (58.395, 57.12, 57.375),
|
||||
'bgr_to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
alg = DenseCL(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = DenseCL(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = DenseCL(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
|
||||
alg = DenseCL(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
queue_len=queue_len,
|
||||
feat_dim=feat_dim,
|
||||
momentum=momentum,
|
||||
loss_lambda=loss_lambda,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
data_preprocessor=copy.deepcopy(data_preprocessor))
|
||||
|
||||
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
|
||||
assert alg.queue2.size() == torch.Size([feat_dim, queue_len])
|
||||
|
@ -100,21 +80,19 @@ def test_densecl():
|
|||
SelfSupDataSample(),
|
||||
} for _ in range(2)]
|
||||
|
||||
fake_outputs = alg(fake_data, return_loss=True)
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
assert isinstance(fake_outputs['log_vars']['loss_single'], float)
|
||||
assert isinstance(fake_outputs['log_vars']['loss_dense'], float)
|
||||
assert fake_outputs['log_vars']['loss_single'] > 0
|
||||
assert fake_outputs['log_vars']['loss_dense'] > 0
|
||||
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_loss['loss_single'].item(), float)
|
||||
assert isinstance(fake_loss['loss_dense'].item(), float)
|
||||
assert fake_loss['loss_single'].item() > 0
|
||||
assert fake_loss['loss_dense'].item() > 0
|
||||
assert alg.queue_ptr.item() == 2
|
||||
assert alg.queue2_ptr.item() == 2
|
||||
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_feat = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
fake_feat = alg(fake_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
|
||||
|
||||
fake_outputs = alg(fake_data, return_loss=False)
|
||||
fake_outputs = alg(fake_inputs, fake_data_samples, mode='predict')
|
||||
assert 'q_grid' in fake_outputs
|
||||
assert 'value' in fake_outputs.q_grid
|
||||
assert list(fake_outputs.q_grid.value.shape) == [2, 512, 49]
|
||||
|
|
Loading…
Reference in New Issue