mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge remote-tracking branch 'upstream/main' into vit_siglip_and_reg
This commit is contained in:
commit
49a459e8f1
14
.github/workflows/tests.yml
vendored
14
.github/workflows/tests.yml
vendored
@ -16,10 +16,12 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest]
|
os: [ubuntu-latest]
|
||||||
python: ['3.10']
|
python: ['3.10', '3.11']
|
||||||
torch: ['1.13.0']
|
torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.1.0', vision: '0.16.0'}]
|
||||||
torchvision: ['0.14.0']
|
|
||||||
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
|
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
|
||||||
|
exclude:
|
||||||
|
- python: '3.11'
|
||||||
|
torch: {base: '1.13.0', vision: '0.14.0'}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@ -34,17 +36,17 @@ jobs:
|
|||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
- name: Install torch on mac
|
- name: Install torch on mac
|
||||||
if: startsWith(matrix.os, 'macOS')
|
if: startsWith(matrix.os, 'macOS')
|
||||||
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
|
||||||
- name: Install torch on Windows
|
- name: Install torch on Windows
|
||||||
if: startsWith(matrix.os, 'windows')
|
if: startsWith(matrix.os, 'windows')
|
||||||
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
|
||||||
- name: Install torch on ubuntu
|
- name: Install torch on ubuntu
|
||||||
if: startsWith(matrix.os, 'ubuntu')
|
if: startsWith(matrix.os, 'ubuntu')
|
||||||
run: |
|
run: |
|
||||||
sudo sed -i 's/azure\.//' /etc/apt/sources.list
|
sudo sed -i 's/azure\.//' /etc/apt/sources.list
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y google-perftools
|
sudo apt install -y google-perftools
|
||||||
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
@ -10,7 +10,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
from torch.autograd import Variable
|
from torch.nn import Parameter
|
||||||
from timm.scheduler import PlateauLRScheduler
|
from timm.scheduler import PlateauLRScheduler
|
||||||
|
|
||||||
from timm.optim import create_optimizer_v2
|
from timm.optim import create_optimizer_v2
|
||||||
@ -21,9 +21,9 @@ torch_tc = TestCase()
|
|||||||
|
|
||||||
|
|
||||||
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
|
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
|
||||||
weight = Variable(weight, requires_grad=True)
|
weight = Parameter(weight)
|
||||||
bias = Variable(bias, requires_grad=True)
|
bias = Parameter(bias)
|
||||||
input = Variable(input)
|
input = Parameter(input)
|
||||||
optimizer = constructor(weight, bias)
|
optimizer = constructor(weight, bias)
|
||||||
schedulers = []
|
schedulers = []
|
||||||
for scheduler_constructor in scheduler_constructors:
|
for scheduler_constructor in scheduler_constructors:
|
||||||
@ -55,9 +55,9 @@ def _test_basic_cases_template(weight, bias, input, constructor, scheduler_const
|
|||||||
|
|
||||||
|
|
||||||
def _test_state_dict(weight, bias, input, constructor):
|
def _test_state_dict(weight, bias, input, constructor):
|
||||||
weight = Variable(weight, requires_grad=True)
|
weight = Parameter(weight)
|
||||||
bias = Variable(bias, requires_grad=True)
|
bias = Parameter(bias)
|
||||||
input = Variable(input)
|
input = Parameter(input)
|
||||||
|
|
||||||
def fn_base(optimizer, weight, bias):
|
def fn_base(optimizer, weight, bias):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -73,8 +73,9 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||||||
for _i in range(20):
|
for _i in range(20):
|
||||||
optimizer.step(fn)
|
optimizer.step(fn)
|
||||||
# Clone the weights and construct new optimizer for them
|
# Clone the weights and construct new optimizer for them
|
||||||
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
with torch.no_grad():
|
||||||
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
weight_c = Parameter(weight.clone().detach())
|
||||||
|
bias_c = Parameter(bias.clone().detach())
|
||||||
optimizer_c = constructor(weight_c, bias_c)
|
optimizer_c = constructor(weight_c, bias_c)
|
||||||
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
||||||
# Load state dict
|
# Load state dict
|
||||||
@ -86,12 +87,8 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||||||
for _i in range(20):
|
for _i in range(20):
|
||||||
optimizer.step(fn)
|
optimizer.step(fn)
|
||||||
optimizer_c.step(fn_c)
|
optimizer_c.step(fn_c)
|
||||||
#assert torch.equal(weight, weight_c)
|
|
||||||
#assert torch.equal(bias, bias_c)
|
|
||||||
torch_tc.assertEqual(weight, weight_c)
|
torch_tc.assertEqual(weight, weight_c)
|
||||||
torch_tc.assertEqual(bias, bias_c)
|
torch_tc.assertEqual(bias, bias_c)
|
||||||
# Make sure state dict wasn't modified
|
|
||||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
|
||||||
# Make sure state dict is deterministic with equal but not identical parameters
|
# Make sure state dict is deterministic with equal but not identical parameters
|
||||||
torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
||||||
# Make sure repeated parameters have identical representation in state dict
|
# Make sure repeated parameters have identical representation in state dict
|
||||||
@ -103,9 +100,10 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
return
|
return
|
||||||
|
|
||||||
input_cuda = Variable(input.data.float().cuda())
|
with torch.no_grad():
|
||||||
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
input_cuda = Parameter(input.clone().detach().float().cuda())
|
||||||
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
weight_cuda = Parameter(weight.clone().detach().cuda())
|
||||||
|
bias_cuda = Parameter(bias.clone().detach().cuda())
|
||||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
||||||
|
|
||||||
@ -216,21 +214,21 @@ def _test_rosenbrock(constructor, scheduler_constructors=None):
|
|||||||
scheduler_constructors = []
|
scheduler_constructors = []
|
||||||
params_t = torch.tensor([1.5, 1.5])
|
params_t = torch.tensor([1.5, 1.5])
|
||||||
|
|
||||||
params = Variable(params_t, requires_grad=True)
|
params = Parameter(params_t)
|
||||||
optimizer = constructor([params])
|
optimizer = constructor([params])
|
||||||
schedulers = []
|
schedulers = []
|
||||||
for scheduler_constructor in scheduler_constructors:
|
for scheduler_constructor in scheduler_constructors:
|
||||||
schedulers.append(scheduler_constructor(optimizer))
|
schedulers.append(scheduler_constructor(optimizer))
|
||||||
|
|
||||||
solution = torch.tensor([1, 1])
|
solution = torch.tensor([1, 1])
|
||||||
initial_dist = params.data.dist(solution)
|
initial_dist = params.clone().detach().dist(solution)
|
||||||
|
|
||||||
def eval(params, w):
|
def eval(params, w):
|
||||||
# Depending on w, provide only the x or y gradient
|
# Depending on w, provide only the x or y gradient
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = rosenbrock(params)
|
loss = rosenbrock(params)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad = drosenbrock(params.data)
|
grad = drosenbrock(params.clone().detach())
|
||||||
# NB: We torture test the optimizer by returning an
|
# NB: We torture test the optimizer by returning an
|
||||||
# uncoalesced sparse tensor
|
# uncoalesced sparse tensor
|
||||||
if w:
|
if w:
|
||||||
@ -256,7 +254,7 @@ def _test_rosenbrock(constructor, scheduler_constructors=None):
|
|||||||
else:
|
else:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
torch_tc.assertLessEqual(params.data.dist(solution), initial_dist)
|
torch_tc.assertLessEqual(params.clone().detach().dist(solution), initial_dist)
|
||||||
|
|
||||||
|
|
||||||
def _build_params_dict(weight, bias, **kwargs):
|
def _build_params_dict(weight, bias, **kwargs):
|
||||||
|
@ -130,8 +130,6 @@ class SwiGLU(nn.Module):
|
|||||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
||||||
self.drop2 = nn.Dropout(drop_probs[1])
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
self.drop = nn.Dropout(drop)
|
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||||||
nn.init.ones_(self.fc1_g.bias)
|
nn.init.ones_(self.fc1_g.bias)
|
||||||
|
@ -155,7 +155,7 @@ class Attention(nn.Module):
|
|||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=rel_pos_bias,
|
attn_mask=rel_pos_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -50,7 +50,7 @@ class ClassAttn(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x_cls = torch.nn.functional.scaled_dot_product_attention(
|
x_cls = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -126,7 +126,7 @@ class EvaAttention(nn.Module):
|
|||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -514,7 +514,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -190,7 +190,7 @@ class Attention2d(nn.Module):
|
|||||||
k.transpose(-1, -2).contiguous(),
|
k.transpose(-1, -2).contiguous(),
|
||||||
v.transpose(-1, -2).contiguous(),
|
v.transpose(-1, -2).contiguous(),
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
).transpose(-1, -2).reshape(B, -1, H, W)
|
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
@ -259,7 +259,7 @@ class AttentionCl(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -198,7 +198,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
@ -59,14 +59,14 @@ class Attention(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
|
x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
|
||||||
"""
|
"""
|
||||||
B, T, N, C = x.shape
|
B, T, N, C = x.shape
|
||||||
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
||||||
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
|
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
|
||||||
@ -330,7 +330,7 @@ class Nest(nn.Module):
|
|||||||
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
|
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
|
||||||
# number of blocks along edge of image
|
# number of blocks along edge of image
|
||||||
self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
|
self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
|
||||||
|
|
||||||
# Patch embedding
|
# Patch embedding
|
||||||
self.patch_embed = PatchEmbed(
|
self.patch_embed = PatchEmbed(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
|
@ -130,7 +130,7 @@ class Attention(nn.Module):
|
|||||||
k, v = kv.unbind(0)
|
k, v = kv.unbind(0)
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
@ -164,7 +164,7 @@ class WindowAttention(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -75,7 +75,7 @@ class LocallyGroupedAttn(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
@ -172,7 +172,7 @@ class GlobalSubSampleAttn(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -95,7 +95,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q.contiguous(), k.contiguous(), v.contiguous(),
|
q.contiguous(), k.contiguous(), v.contiguous(),
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
@ -85,7 +85,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
@ -285,7 +285,7 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x_attn = F.scaled_dot_product_attention(
|
x_attn = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
@ -1208,7 +1208,7 @@ default_cfgs = generate_default_cfgs({
|
|||||||
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
|
||||||
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
|
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
|
||||||
'vit_small_patch14_dinov2.lvd142m': _cfg(
|
'vit_small_patch14_dinov2.lvd142m': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
|
||||||
@ -1528,7 +1528,7 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
|
||||||
'vit_huge_patch14_ijepa_224.in1k': _cfg(
|
'vit_huge_patch14_ijepa_224.in1k': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
|
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
|
||||||
# hf_hub_id='timm/',
|
# hf_hub_id='timm/',
|
||||||
@ -2182,7 +2182,7 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||||
|
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
||||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
|
@ -71,7 +71,7 @@ class RelPosAttention(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -168,7 +168,7 @@ class Attention(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
Loading…
x
Reference in New Issue
Block a user