mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
set_input_size initial impl for vit & swin v1. Move HybridEmbed to own location in timm/layers
This commit is contained in:
parent
f920119f3b
commit
392b78aee7
@ -27,6 +27,7 @@ from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .grid import ndgrid, meshgrid
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
|
253
timm/layers/hybrid_embed.py
Normal file
253
timm/layers/hybrid_embed.py
Normal file
@ -0,0 +1,253 @@
|
||||
""" Image to Patch Hybird Embedding Layer
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .format import Format, nchw_to
|
||||
from .helpers import to_2tuple
|
||||
from .patch_embed import resample_patch_embed
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HybridEmbed(nn.Module):
|
||||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
output_fmt: Format
|
||||
dynamic_img_pad: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
proj: bool = True,
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
self.backbone = backbone
|
||||
self.in_chans = in_chans
|
||||
(
|
||||
self.img_size,
|
||||
self.patch_size,
|
||||
self.feature_size,
|
||||
self.feature_ratio,
|
||||
self.feature_dim,
|
||||
self.grid_size,
|
||||
self.num_patches,
|
||||
) = self._init_backbone(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
feature_size=feature_size,
|
||||
feature_ratio=feature_ratio,
|
||||
)
|
||||
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
self.output_fmt = Format(output_fmt)
|
||||
else:
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.output_fmt = Format.NCHW
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
if not dynamic_img_pad:
|
||||
assert self.feature_size[0] % self.patch_size[0] == 0 and self.feature_size[1] % self.patch_size[1] == 0
|
||||
|
||||
if proj:
|
||||
self.proj = nn.Conv2d(
|
||||
self.feature_dim,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert self.feature_dim == embed_dim, \
|
||||
f'The feature dim ({self.feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
|
||||
self.proj = nn.Identity()
|
||||
|
||||
def _init_backbone(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_dim: Optional[int] = None,
|
||||
):
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
if feature_size is None:
|
||||
with torch.no_grad():
|
||||
# NOTE Most reliable way of determining output dims is to run forward pass
|
||||
training = self.backbone.training
|
||||
if training:
|
||||
self.backbone.eval()
|
||||
o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1]))
|
||||
if isinstance(o, (list, tuple)):
|
||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
self.backbone.train(training)
|
||||
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
|
||||
else:
|
||||
feature_size = to_2tuple(feature_size)
|
||||
feature_ratio = to_2tuple(feature_ratio or 16)
|
||||
if feature_dim is None:
|
||||
if hasattr(self.backbone, 'feature_info'):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
grid_size = tuple([f // p for f, p in zip(feature_size, patch_size)])
|
||||
num_patches = grid_size[0] * grid_size[1]
|
||||
return img_size, patch_size, feature_size, feature_ratio, feature_dim, grid_size, num_patches
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
img_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_dim: Optional[int] = None,
|
||||
):
|
||||
assert img_size is not None or patch_size is not None
|
||||
img_size = img_size or self.img_size
|
||||
new_patch_size = None
|
||||
if patch_size is not None:
|
||||
new_patch_size = to_2tuple(patch_size)
|
||||
if new_patch_size is not None and new_patch_size != self.patch_size:
|
||||
assert isinstance(self.proj, nn.Conv2d), 'HybridEmbed must have a projection layer to change patch size.'
|
||||
with torch.no_grad():
|
||||
new_proj = nn.Conv2d(
|
||||
self.proj.in_channels,
|
||||
self.proj.out_channels,
|
||||
kernel_size=new_patch_size,
|
||||
stride=new_patch_size,
|
||||
bias=self.proj.bias is not None,
|
||||
)
|
||||
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
|
||||
if self.proj.bias is not None:
|
||||
new_proj.bias.copy_(self.proj.bias)
|
||||
self.proj = new_proj
|
||||
patch_size = new_patch_size
|
||||
patch_size = patch_size or self.patch_size
|
||||
|
||||
if img_size != self.img_size or patch_size != self.patch_size:
|
||||
(
|
||||
self.img_size,
|
||||
self.patch_size,
|
||||
self.feature_size,
|
||||
self.feature_ratio,
|
||||
self.feature_dim,
|
||||
self.grid_size,
|
||||
self.num_patches,
|
||||
) = self._init_backbone(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
feature_size=feature_size,
|
||||
feature_ratio=feature_ratio,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||
total_reduction = (
|
||||
self.feature_ratio[0] * self.patch_size[0],
|
||||
self.feature_ratio[1] * self.patch_size[1]
|
||||
)
|
||||
if as_scalar:
|
||||
return max(total_reduction)
|
||||
else:
|
||||
return total_reduction
|
||||
|
||||
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
|
||||
"""
|
||||
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
|
||||
if self.dynamic_img_pad:
|
||||
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
|
||||
else:
|
||||
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
_, _, H, W = x.shape
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
elif self.output_fmt != Format.NCHW:
|
||||
x = nchw_to(x, self.output_fmt)
|
||||
return x
|
||||
|
||||
|
||||
class HybridEmbedWithSize(HybridEmbed):
|
||||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias=True,
|
||||
proj=True,
|
||||
):
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
feature_size=feature_size,
|
||||
feature_ratio=feature_ratio,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=bias,
|
||||
proj=proj,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x)
|
||||
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
@ -44,14 +44,7 @@ class PatchEmbed(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = to_2tuple(img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
||||
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
@ -66,6 +59,41 @@ class PatchEmbed(nn.Module):
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
|
||||
assert self.patch_size
|
||||
if img_size is None:
|
||||
return None, None, None
|
||||
img_size = to_2tuple(img_size)
|
||||
grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
|
||||
num_patches = grid_size[0] * grid_size[1]
|
||||
return img_size, grid_size, num_patches
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
img_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
):
|
||||
new_patch_size = None
|
||||
if patch_size is not None:
|
||||
new_patch_size = to_2tuple(patch_size)
|
||||
if new_patch_size is not None and new_patch_size != self.patch_size:
|
||||
with torch.no_grad():
|
||||
new_proj = nn.Conv2d(
|
||||
self.proj.in_channels,
|
||||
self.proj.out_channels,
|
||||
kernel_size=new_patch_size,
|
||||
stride=new_patch_size,
|
||||
bias=self.proj.bias is not None,
|
||||
)
|
||||
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
|
||||
if self.proj.bias is not None:
|
||||
new_proj.bias.copy_(self.proj.bias)
|
||||
self.proj = new_proj
|
||||
self.patch_size = new_patch_size
|
||||
img_size = img_size or self.img_size
|
||||
if img_size != self.img_size or new_patch_size is not None:
|
||||
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
||||
|
||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||
if as_scalar:
|
||||
return max(self.patch_size)
|
||||
@ -180,13 +208,9 @@ def resample_patch_embed(
|
||||
"""
|
||||
import numpy as np
|
||||
try:
|
||||
import functorch
|
||||
vmap = functorch.vmap
|
||||
from torch import vmap
|
||||
except ImportError:
|
||||
if hasattr(torch, 'vmap'):
|
||||
vmap = torch.vmap
|
||||
else:
|
||||
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
|
||||
from functorch import vmap
|
||||
|
||||
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
||||
assert len(new_size) == 2, "New shape should only be hw"
|
||||
|
@ -27,11 +27,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm
|
||||
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
|
||||
|
||||
__all__ = ['ConVit']
|
||||
|
@ -140,6 +140,27 @@ class WindowAttention(nn.Module):
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def set_window_size(self, window_size: Tuple[int, int]) -> None:
|
||||
"""Update window size & interpolate position embeddings
|
||||
Args:
|
||||
window_size (int): New window size
|
||||
"""
|
||||
window_size = to_2tuple(window_size)
|
||||
if window_size == self.window_size:
|
||||
return
|
||||
self.window_size = window_size
|
||||
win_h, win_w = self.window_size
|
||||
self.window_area = win_h * win_w
|
||||
with torch.no_grad():
|
||||
new_bias_shape = (2 * win_h - 1) * (2 * win_w - 1), self.num_heads
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
resize_rel_pos_bias_table(
|
||||
self.relative_position_bias_table,
|
||||
new_window_size=self.window_size,
|
||||
new_bias_shape=new_bias_shape,
|
||||
))
|
||||
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
|
||||
|
||||
def _get_rel_pos_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
|
||||
@ -197,6 +218,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
shift_size: int = 0,
|
||||
always_partition: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
@ -224,9 +246,9 @@ class SwinTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
ws, ss = self._calc_window_shift(window_size, shift_size)
|
||||
self.window_size: Tuple[int, int] = ws
|
||||
self.shift_size: Tuple[int, int] = ss
|
||||
self.target_shift_size = to_2tuple(shift_size)
|
||||
self.always_partition = always_partition
|
||||
self.window_size, self.shift_size = self._calc_window_shift(window_size, target_shift_size=shift_size)
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
@ -251,6 +273,9 @@ class SwinTransformerBlock(nn.Module):
|
||||
)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self._make_attention_mask()
|
||||
|
||||
def _make_attention_mask(self):
|
||||
if any(self.shift_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
@ -274,16 +299,47 @@ class SwinTransformerBlock(nn.Module):
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||
|
||||
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
def _calc_window_shift(
|
||||
self,
|
||||
target_window_size: Union[int, Tuple[int, int]],
|
||||
target_shift_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
target_window_size = to_2tuple(target_window_size)
|
||||
target_shift_size = to_2tuple(target_shift_size)
|
||||
if target_shift_size is None:
|
||||
# if passed value is None, recalculate from default window_size // 2 if it was active
|
||||
target_shift_size = self.target_shift_size
|
||||
if any(target_shift_size):
|
||||
target_shift_size = [target_window_size[0] // 2, target_window_size[1] // 2]
|
||||
else:
|
||||
target_shift_size = to_2tuple(target_shift_size)
|
||||
if self.always_partition:
|
||||
return target_window_size, target_shift_size
|
||||
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||
return tuple(window_size), tuple(shift_size)
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: Tuple[int, int],
|
||||
always_partition: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
feat_size: New input resolution
|
||||
window_size: New window size
|
||||
always_partition: Change always_partition attribute if not None
|
||||
"""
|
||||
self.input_resolution = feat_size
|
||||
if always_partition is not None:
|
||||
self.always_partition = always_partition
|
||||
self.window_size, self.shift_size = self._calc_window_shift(window_size)
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.attn.set_window_size(self.window_size)
|
||||
self._make_attention_mask()
|
||||
|
||||
def _attn(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
@ -374,6 +430,7 @@ class SwinTransformerStage(nn.Module):
|
||||
num_heads: int = 4,
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
always_partition: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
@ -427,6 +484,7 @@ class SwinTransformerStage(nn.Module):
|
||||
head_dim=head_dim,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||
always_partition=always_partition,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
@ -436,6 +494,30 @@ class SwinTransformerStage(nn.Module):
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: int,
|
||||
always_partition: Optional[bool] = None,
|
||||
):
|
||||
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
feat_size (Tuple[int, int]): New input resolution
|
||||
window_size (int): New window size
|
||||
"""
|
||||
self.input_resolution = feat_size
|
||||
if isinstance(self.downsample, nn.Identity):
|
||||
self.output_resolution = feat_size
|
||||
else:
|
||||
self.output_resolution = tuple(i // 2 for i in feat_size)
|
||||
for block in self.blocks:
|
||||
block.set_input_size(
|
||||
feat_size=self.output_resolution,
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
|
||||
@ -465,6 +547,7 @@ class SwinTransformer(nn.Module):
|
||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
always_partition: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
@ -546,6 +629,7 @@ class SwinTransformer(nn.Module):
|
||||
num_heads=num_heads[i],
|
||||
head_dim=head_dim[i],
|
||||
window_size=window_size[i],
|
||||
always_partition=always_partition,
|
||||
mlp_ratio=mlp_ratio[i],
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
@ -556,7 +640,7 @@ class SwinTransformer(nn.Module):
|
||||
in_dim = out_dim
|
||||
if i > 0:
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
|
||||
self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.norm = norm_layer(self.num_features)
|
||||
@ -584,6 +668,36 @@ class SwinTransformer(nn.Module):
|
||||
nwd.add(n)
|
||||
return nwd
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
patch_size: Optional[Tuple[int, int]] = None,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
window_ratio: int = 32,
|
||||
always_partition: Optional[bool] = None,
|
||||
) -> None:
|
||||
""" Updates the image resolution and window size.
|
||||
|
||||
Args:
|
||||
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||
window_ratio (int): divisor for calculating window size from image size
|
||||
"""
|
||||
if img_size is not None or patch_size is not None:
|
||||
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
|
||||
self.patch_grid = self.patch_embed.grid_size
|
||||
if window_size is None:
|
||||
img_size = self.patch_embed.img_size
|
||||
window_size = tuple([s // window_ratio for s in img_size])
|
||||
for index, stage in enumerate(self.layers):
|
||||
stage_scale = 2 ** max(index - 1, 0)
|
||||
print(self.patch_grid, stage_scale)
|
||||
stage.set_input_size(
|
||||
feat_size=(self.patch_grid[0] // stage_scale, self.patch_grid[1] // stage_scale),
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
|
@ -119,7 +119,7 @@ class WindowMultiHeadAttention(nn.Module):
|
||||
assert dim % num_heads == 0, \
|
||||
"The number of input features (in_features) are not divisible by the number of heads (num_heads)."
|
||||
self.in_features: int = dim
|
||||
self.window_size: Tuple[int, int] = window_size
|
||||
self.window_size: Tuple[int, int] = to_2tuple(window_size)
|
||||
self.num_heads: int = num_heads
|
||||
self.sequential_attn: bool = sequential_attn
|
||||
|
||||
@ -152,16 +152,15 @@ class WindowMultiHeadAttention(nn.Module):
|
||||
1.0 + relative_coordinates.abs())
|
||||
self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False)
|
||||
|
||||
def update_input_size(self, new_window_size: int, **kwargs: Any) -> None:
|
||||
"""Method updates the window size and so the pair-wise relative positions
|
||||
|
||||
def set_window_size(self, window_size: Tuple[int, int]) -> None:
|
||||
"""Update window size & interpolate position embeddings
|
||||
Args:
|
||||
new_window_size (int): New window size
|
||||
kwargs (Any): Unused
|
||||
window_size (int): New window size
|
||||
"""
|
||||
# Set new window size and new pair-wise relative positions
|
||||
self.window_size: int = new_window_size
|
||||
self._make_pair_wise_relative_positions()
|
||||
window_size = to_2tuple(window_size)
|
||||
if window_size != self.window_size:
|
||||
self.window_size = window_size
|
||||
self._make_pair_wise_relative_positions()
|
||||
|
||||
def _relative_positional_encodings(self) -> torch.Tensor:
|
||||
"""Method computes the relative positional encodings
|
||||
@ -321,18 +320,18 @@ class SwinTransformerV2CrBlock(nn.Module):
|
||||
nn.init.constant_(self.norm1.weight, self.init_values)
|
||||
nn.init.constant_(self.norm2.weight, self.init_values)
|
||||
|
||||
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
|
||||
def set_input_size(self, feat_size: Tuple[int, int], window_size: Tuple[int, int]) -> None:
|
||||
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
new_window_size (int): New window size
|
||||
new_feat_size (Tuple[int, int]): New input resolution
|
||||
feat_size (Tuple[int, int]): New input resolution
|
||||
window_size (int): New window size
|
||||
"""
|
||||
# Update input resolution
|
||||
self.feat_size: Tuple[int, int] = new_feat_size
|
||||
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size))
|
||||
self.feat_size: Tuple[int, int] = feat_size
|
||||
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.attn.update_input_size(new_window_size=self.window_size)
|
||||
self.attn.set_window_size(self.window_size)
|
||||
self._make_attention_mask()
|
||||
|
||||
def _shifted_window_attn(self, x):
|
||||
@ -510,18 +509,18 @@ class SwinTransformerV2CrStage(nn.Module):
|
||||
for index in range(depth)]
|
||||
)
|
||||
|
||||
def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None:
|
||||
def set_input_size(self, feat_size: Tuple[int, int], window_size: int) -> None:
|
||||
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
new_window_size (int): New window size
|
||||
new_feat_size (Tuple[int, int]): New input resolution
|
||||
window_size (int): New window size
|
||||
feat_size (Tuple[int, int]): New input resolution
|
||||
"""
|
||||
self.feat_size: Tuple[int, int] = (
|
||||
(new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size
|
||||
(feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size
|
||||
)
|
||||
for block in self.blocks:
|
||||
block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size)
|
||||
block.set_input_size(feat_size=self.feat_size, window_size=window_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
@ -657,33 +656,32 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
if weight_init != 'skip':
|
||||
named_apply(init_weights, self)
|
||||
|
||||
def update_input_size(
|
||||
def set_input_size(
|
||||
self,
|
||||
new_img_size: Optional[Tuple[int, int]] = None,
|
||||
new_window_size: Optional[int] = None,
|
||||
img_window_ratio: int = 32,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
window_ratio: int = 32,
|
||||
) -> None:
|
||||
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||
new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||
img_window_ratio (int): divisor for calculating window size from image size
|
||||
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||
window_ratio (int): divisor for calculating window size from image size
|
||||
"""
|
||||
# Check parameters
|
||||
if new_img_size is None:
|
||||
new_img_size = self.img_size
|
||||
if img_size is None:
|
||||
img_size = self.img_size
|
||||
else:
|
||||
new_img_size = to_2tuple(new_img_size)
|
||||
if new_window_size is None:
|
||||
new_window_size = tuple([s // img_window_ratio for s in new_img_size])
|
||||
img_size = to_2tuple(img_size)
|
||||
if window_size is None:
|
||||
window_size = tuple([s // window_ratio for s in img_size])
|
||||
# Compute new patch resolution & update resolution of each stage
|
||||
new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size)
|
||||
patch_grid_size = (img_size[0] // self.patch_size, img_size[1] // self.patch_size)
|
||||
for index, stage in enumerate(self.stages):
|
||||
stage_scale = 2 ** max(index - 1, 0)
|
||||
stage.update_input_size(
|
||||
new_window_size=new_window_size,
|
||||
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
|
||||
stage.set_input_size(
|
||||
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
|
@ -632,6 +632,31 @@ class VisionTransformer(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
patch_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
"""Method updates the input image resolution, patch size
|
||||
|
||||
Args:
|
||||
img_size: New input resolution, if None current resolution is used
|
||||
patch_size: New patch size, if None existing patch size is used
|
||||
"""
|
||||
prev_grid_size = self.patch_embed.grid_size
|
||||
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
|
||||
if self.pos_embed is not None:
|
||||
num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
|
||||
num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
|
||||
if num_new_tokens != self.pos_embed.shape[1]:
|
||||
self.pos_embed = nn.Parameter(resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
new_size=self.patch_embed.grid_size,
|
||||
old_size=prev_grid_size,
|
||||
num_prefix_tokens=num_prefix_tokens,
|
||||
verbose=True,
|
||||
))
|
||||
|
||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is None:
|
||||
return x.view(x.shape[0], -1, x.shape[-1])
|
||||
|
@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
|
||||
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, HybridEmbed
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
@ -31,172 +30,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
|
||||
class HybridEmbed(nn.Module):
|
||||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
output_fmt: Format
|
||||
dynamic_img_pad: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
proj: bool = True,
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(backbone, nn.Module)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.backbone = backbone
|
||||
if feature_size is None:
|
||||
with torch.no_grad():
|
||||
# NOTE Most reliable way of determining output dims is to run forward pass
|
||||
training = backbone.training
|
||||
if training:
|
||||
backbone.eval()
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
||||
if isinstance(o, (list, tuple)):
|
||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
backbone.train(training)
|
||||
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
|
||||
else:
|
||||
|
||||
feature_size = to_2tuple(feature_size)
|
||||
feature_ratio = to_2tuple(feature_ratio or 16)
|
||||
if hasattr(self.backbone, 'feature_info'):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
if not dynamic_img_pad:
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.feature_size = feature_size
|
||||
self.feature_ratio = feature_ratio
|
||||
self.grid_size = tuple([f // p for f, p in zip(self.feature_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
self.output_fmt = Format(output_fmt)
|
||||
else:
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.output_fmt = Format.NCHW
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
if proj:
|
||||
self.proj = nn.Conv2d(
|
||||
feature_dim,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert feature_dim == embed_dim,\
|
||||
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
|
||||
self.proj = nn.Identity()
|
||||
|
||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||
total_reduction = (
|
||||
self.feature_ratio[0] * self.patch_size[0],
|
||||
self.feature_ratio[1] * self.patch_size[1]
|
||||
)
|
||||
if as_scalar:
|
||||
return max(total_reduction)
|
||||
else:
|
||||
return total_reduction
|
||||
|
||||
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
|
||||
"""
|
||||
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
|
||||
if self.dynamic_img_pad:
|
||||
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
|
||||
else:
|
||||
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
_, _, H, W = x.shape
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
elif self.output_fmt != Format.NCHW:
|
||||
x = nchw_to(x, self.output_fmt)
|
||||
return x
|
||||
|
||||
|
||||
class HybridEmbedWithSize(nn.Module):
|
||||
""" CNN Feature Map Embedding
|
||||
Extract feature map from CNN, flatten, project to embedding dim.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
bias=True,
|
||||
proj=True,
|
||||
):
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
feature_size=feature_size,
|
||||
feature_ratio=feature_ratio,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=bias,
|
||||
proj=proj,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||
self.backbone.set_grad_checkpointing(enable=enable)
|
||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||
self.backbone.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x)
|
||||
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
||||
|
||||
|
||||
class ConvStem(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -29,12 +29,11 @@ import torch.nn as nn
|
||||
|
||||
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
|
||||
make_divisible, DropPath
|
||||
make_divisible, DropPath, HybridEmbed
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
|
||||
|
||||
@dataclass
|
||||
|
Loading…
x
Reference in New Issue
Block a user