mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Bit more Hiera fiddling
This commit is contained in:
parent
8a54d2a930
commit
d88bed6535
@ -47,7 +47,7 @@ def conv_nd(n: int) -> Type[nn.Module]:
|
|||||||
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
|
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
|
||||||
|
|
||||||
|
|
||||||
def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor:
|
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
|
||||||
# target_size: [(T), (H), W]
|
# target_size: [(T), (H), W]
|
||||||
# (spatial) mask: [B, C, (t), (h), w]
|
# (spatial) mask: [B, C, (t), (h), w]
|
||||||
if mask is None:
|
if mask is None:
|
||||||
@ -59,23 +59,6 @@ def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tenso
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def do_masked_conv(
|
|
||||||
x: torch.Tensor,
|
|
||||||
conv: nn.Module,
|
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Zero-out the masked regions of the input before conv.
|
|
||||||
Prevents leakage of masked regions when using overlapping kernels.
|
|
||||||
"""
|
|
||||||
if conv is None:
|
|
||||||
return x
|
|
||||||
if mask is None:
|
|
||||||
return conv(x)
|
|
||||||
|
|
||||||
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
|
|
||||||
return conv(x * mask.bool())
|
|
||||||
|
|
||||||
|
|
||||||
def undo_windowing(
|
def undo_windowing(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
shape: List[int],
|
shape: List[int],
|
||||||
@ -145,7 +128,6 @@ class Unroll(nn.Module):
|
|||||||
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
|
Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
|
||||||
"""
|
"""
|
||||||
B, _, C = x.shape
|
B, _, C = x.shape
|
||||||
|
|
||||||
cur_size = self.size
|
cur_size = self.size
|
||||||
x = x.view(*([B] + cur_size + [C]))
|
x = x.view(*([B] + cur_size + [C]))
|
||||||
|
|
||||||
@ -332,6 +314,7 @@ class HieraBlock(nn.Module):
|
|||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: nn.Module = nn.GELU,
|
||||||
q_stride: int = 1,
|
q_stride: int = 1,
|
||||||
window_size: int = 0,
|
window_size: int = 0,
|
||||||
|
use_expand_proj: bool = True,
|
||||||
use_mask_unit_attn: bool = False,
|
use_mask_unit_attn: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -341,8 +324,14 @@ class HieraBlock(nn.Module):
|
|||||||
|
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
if dim != dim_out:
|
if dim != dim_out:
|
||||||
self.proj = nn.Linear(dim, dim_out)
|
self.do_expand = True
|
||||||
|
if use_expand_proj:
|
||||||
|
self.proj = nn.Linear(dim, dim_out)
|
||||||
|
else:
|
||||||
|
assert dim_out == dim * 2
|
||||||
|
self.proj = None
|
||||||
else:
|
else:
|
||||||
|
self.do_expand = False
|
||||||
self.proj = None
|
self.proj = None
|
||||||
self.attn = MaskUnitAttention(
|
self.attn = MaskUnitAttention(
|
||||||
dim,
|
dim,
|
||||||
@ -362,9 +351,17 @@ class HieraBlock(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# Attention + Q Pooling
|
# Attention + Q Pooling
|
||||||
x_norm = self.norm1(x)
|
x_norm = self.norm1(x)
|
||||||
if self.proj is not None:
|
if self.do_expand:
|
||||||
x = self.proj(x_norm)
|
if self.proj is not None:
|
||||||
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
|
x = self.proj(x_norm)
|
||||||
|
x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
|
||||||
|
else:
|
||||||
|
x = torch.cat([
|
||||||
|
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool
|
||||||
|
x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
x = x + self.drop_path1(self.attn(x_norm))
|
x = x + self.drop_path1(self.attn(x_norm))
|
||||||
|
|
||||||
# MLP
|
# MLP
|
||||||
@ -419,7 +416,11 @@ class PatchEmbed(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = do_masked_conv(x, self.proj, mask)
|
if mask is not None:
|
||||||
|
mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
|
||||||
|
x = self.proj(x * mask.to(torch.bool))
|
||||||
|
else:
|
||||||
|
x = self.proj(x)
|
||||||
if self.reshape:
|
if self.reshape:
|
||||||
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
|
x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
|
||||||
return x
|
return x
|
||||||
@ -570,10 +571,10 @@ class Hiera(nn.Module):
|
|||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
if self.sep_pos_embed:
|
if self.pos_embed is not None:
|
||||||
return ["pos_embed_spatial", "pos_embed_temporal"]
|
|
||||||
else:
|
|
||||||
return ["pos_embed"]
|
return ["pos_embed"]
|
||||||
|
else:
|
||||||
|
return ["pos_embed_spatial", "pos_embed_temporal"]
|
||||||
|
|
||||||
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user