[Refactor]: Refactor ODC

pull/352/head
YuanLiuuuuuu 2022-05-25 13:53:47 +00:00 committed by fangyixiao18
parent 0baac605d1
commit 2f2813ecd4
2 changed files with 145 additions and 38 deletions

View File

@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine.data import InstanceData
from mmselfsup.core import SelfSupDataSample
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
build_memory, build_neck)
from ..utils import Sobel
@ -17,27 +21,31 @@ class ODC(BaseModel):
`core/hooks/odc_hook.py`.
Args:
backbone (dict): Config dict for module of backbone.
backbone (Dict): Config dict for module of backbone.
with_sobel (bool): Whether to apply a Sobel filter on images.
Defaults to False.
neck (dict): Config dict for module of deep features to compact feature
vectors. Defaults to None.
head (dict): Config dict for module of head functions.
neck (Dict, optional): Config dict for module of deep features to
compact feature vectors. Defaults to None.
head (Dict, optional): Config dict for module of head functions.
Defaults to None.
loss (dict): Config dict for module of loss functions.
loss (Dict, optional): Config dict for module of loss functions.
memory_bank (Dict, optional): Module of memory banks. Defaults to None.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None.
memory_bank (dict): Module of memory banks. Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone,
with_sobel=False,
neck=None,
head=None,
loss=None,
memory_bank=None,
init_cfg=None):
super(ODC, self).__init__(init_cfg)
backbone: Dict,
with_sobel: Optional[bool] = False,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
memory_bank: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
self.with_sobel = with_sobel
if with_sobel:
self.sobel_layer = Sobel()
@ -57,35 +65,41 @@ class ODC(BaseModel):
dtype=torch.float32).cuda()
self.loss_weight /= self.loss_weight.sum()
def extract_feat(self, img):
"""Function to extract features from backbone.
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
tuple[Tensor]: backbone outputs.
Tuple[torch.Tensor]: backbone outputs.
"""
if self.with_sobel:
img = self.sobel_layer(img)
img = self.sobel_layer(inputs[0])
x = self.backbone(img)
return x
def forward_train(self, img, idx, **kwargs):
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
idx (Tensor): Index corresponding to each image.
kwargs: Any keyword arguments to be used to forward.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
dict[str, Tensor]: A dictionary of loss components.
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# forward & backward
feature = self.extract_feat(img)
feature = self.extract_feat(inputs[0])
idx = [data_sample.idx for data_sample in data_samples]
idx = torch.cat(idx)
if self.with_neck:
feature = self.neck(feature)
outs = self.head(feature)
@ -104,20 +118,25 @@ class ODC(BaseModel):
return losses
def forward_test(self, img, **kwargs):
"""Forward computation during test.
def forward_test(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> List[SelfSupDataSample]:
"""The forward function in testing
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
dict[str, Tensor]: A dictionary of output features.
List[SelfSupDataSample]: The prediction from model.
"""
feature = self.extract_feat(img) # tuple
feature = self.extract_feat(inputs[0]) # tuple
if self.with_neck:
feature = self.neck(feature)
outs = self.head(feature)
keys = [f'head{i}' for i in range(len(outs))]
out_tensors = [out.cpu() for out in outs] # NxC
return dict(zip(keys, out_tensors))
for i in range(outs[0].shape[0]):
prediction_data = {key: out[i] for key, out in zip(keys, outs)}
prediction = InstanceData(**prediction_data)
data_samples[i].prediction = prediction
return data_samples

View File

@ -0,0 +1,88 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.core import SelfSupDataSample
from mmselfsup.models.algorithms import ODC
num_classes = 5
backbone = dict(
type='ResNet',
depth=18,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
neck = dict(
type='ODCNeck',
in_channels=512,
hid_channels=2,
out_channels=2,
norm_cfg=dict(type='BN1d'),
with_avg_pool=True)
head = dict(
type='ClsHead',
with_avg_pool=False,
in_channels=2,
num_classes=num_classes)
loss = dict(type='CrossEntropyLoss')
memory_bank = dict(
type='ODCMemory',
length=8,
feat_dim=2,
momentum=0.5,
num_classes=num_classes,
min_cluster=2,
debug=False)
preprocess_cfg = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
@pytest.mark.skipif(
not torch.cuda.is_available() or platform.system() == 'Windows',
reason='CUDA is not available or Windows mem limit')
def test_odc():
with pytest.raises(AssertionError):
alg = ODC(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
memory_bank=None,
preprocess_cfg=preprocess_cfg)
with pytest.raises(AssertionError):
alg = ODC(
backbone=backbone,
neck=neck,
head=None,
memory_bank=memory_bank,
preprocess_cfg=preprocess_cfg)
with pytest.raises(AssertionError):
alg = ODC(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
memory_bank=memory_bank,
preprocess_cfg=preprocess_cfg)
alg = ODC(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
memory_bank=memory_bank,
preprocess_cfg=preprocess_cfg)
alg.set_reweight()
fake_data = [{
'inputs': torch.randn((3, 224, 224)),
'data_sample': SelfSupDataSample()
} for _ in range(2)]
fake_out = alg(fake_data, return_loss=False)
assert hasattr(fake_out[0].prediction, 'head0')
assert fake_out[0].prediction.head0.size() == torch.Size([num_classes])