[Enhance] Add iTPN Supports for Non-three channel image ()

* Add channel argments to mae_head

When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead.

* Transfer other argments from iTPNHiViT to HiViT

The HiViT supports specifying channels, but the iTPNHiViT class can't pass channel argments to it. This is one of the reasons that iTPNHiViT implementation only support images with 3 channels.

* Update itpn.py

Fix hint problem
pull/1780/head
ZhangYiqin 2023-09-04 13:11:16 +08:00 committed by GitHub
parent e1675e893e
commit da1da48eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 11 deletions
mmpretrain/models
selfsup

View File

@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule):
norm_pix_loss (bool): Whether or not normalize target.
Defaults to False.
patch_size (int): Patch size. Defaults to 16.
in_channels (int): Number of input channels. Defaults to 3.
"""
def __init__(self,
loss: dict,
norm_pix: bool = False,
patch_size: int = 16) -> None:
patch_size: int = 16,
in_channels: int = 3) -> None:
super().__init__()
self.norm_pix = norm_pix
self.patch_size = patch_size
self.in_channels = in_channels
self.loss_module = MODELS.build(loss)
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule):
Args:
imgs (torch.Tensor): A batch of images. The shape should
be :math:`(B, 3, H, W)`.
be :math:`(B, C, H, W)`.
Returns:
torch.Tensor: Patchified images. The shape is
:math:`(B, L, \text{patch_size}^2 \times 3)`.
:math:`(B, L, \text{patch_size}^2 \times C)`.
"""
p = self.patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels))
return x
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule):
Args:
x (torch.Tensor): The shape is
:math:`(B, L, \text{patch_size}^2 \times 3)`.
:math:`(B, L, \text{patch_size}^2 \times C)`.
Returns:
torch.Tensor: The shape is :math:`(B, 3, H, W)`.
torch.Tensor: The shape is :math:`(B, C, H, W)`.
"""
p = self.patch_size
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p))
return imgs
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule):
normalize the image according to ``norm_pix``.
Args:
target (torch.Tensor): Image with the shape of B x 3 x H x W
target (torch.Tensor): Image with the shape of B x C x H x W
Returns:
torch.Tensor: Tokenized images with the shape of B x L x C

View File

@ -64,6 +64,7 @@ class iTPNHiViT(HiViT):
layer_scale_init_value: float = 0.0,
mask_ratio: float = 0.75,
reconstruction_type: str = 'pixel',
**kwargs,
):
super().__init__(
arch=arch,
@ -80,7 +81,9 @@ class iTPNHiViT(HiViT):
norm_cfg=norm_cfg,
ape=ape,
rpe=rpe,
layer_scale_init_value=layer_scale_init_value)
layer_scale_init_value=layer_scale_init_value,
**kwargs,
)
self.pos_embed.requires_grad = False
self.mask_ratio = mask_ratio