[Refactor]: Refactor CAE
parent
e687aff595
commit
be1dd2f5c2
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.utils import get_module_device
|
||||
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
||||
from .base import BaseModel
|
||||
|
||||
|
@ -16,24 +17,26 @@ class CAE(BaseModel):
|
|||
Learning <https://arxiv.org/abs/2202.03026>`_.
|
||||
|
||||
Args:
|
||||
backbone (dict, optional): Config dict for module of backbone.
|
||||
neck (dict, optional): Config dict for module of deep features to
|
||||
compact feature vectors. Defaults to None.
|
||||
head (dict, optional): Config dict for module of loss functions.
|
||||
backbone (Dict, optional): Config dict for encoder. Defaults to None.
|
||||
neck (Dict, optional): Config dict for encoder. Defaults to None.
|
||||
head (Dict, optional): Config dict for loss functions.
|
||||
Defaults to None.
|
||||
base_momentum (float): The base momentum coefficient for the target
|
||||
network. Defaults to 0.0.
|
||||
init_cfg (dict, optional): the config to control the initialization.
|
||||
preprocess_cfg (Dict, optional): Config to preprocess images.
|
||||
Defaults to None.
|
||||
init_cfg (Union[List[Dict], Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: dict = None,
|
||||
neck: dict = None,
|
||||
head: dict = None,
|
||||
backbone: Optional[Dict] = None,
|
||||
neck: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
base_momentum: float = 0.0,
|
||||
init_cfg: dict = None,
|
||||
**kwargs) -> None:
|
||||
super(CAE, self).__init__(init_cfg)
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[List[Dict], Dict]] = None) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
self.teacher = build_backbone(backbone)
|
||||
|
@ -44,10 +47,6 @@ class CAE(BaseModel):
|
|||
|
||||
self.momentum = base_momentum
|
||||
|
||||
self.img_norm = Normalize(
|
||||
mean=torch.tensor((0.485, 0.456, 0.406)),
|
||||
std=torch.tensor((0.229, 0.224, 0.225)))
|
||||
|
||||
def init_weights(self) -> None:
|
||||
super().init_weights()
|
||||
self._init_teacher()
|
||||
|
@ -67,36 +66,46 @@ class CAE(BaseModel):
|
|||
param_teacher.data = param_teacher.data * self.momentum + \
|
||||
param_bacbone.data * (1. - self.momentum)
|
||||
|
||||
def extract_feat(self, img: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""The forward function to extract features.
|
||||
|
||||
x = self.backbone(img, mask)
|
||||
return x
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
def forward_train(self, samples: Sequence, **kwargs) -> dict:
|
||||
img, img_target, mask = samples
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
mask = torch.stack(
|
||||
[data_sample.mask.value for data_sample in data_samples])
|
||||
return self.backbone(inputs[0], mask)
|
||||
|
||||
# normalize images and the images to get the target
|
||||
img_list = [self.img_norm(x).unsqueeze(0) for x in img]
|
||||
img = torch.cat(img_list)
|
||||
img_target = 0.8 * img_target + 0.1
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
|
||||
mask = torch.stack(
|
||||
[data_sample.mask.value for data_sample in data_samples])
|
||||
|
||||
mask = mask.flatten(1).to(torch.bool)
|
||||
|
||||
unmasked = self.backbone(img, mask)
|
||||
unmasked = self.backbone(inputs[0], mask)
|
||||
|
||||
# get the latent prediction for the masked patches
|
||||
with torch.no_grad():
|
||||
latent_target = self.teacher(img, ~mask)
|
||||
latent_target = self.teacher(inputs[0], ~mask)
|
||||
latent_target = latent_target[:, 1:, :]
|
||||
self.momentum_update()
|
||||
|
||||
pos_embed = self.backbone.pos_embed.expand(img.shape[0], -1, -1)
|
||||
pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1)
|
||||
pos_embed_masked = pos_embed[:,
|
||||
1:][mask].reshape(img.shape[0], -1,
|
||||
1:][mask].reshape(inputs[0].shape[0], -1,
|
||||
pos_embed.shape[-1])
|
||||
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
|
||||
img.shape[0], -1, pos_embed.shape[-1])
|
||||
inputs[0].shape[0], -1, pos_embed.shape[-1])
|
||||
|
||||
# input the unmasked tokens and masked tokens to the decoder
|
||||
logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked,
|
||||
|
@ -104,6 +113,51 @@ class CAE(BaseModel):
|
|||
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
|
||||
losses = self.head(img_target, logits, latent_pred, latent_target,
|
||||
mask)
|
||||
losses = self.head(inputs[1], logits, latent_pred, latent_target, mask)
|
||||
return losses
|
||||
|
||||
def preprocss_data(
|
||||
self,
|
||||
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
|
||||
"""Process input data during training, testing or extracting.
|
||||
|
||||
This function overwrites the defaults function in BaseModel by
|
||||
normalizing img_target with dalle style normalization.
|
||||
|
||||
Args:
|
||||
data (List[Dict]): The data to be processed, which
|
||||
comes from dataloader.
|
||||
|
||||
Returns:
|
||||
tuple: It should contain 2 item.
|
||||
- batch_images (List[torch.Tensor]): The batch image tensor.
|
||||
- data_samples (List[SelfSupDataSample], Optional): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_label`. Return None If the input data does not
|
||||
contain `data_sample`.
|
||||
"""
|
||||
# data_['inputs] is a list
|
||||
images = [data_['inputs'] for data_ in data]
|
||||
data_samples = [data_['data_sample'] for data_ in data]
|
||||
|
||||
device = get_module_device(self)
|
||||
data_samples = [data_sample.to(device) for data_sample in data_samples]
|
||||
images = [[img_.to(device) for img_ in img] for img in images]
|
||||
|
||||
# convert images to rgb
|
||||
if self.to_rgb and images[0][0].size(0) == 3:
|
||||
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
|
||||
|
||||
# normalize images
|
||||
images = [[(img[0] - self.mean_norm) / self.std_norm,
|
||||
img[1] * 0.8 + 0.1] for img in images]
|
||||
|
||||
# reconstruct images into several batches. For example, SimCLR needs
|
||||
# two crops for each image, and this code snippet will convert images
|
||||
# into two batches, each containing one crop of an image.
|
||||
batch_images = []
|
||||
for i in range(len(images[0])):
|
||||
cur_batch = [img[i] for img in images]
|
||||
batch_images.append(torch.stack(cur_batch))
|
||||
|
||||
return batch_images, data_samples
|
||||
|
|
|
@ -1,14 +1,29 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
import copy
|
||||
>>>>>>> 6491042 ([Refactor]: Refactor CAE)
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
<<<<<<< HEAD
|
||||
|
||||
from mmselfsup.models.algorithms import CAE
|
||||
|
||||
# model settings
|
||||
backbone = dict(
|
||||
type='CAEViT', arch='b', patch_size=16, init_values=0.1, qkv_bias=False)
|
||||
=======
|
||||
from mmengine.data import BaseDataElement as PixelData
|
||||
|
||||
from mmselfsup.core.data_structures.selfsup_data_sample import \
|
||||
SelfSupDataSample
|
||||
from mmselfsup.models.algorithms.cae import CAE
|
||||
|
||||
# model settings
|
||||
backbone = dict(type='CAEViT', arch='b', patch_size=16, init_values=0.1)
|
||||
>>>>>>> 6491042 ([Refactor]: Refactor CAE)
|
||||
neck = dict(
|
||||
type='CAENeck',
|
||||
patch_size=16,
|
||||
|
@ -22,27 +37,61 @@ neck = dict(
|
|||
head = dict(
|
||||
type='CAEHead', tokenizer_path='cae_ckpt/encoder_stat_dict.pth', lambd=2)
|
||||
|
||||
preprocess_cfg = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_cae():
|
||||
with pytest.raises(AssertionError):
|
||||
model = CAE(backbone=None, neck=neck, head=head)
|
||||
model = CAE(
|
||||
backbone=None,
|
||||
neck=neck,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
model = CAE(backbone=backbone, neck=None, head=head)
|
||||
model = CAE(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
model = CAE(backbone=backbone, neck=neck, head=None)
|
||||
model = CAE(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
|
||||
model = CAE(backbone=backbone, neck=neck, head=head)
|
||||
model = CAE(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
model.init_weights()
|
||||
|
||||
fake_input = torch.rand((1, 3, 224, 224))
|
||||
fake_target = torch.rand((1, 3, 112, 112))
|
||||
fake_mask = torch.zeros((1, 196)).bool()
|
||||
fake_mask[:, 75:150] = 1
|
||||
fake_img = torch.rand((3, 224, 224))
|
||||
fake_target_img = torch.rand((3, 112, 112))
|
||||
fake_mask = torch.zeros((196)).bool()
|
||||
fake_mask[75:150] = 1
|
||||
fake_data_sample = SelfSupDataSample()
|
||||
fake_mask = PixelData(value=fake_mask)
|
||||
fake_data_sample.mask = fake_mask
|
||||
|
||||
inputs = (fake_input, fake_target, fake_mask)
|
||||
fake_data = [{
|
||||
'inputs': [fake_img, fake_target_img],
|
||||
'data_sample': fake_data_sample
|
||||
}]
|
||||
|
||||
fake_loss = model.forward_train(inputs)
|
||||
fake_feat = model.extract_feat(fake_input, fake_mask)
|
||||
fake_loss = model(fake_data, return_loss=True)
|
||||
|
||||
# test forward_train
|
||||
assert isinstance(fake_loss['loss'].item(), float)
|
||||
|
||||
# test extract_feat
|
||||
fake_inputs, fake_data_samples = model.preprocss_data(fake_data)
|
||||
fake_feat = model.extract_feat(fake_inputs, fake_data_samples)
|
||||
|
||||
assert list(fake_feat.shape) == [1, 122, 768]
|
||||
|
|
Loading…
Reference in New Issue