Adjust arg order for recent vit model args, add a few comments
parent
41dc49a337
commit
f5ca4141f7
|
@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
|
|||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True,
|
||||
fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
|
||||
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
|
@ -340,12 +340,12 @@ class VisionTransformer(nn.Module):
|
|||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token
|
||||
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
class_token (bool): use class token
|
||||
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
|
|
|
@ -240,13 +240,19 @@ class ResPostRelPosBlock(nn.Module):
|
|||
|
||||
class VisionTransformerRelPos(nn.Module):
|
||||
""" Vision Transformer w/ Relative Position Bias
|
||||
|
||||
Differing from classic vit, this impl
|
||||
* uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
|
||||
* defaults to no class token (can be enabled)
|
||||
* defaults to global avg pool for head (can be changed)
|
||||
* layer-scale (residual branch gain) enabled
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False,
|
||||
rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
|
||||
class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
|
||||
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
|
||||
"""
|
||||
Args:
|
||||
|
@ -254,21 +260,21 @@ class VisionTransformerRelPos(nn.Module):
|
|||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
global_pool (str): type of global pooling for final sequence (default: 'token')
|
||||
global_pool (str): type of global pooling for final sequence (default: 'avg')
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token (default: False)
|
||||
rel_pos_ty pe (str): type of relative position
|
||||
shared_rel_pos (bool): share relative pos across all blocks
|
||||
fc_norm (bool): use pre classifier norm instead of pre-pool
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
class_token (bool): use class token (default: False)
|
||||
rel_pos_ty pe (str): type of relative position
|
||||
shared_rel_pos (bool): share relative pos across all blocks
|
||||
fc_norm (bool): use pre classifier norm
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
|
@ -384,11 +390,10 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
|
||||
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5,
|
||||
block_fn=ResPostRelPosBlock, **kwargs)
|
||||
patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs)
|
||||
model = _create_vision_transformer_relpos(
|
||||
'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
@ -398,7 +403,7 @@ def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
|
|||
def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
|
||||
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs)
|
||||
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
@ -408,8 +413,7 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
|
|||
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
|
||||
fc_norm=True, **kwargs)
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs)
|
||||
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
@ -419,7 +423,6 @@ def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
|
|||
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
|
||||
block_fn=ResPostRelPosBlock, **kwargs)
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
|
||||
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue