From 3a3217b0c05b4c2e1dd861210e9afe3494f5fcb5 Mon Sep 17 00:00:00 2001 From: Federico Baldassarre Date: Sat, 26 Oct 2024 01:30:36 +0200 Subject: [PATCH] Defer LayerScale initialization for compatibility with "meta" devices --- dinov2/layers/layer_scale.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dinov2/layers/layer_scale.py b/dinov2/layers/layer_scale.py index 51df0d7..0b38971 100644 --- a/dinov2/layers/layer_scale.py +++ b/dinov2/layers/layer_scale.py @@ -5,7 +5,7 @@ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 -from typing import Union +from typing import Optional, Union import torch from torch import Tensor @@ -18,10 +18,17 @@ class LayerScale(nn.Module): dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) + self.init_values = init_values + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.gamma, self.init_values) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma