[Refactor] refactor moco and mocov3
parent
dfbe3f6235
commit
e87be11a98
|
@ -1,7 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
|
||||
concat_all_gather)
|
||||
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
||||
|
@ -18,28 +21,33 @@ class MoCo(BaseModel):
|
|||
`<https://github.com/facebookresearch/moco/blob/master/moco/builder.py>`_.
|
||||
|
||||
Args:
|
||||
backbone (dict): Config dict for module of backbone.
|
||||
neck (dict): Config dict for module of deep features to compact
|
||||
feature vectors. Defaults to None.
|
||||
head (dict): Config dict for module of loss functions.
|
||||
backbone (Dict): Config dict for module of backbone.
|
||||
neck (Dict): Config dict for module of deep features to compact feature
|
||||
vectors.
|
||||
head (Dict): Config dict for module of loss functions.
|
||||
queue_len (int, optional): Number of negative keys maintained in the
|
||||
queue. Defaults to 65536.
|
||||
feat_dim (int, optional): Dimension of compact feature vectors.
|
||||
Defaults to 128.
|
||||
momentum (float, optional): Momentum coefficient for the
|
||||
momentum-updated encoder. Defaults to 0.999.
|
||||
preprocess_cfg (Dict, optional): Config to preprocess images.
|
||||
Defaults to None.
|
||||
queue_len (int): Number of negative keys maintained in the queue.
|
||||
Defaults to 65536.
|
||||
feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
|
||||
momentum (float): Momentum coefficient for the momentum-updated
|
||||
encoder. Defaults to 0.999.
|
||||
init_cfg (Dict or list[Dict], optional): Initialization config dict.
|
||||
Defaults to None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck=None,
|
||||
head=None,
|
||||
queue_len=65536,
|
||||
feat_dim=128,
|
||||
momentum=0.999,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super(MoCo, self).__init__(init_cfg)
|
||||
backbone: Dict,
|
||||
neck: Dict,
|
||||
head: Dict,
|
||||
queue_len: Optional[int] = 65536,
|
||||
feat_dim: Optional[int] = 128,
|
||||
momentum: Optional[float] = 0.999,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
assert neck is not None
|
||||
self.encoder_q = nn.Sequential(
|
||||
build_backbone(backbone), build_neck(neck))
|
||||
|
@ -65,7 +73,7 @@ class MoCo(BaseModel):
|
|||
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update_key_encoder(self):
|
||||
def _momentum_update_key_encoder(self) -> None:
|
||||
"""Momentum update of the key encoder."""
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(),
|
||||
self.encoder_k.parameters()):
|
||||
|
@ -73,7 +81,7 @@ class MoCo(BaseModel):
|
|||
param_q.data * (1. - self.momentum)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, keys):
|
||||
def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
|
||||
"""Update queue."""
|
||||
# gather keys before updating queue
|
||||
keys = concat_all_gather(keys)
|
||||
|
@ -89,33 +97,37 @@ class MoCo(BaseModel):
|
|||
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
def extract_feat(self, img):
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: backbone outputs.
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
x = self.backbone(img)
|
||||
x = self.backbone(inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, img, **kwargs):
|
||||
"""Forward computation during training.
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
img (list[Tensor]): A list of input images with shape
|
||||
(N, C, H, W). Typically these should be mean centered
|
||||
and std scaled.
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
assert isinstance(img, list)
|
||||
im_q = img[0]
|
||||
im_k = img[1]
|
||||
im_q = inputs[0]
|
||||
im_k = inputs[1]
|
||||
# compute query features
|
||||
q = self.encoder_q(im_q)[0] # queries: NxC
|
||||
q = nn.functional.normalize(q, dim=1)
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
||||
from .base import BaseModel
|
||||
|
||||
|
@ -14,25 +17,27 @@ class MoCoV3(BaseModel):
|
|||
Transformers <https://arxiv.org/abs/2104.02057>`_.
|
||||
|
||||
Args:
|
||||
backbone (dict): Config dict for module of backbone.
|
||||
neck (dict): Config dict for module of deep features to compact
|
||||
feature vectors. Defaults to None.
|
||||
head (dict): Config dict for module of loss functions.
|
||||
backbone (Dict): Config dict for module of backbone
|
||||
neck (Dict): Config dict for module of deep features to compact feature
|
||||
vectors.
|
||||
head (Dict): Config dict for module of loss functions.
|
||||
base_momentum (float, , optional): Momentum coefficient for the
|
||||
momentum-updated encoder. Defaults to 0.99.
|
||||
preprocess_cfg (Dict, optional): Config to preprocess images.
|
||||
Defaults to None.
|
||||
base_momentum (float): Momentum coefficient for the momentum-updated
|
||||
encoder. Defaults to 0.99.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
init_cfg (Dict or list[Dict], optional): Initialization config dict.
|
||||
Defaults to None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
neck,
|
||||
head,
|
||||
base_momentum=0.99,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super(MoCoV3, self).__init__(init_cfg)
|
||||
backbone: Dict,
|
||||
neck: Dict,
|
||||
head: Dict,
|
||||
base_momentum: Optional[float] = 0.99,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
assert neck is not None
|
||||
self.base_encoder = nn.Sequential(
|
||||
build_backbone(backbone), build_neck(neck))
|
||||
|
@ -46,9 +51,9 @@ class MoCoV3(BaseModel):
|
|||
self.base_momentum = base_momentum
|
||||
self.momentum = base_momentum
|
||||
|
||||
def init_weights(self):
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize base_encoder with init_cfg defined in backbone."""
|
||||
super(MoCoV3, self).init_weights()
|
||||
super().init_weights()
|
||||
|
||||
for param_b, param_m in zip(self.base_encoder.parameters(),
|
||||
self.momentum_encoder.parameters()):
|
||||
|
@ -56,39 +61,44 @@ class MoCoV3(BaseModel):
|
|||
param_m.requires_grad = False
|
||||
|
||||
@torch.no_grad()
|
||||
def momentum_update(self):
|
||||
def momentum_update(self) -> None:
|
||||
"""Momentum update of the momentum encoder."""
|
||||
for param_b, param_m in zip(self.base_encoder.parameters(),
|
||||
self.momentum_encoder.parameters()):
|
||||
param_m.data = param_m.data * self.momentum + param_b.data * (
|
||||
1. - self.momentum)
|
||||
|
||||
def extract_feat(self, img):
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
img (Tensor): Input images. Typically these should be mean centered
|
||||
and std scaled.
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: backbone outputs.
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
x = self.backbone(img)
|
||||
x = self.backbone(inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, img, **kwargs):
|
||||
"""Forward computation during training.
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
img (list[Tensor]): A list of input images. Typically these should
|
||||
be mean centered and std scaled.
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
assert isinstance(img, list)
|
||||
view_1 = img[0].cuda(non_blocking=True)
|
||||
view_2 = img[1].cuda(non_blocking=True)
|
||||
view_1 = inputs[0]
|
||||
view_2 = inputs[1]
|
||||
|
||||
# compute query features, [N, C] each
|
||||
q1 = self.base_encoder(view_1)[0]
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import platform
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import mmselfsup
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.models.algorithms import MoCo
|
||||
|
||||
queue_len = 32
|
||||
feat_dim = 2
|
||||
momentum = 0.999
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='MoCoV2Neck',
|
||||
in_channels=512,
|
||||
hid_channels=2,
|
||||
out_channels=2,
|
||||
with_avg_pool=True)
|
||||
head = dict(type='ContrastiveHead', temperature=0.2)
|
||||
|
||||
|
||||
def mock_batch_shuffle_ddp(img):
|
||||
return img, 0
|
||||
|
||||
|
||||
def mock_batch_unshuffle_ddp(img, mock_input):
|
||||
return img
|
||||
|
||||
|
||||
def mock_concat_all_gather(img):
|
||||
return img
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_moco():
|
||||
preprocess_cfg = {
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCo(
|
||||
backbone=None,
|
||||
neck=neck,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCo(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCo(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
|
||||
alg = MoCo(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
queue_len=queue_len,
|
||||
feat_dim=feat_dim,
|
||||
momentum=momentum,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
|
||||
|
||||
fake_data = [{
|
||||
'inputs': [torch.randn((3, 224, 224)),
|
||||
torch.randn((3, 224, 224))],
|
||||
'data_sample':
|
||||
SelfSupDataSample()
|
||||
} for _ in range(2)]
|
||||
|
||||
mmselfsup.models.algorithms.moco.batch_shuffle_ddp = MagicMock(
|
||||
side_effect=mock_batch_shuffle_ddp)
|
||||
mmselfsup.models.algorithms.moco.batch_unshuffle_ddp = MagicMock(
|
||||
side_effect=mock_batch_unshuffle_ddp)
|
||||
mmselfsup.models.algorithms.moco.concat_all_gather = MagicMock(
|
||||
side_effect=mock_concat_all_gather)
|
||||
fake_loss = alg(fake_data, return_loss=True)
|
||||
assert fake_loss['loss'] > 0
|
||||
assert alg.queue_ptr.item() == 2
|
||||
|
||||
# test extract
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_backbone_out = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
assert fake_backbone_out[0].size() == torch.Size([2, 512, 7, 7])
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.models import MoCoV3
|
||||
|
||||
backbone = dict(
|
||||
type='VisionTransformer',
|
||||
arch='mocov3-small', # embed_dim = 384
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
stop_grad_conv1=True)
|
||||
neck = dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=384,
|
||||
hid_channels=2,
|
||||
out_channels=2,
|
||||
num_layers=2,
|
||||
with_bias=False,
|
||||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=False,
|
||||
vit_backbone=True,
|
||||
norm_cfg=dict(type='BN1d'))
|
||||
head = dict(
|
||||
type='MoCoV3Head',
|
||||
predictor=dict(
|
||||
type='NonLinearNeck',
|
||||
in_channels=2,
|
||||
hid_channels=2,
|
||||
out_channels=2,
|
||||
num_layers=2,
|
||||
with_bias=False,
|
||||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=False,
|
||||
norm_cfg=dict(type='BN1d')),
|
||||
temperature=0.2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_mocov3():
|
||||
preprocess_cfg = {
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCoV3(
|
||||
backbone=None,
|
||||
neck=neck,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCoV3(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = MoCoV3(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
|
||||
alg = MoCoV3(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
alg.init_weights()
|
||||
alg.momentum_update()
|
||||
|
||||
fake_data = [{
|
||||
'inputs': [torch.randn((3, 224, 224)),
|
||||
torch.randn((3, 224, 224))],
|
||||
'data_sample':
|
||||
SelfSupDataSample()
|
||||
} for _ in range(2)]
|
||||
|
||||
# test extract
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_backbone_out = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
assert fake_backbone_out[0][0].size() == torch.Size([2, 384, 14, 14])
|
||||
assert fake_backbone_out[0][1].size() == torch.Size([2, 384])
|
Loading…
Reference in New Issue