mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix for torch script
This commit is contained in:
parent
7ab9d4555c
commit
bb50b69a57
@ -26,6 +26,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.hub
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
@ -135,23 +136,16 @@ class CrossAttention(nn.Module):
|
||||
class CrossAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True):
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = CrossAttention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.has_mlp = has_mlp
|
||||
if has_mlp:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
|
||||
if self.has_mlp:
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
@ -192,14 +186,12 @@ class MultiScaleBlock(nn.Module):
|
||||
nh = num_heads[d_]
|
||||
if depth[-1] == 0: # backward capability:
|
||||
self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
|
||||
has_mlp=False))
|
||||
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
||||
else:
|
||||
tmp = []
|
||||
for _ in range(depth[-1]):
|
||||
tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
|
||||
has_mlp=False))
|
||||
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
||||
self.fusion.append(nn.Sequential(*tmp))
|
||||
|
||||
self.revert_projs = nn.ModuleList()
|
||||
@ -210,16 +202,23 @@ class MultiScaleBlock(nn.Module):
|
||||
tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])]
|
||||
self.revert_projs.append(nn.Sequential(*tmp))
|
||||
|
||||
def forward(self, x):
|
||||
outs_b = [block(x_) for x_, block in zip(x, self.blocks)]
|
||||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
|
||||
outs_b = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
outs_b.append(block(x[i]))
|
||||
|
||||
# only take the cls token out
|
||||
proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)]
|
||||
proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
|
||||
for i, proj in enumerate(self.projs):
|
||||
proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
|
||||
|
||||
# cross attention
|
||||
outs = []
|
||||
for i in range(self.num_branches):
|
||||
for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
|
||||
tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
|
||||
tmp = self.fusion[i](tmp)
|
||||
reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...])
|
||||
tmp = fusion(tmp)
|
||||
reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
|
||||
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
|
||||
outs.append(tmp)
|
||||
return outs
|
||||
@ -246,11 +245,15 @@ class CrossViT(nn.Module):
|
||||
self.num_branches = len(patch_size)
|
||||
|
||||
self.patch_embed = nn.ModuleList()
|
||||
self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)])
|
||||
|
||||
# hard-coded for torch jit script
|
||||
for i in range(self.num_branches):
|
||||
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
|
||||
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
|
||||
|
||||
for im_s, p, d in zip(img_size, patch_size, embed_dim):
|
||||
self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
|
||||
|
||||
self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)])
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
total_depth = sum([sum(x[-2:]) for x in depth])
|
||||
@ -270,9 +273,10 @@ class CrossViT(nn.Module):
|
||||
self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
|
||||
|
||||
for i in range(self.num_branches):
|
||||
if self.pos_embed[i].requires_grad:
|
||||
trunc_normal_(self.pos_embed[i], std=.02)
|
||||
trunc_normal_(self.cls_token[i], std=.02)
|
||||
if hasattr(self, f'pos_embed_{i}'):
|
||||
# if self.pos_embed[i].requires_grad:
|
||||
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
||||
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@ -302,27 +306,29 @@ class CrossViT(nn.Module):
|
||||
def forward_features(self, x):
|
||||
B, C, H, W = x.shape
|
||||
xs = []
|
||||
for i in range(self.num_branches):
|
||||
for i, patch_embed in enumerate(self.patch_embed):
|
||||
x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x
|
||||
tmp = self.patch_embed[i](x_)
|
||||
cls_tokens = self.cls_token[i].expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
tmp = patch_embed(x_)
|
||||
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
||||
cls_tokens = cls_tokens.expand(B, -1, -1)
|
||||
tmp = torch.cat((cls_tokens, tmp), dim=1)
|
||||
tmp = tmp + self.pos_embed[i]
|
||||
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
|
||||
tmp = tmp + pos_embed
|
||||
tmp = self.pos_drop(tmp)
|
||||
xs.append(tmp)
|
||||
|
||||
for blk in self.blocks:
|
||||
for i, blk in enumerate(self.blocks):
|
||||
xs = blk(xs)
|
||||
|
||||
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
||||
xs = [self.norm[i](x) for i, x in enumerate(xs)]
|
||||
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
||||
out = [x[:, 0] for x in xs]
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
xs = self.forward_features(x)
|
||||
ce_logits = [self.head[i](x) for i, x in enumerate(xs)]
|
||||
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
|
||||
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
|
||||
return ce_logits
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user