diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py index 1a5366d1..b76ecedd 100644 --- a/mmpretrain/models/heads/mae_head.py +++ b/mmpretrain/models/heads/mae_head.py @@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule): norm_pix_loss (bool): Whether or not normalize target. Defaults to False. patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. """ def __init__(self, loss: dict, norm_pix: bool = False, - patch_size: int = 16) -> None: + patch_size: int = 16, + in_channels: int = 3) -> None: super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size + self.in_channels = in_channels self.loss_module = MODELS.build(loss) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: @@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule): Args: imgs (torch.Tensor): A batch of images. The shape should - be :math:`(B, 3, H, W)`. + be :math:`(B, C, H, W)`. Returns: torch.Tensor: Patchified images. The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels)) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: @@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule): Args: x (torch.Tensor): The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. Returns: - torch.Tensor: The shape is :math:`(B, 3, H, W)`. + torch.Tensor: The shape is :math:`(B, C, H, W)`. """ p = self.patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p)) return imgs def construct_target(self, target: torch.Tensor) -> torch.Tensor: @@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule): normalize the image according to ``norm_pix``. Args: - target (torch.Tensor): Image with the shape of B x 3 x H x W + target (torch.Tensor): Image with the shape of B x C x H x W Returns: torch.Tensor: Tokenized images with the shape of B x L x C diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py index 85efd254..488a9963 100644 --- a/mmpretrain/models/selfsup/itpn.py +++ b/mmpretrain/models/selfsup/itpn.py @@ -64,6 +64,7 @@ class iTPNHiViT(HiViT): layer_scale_init_value: float = 0.0, mask_ratio: float = 0.75, reconstruction_type: str = 'pixel', + **kwargs, ): super().__init__( arch=arch, @@ -80,7 +81,9 @@ class iTPNHiViT(HiViT): norm_cfg=norm_cfg, ape=ape, rpe=rpe, - layer_scale_init_value=layer_scale_init_value) + layer_scale_init_value=layer_scale_init_value, + **kwargs, + ) self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio