mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixing efficient_vit torchscript, fx, default_cfg issues
This commit is contained in:
parent
58ea1c02c4
commit
7d7589e8da
@ -53,7 +53,7 @@ class ConvNormAct(nn.Module):
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
dropout=0,
|
||||
dropout=0.,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
@ -248,7 +248,7 @@ class LiteMSA(nn.Module):
|
||||
# lightweight global attention
|
||||
q = self.kernel_func(q)
|
||||
k = self.kernel_func(k)
|
||||
v = F.pad(v, (0, 1), mode="constant", value=1)
|
||||
v = F.pad(v, (0, 1), mode="constant", value=1.)
|
||||
|
||||
kv = k.transpose(-1, -2) @ v
|
||||
out = q @ kv
|
||||
@ -443,7 +443,7 @@ class ClassifierHead(nn.Module):
|
||||
in_channels,
|
||||
widths,
|
||||
n_classes=1000,
|
||||
dropout=0,
|
||||
dropout=0.,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
act_layer=nn.Hardswish,
|
||||
global_pool='avg',
|
||||
@ -547,7 +547,7 @@ class EfficientVit(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.head.classifier[-1]
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None, dropout=0):
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
@ -561,7 +561,7 @@ class EfficientVit(nn.Module):
|
||||
)
|
||||
else:
|
||||
if self.global_pool == 'avg':
|
||||
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
|
||||
@ -592,6 +592,7 @@ def _cfg(url='', **kwargs):
|
||||
'classifier': 'head.classifier.4',
|
||||
'crop_pct': 0.95,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': (7, 7),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
@ -605,33 +606,33 @@ default_cfgs = generate_default_cfgs({
|
||||
),
|
||||
'efficientvit_b1.r256_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=1.0,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b1.r288_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 288, 288), crop_pct=1.0,
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b2.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'efficientvit_b2.r256_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=1.0,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b2.r288_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 288, 288), crop_pct=1.0,
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b3.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
),
|
||||
'efficientvit_b3.r256_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), crop_pct=1.0,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
|
||||
),
|
||||
'efficientvit_b3.r288_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 288, 288), crop_pct=1.0,
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
|
||||
),
|
||||
})
|
||||
|
||||
|
@ -9,12 +9,13 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
|
||||
__all__ = ['EfficientVitMsra']
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_
|
||||
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -113,6 +114,8 @@ class ConvMlp(torch.nn.Module):
|
||||
|
||||
|
||||
class CascadedGroupAttention(torch.nn.Module):
|
||||
attention_bias_cache: Dict[str, torch.Tensor]
|
||||
|
||||
r""" Cascaded Group Attention.
|
||||
|
||||
Args:
|
||||
@ -136,19 +139,19 @@ class CascadedGroupAttention(torch.nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.key_dim = key_dim
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.val_dim = int(attn_ratio * key_dim)
|
||||
self.attn_ratio = attn_ratio
|
||||
|
||||
qkvs = []
|
||||
dws = []
|
||||
for i in range(num_heads):
|
||||
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.d))
|
||||
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim))
|
||||
dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim))
|
||||
self.qkvs = torch.nn.ModuleList(qkvs)
|
||||
self.dws = torch.nn.ModuleList(dws)
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(self.d * num_heads, dim, bn_weight_init=0)
|
||||
ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0)
|
||||
)
|
||||
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
@ -161,37 +164,44 @@ class CascadedGroupAttention(torch.nn.Module):
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = torch.nn.Parameter(
|
||||
torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
self.attention_bias_cache = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
if mode and self.attention_bias_cache:
|
||||
self.attention_bias_cache = {} # clear ab cache
|
||||
|
||||
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
||||
if torch.jit.is_tracing() or self.training:
|
||||
return self.attention_biases[:, self.attention_bias_idxs]
|
||||
else:
|
||||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
device_key = str(device)
|
||||
if device_key not in self.attention_bias_cache:
|
||||
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
||||
return self.attention_bias_cache[device_key]
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
feats_in = x.chunk(len(self.qkvs), dim=1)
|
||||
feats_out = []
|
||||
feat = feats_in[0]
|
||||
for i, qkv in enumerate(self.qkvs):
|
||||
attn_bias = self.attention_biases[:, self.attention_bias_idxs][i] if self.training else self.ab[i]
|
||||
if i > 0:
|
||||
feat = feat + feats_in[i]
|
||||
attn_bias = self.get_attention_biases(x.device)
|
||||
for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
|
||||
if head_idx > 0:
|
||||
feat = feat + feats_in[head_idx]
|
||||
feat = qkv(feat)
|
||||
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1)
|
||||
q = self.dws[i](q)
|
||||
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
|
||||
q = dws(q)
|
||||
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
|
||||
q = q * self.scale
|
||||
attn = q.transpose(-2, -1) @ k
|
||||
attn = attn + attn_bias
|
||||
attn = attn + attn_bias[head_idx]
|
||||
attn = attn.softmax(dim=-1)
|
||||
feat = v @ attn.transpose(-2, -1)
|
||||
feat = feat.view(B, self.d, H, W)
|
||||
feat = feat.view(B, self.val_dim, H, W)
|
||||
feats_out.append(feat)
|
||||
x = self.proj(torch.cat(feats_out, 1))
|
||||
return x
|
||||
@ -237,8 +247,8 @@ class LocalWindowAttention(torch.nn.Module):
|
||||
H = W = self.resolution
|
||||
B, C, H_, W_ = x.shape
|
||||
# Only check this for classifcation models
|
||||
assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))
|
||||
|
||||
_assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
|
||||
_assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
|
||||
if H <= self.window_resolution and W <= self.window_resolution:
|
||||
x = self.attn(x)
|
||||
else:
|
||||
@ -519,38 +529,37 @@ def _cfg(url='', **kwargs):
|
||||
'first_conv': 'patch_embed.conv1.conv',
|
||||
'classifier': 'head.linear',
|
||||
'fixed_input_size': True,
|
||||
'pool_size': (4, 4),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs(
|
||||
{
|
||||
'efficientvit_m0.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
|
||||
),
|
||||
'efficientvit_m1.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
|
||||
),
|
||||
'efficientvit_m2.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
|
||||
),
|
||||
'efficientvit_m3.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
|
||||
),
|
||||
'efficientvit_m4.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
|
||||
),
|
||||
'efficientvit_m5.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
|
||||
),
|
||||
}
|
||||
)
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'efficientvit_m0.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
|
||||
),
|
||||
'efficientvit_m1.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
|
||||
),
|
||||
'efficientvit_m2.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
|
||||
),
|
||||
'efficientvit_m3.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
|
||||
),
|
||||
'efficientvit_m4.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
|
||||
),
|
||||
'efficientvit_m5.r224_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user