[Refactor] refactor moco and mocov3

pull/352/head
fangyixiao.vendor 2022-05-17 02:27:17 +00:00 committed by fangyixiao18
parent dfbe3f6235
commit e87be11a98
4 changed files with 279 additions and 64 deletions

View File

@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmselfsup.core import SelfSupDataSample
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
concat_all_gather)
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
@ -18,28 +21,33 @@ class MoCo(BaseModel):
`<https://github.com/facebookresearch/moco/blob/master/moco/builder.py>`_.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact
feature vectors. Defaults to None.
head (dict): Config dict for module of loss functions.
backbone (Dict): Config dict for module of backbone.
neck (Dict): Config dict for module of deep features to compact feature
vectors.
head (Dict): Config dict for module of loss functions.
queue_len (int, optional): Number of negative keys maintained in the
queue. Defaults to 65536.
feat_dim (int, optional): Dimension of compact feature vectors.
Defaults to 128.
momentum (float, optional): Momentum coefficient for the
momentum-updated encoder. Defaults to 0.999.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None.
queue_len (int): Number of negative keys maintained in the queue.
Defaults to 65536.
feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
momentum (float): Momentum coefficient for the momentum-updated
encoder. Defaults to 0.999.
init_cfg (Dict or list[Dict], optional): Initialization config dict.
Defaults to None
"""
def __init__(self,
backbone,
neck=None,
head=None,
queue_len=65536,
feat_dim=128,
momentum=0.999,
init_cfg=None,
**kwargs):
super(MoCo, self).__init__(init_cfg)
backbone: Dict,
neck: Dict,
head: Dict,
queue_len: Optional[int] = 65536,
feat_dim: Optional[int] = 128,
momentum: Optional[float] = 0.999,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
assert neck is not None
self.encoder_q = nn.Sequential(
build_backbone(backbone), build_neck(neck))
@ -65,7 +73,7 @@ class MoCo(BaseModel):
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
def _momentum_update_key_encoder(self) -> None:
"""Momentum update of the key encoder."""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
@ -73,7 +81,7 @@ class MoCo(BaseModel):
param_q.data * (1. - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
"""Update queue."""
# gather keys before updating queue
keys = concat_all_gather(keys)
@ -89,33 +97,37 @@ class MoCo(BaseModel):
self.queue_ptr[0] = ptr
def extract_feat(self, img):
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
tuple[Tensor]: backbone outputs.
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(img)
x = self.backbone(inputs[0])
return x
def forward_train(self, img, **kwargs):
"""Forward computation during training.
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
img (list[Tensor]): A list of input images with shape
(N, C, H, W). Typically these should be mean centered
and std scaled.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
dict[str, Tensor]: A dictionary of loss components.
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
assert isinstance(img, list)
im_q = img[0]
im_k = img[1]
im_q = inputs[0]
im_k = inputs[1]
# compute query features
q = self.encoder_q(im_q)[0] # queries: NxC
q = nn.functional.normalize(q, dim=1)

View File

@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmselfsup.core import SelfSupDataSample
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel
@ -14,25 +17,27 @@ class MoCoV3(BaseModel):
Transformers <https://arxiv.org/abs/2104.02057>`_.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact
feature vectors. Defaults to None.
head (dict): Config dict for module of loss functions.
backbone (Dict): Config dict for module of backbone
neck (Dict): Config dict for module of deep features to compact feature
vectors.
head (Dict): Config dict for module of loss functions.
base_momentum (float, , optional): Momentum coefficient for the
momentum-updated encoder. Defaults to 0.99.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None.
base_momentum (float): Momentum coefficient for the momentum-updated
encoder. Defaults to 0.99.
init_cfg (dict or list[dict], optional): Initialization config dict.
init_cfg (Dict or list[Dict], optional): Initialization config dict.
Defaults to None
"""
def __init__(self,
backbone,
neck,
head,
base_momentum=0.99,
init_cfg=None,
**kwargs):
super(MoCoV3, self).__init__(init_cfg)
backbone: Dict,
neck: Dict,
head: Dict,
base_momentum: Optional[float] = 0.99,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
assert neck is not None
self.base_encoder = nn.Sequential(
build_backbone(backbone), build_neck(neck))
@ -46,9 +51,9 @@ class MoCoV3(BaseModel):
self.base_momentum = base_momentum
self.momentum = base_momentum
def init_weights(self):
def init_weights(self) -> None:
"""Initialize base_encoder with init_cfg defined in backbone."""
super(MoCoV3, self).init_weights()
super().init_weights()
for param_b, param_m in zip(self.base_encoder.parameters(),
self.momentum_encoder.parameters()):
@ -56,39 +61,44 @@ class MoCoV3(BaseModel):
param_m.requires_grad = False
@torch.no_grad()
def momentum_update(self):
def momentum_update(self) -> None:
"""Momentum update of the momentum encoder."""
for param_b, param_m in zip(self.base_encoder.parameters(),
self.momentum_encoder.parameters()):
param_m.data = param_m.data * self.momentum + param_b.data * (
1. - self.momentum)
def extract_feat(self, img):
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
img (Tensor): Input images. Typically these should be mean centered
and std scaled.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
tuple[Tensor]: backbone outputs.
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(img)
x = self.backbone(inputs[0])
return x
def forward_train(self, img, **kwargs):
"""Forward computation during training.
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
img (list[Tensor]): A list of input images. Typically these should
be mean centered and std scaled.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
dict[str, Tensor]: A dictionary of loss components.
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
assert isinstance(img, list)
view_1 = img[0].cuda(non_blocking=True)
view_2 = img[1].cuda(non_blocking=True)
view_1 = inputs[0]
view_2 = inputs[1]
# compute query features, [N, C] each
q1 = self.base_encoder(view_1)[0]

View File

@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
from unittest.mock import MagicMock
import pytest
import torch
import mmselfsup
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.algorithms import MoCo
queue_len = 32
feat_dim = 2
momentum = 0.999
backbone = dict(
type='ResNet',
depth=18,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
neck = dict(
type='MoCoV2Neck',
in_channels=512,
hid_channels=2,
out_channels=2,
with_avg_pool=True)
head = dict(type='ContrastiveHead', temperature=0.2)
def mock_batch_shuffle_ddp(img):
return img, 0
def mock_batch_unshuffle_ddp(img, mock_input):
return img
def mock_concat_all_gather(img):
return img
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_moco():
preprocess_cfg = {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'to_rgb': True
}
with pytest.raises(AssertionError):
alg = MoCo(
backbone=None,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCo(
backbone=backbone,
neck=None,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCo(
backbone=backbone,
neck=neck,
head=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = MoCo(
backbone=backbone,
neck=neck,
head=head,
queue_len=queue_len,
feat_dim=feat_dim,
momentum=momentum,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
fake_data = [{
'inputs': [torch.randn((3, 224, 224)),
torch.randn((3, 224, 224))],
'data_sample':
SelfSupDataSample()
} for _ in range(2)]
mmselfsup.models.algorithms.moco.batch_shuffle_ddp = MagicMock(
side_effect=mock_batch_shuffle_ddp)
mmselfsup.models.algorithms.moco.batch_unshuffle_ddp = MagicMock(
side_effect=mock_batch_unshuffle_ddp)
mmselfsup.models.algorithms.moco.concat_all_gather = MagicMock(
side_effect=mock_concat_all_gather)
fake_loss = alg(fake_data, return_loss=True)
assert fake_loss['loss'] > 0
assert alg.queue_ptr.item() == 2
# test extract
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_backbone_out = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert fake_backbone_out[0].size() == torch.Size([2, 512, 7, 7])

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import pytest
import torch
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models import MoCoV3
backbone = dict(
type='VisionTransformer',
arch='mocov3-small', # embed_dim = 384
img_size=224,
patch_size=16,
stop_grad_conv1=True)
neck = dict(
type='NonLinearNeck',
in_channels=384,
hid_channels=2,
out_channels=2,
num_layers=2,
with_bias=False,
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
vit_backbone=True,
norm_cfg=dict(type='BN1d'))
head = dict(
type='MoCoV3Head',
predictor=dict(
type='NonLinearNeck',
in_channels=2,
hid_channels=2,
out_channels=2,
num_layers=2,
with_bias=False,
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
norm_cfg=dict(type='BN1d')),
temperature=0.2)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mocov3():
preprocess_cfg = {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'to_rgb': True
}
with pytest.raises(AssertionError):
alg = MoCoV3(
backbone=None,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCoV3(
backbone=backbone,
neck=None,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCoV3(
backbone=backbone,
neck=neck,
head=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = MoCoV3(
backbone=backbone,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg.init_weights()
alg.momentum_update()
fake_data = [{
'inputs': [torch.randn((3, 224, 224)),
torch.randn((3, 224, 224))],
'data_sample':
SelfSupDataSample()
} for _ in range(2)]
# test extract
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_backbone_out = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert fake_backbone_out[0][0].size() == torch.Size([2, 384, 14, 14])
assert fake_backbone_out[0][1].size() == torch.Size([2, 384])