From 40dbaafef564b9abf270d0a6ad988dddcbf41b25 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 22 Aug 2023 15:19:12 -0700 Subject: [PATCH] Stagify FastViT /w downsample to top of stage --- timm/models/fastvit.py | 282 ++++++++++++++++++++--------------------- 1 file changed, 141 insertions(+), 141 deletions(-) diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index 133cb39d..8e5051de 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -761,16 +761,16 @@ class RepCPE(nn.Module): def __init__( self, - in_chs: int, - embed_dim: int = 768, + dim: int, + dim_out: Optional[int] = None, spatial_shape: Union[int, Tuple[int, int]] = (7, 7), inference_mode=False, ) -> None: """Build reparameterizable conditional positional encoding Args: - in_chs: Number of input channels. - embed_dim: Number of embedding dimensions. Default: 768 + dim: Number of input channels. + dim_out: Number of embedding dimensions. Default: 768 spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """ @@ -787,29 +787,29 @@ class RepCPE(nn.Module): ) self.spatial_shape = spatial_shape - self.embed_dim = embed_dim - self.in_chs = in_chs - self.groups = embed_dim + self.dim = dim + self.dim_out = dim_out or dim + self.groups = dim if inference_mode: self.reparam_conv = nn.Conv2d( - self.in_chs, - self.embed_dim, + self.dim, + self.dim_out, kernel_size=self.spatial_shape, stride=1, padding=spatial_shape[0] // 2, - groups=self.embed_dim, + groups=self.groups, bias=True, ) else: self.reparam_conv = None self.pe = nn.Conv2d( - in_chs, - embed_dim, + self.dim, + self.dim_out, spatial_shape, 1, int(spatial_shape[0] // 2), - groups=embed_dim, + groups=self.groups, bias=True, ) @@ -823,10 +823,10 @@ class RepCPE(nn.Module): def reparameterize(self) -> None: # Build equivalent Id tensor - input_dim = self.in_chs // self.groups + input_dim = self.dim // self.groups kernel_value = torch.zeros( ( - self.in_chs, + self.dim, input_dim, self.spatial_shape[0], self.spatial_shape[1], @@ -834,7 +834,7 @@ class RepCPE(nn.Module): dtype=self.pe.weight.dtype, device=self.pe.weight.device, ) - for i in range(self.in_chs): + for i in range(self.dim): kernel_value[ i, i % input_dim, @@ -849,12 +849,12 @@ class RepCPE(nn.Module): # Introduce reparam conv self.reparam_conv = nn.Conv2d( - self.in_chs, - self.embed_dim, + self.dim, + self.dim_out, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), - groups=self.embed_dim, + groups=self.groups, bias=True, ) self.reparam_conv.weight.data = w_final @@ -1002,78 +1002,97 @@ class AttentionBlock(nn.Module): return x -def basic_blocks( - dim: int, - block_index: int, - num_blocks: List[int], - token_mixer_type: str, - kernel_size: int = 3, - mlp_ratio: float = 4.0, - act_layer: nn.Module = nn.GELU, - norm_layer: nn.Module = nn.BatchNorm2d, - drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - use_layer_scale: bool = True, - layer_scale_init_value: float = 1e-5, - inference_mode=False, -) -> nn.Sequential: - """Build FastViT blocks within a stage. +class FastVitStage(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + depth: int, + token_mixer_type: str, + downsample: bool = True, + down_patch_size: int = 7, + down_stride: int = 2, + pos_emb_layer: Optional[nn.Module] = None, + kernel_size: int = 3, + mlp_ratio: float = 4.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.BatchNorm2d, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + inference_mode=False, + ): + """FastViT stage. - Args: - dim: Number of embedding dimensions. - block_index: block index. - num_blocks: List containing number of blocks per stage. - token_mixer_type: Token mixer type. - kernel_size: Kernel size for repmixer. - mlp_ratio: MLP expansion ratio. - act_layer: Activation layer. - norm_layer: Normalization layer. - drop_rate: Dropout rate. - drop_path_rate: Drop path rate. - use_layer_scale: Flag to turn on layer scale regularization. - layer_scale_init_value: Layer scale value at initialization. - inference_mode: Flag to instantiate block in inference mode. - - Returns: - nn.Sequential object of all the blocks within the stage. - """ - blocks = [] - for block_idx in range(num_blocks[block_index]): - block_dpr = ( - drop_path_rate - * (block_idx + sum(num_blocks[:block_index])) - / (sum(num_blocks) - 1) - ) - if token_mixer_type == "repmixer": - blocks.append(RepMixerBlock( - dim, - kernel_size=kernel_size, - mlp_ratio=mlp_ratio, - act_layer=act_layer, - drop=drop_rate, - drop_path=block_dpr, - use_layer_scale=use_layer_scale, - layer_scale_init_value=layer_scale_init_value, + Args: + dim: Number of embedding dimensions. + num_blocks: List containing number of blocks per stage. + token_mixer_type: Token mixer type. + kernel_size: Kernel size for repmixer. + mlp_ratio: MLP expansion ratio. + act_layer: Activation layer. + norm_layer: Normalization layer. + drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + use_layer_scale: Flag to turn on layer scale regularization. + layer_scale_init_value: Layer scale value at initialization. + inference_mode: Flag to instantiate block in inference mode. + """ + super().__init__() + if downsample: + self.downsample = PatchEmbed( + patch_size=down_patch_size, + stride=down_stride, + in_chs=dim, + embed_dim=dim_out, inference_mode=inference_mode, - )) - elif token_mixer_type == "attention": - blocks.append(AttentionBlock( - dim, - mlp_ratio=mlp_ratio, - act_layer=act_layer, - norm_layer=norm_layer, - drop=drop_rate, - drop_path=block_dpr, - use_layer_scale=use_layer_scale, - layer_scale_init_value=layer_scale_init_value, - )) - else: - raise ValueError( - "Token mixer type: {} not supported".format(token_mixer_type) ) - blocks = nn.Sequential(*blocks) + else: + assert dim == dim_out + self.downsample = nn.Identity() - return blocks + if pos_emb_layer is not None: + self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode) + else: + self.pos_emb = nn.Identity() + + blocks = [] + for block_idx in range(depth): + if token_mixer_type == "repmixer": + blocks.append(RepMixerBlock( + dim_out, + kernel_size=kernel_size, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + drop=drop_rate, + drop_path=drop_path_rate[block_idx], + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + inference_mode=inference_mode, + )) + elif token_mixer_type == "attention": + blocks.append(AttentionBlock( + dim_out, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + drop=drop_rate, + drop_path=drop_path_rate[block_idx], + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + )) + else: + raise ValueError( + "Token mixer type: {} not supported".format(token_mixer_type) + ) + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.pos_emb(x) + x = self.blocks(x) + return x class FastVit(nn.Module): @@ -1085,78 +1104,66 @@ class FastVit(nn.Module): def __init__( self, - in_chans=3, - layers=(2, 2, 6, 2), + in_chans: int = 3, + layers: Tuple[int, ...] = (2, 2, 6, 2), token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"), - embed_dims=None, - mlp_ratios=None, - downsamples=None, - repmixer_kernel_size=3, + embed_dims: Tuple[int, ...] = (64, 128, 256, 512), + mlp_ratios: Tuple[float, ...] = (4,) * 4, + downsamples: Tuple[bool, ...] = (False, True, True, True), + repmixer_kernel_size: int = 3, + num_classes: int = 1000, + pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4, + down_patch_size: int = 7, + down_stride: int = 2, + drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + fork_feat: bool = False, + cls_ratio: float = 2.0, norm_layer: nn.Module = nn.BatchNorm2d, act_layer: nn.Module = nn.GELU, - num_classes=1000, - pos_embs=None, - down_patch_size=7, - down_stride=2, - drop_rate=0.0, - drop_path_rate=0.0, - use_layer_scale=True, - layer_scale_init_value=1e-5, - fork_feat=False, - cls_ratio=2.0, - inference_mode=False, + inference_mode: bool = False, ) -> None: super().__init__() self.num_classes = 0 if fork_feat else num_classes self.fork_feat = fork_feat - if pos_embs is None: - pos_embs = [None] * len(layers) - # Convolutional stem self.patch_embed = convolutional_stem( - in_chans, embed_dims[0], inference_mode) + in_chans, + embed_dims[0], + inference_mode, + ) # Build the main stages of the network architecture + prev_dim = embed_dims[0] + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] network = [] for i in range(len(layers)): - # Add position embeddings if requested - if pos_embs[i] is not None: - network.append(pos_embs[i]( - embed_dims[i], - embed_dims[i], - inference_mode=inference_mode, - )) - stage = basic_blocks( - embed_dims[i], - i, - layers, + stage = FastVitStage( + dim=prev_dim, + dim_out=embed_dims[i], + depth=layers[i], + downsample=downsamples[i] or prev_dim != embed_dims[i], + down_patch_size=down_patch_size, + down_stride=down_stride, + pos_emb_layer=pos_embs[i], token_mixer_type=token_mixers[i], kernel_size=repmixer_kernel_size, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop_rate=drop_rate, - drop_path_rate=drop_path_rate, + drop_path_rate=dpr[i], use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) network.append(stage) - if i >= len(layers) - 1: - break + prev_dim = embed_dims[i] - # Patch merging/downsampling between stages. - if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: - network += [PatchEmbed( - patch_size=down_patch_size, - stride=down_stride, - in_chs=embed_dims[i], - embed_dim=embed_dims[i + 1], - inference_mode=inference_mode, - )] - - self.network = nn.ModuleList(network) + self.network = nn.Sequential(*network) # For segmentation and detection, extract intermediate output if self.fork_feat: @@ -1338,7 +1345,6 @@ def fastvit_t8(pretrained=False, **kwargs): layers=(2, 2, 4, 2), embed_dims=(48, 96, 192, 384), mlp_ratios=(3, 3, 3, 3), - downsamples=(True, True, True, True), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer") ) return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1351,8 +1357,7 @@ def fastvit_t12(pretrained=False, **kwargs): layers=(2, 2, 6, 2), embed_dims=(64, 128, 256, 512), mlp_ratios=(3, 3, 3, 3), - downsamples=(True, True, True, True), - token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer"), + token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), ) return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1364,7 +1369,6 @@ def fastvit_s12(pretrained=False, **kwargs): layers=(2, 2, 6, 2), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), - downsamples=(True, True, True, True), token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"), ) return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -1377,7 +1381,6 @@ def fastvit_sa12(pretrained=False, **kwargs): layers=(2, 2, 6, 2), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), - downsamples=(True, True, True, True), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) @@ -1391,7 +1394,6 @@ def fastvit_sa24(pretrained=False, **kwargs): layers=(4, 4, 12, 4), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), - downsamples=(True, True, True, True), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) @@ -1405,7 +1407,6 @@ def fastvit_sa36(pretrained=False, **kwargs): layers=(6, 6, 18, 6), embed_dims=(64, 128, 256, 512), mlp_ratios=(4, 4, 4, 4), - downsamples=(True, True, True, True), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention"), ) @@ -1418,7 +1419,6 @@ def fastvit_ma36(pretrained=False, **kwargs): layers=(6, 6, 18, 6), embed_dims=(76, 152, 304, 608), mlp_ratios=(4, 4, 4, 4), - downsamples=(True, True, True, True), pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))), token_mixers=("repmixer", "repmixer", "repmixer", "attention") )