From 9caf32b93f0f0df6c9a0cdf85d5375f3df3ee6cb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 31 Aug 2023 17:20:14 -0700 Subject: [PATCH] Move levit style pos bias resize with other rel pos bias utils --- timm/layers/__init__.py | 4 ++-- timm/layers/pos_embed.py | 35 ----------------------------------- timm/layers/pos_embed_rel.py | 32 ++++++++++++++++++++++++++++++++ timm/models/tiny_vit.py | 20 +++++++++++--------- 4 files changed, 45 insertions(+), 46 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index eb5a140a..5a610da6 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -36,9 +36,9 @@ from .padding import get_padding, get_same_padding, pad_same from .patch_dropout import PatchDropout from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d -from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, resample_relative_position_bias_table +from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ - resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple + resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index dc96048c..3e67be00 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -78,38 +78,3 @@ def resample_abs_pos_embed_nhwc( _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.') return posemb - - -def resample_relative_position_bias_table( - position_bias_table, - new_size, - interpolation: str = 'bicubic', - antialias: bool = True, - verbose: bool = False -): - """ - Resample relative position bias table suggested in LeVit - Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py - """ - L1, nH1 = position_bias_table.size() - L2, nH2 = new_size - assert nH1 == nH2 - if L1 != L2: - orig_dtype = position_bias_table.dtype - position_bias_table = position_bias_table.float() - # bicubic interpolate relative_position_bias_table if not match - S1 = int(L1 ** 0.5) - S2 = int(L2 ** 0.5) - relative_position_bias_table_resized = F.interpolate( - position_bias_table.permute(1, 0).view(1, nH1, S1, S1), - size=(S2, S2), - mode=interpolation, - antialias=antialias) - relative_position_bias_table_resized = \ - relative_position_bias_table_resized.view(nH2, L2).permute(1, 0) - relative_position_bias_table_resized.to(orig_dtype) - if not torch.jit.is_scripting() and verbose: - _logger.info(f'Resized position bias: {L1, nH1} to {L2, nH2}.') - return relative_position_bias_table_resized - else: - return position_bias_table diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index dc4377d6..4620e81d 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -121,6 +121,38 @@ def resize_rel_pos_bias_table_simple( return rel_pos_bias +def resize_rel_pos_bias_table_levit( + position_bias_table, + new_size, + interpolation: str = 'bicubic', + antialias: bool = True, +): + """ + Resample relative position bias table suggested in LeVit + Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py + """ + L1, nH1 = position_bias_table.size() + L2, nH2 = new_size + assert nH1 == nH2 + if L1 != L2: + orig_dtype = position_bias_table.dtype + position_bias_table = position_bias_table.float() + # bicubic interpolate relative_position_bias_table if not match + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + relative_position_bias_table_resized = F.interpolate( + position_bias_table.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode=interpolation, + antialias=antialias) + relative_position_bias_table_resized = \ + relative_position_bias_table_resized.view(nH2, L2).permute(1, 0) + relative_position_bias_table_resized.to(orig_dtype) + return relative_position_bias_table_resized + else: + return position_bias_table + + def resize_rel_pos_bias_table( rel_pos_bias, new_window_size: Tuple[int, int], diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index c8a6007d..c0632983 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ - to_2tuple, trunc_normal_, resample_relative_position_bias_table, use_fused_attn + trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -182,6 +182,7 @@ class Attention(torch.nn.Module): self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio + self.resolution = resolution self.fused_attn = use_fused_attn() h = self.dh + nh_kd * 2 @@ -551,17 +552,18 @@ def checkpoint_filter_fn(state_dict, model): # TODO: temporary use for testing, need change after weight convert if 'model' in state_dict.keys(): state_dict = state_dict['model'] - targe_sd = model.state_dict() - target_keys = list(targe_sd.keys()) + target_sd = model.state_dict() + target_keys = list(target_sd.keys()) out_dict = {} i = 0 for k, v in state_dict.items(): - if not k.endswith('attention_bias_idxs'): - if 'attention_biases' in k: - # dynamic window size by resampling relative_position_bias_table - # TODO: whether move this func into model for dynamic input resolution? (high risk) - v = resample_relative_position_bias_table(v.T, targe_sd[target_keys[i]].shape[::-1]).T - out_dict[target_keys[i]] = v + if k.endswith('attention_bias_idxs'): + continue + tk = target_keys[i] + if 'attention_biases' in k: + # TODO: whether move this func into model for dynamic input resolution? (high risk) + v = resize_rel_pos_bias_table_levit(v.T, target_sd[tk].shape[::-1]).T + out_dict[tk] = v i += 1 return out_dict