Add xavier_uniform init of MNVC hybrid attention modules. Small improvement in training stability.
parent
9558a7f8eb
commit
ab8cb070fc
|
@ -205,6 +205,16 @@ class MultiQueryAttention2d(nn.Module):
|
|||
|
||||
self.einsum = False
|
||||
|
||||
def init_weights(self):
|
||||
# using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
|
||||
nn.init.xavier_uniform_(self.query.proj.weight)
|
||||
nn.init.xavier_uniform_(self.key.proj.weight)
|
||||
nn.init.xavier_uniform_(self.value.proj.weight)
|
||||
if self.kv_stride > 1:
|
||||
nn.init.xavier_uniform_(self.key.down_conv.weight)
|
||||
nn.init.xavier_uniform_(self.value.down_conv.weight)
|
||||
nn.init.xavier_uniform_(self.output.proj.weight)
|
||||
|
||||
def _reshape_input(self, t: torch.Tensor):
|
||||
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
|
||||
s = t.shape
|
||||
|
|
|
@ -16,8 +16,9 @@ from typing import Any, Dict, List
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from ._efficientnet_blocks import *
|
||||
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
|
||||
from ._efficientnet_blocks import *
|
||||
from ._manipulate import named_modules
|
||||
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
||||
|
@ -569,3 +570,7 @@ def efficientnet_init_weights(model: nn.Module, init_fn=None):
|
|||
for n, m in model.named_modules():
|
||||
init_fn(m, n)
|
||||
|
||||
# iterate and call any module.init_weights() fn, children first
|
||||
for n, m in named_modules(model):
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights()
|
||||
|
|
Loading…
Reference in New Issue