mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Davit std (#5)
* Update davit.py * Update test_models.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * starting point * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update test_models.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Davit revised (#4) * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py clean up * Update test_models.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update test_models.py * Update davit.py
This commit is contained in:
parent
edea013dd1
commit
c43340ddd4
@ -40,7 +40,7 @@ if 'GITHUB_ACTIONS' in os.environ:
|
||||
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
|
||||
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
|
||||
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
|
||||
'swin*giant*']
|
||||
'swin*giant*', 'davit*giant', 'davit*huge']
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
|
||||
else:
|
||||
EXCLUDE_FILTERS = []
|
||||
@ -271,7 +271,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
||||
|
||||
EXCLUDE_JIT_FILTERS = [
|
||||
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
|
||||
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
|
||||
'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
|
||||
'vit_large_*', 'vit_huge_*', 'vit_gi*',
|
||||
]
|
||||
|
||||
|
@ -12,8 +12,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the MIT license
|
||||
|
||||
# FIXME remove unused imports
|
||||
|
||||
import itertools
|
||||
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
@ -32,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['DaViT']
|
||||
|
||||
|
||||
class ConvPosEnc(nn.Module):
|
||||
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
||||
|
||||
@ -50,25 +53,21 @@ class ConvPosEnc(nn.Module):
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.activation = nn.GELU() if act else nn.Identity()
|
||||
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
def forward(self, x : Tensor):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
feat = x.transpose(1, 2).view(B, C, H, W)
|
||||
feat = self.proj(feat)
|
||||
#feat = x.transpose(1, 2).view(B, C, H, W)
|
||||
feat = self.proj(x)
|
||||
if self.normtype == 'batch':
|
||||
feat = self.norm(feat).flatten(2).transpose(1, 2)
|
||||
elif self.normtype == 'layer':
|
||||
feat = self.norm(feat.flatten(2).transpose(1, 2))
|
||||
else:
|
||||
feat = feat.flatten(2).transpose(1, 2)
|
||||
x = x + self.activation(feat)
|
||||
x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W)
|
||||
return x
|
||||
|
||||
|
||||
# reason: dim in control sequence
|
||||
# FIXME reimplement to allow tracing
|
||||
@register_notrace_module
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Size-agnostic implementation of 2D image to patch embedding,
|
||||
allowing input size to be adjusted during model forward operation
|
||||
@ -76,13 +75,15 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=16,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
overlapped=False):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if patch_size[0] == 4:
|
||||
self.proj = nn.Conv2d(
|
||||
@ -104,31 +105,20 @@ class PatchEmbed(nn.Module):
|
||||
self.norm = nn.LayerNorm(in_chans)
|
||||
|
||||
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
H, W = size
|
||||
dim = x.dim()
|
||||
if dim == 3:
|
||||
B, HW, C = x.shape
|
||||
x = self.norm(x)
|
||||
x = x.reshape(B,
|
||||
H,
|
||||
W,
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
def forward(self, x : Tensor):
|
||||
B, C, H, W = x.shape
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
if self.norm.normalized_shape[0] == self.in_chans:
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]))
|
||||
x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x)
|
||||
newsize = (x.size(2), x.size(3))
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
if dim == 4:
|
||||
x = self.norm(x)
|
||||
return x, newsize
|
||||
|
||||
|
||||
if self.norm.normalized_shape[0] == self.embed_dim:
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
||||
@ -153,7 +143,7 @@ class ChannelAttention(nn.Module):
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ChannelBlock(nn.Module):
|
||||
|
||||
@ -162,13 +152,13 @@ class ChannelBlock(nn.Module):
|
||||
ffn=True, cpe_act=False):
|
||||
super().__init__()
|
||||
|
||||
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
|
||||
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
|
||||
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
||||
self.ffn = ffn
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
||||
|
||||
if self.ffn:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
@ -178,17 +168,23 @@ class ChannelBlock(nn.Module):
|
||||
act_layer=act_layer)
|
||||
|
||||
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
x = self.cpe[0](x, size)
|
||||
def forward(self, x : Tensor):
|
||||
|
||||
B, C, H, W = x.shape
|
||||
|
||||
x = self.cpe1(x).flatten(2).transpose(1, 2)
|
||||
|
||||
cur = self.norm1(x)
|
||||
cur = self.attn(cur)
|
||||
x = x + self.drop_path(cur)
|
||||
|
||||
x = self.cpe[1](x, size)
|
||||
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
|
||||
if self.ffn:
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x, size
|
||||
|
||||
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def window_partition(x : Tensor, window_size: int):
|
||||
"""
|
||||
@ -283,9 +279,8 @@ class SpatialBlock(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
|
||||
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
|
||||
|
||||
|
||||
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
@ -294,7 +289,8 @@ class SpatialBlock(nn.Module):
|
||||
qkv_bias=qkv_bias)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
||||
|
||||
if self.ffn:
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
@ -304,12 +300,11 @@ class SpatialBlock(nn.Module):
|
||||
act_layer=act_layer)
|
||||
|
||||
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
def forward(self, x : Tensor):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
H, W = size
|
||||
B, L, C = x.shape
|
||||
|
||||
shortcut = self.cpe[0](x, size)
|
||||
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
|
||||
x = self.norm1(shortcut)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
@ -338,11 +333,92 @@ class SpatialBlock(nn.Module):
|
||||
x = x.view(B, H * W, C)
|
||||
x = shortcut + self.drop_path(x)
|
||||
|
||||
x = self.cpe[1](x, size)
|
||||
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
|
||||
if self.ffn:
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x, size
|
||||
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DaViTStage(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
depth = 1,
|
||||
patch_size = 4,
|
||||
overlapped_patch = False,
|
||||
attention_types = ('spatial', 'channel'),
|
||||
num_heads = 3,
|
||||
window_size = 7,
|
||||
mlp_ratio = 4,
|
||||
qkv_bias = True,
|
||||
drop_path_rates = (0, 0),
|
||||
norm_layer = nn.LayerNorm,
|
||||
ffn = True,
|
||||
cpe_act = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# patch embedding layer at the beginning of each stage
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chs,
|
||||
embed_dim=out_chs,
|
||||
overlapped=overlapped_patch
|
||||
)
|
||||
'''
|
||||
repeating alternating attention blocks in each stage
|
||||
default: (spatial -> channel) x depth
|
||||
|
||||
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
|
||||
since the logic is similar
|
||||
'''
|
||||
stage_blocks = []
|
||||
for block_idx in range(depth):
|
||||
|
||||
dual_attention_block = []
|
||||
|
||||
for attention_id, attention_type in enumerate(attention_types):
|
||||
if attention_type == 'spatial':
|
||||
dual_attention_block.append(SpatialBlock(
|
||||
dim=out_chs,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn=ffn,
|
||||
cpe_act=cpe_act,
|
||||
window_size=window_size,
|
||||
))
|
||||
elif attention_type == 'channel':
|
||||
dual_attention_block.append(ChannelBlock(
|
||||
dim=out_chs,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn=ffn,
|
||||
cpe_act=cpe_act
|
||||
))
|
||||
|
||||
stage_blocks.append(nn.Sequential(*dual_attention_block))
|
||||
|
||||
self.blocks = nn.Sequential(*stage_blocks)
|
||||
|
||||
def forward(self, x : Tensor):
|
||||
x = self.patch_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class DaViT(nn.Module):
|
||||
@ -392,7 +468,7 @@ class DaViT(nn.Module):
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.num_stages = len(self.embed_dims)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))]
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
|
||||
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
|
||||
|
||||
self.num_classes = num_classes
|
||||
@ -401,51 +477,37 @@ class DaViT(nn.Module):
|
||||
self.grad_checkpointing = False
|
||||
self.feature_info = []
|
||||
|
||||
self.patch_embeds = nn.ModuleList([
|
||||
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
||||
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
|
||||
embed_dim=self.embed_dims[i],
|
||||
overlapped=overlapped_patch)
|
||||
for i in range(self.num_stages)])
|
||||
stages = []
|
||||
|
||||
for stage_id in range(self.num_stages):
|
||||
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
for stage_id, stage_param in enumerate(self.architecture):
|
||||
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
|
||||
|
||||
stage = nn.ModuleList([
|
||||
nn.ModuleList([
|
||||
ChannelBlock(
|
||||
dim=self.embed_dims[item],
|
||||
num_heads=self.num_heads[item],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn=ffn,
|
||||
cpe_act=cpe_act
|
||||
) if attention_type == 'channel' else
|
||||
SpatialBlock(
|
||||
dim=self.embed_dims[item],
|
||||
num_heads=self.num_heads[item],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn=ffn,
|
||||
cpe_act=cpe_act,
|
||||
window_size=window_size,
|
||||
) if attention_type == 'spatial' else None
|
||||
for attention_id, attention_type in enumerate(attention_types)]
|
||||
) for layer_id, item in enumerate(stage_param)
|
||||
])
|
||||
stage = DaViTStage(
|
||||
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
||||
embed_dims[stage_id],
|
||||
depth = depths[stage_id],
|
||||
patch_size = patch_size if stage_id == 0 else 2,
|
||||
overlapped_patch = overlapped_patch,
|
||||
attention_types = attention_types,
|
||||
num_heads = num_heads[stage_id],
|
||||
window_size = window_size,
|
||||
mlp_ratio = mlp_ratio,
|
||||
qkv_bias = qkv_bias,
|
||||
drop_path_rates = stage_drop_rates,
|
||||
norm_layer = nn.LayerNorm,
|
||||
ffn = ffn,
|
||||
cpe_act = cpe_act
|
||||
)
|
||||
|
||||
self.stages.add_module(f'stage_{stage_id}', stage)
|
||||
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')]
|
||||
|
||||
stages.append(stage)
|
||||
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
||||
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.norms = norm_layer(self.num_features)
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
@ -469,46 +531,13 @@ class DaViT(nn.Module):
|
||||
if global_pool is None:
|
||||
global_pool = self.head.global_pool.pool_type
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
|
||||
def forward_network(self, x):
|
||||
size: Tuple[int, int] = (x.size(2), x.size(3))
|
||||
features = [x]
|
||||
sizes = [size]
|
||||
|
||||
for patch_layer, stage in zip(self.patch_embeds, self.stages):
|
||||
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
||||
for _, block in enumerate(stage):
|
||||
for _, layer in enumerate(block):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
|
||||
else:
|
||||
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
|
||||
|
||||
# don't append outputs of last stage, since they are already there
|
||||
if(len(features) < self.num_stages):
|
||||
features.append(features[-1])
|
||||
sizes.append(sizes[-1])
|
||||
|
||||
|
||||
# non-normalized pyramid features + corresponding sizes
|
||||
return features, sizes
|
||||
|
||||
def forward_pyramid_features(self, x) -> List[Tensor]:
|
||||
x, sizes = self.forward_network(x)
|
||||
outs = []
|
||||
for i, out in enumerate(x):
|
||||
H, W = sizes[i]
|
||||
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
||||
|
||||
return outs
|
||||
|
||||
def forward_features(self, x):
|
||||
x, sizes = self.forward_network(x)
|
||||
x = self.stages(x)
|
||||
# take final feature and norm
|
||||
x = self.norms(x[-1])
|
||||
H, W = sizes[-1]
|
||||
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
||||
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
#H, W = sizes[-1]
|
||||
#x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
@ -521,17 +550,6 @@ class DaViT(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return self.forward_classifier(x)
|
||||
|
||||
|
||||
class DaViTFeatures(DaViT):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
|
||||
|
||||
def forward(self, x) -> List[Tensor]:
|
||||
return self.forward_pyramid_features(x)
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap MSFT checkpoints -> timm """
|
||||
@ -541,38 +559,36 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
if 'state_dict' in state_dict:
|
||||
state_dict = state_dict['state_dict']
|
||||
|
||||
import re
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
k = k.replace('main_blocks.', 'stages.stage_')
|
||||
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
||||
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
k = k.replace('cpe.0', 'cpe1')
|
||||
k = k.replace('cpe.1', 'cpe2')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
|
||||
|
||||
def _create_davit(variant, pretrained=False, **kwargs):
|
||||
model_cls = DaViT
|
||||
features_only = False
|
||||
kwargs_filter = None
|
||||
|
||||
|
||||
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
if kwargs.pop('features_only', False):
|
||||
model_cls = DaViTFeatures
|
||||
kwargs_filter = ('num_classes', 'global_pool')
|
||||
features_only = True
|
||||
|
||||
model = build_model_with_cfg(
|
||||
model_cls,
|
||||
DaViT,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
if features_only:
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
|
||||
model.default_cfg = model.pretrained_cfg # backwards compat
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs): # not sure how this should be set up
|
||||
return {
|
||||
@ -580,7 +596,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
|
||||
'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
@ -594,6 +610,9 @@ default_cfgs = generate_default_cfgs({
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
||||
'davit_base.msft_in1k': _cfg(
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
||||
'davit_large': _cfg(),
|
||||
'davit_huge': _cfg(),
|
||||
'davit_giant': _cfg(),
|
||||
})
|
||||
|
||||
|
||||
@ -616,7 +635,7 @@ def davit_base(pretrained=False, **kwargs):
|
||||
num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
''' models without weights
|
||||
|
||||
# TODO contact authors to get larger pretrained models
|
||||
@register_model
|
||||
def davit_large(pretrained=False, **kwargs):
|
||||
@ -635,4 +654,3 @@ def davit_giant(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
||||
num_heads=(12, 24, 48, 96), **kwargs)
|
||||
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
||||
'''
|
Loading…
x
Reference in New Issue
Block a user