Fixing tinyvit torchscript issue

This commit is contained in:
方曦 2023-08-31 10:47:31 +08:00 committed by Ross Wightman
parent bae949f830
commit fabc4e5bcd

View File

@ -9,13 +9,14 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV
__all__ = ['TinyVit'] __all__ = ['TinyVit']
import math import math
import itertools import itertools
from typing import Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table, _assert
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -178,6 +179,8 @@ class ClassifierHead(nn.Module):
class Attention(torch.nn.Module): class Attention(torch.nn.Module):
attention_bias_cache: Dict[str, torch.Tensor]
def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)): def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)):
super().__init__() super().__init__()
assert isinstance(resolution, tuple) and len(resolution) == 2 assert isinstance(resolution, tuple) and len(resolution) == 2
@ -304,7 +307,7 @@ class TinyVitBlock(nn.Module):
def forward(self, x): def forward(self, x):
H, W = self.input_resolution H, W = self.input_resolution
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" _assert(L == H * W, f"input feature has wrong size, expect {H * W}, got {L}")
res_x = x res_x = x
if H == self.window_size and W == self.window_size: if H == self.window_size and W == self.window_size:
x = self.attn(x) x = self.attn(x)