Fix torchscript for vit-hybrid dynamic_resize

This commit is contained in:
Ross Wightman 2023-08-25 13:33:13 -07:00 committed by Ross Wightman
parent fdd8c7c2da
commit 4d8ecde6cc
2 changed files with 3 additions and 1 deletions

View File

@ -38,7 +38,7 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed, resample_abs_pos_embed_nhwc, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations

View File

@ -31,6 +31,8 @@ class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding """ CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim. Extract feature map from CNN, flatten, project to embedding dim.
""" """
output_fmt: Format
def __init__( def __init__(
self, self,
backbone, backbone,