mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Stagify FastViT /w downsample to top of stage
This commit is contained in:
parent
8470eb1cb5
commit
40dbaafef5
@ -761,16 +761,16 @@ class RepCPE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
embed_dim: int = 768,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
spatial_shape: Union[int, Tuple[int, int]] = (7, 7),
|
||||
inference_mode=False,
|
||||
) -> None:
|
||||
"""Build reparameterizable conditional positional encoding
|
||||
|
||||
Args:
|
||||
in_chs: Number of input channels.
|
||||
embed_dim: Number of embedding dimensions. Default: 768
|
||||
dim: Number of input channels.
|
||||
dim_out: Number of embedding dimensions. Default: 768
|
||||
spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7)
|
||||
inference_mode: Flag to instantiate block in inference mode. Default: ``False``
|
||||
"""
|
||||
@ -787,29 +787,29 @@ class RepCPE(nn.Module):
|
||||
)
|
||||
|
||||
self.spatial_shape = spatial_shape
|
||||
self.embed_dim = embed_dim
|
||||
self.in_chs = in_chs
|
||||
self.groups = embed_dim
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out or dim
|
||||
self.groups = dim
|
||||
|
||||
if inference_mode:
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
self.in_chs,
|
||||
self.embed_dim,
|
||||
self.dim,
|
||||
self.dim_out,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=spatial_shape[0] // 2,
|
||||
groups=self.embed_dim,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
else:
|
||||
self.reparam_conv = None
|
||||
self.pe = nn.Conv2d(
|
||||
in_chs,
|
||||
embed_dim,
|
||||
self.dim,
|
||||
self.dim_out,
|
||||
spatial_shape,
|
||||
1,
|
||||
int(spatial_shape[0] // 2),
|
||||
groups=embed_dim,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
@ -823,10 +823,10 @@ class RepCPE(nn.Module):
|
||||
|
||||
def reparameterize(self) -> None:
|
||||
# Build equivalent Id tensor
|
||||
input_dim = self.in_chs // self.groups
|
||||
input_dim = self.dim // self.groups
|
||||
kernel_value = torch.zeros(
|
||||
(
|
||||
self.in_chs,
|
||||
self.dim,
|
||||
input_dim,
|
||||
self.spatial_shape[0],
|
||||
self.spatial_shape[1],
|
||||
@ -834,7 +834,7 @@ class RepCPE(nn.Module):
|
||||
dtype=self.pe.weight.dtype,
|
||||
device=self.pe.weight.device,
|
||||
)
|
||||
for i in range(self.in_chs):
|
||||
for i in range(self.dim):
|
||||
kernel_value[
|
||||
i,
|
||||
i % input_dim,
|
||||
@ -849,12 +849,12 @@ class RepCPE(nn.Module):
|
||||
|
||||
# Introduce reparam conv
|
||||
self.reparam_conv = nn.Conv2d(
|
||||
self.in_chs,
|
||||
self.embed_dim,
|
||||
self.dim,
|
||||
self.dim_out,
|
||||
kernel_size=self.spatial_shape,
|
||||
stride=1,
|
||||
padding=int(self.spatial_shape[0] // 2),
|
||||
groups=self.embed_dim,
|
||||
groups=self.groups,
|
||||
bias=True,
|
||||
)
|
||||
self.reparam_conv.weight.data = w_final
|
||||
@ -1002,78 +1002,97 @@ class AttentionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def basic_blocks(
|
||||
dim: int,
|
||||
block_index: int,
|
||||
num_blocks: List[int],
|
||||
token_mixer_type: str,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode=False,
|
||||
) -> nn.Sequential:
|
||||
"""Build FastViT blocks within a stage.
|
||||
class FastVitStage(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
depth: int,
|
||||
token_mixer_type: str,
|
||||
downsample: bool = True,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
pos_emb_layer: Optional[nn.Module] = None,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
inference_mode=False,
|
||||
):
|
||||
"""FastViT stage.
|
||||
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
block_index: block index.
|
||||
num_blocks: List containing number of blocks per stage.
|
||||
token_mixer_type: Token mixer type.
|
||||
kernel_size: Kernel size for repmixer.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Drop path rate.
|
||||
use_layer_scale: Flag to turn on layer scale regularization.
|
||||
layer_scale_init_value: Layer scale value at initialization.
|
||||
inference_mode: Flag to instantiate block in inference mode.
|
||||
|
||||
Returns:
|
||||
nn.Sequential object of all the blocks within the stage.
|
||||
"""
|
||||
blocks = []
|
||||
for block_idx in range(num_blocks[block_index]):
|
||||
block_dpr = (
|
||||
drop_path_rate
|
||||
* (block_idx + sum(num_blocks[:block_index]))
|
||||
/ (sum(num_blocks) - 1)
|
||||
)
|
||||
if token_mixer_type == "repmixer":
|
||||
blocks.append(RepMixerBlock(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
Args:
|
||||
dim: Number of embedding dimensions.
|
||||
num_blocks: List containing number of blocks per stage.
|
||||
token_mixer_type: Token mixer type.
|
||||
kernel_size: Kernel size for repmixer.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Drop path rate.
|
||||
use_layer_scale: Flag to turn on layer scale regularization.
|
||||
layer_scale_init_value: Layer scale value at initialization.
|
||||
inference_mode: Flag to instantiate block in inference mode.
|
||||
"""
|
||||
super().__init__()
|
||||
if downsample:
|
||||
self.downsample = PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_chs=dim,
|
||||
embed_dim=dim_out,
|
||||
inference_mode=inference_mode,
|
||||
))
|
||||
elif token_mixer_type == "attention":
|
||||
blocks.append(AttentionBlock(
|
||||
dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Token mixer type: {} not supported".format(token_mixer_type)
|
||||
)
|
||||
blocks = nn.Sequential(*blocks)
|
||||
else:
|
||||
assert dim == dim_out
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
return blocks
|
||||
if pos_emb_layer is not None:
|
||||
self.pos_emb = pos_emb_layer(dim_out, inference_mode=inference_mode)
|
||||
else:
|
||||
self.pos_emb = nn.Identity()
|
||||
|
||||
blocks = []
|
||||
for block_idx in range(depth):
|
||||
if token_mixer_type == "repmixer":
|
||||
blocks.append(RepMixerBlock(
|
||||
dim_out,
|
||||
kernel_size=kernel_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=drop_path_rate[block_idx],
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
))
|
||||
elif token_mixer_type == "attention":
|
||||
blocks.append(AttentionBlock(
|
||||
dim_out,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=drop_path_rate[block_idx],
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Token mixer type: {} not supported".format(token_mixer_type)
|
||||
)
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
x = self.pos_emb(x)
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class FastVit(nn.Module):
|
||||
@ -1085,78 +1104,66 @@ class FastVit(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
layers=(2, 2, 6, 2),
|
||||
in_chans: int = 3,
|
||||
layers: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
token_mixers: Tuple[str, ...] = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
embed_dims=None,
|
||||
mlp_ratios=None,
|
||||
downsamples=None,
|
||||
repmixer_kernel_size=3,
|
||||
embed_dims: Tuple[int, ...] = (64, 128, 256, 512),
|
||||
mlp_ratios: Tuple[float, ...] = (4,) * 4,
|
||||
downsamples: Tuple[bool, ...] = (False, True, True, True),
|
||||
repmixer_kernel_size: int = 3,
|
||||
num_classes: int = 1000,
|
||||
pos_embs: Tuple[Optional[nn.Module], ...] = (None,) * 4,
|
||||
down_patch_size: int = 7,
|
||||
down_stride: int = 2,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
use_layer_scale: bool = True,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
fork_feat: bool = False,
|
||||
cls_ratio: float = 2.0,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
num_classes=1000,
|
||||
pos_embs=None,
|
||||
down_patch_size=7,
|
||||
down_stride=2,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_layer_scale=True,
|
||||
layer_scale_init_value=1e-5,
|
||||
fork_feat=False,
|
||||
cls_ratio=2.0,
|
||||
inference_mode=False,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_classes = 0 if fork_feat else num_classes
|
||||
self.fork_feat = fork_feat
|
||||
|
||||
if pos_embs is None:
|
||||
pos_embs = [None] * len(layers)
|
||||
|
||||
# Convolutional stem
|
||||
self.patch_embed = convolutional_stem(
|
||||
in_chans, embed_dims[0], inference_mode)
|
||||
in_chans,
|
||||
embed_dims[0],
|
||||
inference_mode,
|
||||
)
|
||||
|
||||
# Build the main stages of the network architecture
|
||||
prev_dim = embed_dims[0]
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
# Add position embeddings if requested
|
||||
if pos_embs[i] is not None:
|
||||
network.append(pos_embs[i](
|
||||
embed_dims[i],
|
||||
embed_dims[i],
|
||||
inference_mode=inference_mode,
|
||||
))
|
||||
stage = basic_blocks(
|
||||
embed_dims[i],
|
||||
i,
|
||||
layers,
|
||||
stage = FastVitStage(
|
||||
dim=prev_dim,
|
||||
dim_out=embed_dims[i],
|
||||
depth=layers[i],
|
||||
downsample=downsamples[i] or prev_dim != embed_dims[i],
|
||||
down_patch_size=down_patch_size,
|
||||
down_stride=down_stride,
|
||||
pos_emb_layer=pos_embs[i],
|
||||
token_mixer_type=token_mixers[i],
|
||||
kernel_size=repmixer_kernel_size,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
inference_mode=inference_mode,
|
||||
)
|
||||
network.append(stage)
|
||||
if i >= len(layers) - 1:
|
||||
break
|
||||
prev_dim = embed_dims[i]
|
||||
|
||||
# Patch merging/downsampling between stages.
|
||||
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||||
network += [PatchEmbed(
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
in_chs=embed_dims[i],
|
||||
embed_dim=embed_dims[i + 1],
|
||||
inference_mode=inference_mode,
|
||||
)]
|
||||
|
||||
self.network = nn.ModuleList(network)
|
||||
self.network = nn.Sequential(*network)
|
||||
|
||||
# For segmentation and detection, extract intermediate output
|
||||
if self.fork_feat:
|
||||
@ -1338,7 +1345,6 @@ def fastvit_t8(pretrained=False, **kwargs):
|
||||
layers=(2, 2, 4, 2),
|
||||
embed_dims=(48, 96, 192, 384),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
downsamples=(True, True, True, True),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer")
|
||||
)
|
||||
return _create_fastvit('fastvit_t8', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
@ -1351,8 +1357,7 @@ def fastvit_t12(pretrained=False, **kwargs):
|
||||
layers=(2, 2, 6, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(3, 3, 3, 3),
|
||||
downsamples=(True, True, True, True),
|
||||
token_mixers = ("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
)
|
||||
return _create_fastvit('fastvit_t12', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
@ -1364,7 +1369,6 @@ def fastvit_s12(pretrained=False, **kwargs):
|
||||
layers=(2, 2, 6, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "repmixer"),
|
||||
)
|
||||
return _create_fastvit('fastvit_s12', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
@ -1377,7 +1381,6 @@ def fastvit_sa12(pretrained=False, **kwargs):
|
||||
layers=(2, 2, 6, 2),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
@ -1391,7 +1394,6 @@ def fastvit_sa24(pretrained=False, **kwargs):
|
||||
layers=(4, 4, 12, 4),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
@ -1405,7 +1407,6 @@ def fastvit_sa36(pretrained=False, **kwargs):
|
||||
layers=(6, 6, 18, 6),
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention"),
|
||||
)
|
||||
@ -1418,7 +1419,6 @@ def fastvit_ma36(pretrained=False, **kwargs):
|
||||
layers=(6, 6, 18, 6),
|
||||
embed_dims=(76, 152, 304, 608),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
downsamples=(True, True, True, True),
|
||||
pos_embs=(None, None, None, partial(RepCPE, spatial_shape=(7, 7))),
|
||||
token_mixers=("repmixer", "repmixer", "repmixer", "attention")
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user