Update forward_intermediates for hiera to have its own fwd impl w/ early stopping. Remove return_intermediates bool from forward(). Still an fx issue with None mask arg :(

pull/2156/head
Ross Wightman 2024-04-29 17:23:37 -07:00
parent e8b08a4e7b
commit c6db4043cd
1 changed files with 40 additions and 23 deletions

View File

@ -32,12 +32,13 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, use_fused_attn
from timm.layers import DropPath, Mlp, use_fused_attn, _assert
from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
def conv_nd(n: int) -> Type[nn.Module]:
@ -48,13 +49,14 @@ def conv_nd(n: int) -> Type[nn.Module]:
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
@register_notrace_function
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
# target_size: [(T), (H), W]
# (spatial) mask: [B, C, (t), (h), w]
if mask is None:
return mask
assert len(mask.shape[2:]) == len(target_size)
_assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.")
if mask.shape[2:] != target_size:
return F.interpolate(mask.float(), size=target_size)
return mask
@ -622,6 +624,7 @@ class Hiera(nn.Module):
def forward_intermediates(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
norm: bool = False,
stop_early: bool = True,
@ -643,10 +646,31 @@ class Hiera(nn.Module):
assert not norm, 'normalization of features not supported'
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
take_indices = [self.stage_ends[i] for i in take_indices]
max_index = self.stage_ends[max_index]
if mask is not None:
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
else:
patch_mask = None
x = self.patch_embed(x, mask=patch_mask)
x = self._pos_embed(x)
x = self.unroll(x)
# Discard masked tokens
if mask is not None:
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
intermediates = []
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
blocks = self.blocks
else:
blocks = self.blocks[:max_index + 1]
for i, blk in enumerate(blocks):
x = blk(x)
if i in take_indices:
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
# FIXME using existing return_intermediates support in model, doesn't have early stopping.
x, intermediates = self.forward_features(x, return_intermediates=True)
intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices]
if intermediates_only:
return intermediates
@ -673,18 +697,18 @@ class Hiera(nn.Module):
def forward_features(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
mask: Optional[torch.Tensor] = None,
return_intermediates: bool = False,
) -> torch.Tensor:
"""
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
"""
x = self.patch_embed(
x,
mask=mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
if mask is not None else None,
)
if mask is not None:
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
else:
patch_mask = None
x = self.patch_embed(x, mask=patch_mask)
x = self._pos_embed(x)
x = self.unroll(x)
@ -718,19 +742,12 @@ class Hiera(nn.Module):
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
return_intermediates: bool = False,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if return_intermediates:
x, intermediates = self.forward_features(x, mask=mask, return_intermediates=return_intermediates)
if mask is not None:
x = self.forward_head(x)
return x, intermediates
else:
x = self.forward_features(x, mask=mask)
if mask is None:
x = self.forward_head(x)
return x
x = self.forward_features(x, mask=mask)
if mask is None:
x = self.forward_head(x)
return x
def _cfg(url='', **kwargs):