104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MAEPretrainHead(BaseModule):
|
|
"""Head for MAE Pre-training.
|
|
|
|
Args:
|
|
loss (dict): Config of loss.
|
|
norm_pix_loss (bool): Whether or not normalize target.
|
|
Defaults to False.
|
|
patch_size (int): Patch size. Defaults to 16.
|
|
"""
|
|
|
|
def __init__(self,
|
|
loss: dict,
|
|
norm_pix: bool = False,
|
|
patch_size: int = 16) -> None:
|
|
super().__init__()
|
|
self.norm_pix = norm_pix
|
|
self.patch_size = patch_size
|
|
self.loss_module = MODELS.build(loss)
|
|
|
|
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
|
r"""Split images into non-overlapped patches.
|
|
|
|
Args:
|
|
imgs (torch.Tensor): A batch of images. The shape should
|
|
be :math:`(B, 3, H, W)`.
|
|
|
|
Returns:
|
|
torch.Tensor: Patchified images. The shape is
|
|
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
|
"""
|
|
p = self.patch_size
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
|
|
|
h = w = imgs.shape[2] // p
|
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
|
x = torch.einsum('nchpwq->nhwpqc', x)
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
|
return x
|
|
|
|
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
|
r"""Combine non-overlapped patches into images.
|
|
|
|
Args:
|
|
x (torch.Tensor): The shape is
|
|
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
|
|
|
Returns:
|
|
torch.Tensor: The shape is :math:`(B, 3, H, W)`.
|
|
"""
|
|
p = self.patch_size
|
|
h = w = int(x.shape[1]**.5)
|
|
assert h * w == x.shape[1]
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
|
x = torch.einsum('nhwpqc->nchpwq', x)
|
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
|
return imgs
|
|
|
|
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
|
"""Construct the reconstruction target.
|
|
|
|
In addition to splitting images into tokens, this module will also
|
|
normalize the image according to ``norm_pix``.
|
|
|
|
Args:
|
|
target (torch.Tensor): Image with the shape of B x 3 x H x W
|
|
|
|
Returns:
|
|
torch.Tensor: Tokenized images with the shape of B x L x C
|
|
"""
|
|
target = self.patchify(target)
|
|
if self.norm_pix:
|
|
# normalize the target image
|
|
mean = target.mean(dim=-1, keepdim=True)
|
|
var = target.var(dim=-1, keepdim=True)
|
|
target = (target - mean) / (var + 1.e-6)**.5
|
|
|
|
return target
|
|
|
|
def loss(self, pred: torch.Tensor, target: torch.Tensor,
|
|
mask: torch.Tensor) -> torch.Tensor:
|
|
"""Generate loss.
|
|
|
|
Args:
|
|
pred (torch.Tensor): The reconstructed image.
|
|
target (torch.Tensor): The target image.
|
|
mask (torch.Tensor): The mask of the target image.
|
|
|
|
Returns:
|
|
torch.Tensor: The reconstruction loss.
|
|
"""
|
|
target = self.construct_target(target)
|
|
loss = self.loss_module(pred, target, mask)
|
|
|
|
return loss
|