Stagify FastViT /w downsample to top of stage

This commit is contained in:
Ross Wightman 2023-08-22 15:19:12 -07:00 committed by Ross Wightman
parent 8470eb1cb5
commit 40dbaafef5

View File

@ -761,16 +761,16 @@ class RepCPE(nn.Module):
def __init__( def __init__(
self, self,
in_chs: int, dim: int,
embed_dim: int = 768, dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7), spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False, inference_mode=False,
) -> None: ) -> None:
"""Build reparameterizable conditional positional encoding """Build reparameterizable conditional positional encoding
Args: Args:
in_chs: Number of input channels. dim: Number of input channels.
embed_dim: Number of embedding dimensions. Default: 768 dim_out: Number of embedding dimensions. Default: 768
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False`` inference_mode: Flag to instantiate block in inference mode. Default: ``False``
""" """
@ -787,29 +787,29 @@ class RepCPE(nn.Module):
) )
self.spatial_shape = spatial_shape self.spatial_shape = spatial_shape
self.embed_dim = embed_dim self.dim = dim
self.in_chs = in_chs self.dim_out = dim_out or dim
self.groups = embed_dim self.groups = dim
if inference_mode: if inference_mode:
self.reparam_conv = nn.Conv2d( self.reparam_conv = nn.Conv2d(
self.in_chs, self.dim,
self.embed_dim, self.dim_out,
kernel_size=self.spatial_shape, kernel_size=self.spatial_shape,
stride=1, stride=1,
padding=spatial_shape[0] // 2, padding=spatial_shape[0] // 2,
groups=self.embed_dim, groups=self.groups,
bias=True, bias=True,
) )
else: else:
self.reparam_conv = None self.reparam_conv = None
self.pe = nn.Conv2d( self.pe = nn.Conv2d(
in_chs, self.dim,
embed_dim, self.dim_out,
spatial_shape, spatial_shape,
1, 1,
int(spatial_shape[0] // 2), int(spatial_shape[0] // 2),
groups=embed_dim, groups=self.groups,
bias=True, bias=True,
) )
@ -823,10 +823,10 @@ class RepCPE(nn.Module):
def reparameterize(self) -> None: def reparameterize(self) -> None:
# Build equivalent Id tensor # Build equivalent Id tensor
input_dim = self.in_chs // self.groups input_dim = self.dim // self.groups
kernel_value = torch.zeros( kernel_value = torch.zeros(
( (
self.in_chs, self.dim,
input_dim, input_dim,
self.spatial_shape[0], self.spatial_shape[0],
self.spatial_shape[1], self.spatial_shape[1],
@ -834,7 +834,7 @@ class RepCPE(nn.Module):
dtype=self.pe.weight.dtype, dtype=self.pe.weight.dtype,
device=self.pe.weight.device, device=self.pe.weight.device,
) )
for i in range(self.in_chs): for i in range(self.dim):
kernel_value[ kernel_value[
i, i,
i % input_dim, i % input_dim,
@ -849,12 +849,12 @@ class RepCPE(nn.Module):
# Introduce reparam conv # Introduce reparam conv
self.reparam_conv = nn.Conv2d( self.reparam_conv = nn.Conv2d(
self.in_chs, self.dim,
self.embed_dim, self.dim_out,
kernel_size=self.spatial_shape, kernel_size=self.spatial_shape,
stride=1, stride=1,
padding=int(self.spatial_shape[0] // 2), padding=int(self.spatial_shape[0] // 2),
groups=self.embed_dim, groups=self.groups,
bias=True, bias=True,
) )
self.reparam_conv.weight.data = w_final self.reparam_conv.weight.data = w_final
@ -1002,78 +1002,97 @@ class AttentionBlock(nn.Module):
return x return x
def basic_blocks( class FastVitStage(nn.Module):
dim: int, def __init__(
block_index: int, self,
num_blocks: List[int], dim: int,
token_mixer_type: str, dim_out: int,
kernel_size: int = 3, depth: int,
mlp_ratio: float = 4.0, token_mixer_type: str,
act_layer: nn.Module = nn.GELU, downsample: bool = True,
norm_layer: nn.Module = nn.BatchNorm2d, down_patch_size: int = 7,
drop_rate: float = 0.0, down_stride: int = 2,
drop_path_rate: float = 0.0, pos_emb_layer: Optional[nn.Module] = None,
use_layer_scale: bool = True, kernel_size: int = 3,
layer_scale_init_value: float = 1e-5, mlp_ratio: float = 4.0,
inference_mode=False, act_layer: nn.Module = nn.GELU,
) -> nn.Sequential: norm_layer: nn.Module = nn.BatchNorm2d,
"""Build FastViT blocks within a stage. drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
inference_mode=False,
):
"""FastViT stage.
Args: Args:
dim: Number of embedding dimensions. dim: Number of embedding dimensions.
block_index: block index. num_blocks: List containing number of blocks per stage.
num_blocks: List containing number of blocks per stage. token_mixer_type: Token mixer type.
token_mixer_type: Token mixer type. kernel_size: Kernel size for repmixer.
kernel_size: Kernel size for repmixer. mlp_ratio: MLP expansion ratio.
mlp_ratio: MLP expansion ratio. act_layer: Activation layer.
act_layer: Activation layer. norm_layer: Normalization layer.
norm_layer: Normalization layer. drop_rate: Dropout rate.
drop_rate: Dropout rate. drop_path_rate: Drop path rate.
drop_path_rate: Drop path rate. use_layer_scale: Flag to turn on layer scale regularization.
use_layer_scale: Flag to turn on layer scale regularization. layer_scale_init_value: Layer scale value at initialization.
layer_scale_init_value: Layer scale value at initialization. inference_mode: Flag to instantiate block in inference mode.
inference_mode: Flag to instantiate block in inference mode. """
super().__init__()
Returns: if downsample:
nn.Sequential object of all the blocks within the stage. self.downsample = PatchEmbed(
""" patch_size=down_patch_size,
blocks = [] stride=down_stride,
for block_idx in range(num_blocks[block_index]): in_chs=dim,
block_dpr = ( embed_dim=dim_out,
drop_path_rate
* (block_idx + sum(num_blocks[:block_index]))
/ (sum(num_blocks) - 1)
)
if token_mixer_type == "repmixer":
blocks.append(RepMixerBlock(
dim,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode, inference_mode=inference_mode,
))
elif token_mixer_type == "attention":
blocks.append(AttentionBlock(
dim,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
else:
raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type)
) )
blocks = nn.Sequential(*blocks) else:
assert dim == dim_out
self.downsample = nn.Identity()
return blocks if pos_emb_layer is not None:
self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode)
else:
self.pos_emb = nn.Identity()
blocks = []
for block_idx in range(depth):
if token_mixer_type == "repmixer":
blocks.append(RepMixerBlock(
dim_out,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
))
elif token_mixer_type == "attention":
blocks.append(AttentionBlock(
dim_out,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
else:
raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.downsample(x)
x = self.pos_emb(x)
x = self.blocks(x)
return x
class FastVit(nn.Module): class FastVit(nn.Module):
@ -1085,78 +1104,66 @@ class FastVit(nn.Module):
def __init__( def __init__(
self, self,
in_chans=3, in_chans: int = 3,
layers=(2, 2, 6, 2), layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims=None, embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios=None, mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples=None, downsamples: Tuple[bool, ...] = (False, True, True, True),
repmixer_kernel_size=3, repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
fork_feat: bool = False,
cls_ratio: float = 2.0,
norm_layer: nn.Module = nn.BatchNorm2d, norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU, act_layer: nn.Module = nn.GELU,
num_classes=1000, inference_mode: bool = False,
pos_embs=None,
down_patch_size=7,
down_stride=2,
drop_rate=0.0,
drop_path_rate=0.0,
use_layer_scale=True,
layer_scale_init_value=1e-5,
fork_feat=False,
cls_ratio=2.0,
inference_mode=False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_classes = 0 if fork_feat else num_classes self.num_classes = 0 if fork_feat else num_classes
self.fork_feat = fork_feat self.fork_feat = fork_feat
if pos_embs is None:
pos_embs = [None] * len(layers)
# Convolutional stem # Convolutional stem
self.patch_embed = convolutional_stem( self.patch_embed = convolutional_stem(
in_chans, embed_dims[0], inference_mode) in_chans,
embed_dims[0],
inference_mode,
)
# Build the main stages of the network architecture # Build the main stages of the network architecture
prev_dim = embed_dims[0]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
network = [] network = []
for i in range(len(layers)): for i in range(len(layers)):
# Add position embeddings if requested stage = FastVitStage(
if pos_embs[i] is not None: dim=prev_dim,
network.append(pos_embs[i]( dim_out=embed_dims[i],
embed_dims[i], depth=layers[i],
embed_dims[i], downsample=downsamples[i] or prev_dim != embed_dims[i],
inference_mode=inference_mode, down_patch_size=down_patch_size,
)) down_stride=down_stride,
stage = basic_blocks( pos_emb_layer=pos_embs[i],
embed_dims[i],
i,
layers,
token_mixer_type=token_mixers[i], token_mixer_type=token_mixers[i],
kernel_size=repmixer_kernel_size, kernel_size=repmixer_kernel_size,
mlp_ratio=mlp_ratios[i], mlp_ratio=mlp_ratios[i],
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
drop_rate=drop_rate, drop_rate=drop_rate,
drop_path_rate=drop_path_rate, drop_path_rate=dpr[i],
use_layer_scale=use_layer_scale, use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value, layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode, inference_mode=inference_mode,
) )
network.append(stage) network.append(stage)
if i >= len(layers) - 1: prev_dim = embed_dims[i]
break
# Patch merging/downsampling between stages. self.network = nn.Sequential(*network)
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
network += [PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=embed_dims[i],
embed_dim=embed_dims[i + 1],
inference_mode=inference_mode,
)]
self.network = nn.ModuleList(network)
# For segmentation and detection, extract intermediate output # For segmentation and detection, extract intermediate output
if self.fork_feat: if self.fork_feat:
@ -1338,7 +1345,6 @@ def fastvit_t8(pretrained=False, **kwargs):
layers=(2, 2, 4, 2), layers=(2, 2, 4, 2),
embed_dims=(48, 96, 192, 384), embed_dims=(48, 96, 192, 384),
mlp_ratios=(3, 3, 3, 3), mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer") token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
) )
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1351,8 +1357,7 @@ def fastvit_t12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2), layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3), mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer"),
) )
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1364,7 +1369,6 @@ def fastvit_s12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2), layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
) )
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1377,7 +1381,6 @@ def fastvit_sa12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2), layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
@ -1391,7 +1394,6 @@ def fastvit_sa24(pretrained=False, **kwargs):
layers=(4, 4, 12, 4), layers=(4, 4, 12, 4),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
@ -1405,7 +1407,6 @@ def fastvit_sa36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6), layers=(6, 6, 18, 6),
embed_dims=(64, 128, 256, 512), embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"), token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
) )
@ -1418,7 +1419,6 @@ def fastvit_ma36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6), layers=(6, 6, 18, 6),
embed_dims=(76, 152, 304, 608), embed_dims=(76, 152, 304, 608),
mlp_ratios=(4, 4, 4, 4), mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention") token_mixers=("repmixer", "repmixer", "repmixer", "attention")
) )