From 7d7589e8da92d7fe0f65787c234eaa0d49740520 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Aug 2023 23:23:11 -0700 Subject: [PATCH] Fixing efficient_vit torchscript, fx, default_cfg issues --- timm/models/efficientvit_mit.py | 23 +++---- timm/models/efficientvit_msra.py | 103 +++++++++++++++++-------------- 2 files changed, 68 insertions(+), 58 deletions(-) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index 0b9ddfc3..6d123cd4 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -53,7 +53,7 @@ class ConvNormAct(nn.Module): dilation=1, groups=1, bias=False, - dropout=0, + dropout=0., norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, ): @@ -248,7 +248,7 @@ class LiteMSA(nn.Module): # lightweight global attention q = self.kernel_func(q) k = self.kernel_func(k) - v = F.pad(v, (0, 1), mode="constant", value=1) + v = F.pad(v, (0, 1), mode="constant", value=1.) kv = k.transpose(-1, -2) @ v out = q @ kv @@ -443,7 +443,7 @@ class ClassifierHead(nn.Module): in_channels, widths, n_classes=1000, - dropout=0, + dropout=0., norm_layer=nn.BatchNorm2d, act_layer=nn.Hardswish, global_pool='avg', @@ -547,7 +547,7 @@ class EfficientVit(nn.Module): def get_classifier(self): return self.head.classifier[-1] - def reset_classifier(self, num_classes, global_pool=None, dropout=0): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool @@ -561,7 +561,7 @@ class EfficientVit(nn.Module): ) else: if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) else: self.head = nn.Identity() @@ -592,6 +592,7 @@ def _cfg(url='', **kwargs): 'classifier': 'head.classifier.4', 'crop_pct': 0.95, 'input_size': (3, 224, 224), + 'pool_size': (7, 7), **kwargs, } @@ -605,33 +606,33 @@ default_cfgs = generate_default_cfgs({ ), 'efficientvit_b1.r256_in1k': _cfg( hf_hub_id='timm/', - input_size=(3, 256, 256), crop_pct=1.0, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, ), 'efficientvit_b1.r288_in1k': _cfg( hf_hub_id='timm/', - input_size=(3, 288, 288), crop_pct=1.0, + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, ), 'efficientvit_b2.r224_in1k': _cfg( hf_hub_id='timm/', ), 'efficientvit_b2.r256_in1k': _cfg( hf_hub_id='timm/', - input_size=(3, 256, 256), crop_pct=1.0, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, ), 'efficientvit_b2.r288_in1k': _cfg( hf_hub_id='timm/', - input_size=(3, 288, 288), crop_pct=1.0, + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, ), 'efficientvit_b3.r224_in1k': _cfg( hf_hub_id='timm/', ), 'efficientvit_b3.r256_in1k': _cfg( hf_hub_id='timm/', - input_size=(3, 256, 256), crop_pct=1.0, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, ), 'efficientvit_b3.r288_in1k': _cfg( hf_hub_id='timm/', - input_size=(3, 288, 288), crop_pct=1.0, + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, ), }) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 69ee0bb6..8940df0f 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -9,12 +9,13 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/Effic __all__ = ['EfficientVitMsra'] import itertools from collections import OrderedDict +from typing import Dict import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_ +from timm.models.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -113,6 +114,8 @@ class ConvMlp(torch.nn.Module): class CascadedGroupAttention(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + r""" Cascaded Group Attention. Args: @@ -136,19 +139,19 @@ class CascadedGroupAttention(torch.nn.Module): self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim - self.d = int(attn_ratio * key_dim) + self.val_dim = int(attn_ratio * key_dim) self.attn_ratio = attn_ratio qkvs = [] dws = [] for i in range(num_heads): - qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.d)) + qkvs.append(ConvNorm(dim // (num_heads), self.key_dim * 2 + self.val_dim)) dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim)) self.qkvs = torch.nn.ModuleList(qkvs) self.dws = torch.nn.ModuleList(dws) self.proj = torch.nn.Sequential( torch.nn.ReLU(), - ConvNorm(self.d * num_heads, dim, bn_weight_init=0) + ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0) ) points = list(itertools.product(range(resolution), range(resolution))) @@ -161,37 +164,44 @@ class CascadedGroupAttention(torch.nn.Module): if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter( - torch.zeros(num_heads, len(attention_offsets))) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False) + self.attention_bias_cache = {} @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and hasattr(self, 'ab'): - del self.ab + if mode and self.attention_bias_cache: + self.attention_bias_cache = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if torch.jit.is_tracing() or self.training: + return self.attention_biases[:, self.attention_bias_idxs] else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] + device_key = str(device) + if device_key not in self.attention_bias_cache: + self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.attention_bias_cache[device_key] def forward(self, x): B, C, H, W = x.shape feats_in = x.chunk(len(self.qkvs), dim=1) feats_out = [] feat = feats_in[0] - for i, qkv in enumerate(self.qkvs): - attn_bias = self.attention_biases[:, self.attention_bias_idxs][i] if self.training else self.ab[i] - if i > 0: - feat = feat + feats_in[i] + attn_bias = self.get_attention_biases(x.device) + for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)): + if head_idx > 0: + feat = feat + feats_in[head_idx] feat = qkv(feat) - q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) - q = self.dws[i](q) + q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1) + q = dws(q) q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) q = q * self.scale attn = q.transpose(-2, -1) @ k - attn = attn + attn_bias + attn = attn + attn_bias[head_idx] attn = attn.softmax(dim=-1) feat = v @ attn.transpose(-2, -1) - feat = feat.view(B, self.d, H, W) + feat = feat.view(B, self.val_dim, H, W) feats_out.append(feat) x = self.proj(torch.cat(feats_out, 1)) return x @@ -237,8 +247,8 @@ class LocalWindowAttention(torch.nn.Module): H = W = self.resolution B, C, H_, W_ = x.shape # Only check this for classifcation models - assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_)) - + _assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') + _assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}') if H <= self.window_resolution and W <= self.window_resolution: x = self.attn(x) else: @@ -519,38 +529,37 @@ def _cfg(url='', **kwargs): 'first_conv': 'patch_embed.conv1.conv', 'classifier': 'head.linear', 'fixed_input_size': True, + 'pool_size': (4, 4), **kwargs, } -default_cfgs = generate_default_cfgs( - { - 'efficientvit_m0.r224_in1k': _cfg( - hf_hub_id='timm/', - #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' - ), - 'efficientvit_m1.r224_in1k': _cfg( - hf_hub_id='timm/', - #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth' - ), - 'efficientvit_m2.r224_in1k': _cfg( - hf_hub_id='timm/', - #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth' - ), - 'efficientvit_m3.r224_in1k': _cfg( - hf_hub_id='timm/', - #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth' - ), - 'efficientvit_m4.r224_in1k': _cfg( - hf_hub_id='timm/', - #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth' - ), - 'efficientvit_m5.r224_in1k': _cfg( - hf_hub_id='timm/', - #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth' - ), - } -) +default_cfgs = generate_default_cfgs({ + 'efficientvit_m0.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth' + ), + 'efficientvit_m1.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth' + ), + 'efficientvit_m2.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth' + ), + 'efficientvit_m3.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth' + ), + 'efficientvit_m4.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth' + ), + 'efficientvit_m5.r224_in1k': _cfg( + hf_hub_id='timm/', + #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth' + ), +}) def _create_efficientvit_msra(variant, pretrained=False, **kwargs):