Add channel argments to mae_head
When trying iTPN pretrain, it only supports images with 3 channels. One of the restrictions is from MAEHead.pull/1735/head
parent
58a2243d99
commit
18f0503ef4
|
@ -14,15 +14,18 @@ class MAEPretrainHead(BaseModule):
|
|||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
in_channels (int): Number of input channels. Defaults to 3.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3) -> None:
|
||||
super().__init__()
|
||||
self.norm_pix = norm_pix
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -30,19 +33,19 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
imgs (torch.Tensor): A batch of images. The shape should
|
||||
be :math:`(B, 3, H, W)`.
|
||||
be :math:`(B, C, H, W)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Patchified images. The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
||||
|
||||
h = w = imgs.shape[2] // p
|
||||
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
||||
x = imgs.reshape(shape=(imgs.shape[0], self.in_channels, h, p, w, p))
|
||||
x = torch.einsum('nchpwq->nhwpqc', x)
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_channels))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -50,18 +53,18 @@ class MAEPretrainHead(BaseModule):
|
|||
|
||||
Args:
|
||||
x (torch.Tensor): The shape is
|
||||
:math:`(B, L, \text{patch_size}^2 \times 3)`.
|
||||
:math:`(B, L, \text{patch_size}^2 \times C)`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shape is :math:`(B, 3, H, W)`.
|
||||
torch.Tensor: The shape is :math:`(B, C, H, W)`.
|
||||
"""
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1]**.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_channels))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||
imgs = x.reshape(shape=(x.shape[0], self.in_channels, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -71,7 +74,7 @@ class MAEPretrainHead(BaseModule):
|
|||
normalize the image according to ``norm_pix``.
|
||||
|
||||
Args:
|
||||
target (torch.Tensor): Image with the shape of B x 3 x H x W
|
||||
target (torch.Tensor): Image with the shape of B x C x H x W
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tokenized images with the shape of B x L x C
|
||||
|
|
Loading…
Reference in New Issue