mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improved (hopefully) init for SA/SA-like layers used in ByoaNets
This commit is contained in:
parent
d5473c17f7
commit
0721559511
@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module):
|
||||
def init_weights(self, zero_init_last_bn=False):
|
||||
if zero_init_last_bn:
|
||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||
if hasattr(self.self_attn, 'reset_parameters'):
|
||||
self.self_attn.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut(x)
|
||||
|
@ -21,6 +21,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module):
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.pos_embed.height and W == self.pos_embed.width
|
||||
|
@ -25,6 +25,8 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
""" Compute relative logits along one dimension
|
||||
@ -124,6 +126,13 @@ class HaloAttn(nn.Module):
|
||||
self.pos_embed = PosEmbedRel(
|
||||
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
||||
|
||||
def reset_parameters(self):
|
||||
std = self.q.weight.shape[1] ** -0.5 # fan-in
|
||||
trunc_normal_(self.q.weight, std=std)
|
||||
trunc_normal_(self.kv.weight, std=std)
|
||||
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H % self.block_size == 0 and W % self.block_size == 0
|
||||
|
@ -24,6 +24,7 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
class LambdaLayer(nn.Module):
|
||||
@ -36,6 +37,7 @@ class LambdaLayer(nn.Module):
|
||||
self,
|
||||
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out or dim
|
||||
self.dim_k = dim_head # query depth 'k'
|
||||
self.num_heads = num_heads
|
||||
@ -55,6 +57,10 @@ class LambdaLayer(nn.Module):
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
|
||||
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
M = H * W
|
||||
|
@ -107,6 +107,7 @@ class WindowAttention(nn.Module):
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
# 2 * Wh - 1 * 2 * Ww - 1, nH
|
||||
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.win_size)
|
||||
@ -120,13 +121,16 @@ class WindowAttention(nn.Module):
|
||||
relative_coords[:, :, 0] *= 2 * self.win_size - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user