【Feature】fix the resolution problem for clip-vision transformer part and swim … ()

* fix the resolution problem for clip-vision transformer part and swim transformer

fix the resolution problem for clip-vision transformer part and swim transformer

* adjust function name

* integrate the pading function to one

* support non 224 resolution

* fix clip patch embedding resolution problem

* fix conflict

remove the conflict checkpoint function

* Revert "fix conflict"

This reverts commit d7a7dade71.

* fix conflict

remove check resolution function
pull/3018/head
sky 2023-10-18 20:55:37 +08:00 committed by GitHub
parent c446df9b69
commit e1a7840816
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 318 additions and 102 deletions

1
.gitignore vendored
View File

@ -14,3 +14,4 @@ nohup.out
.DS_Store
.idea
inference/
test.py

View File

@ -46,6 +46,11 @@ __all__ = list(MODEL_URLS.keys())
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
class RollWithIndexSelect(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, input1, index_fp, index_bp):
@ -134,6 +139,30 @@ class Mlp(nn.Layer):
return x
def pading_for_not_divisible(pixel_values,
height,
width,
patch_size,
format="BCHW",
function="split"):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
if function == "split":
pading_width = patch_size[1] - width % patch_size[1]
pading_height = patch_size[0] - height % patch_size[0]
elif function == "merge":
pading_width = width % 2
pading_height = height % 2
if format == "BCHW":
pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width)
elif format == "BHWC":
pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0)
else:
assert ("vaild format")
return F.pad(pixel_values, pad_index), pad_index
def window_partition(x, window_size):
"""
Args:
@ -360,7 +389,6 @@ class SwinTransformerBlock(nn.Layer):
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.check_condition()
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
@ -378,52 +406,50 @@ class SwinTransformerBlock(nn.Layer):
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = paddle.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.reshape(
[-1, self.window_size * self.window_size])
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
huns = -100.0 * paddle.ones_like(attn_mask)
attn_mask = huns * (attn_mask != 0).astype("float32")
else:
attn_mask = None
H, W = self.input_resolution
attn_mask = paddle.zeros([1, H, W, 1])
self.register_buffer("attn_mask", attn_mask)
def check_condition(self):
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for shifted window multihead self attention
img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None), )
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None), )
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
def forward(self, x):
H, W = self.input_resolution
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape(
(-1, self.window_size * self.window_size))
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def forward(self, x, input_dimensions):
H, W = input_dimensions
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.reshape([B, H, W, C])
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
"BHWC")
_, height_pad, width_pad, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = RollWrapper.roll(
@ -439,14 +465,15 @@ class SwinTransformerBlock(nn.Layer):
C]) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype)
attn_windows = self.attn(
x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.reshape(
[-1, self.window_size, self.window_size, C])
shifted_x = window_reverse(attn_windows, self.window_size, H, W,
C) # B H' W' C
shifted_x = window_reverse(attn_windows, self.window_size, height_pad,
width_pad, C) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
@ -456,6 +483,10 @@ class SwinTransformerBlock(nn.Layer):
axis=(1, 2))
else:
x = shifted_x
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
x = x[:, :H, :W, :]
x = x.reshape([B, H * W, C])
# FFN
@ -500,28 +531,25 @@ class PatchMerging(nn.Layer):
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
def forward(self, x, input_dimensions):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
H, W = input_dimensions
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, "x size ({}*{}) are not even.".format(
H, W)
x = x.reshape((B, H, W, C))
x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge")
# x = x.reshape([B, H, W, C])
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
# x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
# x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
# x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
# x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
# x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
# x = x.reshape([B, H // 2, 2, W // 2, 2, C])
# x = x.transpose((0, 1, 3, 4, 2, 5))
x = x.reshape([B, H // 2, 2, W // 2, 2, C])
x = x.transpose((0, 1, 3, 4, 2, 5))
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
@ -606,12 +634,14 @@ class BasicLayer(nn.Layer):
else:
self.downsample = None
def forward(self, x):
def forward(self, x, input_dimensions):
H, W = input_dimensions
for blk in self.blocks:
x = blk(x)
x = blk(x, input_dimensions)
if self.downsample is not None:
x = self.downsample(x)
return x
H, W = (H + 1) // 2, (W + 1) // 2
x = self.downsample(x, input_dimensions)
return x, (H, W)
def extra_repr(self):
return "dim={}, input_resolution={}, depth={}".format(
@ -666,14 +696,14 @@ class PatchEmbed(nn.Layer):
def forward(self, x):
B, C, H, W = x.shape
# TODO (littletomatodonkey), uncomment the line will cause failure of jit.save
# assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW")
x = self.proj(x)
_, _, height, width = x.shape
output_dimensions = (height, width)
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
return x, output_dimensions
def flops(self):
Ho, Wo = self.patches_resolution
@ -804,13 +834,13 @@ class SwinTransformer(TheseusLayer):
ones_(m.weight)
def forward_features(self, x):
x = self.patch_embed(x)
x, output_dimensions = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x, output_dimensions = layer(x, output_dimensions)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose([0, 2, 1])) # B C 1
@ -992,4 +1022,4 @@ def SwinTransformer_large_patch4_window12_384(
use_ssld=use_ssld,
use_imagenet22k_pretrained=use_imagenet22k_pretrained,
use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained)
return model
return model

View File

@ -55,6 +55,81 @@ MODEL_URLS = {
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MAE_vit_base_patch16.pdparams",
}
def resize_pos_embed(pos_embed,
src_shape,
dst_shape,
mode='bicubic',
num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
pos_embed (torch.Tensor): Position embedding weights with shape
[1, L, C].
src_shape (tuple): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (tuple): The resolution of downsampled new training
image, in format (H, W).
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
'linear', 'bilinear', 'bicubic' and 'trilinear'.
Defaults to 'bicubic'.
num_extra_tokens (int): The number of extra tokens, such as cls_token.
Defaults to 1.
Returns:
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
"""
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
return pos_embed
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens, \
f"The length of `pos_embed` ({L}) doesn't match the expected " \
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
'`img_size` argument.'
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape([-1, src_h, src_w, C]).transpose(
[0, 3, 1, 2])
# The cubic interpolate algorithm only accepts float32
dst_weight = paddle.nn.functional.interpolate(
paddle.cast(src_weight, paddle.float32),
size=dst_shape,
align_corners=False,
mode=mode)
dst_weight = paddle.flatten(dst_weight, 2).transpose([0, 2, 1])
dst_weight = paddle.cast(dst_weight, src_weight.dtype)
return paddle.concat((extra_tokens, dst_weight), axis=1)
def pading_for_not_divisible(pixel_values,
height,
width,
patch_size,
format="BCHW",
function="split"):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
if function == "split":
pading_width = patch_size[1] - width % patch_size[1]
pading_height = patch_size[0] - height % patch_size[0]
elif function == "merge":
pading_width = width % 2
pading_height = height % 2
if format == "BCHW":
pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width)
elif format == "BHWC":
pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0)
else:
assert ("vaild format")
return paddle.nn.functional.pad(pixel_values, pad_index), pad_index
__all__ = list(MODEL_URLS.keys())
_model_size = None
@ -501,11 +576,13 @@ class PatchEmbed(nn.Layer):
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x, _ = pading_for_not_divisible(x, H, W, patch_size=self.patch_size)
x = self.proj(x).flatten(2).transpose((0, 2, 1))
return x
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose((0, 2, 1))
return x, (H, W)
class Head(nn.Layer):
@ -678,14 +755,15 @@ class VisionTransformer(nn.Layer):
def forward_features(self, x):
# B = x.shape[0]
B = paddle.shape(x)[0]
x = self.patch_embed(x)
B, C, H, W = x.shape
x, output_dimensions = self.patch_embed(x)
if not _model_size in _model_diff['remove_cls_token']:
cls_tokens = self.cls_token.expand((B, -1, -1))
x = paddle.concat((cls_tokens, x), axis=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = x + resize_pos_embed(self.pos_embed, self.window_size,
output_dimensions)
x = self.ln_pre(x)
x = self.pos_drop(x)

View File

@ -50,6 +50,11 @@ MODEL_URLS = {
__all__ = list(MODEL_URLS.keys())
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
class Mlp(nn.Layer):
def __init__(self,
in_features,
@ -79,6 +84,30 @@ def masked_fill(x, mask, value):
return paddle.where(mask, y, x)
def pading_for_not_divisible(pixel_values,
height,
width,
patch_size,
format="BCHW",
function="split"):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
if function == "split":
pading_width = patch_size[1] - width % patch_size[1]
pading_height = patch_size[0] - height % patch_size[0]
elif function == "merge":
pading_width = width % 2
pading_height = height % 2
if format == "BCHW":
pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width)
elif format == "BHWC":
pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0)
else:
assert ("vaild format")
return F.pad(pixel_values, pad_index), pad_index
def window_partition(x, window_size):
"""
Args:
@ -96,6 +125,23 @@ def window_partition(x, window_size):
return windows
def pad_patch(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.reshape(
[B, H // window_size, window_size, W // window_size, window_size, C])
windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape(
[-1, window_size, window_size, C])
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
@ -345,7 +391,7 @@ class SwinTransformerBlock(nn.Layer):
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
"""
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
@ -370,17 +416,54 @@ class SwinTransformerBlock(nn.Layer):
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
else:
attn_mask = None
"""
H, W = self.input_resolution
attn_mask = paddle.zeros([1, H, W, 1])
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
def get_attn_mask(self, height, width, dtype):
if self.shift_size > 0:
# calculate attention mask for shifted window multihead self attention
img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None), )
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None), )
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape(
(-1, self.window_size * self.window_size))
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0))
attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def forward(self, x, input_dimensions):
H, W = input_dimensions
#token format
B, L, C = x.shape
shortcut = x
x = x.reshape([B, H, W, C])
#feature format
x, pad_values = pading_for_not_divisible(x, H, W, self.window_size,
"BHWC")
_, height_pad, width_pad, _ = x.shape
# cyclic shift
if self.shift_size > 0:
@ -397,14 +480,15 @@ class SwinTransformerBlock(nn.Layer):
C]) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype)
attn_windows = self.attn(
x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.reshape(
[-1, self.window_size, self.window_size, C])
shifted_x = window_reverse(attn_windows, self.window_size, H,
W) # B H' W' C
shifted_x = window_reverse(attn_windows, self.window_size, height_pad,
width_pad) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
@ -414,12 +498,16 @@ class SwinTransformerBlock(nn.Layer):
axis=(1, 2))
else:
x = shifted_x
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
x = x[:, :H, :W, :]
x = x.reshape([B, H * W, C])
x = shortcut + self.drop_path(self.norm1(x))
# FFN
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
def extra_repr(self):
@ -457,19 +545,32 @@ class PatchMerging(nn.Layer):
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
def forward(self, x, input_dimensions):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
H, W = input_dimensions
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.reshape([B, H // 2, 2, W // 2, 2, C])
x = x.transpose((0, 1, 3, 4, 2, 5))
x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C
x = self.reduction(x)
x = x.reshape((B, H, W, C))
x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge")
# [batch_size, height/2, width/2, num_channels]
input_feature_0 = x[:, 0::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_1 = x[:, 1::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_2 = x[:, 0::2, 1::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_3 = x[:, 1::2, 1::2, :]
# [batch_size, height/2 * width/2, 4*num_channels]
input_feature = paddle.concat([
input_feature_0, input_feature_1, input_feature_2, input_feature_3
], -1)
input_feature = input_feature.reshape(
(B, -1, 4 * C)) # [batch_size, height/2 * width/2, 4*C]
x = self.reduction(input_feature)
x = self.norm(x)
return x
@ -548,12 +649,15 @@ class BasicLayer(nn.Layer):
else:
self.downsample = None
def forward(self, x):
def forward(self, x, input_dimensions):
H, W = input_dimensions
for blk in self.blocks:
x = blk(x)
x = blk(x, input_dimensions)
if self.downsample is not None:
x = self.downsample(x)
return x
H, W = (H + 1) // 2, (W + 1) // 2
x = self.downsample(x, input_dimensions)
return x, (H, W)
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
@ -607,13 +711,15 @@ class PatchEmbed(nn.Layer):
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW")
x = self.proj(x)
_, _, height, width = x.shape
output_dimensions = (height, width)
x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
return x, output_dimensions
def flops(self):
Ho, Wo = self.patches_resolution
@ -674,6 +780,7 @@ class SwinTransformerV2(nn.Layer):
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.img_size = img_size
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2**(self.num_layers - 1))
self.mlp_ratio = mlp_ratio
@ -740,13 +847,13 @@ class SwinTransformerV2(nn.Layer):
ones_(m.weight)
def forward_features(self, x):
x = self.patch_embed(x)
x, output_dimensions = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x, output_dimensions = layer(x, input_dimensions=output_dimensions)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose([0, 2, 1])) # B C 1