99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn.bricks.drop import build_dropout
|
|
|
|
from .layer_scale import LayerScale
|
|
from .norm import build_norm_layer
|
|
|
|
|
|
class SwiGLUFFN(nn.Module):
|
|
"""SwiGLU FFN layer.
|
|
|
|
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
|
|
""" # noqa
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dims: int,
|
|
feedforward_channels: Optional[int] = None,
|
|
out_dims: Optional[int] = None,
|
|
layer_scale_init_value: float = 0.,
|
|
bias: bool = True,
|
|
dropout_layer: Optional[dict] = None,
|
|
norm_cfg: Optional[dict] = None,
|
|
add_identity: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
self.embed_dims = embed_dims
|
|
self.out_dims = out_dims or embed_dims
|
|
hidden_dims = feedforward_channels or embed_dims
|
|
|
|
self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
|
|
|
|
if norm_cfg is not None:
|
|
self.norm = build_norm_layer(norm_cfg, hidden_dims)
|
|
else:
|
|
self.norm = nn.Identity()
|
|
|
|
self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
|
|
|
|
if layer_scale_init_value > 0:
|
|
self.gamma2 = LayerScale(
|
|
dim=embed_dims, layer_scale_init_value=layer_scale_init_value)
|
|
else:
|
|
self.gamma2 = nn.Identity()
|
|
|
|
self.dropout_layer = build_dropout(
|
|
dropout_layer) if dropout_layer else torch.nn.Identity()
|
|
self.add_identity = add_identity
|
|
|
|
def forward(self,
|
|
x: torch.Tensor,
|
|
identity: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
x12 = self.w12(x)
|
|
x1, x2 = x12.chunk(2, dim=-1)
|
|
hidden = F.silu(x1) * x2
|
|
hidden = self.norm(hidden)
|
|
out = self.w3(hidden)
|
|
out = self.gamma2(out)
|
|
out = self.dropout_layer(out)
|
|
|
|
if self.out_dims != self.embed_dims or not self.add_identity:
|
|
# due to the dimension inconsistence or user setting
|
|
# not to apply residual operation
|
|
return out
|
|
|
|
if identity is None:
|
|
identity = x
|
|
return identity + out
|
|
|
|
|
|
class SwiGLUFFNFused(SwiGLUFFN):
|
|
"""SwiGLU FFN layer with fusing.
|
|
|
|
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
|
|
""" # noqa
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dims: int,
|
|
feedforward_channels: Optional[int] = None,
|
|
out_dims: Optional[int] = None,
|
|
layer_scale_init_value: float = 0.,
|
|
bias: bool = True,
|
|
) -> None:
|
|
out_dims = out_dims or embed_dims
|
|
feedforward_channels = feedforward_channels or embed_dims
|
|
feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8
|
|
super().__init__(
|
|
embed_dims=embed_dims,
|
|
feedforward_channels=feedforward_channels,
|
|
out_dims=out_dims,
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
bias=bias,
|
|
)
|