Fixing efficient_vit torchscript, fx, default_cfg issues

This commit is contained in:
Ross Wightman 2023-08-18 23:23:11 -07:00
parent 58ea1c02c4
commit 7d7589e8da
2 changed files with 68 additions and 58 deletions

View File

@ -53,7 +53,7 @@ class ConvNormAct(nn.Module):
dilation=1,
groups=1,
bias=False,
dropout=0,
dropout=0.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
@ -248,7 +248,7 @@ class LiteMSA(nn.Module):
# lightweight global attention
q = self.kernel_func(q)
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1)
v = F.pad(v, (0, 1), mode="constant", value=1.)
kv = k.transpose(-1, -2) @ v
out = q @ kv
@ -443,7 +443,7 @@ class ClassifierHead(nn.Module):
in_channels,
widths,
n_classes=1000,
dropout=0,
dropout=0.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
global_pool='avg',
@ -547,7 +547,7 @@ class EfficientVit(nn.Module):
def get_classifier(self):
return self.head.classifier[-1]
def reset_classifier(self, num_classes, global_pool=None, dropout=0):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
@ -561,7 +561,7 @@ class EfficientVit(nn.Module):
)
else:
if self.global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
else:
self.head = nn.Identity()
@ -592,6 +592,7 @@ def _cfg(url='', **kwargs):
'classifier': 'head.classifier.4',
'crop_pct': 0.95,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
**kwargs,
}
@ -605,33 +606,33 @@ default_cfgs = generate_default_cfgs({
),
'efficientvit_b1.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=1.0,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_b1.r288_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 288, 288), crop_pct=1.0,
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
'efficientvit_b2.r224_in1k': _cfg(
hf_hub_id='timm/',
),
'efficientvit_b2.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=1.0,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_b2.r288_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 288, 288), crop_pct=1.0,
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
'efficientvit_b3.r224_in1k': _cfg(
hf_hub_id='timm/',
),
'efficientvit_b3.r256_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=1.0,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
),
'efficientvit_b3.r288_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 288, 288), crop_pct=1.0,
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
),
})

View File

@ -9,12 +9,13 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic
__all__ = ['EfficientVitMsra']
import itertools
from collections import OrderedDict
from typing import Dict
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_
from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -113,6 +114,8 @@ class ConvMlp(torch.nn.Module):
class CascadedGroupAttention(torch.nn.Module):
attention_bias_cache: Dict[str, torch.Tensor]
r""" Cascaded Group Attention.
Args:
@ -136,19 +139,19 @@ class CascadedGroupAttention(torch.nn.Module):
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.d = int(attn_ratio * key_dim)
self.val_dim = int(attn_ratio * key_dim)
self.attn_ratio = attn_ratio
qkvs = []
dws = []
for i in range(num_heads):
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.d))
qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim))
dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim))
self.qkvs = torch.nn.ModuleList(qkvs)
self.dws = torch.nn.ModuleList(dws)
self.proj = torch.nn.Sequential(
torch.nn.ReLU(),
ConvNorm(self.d * num_heads, dim, bn_weight_init=0)
ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0)
)
points = list(itertools.product(range(resolution), range(resolution)))
@ -161,37 +164,44 @@ class CascadedGroupAttention(torch.nn.Module):
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(
torch.zeros(num_heads, len(attention_offsets)))
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
self.attention_bias_cache = {}
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
if mode and self.attention_bias_cache:
self.attention_bias_cache = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if torch.jit.is_tracing() or self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]
def forward(self, x):
B, C, H, W = x.shape
feats_in = x.chunk(len(self.qkvs), dim=1)
feats_out = []
feat = feats_in[0]
for i, qkv in enumerate(self.qkvs):
attn_bias = self.attention_biases[:, self.attention_bias_idxs][i] if self.training else self.ab[i]
if i > 0:
feat = feat + feats_in[i]
attn_bias = self.get_attention_biases(x.device)
for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
if head_idx > 0:
feat = feat + feats_in[head_idx]
feat = qkv(feat)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1)
q = self.dws[i](q)
q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
q = dws(q)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
q = q * self.scale
attn = q.transpose(-2, -1) @ k
attn = attn + attn_bias
attn = attn + attn_bias[head_idx]
attn = attn.softmax(dim=-1)
feat = v @ attn.transpose(-2, -1)
feat = feat.view(B, self.d, H, W)
feat = feat.view(B, self.val_dim, H, W)
feats_out.append(feat)
x = self.proj(torch.cat(feats_out, 1))
return x
@ -237,8 +247,8 @@ class LocalWindowAttention(torch.nn.Module):
H = W = self.resolution
B, C, H_, W_ = x.shape
# Only check this for classifcation models
assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))
_assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
_assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
if H <= self.window_resolution and W <= self.window_resolution:
x = self.attn(x)
else:
@ -519,38 +529,37 @@ def _cfg(url='', **kwargs):
'first_conv': 'patch_embed.conv1.conv',
'classifier': 'head.linear',
'fixed_input_size': True,
'pool_size': (4, 4),
**kwargs,
}
default_cfgs = generate_default_cfgs(
{
'efficientvit_m0.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
),
'efficientvit_m1.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
),
'efficientvit_m2.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
),
'efficientvit_m3.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
),
'efficientvit_m4.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
),
'efficientvit_m5.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
),
}
)
default_cfgs = generate_default_cfgs({
'efficientvit_m0.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
),
'efficientvit_m1.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
),
'efficientvit_m2.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
),
'efficientvit_m3.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
),
'efficientvit_m4.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
),
'efficientvit_m5.r224_in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
),
})
def _create_efficientvit_msra(variant, pretrained=False, **kwargs):