From da1da48eb6bb8285b28277f1dd06ca30ffbe3dfe Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Mon, 4 Sep 2023 13:11:16 +0800 Subject: [PATCH] [Enhance] Add iTPN Supports for Non-three channel image (#1735) * Add channel argments to mae_head When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead. * Transfer other argments from iTPNHiViT to HiViT The HiViT supports specifying channels, but the iTPNHiViT class can't pass channel argments to it. This is one of the reasons that iTPNHiViT implementation only support images with 3 channels. * Update itpn.py Fix hint problem --- mmpretrain/models/heads/mae_head.py | 23 +++++++++++++---------- mmpretrain/models/selfsup/itpn.py | 5 ++++- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mmpretrain/models/heads/mae_head.py b/mmpretrain/models/heads/mae_head.py index 1a5366d1..b76ecedd 100644 --- a/mmpretrain/models/heads/mae_head.py +++ b/mmpretrain/models/heads/mae_head.py @@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule): norm_pix_loss (bool): Whether or not normalize target. Defaults to False. patch_size (int): Patch size. Defaults to 16. + in_channels (int): Number of input channels. Defaults to 3. """ def __init__(self, loss: dict, norm_pix: bool = False, - patch_size: int = 16) -> None: + patch_size: int = 16, + in_channels: int = 3) -> None: super().__init__() self.norm_pix = norm_pix self.patch_size = patch_size + self.in_channels = in_channels self.loss_module = MODELS.build(loss) def patchify(self, imgs: torch.Tensor) -> torch.Tensor: @@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule): Args: imgs (torch.Tensor): A batch of images. The shape should - be :math:`(B, 3, H, W)`. + be :math:`(B, C, H, W)`. Returns: torch.Tensor: Patchified images. The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. """ p = self.patch_size assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels)) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: @@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule): Args: x (torch.Tensor): The shape is - :math:`(B, L, \text{patch_size}^2 \times 3)`. + :math:`(B, L, \text{patch_size}^2 \times C)`. Returns: - torch.Tensor: The shape is :math:`(B, 3, H, W)`. + torch.Tensor: The shape is :math:`(B, C, H, W)`. """ p = self.patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p)) return imgs def construct_target(self, target: torch.Tensor) -> torch.Tensor: @@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule): normalize the image according to ``norm_pix``. Args: - target (torch.Tensor): Image with the shape of B x 3 x H x W + target (torch.Tensor): Image with the shape of B x C x H x W Returns: torch.Tensor: Tokenized images with the shape of B x L x C diff --git a/mmpretrain/models/selfsup/itpn.py b/mmpretrain/models/selfsup/itpn.py index 85efd254..488a9963 100644 --- a/mmpretrain/models/selfsup/itpn.py +++ b/mmpretrain/models/selfsup/itpn.py @@ -64,6 +64,7 @@ class iTPNHiViT(HiViT): layer_scale_init_value: float = 0.0, mask_ratio: float = 0.75, reconstruction_type: str = 'pixel', + **kwargs, ): super().__init__( arch=arch, @@ -80,7 +81,9 @@ class iTPNHiViT(HiViT): norm_cfg=norm_cfg, ape=ape, rpe=rpe, - layer_scale_init_value=layer_scale_init_value) + layer_scale_init_value=layer_scale_init_value, + **kwargs, + ) self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio