mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
parent
434a03937d
commit
edea013dd1
@ -27,7 +27,9 @@ NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
|
||||
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
|
||||
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
|
@ -7,47 +7,34 @@ attention in each block. The attention mechanisms used are linear in complexity.
|
||||
|
||||
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
"""
|
||||
# Copyright (c) 2022 Mingyu Ding
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the MIT license
|
||||
|
||||
import itertools
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .helpers import build_model_with_cfg
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
||||
from collections import OrderedDict
|
||||
from torch import Tensor
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from .features import FeatureInfo
|
||||
from .fx_features import register_notrace_function, register_notrace_module
|
||||
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
||||
from .pretrained import generate_default_cfgs
|
||||
from .registry import register_model
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['DaViT']
|
||||
|
||||
|
||||
|
||||
|
||||
class MySequential(nn.Sequential):
|
||||
def forward(self, *inputs):
|
||||
for module in self._modules.values():
|
||||
if type(inputs) == tuple:
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
class ConvPosEnc(nn.Module):
|
||||
def __init__(self, dim, k=3, act=False, normtype=False):
|
||||
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
||||
|
||||
super(ConvPosEnc, self).__init__()
|
||||
self.proj = nn.Conv2d(dim,
|
||||
dim,
|
||||
@ -56,16 +43,16 @@ class ConvPosEnc(nn.Module):
|
||||
to_2tuple(k // 2),
|
||||
groups=dim)
|
||||
self.normtype = normtype
|
||||
self.norm = nn.Identity()
|
||||
if self.normtype == 'batch':
|
||||
self.norm = nn.BatchNorm2d(dim)
|
||||
elif self.normtype == 'layer':
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.activation = nn.GELU() if act else nn.Identity()
|
||||
|
||||
def forward(self, x, size: Tuple[int, int]):
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == H * W
|
||||
|
||||
feat = x.transpose(1, 2).view(B, C, H, W)
|
||||
feat = self.proj(feat)
|
||||
@ -77,8 +64,11 @@ class ConvPosEnc(nn.Module):
|
||||
feat = feat.flatten(2).transpose(1, 2)
|
||||
x = x + self.activation(feat)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
# reason: dim in control sequence
|
||||
# FIXME reimplement to allow tracing
|
||||
@register_notrace_module
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Size-agnostic implementation of 2D image to patch embedding,
|
||||
allowing input size to be adjusted during model forward operation
|
||||
@ -113,9 +103,10 @@ class PatchEmbed(nn.Module):
|
||||
padding=to_2tuple(pad))
|
||||
self.norm = nn.LayerNorm(in_chans)
|
||||
|
||||
def forward(self, x, size):
|
||||
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
H, W = size
|
||||
dim = len(x.shape)
|
||||
dim = x.dim()
|
||||
if dim == 3:
|
||||
B, HW, C = x.shape
|
||||
x = self.norm(x)
|
||||
@ -149,7 +140,7 @@ class ChannelAttention(nn.Module):
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x : Tensor):
|
||||
B, N, C = x.shape
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
@ -186,7 +177,8 @@ class ChannelBlock(nn.Module):
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer)
|
||||
|
||||
def forward(self, x, size):
|
||||
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
x = self.cpe[0](x, size)
|
||||
cur = self.norm1(x)
|
||||
cur = self.attn(cur)
|
||||
@ -198,7 +190,7 @@ class ChannelBlock(nn.Module):
|
||||
return x, size
|
||||
|
||||
|
||||
def window_partition(x, window_size: int):
|
||||
def window_partition(x : Tensor, window_size: int):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
@ -211,8 +203,8 @@ def window_partition(x, window_size: int):
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size: int, H: int, W: int):
|
||||
@register_notrace_function # reason: int argument is a Proxy
|
||||
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
@ -222,6 +214,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
@ -252,7 +245,7 @@ class WindowAttention(nn.Module):
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x : Tensor):
|
||||
B_, N, C = x.shape
|
||||
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
@ -310,10 +303,11 @@ class SpatialBlock(nn.Module):
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer)
|
||||
|
||||
def forward(self, x, size):
|
||||
|
||||
def forward(self, x : Tensor, size: Tuple[int, int]):
|
||||
|
||||
H, W = size
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = self.cpe[0](x, size)
|
||||
x = self.norm1(shortcut)
|
||||
@ -338,8 +332,8 @@ class SpatialBlock(nn.Module):
|
||||
C)
|
||||
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
#if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
x = shortcut + self.drop_path(x)
|
||||
@ -352,12 +346,17 @@ class SpatialBlock(nn.Module):
|
||||
|
||||
|
||||
class DaViT(nn.Module):
|
||||
r""" Dual Attention Transformer
|
||||
r""" DaViT
|
||||
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
|
||||
Supports arbitrary input sizes and pyramid feature extraction
|
||||
|
||||
Args:
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
|
||||
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
|
||||
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
@ -383,7 +382,6 @@ class DaViT(nn.Module):
|
||||
cpe_act=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
img_size=224,
|
||||
num_classes=1000,
|
||||
global_pool='avg'
|
||||
):
|
||||
@ -401,7 +399,7 @@ class DaViT(nn.Module):
|
||||
self.num_features = embed_dims[-1]
|
||||
self.drop_rate=drop_rate
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.feature_info = []
|
||||
|
||||
self.patch_embeds = nn.ModuleList([
|
||||
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
||||
@ -410,12 +408,12 @@ class DaViT(nn.Module):
|
||||
overlapped=overlapped_patch)
|
||||
for i in range(self.num_stages)])
|
||||
|
||||
main_blocks = []
|
||||
for block_id, block_param in enumerate(self.architecture):
|
||||
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id])))
|
||||
self.stages = nn.ModuleList()
|
||||
for stage_id, stage_param in enumerate(self.architecture):
|
||||
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
|
||||
|
||||
block = nn.ModuleList([
|
||||
MySequential(*[
|
||||
stage = nn.ModuleList([
|
||||
nn.ModuleList([
|
||||
ChannelBlock(
|
||||
dim=self.embed_dims[item],
|
||||
num_heads=self.num_heads[item],
|
||||
@ -438,27 +436,17 @@ class DaViT(nn.Module):
|
||||
window_size=window_size,
|
||||
) if attention_type == 'spatial' else None
|
||||
for attention_id, attention_type in enumerate(attention_types)]
|
||||
) for layer_id, item in enumerate(block_param)
|
||||
) for layer_id, item in enumerate(stage_param)
|
||||
])
|
||||
main_blocks.append(block)
|
||||
self.main_blocks = nn.ModuleList(main_blocks)
|
||||
|
||||
'''
|
||||
# layer norms for pyramid feature extraction
|
||||
#
|
||||
# TODO implement pyramid feature extraction
|
||||
#
|
||||
# davit should be a good transformer candidate, since the only official implementation
|
||||
# is for segmentation and detection
|
||||
for i_layer in range(self.num_stages):
|
||||
layer = norm_layer(self.embed_dims[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
'''
|
||||
|
||||
self.stages.add_module(f'stage_{stage_id}', stage)
|
||||
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')]
|
||||
|
||||
self.norms = norm_layer(self.num_features)
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
@ -467,9 +455,7 @@ class DaViT(nn.Module):
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
@ -485,55 +471,67 @@ class DaViT(nn.Module):
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
|
||||
def forward_features_full(self, x):
|
||||
x, size = self.patch_embeds[0](x, (x.size(2), x.size(3)))
|
||||
def forward_network(self, x):
|
||||
size: Tuple[int, int] = (x.size(2), x.size(3))
|
||||
features = [x]
|
||||
sizes = [size]
|
||||
branches = [0]
|
||||
|
||||
for block_index, block_param in enumerate(self.architecture):
|
||||
branch_ids = sorted(set(block_param))
|
||||
for branch_id in branch_ids:
|
||||
if branch_id not in branches:
|
||||
x, size = self.patch_embeds[branch_id](features[-1], sizes[-1])
|
||||
features.append(x)
|
||||
sizes.append(size)
|
||||
branches.append(branch_id)
|
||||
for layer_index, branch_id in enumerate(block_param):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
|
||||
else:
|
||||
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
|
||||
'''
|
||||
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
|
||||
outs = []
|
||||
for i in range(self.num_stages):
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(features[i])
|
||||
H, W = sizes[i]
|
||||
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
'''
|
||||
# non-normalized pyramid features + corresponding sizes
|
||||
return tuple(features), tuple(sizes)
|
||||
|
||||
for patch_layer, stage in zip(self.patch_embeds, self.stages):
|
||||
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
||||
for _, block in enumerate(stage):
|
||||
for _, layer in enumerate(block):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
|
||||
else:
|
||||
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
|
||||
|
||||
# don't append outputs of last stage, since they are already there
|
||||
if(len(features) < self.num_stages):
|
||||
features.append(features[-1])
|
||||
sizes.append(sizes[-1])
|
||||
|
||||
|
||||
# non-normalized pyramid features + corresponding sizes
|
||||
return features, sizes
|
||||
|
||||
def forward_pyramid_features(self, x) -> List[Tensor]:
|
||||
x, sizes = self.forward_network(x)
|
||||
outs = []
|
||||
for i, out in enumerate(x):
|
||||
H, W = sizes[i]
|
||||
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
||||
|
||||
return outs
|
||||
|
||||
def forward_features(self, x):
|
||||
x, sizes = self.forward_features_full(x)
|
||||
x, sizes = self.forward_network(x)
|
||||
# take final feature and norm
|
||||
x = self.norms(x[-1])
|
||||
H, W = sizes[-1]
|
||||
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
||||
#print(x.shape)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
|
||||
def forward(self, x):
|
||||
def forward_classifier(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
return self.forward_classifier(x)
|
||||
|
||||
|
||||
class DaViTFeatures(DaViT):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
|
||||
|
||||
def forward(self, x) -> List[Tensor]:
|
||||
return self.forward_pyramid_features(x)
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap MSFT checkpoints -> timm """
|
||||
@ -542,11 +540,10 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
|
||||
if 'state_dict' in state_dict:
|
||||
state_dict = state_dict['state_dict']
|
||||
|
||||
|
||||
out_dict = {}
|
||||
import re
|
||||
for k, v in state_dict.items():
|
||||
|
||||
k = k.replace('main_blocks.', 'stages.stage_')
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
@ -554,8 +551,25 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
|
||||
|
||||
def _create_davit(variant, pretrained=False, **kwargs):
|
||||
model = build_model_with_cfg(DaViT, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
|
||||
model_cls = DaViT
|
||||
features_only = False
|
||||
kwargs_filter = None
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
if kwargs.pop('features_only', False):
|
||||
model_cls = DaViTFeatures
|
||||
kwargs_filter = ('num_classes', 'global_pool')
|
||||
features_only = True
|
||||
model = build_model_with_cfg(
|
||||
model_cls,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
if features_only:
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
|
||||
model.default_cfg = model.pretrained_cfg # backwards compat
|
||||
return model
|
||||
|
||||
|
||||
@ -573,13 +587,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
|
||||
'davit_tiny.msft_in1k': _cfg(
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
|
||||
'davit_small.msft_in1k': _cfg(
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
||||
'davit_base.msft_in1k': _cfg(
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
||||
# official microsoft weights from https://github.com/dingmyu/davit
|
||||
'davit_tiny.msft_in1k': _cfg(
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
|
||||
'davit_small.msft_in1k': _cfg(
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
||||
'davit_base.msft_in1k': _cfg(
|
||||
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
||||
})
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user