Update forward_intermediates for hiera to have its own fwd impl w/ early stopping. Remove return_intermediates bool from forward(). Still an fx issue with None mask arg :(
parent
e8b08a4e7b
commit
c6db4043cd
|
@ -32,12 +32,13 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn
|
||||
from timm.layers import DropPath, Mlp, use_fused_attn, _assert
|
||||
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
|
||||
|
||||
def conv_nd(n: int) -> Type[nn.Module]:
|
||||
|
@ -48,13 +49,14 @@ def conv_nd(n: int) -> Type[nn.Module]:
|
|||
return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
|
||||
# target_size: [(T), (H), W]
|
||||
# (spatial) mask: [B, C, (t), (h), w]
|
||||
if mask is None:
|
||||
return mask
|
||||
|
||||
assert len(mask.shape[2:]) == len(target_size)
|
||||
_assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.")
|
||||
if mask.shape[2:] != target_size:
|
||||
return F.interpolate(mask.float(), size=target_size)
|
||||
return mask
|
||||
|
@ -622,6 +624,7 @@ class Hiera(nn.Module):
|
|||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
|
@ -643,10 +646,31 @@ class Hiera(nn.Module):
|
|||
assert not norm, 'normalization of features not supported'
|
||||
assert output_fmt in ('NCHW',), 'Output format must be one of NCHW.'
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
|
||||
if mask is not None:
|
||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||
else:
|
||||
patch_mask = None
|
||||
x = self.patch_embed(x, mask=patch_mask)
|
||||
x = self._pos_embed(x)
|
||||
x = self.unroll(x)
|
||||
|
||||
# Discard masked tokens
|
||||
if mask is not None:
|
||||
x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
|
||||
|
||||
intermediates = []
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
intermediates.append(self.reroll(x, i, mask=mask).permute(0, 3, 1, 2))
|
||||
|
||||
# FIXME using existing return_intermediates support in model, doesn't have early stopping.
|
||||
x, intermediates = self.forward_features(x, return_intermediates=True)
|
||||
intermediates = [y.permute(0, 3, 1, 2) for i, y in enumerate(intermediates) if i in take_indices]
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
|
@ -673,18 +697,18 @@ class Hiera(nn.Module):
|
|||
def forward_features(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
return_intermediates: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
|
||||
Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
|
||||
"""
|
||||
x = self.patch_embed(
|
||||
x,
|
||||
mask=mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||
if mask is not None else None,
|
||||
)
|
||||
if mask is not None:
|
||||
patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
|
||||
else:
|
||||
patch_mask = None
|
||||
x = self.patch_embed(x, mask=patch_mask)
|
||||
x = self._pos_embed(x)
|
||||
x = self.unroll(x)
|
||||
|
||||
|
@ -718,19 +742,12 @@ class Hiera(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
return_intermediates: bool = False,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if return_intermediates:
|
||||
x, intermediates = self.forward_features(x, mask=mask, return_intermediates=return_intermediates)
|
||||
if mask is not None:
|
||||
x = self.forward_head(x)
|
||||
return x, intermediates
|
||||
else:
|
||||
x = self.forward_features(x, mask=mask)
|
||||
if mask is None:
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
x = self.forward_features(x, mask=mask)
|
||||
if mask is None:
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
|
Loading…
Reference in New Issue