mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixing tinyvit torchscript issue
This commit is contained in:
parent
bae949f830
commit
fabc4e5bcd
@ -9,13 +9,14 @@ Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyV
|
||||
__all__ = ['TinyVit']
|
||||
import math
|
||||
import itertools
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -178,6 +179,8 @@ class ClassifierHead(nn.Module):
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
attention_bias_cache: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)):
|
||||
super().__init__()
|
||||
assert isinstance(resolution, tuple) and len(resolution) == 2
|
||||
@ -304,7 +307,7 @@ class TinyVitBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
_assert(L == H * W, f"input feature has wrong size, expect {H * W}, got {L}")
|
||||
res_x = x
|
||||
if H == self.window_size and W == self.window_size:
|
||||
x = self.attn(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user