Merge 3a3217b0c0
into e1277af2ba
commit
d06d233ba1
|
@ -5,7 +5,7 @@
|
|||
|
||||
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
||||
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -18,10 +18,17 @@ class LayerScale(nn.Module):
|
|||
dim: int,
|
||||
init_values: Union[float, Tensor] = 1e-5,
|
||||
inplace: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
self.init_values = init_values
|
||||
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.constant_(self.gamma, self.init_values)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
|
Loading…
Reference in New Issue