[Enhance] Add iTPN Supports for Non-three channel image (#1735)
* Add channel argments to mae_head When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead. * Transfer other argments from iTPNHiViT to HiViT The HiViT supports specifying channels, but the iTPNHiViT class can't pass channel argments to it. This is one of the reasons that iTPNHiViT implementation only support images with 3 channels. * Update itpn.py Fix hint problempull/1780/head
parent
e1675e893e
commit
da1da48eb6
mmpretrain/models
heads
selfsup
|
@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule):
|
|||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
in_channels (int): Number of input channels. Defaults to 3.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3) -> None:
|
||||
super().__init__()
|
||||
self.norm_pix = norm_pix
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
imgs (torch.Tensor): A batch of images. The shape should
|
||||
be :math:`(B, 3, H, W)`.
|
||||
be :math:`(B, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Patchified images. The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
||||
|
||||
h = w = imgs.shape[2] // p
|
||||
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
||||
x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p))
|
||||
x = torch.einsum('nchpwq->nhwpqc', x)
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
x (torch.Tensor): The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shape is :math:`(B, 3, H, W)`.
|
||||
torch.Tensor: The shape is :math:`(B, C, H, W)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1]**.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||
imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule):
|
|||
normalize the image according to ``norm_pix``.
|
||||
|
||||
Args:
|
||||
target (torch.Tensor): Image with the shape of B x 3 x H x W
|
||||
target (torch.Tensor): Image with the shape of B x C x H x W
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tokenized images with the shape of B x L x C
|
||||
|
|
|
@ -64,6 +64,7 @@ class iTPNHiViT(HiViT):
|
|||
layer_scale_init_value: float = 0.0,
|
||||
mask_ratio: float = 0.75,
|
||||
reconstruction_type: str = 'pixel',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
|
@ -80,7 +81,9 @@ class iTPNHiViT(HiViT):
|
|||
norm_cfg=norm_cfg,
|
||||
ape=ape,
|
||||
rpe=rpe,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.pos_embed.requires_grad = False
|
||||
self.mask_ratio = mask_ratio
|
||||
|
|
Loading…
Reference in New Issue