diff --git a/configs/selfsup/_base_/datasets/imagenet_npid.py b/configs/selfsup/_base_/datasets/imagenet_npid.py index 9f0281f5..22436139 100644 --- a/configs/selfsup/_base_/datasets/imagenet_npid.py +++ b/configs/selfsup/_base_/datasets/imagenet_npid.py @@ -1,12 +1,18 @@ # dataset settings +custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False) dataset_type = 'mmcls.ImageNet' data_root = 'data/imagenet/' file_client_args = dict(backend='disk') train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='RandomResizedCrop', size=224, scale=(0.2, 1.)), - dict(type='RandomGrayscale', prob=0.2, keep_channels=True), + dict( + type='RandomResizedCrop', size=224, scale=(0.2, 1.), backend='pillow'), + dict( + type='RandomGrayscale', + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989)), dict( type='ColorJitter', brightness=0.4, @@ -22,7 +28,8 @@ train_pipeline = [ train_dataloader = dict( batch_size=32, - num_workers=4, + num_workers=8, + drop_last=True, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), dataset=dict( diff --git a/configs/selfsup/_base_/models/npid.py b/configs/selfsup/_base_/models/npid.py index 614849bc..f62a79a4 100644 --- a/configs/selfsup/_base_/models/npid.py +++ b/configs/selfsup/_base_/models/npid.py @@ -2,6 +2,10 @@ model = dict( type='NPID', neg_num=65536, + data_preprocessor=dict( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + bgr_to_rgb=True), backbone=dict( type='ResNet', depth=50, @@ -13,7 +17,9 @@ model = dict( in_channels=2048, out_channels=128, with_avg_pool=True), - head=dict(type='ContrastiveHead', temperature=0.07), - loss=dict(type='mmcls.CrossEntropyLoss'), + head=dict( + type='ContrastiveHead', + loss=dict(type='mmcls.CrossEntropyLoss'), + temperature=0.07), memory_bank=dict( type='SimpleMemory', length=1281167, feat_dim=128, momentum=0.5)) diff --git a/configs/selfsup/npid/npid_resnet50_8xb32-steplr-200e_in1k.py b/configs/selfsup/npid/npid_resnet50_8xb32-steplr-200e_in1k.py index c3f47bab..7b32ce80 100644 --- a/configs/selfsup/npid/npid_resnet50_8xb32-steplr-200e_in1k.py +++ b/configs/selfsup/npid/npid_resnet50_8xb32-steplr-200e_in1k.py @@ -6,7 +6,7 @@ _base_ = [ ] # runtime settings -# the max_keep_ckpts controls the max number of ckpt file in your work_dirs -# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt -# it will remove the oldest one to keep the number of total ckpts as 3 -checkpoint_config = dict(interval=10, max_keep_ckpts=3) +default_hooks = dict( + logger=dict(type='LoggerHook', interval=50), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3)) diff --git a/mmselfsup/models/algorithms/npid.py b/mmselfsup/models/algorithms/npid.py index 190f989e..e9a0e077 100644 --- a/mmselfsup/models/algorithms/npid.py +++ b/mmselfsup/models/algorithms/npid.py @@ -5,12 +5,11 @@ import torch import torch.nn as nn from mmselfsup.core import SelfSupDataSample -from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss, - build_memory, build_neck) +from mmselfsup.registry import MODELS from .base import BaseModel -@ALGORITHMS.register_module() +@MODELS.register_module() class NPID(BaseModel): """NPID. @@ -18,80 +17,79 @@ class NPID(BaseModel): Instance Discrimination `_. Args: - backbone (Dict): Config dict for module of backbone. - neck (Dict, optional): Config dict for module of deep features to - compact feature vectors. Defaults to None. - head (Dict, optional): Config dict for module of head functions. - Defaults to None. - loss (dict): Config dict for module of loss functions. - Defaults to None. - memory_bank (Dict, optional): Config dict for module of memory banks. - Defaults to None. + backbone (dict): Config dict for module of backbone. + neck (dict): Config dict for module of deep features to + compact feature vectors. + head (dict): Config dict for module of head functions. + memory_bank (dict): Config dict for module of memory banks. neg_num (int): Number of negative samples for each image. Defaults to 65536. ensure_neg (bool): If False, there is a small probability that negative samples contain positive ones. Defaults to False. - preprocess_cfg (Dict, optional): Config to preprocess images. + pretrained (str, optional): The pretrained checkpoint path, support + local path and remote path. Defaults to None. + data_preprocessor (Union[dict, nn.Module], optional): The config for + preprocessing input data. If None or no specified type, it will use + "SelfSupDataPreprocessor" as type. + See :class:`SelfSupDataPreprocessor` for more details. Defaults to None. init_cfg (Dict or List[Dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, - backbone: Dict, - neck: Optional[Dict] = None, - head: Optional[Dict] = None, - loss: Optional[Dict] = None, - memory_bank: Optional[Dict] = None, - neg_num: Optional[int] = 65536, - ensure_neg: Optional[bool] = False, - preprocess_cfg: Optional[Dict] = None, - init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: - super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg) - self.backbone = build_backbone(backbone) - if neck is not None: - self.neck = build_neck(neck) - assert head is not None - self.head = build_head(head) - assert loss is not None - self.loss = build_loss(loss) + backbone: dict, + neck: dict, + head: dict, + memory_bank: dict, + neg_num: int = 65536, + ensure_neg: bool = False, + pretrained: Optional[str] = None, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) assert memory_bank is not None - self.memory_bank = build_memory(memory_bank) + self.memory_bank = MODELS.build(memory_bank) self.neg_num = neg_num self.ensure_neg = ensure_neg - def extract_feat(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], + def extract_feat(self, batch_inputs: List[torch.Tensor], **kwarg) -> Tuple[torch.Tensor]: """Function to extract features from backbone. Args: - inputs (List[torch.Tensor]): The input images. + batch_inputs (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: Tuple[torch.Tensor]: backbone outputs. """ - x = self.backbone(inputs[0]) + x = self.backbone(batch_inputs[0]) return x - def forward_train(self, inputs: List[torch.Tensor], - data_samples: List[SelfSupDataSample], - **kwargs) -> Dict[str, torch.Tensor]: + def loss(self, batch_inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: - inputs (List[torch.Tensor]): The input images. + batch_inputs (List[torch.Tensor]): The input images. data_samples (List[SelfSupDataSample]): All elements required during the forward function. Returns: Dict[str, Tensor]: A dictionary of loss components. """ - feature = self.extract_feat(inputs[0]) - idx = [data_sample.idx for data_sample in data_samples] + feature = self.backbone(batch_inputs[0]) + idx = [data_sample.sample_idx.value for data_sample in data_samples] idx = torch.cat(idx) if self.with_neck: feature = self.neck(feature)[0] @@ -119,8 +117,7 @@ class NPID(BaseModel): [pos_feat, feature]).unsqueeze(-1) neg_logits = torch.bmm(neg_feat, feature.unsqueeze(2)).squeeze(2) - logits, labels = self.head(pos_logits, neg_logits) - loss = self.loss(logits, labels) + loss = self.head(pos_logits, neg_logits) losses = dict(loss=loss) # update memory bank with torch.no_grad(): diff --git a/mmselfsup/models/heads/contrastive_head.py b/mmselfsup/models/heads/contrastive_head.py index ac495737..789d46db 100644 --- a/mmselfsup/models/heads/contrastive_head.py +++ b/mmselfsup/models/heads/contrastive_head.py @@ -1,6 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple - import torch from mmengine.model import BaseModule @@ -26,8 +24,7 @@ class ContrastiveHead(BaseModule): self.loss = MODELS.build(loss) self.temperature = temperature - def forward(self, pos: torch.Tensor, - neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor: """Forward function to compute contrastive loss. Args: diff --git a/mmselfsup/models/memories/simple_memory.py b/mmselfsup/models/memories/simple_memory.py index c55ddd3e..d4a42475 100644 --- a/mmselfsup/models/memories/simple_memory.py +++ b/mmselfsup/models/memories/simple_memory.py @@ -2,15 +2,15 @@ from typing import Tuple import torch -import torch.distributed as dist import torch.nn as nn from mmcv.runner import BaseModule, get_dist_info +from mmengine.dist import all_gather +from mmselfsup.registry import MODELS from mmselfsup.utils import AliasMethod -from ..builder import MEMORIES -@MEMORIES.register_module() +@MODELS.register_module() class SimpleMemory(BaseModule): """Simple memory bank (e.g., for NPID). @@ -27,11 +27,10 @@ class SimpleMemory(BaseModule): **kwargs) -> None: super().__init__() self.rank, self.num_replicas = get_dist_info() - self.feature_bank = torch.randn(length, feat_dim).cuda() - self.feature_bank = nn.functional.normalize(self.feature_bank).cuda() + self.register_buffer('feature_bank', torch.randn(length, feat_dim)) + self.feature_bank = nn.functional.normalize(self.feature_bank) self.momentum = momentum self.multinomial = AliasMethod(torch.ones(length)) - self.multinomial.cuda() def update(self, idx: torch.Tensor, feature: torch.Tensor) -> None: """Update features in memory bank. @@ -61,14 +60,8 @@ class SimpleMemory(BaseModule): - idx_gathered: Gathered indices. - feature_gathered: Gathered features. """ - idx_gathered = [ - torch.ones_like(idx).cuda() for _ in range(self.num_replicas) - ] - feature_gathered = [ - torch.ones_like(feature).cuda() for _ in range(self.num_replicas) - ] - dist.all_gather(idx_gathered, idx) - dist.all_gather(feature_gathered, feature) + idx_gathered = all_gather(idx) + feature_gathered = all_gather(feature) idx_gathered = torch.cat(idx_gathered, dim=0) feature_gathered = torch.cat(feature_gathered, dim=0) return idx_gathered, feature_gathered diff --git a/mmselfsup/models/necks/linear_neck.py b/mmselfsup/models/necks/linear_neck.py index c05268e7..6e88c7c2 100644 --- a/mmselfsup/models/necks/linear_neck.py +++ b/mmselfsup/models/necks/linear_neck.py @@ -1,11 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple, Union + +import torch import torch.nn as nn -from mmcv.runner import BaseModule +from mmengine.model import BaseModule -from ..builder import NECKS +from mmselfsup.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class LinearNeck(BaseModule): """The linear neck: fc only. @@ -19,17 +22,17 @@ class LinearNeck(BaseModule): """ def __init__(self, - in_channels, - out_channels, - with_avg_pool=True, - init_cfg=None): + in_channels: int, + out_channels: int, + with_avg_pool: bool = True, + init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: super(LinearNeck, self).__init__(init_cfg) self.with_avg_pool = with_avg_pool if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(in_channels, out_channels) - def forward(self, x): + def forward(self, x: Tuple[torch.Tensor]) -> List[torch.Tensor]: assert len(x) == 1 x = x[0] if self.with_avg_pool: diff --git a/mmselfsup/utils/alias_multinomial.py b/mmselfsup/utils/alias_multinomial.py index 4002bbf2..02ad40fb 100644 --- a/mmselfsup/utils/alias_multinomial.py +++ b/mmselfsup/utils/alias_multinomial.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +import torch.nn as nn -class AliasMethod(): +class AliasMethod(nn.Module): """The alias method for sampling. From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ @@ -12,11 +13,12 @@ class AliasMethod(): """ # noqa: E501 def __init__(self, probs: torch.Tensor) -> None: + super().__init__() if probs.sum() > 1: probs.div_(probs.sum()) K = len(probs) - self.prob = torch.zeros(K) - self.alias = torch.LongTensor([0] * K) + self.register_buffer('prob', torch.zeros(K)) + self.register_buffer('alias', torch.LongTensor([0] * K)) # Sort the data into the outcomes with probabilities # that are larger and smaller than 1/K. @@ -47,10 +49,6 @@ class AliasMethod(): for last_one in smaller + larger: self.prob[last_one] = 1 - def cuda(self) -> None: - self.prob = self.prob.cuda() - self.alias = self.alias.cuda() - def draw(self, N: int) -> None: """Draw N samples from multinomial. diff --git a/tests/test_models/test_algorithms/test_npid.py b/tests/test_models/test_algorithms/test_npid.py index f9dd2dcb..8f42e809 100644 --- a/tests/test_models/test_algorithms/test_npid.py +++ b/tests/test_models/test_algorithms/test_npid.py @@ -4,6 +4,7 @@ import platform import pytest import torch +from mmengine.data import InstanceData from mmselfsup.core import SelfSupDataSample from mmselfsup.models.algorithms import NPID @@ -16,58 +17,38 @@ backbone = dict( norm_cfg=dict(type='BN')) neck = dict( type='LinearNeck', in_channels=512, out_channels=2, with_avg_pool=True) -head = dict(type='ContrastiveHead', temperature=0.07) -loss = dict(type='mmcls.CrossEntropyLoss'), +head = dict( + type='ContrastiveHead', + loss=dict(type='mmcls.CrossEntropyLoss'), + temperature=0.07) memory_bank = dict(type='SimpleMemory', length=8, feat_dim=2, momentum=0.5) -preprocess_cfg = { - 'mean': [0.5, 0.5, 0.5], - 'std': [0.5, 0.5, 0.5], - 'to_rgb': True -} -@pytest.mark.skipif( - not torch.cuda.is_available() or platform.system() == 'Windows', - reason='CUDA is not available or Windows mem limit') +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') def test_npid(): - with pytest.raises(AssertionError): - alg = NPID( - backbone=backbone, - neck=neck, - head=head, - memory_bank=None, - loss=loss, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - with pytest.raises(AssertionError): - alg = NPID( - backbone=backbone, - neck=neck, - head=None, - loss=loss, - memory_bank=memory_bank, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - with pytest.raises(AssertionError): - alg = NPID( - backbone=backbone, - neck=neck, - head=head, - loss=None, - memory_bank=memory_bank, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) - + data_preprocessor = { + 'mean': (123.675, 116.28, 103.53), + 'std': (58.395, 57.12, 57.375), + 'bgr_to_rgb': True + } alg = NPID( backbone=backbone, neck=neck, head=head, - loss=loss, memory_bank=memory_bank, - preprocess_cfg=copy.deepcopy(preprocess_cfg)) + data_preprocessor=copy.deepcopy(data_preprocessor)) fake_data = [{ - 'inputs': torch.randn((3, 224, 224)), - 'data_sample': SelfSupDataSample() + 'inputs': [torch.randn((3, 224, 224))], + 'data_sample': + SelfSupDataSample( + sample_idx=InstanceData(value=torch.randint(0, 7, (1, )))) } for _ in range(2)] - fake_inputs, _ = alg.preprocss_data(fake_data) - fake_backbone_out = alg.extract_feat(fake_inputs) - assert fake_backbone_out[0].size() == torch.Size([2, 512, 7, 7]) + fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data) + fake_loss = alg(fake_inputs, fake_data_samples, mode='loss') + assert fake_loss['loss'] > -1 + + # test extract + fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor') + assert fake_feats[0].size() == torch.Size([2, 512, 7, 7])