[Refactor]: refactor npid algorithm

pull/352/head
renqin 2022-07-12 07:23:55 +00:00 committed by fangyixiao18
parent 3e5d18ea83
commit 4c5c4b88f4
9 changed files with 109 additions and 127 deletions

View File

@ -1,12 +1,18 @@
# dataset settings
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='RandomResizedCrop', size=224, scale=(0.2, 1.)),
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
dict(
type='RandomResizedCrop', size=224, scale=(0.2, 1.), backend='pillow'),
dict(
type='RandomGrayscale',
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989)),
dict(
type='ColorJitter',
brightness=0.4,
@ -22,7 +28,8 @@ train_pipeline = [
train_dataloader = dict(
batch_size=32,
num_workers=4,
num_workers=8,
drop_last=True,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(

View File

@ -2,6 +2,10 @@
model = dict(
type='NPID',
neg_num=65536,
data_preprocessor=dict(
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
@ -13,7 +17,9 @@ model = dict(
in_channels=2048,
out_channels=128,
with_avg_pool=True),
head=dict(type='ContrastiveHead', temperature=0.07),
loss=dict(type='mmcls.CrossEntropyLoss'),
head=dict(
type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.07),
memory_bank=dict(
type='SimpleMemory', length=1281167, feat_dim=128, momentum=0.5))

View File

@ -6,7 +6,7 @@ _base_ = [
]
# runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))

View File

@ -5,12 +5,11 @@ import torch
import torch.nn as nn
from mmselfsup.core import SelfSupDataSample
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
build_memory, build_neck)
from mmselfsup.registry import MODELS
from .base import BaseModel
@ALGORITHMS.register_module()
@MODELS.register_module()
class NPID(BaseModel):
"""NPID.
@ -18,80 +17,79 @@ class NPID(BaseModel):
Instance Discrimination <https://arxiv.org/abs/1805.01978>`_.
Args:
backbone (Dict): Config dict for module of backbone.
neck (Dict, optional): Config dict for module of deep features to
compact feature vectors. Defaults to None.
head (Dict, optional): Config dict for module of head functions.
Defaults to None.
loss (dict): Config dict for module of loss functions.
Defaults to None.
memory_bank (Dict, optional): Config dict for module of memory banks.
Defaults to None.
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to
compact feature vectors.
head (dict): Config dict for module of head functions.
memory_bank (dict): Config dict for module of memory banks.
neg_num (int): Number of negative samples for each image.
Defaults to 65536.
ensure_neg (bool): If False, there is a small probability
that negative samples contain positive ones. Defaults to False.
preprocess_cfg (Dict, optional): Config to preprocess images.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: Dict,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
memory_bank: Optional[Dict] = None,
neg_num: Optional[int] = 65536,
ensure_neg: Optional[bool] = False,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
backbone: dict,
neck: dict,
head: dict,
memory_bank: dict,
neg_num: int = 65536,
ensure_neg: bool = False,
pretrained: Optional[str] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
assert memory_bank is not None
self.memory_bank = build_memory(memory_bank)
self.memory_bank = MODELS.build(memory_bank)
self.neg_num = neg_num
self.ensure_neg = ensure_neg
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
def extract_feat(self, batch_inputs: List[torch.Tensor],
**kwarg) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(inputs[0])
x = self.backbone(batch_inputs[0])
return x
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
feature = self.extract_feat(inputs[0])
idx = [data_sample.idx for data_sample in data_samples]
feature = self.backbone(batch_inputs[0])
idx = [data_sample.sample_idx.value for data_sample in data_samples]
idx = torch.cat(idx)
if self.with_neck:
feature = self.neck(feature)[0]
@ -119,8 +117,7 @@ class NPID(BaseModel):
[pos_feat, feature]).unsqueeze(-1)
neg_logits = torch.bmm(neg_feat, feature.unsqueeze(2)).squeeze(2)
logits, labels = self.head(pos_logits, neg_logits)
loss = self.loss(logits, labels)
loss = self.head(pos_logits, neg_logits)
losses = dict(loss=loss)
# update memory bank
with torch.no_grad():

View File

@ -1,6 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from mmengine.model import BaseModule
@ -26,8 +24,7 @@ class ContrastiveHead(BaseModule):
self.loss = MODELS.build(loss)
self.temperature = temperature
def forward(self, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor:
"""Forward function to compute contrastive loss.
Args:

View File

@ -2,15 +2,15 @@
from typing import Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import BaseModule, get_dist_info
from mmengine.dist import all_gather
from mmselfsup.registry import MODELS
from mmselfsup.utils import AliasMethod
from ..builder import MEMORIES
@MEMORIES.register_module()
@MODELS.register_module()
class SimpleMemory(BaseModule):
"""Simple memory bank (e.g., for NPID).
@ -27,11 +27,10 @@ class SimpleMemory(BaseModule):
**kwargs) -> None:
super().__init__()
self.rank, self.num_replicas = get_dist_info()
self.feature_bank = torch.randn(length, feat_dim).cuda()
self.feature_bank = nn.functional.normalize(self.feature_bank).cuda()
self.register_buffer('feature_bank', torch.randn(length, feat_dim))
self.feature_bank = nn.functional.normalize(self.feature_bank)
self.momentum = momentum
self.multinomial = AliasMethod(torch.ones(length))
self.multinomial.cuda()
def update(self, idx: torch.Tensor, feature: torch.Tensor) -> None:
"""Update features in memory bank.
@ -61,14 +60,8 @@ class SimpleMemory(BaseModule):
- idx_gathered: Gathered indices.
- feature_gathered: Gathered features.
"""
idx_gathered = [
torch.ones_like(idx).cuda() for _ in range(self.num_replicas)
]
feature_gathered = [
torch.ones_like(feature).cuda() for _ in range(self.num_replicas)
]
dist.all_gather(idx_gathered, idx)
dist.all_gather(feature_gathered, feature)
idx_gathered = all_gather(idx)
feature_gathered = all_gather(feature)
idx_gathered = torch.cat(idx_gathered, dim=0)
feature_gathered = torch.cat(feature_gathered, dim=0)
return idx_gathered, feature_gathered

View File

@ -1,11 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmengine.model import BaseModule
from ..builder import NECKS
from mmselfsup.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class LinearNeck(BaseModule):
"""The linear neck: fc only.
@ -19,17 +22,17 @@ class LinearNeck(BaseModule):
"""
def __init__(self,
in_channels,
out_channels,
with_avg_pool=True,
init_cfg=None):
in_channels: int,
out_channels: int,
with_avg_pool: bool = True,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super(LinearNeck, self).__init__(init_cfg)
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(in_channels, out_channels)
def forward(self, x):
def forward(self, x: Tuple[torch.Tensor]) -> List[torch.Tensor]:
assert len(x) == 1
x = x[0]
if self.with_avg_pool:

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class AliasMethod():
class AliasMethod(nn.Module):
"""The alias method for sampling.
From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
@ -12,11 +13,12 @@ class AliasMethod():
""" # noqa: E501
def __init__(self, probs: torch.Tensor) -> None:
super().__init__()
if probs.sum() > 1:
probs.div_(probs.sum())
K = len(probs)
self.prob = torch.zeros(K)
self.alias = torch.LongTensor([0] * K)
self.register_buffer('prob', torch.zeros(K))
self.register_buffer('alias', torch.LongTensor([0] * K))
# Sort the data into the outcomes with probabilities
# that are larger and smaller than 1/K.
@ -47,10 +49,6 @@ class AliasMethod():
for last_one in smaller + larger:
self.prob[last_one] = 1
def cuda(self) -> None:
self.prob = self.prob.cuda()
self.alias = self.alias.cuda()
def draw(self, N: int) -> None:
"""Draw N samples from multinomial.

View File

@ -4,6 +4,7 @@ import platform
import pytest
import torch
from mmengine.data import InstanceData
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.algorithms import NPID
@ -16,58 +17,38 @@ backbone = dict(
norm_cfg=dict(type='BN'))
neck = dict(
type='LinearNeck', in_channels=512, out_channels=2, with_avg_pool=True)
head = dict(type='ContrastiveHead', temperature=0.07)
loss = dict(type='mmcls.CrossEntropyLoss'),
head = dict(
type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.07)
memory_bank = dict(type='SimpleMemory', length=8, feat_dim=2, momentum=0.5)
preprocess_cfg = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
@pytest.mark.skipif(
not torch.cuda.is_available() or platform.system() == 'Windows',
reason='CUDA is not available or Windows mem limit')
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_npid():
with pytest.raises(AssertionError):
alg = NPID(
backbone=backbone,
neck=neck,
head=head,
memory_bank=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = NPID(
backbone=backbone,
neck=neck,
head=None,
loss=loss,
memory_bank=memory_bank,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = NPID(
backbone=backbone,
neck=neck,
head=head,
loss=None,
memory_bank=memory_bank,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
data_preprocessor = {
'mean': (123.675, 116.28, 103.53),
'std': (58.395, 57.12, 57.375),
'bgr_to_rgb': True
}
alg = NPID(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
memory_bank=memory_bank,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
data_preprocessor=copy.deepcopy(data_preprocessor))
fake_data = [{
'inputs': torch.randn((3, 224, 224)),
'data_sample': SelfSupDataSample()
'inputs': [torch.randn((3, 224, 224))],
'data_sample':
SelfSupDataSample(
sample_idx=InstanceData(value=torch.randint(0, 7, (1, ))))
} for _ in range(2)]
fake_inputs, _ = alg.preprocss_data(fake_data)
fake_backbone_out = alg.extract_feat(fake_inputs)
assert fake_backbone_out[0].size() == torch.Size([2, 512, 7, 7])
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
assert fake_loss['loss'] > -1
# test extract
fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor')
assert fake_feats[0].size() == torch.Size([2, 512, 7, 7])