[Refactor]: refactor npid algorithm
parent
3e5d18ea83
commit
4c5c4b88f4
|
@ -1,12 +1,18 @@
|
|||
# dataset settings
|
||||
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='RandomResizedCrop', size=224, scale=(0.2, 1.)),
|
||||
dict(type='RandomGrayscale', prob=0.2, keep_channels=True),
|
||||
dict(
|
||||
type='RandomResizedCrop', size=224, scale=(0.2, 1.), backend='pillow'),
|
||||
dict(
|
||||
type='RandomGrayscale',
|
||||
prob=0.2,
|
||||
keep_channels=True,
|
||||
channel_weights=(0.114, 0.587, 0.2989)),
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
|
@ -22,7 +28,8 @@ train_pipeline = [
|
|||
|
||||
train_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
num_workers=8,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
model = dict(
|
||||
type='NPID',
|
||||
neg_num=65536,
|
||||
data_preprocessor=dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -13,7 +17,9 @@ model = dict(
|
|||
in_channels=2048,
|
||||
out_channels=128,
|
||||
with_avg_pool=True),
|
||||
head=dict(type='ContrastiveHead', temperature=0.07),
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
head=dict(
|
||||
type='ContrastiveHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
temperature=0.07),
|
||||
memory_bank=dict(
|
||||
type='SimpleMemory', length=1281167, feat_dim=128, momentum=0.5))
|
||||
|
|
|
@ -6,7 +6,7 @@ _base_ = [
|
|||
]
|
||||
|
||||
# runtime settings
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', interval=50),
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
|
||||
|
|
|
@ -5,12 +5,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
|
||||
build_memory, build_neck)
|
||||
from mmselfsup.registry import MODELS
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@ALGORITHMS.register_module()
|
||||
@MODELS.register_module()
|
||||
class NPID(BaseModel):
|
||||
"""NPID.
|
||||
|
||||
|
@ -18,80 +17,79 @@ class NPID(BaseModel):
|
|||
Instance Discrimination <https://arxiv.org/abs/1805.01978>`_.
|
||||
|
||||
Args:
|
||||
backbone (Dict): Config dict for module of backbone.
|
||||
neck (Dict, optional): Config dict for module of deep features to
|
||||
compact feature vectors. Defaults to None.
|
||||
head (Dict, optional): Config dict for module of head functions.
|
||||
Defaults to None.
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
Defaults to None.
|
||||
memory_bank (Dict, optional): Config dict for module of memory banks.
|
||||
Defaults to None.
|
||||
backbone (dict): Config dict for module of backbone.
|
||||
neck (dict): Config dict for module of deep features to
|
||||
compact feature vectors.
|
||||
head (dict): Config dict for module of head functions.
|
||||
memory_bank (dict): Config dict for module of memory banks.
|
||||
neg_num (int): Number of negative samples for each image.
|
||||
Defaults to 65536.
|
||||
ensure_neg (bool): If False, there is a small probability
|
||||
that negative samples contain positive ones. Defaults to False.
|
||||
preprocess_cfg (Dict, optional): Config to preprocess images.
|
||||
pretrained (str, optional): The pretrained checkpoint path, support
|
||||
local path and remote path. Defaults to None.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
preprocessing input data. If None or no specified type, it will use
|
||||
"SelfSupDataPreprocessor" as type.
|
||||
See :class:`SelfSupDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: Dict,
|
||||
neck: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
memory_bank: Optional[Dict] = None,
|
||||
neg_num: Optional[int] = 65536,
|
||||
ensure_neg: Optional[bool] = False,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
self.backbone = build_backbone(backbone)
|
||||
if neck is not None:
|
||||
self.neck = build_neck(neck)
|
||||
assert head is not None
|
||||
self.head = build_head(head)
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
backbone: dict,
|
||||
neck: dict,
|
||||
head: dict,
|
||||
memory_bank: dict,
|
||||
neg_num: int = 65536,
|
||||
ensure_neg: bool = False,
|
||||
pretrained: Optional[str] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
pretrained=pretrained,
|
||||
data_preprocessor=data_preprocessor,
|
||||
init_cfg=init_cfg)
|
||||
assert memory_bank is not None
|
||||
self.memory_bank = build_memory(memory_bank)
|
||||
self.memory_bank = MODELS.build(memory_bank)
|
||||
|
||||
self.neg_num = neg_num
|
||||
self.ensure_neg = ensure_neg
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
x = self.backbone(inputs[0])
|
||||
x = self.backbone(batch_inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
def loss(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
feature = self.extract_feat(inputs[0])
|
||||
idx = [data_sample.idx for data_sample in data_samples]
|
||||
feature = self.backbone(batch_inputs[0])
|
||||
idx = [data_sample.sample_idx.value for data_sample in data_samples]
|
||||
idx = torch.cat(idx)
|
||||
if self.with_neck:
|
||||
feature = self.neck(feature)[0]
|
||||
|
@ -119,8 +117,7 @@ class NPID(BaseModel):
|
|||
[pos_feat, feature]).unsqueeze(-1)
|
||||
neg_logits = torch.bmm(neg_feat, feature.unsqueeze(2)).squeeze(2)
|
||||
|
||||
logits, labels = self.head(pos_logits, neg_logits)
|
||||
loss = self.loss(logits, labels)
|
||||
loss = self.head(pos_logits, neg_logits)
|
||||
losses = dict(loss=loss)
|
||||
# update memory bank
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
|
@ -26,8 +24,7 @@ class ContrastiveHead(BaseModule):
|
|||
self.loss = MODELS.build(loss)
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, pos: torch.Tensor,
|
||||
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function to compute contrastive loss.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -2,15 +2,15 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, get_dist_info
|
||||
from mmengine.dist import all_gather
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.utils import AliasMethod
|
||||
from ..builder import MEMORIES
|
||||
|
||||
|
||||
@MEMORIES.register_module()
|
||||
@MODELS.register_module()
|
||||
class SimpleMemory(BaseModule):
|
||||
"""Simple memory bank (e.g., for NPID).
|
||||
|
||||
|
@ -27,11 +27,10 @@ class SimpleMemory(BaseModule):
|
|||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.rank, self.num_replicas = get_dist_info()
|
||||
self.feature_bank = torch.randn(length, feat_dim).cuda()
|
||||
self.feature_bank = nn.functional.normalize(self.feature_bank).cuda()
|
||||
self.register_buffer('feature_bank', torch.randn(length, feat_dim))
|
||||
self.feature_bank = nn.functional.normalize(self.feature_bank)
|
||||
self.momentum = momentum
|
||||
self.multinomial = AliasMethod(torch.ones(length))
|
||||
self.multinomial.cuda()
|
||||
|
||||
def update(self, idx: torch.Tensor, feature: torch.Tensor) -> None:
|
||||
"""Update features in memory bank.
|
||||
|
@ -61,14 +60,8 @@ class SimpleMemory(BaseModule):
|
|||
- idx_gathered: Gathered indices.
|
||||
- feature_gathered: Gathered features.
|
||||
"""
|
||||
idx_gathered = [
|
||||
torch.ones_like(idx).cuda() for _ in range(self.num_replicas)
|
||||
]
|
||||
feature_gathered = [
|
||||
torch.ones_like(feature).cuda() for _ in range(self.num_replicas)
|
||||
]
|
||||
dist.all_gather(idx_gathered, idx)
|
||||
dist.all_gather(feature_gathered, feature)
|
||||
idx_gathered = all_gather(idx)
|
||||
feature_gathered = all_gather(feature)
|
||||
idx_gathered = torch.cat(idx_gathered, dim=0)
|
||||
feature_gathered = torch.cat(feature_gathered, dim=0)
|
||||
return idx_gathered, feature_gathered
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from ..builder import NECKS
|
||||
from mmselfsup.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class LinearNeck(BaseModule):
|
||||
"""The linear neck: fc only.
|
||||
|
||||
|
@ -19,17 +22,17 @@ class LinearNeck(BaseModule):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_avg_pool=True,
|
||||
init_cfg=None):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
with_avg_pool: bool = True,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super(LinearNeck, self).__init__(init_cfg)
|
||||
self.with_avg_pool = with_avg_pool
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Tuple[torch.Tensor]) -> List[torch.Tensor]:
|
||||
assert len(x) == 1
|
||||
x = x[0]
|
||||
if self.with_avg_pool:
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class AliasMethod():
|
||||
class AliasMethod(nn.Module):
|
||||
"""The alias method for sampling.
|
||||
|
||||
From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
|
||||
|
@ -12,11 +13,12 @@ class AliasMethod():
|
|||
""" # noqa: E501
|
||||
|
||||
def __init__(self, probs: torch.Tensor) -> None:
|
||||
super().__init__()
|
||||
if probs.sum() > 1:
|
||||
probs.div_(probs.sum())
|
||||
K = len(probs)
|
||||
self.prob = torch.zeros(K)
|
||||
self.alias = torch.LongTensor([0] * K)
|
||||
self.register_buffer('prob', torch.zeros(K))
|
||||
self.register_buffer('alias', torch.LongTensor([0] * K))
|
||||
|
||||
# Sort the data into the outcomes with probabilities
|
||||
# that are larger and smaller than 1/K.
|
||||
|
@ -47,10 +49,6 @@ class AliasMethod():
|
|||
for last_one in smaller + larger:
|
||||
self.prob[last_one] = 1
|
||||
|
||||
def cuda(self) -> None:
|
||||
self.prob = self.prob.cuda()
|
||||
self.alias = self.alias.cuda()
|
||||
|
||||
def draw(self, N: int) -> None:
|
||||
"""Draw N samples from multinomial.
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import platform
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.data import InstanceData
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.models.algorithms import NPID
|
||||
|
@ -16,58 +17,38 @@ backbone = dict(
|
|||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='LinearNeck', in_channels=512, out_channels=2, with_avg_pool=True)
|
||||
head = dict(type='ContrastiveHead', temperature=0.07)
|
||||
loss = dict(type='mmcls.CrossEntropyLoss'),
|
||||
head = dict(
|
||||
type='ContrastiveHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
temperature=0.07)
|
||||
memory_bank = dict(type='SimpleMemory', length=8, feat_dim=2, momentum=0.5)
|
||||
preprocess_cfg = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or platform.system() == 'Windows',
|
||||
reason='CUDA is not available or Windows mem limit')
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_npid():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = NPID(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
memory_bank=None,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = NPID(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
loss=loss,
|
||||
memory_bank=memory_bank,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = NPID(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=None,
|
||||
memory_bank=memory_bank,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
|
||||
data_preprocessor = {
|
||||
'mean': (123.675, 116.28, 103.53),
|
||||
'std': (58.395, 57.12, 57.375),
|
||||
'bgr_to_rgb': True
|
||||
}
|
||||
alg = NPID(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
memory_bank=memory_bank,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
data_preprocessor=copy.deepcopy(data_preprocessor))
|
||||
|
||||
fake_data = [{
|
||||
'inputs': torch.randn((3, 224, 224)),
|
||||
'data_sample': SelfSupDataSample()
|
||||
'inputs': [torch.randn((3, 224, 224))],
|
||||
'data_sample':
|
||||
SelfSupDataSample(
|
||||
sample_idx=InstanceData(value=torch.randint(0, 7, (1, ))))
|
||||
} for _ in range(2)]
|
||||
|
||||
fake_inputs, _ = alg.preprocss_data(fake_data)
|
||||
fake_backbone_out = alg.extract_feat(fake_inputs)
|
||||
assert fake_backbone_out[0].size() == torch.Size([2, 512, 7, 7])
|
||||
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
|
||||
assert fake_loss['loss'] > -1
|
||||
|
||||
# test extract
|
||||
fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor')
|
||||
assert fake_feats[0].size() == torch.Size([2, 512, 7, 7])
|
||||
|
|
Loading…
Reference in New Issue