Cleanup before samvit merge. Resize abs posembed on the fly, undo some line-wraps, remove redundant unbind, fix HF hub weight load

This commit is contained in:
Ross Wightman 2023-05-18 16:43:48 -07:00
parent c1c6eeb909
commit e9373b1b92
4 changed files with 93 additions and 60 deletions

View File

@ -36,7 +36,7 @@ from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \

View File

@ -37,6 +37,7 @@ class PatchEmbed(nn.Module):
flatten: bool = True, flatten: bool = True,
output_fmt: Optional[str] = None, output_fmt: Optional[str] = None,
bias: bool = True, bias: bool = True,
strict_img_size: bool = True,
): ):
super().__init__() super().__init__()
self.patch_size = to_2tuple(patch_size) self.patch_size = to_2tuple(patch_size)
@ -56,6 +57,7 @@ class PatchEmbed(nn.Module):
# flatten spatial dim and transpose to channels last, kept for bwd compat # flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten self.flatten = flatten
self.output_fmt = Format.NCHW self.output_fmt = Format.NCHW
self.strict_img_size = strict_img_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -63,8 +65,18 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
if self.img_size is not None: if self.img_size is not None:
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") if self.strict_img_size:
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
else:
_assert(
H % self.patch_size[0] == 0,
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
)
_assert(
W % self.patch_size[1] == 0,
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
)
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:

View File

@ -52,3 +52,24 @@ def resample_abs_pos_embed(
_logger.info(f'Resized position embedding: {old_size} to {new_size}.') _logger.info(f'Resized position embedding: {old_size} to {new_size}.')
return posemb return posemb
def resample_abs_pos_embed_nhwc(
posemb,
new_size: List[int],
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
return posemb
# do the interpolation
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1)
if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
return posemb

View File

@ -19,7 +19,8 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, Format from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead,\
Format, resample_abs_pos_embed_nhwc
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -71,24 +72,21 @@ class Attention(nn.Module):
def forward(self, x): def forward(self, x):
B, H, W, _ = x.shape B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape( qkv = self.qkv(x).reshape(
B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C) # qkv with shape (3, B, nHead, H * W, C)
q, k, v = qkv.unbind(0)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
# q, k, v with shape (B * nHead, H * W, C)
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
q = q * self.scale
attn = (q * self.scale) @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)
if self.use_rel_pos: if self.use_rel_pos:
attn = add_decomposed_rel_pos( attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).view(B, self.num_heads, H, W, - x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x) x = self.proj(x)
return x return x
@ -136,13 +134,10 @@ class Block(nn.Module):
proj_drop=proj_drop, proj_drop=proj_drop,
norm_layer=norm_layer, norm_layer=norm_layer,
use_rel_pos=use_rel_pos, use_rel_pos=use_rel_pos,
input_size=input_size if window_size == 0 else ( input_size=input_size if window_size == 0 else (window_size, window_size),
window_size, window_size),
) )
self.ls1 = LayerScale( self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = mlp_layer( self.mlp = mlp_layer(
@ -151,10 +146,8 @@ class Block(nn.Module):
act_layer=act_layer, act_layer=act_layer,
drop=proj_drop, drop=proj_drop,
) )
self.ls2 = LayerScale( self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
@ -194,10 +187,8 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
).view(-1, window_size, window_size, C)
return windows, (Hp, Wp) return windows, (Hp, Wp)
@ -218,8 +209,7 @@ def window_unpartition(
Hp, Wp = pad_hw Hp, Wp = pad_hw
H, W = hw H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size) B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W: if Hp > H or Wp > W:
@ -248,16 +238,14 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
size=max_rel_dist, size=max_rel_dist,
mode="linear", mode="linear",
) )
rel_pos_resized = rel_pos_resized.reshape( rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
-1, max_rel_dist).permute(1, 0)
else: else:
rel_pos_resized = rel_pos rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different. # Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + \ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
(k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()] return rel_pos_resized[relative_coords.long()]
@ -331,7 +319,7 @@ class VisionTransformerSAM(nn.Module):
drop_path_rate: float = 0., drop_path_rate: float = 0.,
weight_init: str = '', weight_init: str = '',
embed_layer: Callable = partial( embed_layer: Callable = partial(
PatchEmbed, output_fmt=Format.NHWC), PatchEmbed, output_fmt=Format.NHWC, strict_img_size=False),
norm_layer: Optional[Callable] = nn.LayerNorm, norm_layer: Optional[Callable] = nn.LayerNorm,
act_layer: Optional[Callable] = nn.GELU, act_layer: Optional[Callable] = nn.GELU,
block_fn: Callable = Block, block_fn: Callable = Block,
@ -342,6 +330,7 @@ class VisionTransformerSAM(nn.Module):
global_attn_indexes: Tuple[int, ...] = (), global_attn_indexes: Tuple[int, ...] = (),
neck_chans: int = 256, neck_chans: int = 256,
global_pool: str = 'avg', global_pool: str = 'avg',
head_hidden_size: Optional[int] = None
): ):
""" """
Args: Args:
@ -370,6 +359,7 @@ class VisionTransformerSAM(nn.Module):
window_size: Window size for window attention blocks. If 0, not use window attention. window_size: Window size for window attention blocks. If 0, not use window attention.
global_attn_indexes: Indexes for blocks using global attention. Used when window_size > 0. global_attn_indexes: Indexes for blocks using global attention. Used when window_size > 0.
global_pool: Global pooling type. global_pool: Global pooling type.
head_hidden_size: If set, use NormMlpHead
""" """
super().__init__() super().__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
@ -388,14 +378,12 @@ class VisionTransformerSAM(nn.Module):
embed_dim=embed_dim, embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used bias=not pre_norm, # disable bias if pre-norm is used
) )
grid_size = self.patch_embed.grid_size
if use_abs_pos: if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size. # Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(torch.zeros(1, grid_size[0], grid_size[1], embed_dim))
torch.zeros(1, img_size // patch_size,
img_size // patch_size, embed_dim)
)
else: else:
self.pos_embed = 0. self.pos_embed = None
self.pos_drop = nn.Dropout(p=pos_drop_rate) self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0: if patch_drop_rate > 0:
self.patch_drop = PatchDropout( self.patch_drop = PatchDropout(
@ -424,7 +412,7 @@ class VisionTransformerSAM(nn.Module):
mlp_layer=mlp_layer, mlp_layer=mlp_layer,
use_rel_pos=use_rel_pos, use_rel_pos=use_rel_pos,
window_size=window_size if i not in global_attn_indexes else 0, window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size), input_size=grid_size,
) )
for i in range(depth)]) for i in range(depth)])
@ -451,12 +439,21 @@ class VisionTransformerSAM(nn.Module):
neck_chans = embed_dim neck_chans = embed_dim
# Classifier Head # Classifier Head
self.head = ClassifierHead( if head_hidden_size:
neck_chans, self.head = NormMlpClassifierHead(
num_classes, neck_chans,
pool_type=global_pool, num_classes,
drop_rate=drop_rate, hidden_size=head_hidden_size,
) pool_type=global_pool,
drop_rate=drop_rate,
)
else:
self.head = ClassifierHead(
neck_chans,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
@ -478,15 +475,14 @@ class VisionTransformerSAM(nn.Module):
return self.head return self.head
def reset_classifier(self, num_classes=0, global_pool=None): def reset_classifier(self, num_classes=0, global_pool=None):
self.head = self.head.reset(num_classes, global_pool) if num_classes > 0 else nn.Identity() self.head.reset(num_classes, global_pool)
def _pos_embed(self, x):
x = x + self.pos_embed
return self.pos_drop(x)
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self._pos_embed(x) if self.pos_embed is not None:
# dynamically resize abs pos embedding if needed
x = x + resample_abs_pos_embed_nhwc(self.pos_embed, x.shape[1:3])
x = self.pos_drop(x)
x = self.patch_drop(x) x = self.patch_drop(x)
x = self.norm_pre(x) x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
@ -507,15 +503,19 @@ class VisionTransformerSAM(nn.Module):
def checkpoint_filter_fn( def checkpoint_filter_fn(
state_dict, state_dict,
model model,
): ):
""" Remap SAM checkpoints -> timm """ """ Remap SAM checkpoints -> timm """
sam_checkpoint = 'image_encoder.patch_embed.proj.weight' in state_dict
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'image_encoder.' in k: if k.startswith('image_encoder.'):
new_k = k.replace('image_encoder.', '') k = k[14:]
new_k = new_k.replace('mlp.lin', 'mlp.fc') k = k.replace('mlp.lin', 'mlp.fc')
out_dict[new_k] = v else:
if sam_checkpoint:
continue
out_dict[k] = v
return out_dict return out_dict
@ -535,19 +535,19 @@ default_cfgs = generate_default_cfgs({
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only) # Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
'samvit_base_patch16.sa1b': _cfg( 'samvit_base_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
# hf_hub_id='timm/', hf_hub_id='timm/',
license='apache-2.0', license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0), input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_large_patch16.sa1b': _cfg( 'samvit_large_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
# hf_hub_id='timm/', hf_hub_id='timm/',
license='apache-2.0', license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0), input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_huge_patch16.sa1b': _cfg( 'samvit_huge_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
# hf_hub_id='timm/', hf_hub_id='timm/',
license='apache-2.0', license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0), input_size=(3, 1024, 1024), crop_pct=1.0),