Fix: x.item() can not be called when creating the model on a "meta" device

pull/477/head
Federico Baldassarre 2024-10-26 02:19:00 +02:00 committed by GitHub
parent e1277af2ba
commit 85a2460209
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -12,6 +12,7 @@ import math
import logging
from typing import Sequence, Tuple, Union, Callable
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
@ -116,7 +117,7 @@ class DinoVisionTransformer(nn.Module):
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, depth).tolist() # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")