Another effcientvit (mit) tweak, fix torchscript/fx conflict with autocast disable

samvit_fix_and_rope^2
Ross Wightman 2023-08-20 15:07:25 -07:00
parent dc18cda2e7
commit 300f54a96f
1 changed files with 21 additions and 11 deletions

View File

@ -14,10 +14,11 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from timm.layers import SelectAdaptivePool2d, create_conv2d
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
def val2list(x: list or tuple or any, repeat_time=1):
@ -233,6 +234,14 @@ class LiteMSA(nn.Module):
act_layer=act_layer[1],
)
def _attn(self, q, k, v):
dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
return out.to(dtype)
def forward(self, x):
B, _, H, W = x.shape
@ -243,20 +252,18 @@ class LiteMSA(nn.Module):
multi_scale_qkv.append(op(qkv))
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
q, k, v = multi_scale_qkv.tensor_split(3, dim=-1)
q, k, v = multi_scale_qkv.chunk(3, dim=-1)
# lightweight global attention
q = self.kernel_func(q)
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1.)
dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
with torch.amp.autocast(device_type=v.device.type, enabled=False):
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
out = out.to(dtype)
if not torch.jit.is_scripting():
with torch.amp.autocast(device_type=v.device.type, enabled=False):
out = self._attn(q, k, v)
else:
out = self._attn(q, k, v)
# final projection
out = out.transpose(-1, -2).reshape(B, -1, H, W)
@ -264,6 +271,9 @@ class LiteMSA(nn.Module):
return out
register_notrace_module(LiteMSA)
class EfficientVitBlock(nn.Module):
def __init__(
self,