diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 2ce5820f..fc341f9a 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -9,13 +9,14 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV __all__ = ['TinyVit'] import math import itertools +from typing import Dict import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table +from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table, _assert from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -178,6 +179,8 @@ class ClassifierHead(nn.Module): class Attention(torch.nn.Module): + attention_bias_cache: Dict[str, torch.Tensor] + def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)): super().__init__() assert isinstance(resolution, tuple) and len(resolution) == 2 @@ -304,7 +307,7 @@ class TinyVitBlock(nn.Module): def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" + _assert(L == H * W, f"input feature has wrong size, expect {H * W}, got {L}") res_x = x if H == self.window_size and W == self.window_size: x = self.attn(x)