Another effcientvit (mit) tweak, fix torchscript/fx conflict with autocast disable
parent
dc18cda2e7
commit
300f54a96f
|
@ -14,10 +14,11 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from timm.layers import SelectAdaptivePool2d, create_conv2d
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
|
||||
def val2list(x: list or tuple or any, repeat_time=1):
|
||||
|
@ -233,6 +234,14 @@ class LiteMSA(nn.Module):
|
|||
act_layer=act_layer[1],
|
||||
)
|
||||
|
||||
def _attn(self, q, k, v):
|
||||
dtype = v.dtype
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
kv = k.transpose(-1, -2) @ v
|
||||
out = q @ kv
|
||||
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
||||
return out.to(dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, _, H, W = x.shape
|
||||
|
||||
|
@ -243,20 +252,18 @@ class LiteMSA(nn.Module):
|
|||
multi_scale_qkv.append(op(qkv))
|
||||
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
|
||||
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
|
||||
q, k, v = multi_scale_qkv.tensor_split(3, dim=-1)
|
||||
q, k, v = multi_scale_qkv.chunk(3, dim=-1)
|
||||
|
||||
# lightweight global attention
|
||||
q = self.kernel_func(q)
|
||||
k = self.kernel_func(k)
|
||||
v = F.pad(v, (0, 1), mode="constant", value=1.)
|
||||
|
||||
dtype = v.dtype
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
with torch.amp.autocast(device_type=v.device.type, enabled=False):
|
||||
kv = k.transpose(-1, -2) @ v
|
||||
out = q @ kv
|
||||
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
||||
out = out.to(dtype)
|
||||
if not torch.jit.is_scripting():
|
||||
with torch.amp.autocast(device_type=v.device.type, enabled=False):
|
||||
out = self._attn(q, k, v)
|
||||
else:
|
||||
out = self._attn(q, k, v)
|
||||
|
||||
# final projection
|
||||
out = out.transpose(-1, -2).reshape(B, -1, H, W)
|
||||
|
@ -264,6 +271,9 @@ class LiteMSA(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
register_notrace_module(LiteMSA)
|
||||
|
||||
|
||||
class EfficientVitBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue