mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improved (hopefully) init for SA/SA-like layers used in ByoaNets
This commit is contained in:
parent
d5473c17f7
commit
0721559511
@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module):
|
|||||||
def init_weights(self, zero_init_last_bn=False):
|
def init_weights(self, zero_init_last_bn=False):
|
||||||
if zero_init_last_bn:
|
if zero_init_last_bn:
|
||||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||||
|
if hasattr(self.self_attn, 'reset_parameters'):
|
||||||
|
self.self_attn.reset_parameters()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = self.shortcut(x)
|
shortcut = self.shortcut(x)
|
||||||
|
@ -21,6 +21,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .helpers import to_2tuple
|
from .helpers import to_2tuple
|
||||||
|
from .weight_init import trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module):
|
|||||||
|
|
||||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||||
|
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||||
|
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
assert H == self.pos_embed.height and W == self.pos_embed.width
|
assert H == self.pos_embed.height and W == self.pos_embed.width
|
||||||
|
@ -25,6 +25,8 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .weight_init import trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
""" Compute relative logits along one dimension
|
""" Compute relative logits along one dimension
|
||||||
@ -124,6 +126,13 @@ class HaloAttn(nn.Module):
|
|||||||
self.pos_embed = PosEmbedRel(
|
self.pos_embed = PosEmbedRel(
|
||||||
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
std = self.q.weight.shape[1] ** -0.5 # fan-in
|
||||||
|
trunc_normal_(self.q.weight, std=std)
|
||||||
|
trunc_normal_(self.kv.weight, std=std)
|
||||||
|
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||||
|
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
assert H % self.block_size == 0 and W % self.block_size == 0
|
assert H % self.block_size == 0 and W % self.block_size == 0
|
||||||
|
@ -24,6 +24,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .weight_init import trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
class LambdaLayer(nn.Module):
|
class LambdaLayer(nn.Module):
|
||||||
@ -36,6 +37,7 @@ class LambdaLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
|
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
self.dim_out = dim_out or dim
|
self.dim_out = dim_out or dim
|
||||||
self.dim_k = dim_head # query depth 'k'
|
self.dim_k = dim_head # query depth 'k'
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
@ -55,6 +57,10 @@ class LambdaLayer(nn.Module):
|
|||||||
|
|
||||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
|
||||||
|
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
M = H * W
|
M = H * W
|
||||||
|
@ -107,6 +107,7 @@ class WindowAttention(nn.Module):
|
|||||||
self.relative_position_bias_table = nn.Parameter(
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
# 2 * Wh - 1 * 2 * Ww - 1, nH
|
# 2 * Wh - 1 * 2 * Ww - 1, nH
|
||||||
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
|
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
|
||||||
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||||
|
|
||||||
# get pair-wise relative position index for each token inside the window
|
# get pair-wise relative position index for each token inside the window
|
||||||
coords_h = torch.arange(self.win_size)
|
coords_h = torch.arange(self.win_size)
|
||||||
@ -120,13 +121,16 @@ class WindowAttention(nn.Module):
|
|||||||
relative_coords[:, :, 0] *= 2 * self.win_size - 1
|
relative_coords[:, :, 0] *= 2 * self.win_size - 1
|
||||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||||
self.register_buffer("relative_position_index", relative_position_index)
|
self.register_buffer("relative_position_index", relative_position_index)
|
||||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||||
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
x = x.permute(0, 2, 3, 1)
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user