【Feature】fix the resolution problem for clip-vision transformer part and swim … (#3001)
* fix the resolution problem for clip-vision transformer part and swim transformer
fix the resolution problem for clip-vision transformer part and swim transformer
* adjust function name
* integrate the pading function to one
* support non 224 resolution
* fix clip patch embedding resolution problem
* fix conflict
remove the conflict checkpoint function
* Revert "fix conflict"
This reverts commit d7a7dade71
.
* fix conflict
remove check resolution function
pull/3018/head
parent
c446df9b69
commit
e1a7840816
ppcls/arch/backbone
legendary_models
|
@ -14,3 +14,4 @@ nohup.out
|
|||
.DS_Store
|
||||
.idea
|
||||
inference/
|
||||
test.py
|
||||
|
|
|
@ -46,6 +46,11 @@ __all__ = list(MODEL_URLS.keys())
|
|||
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
|
||||
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
y = paddle.full(x.shape, value, x.dtype)
|
||||
return paddle.where(mask, y, x)
|
||||
|
||||
|
||||
class RollWithIndexSelect(paddle.autograd.PyLayer):
|
||||
@staticmethod
|
||||
def forward(ctx, input1, index_fp, index_bp):
|
||||
|
@ -134,6 +139,30 @@ class Mlp(nn.Layer):
|
|||
return x
|
||||
|
||||
|
||||
def pading_for_not_divisible(pixel_values,
|
||||
height,
|
||||
width,
|
||||
patch_size,
|
||||
format="BCHW",
|
||||
function="split"):
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = (patch_size, patch_size)
|
||||
if function == "split":
|
||||
pading_width = patch_size[1] - width % patch_size[1]
|
||||
pading_height = patch_size[0] - height % patch_size[0]
|
||||
elif function == "merge":
|
||||
pading_width = width % 2
|
||||
pading_height = height % 2
|
||||
if format == "BCHW":
|
||||
pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width)
|
||||
elif format == "BHWC":
|
||||
pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0)
|
||||
else:
|
||||
assert ("vaild format")
|
||||
|
||||
return F.pad(pixel_values, pad_index), pad_index
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
|
@ -360,7 +389,6 @@ class SwinTransformerBlock(nn.Layer):
|
|||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.check_condition()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
|
@ -378,52 +406,50 @@ class SwinTransformerBlock(nn.Layer):
|
|||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
img_mask = paddle.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.reshape(
|
||||
[-1, self.window_size * self.window_size])
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
|
||||
huns = -100.0 * paddle.ones_like(attn_mask)
|
||||
attn_mask = huns * (attn_mask != 0).astype("float32")
|
||||
else:
|
||||
attn_mask = None
|
||||
H, W = self.input_resolution
|
||||
attn_mask = paddle.zeros([1, H, W, 1])
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def check_condition(self):
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
def get_attn_mask(self, height, width, dtype):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for shifted window multihead self attention
|
||||
img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None), )
|
||||
width_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None), )
|
||||
count = 0
|
||||
for height_slice in height_slices:
|
||||
for width_slice in width_slices:
|
||||
img_mask[:, height_slice, width_slice, :] = count
|
||||
count += 1
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.reshape(
|
||||
(-1, self.window_size * self.window_size))
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
|
||||
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x, input_dimensions):
|
||||
H, W = input_dimensions
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.reshape([B, H, W, C])
|
||||
|
||||
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
|
||||
"BHWC")
|
||||
_, height_pad, width_pad, _ = x.shape
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = RollWrapper.roll(
|
||||
|
@ -439,14 +465,15 @@ class SwinTransformerBlock(nn.Layer):
|
|||
C]) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype)
|
||||
attn_windows = self.attn(
|
||||
x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.reshape(
|
||||
[-1, self.window_size, self.window_size, C])
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W,
|
||||
C) # B H' W' C
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, height_pad,
|
||||
width_pad, C) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
|
@ -456,6 +483,10 @@ class SwinTransformerBlock(nn.Layer):
|
|||
axis=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
||||
if was_padded:
|
||||
x = x[:, :H, :W, :]
|
||||
x = x.reshape([B, H * W, C])
|
||||
|
||||
# FFN
|
||||
|
@ -500,28 +531,25 @@ class PatchMerging(nn.Layer):
|
|||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, input_dimensions):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
H, W = input_dimensions
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, "x size ({}*{}) are not even.".format(
|
||||
H, W)
|
||||
x = x.reshape((B, H, W, C))
|
||||
x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge")
|
||||
|
||||
# x = x.reshape([B, H, W, C])
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
|
||||
# x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
# x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
# x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
# x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
# x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
# x = x.reshape([B, H // 2, 2, W // 2, 2, C])
|
||||
# x = x.transpose((0, 1, 3, 4, 2, 5))
|
||||
|
||||
x = x.reshape([B, H // 2, 2, W // 2, 2, C])
|
||||
x = x.transpose((0, 1, 3, 4, 2, 5))
|
||||
|
||||
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
|
||||
x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
@ -606,12 +634,14 @@ class BasicLayer(nn.Layer):
|
|||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, input_dimensions):
|
||||
H, W = input_dimensions
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = blk(x, input_dimensions)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
H, W = (H + 1) // 2, (W + 1) // 2
|
||||
x = self.downsample(x, input_dimensions)
|
||||
return x, (H, W)
|
||||
|
||||
def extra_repr(self):
|
||||
return "dim={}, input_resolution={}, depth={}".format(
|
||||
|
@ -666,14 +696,14 @@ class PatchEmbed(nn.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# TODO (littletomatodonkey), uncomment the line will cause failure of jit.save
|
||||
# assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
|
||||
x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW")
|
||||
x = self.proj(x)
|
||||
|
||||
_, _, height, width = x.shape
|
||||
output_dimensions = (height, width)
|
||||
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
return x, output_dimensions
|
||||
|
||||
def flops(self):
|
||||
Ho, Wo = self.patches_resolution
|
||||
|
@ -804,13 +834,13 @@ class SwinTransformer(TheseusLayer):
|
|||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, output_dimensions = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x, output_dimensions = layer(x, output_dimensions)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
x = self.avgpool(x.transpose([0, 2, 1])) # B C 1
|
||||
|
@ -992,4 +1022,4 @@ def SwinTransformer_large_patch4_window12_384(
|
|||
use_ssld=use_ssld,
|
||||
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
|
||||
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
|
||||
return model
|
||||
return model
|
||||
|
|
|
@ -55,6 +55,81 @@ MODEL_URLS = {
|
|||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MAE_vit_base_patch16.pdparams",
|
||||
}
|
||||
|
||||
|
||||
def resize_pos_embed(pos_embed,
|
||||
src_shape,
|
||||
dst_shape,
|
||||
mode='bicubic',
|
||||
num_extra_tokens=1):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights with shape
|
||||
[1, L, C].
|
||||
src_shape (tuple): The resolution of downsampled origin training
|
||||
image, in format (H, W).
|
||||
dst_shape (tuple): The resolution of downsampled new training
|
||||
image, in format (H, W).
|
||||
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
|
||||
'linear', 'bilinear', 'bicubic' and 'trilinear'.
|
||||
Defaults to 'bicubic'.
|
||||
num_extra_tokens (int): The number of extra tokens, such as cls_token.
|
||||
Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
|
||||
"""
|
||||
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
|
||||
return pos_embed
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
|
||||
_, L, C = pos_embed.shape
|
||||
src_h, src_w = src_shape
|
||||
assert L == src_h * src_w + num_extra_tokens, \
|
||||
f"The length of `pos_embed` ({L}) doesn't match the expected " \
|
||||
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
|
||||
'`img_size` argument.'
|
||||
extra_tokens = pos_embed[:, :num_extra_tokens]
|
||||
|
||||
src_weight = pos_embed[:, num_extra_tokens:]
|
||||
src_weight = src_weight.reshape([-1, src_h, src_w, C]).transpose(
|
||||
[0, 3, 1, 2])
|
||||
|
||||
# The cubic interpolate algorithm only accepts float32
|
||||
dst_weight = paddle.nn.functional.interpolate(
|
||||
paddle.cast(src_weight, paddle.float32),
|
||||
size=dst_shape,
|
||||
align_corners=False,
|
||||
mode=mode)
|
||||
dst_weight = paddle.flatten(dst_weight, 2).transpose([0, 2, 1])
|
||||
dst_weight = paddle.cast(dst_weight, src_weight.dtype)
|
||||
|
||||
return paddle.concat((extra_tokens, dst_weight), axis=1)
|
||||
|
||||
|
||||
def pading_for_not_divisible(pixel_values,
|
||||
height,
|
||||
width,
|
||||
patch_size,
|
||||
format="BCHW",
|
||||
function="split"):
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = (patch_size, patch_size)
|
||||
if function == "split":
|
||||
pading_width = patch_size[1] - width % patch_size[1]
|
||||
pading_height = patch_size[0] - height % patch_size[0]
|
||||
elif function == "merge":
|
||||
pading_width = width % 2
|
||||
pading_height = height % 2
|
||||
if format == "BCHW":
|
||||
pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width)
|
||||
elif format == "BHWC":
|
||||
pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0)
|
||||
else:
|
||||
assert ("vaild format")
|
||||
|
||||
return paddle.nn.functional.pad(pixel_values, pad_index), pad_index
|
||||
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
_model_size = None
|
||||
|
@ -501,11 +576,13 @@ class PatchEmbed(nn.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x, _ = pading_for_not_divisible(x, H, W, patch_size=self.patch_size)
|
||||
|
||||
x = self.proj(x).flatten(2).transpose((0, 2, 1))
|
||||
return x
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
|
||||
x = x.flatten(2).transpose((0, 2, 1))
|
||||
return x, (H, W)
|
||||
|
||||
|
||||
class Head(nn.Layer):
|
||||
|
@ -678,14 +755,15 @@ class VisionTransformer(nn.Layer):
|
|||
|
||||
def forward_features(self, x):
|
||||
# B = x.shape[0]
|
||||
B = paddle.shape(x)[0]
|
||||
x = self.patch_embed(x)
|
||||
B, C, H, W = x.shape
|
||||
x, output_dimensions = self.patch_embed(x)
|
||||
if not _model_size in _model_diff['remove_cls_token']:
|
||||
cls_tokens = self.cls_token.expand((B, -1, -1))
|
||||
x = paddle.concat((cls_tokens, x), axis=1)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = x + resize_pos_embed(self.pos_embed, self.window_size,
|
||||
output_dimensions)
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = self.pos_drop(x)
|
||||
|
|
|
@ -50,6 +50,11 @@ MODEL_URLS = {
|
|||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
y = paddle.full(x.shape, value, x.dtype)
|
||||
return paddle.where(mask, y, x)
|
||||
|
||||
|
||||
class Mlp(nn.Layer):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
|
@ -79,6 +84,30 @@ def masked_fill(x, mask, value):
|
|||
return paddle.where(mask, y, x)
|
||||
|
||||
|
||||
def pading_for_not_divisible(pixel_values,
|
||||
height,
|
||||
width,
|
||||
patch_size,
|
||||
format="BCHW",
|
||||
function="split"):
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = (patch_size, patch_size)
|
||||
if function == "split":
|
||||
pading_width = patch_size[1] - width % patch_size[1]
|
||||
pading_height = patch_size[0] - height % patch_size[0]
|
||||
elif function == "merge":
|
||||
pading_width = width % 2
|
||||
pading_height = height % 2
|
||||
if format == "BCHW":
|
||||
pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width)
|
||||
elif format == "BHWC":
|
||||
pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0)
|
||||
else:
|
||||
assert ("vaild format")
|
||||
|
||||
return F.pad(pixel_values, pad_index), pad_index
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
|
@ -96,6 +125,23 @@ def window_partition(x, window_size):
|
|||
return windows
|
||||
|
||||
|
||||
def pad_patch(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(
|
||||
[B, H // window_size, window_size, W // window_size, window_size, C])
|
||||
windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape(
|
||||
[-1, window_size, window_size, C])
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
|
@ -345,7 +391,7 @@ class SwinTransformerBlock(nn.Layer):
|
|||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
"""
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
|
@ -370,17 +416,54 @@ class SwinTransformerBlock(nn.Layer):
|
|||
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
|
||||
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
attn_mask = paddle.zeros([1, H, W, 1])
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
def get_attn_mask(self, height, width, dtype):
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for shifted window multihead self attention
|
||||
img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
|
||||
height_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None), )
|
||||
width_slices = (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None), )
|
||||
count = 0
|
||||
for height_slice in height_slices:
|
||||
for width_slice in width_slices:
|
||||
img_mask[:, height_slice, width_slice, :] = count
|
||||
count += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.reshape(
|
||||
(-1, self.window_size * self.window_size))
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
|
||||
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x, input_dimensions):
|
||||
|
||||
H, W = input_dimensions
|
||||
|
||||
#token format
|
||||
B, L, C = x.shape
|
||||
shortcut = x
|
||||
|
||||
x = x.reshape([B, H, W, C])
|
||||
#feature format
|
||||
|
||||
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
|
||||
"BHWC")
|
||||
_, height_pad, width_pad, _ = x.shape
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
|
@ -397,14 +480,15 @@ class SwinTransformerBlock(nn.Layer):
|
|||
C]) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype)
|
||||
attn_windows = self.attn(
|
||||
x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.reshape(
|
||||
[-1, self.window_size, self.window_size, C])
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H,
|
||||
W) # B H' W' C
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, height_pad,
|
||||
width_pad) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
|
@ -414,12 +498,16 @@ class SwinTransformerBlock(nn.Layer):
|
|||
axis=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
was_padded = pad_values[3] > 0 or pad_values[5] > 0
|
||||
if was_padded:
|
||||
x = x[:, :H, :W, :]
|
||||
|
||||
x = x.reshape([B, H * W, C])
|
||||
x = shortcut + self.drop_path(self.norm1(x))
|
||||
|
||||
# FFN
|
||||
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
|
@ -457,19 +545,32 @@ class PatchMerging(nn.Layer):
|
|||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
|
||||
self.norm = norm_layer(2 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, input_dimensions):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
H, W = input_dimensions
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.reshape([B, H // 2, 2, W // 2, 2, C])
|
||||
x = x.transpose((0, 1, 3, 4, 2, 5))
|
||||
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
|
||||
x = self.reduction(x)
|
||||
x = x.reshape((B, H, W, C))
|
||||
x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge")
|
||||
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_0 = x[:, 0::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_1 = x[:, 1::2, 0::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_2 = x[:, 0::2, 1::2, :]
|
||||
# [batch_size, height/2, width/2, num_channels]
|
||||
input_feature_3 = x[:, 1::2, 1::2, :]
|
||||
|
||||
# [batch_size, height/2 * width/2, 4*num_channels]
|
||||
input_feature = paddle.concat([
|
||||
input_feature_0, input_feature_1, input_feature_2, input_feature_3
|
||||
], -1)
|
||||
input_feature = input_feature.reshape(
|
||||
(B, -1, 4 * C)) # [batch_size, height/2 * width/2, 4*C]
|
||||
x = self.reduction(input_feature)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
@ -548,12 +649,15 @@ class BasicLayer(nn.Layer):
|
|||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, input_dimensions):
|
||||
H, W = input_dimensions
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = blk(x, input_dimensions)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
H, W = (H + 1) // 2, (W + 1) // 2
|
||||
x = self.downsample(x, input_dimensions)
|
||||
|
||||
return x, (H, W)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
@ -607,13 +711,15 @@ class PatchEmbed(nn.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
|
||||
|
||||
x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW")
|
||||
x = self.proj(x)
|
||||
_, _, height, width = x.shape
|
||||
output_dimensions = (height, width)
|
||||
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
return x, output_dimensions
|
||||
|
||||
def flops(self):
|
||||
Ho, Wo = self.patches_resolution
|
||||
|
@ -674,6 +780,7 @@ class SwinTransformerV2(nn.Layer):
|
|||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.img_size = img_size
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = int(embed_dim * 2**(self.num_layers - 1))
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
@ -740,13 +847,13 @@ class SwinTransformerV2(nn.Layer):
|
|||
ones_(m.weight)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x, output_dimensions = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x, output_dimensions = layer(x, input_dimensions=output_dimensions)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
x = self.avgpool(x.transpose([0, 2, 1])) # B C 1
|
||||
|
|
Loading…
Reference in New Issue