Stagify FastViT /w downsample to top of stage

This commit is contained in:
Ross Wightman 2023-08-22 15:19:12 -07:00 committed by Ross Wightman
parent 8470eb1cb5
commit 40dbaafef5

View File

@ -761,16 +761,16 @@ class RepCPE(nn.Module):
def __init__(
self,
in_chs: int,
embed_dim: int = 768,
dim: int,
dim_out: Optional[int] = None,
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
inference_mode=False,
) -> None:
"""Build reparameterizable conditional positional encoding
Args:
in_chs: Number of input channels.
embed_dim: Number of embedding dimensions. Default: 768
dim: Number of input channels.
dim_out: Number of embedding dimensions. Default: 768
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
"""
@ -787,29 +787,29 @@ class RepCPE(nn.Module):
)
self.spatial_shape = spatial_shape
self.embed_dim = embed_dim
self.in_chs = in_chs
self.groups = embed_dim
self.dim = dim
self.dim_out = dim_out or dim
self.groups = dim
if inference_mode:
self.reparam_conv = nn.Conv2d(
self.in_chs,
self.embed_dim,
self.dim,
self.dim_out,
kernel_size=self.spatial_shape,
stride=1,
padding=spatial_shape[0] // 2,
groups=self.embed_dim,
groups=self.groups,
bias=True,
)
else:
self.reparam_conv = None
self.pe = nn.Conv2d(
in_chs,
embed_dim,
self.dim,
self.dim_out,
spatial_shape,
1,
int(spatial_shape[0] // 2),
groups=embed_dim,
groups=self.groups,
bias=True,
)
@ -823,10 +823,10 @@ class RepCPE(nn.Module):
def reparameterize(self) -> None:
# Build equivalent Id tensor
input_dim = self.in_chs // self.groups
input_dim = self.dim // self.groups
kernel_value = torch.zeros(
(
self.in_chs,
self.dim,
input_dim,
self.spatial_shape[0],
self.spatial_shape[1],
@ -834,7 +834,7 @@ class RepCPE(nn.Module):
dtype=self.pe.weight.dtype,
device=self.pe.weight.device,
)
for i in range(self.in_chs):
for i in range(self.dim):
kernel_value[
i,
i % input_dim,
@ -849,12 +849,12 @@ class RepCPE(nn.Module):
# Introduce reparam conv
self.reparam_conv = nn.Conv2d(
self.in_chs,
self.embed_dim,
self.dim,
self.dim_out,
kernel_size=self.spatial_shape,
stride=1,
padding=int(self.spatial_shape[0] // 2),
groups=self.embed_dim,
groups=self.groups,
bias=True,
)
self.reparam_conv.weight.data = w_final
@ -1002,78 +1002,97 @@ class AttentionBlock(nn.Module):
return x
def basic_blocks(
dim: int,
block_index: int,
num_blocks: List[int],
token_mixer_type: str,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
inference_mode=False,
) -> nn.Sequential:
"""Build FastViT blocks within a stage.
class FastVitStage(nn.Module):
def __init__(
self,
dim: int,
dim_out: int,
depth: int,
token_mixer_type: str,
downsample: bool = True,
down_patch_size: int = 7,
down_stride: int = 2,
pos_emb_layer: Optional[nn.Module] = None,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
inference_mode=False,
):
"""FastViT stage.
Args:
dim: Number of embedding dimensions.
block_index: block index.
num_blocks: List containing number of blocks per stage.
token_mixer_type: Token mixer type.
kernel_size: Kernel size for repmixer.
mlp_ratio: MLP expansion ratio.
act_layer: Activation layer.
norm_layer: Normalization layer.
drop_rate: Dropout rate.
drop_path_rate: Drop path rate.
use_layer_scale: Flag to turn on layer scale regularization.
layer_scale_init_value: Layer scale value at initialization.
inference_mode: Flag to instantiate block in inference mode.
Returns:
nn.Sequential object of all the blocks within the stage.
"""
blocks = []
for block_idx in range(num_blocks[block_index]):
block_dpr = (
drop_path_rate
* (block_idx + sum(num_blocks[:block_index]))
/ (sum(num_blocks) - 1)
)
if token_mixer_type == "repmixer":
blocks.append(RepMixerBlock(
dim,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
Args:
dim: Number of embedding dimensions.
num_blocks: List containing number of blocks per stage.
token_mixer_type: Token mixer type.
kernel_size: Kernel size for repmixer.
mlp_ratio: MLP expansion ratio.
act_layer: Activation layer.
norm_layer: Normalization layer.
drop_rate: Dropout rate.
drop_path_rate: Drop path rate.
use_layer_scale: Flag to turn on layer scale regularization.
layer_scale_init_value: Layer scale value at initialization.
inference_mode: Flag to instantiate block in inference mode.
"""
super().__init__()
if downsample:
self.downsample = PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=dim,
embed_dim=dim_out,
inference_mode=inference_mode,
))
elif token_mixer_type == "attention":
blocks.append(AttentionBlock(
dim,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
else:
raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type)
)
blocks = nn.Sequential(*blocks)
else:
assert dim == dim_out
self.downsample = nn.Identity()
return blocks
if pos_emb_layer is not None:
self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode)
else:
self.pos_emb = nn.Identity()
blocks = []
for block_idx in range(depth):
if token_mixer_type == "repmixer":
blocks.append(RepMixerBlock(
dim_out,
kernel_size=kernel_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
drop=drop_rate,
drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
))
elif token_mixer_type == "attention":
blocks.append(AttentionBlock(
dim_out,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=drop_path_rate[block_idx],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
else:
raise ValueError(
"Token mixer type: {} not supported".format(token_mixer_type)
)
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.downsample(x)
x = self.pos_emb(x)
x = self.blocks(x)
return x
class FastVit(nn.Module):
@ -1085,78 +1104,66 @@ class FastVit(nn.Module):
def __init__(
self,
in_chans=3,
layers=(2, 2, 6, 2),
in_chans: int = 3,
layers: Tuple[int, ...] = (2, 2, 6, 2),
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
embed_dims=None,
mlp_ratios=None,
downsamples=None,
repmixer_kernel_size=3,
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
mlp_ratios: Tuple[float, ...] = (4,) * 4,
downsamples: Tuple[bool, ...] = (False, True, True, True),
repmixer_kernel_size: int = 3,
num_classes: int = 1000,
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
down_patch_size: int = 7,
down_stride: int = 2,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
use_layer_scale: bool = True,
layer_scale_init_value: float = 1e-5,
fork_feat: bool = False,
cls_ratio: float = 2.0,
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
num_classes=1000,
pos_embs=None,
down_patch_size=7,
down_stride=2,
drop_rate=0.0,
drop_path_rate=0.0,
use_layer_scale=True,
layer_scale_init_value=1e-5,
fork_feat=False,
cls_ratio=2.0,
inference_mode=False,
inference_mode: bool = False,
) -> None:
super().__init__()
self.num_classes = 0 if fork_feat else num_classes
self.fork_feat = fork_feat
if pos_embs is None:
pos_embs = [None] * len(layers)
# Convolutional stem
self.patch_embed = convolutional_stem(
in_chans, embed_dims[0], inference_mode)
in_chans,
embed_dims[0],
inference_mode,
)
# Build the main stages of the network architecture
prev_dim = embed_dims[0]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
network = []
for i in range(len(layers)):
# Add position embeddings if requested
if pos_embs[i] is not None:
network.append(pos_embs[i](
embed_dims[i],
embed_dims[i],
inference_mode=inference_mode,
))
stage = basic_blocks(
embed_dims[i],
i,
layers,
stage = FastVitStage(
dim=prev_dim,
dim_out=embed_dims[i],
depth=layers[i],
downsample=downsamples[i] or prev_dim != embed_dims[i],
down_patch_size=down_patch_size,
down_stride=down_stride,
pos_emb_layer=pos_embs[i],
token_mixer_type=token_mixers[i],
kernel_size=repmixer_kernel_size,
mlp_ratio=mlp_ratios[i],
act_layer=act_layer,
norm_layer=norm_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
drop_path_rate=dpr[i],
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
inference_mode=inference_mode,
)
network.append(stage)
if i >= len(layers) - 1:
break
prev_dim = embed_dims[i]
# Patch merging/downsampling between stages.
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
network += [PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
in_chs=embed_dims[i],
embed_dim=embed_dims[i + 1],
inference_mode=inference_mode,
)]
self.network = nn.ModuleList(network)
self.network = nn.Sequential(*network)
# For segmentation and detection, extract intermediate output
if self.fork_feat:
@ -1338,7 +1345,6 @@ def fastvit_t8(pretrained=False, **kwargs):
layers=(2, 2, 4, 2),
embed_dims=(48, 96, 192, 384),
mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
)
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1351,8 +1357,7 @@ def fastvit_t12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(3, 3, 3, 3),
downsamples=(True, True, True, True),
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer"),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
)
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1364,7 +1369,6 @@ def fastvit_s12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
)
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
@ -1377,7 +1381,6 @@ def fastvit_sa12(pretrained=False, **kwargs):
layers=(2, 2, 6, 2),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
@ -1391,7 +1394,6 @@ def fastvit_sa24(pretrained=False, **kwargs):
layers=(4, 4, 12, 4),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
@ -1405,7 +1407,6 @@ def fastvit_sa36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6),
embed_dims=(64, 128, 256, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
)
@ -1418,7 +1419,6 @@ def fastvit_ma36(pretrained=False, **kwargs):
layers=(6, 6, 18, 6),
embed_dims=(76, 152, 304, 608),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
)