Fixing tinyvit torchscript issue

This commit is contained in:
方曦 2023-08-31 10:47:31 +08:00 committed by Ross Wightman
parent bae949f830
commit fabc4e5bcd

View File

@ -9,13 +9,14 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV
__all__ = ['TinyVit']
import math
import itertools
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table
from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table, _assert
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
@ -178,6 +179,8 @@ class ClassifierHead(nn.Module):
class Attention(torch.nn.Module):
attention_bias_cache: Dict[str, torch.Tensor]
def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)):
super().__init__()
assert isinstance(resolution, tuple) and len(resolution) == 2
@ -304,7 +307,7 @@ class TinyVitBlock(nn.Module):
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
_assert(L == H * W, f"input feature has wrong size, expect {H * W}, got {L}")
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)