[Refactor] refactor moco and mocov3

This commit is contained in:
fangyixiao.vendor 2022-05-17 02:27:17 +00:00 committed by fangyixiao18
parent dfbe3f6235
commit e87be11a98
4 changed files with 279 additions and 64 deletions

View File

@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmselfsup.core import SelfSupDataSample
from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp, from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
concat_all_gather) concat_all_gather)
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
@ -18,28 +21,33 @@ class MoCo(BaseModel):
`<https://github.com/facebookresearch/moco/blob/master/moco/builder.py>`_. `<https://github.com/facebookresearch/moco/blob/master/moco/builder.py>`_.
Args: Args:
backbone (dict): Config dict for module of backbone. backbone (Dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact neck (Dict): Config dict for module of deep features to compact feature
feature vectors. Defaults to None. vectors.
head (dict): Config dict for module of loss functions. head (Dict): Config dict for module of loss functions.
queue_len (int, optional): Number of negative keys maintained in the
queue. Defaults to 65536.
feat_dim (int, optional): Dimension of compact feature vectors.
Defaults to 128.
momentum (float, optional): Momentum coefficient for the
momentum-updated encoder. Defaults to 0.999.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None. Defaults to None.
queue_len (int): Number of negative keys maintained in the queue. init_cfg (Dict or list[Dict], optional): Initialization config dict.
Defaults to 65536. Defaults to None
feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
momentum (float): Momentum coefficient for the momentum-updated
encoder. Defaults to 0.999.
""" """
def __init__(self, def __init__(self,
backbone, backbone: Dict,
neck=None, neck: Dict,
head=None, head: Dict,
queue_len=65536, queue_len: Optional[int] = 65536,
feat_dim=128, feat_dim: Optional[int] = 128,
momentum=0.999, momentum: Optional[float] = 0.999,
init_cfg=None, preprocess_cfg: Optional[Dict] = None,
**kwargs): init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
super(MoCo, self).__init__(init_cfg) super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
assert neck is not None assert neck is not None
self.encoder_q = nn.Sequential( self.encoder_q = nn.Sequential(
build_backbone(backbone), build_neck(neck)) build_backbone(backbone), build_neck(neck))
@ -65,7 +73,7 @@ class MoCo(BaseModel):
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad() @torch.no_grad()
def _momentum_update_key_encoder(self): def _momentum_update_key_encoder(self) -> None:
"""Momentum update of the key encoder.""" """Momentum update of the key encoder."""
for param_q, param_k in zip(self.encoder_q.parameters(), for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()): self.encoder_k.parameters()):
@ -73,7 +81,7 @@ class MoCo(BaseModel):
param_q.data * (1. - self.momentum) param_q.data * (1. - self.momentum)
@torch.no_grad() @torch.no_grad()
def _dequeue_and_enqueue(self, keys): def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
"""Update queue.""" """Update queue."""
# gather keys before updating queue # gather keys before updating queue
keys = concat_all_gather(keys) keys = concat_all_gather(keys)
@ -89,33 +97,37 @@ class MoCo(BaseModel):
self.queue_ptr[0] = ptr self.queue_ptr[0] = ptr
def extract_feat(self, img): def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone. """Function to extract features from backbone.
Args: Args:
img (Tensor): Input images of shape (N, C, H, W). inputs (List[torch.Tensor]): The input images.
Typically these should be mean centered and std scaled. data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns: Returns:
tuple[Tensor]: backbone outputs. Tuple[torch.Tensor]: backbone outputs.
""" """
x = self.backbone(img) x = self.backbone(inputs[0])
return x return x
def forward_train(self, img, **kwargs): def forward_train(self, inputs: List[torch.Tensor],
"""Forward computation during training. data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args: Args:
img (list[Tensor]): A list of input images with shape inputs (List[torch.Tensor]): The input images.
(N, C, H, W). Typically these should be mean centered data_samples (List[SelfSupDataSample]): All elements required
and std scaled. during the forward function.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. Dict[str, torch.Tensor]: A dictionary of loss components.
""" """
assert isinstance(img, list) im_q = inputs[0]
im_q = img[0] im_k = inputs[1]
im_k = img[1]
# compute query features # compute query features
q = self.encoder_q(im_q)[0] # queries: NxC q = self.encoder_q(im_q)[0] # queries: NxC
q = nn.functional.normalize(q, dim=1) q = nn.functional.normalize(q, dim=1)

View File

@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmselfsup.core import SelfSupDataSample
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel from .base import BaseModel
@ -14,25 +17,27 @@ class MoCoV3(BaseModel):
Transformers <https://arxiv.org/abs/2104.02057>`_. Transformers <https://arxiv.org/abs/2104.02057>`_.
Args: Args:
backbone (dict): Config dict for module of backbone. backbone (Dict): Config dict for module of backbone
neck (dict): Config dict for module of deep features to compact neck (Dict): Config dict for module of deep features to compact feature
feature vectors. Defaults to None. vectors.
head (dict): Config dict for module of loss functions. head (Dict): Config dict for module of loss functions.
base_momentum (float, , optional): Momentum coefficient for the
momentum-updated encoder. Defaults to 0.99.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None. Defaults to None.
base_momentum (float): Momentum coefficient for the momentum-updated init_cfg (Dict or list[Dict], optional): Initialization config dict.
encoder. Defaults to 0.99.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None Defaults to None
""" """
def __init__(self, def __init__(self,
backbone, backbone: Dict,
neck, neck: Dict,
head, head: Dict,
base_momentum=0.99, base_momentum: Optional[float] = 0.99,
init_cfg=None, preprocess_cfg: Optional[Dict] = None,
**kwargs): init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
super(MoCoV3, self).__init__(init_cfg) super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
assert neck is not None assert neck is not None
self.base_encoder = nn.Sequential( self.base_encoder = nn.Sequential(
build_backbone(backbone), build_neck(neck)) build_backbone(backbone), build_neck(neck))
@ -46,9 +51,9 @@ class MoCoV3(BaseModel):
self.base_momentum = base_momentum self.base_momentum = base_momentum
self.momentum = base_momentum self.momentum = base_momentum
def init_weights(self): def init_weights(self) -> None:
"""Initialize base_encoder with init_cfg defined in backbone.""" """Initialize base_encoder with init_cfg defined in backbone."""
super(MoCoV3, self).init_weights() super().init_weights()
for param_b, param_m in zip(self.base_encoder.parameters(), for param_b, param_m in zip(self.base_encoder.parameters(),
self.momentum_encoder.parameters()): self.momentum_encoder.parameters()):
@ -56,39 +61,44 @@ class MoCoV3(BaseModel):
param_m.requires_grad = False param_m.requires_grad = False
@torch.no_grad() @torch.no_grad()
def momentum_update(self): def momentum_update(self) -> None:
"""Momentum update of the momentum encoder.""" """Momentum update of the momentum encoder."""
for param_b, param_m in zip(self.base_encoder.parameters(), for param_b, param_m in zip(self.base_encoder.parameters(),
self.momentum_encoder.parameters()): self.momentum_encoder.parameters()):
param_m.data = param_m.data * self.momentum + param_b.data * ( param_m.data = param_m.data * self.momentum + param_b.data * (
1. - self.momentum) 1. - self.momentum)
def extract_feat(self, img): def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone. """Function to extract features from backbone.
Args: Args:
img (Tensor): Input images. Typically these should be mean centered inputs (List[torch.Tensor]): The input images.
and std scaled. data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns: Returns:
tuple[Tensor]: backbone outputs. Tuple[torch.Tensor]: backbone outputs.
""" """
x = self.backbone(img) x = self.backbone(inputs[0])
return x return x
def forward_train(self, img, **kwargs): def forward_train(self, inputs: List[torch.Tensor],
"""Forward computation during training. data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args: Args:
img (list[Tensor]): A list of input images. Typically these should inputs (List[torch.Tensor]): The input images.
be mean centered and std scaled. data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. Dict[str, torch.Tensor]: A dictionary of loss components.
""" """
assert isinstance(img, list) view_1 = inputs[0]
view_1 = img[0].cuda(non_blocking=True) view_2 = inputs[1]
view_2 = img[1].cuda(non_blocking=True)
# compute query features, [N, C] each # compute query features, [N, C] each
q1 = self.base_encoder(view_1)[0] q1 = self.base_encoder(view_1)[0]

View File

@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
from unittest.mock import MagicMock
import pytest
import torch
import mmselfsup
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.algorithms import MoCo
queue_len = 32
feat_dim = 2
momentum = 0.999
backbone = dict(
type='ResNet',
depth=18,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
neck = dict(
type='MoCoV2Neck',
in_channels=512,
hid_channels=2,
out_channels=2,
with_avg_pool=True)
head = dict(type='ContrastiveHead', temperature=0.2)
def mock_batch_shuffle_ddp(img):
return img, 0
def mock_batch_unshuffle_ddp(img, mock_input):
return img
def mock_concat_all_gather(img):
return img
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_moco():
preprocess_cfg = {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'to_rgb': True
}
with pytest.raises(AssertionError):
alg = MoCo(
backbone=None,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCo(
backbone=backbone,
neck=None,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCo(
backbone=backbone,
neck=neck,
head=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = MoCo(
backbone=backbone,
neck=neck,
head=head,
queue_len=queue_len,
feat_dim=feat_dim,
momentum=momentum,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
fake_data = [{
'inputs': [torch.randn((3, 224, 224)),
torch.randn((3, 224, 224))],
'data_sample':
SelfSupDataSample()
} for _ in range(2)]
mmselfsup.models.algorithms.moco.batch_shuffle_ddp = MagicMock(
side_effect=mock_batch_shuffle_ddp)
mmselfsup.models.algorithms.moco.batch_unshuffle_ddp = MagicMock(
side_effect=mock_batch_unshuffle_ddp)
mmselfsup.models.algorithms.moco.concat_all_gather = MagicMock(
side_effect=mock_concat_all_gather)
fake_loss = alg(fake_data, return_loss=True)
assert fake_loss['loss'] > 0
assert alg.queue_ptr.item() == 2
# test extract
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_backbone_out = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert fake_backbone_out[0].size() == torch.Size([2, 512, 7, 7])

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import pytest
import torch
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models import MoCoV3
backbone = dict(
type='VisionTransformer',
arch='mocov3-small', # embed_dim = 384
img_size=224,
patch_size=16,
stop_grad_conv1=True)
neck = dict(
type='NonLinearNeck',
in_channels=384,
hid_channels=2,
out_channels=2,
num_layers=2,
with_bias=False,
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
vit_backbone=True,
norm_cfg=dict(type='BN1d'))
head = dict(
type='MoCoV3Head',
predictor=dict(
type='NonLinearNeck',
in_channels=2,
hid_channels=2,
out_channels=2,
num_layers=2,
with_bias=False,
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
norm_cfg=dict(type='BN1d')),
temperature=0.2)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mocov3():
preprocess_cfg = {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'to_rgb': True
}
with pytest.raises(AssertionError):
alg = MoCoV3(
backbone=None,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCoV3(
backbone=backbone,
neck=None,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = MoCoV3(
backbone=backbone,
neck=neck,
head=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = MoCoV3(
backbone=backbone,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg.init_weights()
alg.momentum_update()
fake_data = [{
'inputs': [torch.randn((3, 224, 224)),
torch.randn((3, 224, 224))],
'data_sample':
SelfSupDataSample()
} for _ in range(2)]
# test extract
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_backbone_out = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert fake_backbone_out[0][0].size() == torch.Size([2, 384, 14, 14])
assert fake_backbone_out[0][1].size() == torch.Size([2, 384])