[Refactor] refactor rotation pred
parent
87ed42aaeb
commit
bc199203a0
|
@ -1,11 +1,12 @@
|
|||
# dataset settings
|
||||
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='RotationWithLabels'),
|
||||
dict(
|
||||
|
@ -13,16 +14,6 @@ train_pipeline = [
|
|||
pseudo_label_keys=['rot_label'],
|
||||
meta_keys=['img_path'])
|
||||
]
|
||||
val_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=256),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='RotationWithLabels'),
|
||||
dict(
|
||||
type='PackSelfSupInputs',
|
||||
pseudo_label_keys=['rot_label'],
|
||||
meta_keys=['img_path'])
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=16,
|
||||
|
@ -35,14 +26,3 @@ train_dataloader = dict(
|
|||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=16,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix=dict(img_path='val/'),
|
||||
pipeline=val_pipeline))
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='RotationPred',
|
||||
data_preprocessor=dict(
|
||||
type='mmselfsup.RotationPredDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -8,5 +13,8 @@ model = dict(
|
|||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='SyncBN')),
|
||||
head=dict(
|
||||
type='ClsHead', with_avg_pool=True, in_channels=2048, num_classes=4),
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'))
|
||||
type='ClsHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
with_avg_pool=True,
|
||||
in_channels=2048,
|
||||
num_classes=4))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_base_ = 'rotation-pred_resnet50_8xb16-steplr-70e_in1k.py'
|
||||
|
||||
# fp16
|
||||
fp16 = dict(loss_scale=512.)
|
||||
# mixed precision
|
||||
optim_wrapper = dict(type='AmpOptimWrapper')
|
||||
|
|
|
@ -7,9 +7,10 @@ _base_ = [
|
|||
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=1e-4)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = [
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.1,
|
||||
|
@ -21,8 +22,4 @@ scheduler = [
|
|||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=70)
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=70)
|
||||
|
|
|
@ -1,73 +1,43 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.data import InstanceData
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.utils import get_module_device
|
||||
from ..builder import ALGORITHMS, build_backbone, build_head, build_loss
|
||||
from mmselfsup.registry import MODELS
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@ALGORITHMS.register_module()
|
||||
@MODELS.register_module()
|
||||
class RotationPred(BaseModel):
|
||||
"""Rotation prediction.
|
||||
|
||||
Implementation of `Unsupervised Representation Learning
|
||||
by Predicting Image Rotations <https://arxiv.org/abs/1803.07728>`_.
|
||||
|
||||
Args:
|
||||
backbone (Dict, optional): Config dict for module of backbone.
|
||||
head (Dict, optional): Config dict for module of loss functions.
|
||||
Defaults to None.
|
||||
loss (Dict, optional): Config dict for module of loss functions.
|
||||
Defaults to None.
|
||||
preprocess_cfg (Dict, optional): Config dict to preprocess images.
|
||||
Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
Implementation of `Unsupervised Representation Learning by Predicting Image
|
||||
Rotations <https://arxiv.org/abs/1803.07728>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
assert head is not None
|
||||
self.head = build_head(head)
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||
**kwargs) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
|
||||
x = self.backbone(inputs[0])
|
||||
x = self.backbone(batch_inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
def loss(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""Forward computation during training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
|
@ -75,24 +45,23 @@ class RotationPred(BaseModel):
|
|||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
|
||||
x = self.backbone(inputs[0])
|
||||
outs = self.head(x)
|
||||
x = self.backbone(batch_inputs[0])
|
||||
|
||||
rot_label = [
|
||||
data_sample.rot_label.value for data_sample in data_samples
|
||||
data_sample.pseudo_label.rot_label for data_sample in data_samples
|
||||
]
|
||||
rot_label = torch.flatten(torch.stack(rot_label, 0)) # (4N, )
|
||||
loss = self.loss(outs[0], rot_label)
|
||||
loss = self.head(x, rot_label)
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> List[SelfSupDataSample]:
|
||||
def predict(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> List[SelfSupDataSample]:
|
||||
"""The forward function in testing.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
|
@ -100,62 +69,13 @@ class RotationPred(BaseModel):
|
|||
List[SelfSupDataSample]: The prediction from model.
|
||||
"""
|
||||
|
||||
x = self.backbone(inputs[0]) # tuple
|
||||
outs = self.head(x)
|
||||
x = self.backbone(batch_inputs[0]) # tuple
|
||||
outs = self.head.logits(x)
|
||||
keys = [f'head{i}' for i in self.backbone.out_indices]
|
||||
outs = [torch.chunk(out, len(outs[0]) // 4, 0) for out in outs]
|
||||
|
||||
for i in range(len(outs[0])):
|
||||
prediction_data = {key: out[i] for key, out in zip(keys, outs)}
|
||||
prediction = InstanceData(**prediction_data)
|
||||
data_samples[i].prediction = prediction
|
||||
prediction = LabelData(**prediction_data)
|
||||
data_samples[i].pred_score = prediction
|
||||
return data_samples
|
||||
|
||||
def preprocss_data(
|
||||
self,
|
||||
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
|
||||
"""Process input data during training, testing or extracting.
|
||||
|
||||
Args:
|
||||
data (List[Dict]): The data to be processed, which
|
||||
comes from dataloader.
|
||||
|
||||
Returns:
|
||||
tuple: It should contain 2 item.
|
||||
- batch_images (List[torch.Tensor]): The batch image tensor.
|
||||
- data_samples (List[SelfSupDataSample], Optional): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_label`. Return None If the input data does not
|
||||
contain `data_sample`.
|
||||
"""
|
||||
# data_['inputs] is a list
|
||||
images = [data_['inputs'] for data_ in data]
|
||||
data_samples = [data_['data_sample'] for data_ in data]
|
||||
|
||||
device = get_module_device(self)
|
||||
data_samples = [data_sample.to(device) for data_sample in data_samples]
|
||||
images = [[img_.to(device) for img_ in img] for img in images]
|
||||
|
||||
# convert images to rgb
|
||||
if self.to_rgb and images[0][0].size(0) == 3:
|
||||
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
|
||||
|
||||
# normalize images
|
||||
images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img]
|
||||
for img in images]
|
||||
|
||||
# reconstruct images into several batches. RotationPred needs
|
||||
# four views for each image, and this code snippet will convert images
|
||||
# into four batches, each containing one view of an image.
|
||||
batch_images = []
|
||||
for i in range(len(images[0])):
|
||||
cur_batch = [img[i] for img in images]
|
||||
batch_images.append(torch.stack(cur_batch))
|
||||
|
||||
img = torch.stack(batch_images, 1) # Nx4xCxHxW
|
||||
img = img.view(
|
||||
img.size(0) * img.size(1), img.size(2), img.size(3),
|
||||
img.size(4)) # (4N)xCxHxW
|
||||
batch_images = [img]
|
||||
|
||||
return batch_images, data_samples
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dall_e import Encoder
|
||||
from .data_preprocessor import (RelativeLocDataPreprocessor,
|
||||
RotationPredDataPreprocessor,
|
||||
SelfSupDataPreprocessor)
|
||||
from .ema import CosineEMA
|
||||
from .extractor import Extractor
|
||||
|
@ -16,5 +17,6 @@ __all__ = [
|
|||
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
|
||||
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
|
||||
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
|
||||
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor'
|
||||
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor',
|
||||
'RotationPredDataPreprocessor'
|
||||
]
|
||||
|
|
|
@ -146,3 +146,53 @@ class RelativeLocDataPreprocessor(SelfSupDataPreprocessor):
|
|||
batch_inputs = [img1, img2]
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RotationPredDataPreprocessor(SelfSupDataPreprocessor):
|
||||
"""Image pre-processor for Relative Location."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
# channel transform
|
||||
if self.channel_conversion:
|
||||
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
# Normalization. Here is what is different from
|
||||
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
|
||||
# for an image for some algorithms, e.g. SimCLR, each item in inputs
|
||||
# is a list, containing multi-views for an image.
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
cur_batch = [img[i] for img in inputs]
|
||||
batch_inputs.append(torch.stack(cur_batch))
|
||||
|
||||
# This part is unique to Rotation Pred
|
||||
img = torch.stack(batch_inputs, 1) # Nx4xCxHxW
|
||||
img = img.view(
|
||||
img.size(0) * img.size(1), img.size(2), img.size(3),
|
||||
img.size(4)) # (4N)xCxHxW
|
||||
batch_inputs = [img]
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
|
|
|
@ -4,7 +4,7 @@ import platform
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.data import InstanceData
|
||||
|
||||
from mmselfsup.core.data_structures.selfsup_data_sample import \
|
||||
SelfSupDataSample
|
||||
|
@ -16,41 +16,25 @@ backbone = dict(
|
|||
in_channels=3,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
head = dict(type='ClsHead', with_avg_pool=True, in_channels=512, num_classes=4)
|
||||
loss = dict(type='mmcls.CrossEntropyLoss')
|
||||
head = dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
with_avg_pool=True,
|
||||
in_channels=512,
|
||||
num_classes=4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_relative_loc():
|
||||
preprocess_cfg = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RotationPred(
|
||||
backbone=backbone,
|
||||
head=None,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RotationPred(
|
||||
backbone=None,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RotationPred(
|
||||
backbone=backbone,
|
||||
head=head,
|
||||
loss=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
def test_rotation_pred():
|
||||
data_preprocessor = dict(
|
||||
type='mmselfsup.RotationPredDataPreprocessor',
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True)
|
||||
alg = RotationPred(
|
||||
backbone=backbone,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
alg.init_weights()
|
||||
data_preprocessor=copy.deepcopy(data_preprocessor))
|
||||
|
||||
bach_size = 5
|
||||
fake_data = [{
|
||||
|
@ -62,19 +46,19 @@ def test_relative_loc():
|
|||
SelfSupDataSample()
|
||||
} for _ in range(bach_size)]
|
||||
|
||||
rot_label = LabelData()
|
||||
rot_label.value = torch.tensor([0, 1, 2, 3])
|
||||
pseudo_label = InstanceData()
|
||||
pseudo_label.rot_label = torch.tensor([0, 1, 2, 3])
|
||||
for i in range(bach_size):
|
||||
fake_data[i]['data_sample'].rot_label = rot_label
|
||||
fake_data[i]['data_sample'].pseudo_label = pseudo_label
|
||||
|
||||
fake_outputs = alg(fake_data, return_loss=True)
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
|
||||
test_results = alg(fake_data, return_loss=False)
|
||||
assert len(test_results) == len(fake_data)
|
||||
assert list(test_results[0].prediction.head4.shape) == [4, 4]
|
||||
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_loss['loss'].item(), float)
|
||||
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_feat = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
assert list(fake_feat[0].shape) == [bach_size * 4, 512, 1, 1]
|
||||
fake_prediction = alg(fake_inputs, fake_data_samples, mode='predict')
|
||||
assert len(fake_prediction) == len(fake_data)
|
||||
assert list(fake_prediction[0].pred_score.head4.shape) == [4, 4]
|
||||
|
||||
fake_feats = alg(fake_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feats[0].shape) == [bach_size * 4, 512, 1, 1]
|
||||
|
|
Loading…
Reference in New Issue