mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add xavier_uniform init of MNVC hybrid attention modules. Small improvement in training stability.
This commit is contained in:
parent
9558a7f8eb
commit
ab8cb070fc
@ -205,6 +205,16 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
|
|
||||||
self.einsum = False
|
self.einsum = False
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
# using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
|
||||||
|
nn.init.xavier_uniform_(self.query.proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.key.proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.value.proj.weight)
|
||||||
|
if self.kv_stride > 1:
|
||||||
|
nn.init.xavier_uniform_(self.key.down_conv.weight)
|
||||||
|
nn.init.xavier_uniform_(self.value.down_conv.weight)
|
||||||
|
nn.init.xavier_uniform_(self.output.proj.weight)
|
||||||
|
|
||||||
def _reshape_input(self, t: torch.Tensor):
|
def _reshape_input(self, t: torch.Tensor):
|
||||||
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
|
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
|
||||||
s = t.shape
|
s = t.shape
|
||||||
|
@ -16,8 +16,9 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ._efficientnet_blocks import *
|
|
||||||
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
|
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
|
||||||
|
from ._efficientnet_blocks import *
|
||||||
|
from ._manipulate import named_modules
|
||||||
|
|
||||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
||||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
||||||
@ -569,3 +570,7 @@ def efficientnet_init_weights(model: nn.Module, init_fn=None):
|
|||||||
for n, m in model.named_modules():
|
for n, m in model.named_modules():
|
||||||
init_fn(m, n)
|
init_fn(m, n)
|
||||||
|
|
||||||
|
# iterate and call any module.init_weights() fn, children first
|
||||||
|
for n, m in named_modules(model):
|
||||||
|
if hasattr(m, 'init_weights'):
|
||||||
|
m.init_weights()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user