[Refactor]: Refactor relative loc

pull/352/head
YuanLiuuuuuu 2022-06-16 08:09:37 +00:00 committed by fangyixiao18
parent 5ba17adb23
commit bcc4576ace
10 changed files with 148 additions and 211 deletions

View File

@ -14,16 +14,6 @@ train_pipeline = [
pseudo_label_keys=['patch_box', 'patch_label', 'unpatched_img'],
meta_keys=['img_path'])
]
val_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=292),
dict(type='CenterCrop', crop_size=255),
dict(type='RandomPatchWithLabels'),
dict(
type='PackSelfSupInputs',
pseudo_label_keys=['patch_label'],
meta_keys=['img_path'])
]
train_dataloader = dict(
batch_size=64,
@ -36,14 +26,3 @@ train_dataloader = dict(
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=64,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/val.txt',
data_prefix=dict(img_path='val/'),
pipeline=val_pipeline))

View File

@ -1,6 +1,11 @@
# model settings
model = dict(
type='RelativeLoc',
data_preprocessor=dict(
type='mmselfsup.RelativeLocDataPreprocessor',
mean=[124, 117, 104],
std=[59, 58, 58],
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
@ -14,11 +19,11 @@ model = dict(
with_avg_pool=True),
head=dict(
type='ClsHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
with_avg_pool=False,
in_channels=4096,
num_classes=8,
init_cfg=[
dict(type='Normal', std=0.005, layer='Linear'),
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]),
loss=dict(type='mmcls.CrossEntropyLoss'))
]))

View File

@ -1,4 +1,4 @@
# optimizer
# optimizer wrapper
optimizer = dict(type='SGD', lr=0.03, weight_decay=1e-4, momentum=0.9)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)

View File

@ -5,19 +5,18 @@ _base_ = [
'../_base_/default_runtime.py',
]
# optimizer
optimizer = dict(
type='SGD',
lr=0.2,
weight_decay=1e-4,
momentum=0.9,
paramwise_options={
'\\Aneck.': dict(weight_decay=5e-4),
'\\Ahead.': dict(weight_decay=5e-4)
})
# optimizer wrapper
optimizer = dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=1e-4)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
paramwise_cfg=dict(custom_keys={
'neck': dict(decay_mult=5.0),
'head': dict(decay_mult=5.0)
}))
# learning rate scheduler
scheduler = [
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.1,
@ -29,8 +28,10 @@ scheduler = [
]
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=70)
# pre-train for 70 epochs
train_cfg = dict(max_epochs=70)
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))

View File

@ -143,12 +143,12 @@ class SelfSupDataSample(BaseDataElement):
del self._pred_label
@property
def pseudo_label(self) -> InstanceData:
def pseudo_label(self) -> BaseDataElement:
return self._pseudo_label
@pseudo_label.setter
def pseudo_label(self, value: InstanceData):
self.set_field(value, '_pseudo_label', dtype=InstanceData)
def pseudo_label(self, value: BaseDataElement):
self.set_field(value, '_pseudo_label', dtype=BaseDataElement)
@pseudo_label.deleter
def pseudo_label(self):

View File

@ -1,179 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
import torch
from mmengine.data import InstanceData
from mmengine.data import LabelData
from mmselfsup.core import SelfSupDataSample
from mmselfsup.utils import get_module_device
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
build_neck)
from ..builder import MODELS
from .base import BaseModel
@ALGORITHMS.register_module()
@MODELS.register_module()
class RelativeLoc(BaseModel):
"""Relative patch location.
Implementation of `Unsupervised Visual Representation Learning
by Context Prediction <https://arxiv.org/abs/1505.05192>`_.
Args:
backbone (Dict, optional): Config dict for module of backbone.
Defaults to None.
neck (Dict, optional): Config dict for module of deep features
to compact feature vectors. Defaults to None.
head (Dict, optional): Config dict for module of head function.
Defaults to None.
loss (Dict): Config dict for module of loss function. Defaults to None.
preprocess_cfg (Dict, optional): Config dict to preprocess images.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
Implementation of `Unsupervised Visual Representation Learning by Context
Prediction <https://arxiv.org/abs/1505.05192>`_.
"""
def __init__(self,
backbone: Optional[Dict] = None,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
**kwargs) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert neck is not None
self.neck = build_neck(neck)
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
def extract_feat(self, batch_inputs: List[torch.Tensor],
**kwargs) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
batch_inputs (List[torch.Tensor]): The input images.
Returns:
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(inputs[0])
x = self.backbone(batch_inputs[0])
return x
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
x1 = self.backbone(inputs[0])
x2 = self.backbone(inputs[1])
x1 = self.backbone(batch_inputs[0])
x2 = self.backbone(batch_inputs[1])
x = (torch.cat((x1[0], x2[0]), dim=1), )
x = self.neck(x)
outs = self.head(x)
patch_label = [
data_sample.patch_label.value for data_sample in data_samples
data_sample.pseudo_label.patch_label
for data_sample in data_samples
]
patch_label = torch.flatten(torch.stack(patch_label, 0))
loss = self.loss(outs[0], patch_label)
loss = self.head(x, patch_label)
losses = dict(loss=loss)
return losses
def forward_test(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> List[SelfSupDataSample]:
def predict(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> List[SelfSupDataSample]:
"""The forward function in testing.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
List[SelfSupDataSample]: The prediction from model.
"""
x1 = self.backbone(inputs[0])
x2 = self.backbone(inputs[1])
x1 = self.backbone(batch_inputs[0])
x2 = self.backbone(batch_inputs[1])
x = (torch.cat((x1[0], x2[0]), dim=1), )
x = self.neck(x)
outs = self.head(x)
outs = self.head.logits(x)
keys = [f'head{i}' for i in self.backbone.out_indices]
outs = [torch.chunk(out, len(outs[0]) // 8, 0) for out in outs]
for i in range(len(outs[0])):
prediction_data = {key: out[i] for key, out in zip(keys, outs)}
prediction = InstanceData(**prediction_data)
data_samples[i].prediction = prediction
prediction = LabelData(**prediction_data)
data_samples[i].pred_label = prediction
return data_samples
def preprocss_data(
self,
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
"""Process input data during training, testing or extracting.
Args:
data (List[Dict]): The data to be processed, which
comes from dataloader.
Returns:
tuple: It should contain 2 item.
- batch_images (List[torch.Tensor]): The batch image tensor.
- data_samples (List[SelfSupDataSample], Optional): The Data
Samples. It usually includes information such as
`gt_label`. Return None If the input data does not
contain `data_sample`.
"""
# data_['inputs] is a list
images = [data_['inputs'] for data_ in data]
data_samples = [data_['data_sample'] for data_ in data]
device = get_module_device(self)
data_samples = [data_sample.to(device) for data_sample in data_samples]
images = [[img_.to(device) for img_ in img] for img in images]
# convert images to rgb
if self.to_rgb and images[0][0].size(0) == 3:
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
# normalize images
images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img]
for img in images]
# reconstruct images into several batches. RelativeLoc needs
# nine crops for each image, and this code snippet will convert images
# into nine batches, each containing one crop of an image.
batch_images = []
for i in range(len(images[0])):
cur_batch = [img[i] for img in images]
batch_images.append(torch.stack(cur_batch))
img1 = torch.stack(batch_images[1:], 1) # Nx8xCxHxW
img1 = img1.view(
img1.size(0) * img1.size(1), img1.size(2), img1.size(3),
img1.size(4)) # (8N)xCxHxW
img2 = torch.unsqueeze(batch_images[0], 1).repeat(1, 8, 1, 1,
1) # Nx8xCxHxW
img2 = img2.view(
img2.size(0) * img2.size(1), img2.size(2), img2.size(3),
img2.size(4)) # (8N)xCxHxW
batch_images = [img1, img2]
return batch_images, data_samples

View File

@ -5,14 +5,15 @@ import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from ..builder import HEADS
from ..builder import MODELS
@HEADS.register_module()
@MODELS.register_module()
class ClsHead(BaseModule):
"""Simplest classifier head, with only one fc layer.
Args:
loss (dict): Config of the loss.
with_avg_pool (bool): Whether to apply the average pooling
after neck. Defaults to False.
in_channels (int): Number of input channels. Defaults to 2048.
@ -22,6 +23,7 @@ class ClsHead(BaseModule):
def __init__(
self,
loss: dict,
with_avg_pool: Optional[bool] = False,
in_channels: Optional[int] = 2048,
num_classes: Optional[int] = 1000,
@ -32,6 +34,7 @@ class ClsHead(BaseModule):
]
) -> None:
super().__init__(init_cfg)
self.loss = MODELS.build(loss)
self.with_avg_pool = with_avg_pool
self.in_channels = in_channels
self.num_classes = num_classes
@ -41,10 +44,12 @@ class ClsHead(BaseModule):
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc_cls = nn.Linear(in_channels, num_classes)
def forward(
def logits(
self, x: Union[List[torch.Tensor],
Tuple[torch.Tensor]]) -> List[torch.Tensor]:
"""Forward head.
"""Get the logits before the cross_entropy loss.
This module is used to obtain the logits before the loss.
Args:
x (List[Tensor] | Tuple[Tensor]): Feature maps of backbone,
@ -64,3 +69,16 @@ class ClsHead(BaseModule):
x = x.view(x.size(0), -1)
cls_score = self.fc_cls(x)
return [cls_score]
def forward(self, x: Union[List[torch.Tensor], Tuple[torch.Tensor]],
label: torch.Tensor) -> torch.Tensor:
"""Get the loss.
Args:
x (List[Tensor] | Tuple[Tensor]): Feature maps of backbone,
each tensor has shape (N, C, H, W).
label (torch.Tensor): The label for cross entropy loss.
"""
outs = self.logits(x)
loss = self.loss(outs[0], label)
return loss

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dall_e import Encoder
from .data_preprocessor import SelfSupDataPreprocessor
from .data_preprocessor import (RelativeLocDataPreprocessor,
SelfSupDataPreprocessor)
from .ema import CosineEMA
from .extractor import Extractor
from .gather_layer import GatherLayer
@ -15,5 +16,5 @@ __all__ = [
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
'CosineEMA', 'SelfSupDataPreprocessor'
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor'
]

View File

@ -91,3 +91,58 @@ class SelfSupDataPreprocessor(ImgDataPreprocessor):
batch_inputs.append(torch.stack(cur_batch))
return batch_inputs, batch_data_samples
@MODELS.register_module()
class RelativeLocDataPreprocessor(SelfSupDataPreprocessor):
"""Image pre-processor for Relative Location."""
def forward(
self,
data: Sequence[dict],
training: bool = False
) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Performs normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation. If
subclasses override this method, they can perform different
preprocessing strategies for training and testing based on the
value of ``training``.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
"""
inputs, batch_data_samples = self.collate_data(data)
# channel transform
if self.channel_conversion:
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
for _input in inputs]
# Normalization. Here is what is different from
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
# for an image for some algorithms, e.g. SimCLR, each item in inputs
# is a list, containing multi-views for an image.
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
batch_inputs = []
for i in range(len(inputs[0])):
cur_batch = [img[i] for img in inputs]
batch_inputs.append(torch.stack(cur_batch))
# This part is unique to Relative Loc
img1 = torch.stack(batch_inputs[1:], 1) # Nx8xCxHxW
img1 = img1.view(
img1.size(0) * img1.size(1), img1.size(2), img1.size(3),
img1.size(4)) # (8N)xCxHxW
img2 = torch.unsqueeze(batch_inputs[0], 1).repeat(1, 8, 1, 1,
1) # Nx8xCxHxW
img2 = img2.view(
img2.size(0) * img2.size(1), img2.size(2), img2.size(3),
img2.size(4)) # (8N)xCxHxW
batch_inputs = [img1, img2]
return batch_inputs, batch_data_samples

View File

@ -1,10 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import pytest
import torch
from mmengine.data import LabelData
from mmengine.data import InstanceData
from mmselfsup.core.data_structures.selfsup_data_sample import \
SelfSupDataSample
@ -23,6 +22,7 @@ neck = dict(
with_avg_pool=True)
head = dict(
type='ClsHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
with_avg_pool=False,
in_channels=32,
num_classes=8,
@ -30,51 +30,22 @@ head = dict(
dict(type='Normal', std=0.005, layer='Linear'),
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
])
loss = dict(type='mmcls.CrossEntropyLoss')
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_relative_loc():
preprocess_cfg = {
data_preprocessor = {
'type': 'mmselfsup.RelativeLocDataPreprocessor',
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
'bgr_to_rgb': True
}
with pytest.raises(AssertionError):
alg = RelativeLoc(
backbone=backbone,
neck=None,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = RelativeLoc(
backbone=backbone,
neck=neck,
head=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = RelativeLoc(
backbone=None,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = RelativeLoc(
backbone=backbone,
neck=neck,
head=head,
loss=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg = RelativeLoc(
backbone=backbone,
neck=neck,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg.init_weights()
data_preprocessor=data_preprocessor)
batch_size = 5
fake_data = [{
@ -89,19 +60,19 @@ def test_relative_loc():
SelfSupDataSample()
} for _ in range(batch_size)]
patch_label = LabelData()
patch_label.value = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
pseudo_label = InstanceData()
pseudo_label.patch_label = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
for i in range(batch_size):
fake_data[i]['data_sample'].patch_label = patch_label
fake_data[i]['data_sample'].pseudo_label = pseudo_label
fake_outputs = alg(fake_data, return_loss=True)
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)
test_results = alg(fake_data, return_loss=False)
test_results = alg(fake_batch_inputs, fake_data_samples, mode='predict')
assert len(test_results) == len(fake_data)
assert list(test_results[0].prediction.head4.shape) == [8, 8]
assert list(test_results[0].pred_label.head4.shape) == [8, 8]
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_feat = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
fake_feat = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
assert list(fake_feat[0].shape) == [batch_size * 8, 512, 1, 1]