mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel
This commit is contained in:
parent
d400f1dbdd
commit
11ae795e99
@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -255,6 +256,8 @@ class Subsample(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
ab: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
|
||||
super().__init__()
|
||||
@ -286,20 +289,31 @@ class Attention(nn.Module):
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
|
||||
self.ab = None
|
||||
self.ab = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
|
||||
if mode and self.ab:
|
||||
self.ab = {} # clear ab cache
|
||||
|
||||
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
||||
if self.training:
|
||||
return self.attention_biases[:, self.attention_bias_idxs]
|
||||
else:
|
||||
device_key = str(device)
|
||||
if device_key not in self.ab:
|
||||
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
||||
return self.ab[device_key]
|
||||
|
||||
def forward(self, x): # x (B,C,H,W)
|
||||
if self.use_conv:
|
||||
B, C, H, W = x.shape
|
||||
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
|
||||
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
||||
else:
|
||||
B, N, C = x.shape
|
||||
@ -308,15 +322,18 @@ class Attention(nn.Module):
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + ab
|
||||
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionSubsample(nn.Module):
|
||||
ab: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
|
||||
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
|
||||
@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module):
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
|
||||
self.ab = None
|
||||
self.ab = {} # per-device attention_biases cache
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
|
||||
if mode and self.ab:
|
||||
self.ab = {} # clear ab cache
|
||||
|
||||
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
||||
if self.training:
|
||||
return self.attention_biases[:, self.attention_bias_idxs]
|
||||
else:
|
||||
device_key = str(device)
|
||||
if device_key not in self.ab:
|
||||
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
||||
return self.ab[device_key]
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module):
|
||||
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
|
||||
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
|
||||
|
||||
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
|
||||
@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module):
|
||||
v = v.permute(0, 2, 1, 3) # BHNC
|
||||
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
||||
|
||||
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + ab
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
|
||||
|
Loading…
x
Reference in New Issue
Block a user