mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Move levit style pos bias resize with other rel pos bias utils
This commit is contained in:
parent
63417b438f
commit
9caf32b93f
@ -36,9 +36,9 @@ from .padding import get_padding, get_same_padding, pad_same
|
|||||||
from .patch_dropout import PatchDropout
|
from .patch_dropout import PatchDropout
|
||||||
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
|
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
|
||||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, resample_relative_position_bias_table
|
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
||||||
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
|
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
|
||||||
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple
|
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit
|
||||||
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
|
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
|
||||||
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
|
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
|
||||||
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
|
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
|
||||||
|
@ -78,38 +78,3 @@ def resample_abs_pos_embed_nhwc(
|
|||||||
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
|
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
|
||||||
|
|
||||||
return posemb
|
return posemb
|
||||||
|
|
||||||
|
|
||||||
def resample_relative_position_bias_table(
|
|
||||||
position_bias_table,
|
|
||||||
new_size,
|
|
||||||
interpolation: str = 'bicubic',
|
|
||||||
antialias: bool = True,
|
|
||||||
verbose: bool = False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Resample relative position bias table suggested in LeVit
|
|
||||||
Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
|
|
||||||
"""
|
|
||||||
L1, nH1 = position_bias_table.size()
|
|
||||||
L2, nH2 = new_size
|
|
||||||
assert nH1 == nH2
|
|
||||||
if L1 != L2:
|
|
||||||
orig_dtype = position_bias_table.dtype
|
|
||||||
position_bias_table = position_bias_table.float()
|
|
||||||
# bicubic interpolate relative_position_bias_table if not match
|
|
||||||
S1 = int(L1 ** 0.5)
|
|
||||||
S2 = int(L2 ** 0.5)
|
|
||||||
relative_position_bias_table_resized = F.interpolate(
|
|
||||||
position_bias_table.permute(1, 0).view(1, nH1, S1, S1),
|
|
||||||
size=(S2, S2),
|
|
||||||
mode=interpolation,
|
|
||||||
antialias=antialias)
|
|
||||||
relative_position_bias_table_resized = \
|
|
||||||
relative_position_bias_table_resized.view(nH2, L2).permute(1, 0)
|
|
||||||
relative_position_bias_table_resized.to(orig_dtype)
|
|
||||||
if not torch.jit.is_scripting() and verbose:
|
|
||||||
_logger.info(f'Resized position bias: {L1, nH1} to {L2, nH2}.')
|
|
||||||
return relative_position_bias_table_resized
|
|
||||||
else:
|
|
||||||
return position_bias_table
|
|
||||||
|
@ -121,6 +121,38 @@ def resize_rel_pos_bias_table_simple(
|
|||||||
return rel_pos_bias
|
return rel_pos_bias
|
||||||
|
|
||||||
|
|
||||||
|
def resize_rel_pos_bias_table_levit(
|
||||||
|
position_bias_table,
|
||||||
|
new_size,
|
||||||
|
interpolation: str = 'bicubic',
|
||||||
|
antialias: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Resample relative position bias table suggested in LeVit
|
||||||
|
Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
|
||||||
|
"""
|
||||||
|
L1, nH1 = position_bias_table.size()
|
||||||
|
L2, nH2 = new_size
|
||||||
|
assert nH1 == nH2
|
||||||
|
if L1 != L2:
|
||||||
|
orig_dtype = position_bias_table.dtype
|
||||||
|
position_bias_table = position_bias_table.float()
|
||||||
|
# bicubic interpolate relative_position_bias_table if not match
|
||||||
|
S1 = int(L1 ** 0.5)
|
||||||
|
S2 = int(L2 ** 0.5)
|
||||||
|
relative_position_bias_table_resized = F.interpolate(
|
||||||
|
position_bias_table.permute(1, 0).view(1, nH1, S1, S1),
|
||||||
|
size=(S2, S2),
|
||||||
|
mode=interpolation,
|
||||||
|
antialias=antialias)
|
||||||
|
relative_position_bias_table_resized = \
|
||||||
|
relative_position_bias_table_resized.view(nH2, L2).permute(1, 0)
|
||||||
|
relative_position_bias_table_resized.to(orig_dtype)
|
||||||
|
return relative_position_bias_table_resized
|
||||||
|
else:
|
||||||
|
return position_bias_table
|
||||||
|
|
||||||
|
|
||||||
def resize_rel_pos_bias_table(
|
def resize_rel_pos_bias_table(
|
||||||
rel_pos_bias,
|
rel_pos_bias,
|
||||||
new_window_size: Tuple[int, int],
|
new_window_size: Tuple[int, int],
|
||||||
|
@ -19,7 +19,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
|
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
|
||||||
to_2tuple, trunc_normal_, resample_relative_position_bias_table, use_fused_attn
|
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
@ -182,6 +182,7 @@ class Attention(torch.nn.Module):
|
|||||||
self.d = int(attn_ratio * key_dim)
|
self.d = int(attn_ratio * key_dim)
|
||||||
self.dh = int(attn_ratio * key_dim) * num_heads
|
self.dh = int(attn_ratio * key_dim) * num_heads
|
||||||
self.attn_ratio = attn_ratio
|
self.attn_ratio = attn_ratio
|
||||||
|
self.resolution = resolution
|
||||||
self.fused_attn = use_fused_attn()
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
h = self.dh + nh_kd * 2
|
h = self.dh + nh_kd * 2
|
||||||
@ -551,17 +552,18 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
# TODO: temporary use for testing, need change after weight convert
|
# TODO: temporary use for testing, need change after weight convert
|
||||||
if 'model' in state_dict.keys():
|
if 'model' in state_dict.keys():
|
||||||
state_dict = state_dict['model']
|
state_dict = state_dict['model']
|
||||||
targe_sd = model.state_dict()
|
target_sd = model.state_dict()
|
||||||
target_keys = list(targe_sd.keys())
|
target_keys = list(target_sd.keys())
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
i = 0
|
i = 0
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if not k.endswith('attention_bias_idxs'):
|
if k.endswith('attention_bias_idxs'):
|
||||||
|
continue
|
||||||
|
tk = target_keys[i]
|
||||||
if 'attention_biases' in k:
|
if 'attention_biases' in k:
|
||||||
# dynamic window size by resampling relative_position_bias_table
|
|
||||||
# TODO: whether move this func into model for dynamic input resolution? (high risk)
|
# TODO: whether move this func into model for dynamic input resolution? (high risk)
|
||||||
v = resample_relative_position_bias_table(v.T, targe_sd[target_keys[i]].shape[::-1]).T
|
v = resize_rel_pos_bias_table_levit(v.T, target_sd[tk].shape[::-1]).T
|
||||||
out_dict[target_keys[i]] = v
|
out_dict[tk] = v
|
||||||
i += 1
|
i += 1
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user