[Refactor]: Refactor ODC
parent
0baac605d1
commit
2f2813ecd4
|
@ -1,6 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine.data import InstanceData
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
|
||||
build_memory, build_neck)
|
||||
from ..utils import Sobel
|
||||
|
@ -17,27 +21,31 @@ class ODC(BaseModel):
|
|||
`core/hooks/odc_hook.py`.
|
||||
|
||||
Args:
|
||||
backbone (dict): Config dict for module of backbone.
|
||||
backbone (Dict): Config dict for module of backbone.
|
||||
with_sobel (bool): Whether to apply a Sobel filter on images.
|
||||
Defaults to False.
|
||||
neck (dict): Config dict for module of deep features to compact feature
|
||||
vectors. Defaults to None.
|
||||
head (dict): Config dict for module of head functions.
|
||||
neck (Dict, optional): Config dict for module of deep features to
|
||||
compact feature vectors. Defaults to None.
|
||||
head (Dict, optional): Config dict for module of head functions.
|
||||
Defaults to None.
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
loss (Dict, optional): Config dict for module of loss functions.
|
||||
memory_bank (Dict, optional): Module of memory banks. Defaults to None.
|
||||
preprocess_cfg (Dict, optional): Config to preprocess images.
|
||||
Defaults to None.
|
||||
memory_bank (dict): Module of memory banks. Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone,
|
||||
with_sobel=False,
|
||||
neck=None,
|
||||
head=None,
|
||||
loss=None,
|
||||
memory_bank=None,
|
||||
init_cfg=None):
|
||||
super(ODC, self).__init__(init_cfg)
|
||||
backbone: Dict,
|
||||
with_sobel: Optional[bool] = False,
|
||||
neck: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
memory_bank: Optional[Dict] = None,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
self.with_sobel = with_sobel
|
||||
if with_sobel:
|
||||
self.sobel_layer = Sobel()
|
||||
|
@ -57,35 +65,41 @@ class ODC(BaseModel):
|
|||
dtype=torch.float32).cuda()
|
||||
self.loss_weight /= self.loss_weight.sum()
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Function to extract features from backbone.
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""The forward function to extract features.
|
||||
|
||||
Args:
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: backbone outputs.
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
if self.with_sobel:
|
||||
img = self.sobel_layer(img)
|
||||
img = self.sobel_layer(inputs[0])
|
||||
x = self.backbone(img)
|
||||
return x
|
||||
|
||||
def forward_train(self, img, idx, **kwargs):
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""Forward computation during training.
|
||||
|
||||
Args:
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
idx (Tensor): Index corresponding to each image.
|
||||
kwargs: Any keyword arguments to be used to forward.
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
# forward & backward
|
||||
feature = self.extract_feat(img)
|
||||
feature = self.extract_feat(inputs[0])
|
||||
idx = [data_sample.idx for data_sample in data_samples]
|
||||
idx = torch.cat(idx)
|
||||
if self.with_neck:
|
||||
feature = self.neck(feature)
|
||||
outs = self.head(feature)
|
||||
|
@ -104,20 +118,25 @@ class ODC(BaseModel):
|
|||
|
||||
return losses
|
||||
|
||||
def forward_test(self, img, **kwargs):
|
||||
"""Forward computation during test.
|
||||
|
||||
def forward_test(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> List[SelfSupDataSample]:
|
||||
"""The forward function in testing
|
||||
Args:
|
||||
img (Tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of output features.
|
||||
List[SelfSupDataSample]: The prediction from model.
|
||||
"""
|
||||
feature = self.extract_feat(img) # tuple
|
||||
feature = self.extract_feat(inputs[0]) # tuple
|
||||
if self.with_neck:
|
||||
feature = self.neck(feature)
|
||||
outs = self.head(feature)
|
||||
keys = [f'head{i}' for i in range(len(outs))]
|
||||
out_tensors = [out.cpu() for out in outs] # NxC
|
||||
return dict(zip(keys, out_tensors))
|
||||
|
||||
for i in range(outs[0].shape[0]):
|
||||
prediction_data = {key: out[i] for key, out in zip(keys, outs)}
|
||||
prediction = InstanceData(**prediction_data)
|
||||
data_samples[i].prediction = prediction
|
||||
return data_samples
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.models.algorithms import ODC
|
||||
|
||||
num_classes = 5
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
neck = dict(
|
||||
type='ODCNeck',
|
||||
in_channels=512,
|
||||
hid_channels=2,
|
||||
out_channels=2,
|
||||
norm_cfg=dict(type='BN1d'),
|
||||
with_avg_pool=True)
|
||||
head = dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=False,
|
||||
in_channels=2,
|
||||
num_classes=num_classes)
|
||||
loss = dict(type='CrossEntropyLoss')
|
||||
memory_bank = dict(
|
||||
type='ODCMemory',
|
||||
length=8,
|
||||
feat_dim=2,
|
||||
momentum=0.5,
|
||||
num_classes=num_classes,
|
||||
min_cluster=2,
|
||||
debug=False)
|
||||
preprocess_cfg = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or platform.system() == 'Windows',
|
||||
reason='CUDA is not available or Windows mem limit')
|
||||
def test_odc():
|
||||
with pytest.raises(AssertionError):
|
||||
alg = ODC(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
memory_bank=None,
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = ODC(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
memory_bank=memory_bank,
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
alg = ODC(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
memory_bank=memory_bank,
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
|
||||
alg = ODC(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
memory_bank=memory_bank,
|
||||
preprocess_cfg=preprocess_cfg)
|
||||
alg.set_reweight()
|
||||
|
||||
fake_data = [{
|
||||
'inputs': torch.randn((3, 224, 224)),
|
||||
'data_sample': SelfSupDataSample()
|
||||
} for _ in range(2)]
|
||||
fake_out = alg(fake_data, return_loss=False)
|
||||
assert hasattr(fake_out[0].prediction, 'head0')
|
||||
assert fake_out[0].prediction.head0.size() == torch.Size([num_classes])
|
Loading…
Reference in New Issue