mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Another effcientvit (mit) tweak, fix torchscript/fx conflict with autocast disable
This commit is contained in:
parent
dc18cda2e7
commit
300f54a96f
@ -14,10 +14,11 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from ._registry import register_model, generate_default_cfgs
|
|
||||||
from ._builder import build_model_with_cfg
|
|
||||||
from ._manipulate import checkpoint_seq
|
|
||||||
from timm.layers import SelectAdaptivePool2d, create_conv2d
|
from timm.layers import SelectAdaptivePool2d, create_conv2d
|
||||||
|
from ._builder import build_model_with_cfg
|
||||||
|
from ._features_fx import register_notrace_module
|
||||||
|
from ._manipulate import checkpoint_seq
|
||||||
|
from ._registry import register_model, generate_default_cfgs
|
||||||
|
|
||||||
|
|
||||||
def val2list(x: list or tuple or any, repeat_time=1):
|
def val2list(x: list or tuple or any, repeat_time=1):
|
||||||
@ -233,6 +234,14 @@ class LiteMSA(nn.Module):
|
|||||||
act_layer=act_layer[1],
|
act_layer=act_layer[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _attn(self, q, k, v):
|
||||||
|
dtype = v.dtype
|
||||||
|
q, k, v = q.float(), k.float(), v.float()
|
||||||
|
kv = k.transpose(-1, -2) @ v
|
||||||
|
out = q @ kv
|
||||||
|
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
||||||
|
return out.to(dtype)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
|
|
||||||
@ -243,20 +252,18 @@ class LiteMSA(nn.Module):
|
|||||||
multi_scale_qkv.append(op(qkv))
|
multi_scale_qkv.append(op(qkv))
|
||||||
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
|
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
|
||||||
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
|
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
|
||||||
q, k, v = multi_scale_qkv.tensor_split(3, dim=-1)
|
q, k, v = multi_scale_qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
# lightweight global attention
|
# lightweight global attention
|
||||||
q = self.kernel_func(q)
|
q = self.kernel_func(q)
|
||||||
k = self.kernel_func(k)
|
k = self.kernel_func(k)
|
||||||
v = F.pad(v, (0, 1), mode="constant", value=1.)
|
v = F.pad(v, (0, 1), mode="constant", value=1.)
|
||||||
|
|
||||||
dtype = v.dtype
|
if not torch.jit.is_scripting():
|
||||||
q, k, v = q.float(), k.float(), v.float()
|
with torch.amp.autocast(device_type=v.device.type, enabled=False):
|
||||||
with torch.amp.autocast(device_type=v.device.type, enabled=False):
|
out = self._attn(q, k, v)
|
||||||
kv = k.transpose(-1, -2) @ v
|
else:
|
||||||
out = q @ kv
|
out = self._attn(q, k, v)
|
||||||
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
|
||||||
out = out.to(dtype)
|
|
||||||
|
|
||||||
# final projection
|
# final projection
|
||||||
out = out.transpose(-1, -2).reshape(B, -1, H, W)
|
out = out.transpose(-1, -2).reshape(B, -1, H, W)
|
||||||
@ -264,6 +271,9 @@ class LiteMSA(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
register_notrace_module(LiteMSA)
|
||||||
|
|
||||||
|
|
||||||
class EfficientVitBlock(nn.Module):
|
class EfficientVitBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user