[Refactor]: Refactor CAE

pull/352/head
liuyuan1.vendor 2022-05-17 07:19:00 +00:00 committed by fangyixiao18
parent e687aff595
commit be1dd2f5c2
2 changed files with 148 additions and 45 deletions

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
from typing import Dict, List, Optional, Tuple, Union
import torch
from torchvision.transforms import Normalize
from mmselfsup.core import SelfSupDataSample
from mmselfsup.utils import get_module_device
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel
@ -16,24 +17,26 @@ class CAE(BaseModel):
Learning <https://arxiv.org/abs/2202.03026>`_.
Args:
backbone (dict, optional): Config dict for module of backbone.
neck (dict, optional): Config dict for module of deep features to
compact feature vectors. Defaults to None.
head (dict, optional): Config dict for module of loss functions.
backbone (Dict, optional): Config dict for encoder. Defaults to None.
neck (Dict, optional): Config dict for encoder. Defaults to None.
head (Dict, optional): Config dict for loss functions.
Defaults to None.
base_momentum (float): The base momentum coefficient for the target
network. Defaults to 0.0.
init_cfg (dict, optional): the config to control the initialization.
preprocess_cfg (Dict, optional): Config to preprocess images.
Defaults to None.
init_cfg (Union[List[Dict], Dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: dict = None,
neck: dict = None,
head: dict = None,
backbone: Optional[Dict] = None,
neck: Optional[Dict] = None,
head: Optional[Dict] = None,
base_momentum: float = 0.0,
init_cfg: dict = None,
**kwargs) -> None:
super(CAE, self).__init__(init_cfg)
preprocess_cfg: Optional[Dict] = None,
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
self.teacher = build_backbone(backbone)
@ -44,10 +47,6 @@ class CAE(BaseModel):
self.momentum = base_momentum
self.img_norm = Normalize(
mean=torch.tensor((0.485, 0.456, 0.406)),
std=torch.tensor((0.229, 0.224, 0.225)))
def init_weights(self) -> None:
super().init_weights()
self._init_teacher()
@ -67,36 +66,46 @@ class CAE(BaseModel):
param_teacher.data = param_teacher.data * self.momentum + \
param_bacbone.data * (1. - self.momentum)
def extract_feat(self, img: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
def extract_feat(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwarg) -> Tuple[torch.Tensor]:
"""The forward function to extract features.
x = self.backbone(img, mask)
return x
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
def forward_train(self, samples: Sequence, **kwargs) -> dict:
img, img_target, mask = samples
Returns:
Tuple[torch.Tensor]: backbone outputs.
"""
mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
return self.backbone(inputs[0], mask)
# normalize images and the images to get the target
img_list = [self.img_norm(x).unsqueeze(0) for x in img]
img = torch.cat(img_list)
img_target = 0.8 * img_target + 0.1
def forward_train(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
mask = torch.stack(
[data_sample.mask.value for data_sample in data_samples])
mask = mask.flatten(1).to(torch.bool)
unmasked = self.backbone(img, mask)
unmasked = self.backbone(inputs[0], mask)
# get the latent prediction for the masked patches
with torch.no_grad():
latent_target = self.teacher(img, ~mask)
latent_target = self.teacher(inputs[0], ~mask)
latent_target = latent_target[:, 1:, :]
self.momentum_update()
pos_embed = self.backbone.pos_embed.expand(img.shape[0], -1, -1)
pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1)
pos_embed_masked = pos_embed[:,
1:][mask].reshape(img.shape[0], -1,
1:][mask].reshape(inputs[0].shape[0], -1,
pos_embed.shape[-1])
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
img.shape[0], -1, pos_embed.shape[-1])
inputs[0].shape[0], -1, pos_embed.shape[-1])
# input the unmasked tokens and masked tokens to the decoder
logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked,
@ -104,6 +113,51 @@ class CAE(BaseModel):
logits = logits.view(-1, logits.shape[-1])
losses = self.head(img_target, logits, latent_pred, latent_target,
mask)
losses = self.head(inputs[1], logits, latent_pred, latent_target, mask)
return losses
def preprocss_data(
self,
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
"""Process input data during training, testing or extracting.
This function overwrites the defaults function in BaseModel by
normalizing img_target with dalle style normalization.
Args:
data (List[Dict]): The data to be processed, which
comes from dataloader.
Returns:
tuple: It should contain 2 item.
- batch_images (List[torch.Tensor]): The batch image tensor.
- data_samples (List[SelfSupDataSample], Optional): The Data
Samples. It usually includes information such as
`gt_label`. Return None If the input data does not
contain `data_sample`.
"""
# data_['inputs] is a list
images = [data_['inputs'] for data_ in data]
data_samples = [data_['data_sample'] for data_ in data]
device = get_module_device(self)
data_samples = [data_sample.to(device) for data_sample in data_samples]
images = [[img_.to(device) for img_ in img] for img in images]
# convert images to rgb
if self.to_rgb and images[0][0].size(0) == 3:
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
# normalize images
images = [[(img[0] - self.mean_norm) / self.std_norm,
img[1] * 0.8 + 0.1] for img in images]
# reconstruct images into several batches. For example, SimCLR needs
# two crops for each image, and this code snippet will convert images
# into two batches, each containing one crop of an image.
batch_images = []
for i in range(len(images[0])):
cur_batch = [img[i] for img in images]
batch_images.append(torch.stack(cur_batch))
return batch_images, data_samples

View File

@ -1,14 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
<<<<<<< HEAD
=======
import copy
>>>>>>> 6491042 ([Refactor]: Refactor CAE)
import platform
import pytest
import torch
<<<<<<< HEAD
from mmselfsup.models.algorithms import CAE
# model settings
backbone = dict(
type='CAEViT', arch='b', patch_size=16, init_values=0.1, qkv_bias=False)
=======
from mmengine.data import BaseDataElement as PixelData
from mmselfsup.core.data_structures.selfsup_data_sample import \
SelfSupDataSample
from mmselfsup.models.algorithms.cae import CAE
# model settings
backbone = dict(type='CAEViT', arch='b', patch_size=16, init_values=0.1)
>>>>>>> 6491042 ([Refactor]: Refactor CAE)
neck = dict(
type='CAENeck',
patch_size=16,
@ -22,27 +37,61 @@ neck = dict(
head = dict(
type='CAEHead', tokenizer_path='cae_ckpt/encoder_stat_dict.pth', lambd=2)
preprocess_cfg = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_cae():
with pytest.raises(AssertionError):
model = CAE(backbone=None, neck=neck, head=head)
model = CAE(
backbone=None,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
model = CAE(backbone=backbone, neck=None, head=head)
model = CAE(
backbone=backbone,
neck=None,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
with pytest.raises(AssertionError):
model = CAE(backbone=backbone, neck=neck, head=None)
model = CAE(
backbone=backbone,
neck=neck,
head=None,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
model = CAE(backbone=backbone, neck=neck, head=head)
model = CAE(
backbone=backbone,
neck=neck,
head=head,
preprocess_cfg=copy.deepcopy(preprocess_cfg))
model.init_weights()
fake_input = torch.rand((1, 3, 224, 224))
fake_target = torch.rand((1, 3, 112, 112))
fake_mask = torch.zeros((1, 196)).bool()
fake_mask[:, 75:150] = 1
fake_img = torch.rand((3, 224, 224))
fake_target_img = torch.rand((3, 112, 112))
fake_mask = torch.zeros((196)).bool()
fake_mask[75:150] = 1
fake_data_sample = SelfSupDataSample()
fake_mask = PixelData(value=fake_mask)
fake_data_sample.mask = fake_mask
inputs = (fake_input, fake_target, fake_mask)
fake_data = [{
'inputs': [fake_img, fake_target_img],
'data_sample': fake_data_sample
}]
fake_loss = model.forward_train(inputs)
fake_feat = model.extract_feat(fake_input, fake_mask)
fake_loss = model(fake_data, return_loss=True)
# test forward_train
assert isinstance(fake_loss['loss'].item(), float)
# test extract_feat
fake_inputs, fake_data_samples = model.preprocss_data(fake_data)
fake_feat = model.extract_feat(fake_inputs, fake_data_samples)
assert list(fake_feat.shape) == [1, 122, 768]