[Refactor] refactor rotation pred

pull/352/head
fangyixiao.vendor 2022-06-25 04:50:52 +00:00 committed by fangyixiao18
parent 87ed42aaeb
commit bc199203a0
8 changed files with 120 additions and 179 deletions

View File

@ -1,11 +1,12 @@
# dataset settings
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomResizedCrop', size=224, backend='pillow'),
dict(type='RandomFlip', prob=0.5),
dict(type='RotationWithLabels'),
dict(
@ -13,16 +14,6 @@ train_pipeline = [
pseudo_label_keys=['rot_label'],
meta_keys=['img_path'])
]
val_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=256),
dict(type='CenterCrop', crop_size=224),
dict(type='RotationWithLabels'),
dict(
type='PackSelfSupInputs',
pseudo_label_keys=['rot_label'],
meta_keys=['img_path'])
]
train_dataloader = dict(
batch_size=16,
@ -35,14 +26,3 @@ train_dataloader = dict(
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=16,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/val.txt',
data_prefix=dict(img_path='val/'),
pipeline=val_pipeline))

View File

@ -1,6 +1,11 @@
# model settings
model = dict(
type='RotationPred',
data_preprocessor=dict(
type='mmselfsup.RotationPredDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='ResNet',
depth=50,
@ -8,5 +13,8 @@ model = dict(
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='SyncBN')),
head=dict(
type='ClsHead', with_avg_pool=True, in_channels=2048, num_classes=4),
loss=dict(type='mmcls.CrossEntropyLoss'))
type='ClsHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
with_avg_pool=True,
in_channels=2048,
num_classes=4))

View File

@ -1,4 +1,4 @@
_base_ = 'rotation-pred_resnet50_8xb16-steplr-70e_in1k.py'
# fp16
fp16 = dict(loss_scale=512.)
# mixed precision
optim_wrapper = dict(type='AmpOptimWrapper')

View File

@ -7,9 +7,10 @@ _base_ = [
# optimizer
optimizer = dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=1e-4)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
# learning rate scheduler
scheduler = [
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.1,
@ -21,8 +22,4 @@ scheduler = [
]
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=70)
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=70)

View File

@ -1,73 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
import torch
from mmengine.data import InstanceData
from mmengine.data import LabelData
from mmselfsup.core import SelfSupDataSample
from mmselfsup.utils import get_module_device
from ..builder import ALGORITHMS, build_backbone, build_head, build_loss
from mmselfsup.registry import MODELS
from .base import BaseModel
@ALGORITHMS.register_module()
@MODELS.register_module()
class RotationPred(BaseModel):
"""Rotation prediction.
Implementation of `Unsupervised Representation Learning
by Predicting Image Rotations <https://arxiv.org/abs/1803.07728>`_.
Args:
backbone (Dict, optional): Config dict for module of backbone.
head (Dict, optional): Config dict for module of loss functions.
Defaults to None.
loss (Dict, optional): Config dict for module of loss functions.
Defaults to None.
preprocess_cfg (Dict, optional): Config dict to preprocess images.
Defaults to None.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
Implementation of `Unsupervised Representation Learning by Predicting Image
Rotations <https://arxiv.org/abs/1803.07728>`_.
"""
def __init__(self,
backbone: Optional[Dict] = None,
head: Optional[Dict] = None,
loss: Optional[Dict] = None,
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
**kwargs) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert head is not None
self.head = build_head(head)
assert loss is not None
self.loss = build_loss(loss)
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
def extract_feat(self, batch_inputs: List[torch.Tensor],
**kwargs) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
batch_inputs (List[torch.Tensor]): The input images.
Returns:
Tuple[torch.Tensor]: backbone outputs.
"""
x = self.backbone(inputs[0])
x = self.backbone(batch_inputs[0])
return x
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
def loss(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
@ -75,24 +45,23 @@ class RotationPred(BaseModel):
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
x = self.backbone(inputs[0])
outs = self.head(x)
x = self.backbone(batch_inputs[0])
rot_label = [
data_sample.rot_label.value for data_sample in data_samples
data_sample.pseudo_label.rot_label for data_sample in data_samples
]
rot_label = torch.flatten(torch.stack(rot_label, 0)) # (4N, )
loss = self.loss(outs[0], rot_label)
loss = self.head(x, rot_label)
losses = dict(loss=loss)
return losses
def forward_test(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> List[SelfSupDataSample]:
def predict(self, batch_inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> List[SelfSupDataSample]:
"""The forward function in testing.
Args:
inputs (List[torch.Tensor]): The input images.
batch_inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
@ -100,62 +69,13 @@ class RotationPred(BaseModel):
List[SelfSupDataSample]: The prediction from model.
"""
x = self.backbone(inputs[0]) # tuple
outs = self.head(x)
x = self.backbone(batch_inputs[0]) # tuple
outs = self.head.logits(x)
keys = [f'head{i}' for i in self.backbone.out_indices]
outs = [torch.chunk(out, len(outs[0]) // 4, 0) for out in outs]
for i in range(len(outs[0])):
prediction_data = {key: out[i] for key, out in zip(keys, outs)}
prediction = InstanceData(**prediction_data)
data_samples[i].prediction = prediction
prediction = LabelData(**prediction_data)
data_samples[i].pred_score = prediction
return data_samples
def preprocss_data(
self,
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
"""Process input data during training, testing or extracting.
Args:
data (List[Dict]): The data to be processed, which
comes from dataloader.
Returns:
tuple: It should contain 2 item.
- batch_images (List[torch.Tensor]): The batch image tensor.
- data_samples (List[SelfSupDataSample], Optional): The Data
Samples. It usually includes information such as
`gt_label`. Return None If the input data does not
contain `data_sample`.
"""
# data_['inputs] is a list
images = [data_['inputs'] for data_ in data]
data_samples = [data_['data_sample'] for data_ in data]
device = get_module_device(self)
data_samples = [data_sample.to(device) for data_sample in data_samples]
images = [[img_.to(device) for img_ in img] for img in images]
# convert images to rgb
if self.to_rgb and images[0][0].size(0) == 3:
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
# normalize images
images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img]
for img in images]
# reconstruct images into several batches. RotationPred needs
# four views for each image, and this code snippet will convert images
# into four batches, each containing one view of an image.
batch_images = []
for i in range(len(images[0])):
cur_batch = [img[i] for img in images]
batch_images.append(torch.stack(cur_batch))
img = torch.stack(batch_images, 1) # Nx4xCxHxW
img = img.view(
img.size(0) * img.size(1), img.size(2), img.size(3),
img.size(4)) # (4N)xCxHxW
batch_images = [img]
return batch_images, data_samples

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dall_e import Encoder
from .data_preprocessor import (RelativeLocDataPreprocessor,
RotationPredDataPreprocessor,
SelfSupDataPreprocessor)
from .ema import CosineEMA
from .extractor import Extractor
@ -16,5 +17,6 @@ __all__ = [
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor'
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor',
'RotationPredDataPreprocessor'
]

View File

@ -146,3 +146,53 @@ class RelativeLocDataPreprocessor(SelfSupDataPreprocessor):
batch_inputs = [img1, img2]
return batch_inputs, batch_data_samples
@MODELS.register_module()
class RotationPredDataPreprocessor(SelfSupDataPreprocessor):
"""Image pre-processor for Relative Location."""
def forward(
self,
data: Sequence[dict],
training: bool = False
) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Performs normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
training (bool): Whether to enable training time augmentation. If
subclasses override this method, they can perform different
preprocessing strategies for training and testing based on the
value of ``training``.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
"""
inputs, batch_data_samples = self.collate_data(data)
# channel transform
if self.channel_conversion:
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
for _input in inputs]
# Normalization. Here is what is different from
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
# for an image for some algorithms, e.g. SimCLR, each item in inputs
# is a list, containing multi-views for an image.
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
for _input in inputs]
batch_inputs = []
for i in range(len(inputs[0])):
cur_batch = [img[i] for img in inputs]
batch_inputs.append(torch.stack(cur_batch))
# This part is unique to Rotation Pred
img = torch.stack(batch_inputs, 1) # Nx4xCxHxW
img = img.view(
img.size(0) * img.size(1), img.size(2), img.size(3),
img.size(4)) # (4N)xCxHxW
batch_inputs = [img]
return batch_inputs, batch_data_samples

View File

@ -4,7 +4,7 @@ import platform
import pytest
import torch
from mmengine.data import LabelData
from mmengine.data import InstanceData
from mmselfsup.core.data_structures.selfsup_data_sample import \
SelfSupDataSample
@ -16,41 +16,25 @@ backbone = dict(
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
head = dict(type='ClsHead', with_avg_pool=True, in_channels=512, num_classes=4)
loss = dict(type='mmcls.CrossEntropyLoss')
head = dict(
type='ClsHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
with_avg_pool=True,
in_channels=512,
num_classes=4)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_relative_loc():
preprocess_cfg = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
with pytest.raises(AssertionError):
alg = RotationPred(
backbone=backbone,
head=None,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = RotationPred(
backbone=None,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
alg = RotationPred(
backbone=backbone,
head=head,
loss=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
def test_rotation_pred():
data_preprocessor = dict(
type='mmselfsup.RotationPredDataPreprocessor',
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
bgr_to_rgb=True)
alg = RotationPred(
backbone=backbone,
head=head,
loss=loss,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
alg.init_weights()
data_preprocessor=copy.deepcopy(data_preprocessor))
bach_size = 5
fake_data = [{
@ -62,19 +46,19 @@ def test_relative_loc():
SelfSupDataSample()
} for _ in range(bach_size)]
rot_label = LabelData()
rot_label.value = torch.tensor([0, 1, 2, 3])
pseudo_label = InstanceData()
pseudo_label.rot_label = torch.tensor([0, 1, 2, 3])
for i in range(bach_size):
fake_data[i]['data_sample'].rot_label = rot_label
fake_data[i]['data_sample'].pseudo_label = pseudo_label
fake_outputs = alg(fake_data, return_loss=True)
assert isinstance(fake_outputs['loss'].item(), float)
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
test_results = alg(fake_data, return_loss=False)
assert len(test_results) == len(fake_data)
assert list(test_results[0].prediction.head4.shape) == [4, 4]
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_loss['loss'].item(), float)
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
fake_feat = alg.extract_feat(
inputs=fake_inputs, data_samples=fake_data_samples)
assert list(fake_feat[0].shape) == [bach_size * 4, 512, 1, 1]
fake_prediction = alg(fake_inputs, fake_data_samples, mode='predict')
assert len(fake_prediction) == len(fake_data)
assert list(fake_prediction[0].pred_score.head4.shape) == [4, 4]
fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor')
assert list(fake_feats[0].shape) == [bach_size * 4, 512, 1, 1]