mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Allow act_layer switch for xcit, fix in_chans for some variants
This commit is contained in:
parent
d3255adf8e
commit
748ab852ca
@ -141,7 +141,7 @@ def conv3x3(in_planes, out_planes, stride=1):
|
|||||||
class ConvPatchEmbed(nn.Module):
|
class ConvPatchEmbed(nn.Module):
|
||||||
"""Image to Patch Embedding using multiple convolutional layers"""
|
"""Image to Patch Embedding using multiple convolutional layers"""
|
||||||
|
|
||||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
img_size = to_2tuple(img_size)
|
img_size = to_2tuple(img_size)
|
||||||
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
|
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
|
||||||
@ -152,19 +152,19 @@ class ConvPatchEmbed(nn.Module):
|
|||||||
if patch_size == 16:
|
if patch_size == 16:
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
conv3x3(in_chans, embed_dim // 8, 2),
|
conv3x3(in_chans, embed_dim // 8, 2),
|
||||||
nn.GELU(),
|
act_layer(),
|
||||||
conv3x3(embed_dim // 8, embed_dim // 4, 2),
|
conv3x3(embed_dim // 8, embed_dim // 4, 2),
|
||||||
nn.GELU(),
|
act_layer(),
|
||||||
conv3x3(embed_dim // 4, embed_dim // 2, 2),
|
conv3x3(embed_dim // 4, embed_dim // 2, 2),
|
||||||
nn.GELU(),
|
act_layer(),
|
||||||
conv3x3(embed_dim // 2, embed_dim, 2),
|
conv3x3(embed_dim // 2, embed_dim, 2),
|
||||||
)
|
)
|
||||||
elif patch_size == 8:
|
elif patch_size == 8:
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
conv3x3(3, embed_dim // 4, 2),
|
conv3x3(in_chans, embed_dim // 4, 2),
|
||||||
nn.GELU(),
|
act_layer(),
|
||||||
conv3x3(embed_dim // 4, embed_dim // 2, 2),
|
conv3x3(embed_dim // 4, embed_dim // 2, 2),
|
||||||
nn.GELU(),
|
act_layer(),
|
||||||
conv3x3(embed_dim // 2, embed_dim, 2),
|
conv3x3(embed_dim // 2, embed_dim, 2),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -323,7 +323,7 @@ class XCiT(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||||
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||||
norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
|
act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
img_size (int, tuple): input image size
|
img_size (int, tuple): input image size
|
||||||
@ -356,9 +356,10 @@ class XCiT(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.num_features = self.embed_dim = embed_dim
|
self.num_features = self.embed_dim = embed_dim
|
||||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
act_layer = act_layer or nn.GELU
|
||||||
|
|
||||||
self.patch_embed = ConvPatchEmbed(
|
self.patch_embed = ConvPatchEmbed(
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer)
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
self.use_pos_embed = use_pos_embed
|
self.use_pos_embed = use_pos_embed
|
||||||
@ -369,13 +370,13 @@ class XCiT(nn.Module):
|
|||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
XCABlock(
|
XCABlock(
|
||||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
||||||
attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, eta=eta)
|
attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta)
|
||||||
for _ in range(depth)])
|
for _ in range(depth)])
|
||||||
|
|
||||||
self.cls_attn_blocks = nn.ModuleList([
|
self.cls_attn_blocks = nn.ModuleList([
|
||||||
ClassAttentionBlock(
|
ClassAttentionBlock(
|
||||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
||||||
attn_drop=attn_drop_rate, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm)
|
attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm)
|
||||||
for _ in range(cls_attn_layers)])
|
for _ in range(cls_attn_layers)])
|
||||||
|
|
||||||
# Classifier head
|
# Classifier head
|
||||||
|
Loading…
x
Reference in New Issue
Block a user