mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
set_input_size initial impl for vit & swin v1. Move HybridEmbed to own location in timm/layers
This commit is contained in:
parent
f920119f3b
commit
392b78aee7
@ -27,6 +27,7 @@ from .gather_excite import GatherExcite
|
|||||||
from .global_context import GlobalContext
|
from .global_context import GlobalContext
|
||||||
from .grid import ndgrid, meshgrid
|
from .grid import ndgrid, meshgrid
|
||||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||||
|
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
|
||||||
from .inplace_abn import InplaceAbn
|
from .inplace_abn import InplaceAbn
|
||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
|
253
timm/layers/hybrid_embed.py
Normal file
253
timm/layers/hybrid_embed.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
""" Image to Patch Hybird Embedding Layer
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .format import Format, nchw_to
|
||||||
|
from .helpers import to_2tuple
|
||||||
|
from .patch_embed import resample_patch_embed
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HybridEmbed(nn.Module):
|
||||||
|
""" CNN Feature Map Embedding
|
||||||
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
|
"""
|
||||||
|
output_fmt: Format
|
||||||
|
dynamic_img_pad: torch.jit.Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone: nn.Module,
|
||||||
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||||
|
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
bias: bool = True,
|
||||||
|
proj: bool = True,
|
||||||
|
flatten: bool = True,
|
||||||
|
output_fmt: Optional[str] = None,
|
||||||
|
strict_img_size: bool = True,
|
||||||
|
dynamic_img_pad: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(backbone, nn.Module)
|
||||||
|
self.backbone = backbone
|
||||||
|
self.in_chans = in_chans
|
||||||
|
(
|
||||||
|
self.img_size,
|
||||||
|
self.patch_size,
|
||||||
|
self.feature_size,
|
||||||
|
self.feature_ratio,
|
||||||
|
self.feature_dim,
|
||||||
|
self.grid_size,
|
||||||
|
self.num_patches,
|
||||||
|
) = self._init_backbone(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
feature_size=feature_size,
|
||||||
|
feature_ratio=feature_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_fmt is not None:
|
||||||
|
self.flatten = False
|
||||||
|
self.output_fmt = Format(output_fmt)
|
||||||
|
else:
|
||||||
|
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||||
|
self.flatten = flatten
|
||||||
|
self.output_fmt = Format.NCHW
|
||||||
|
self.strict_img_size = strict_img_size
|
||||||
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
if not dynamic_img_pad:
|
||||||
|
assert self.feature_size[0] % self.patch_size[0] == 0 and self.feature_size[1] % self.patch_size[1] == 0
|
||||||
|
|
||||||
|
if proj:
|
||||||
|
self.proj = nn.Conv2d(
|
||||||
|
self.feature_dim,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.feature_dim == embed_dim, \
|
||||||
|
f'The feature dim ({self.feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
|
||||||
|
self.proj = nn.Identity()
|
||||||
|
|
||||||
|
def _init_backbone(
|
||||||
|
self,
|
||||||
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||||
|
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
feature_dim: Optional[int] = None,
|
||||||
|
):
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
if feature_size is None:
|
||||||
|
with torch.no_grad():
|
||||||
|
# NOTE Most reliable way of determining output dims is to run forward pass
|
||||||
|
training = self.backbone.training
|
||||||
|
if training:
|
||||||
|
self.backbone.eval()
|
||||||
|
o = self.backbone(torch.zeros(1, self.in_chans, img_size[0], img_size[1]))
|
||||||
|
if isinstance(o, (list, tuple)):
|
||||||
|
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
feature_size = o.shape[-2:]
|
||||||
|
feature_dim = o.shape[1]
|
||||||
|
self.backbone.train(training)
|
||||||
|
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
|
||||||
|
else:
|
||||||
|
feature_size = to_2tuple(feature_size)
|
||||||
|
feature_ratio = to_2tuple(feature_ratio or 16)
|
||||||
|
if feature_dim is None:
|
||||||
|
if hasattr(self.backbone, 'feature_info'):
|
||||||
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||||
|
else:
|
||||||
|
feature_dim = self.backbone.num_features
|
||||||
|
grid_size = tuple([f // p for f, p in zip(feature_size, patch_size)])
|
||||||
|
num_patches = grid_size[0] * grid_size[1]
|
||||||
|
return img_size, patch_size, feature_size, feature_ratio, feature_dim, grid_size, num_patches
|
||||||
|
|
||||||
|
def set_input_size(
|
||||||
|
self,
|
||||||
|
img_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
feature_dim: Optional[int] = None,
|
||||||
|
):
|
||||||
|
assert img_size is not None or patch_size is not None
|
||||||
|
img_size = img_size or self.img_size
|
||||||
|
new_patch_size = None
|
||||||
|
if patch_size is not None:
|
||||||
|
new_patch_size = to_2tuple(patch_size)
|
||||||
|
if new_patch_size is not None and new_patch_size != self.patch_size:
|
||||||
|
assert isinstance(self.proj, nn.Conv2d), 'HybridEmbed must have a projection layer to change patch size.'
|
||||||
|
with torch.no_grad():
|
||||||
|
new_proj = nn.Conv2d(
|
||||||
|
self.proj.in_channels,
|
||||||
|
self.proj.out_channels,
|
||||||
|
kernel_size=new_patch_size,
|
||||||
|
stride=new_patch_size,
|
||||||
|
bias=self.proj.bias is not None,
|
||||||
|
)
|
||||||
|
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
|
||||||
|
if self.proj.bias is not None:
|
||||||
|
new_proj.bias.copy_(self.proj.bias)
|
||||||
|
self.proj = new_proj
|
||||||
|
patch_size = new_patch_size
|
||||||
|
patch_size = patch_size or self.patch_size
|
||||||
|
|
||||||
|
if img_size != self.img_size or patch_size != self.patch_size:
|
||||||
|
(
|
||||||
|
self.img_size,
|
||||||
|
self.patch_size,
|
||||||
|
self.feature_size,
|
||||||
|
self.feature_ratio,
|
||||||
|
self.feature_dim,
|
||||||
|
self.grid_size,
|
||||||
|
self.num_patches,
|
||||||
|
) = self._init_backbone(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
feature_size=feature_size,
|
||||||
|
feature_ratio=feature_ratio,
|
||||||
|
feature_dim=feature_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||||
|
total_reduction = (
|
||||||
|
self.feature_ratio[0] * self.patch_size[0],
|
||||||
|
self.feature_ratio[1] * self.patch_size[1]
|
||||||
|
)
|
||||||
|
if as_scalar:
|
||||||
|
return max(total_reduction)
|
||||||
|
else:
|
||||||
|
return total_reduction
|
||||||
|
|
||||||
|
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||||
|
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
|
||||||
|
"""
|
||||||
|
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
|
||||||
|
if self.dynamic_img_pad:
|
||||||
|
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
|
||||||
|
else:
|
||||||
|
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True):
|
||||||
|
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||||
|
self.backbone.set_grad_checkpointing(enable=enable)
|
||||||
|
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||||
|
self.backbone.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.backbone(x)
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
if self.dynamic_img_pad:
|
||||||
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||||
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||||
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
x = self.proj(x)
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
elif self.output_fmt != Format.NCHW:
|
||||||
|
x = nchw_to(x, self.output_fmt)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HybridEmbedWithSize(HybridEmbed):
|
||||||
|
""" CNN Feature Map Embedding
|
||||||
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone: nn.Module,
|
||||||
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 1,
|
||||||
|
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
bias=True,
|
||||||
|
proj=True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
backbone=backbone,
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
feature_size=feature_size,
|
||||||
|
feature_ratio=feature_ratio,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
proj=proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def set_grad_checkpointing(self, enable: bool = True):
|
||||||
|
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
||||||
|
self.backbone.set_grad_checkpointing(enable=enable)
|
||||||
|
elif hasattr(self.backbone, 'grad_checkpointing'):
|
||||||
|
self.backbone.grad_checkpointing = enable
|
||||||
|
|
||||||
|
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||||
|
x = self.backbone(x)
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
x = self.proj(x)
|
||||||
|
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
@ -44,14 +44,7 @@ class PatchEmbed(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = to_2tuple(patch_size)
|
self.patch_size = to_2tuple(patch_size)
|
||||||
if img_size is not None:
|
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
||||||
self.img_size = to_2tuple(img_size)
|
|
||||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
||||||
else:
|
|
||||||
self.img_size = None
|
|
||||||
self.grid_size = None
|
|
||||||
self.num_patches = None
|
|
||||||
|
|
||||||
if output_fmt is not None:
|
if output_fmt is not None:
|
||||||
self.flatten = False
|
self.flatten = False
|
||||||
@ -66,6 +59,41 @@ class PatchEmbed(nn.Module):
|
|||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
|
def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
|
||||||
|
assert self.patch_size
|
||||||
|
if img_size is None:
|
||||||
|
return None, None, None
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
|
||||||
|
num_patches = grid_size[0] * grid_size[1]
|
||||||
|
return img_size, grid_size, num_patches
|
||||||
|
|
||||||
|
def set_input_size(
|
||||||
|
self,
|
||||||
|
img_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
):
|
||||||
|
new_patch_size = None
|
||||||
|
if patch_size is not None:
|
||||||
|
new_patch_size = to_2tuple(patch_size)
|
||||||
|
if new_patch_size is not None and new_patch_size != self.patch_size:
|
||||||
|
with torch.no_grad():
|
||||||
|
new_proj = nn.Conv2d(
|
||||||
|
self.proj.in_channels,
|
||||||
|
self.proj.out_channels,
|
||||||
|
kernel_size=new_patch_size,
|
||||||
|
stride=new_patch_size,
|
||||||
|
bias=self.proj.bias is not None,
|
||||||
|
)
|
||||||
|
new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
|
||||||
|
if self.proj.bias is not None:
|
||||||
|
new_proj.bias.copy_(self.proj.bias)
|
||||||
|
self.proj = new_proj
|
||||||
|
self.patch_size = new_patch_size
|
||||||
|
img_size = img_size or self.img_size
|
||||||
|
if img_size != self.img_size or new_patch_size is not None:
|
||||||
|
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
||||||
|
|
||||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
||||||
if as_scalar:
|
if as_scalar:
|
||||||
return max(self.patch_size)
|
return max(self.patch_size)
|
||||||
@ -180,13 +208,9 @@ def resample_patch_embed(
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
try:
|
try:
|
||||||
import functorch
|
from torch import vmap
|
||||||
vmap = functorch.vmap
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if hasattr(torch, 'vmap'):
|
from functorch import vmap
|
||||||
vmap = torch.vmap
|
|
||||||
else:
|
|
||||||
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
|
|
||||||
|
|
||||||
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
||||||
assert len(new_size) == 2, "New shape should only be hw"
|
assert len(new_size) == 2, "New shape should only be hw"
|
||||||
|
@ -27,11 +27,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm
|
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm, HybridEmbed
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features_fx import register_notrace_module
|
from ._features_fx import register_notrace_module
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
from .vision_transformer_hybrid import HybridEmbed
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ConVit']
|
__all__ = ['ConVit']
|
||||||
|
@ -140,6 +140,27 @@ class WindowAttention(nn.Module):
|
|||||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def set_window_size(self, window_size: Tuple[int, int]) -> None:
|
||||||
|
"""Update window size & interpolate position embeddings
|
||||||
|
Args:
|
||||||
|
window_size (int): New window size
|
||||||
|
"""
|
||||||
|
window_size = to_2tuple(window_size)
|
||||||
|
if window_size == self.window_size:
|
||||||
|
return
|
||||||
|
self.window_size = window_size
|
||||||
|
win_h, win_w = self.window_size
|
||||||
|
self.window_area = win_h * win_w
|
||||||
|
with torch.no_grad():
|
||||||
|
new_bias_shape = (2 * win_h - 1) * (2 * win_w - 1), self.num_heads
|
||||||
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
|
resize_rel_pos_bias_table(
|
||||||
|
self.relative_position_bias_table,
|
||||||
|
new_window_size=self.window_size,
|
||||||
|
new_bias_shape=new_bias_shape,
|
||||||
|
))
|
||||||
|
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
|
||||||
|
|
||||||
def _get_rel_pos_bias(self) -> torch.Tensor:
|
def _get_rel_pos_bias(self) -> torch.Tensor:
|
||||||
relative_position_bias = self.relative_position_bias_table[
|
relative_position_bias = self.relative_position_bias_table[
|
||||||
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
|
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
|
||||||
@ -197,6 +218,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
head_dim: Optional[int] = None,
|
head_dim: Optional[int] = None,
|
||||||
window_size: _int_or_tuple_2_t = 7,
|
window_size: _int_or_tuple_2_t = 7,
|
||||||
shift_size: int = 0,
|
shift_size: int = 0,
|
||||||
|
always_partition: bool = False,
|
||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
@ -224,9 +246,9 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.input_resolution = input_resolution
|
self.input_resolution = input_resolution
|
||||||
ws, ss = self._calc_window_shift(window_size, shift_size)
|
self.target_shift_size = to_2tuple(shift_size)
|
||||||
self.window_size: Tuple[int, int] = ws
|
self.always_partition = always_partition
|
||||||
self.shift_size: Tuple[int, int] = ss
|
self.window_size, self.shift_size = self._calc_window_shift(window_size, target_shift_size=shift_size)
|
||||||
self.window_area = self.window_size[0] * self.window_size[1]
|
self.window_area = self.window_size[0] * self.window_size[1]
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
@ -251,6 +273,9 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
self._make_attention_mask()
|
||||||
|
|
||||||
|
def _make_attention_mask(self):
|
||||||
if any(self.shift_size):
|
if any(self.shift_size):
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
@ -274,16 +299,47 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||||
|
|
||||||
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
def _calc_window_shift(
|
||||||
|
self,
|
||||||
|
target_window_size: Union[int, Tuple[int, int]],
|
||||||
|
target_shift_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||||
|
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||||
target_window_size = to_2tuple(target_window_size)
|
target_window_size = to_2tuple(target_window_size)
|
||||||
|
if target_shift_size is None:
|
||||||
|
# if passed value is None, recalculate from default window_size // 2 if it was active
|
||||||
|
target_shift_size = self.target_shift_size
|
||||||
|
if any(target_shift_size):
|
||||||
|
target_shift_size = [target_window_size[0] // 2, target_window_size[1] // 2]
|
||||||
|
else:
|
||||||
target_shift_size = to_2tuple(target_shift_size)
|
target_shift_size = to_2tuple(target_shift_size)
|
||||||
|
if self.always_partition:
|
||||||
|
return target_window_size, target_shift_size
|
||||||
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||||
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||||
return tuple(window_size), tuple(shift_size)
|
return tuple(window_size), tuple(shift_size)
|
||||||
|
|
||||||
|
def set_input_size(
|
||||||
|
self,
|
||||||
|
feat_size: Tuple[int, int],
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
always_partition: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feat_size: New input resolution
|
||||||
|
window_size: New window size
|
||||||
|
always_partition: Change always_partition attribute if not None
|
||||||
|
"""
|
||||||
|
self.input_resolution = feat_size
|
||||||
|
if always_partition is not None:
|
||||||
|
self.always_partition = always_partition
|
||||||
|
self.window_size, self.shift_size = self._calc_window_shift(window_size)
|
||||||
|
self.window_area = self.window_size[0] * self.window_size[1]
|
||||||
|
self.attn.set_window_size(self.window_size)
|
||||||
|
self._make_attention_mask()
|
||||||
|
|
||||||
def _attn(self, x):
|
def _attn(self, x):
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
|
|
||||||
@ -374,6 +430,7 @@ class SwinTransformerStage(nn.Module):
|
|||||||
num_heads: int = 4,
|
num_heads: int = 4,
|
||||||
head_dim: Optional[int] = None,
|
head_dim: Optional[int] = None,
|
||||||
window_size: _int_or_tuple_2_t = 7,
|
window_size: _int_or_tuple_2_t = 7,
|
||||||
|
always_partition: bool = False,
|
||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
@ -427,6 +484,7 @@ class SwinTransformerStage(nn.Module):
|
|||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
shift_size=0 if (i % 2 == 0) else shift_size,
|
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||||
|
always_partition=always_partition,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
@ -436,6 +494,30 @@ class SwinTransformerStage(nn.Module):
|
|||||||
)
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
|
|
||||||
|
def set_input_size(
|
||||||
|
self,
|
||||||
|
feat_size: Tuple[int, int],
|
||||||
|
window_size: int,
|
||||||
|
always_partition: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feat_size (Tuple[int, int]): New input resolution
|
||||||
|
window_size (int): New window size
|
||||||
|
"""
|
||||||
|
self.input_resolution = feat_size
|
||||||
|
if isinstance(self.downsample, nn.Identity):
|
||||||
|
self.output_resolution = feat_size
|
||||||
|
else:
|
||||||
|
self.output_resolution = tuple(i // 2 for i in feat_size)
|
||||||
|
for block in self.blocks:
|
||||||
|
block.set_input_size(
|
||||||
|
feat_size=self.output_resolution,
|
||||||
|
window_size=window_size,
|
||||||
|
always_partition=always_partition,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
|
||||||
@ -465,6 +547,7 @@ class SwinTransformer(nn.Module):
|
|||||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||||
head_dim: Optional[int] = None,
|
head_dim: Optional[int] = None,
|
||||||
window_size: _int_or_tuple_2_t = 7,
|
window_size: _int_or_tuple_2_t = 7,
|
||||||
|
always_partition: bool = False,
|
||||||
mlp_ratio: float = 4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
@ -546,6 +629,7 @@ class SwinTransformer(nn.Module):
|
|||||||
num_heads=num_heads[i],
|
num_heads=num_heads[i],
|
||||||
head_dim=head_dim[i],
|
head_dim=head_dim[i],
|
||||||
window_size=window_size[i],
|
window_size=window_size[i],
|
||||||
|
always_partition=always_partition,
|
||||||
mlp_ratio=mlp_ratio[i],
|
mlp_ratio=mlp_ratio[i],
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
proj_drop=proj_drop_rate,
|
proj_drop=proj_drop_rate,
|
||||||
@ -556,7 +640,7 @@ class SwinTransformer(nn.Module):
|
|||||||
in_dim = out_dim
|
in_dim = out_dim
|
||||||
if i > 0:
|
if i > 0:
|
||||||
scale *= 2
|
scale *= 2
|
||||||
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
|
self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')]
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
self.norm = norm_layer(self.num_features)
|
self.norm = norm_layer(self.num_features)
|
||||||
@ -584,6 +668,36 @@ class SwinTransformer(nn.Module):
|
|||||||
nwd.add(n)
|
nwd.add(n)
|
||||||
return nwd
|
return nwd
|
||||||
|
|
||||||
|
def set_input_size(
|
||||||
|
self,
|
||||||
|
img_size: Optional[Tuple[int, int]] = None,
|
||||||
|
patch_size: Optional[Tuple[int, int]] = None,
|
||||||
|
window_size: Optional[Tuple[int, int]] = None,
|
||||||
|
window_ratio: int = 32,
|
||||||
|
always_partition: Optional[bool] = None,
|
||||||
|
) -> None:
|
||||||
|
""" Updates the image resolution and window size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||||
|
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||||
|
window_ratio (int): divisor for calculating window size from image size
|
||||||
|
"""
|
||||||
|
if img_size is not None or patch_size is not None:
|
||||||
|
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
|
||||||
|
self.patch_grid = self.patch_embed.grid_size
|
||||||
|
if window_size is None:
|
||||||
|
img_size = self.patch_embed.img_size
|
||||||
|
window_size = tuple([s // window_ratio for s in img_size])
|
||||||
|
for index, stage in enumerate(self.layers):
|
||||||
|
stage_scale = 2 ** max(index - 1, 0)
|
||||||
|
print(self.patch_grid, stage_scale)
|
||||||
|
stage.set_input_size(
|
||||||
|
feat_size=(self.patch_grid[0] // stage_scale, self.patch_grid[1] // stage_scale),
|
||||||
|
window_size=window_size,
|
||||||
|
always_partition=always_partition,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def group_matcher(self, coarse=False):
|
def group_matcher(self, coarse=False):
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -119,7 +119,7 @@ class WindowMultiHeadAttention(nn.Module):
|
|||||||
assert dim % num_heads == 0, \
|
assert dim % num_heads == 0, \
|
||||||
"The number of input features (in_features) are not divisible by the number of heads (num_heads)."
|
"The number of input features (in_features) are not divisible by the number of heads (num_heads)."
|
||||||
self.in_features: int = dim
|
self.in_features: int = dim
|
||||||
self.window_size: Tuple[int, int] = window_size
|
self.window_size: Tuple[int, int] = to_2tuple(window_size)
|
||||||
self.num_heads: int = num_heads
|
self.num_heads: int = num_heads
|
||||||
self.sequential_attn: bool = sequential_attn
|
self.sequential_attn: bool = sequential_attn
|
||||||
|
|
||||||
@ -152,15 +152,14 @@ class WindowMultiHeadAttention(nn.Module):
|
|||||||
1.0 + relative_coordinates.abs())
|
1.0 + relative_coordinates.abs())
|
||||||
self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False)
|
self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False)
|
||||||
|
|
||||||
def update_input_size(self, new_window_size: int, **kwargs: Any) -> None:
|
def set_window_size(self, window_size: Tuple[int, int]) -> None:
|
||||||
"""Method updates the window size and so the pair-wise relative positions
|
"""Update window size & interpolate position embeddings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_window_size (int): New window size
|
window_size (int): New window size
|
||||||
kwargs (Any): Unused
|
|
||||||
"""
|
"""
|
||||||
# Set new window size and new pair-wise relative positions
|
window_size = to_2tuple(window_size)
|
||||||
self.window_size: int = new_window_size
|
if window_size != self.window_size:
|
||||||
|
self.window_size = window_size
|
||||||
self._make_pair_wise_relative_positions()
|
self._make_pair_wise_relative_positions()
|
||||||
|
|
||||||
def _relative_positional_encodings(self) -> torch.Tensor:
|
def _relative_positional_encodings(self) -> torch.Tensor:
|
||||||
@ -321,18 +320,18 @@ class SwinTransformerV2CrBlock(nn.Module):
|
|||||||
nn.init.constant_(self.norm1.weight, self.init_values)
|
nn.init.constant_(self.norm1.weight, self.init_values)
|
||||||
nn.init.constant_(self.norm2.weight, self.init_values)
|
nn.init.constant_(self.norm2.weight, self.init_values)
|
||||||
|
|
||||||
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
|
def set_input_size(self, feat_size: Tuple[int, int], window_size: Tuple[int, int]) -> None:
|
||||||
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_window_size (int): New window size
|
feat_size (Tuple[int, int]): New input resolution
|
||||||
new_feat_size (Tuple[int, int]): New input resolution
|
window_size (int): New window size
|
||||||
"""
|
"""
|
||||||
# Update input resolution
|
# Update input resolution
|
||||||
self.feat_size: Tuple[int, int] = new_feat_size
|
self.feat_size: Tuple[int, int] = feat_size
|
||||||
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size))
|
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
||||||
self.window_area = self.window_size[0] * self.window_size[1]
|
self.window_area = self.window_size[0] * self.window_size[1]
|
||||||
self.attn.update_input_size(new_window_size=self.window_size)
|
self.attn.set_window_size(self.window_size)
|
||||||
self._make_attention_mask()
|
self._make_attention_mask()
|
||||||
|
|
||||||
def _shifted_window_attn(self, x):
|
def _shifted_window_attn(self, x):
|
||||||
@ -510,18 +509,18 @@ class SwinTransformerV2CrStage(nn.Module):
|
|||||||
for index in range(depth)]
|
for index in range(depth)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None:
|
def set_input_size(self, feat_size: Tuple[int, int], window_size: int) -> None:
|
||||||
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_window_size (int): New window size
|
window_size (int): New window size
|
||||||
new_feat_size (Tuple[int, int]): New input resolution
|
feat_size (Tuple[int, int]): New input resolution
|
||||||
"""
|
"""
|
||||||
self.feat_size: Tuple[int, int] = (
|
self.feat_size: Tuple[int, int] = (
|
||||||
(new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size
|
(feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size
|
||||||
)
|
)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size)
|
block.set_input_size(feat_size=self.feat_size, window_size=window_size)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Forward pass.
|
"""Forward pass.
|
||||||
@ -657,33 +656,32 @@ class SwinTransformerV2Cr(nn.Module):
|
|||||||
if weight_init != 'skip':
|
if weight_init != 'skip':
|
||||||
named_apply(init_weights, self)
|
named_apply(init_weights, self)
|
||||||
|
|
||||||
def update_input_size(
|
def set_input_size(
|
||||||
self,
|
self,
|
||||||
new_img_size: Optional[Tuple[int, int]] = None,
|
img_size: Optional[Tuple[int, int]] = None,
|
||||||
new_window_size: Optional[int] = None,
|
window_size: Optional[Tuple[int, int]] = None,
|
||||||
img_window_ratio: int = 32,
|
window_ratio: int = 32,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||||
new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||||
img_window_ratio (int): divisor for calculating window size from image size
|
window_ratio (int): divisor for calculating window size from image size
|
||||||
"""
|
"""
|
||||||
# Check parameters
|
if img_size is None:
|
||||||
if new_img_size is None:
|
img_size = self.img_size
|
||||||
new_img_size = self.img_size
|
|
||||||
else:
|
else:
|
||||||
new_img_size = to_2tuple(new_img_size)
|
img_size = to_2tuple(img_size)
|
||||||
if new_window_size is None:
|
if window_size is None:
|
||||||
new_window_size = tuple([s // img_window_ratio for s in new_img_size])
|
window_size = tuple([s // window_ratio for s in img_size])
|
||||||
# Compute new patch resolution & update resolution of each stage
|
# Compute new patch resolution & update resolution of each stage
|
||||||
new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size)
|
patch_grid_size = (img_size[0] // self.patch_size, img_size[1] // self.patch_size)
|
||||||
for index, stage in enumerate(self.stages):
|
for index, stage in enumerate(self.stages):
|
||||||
stage_scale = 2 ** max(index - 1, 0)
|
stage_scale = 2 ** max(index - 1, 0)
|
||||||
stage.update_input_size(
|
stage.set_input_size(
|
||||||
new_window_size=new_window_size,
|
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
|
||||||
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
|
window_size=window_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
|
@ -632,6 +632,31 @@ class VisionTransformer(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def set_input_size(
|
||||||
|
self,
|
||||||
|
img_size: Optional[Tuple[int, int]] = None,
|
||||||
|
patch_size: Optional[Tuple[int, int]] = None,
|
||||||
|
):
|
||||||
|
"""Method updates the input image resolution, patch size
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size: New input resolution, if None current resolution is used
|
||||||
|
patch_size: New patch size, if None existing patch size is used
|
||||||
|
"""
|
||||||
|
prev_grid_size = self.patch_embed.grid_size
|
||||||
|
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
|
||||||
|
if self.pos_embed is not None:
|
||||||
|
num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
|
||||||
|
num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
|
||||||
|
if num_new_tokens != self.pos_embed.shape[1]:
|
||||||
|
self.pos_embed = nn.Parameter(resample_abs_pos_embed(
|
||||||
|
self.pos_embed,
|
||||||
|
new_size=self.patch_embed.grid_size,
|
||||||
|
old_size=prev_grid_size,
|
||||||
|
num_prefix_tokens=num_prefix_tokens,
|
||||||
|
verbose=True,
|
||||||
|
))
|
||||||
|
|
||||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.pos_embed is None:
|
if self.pos_embed is None:
|
||||||
return x.view(x.shape[0], -1, x.shape[-1])
|
return x.view(x.shape[0], -1, x.shape[-1])
|
||||||
|
@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
|
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, HybridEmbed
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
@ -31,172 +30,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem
|
|||||||
from .vision_transformer import VisionTransformer
|
from .vision_transformer import VisionTransformer
|
||||||
|
|
||||||
|
|
||||||
class HybridEmbed(nn.Module):
|
|
||||||
""" CNN Feature Map Embedding
|
|
||||||
Extract feature map from CNN, flatten, project to embedding dim.
|
|
||||||
"""
|
|
||||||
output_fmt: Format
|
|
||||||
dynamic_img_pad: torch.jit.Final[bool]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backbone: nn.Module,
|
|
||||||
img_size: Union[int, Tuple[int, int]] = 224,
|
|
||||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
|
||||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
||||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
|
||||||
in_chans: int = 3,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
bias: bool = True,
|
|
||||||
proj: bool = True,
|
|
||||||
flatten: bool = True,
|
|
||||||
output_fmt: Optional[str] = None,
|
|
||||||
strict_img_size: bool = True,
|
|
||||||
dynamic_img_pad: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
assert isinstance(backbone, nn.Module)
|
|
||||||
img_size = to_2tuple(img_size)
|
|
||||||
patch_size = to_2tuple(patch_size)
|
|
||||||
self.img_size = img_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.backbone = backbone
|
|
||||||
if feature_size is None:
|
|
||||||
with torch.no_grad():
|
|
||||||
# NOTE Most reliable way of determining output dims is to run forward pass
|
|
||||||
training = backbone.training
|
|
||||||
if training:
|
|
||||||
backbone.eval()
|
|
||||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
|
||||||
if isinstance(o, (list, tuple)):
|
|
||||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
feature_size = o.shape[-2:]
|
|
||||||
feature_dim = o.shape[1]
|
|
||||||
backbone.train(training)
|
|
||||||
feature_ratio = tuple([s // f for s, f in zip(img_size, feature_size)])
|
|
||||||
else:
|
|
||||||
|
|
||||||
feature_size = to_2tuple(feature_size)
|
|
||||||
feature_ratio = to_2tuple(feature_ratio or 16)
|
|
||||||
if hasattr(self.backbone, 'feature_info'):
|
|
||||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
||||||
else:
|
|
||||||
feature_dim = self.backbone.num_features
|
|
||||||
if not dynamic_img_pad:
|
|
||||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
|
||||||
self.feature_size = feature_size
|
|
||||||
self.feature_ratio = feature_ratio
|
|
||||||
self.grid_size = tuple([f // p for f, p in zip(self.feature_size, self.patch_size)])
|
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
||||||
if output_fmt is not None:
|
|
||||||
self.flatten = False
|
|
||||||
self.output_fmt = Format(output_fmt)
|
|
||||||
else:
|
|
||||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
|
||||||
self.flatten = flatten
|
|
||||||
self.output_fmt = Format.NCHW
|
|
||||||
self.strict_img_size = strict_img_size
|
|
||||||
self.dynamic_img_pad = dynamic_img_pad
|
|
||||||
|
|
||||||
if proj:
|
|
||||||
self.proj = nn.Conv2d(
|
|
||||||
feature_dim,
|
|
||||||
embed_dim,
|
|
||||||
kernel_size=patch_size,
|
|
||||||
stride=patch_size,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert feature_dim == embed_dim,\
|
|
||||||
f'The feature dim ({feature_dim} must match embed dim ({embed_dim}) when projection disabled.'
|
|
||||||
self.proj = nn.Identity()
|
|
||||||
|
|
||||||
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
|
||||||
total_reduction = (
|
|
||||||
self.feature_ratio[0] * self.patch_size[0],
|
|
||||||
self.feature_ratio[1] * self.patch_size[1]
|
|
||||||
)
|
|
||||||
if as_scalar:
|
|
||||||
return max(total_reduction)
|
|
||||||
else:
|
|
||||||
return total_reduction
|
|
||||||
|
|
||||||
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
|
||||||
""" Get feature grid size taking account dynamic padding and backbone network feat reduction
|
|
||||||
"""
|
|
||||||
feat_size = (img_size[0] // self.feature_ratio[0], img_size[1] // self.feature_ratio[1])
|
|
||||||
if self.dynamic_img_pad:
|
|
||||||
return math.ceil(feat_size[0] / self.patch_size[0]), math.ceil(feat_size[1] / self.patch_size[1])
|
|
||||||
else:
|
|
||||||
return feat_size[0] // self.patch_size[0], feat_size[1] // self.patch_size[1]
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable: bool = True):
|
|
||||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
|
||||||
self.backbone.set_grad_checkpointing(enable=enable)
|
|
||||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
|
||||||
self.backbone.grad_checkpointing = enable
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.backbone(x)
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
_, _, H, W = x.shape
|
|
||||||
if self.dynamic_img_pad:
|
|
||||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
|
||||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
|
||||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
|
||||||
x = self.proj(x)
|
|
||||||
if self.flatten:
|
|
||||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
|
||||||
elif self.output_fmt != Format.NCHW:
|
|
||||||
x = nchw_to(x, self.output_fmt)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class HybridEmbedWithSize(nn.Module):
|
|
||||||
""" CNN Feature Map Embedding
|
|
||||||
Extract feature map from CNN, flatten, project to embedding dim.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backbone: nn.Module,
|
|
||||||
img_size: Union[int, Tuple[int, int]] = 224,
|
|
||||||
patch_size: Union[int, Tuple[int, int]] = 1,
|
|
||||||
feature_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
||||||
feature_ratio: Optional[Union[int, Tuple[int, int]]] = None,
|
|
||||||
in_chans: int = 3,
|
|
||||||
embed_dim: int = 768,
|
|
||||||
bias=True,
|
|
||||||
proj=True,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
backbone=backbone,
|
|
||||||
img_size=img_size,
|
|
||||||
patch_size=patch_size,
|
|
||||||
feature_size=feature_size,
|
|
||||||
feature_ratio=feature_ratio,
|
|
||||||
in_chans=in_chans,
|
|
||||||
embed_dim=embed_dim,
|
|
||||||
bias=bias,
|
|
||||||
proj=proj,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def set_grad_checkpointing(self, enable: bool = True):
|
|
||||||
if hasattr(self.backbone, 'set_grad_checkpointing'):
|
|
||||||
self.backbone.set_grad_checkpointing(enable=enable)
|
|
||||||
elif hasattr(self.backbone, 'grad_checkpointing'):
|
|
||||||
self.backbone.grad_checkpointing = enable
|
|
||||||
|
|
||||||
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
|
||||||
x = self.backbone(x)
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
x = self.proj(x)
|
|
||||||
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
|
||||||
|
|
||||||
|
|
||||||
class ConvStem(nn.Sequential):
|
class ConvStem(nn.Sequential):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -29,12 +29,11 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
|
from timm.layers import create_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, \
|
||||||
make_divisible, DropPath
|
make_divisible, DropPath, HybridEmbed
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import named_apply, checkpoint_seq
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
|
from .vision_transformer import VisionTransformer, checkpoint_filter_fn
|
||||||
from .vision_transformer_hybrid import HybridEmbed
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user