Add xavier_uniform init of MNVC hybrid attention modules. Small improvement in training stability.

pull/2252/head
Ross Wightman 2024-07-26 17:03:40 -07:00
parent 9558a7f8eb
commit ab8cb070fc
2 changed files with 16 additions and 1 deletions

View File

@ -205,6 +205,16 @@ class MultiQueryAttention2d(nn.Module):
self.einsum = False
def init_weights(self):
# using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
nn.init.xavier_uniform_(self.query.proj.weight)
nn.init.xavier_uniform_(self.key.proj.weight)
nn.init.xavier_uniform_(self.value.proj.weight)
if self.kv_stride > 1:
nn.init.xavier_uniform_(self.key.down_conv.weight)
nn.init.xavier_uniform_(self.value.down_conv.weight)
nn.init.xavier_uniform_(self.output.proj.weight)
def _reshape_input(self, t: torch.Tensor):
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
s = t.shape

View File

@ -16,8 +16,9 @@ from typing import Any, Dict, List
import torch.nn as nn
from ._efficientnet_blocks import *
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
from ._efficientnet_blocks import *
from ._manipulate import named_modules
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
@ -569,3 +570,7 @@ def efficientnet_init_weights(model: nn.Module, init_fn=None):
for n, m in model.named_modules():
init_fn(m, n)
# iterate and call any module.init_weights() fn, children first
for n, m in named_modules(model):
if hasattr(m, 'init_weights'):
m.init_weights()