Stagify FastViT /w downsample to top of stage

This commit is contained in:
Ross Wightman 2023-08-22 15:19:12 -07:00 committed by Ross Wightman
parent 8470eb1cb5
commit 40dbaafef5

View File

@ -761,16 +761,16 @@ class RepCPE(nn.Module):
def __init__( def __init__(
self, self,
in_chs: int, dim: int,
embed_dim: int = 768, dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7), spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False, inference_mode=False,
) -> None: ) -> None:
"""Build reparameterizable conditional positional encoding """Build reparameterizable conditional positional encoding
Args: Args:
in_chs: Number of input channels. dim: Number of input channels.
embed_dim: Number of embedding dimensions. Default: 768 dim_out: Number of embedding dimensions. Default: 768
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False`` inference_mode: Flag to instantiate block in inference mode. Default: ``False``
""" """
@ -787,29 +787,29 @@ class RepCPE(nn.Module):
) )
self.spatial_shape = spatial_shape self.spatial_shape = spatial_shape
self.embed_dim = embed_dim self.dim = dim
self.in_chs = in_chs self.dim_out = dim_out or dim
self.groups = embed_dim self.groups = dim
if inference_mode: if inference_mode:
self.reparam_conv = nn.Conv2d( self.reparam_conv = nn.Conv2d(
self.in_chs, self.dim,
self.embed_dim, self.dim_out,
kernel_size=self.spatial_shape, kernel_size=self.spatial_shape,
stride=1, stride=1,
padding=spatial_shape[0] // 2, padding=spatial_shape[0] // 2,
groups=self.embed_dim, groups=self.groups,
bias=True, bias=True,
) )
else: else:
self.reparam_conv = None self.reparam_conv = None
self.pe = nn.Conv2d( self.pe = nn.Conv2d(
in_chs, self.dim,
embed_dim, self.dim_out,
spatial_shape, spatial_shape,
1, 1,
int(spatial_shape[0] // 2), int(spatial_shape[0] // 2),
groups=embed_dim, groups=self.groups,
bias=True, bias=True,
) )
@ -823,10 +823,10 @@ class RepCPE(nn.Module):
def reparameterize(self) -> None: def reparameterize(self) -> None:
# Build equivalent Id tensor # Build equivalent Id tensor
input_dim = self.in_chs // self.groups input_dim = self.dim // self.groups
kernel_value = torch.zeros( kernel_value = torch.zeros(
( (
self.in_chs, self.dim,
input_dim, input_dim,
self.spatial_shape[0], self.spatial_shape[0],
self.spatial_shape[1], self.spatial_shape[1],
@ -834,7 +834,7 @@ class RepCPE(nn.Module):
dtype=self.pe.weight.dtype, dtype=self.pe.weight.dtype,
device=self.pe.weight.device, device=self.pe.weight.device,
) )
for i in range(self.in_chs): for i in range(self.dim):
kernel_value[ kernel_value[
i, i,
i % input_dim, i % input_dim,
@ -849,12 +849,12 @@ class RepCPE(nn.Module):
# Introduce reparam conv # Introduce reparam conv
self.reparam_conv = nn.Conv2d( self.reparam_conv = nn.Conv2d(
self.in_chs, self.dim,
self.embed_dim, self.dim_out,
kernel_size=self.spatial_shape, kernel_size=self.spatial_shape,
stride=1, stride=1,
padding=int(self.spatial_shape[0] // 2), padding=int(self.spatial_shape[0] // 2),
groups=self.embed_dim, groups=self.groups,
bias=True, bias=True,
) )
self.reparam_conv.weight.data = w_final self.reparam_conv.weight.data = w_final
@ -1002,11 +1002,17 @@ class AttentionBlock(nn.Module):
return x return x
def basic_blocks( class FastVitStage(nn.Module):
def __init__(
self,
dim: int, dim: int,
block_index: int, dim_out: int,
num_blocks: List[int], depth: int,
token_mixer_type: str, token_mixer_type: str,
downsample: bool = True,
down_patch_size: int = 7,
down_stride: int = 2,
pos_emb_layer: Optional[nn.Module] = None,
kernel_size: int = 3, kernel_size: int = 3,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU, act_layer: nn.Module = nn.GELU,
@ -1016,12 +1022,11 @@ def basic_blocks(
use_layer_scale: bool = True, use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5, layer_scale_init_value: float = 1e-5,
inference_mode=False, inference_mode=False,
) -> nn.Sequential: ):
"""Build FastViT blocks within a stage. """FastViT stage.
Args: Args:
dim: Number of embedding dimensions. dim: Number of embedding dimensions.
block_index: block index.
num_blocks: List containing number of blocks per stage. num_blocks: List containing number of blocks per stage.
token_mixer_type: Token mixer type. token_mixer_type: Token mixer type.
kernel_size: Kernel size for repmixer. kernel_size: Kernel size for repmixer.
@ -1033,37 +1038,47 @@ def basic_blocks(
use_layer_scale: Flag to turn on layer scale regularization. use_layer_scale: Flag to turn on layer scale regularization.
layer_scale_init_value: Layer scale value at initialization. layer_scale_init_value: Layer scale value at initialization.
inference_mode: Flag to instantiate block in inference mode. inference_mode: Flag to instantiate block in inference mode.
Returns:
nn.Sequential object of all the blocks within the stage.
""" """
blocks = [] super().__init__()
for block_idx in range(num_blocks[block_index]): if downsample:
block_dpr = ( self.downsample = PatchEmbed(
drop_path_rate patch_size=down_patch_size,
* (block_idx + sum(num_blocks[:block_index])) stride=down_stride,
/ (sum(num_blocks) - 1) in_chs=dim,
embed_dim=dim_out,
inference_mode=inference_mode,
) )
else:
assert dim == dim_out
self.downsample = nn.Identity()
if pos_emb_layer is not None:
self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode)
else:
self.pos_emb = nn.Identity()
blocks = []
for block_idx in range(depth):
if token_mixer_type == "repmixer": if token_mixer_type == "repmixer":
blocks.append(RepMixerBlock( blocks.append(RepMixerBlock(
dim, dim_out,
kernel_size=kernel_size, kernel_size=kernel_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
act_layer=act_layer, act_layer=act_layer,
drop=drop_rate, drop=drop_rate,
drop_path=block_dpr, drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode, inference_mode=inference_mode,
)) ))
elif token_mixer_type == "attention": elif token_mixer_type == "attention":
blocks.append(AttentionBlock( blocks.append(AttentionBlock(
dim, dim_out,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
drop=drop_rate, drop=drop_rate,
drop_path=block_dpr, drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
)) ))
@ -1071,9 +1086,13 @@ def basic_blocks(
raise ValueError( raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type) "Token mixer type: {} not supported".format(token_mixer_type)
) )
blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
return blocks def forward(self, x):
x = self.downsample(x)
x = self.pos_emb(x)
x = self.blocks(x)
return x
class FastVit(nn.Module): class FastVit(nn.Module):
@ -1085,78 +1104,66 @@ class FastVit(nn.Module):
def __init__( def __init__(
self, self,
in_chans=3, in_chans: int = 3,
layers=(2, 2, 6, 2), layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims=None, embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios=None, mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples=None, downsamples: Tuple[bool, ...] = (False, True, True, True),
repmixer_kernel_size=3, repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
fork_feat: bool = False,
cls_ratio: float = 2.0,
norm_layer: nn.Module = nn.BatchNorm2d, norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU, act_layer: nn.Module = nn.GELU,
num_classes=1000, inference_mode: bool = False,
pos_embs=None,
down_patch_size=7,
down_stride=2,
drop_rate=0.0,
drop_path_rate=0.0,
use_layer_scale=True,
layer_scale_init_value=1e-5,
fork_feat=False,
cls_ratio=2.0,
inference_mode=False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_classes = 0 if fork_feat else num_classes self.num_classes = 0 if fork_feat else num_classes
self.fork_feat = fork_feat self.fork_feat = fork_feat
if pos_embs is None:
pos_embs = [None] * len(layers)
# Convolutional stem # Convolutional stem
self.patch_embed = convolutional_stem( self.patch_embed = convolutional_stem(
in_chans, embed_dims[0], inference_mode) in_chans,
embed_dims[0],
inference_mode,
)
# Build the main stages of the network architecture # Build the main stages of the network architecture
prev_dim = embed_dims[0]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
network = [] network = []
for i in range(len(layers)): for i in range(len(layers)):
# Add position embeddings if requested stage = FastVitStage(
if pos_embs[i] is not None: dim=prev_dim,
network.append(pos_embs[i]( dim_out=embed_dims[i],
embed_dims[i], depth=layers[i],
embed_dims[i], downsample=downsamples[i] or prev_dim != embed_dims[i],
inference_mode=inference_mode, down_patch_size=down_patch_size,
)) down_stride=down_stride,
stage = basic_blocks( pos_emb_layer=pos_embs[i],
embed_dims[i],
i,
layers,
token_mixer_type=token_mixers[i], token_mixer_type=token_mixers[i],
kernel_size=repmixer_kernel_size, kernel_size=repmixer_kernel_size,
mlp_ratio=mlp_ratios[i], mlp_ratio=mlp_ratios[i],
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path_rate=drop_path_rate, drop_path_rate=dpr[i],
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode, inference_mode=inference_mode,
) )
network.append(stage) network.append(stage)
if i >= len(layers) - 1: prev_dim = embed_dims[i]
break
# Patch merging/downsampling between stages. self.network = nn.Sequential(*network)
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
network += [PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=embed_dims[i],
embed_dim=embed_dims[i + 1],
inference_mode=inference_mode,
)]
self.network = nn.ModuleList(network)
# For segmentation and detection, extract intermediate output # For segmentation and detection, extract intermediate output
if self.fork_feat: if self.fork_feat:
@ -1338,7 +1345,6 @@ def fastvit_t8(pretrained=False, **kwargs):
layers=(2, 2, 4, 2), layers=(2, 2, 4, 2),
embed_dims=(48, 96, 192, 384), embed_dims=(48, 96, 192, 384),
mlp_ratios=(3, 3, 3, 3), mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer") token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
) )
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1351,7 +1357,6 @@ def fastvit_t12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2), layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3), mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
) )
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1364,7 +1369,6 @@ def fastvit_s12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2), layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
) )
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1377,7 +1381,6 @@ def fastvit_sa12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2), layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
@ -1391,7 +1394,6 @@ def fastvit_sa24(pretrained=False, **kwargs):
layers=(4, 4, 12, 4), layers=(4, 4, 12, 4),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
@ -1405,7 +1407,6 @@ def fastvit_sa36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6), layers=(6, 6, 18, 6),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
@ -1418,7 +1419,6 @@ def fastvit_ma36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6), layers=(6, 6, 18, 6),
embed_dims=(76, 152, 304, 608), embed_dims=(76, 152, 304, 608),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention") token_mixers=("repmixer", "repmixer", "repmixer", "attention")
) )