Fix: x.item() can not be called when creating the model on a "meta" device
parent
e1277af2ba
commit
85a2460209
|
@ -12,6 +12,7 @@ import math
|
|||
import logging
|
||||
from typing import Sequence, Tuple, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
@ -116,7 +117,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
else:
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
dpr = np.linspace(0, drop_path_rate, depth).tolist() # stochastic depth decay rule
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
logger.info("using MLP layer as FFN")
|
||||
|
|
Loading…
Reference in New Issue