diff --git a/.gitignore b/.gitignore index ea0c4fab2..c0b097fda 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ nohup.out .DS_Store .idea inference/ +test.py diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index a177462d1..e0586c153 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -46,6 +46,11 @@ __all__ = list(MODEL_URLS.keys()) # https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + class RollWithIndexSelect(paddle.autograd.PyLayer): @staticmethod def forward(ctx, input1, index_fp, index_bp): @@ -134,6 +139,30 @@ class Mlp(nn.Layer): return x +def pading_for_not_divisible(pixel_values, + height, + width, + patch_size, + format="BCHW", + function="split"): + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + if function == "split": + pading_width = patch_size[1] - width % patch_size[1] + pading_height = patch_size[0] - height % patch_size[0] + elif function == "merge": + pading_width = width % 2 + pading_height = height % 2 + if format == "BCHW": + pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width) + elif format == "BHWC": + pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0) + else: + assert ("vaild format") + + return F.pad(pixel_values, pad_index), pad_index + + def window_partition(x, window_size): """ Args: @@ -360,7 +389,6 @@ class SwinTransformerBlock(nn.Layer): self.shift_size = shift_size self.mlp_ratio = mlp_ratio - self.check_condition() self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, @@ -378,52 +406,50 @@ class SwinTransformerBlock(nn.Layer): hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = paddle.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition( - img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.reshape( - [-1, self.window_size * self.window_size]) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - - huns = -100.0 * paddle.ones_like(attn_mask) - attn_mask = huns * (attn_mask != 0).astype("float32") - else: - attn_mask = None + H, W = self.input_resolution + attn_mask = paddle.zeros([1, H, W, 1]) self.register_buffer("attn_mask", attn_mask) - def check_condition(self): - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + def get_attn_mask(self, height, width, dtype): + if self.shift_size > 0: + # calculate attention mask for shifted window multihead self attention + img_mask = paddle.zeros((1, height, width, 1), dtype=dtype) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 - def forward(self, x): - H, W = self.input_resolution + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.reshape( + (-1, self.window_size * self.window_size)) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0)) + attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def forward(self, x, input_dimensions): + H, W = input_dimensions B, L, C = x.shape - assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.reshape([B, H, W, C]) + x, pad_values = pading_for_not_divisible(x, H, W, self.window_size, + "BHWC") + _, height_pad, width_pad, _ = x.shape # cyclic shift if self.shift_size > 0: shifted_x = RollWrapper.roll( @@ -439,14 +465,15 @@ class SwinTransformerBlock(nn.Layer): C]) # nW*B, window_size*window_size, C # W-MSA/SW-MSA + attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype) attn_windows = self.attn( - x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.reshape( [-1, self.window_size, self.window_size, C]) - shifted_x = window_reverse(attn_windows, self.window_size, H, W, - C) # B H' W' C + shifted_x = window_reverse(attn_windows, self.window_size, height_pad, + width_pad, C) # B H' W' C # reverse cyclic shift if self.shift_size > 0: @@ -456,6 +483,10 @@ class SwinTransformerBlock(nn.Layer): axis=(1, 2)) else: x = shifted_x + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + x = x[:, :H, :W, :] x = x.reshape([B, H * W, C]) # FFN @@ -500,28 +531,25 @@ class PatchMerging(nn.Layer): self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False) self.norm = norm_layer(4 * dim) - def forward(self, x): + def forward(self, x, input_dimensions): """ x: B, H*W, C """ - H, W = self.input_resolution + H, W = input_dimensions B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, "x size ({}*{}) are not even.".format( - H, W) + x = x.reshape((B, H, W, C)) + x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge") - # x = x.reshape([B, H, W, C]) + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - # x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - # x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - # x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - # x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - # x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + # x = x.reshape([B, H // 2, 2, W // 2, 2, C]) + # x = x.transpose((0, 1, 3, 4, 2, 5)) - x = x.reshape([B, H // 2, 2, W // 2, 2, C]) - x = x.transpose((0, 1, 3, 4, 2, 5)) - - x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C + x = x.reshape([B, -1, 4 * C]) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) @@ -606,12 +634,14 @@ class BasicLayer(nn.Layer): else: self.downsample = None - def forward(self, x): + def forward(self, x, input_dimensions): + H, W = input_dimensions for blk in self.blocks: - x = blk(x) + x = blk(x, input_dimensions) if self.downsample is not None: - x = self.downsample(x) - return x + H, W = (H + 1) // 2, (W + 1) // 2 + x = self.downsample(x, input_dimensions) + return x, (H, W) def extra_repr(self): return "dim={}, input_resolution={}, depth={}".format( @@ -666,14 +696,14 @@ class PatchEmbed(nn.Layer): def forward(self, x): B, C, H, W = x.shape - # TODO (littletomatodonkey), uncomment the line will cause failure of jit.save - # assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1]) + x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW") x = self.proj(x) - + _, _, height, width = x.shape + output_dimensions = (height, width) x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C if self.norm is not None: x = self.norm(x) - return x + return x, output_dimensions def flops(self): Ho, Wo = self.patches_resolution @@ -804,13 +834,13 @@ class SwinTransformer(TheseusLayer): ones_(m.weight) def forward_features(self, x): - x = self.patch_embed(x) + x, output_dimensions = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: - x = layer(x) + x, output_dimensions = layer(x, output_dimensions) x = self.norm(x) # B L C x = self.avgpool(x.transpose([0, 2, 1])) # B C 1 @@ -992,4 +1022,4 @@ def SwinTransformer_large_patch4_window12_384( use_ssld=use_ssld, use_imagenet22k_pretrained=use_imagenet22k_pretrained, use_imagenet22kto1k_pretrained=use_imagenet22kto1k_pretrained) - return model \ No newline at end of file + return model diff --git a/ppcls/arch/backbone/model_zoo/foundation_vit.py b/ppcls/arch/backbone/model_zoo/foundation_vit.py index 588020fe0..15a56bbaf 100644 --- a/ppcls/arch/backbone/model_zoo/foundation_vit.py +++ b/ppcls/arch/backbone/model_zoo/foundation_vit.py @@ -55,6 +55,81 @@ MODEL_URLS = { "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MAE_vit_base_patch16.pdparams", } + +def resize_pos_embed(pos_embed, + src_shape, + dst_shape, + mode='bicubic', + num_extra_tokens=1): + """Resize pos_embed weights. + + Args: + pos_embed (torch.Tensor): Position embedding weights with shape + [1, L, C]. + src_shape (tuple): The resolution of downsampled origin training + image, in format (H, W). + dst_shape (tuple): The resolution of downsampled new training + image, in format (H, W). + mode (str): Algorithm used for upsampling. Choose one from 'nearest', + 'linear', 'bilinear', 'bicubic' and 'trilinear'. + Defaults to 'bicubic'. + num_extra_tokens (int): The number of extra tokens, such as cls_token. + Defaults to 1. + + Returns: + torch.Tensor: The resized pos_embed of shape [1, L_new, C] + """ + if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: + return pos_embed + assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]' + _, L, C = pos_embed.shape + src_h, src_w = src_shape + assert L == src_h * src_w + num_extra_tokens, \ + f"The length of `pos_embed` ({L}) doesn't match the expected " \ + f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \ + '`img_size` argument.' + extra_tokens = pos_embed[:, :num_extra_tokens] + + src_weight = pos_embed[:, num_extra_tokens:] + src_weight = src_weight.reshape([-1, src_h, src_w, C]).transpose( + [0, 3, 1, 2]) + + # The cubic interpolate algorithm only accepts float32 + dst_weight = paddle.nn.functional.interpolate( + paddle.cast(src_weight, paddle.float32), + size=dst_shape, + align_corners=False, + mode=mode) + dst_weight = paddle.flatten(dst_weight, 2).transpose([0, 2, 1]) + dst_weight = paddle.cast(dst_weight, src_weight.dtype) + + return paddle.concat((extra_tokens, dst_weight), axis=1) + + +def pading_for_not_divisible(pixel_values, + height, + width, + patch_size, + format="BCHW", + function="split"): + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + if function == "split": + pading_width = patch_size[1] - width % patch_size[1] + pading_height = patch_size[0] - height % patch_size[0] + elif function == "merge": + pading_width = width % 2 + pading_height = height % 2 + if format == "BCHW": + pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width) + elif format == "BHWC": + pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0) + else: + assert ("vaild format") + + return paddle.nn.functional.pad(pixel_values, pad_index), pad_index + + __all__ = list(MODEL_URLS.keys()) _model_size = None @@ -501,11 +576,13 @@ class PatchEmbed(nn.Layer): def forward(self, x): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x, _ = pading_for_not_divisible(x, H, W, patch_size=self.patch_size) - x = self.proj(x).flatten(2).transpose((0, 2, 1)) - return x + x = self.proj(x) + _, _, H, W = x.shape + + x = x.flatten(2).transpose((0, 2, 1)) + return x, (H, W) class Head(nn.Layer): @@ -678,14 +755,15 @@ class VisionTransformer(nn.Layer): def forward_features(self, x): # B = x.shape[0] - B = paddle.shape(x)[0] - x = self.patch_embed(x) + B, C, H, W = x.shape + x, output_dimensions = self.patch_embed(x) if not _model_size in _model_diff['remove_cls_token']: cls_tokens = self.cls_token.expand((B, -1, -1)) x = paddle.concat((cls_tokens, x), axis=1) if self.pos_embed is not None: - x = x + self.pos_embed + x = x + resize_pos_embed(self.pos_embed, self.window_size, + output_dimensions) x = self.ln_pre(x) x = self.pos_drop(x) diff --git a/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py b/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py index c5180ca10..e51b97044 100644 --- a/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py +++ b/ppcls/arch/backbone/model_zoo/swin_transformer_v2.py @@ -50,6 +50,11 @@ MODEL_URLS = { __all__ = list(MODEL_URLS.keys()) +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + + class Mlp(nn.Layer): def __init__(self, in_features, @@ -79,6 +84,30 @@ def masked_fill(x, mask, value): return paddle.where(mask, y, x) +def pading_for_not_divisible(pixel_values, + height, + width, + patch_size, + format="BCHW", + function="split"): + if isinstance(patch_size, int): + patch_size = (patch_size, patch_size) + if function == "split": + pading_width = patch_size[1] - width % patch_size[1] + pading_height = patch_size[0] - height % patch_size[0] + elif function == "merge": + pading_width = width % 2 + pading_height = height % 2 + if format == "BCHW": + pad_index = (0, 0, 0, 0, 0, pading_height, 0, pading_width) + elif format == "BHWC": + pad_index = (0, 0, 0, pading_height, 0, pading_width, 0, 0) + else: + assert ("vaild format") + + return F.pad(pixel_values, pad_index), pad_index + + def window_partition(x, window_size): """ Args: @@ -96,6 +125,23 @@ def window_partition(x, window_size): return windows +def pad_patch(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.reshape( + [B, H // window_size, window_size, W // window_size, window_size, C]) + windows = x.transpose(perm=[0, 1, 3, 2, 4, 5]).reshape( + [-1, window_size, window_size, C]) + return windows + + def window_reverse(windows, window_size, H, W): """ Args: @@ -345,7 +391,7 @@ class SwinTransformerBlock(nn.Layer): hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - + """ if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution @@ -370,17 +416,54 @@ class SwinTransformerBlock(nn.Layer): attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0)) attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0)) else: - attn_mask = None + """ + H, W = self.input_resolution + attn_mask = paddle.zeros([1, H, W, 1]) self.register_buffer("attn_mask", attn_mask) - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" + def get_attn_mask(self, height, width, dtype): + if self.shift_size > 0: + # calculate attention mask for shifted window multihead self attention + img_mask = paddle.zeros((1, height, width, 1), dtype=dtype) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.reshape( + (-1, self.window_size * self.window_size)) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = masked_fill(attn_mask, attn_mask != 0, float(-100.0)) + attn_mask = masked_fill(attn_mask, attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def forward(self, x, input_dimensions): + + H, W = input_dimensions + + #token format + B, L, C = x.shape shortcut = x + x = x.reshape([B, H, W, C]) + #feature format + + x, pad_values = pading_for_not_divisible(x, H, W, self.window_size, + "BHWC") + _, height_pad, width_pad, _ = x.shape # cyclic shift if self.shift_size > 0: @@ -397,14 +480,15 @@ class SwinTransformerBlock(nn.Layer): C]) # nW*B, window_size*window_size, C # W-MSA/SW-MSA + attn_mask = self.get_attn_mask(height_pad, width_pad, x.dtype) attn_windows = self.attn( - x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.reshape( [-1, self.window_size, self.window_size, C]) - shifted_x = window_reverse(attn_windows, self.window_size, H, - W) # B H' W' C + shifted_x = window_reverse(attn_windows, self.window_size, height_pad, + width_pad) # B H' W' C # reverse cyclic shift if self.shift_size > 0: @@ -414,12 +498,16 @@ class SwinTransformerBlock(nn.Layer): axis=(1, 2)) else: x = shifted_x + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + x = x[:, :H, :W, :] + x = x.reshape([B, H * W, C]) x = shortcut + self.drop_path(self.norm1(x)) # FFN x = x + self.drop_path(self.norm2(self.mlp(x))) - return x def extra_repr(self): @@ -457,19 +545,32 @@ class PatchMerging(nn.Layer): self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False) self.norm = norm_layer(2 * dim) - def forward(self, x): + def forward(self, x, input_dimensions): """ x: B, H*W, C """ - H, W = self.input_resolution + H, W = input_dimensions B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - x = x.reshape([B, H // 2, 2, W // 2, 2, C]) - x = x.transpose((0, 1, 3, 4, 2, 5)) - x = x.reshape([B, H * W // 4, 4 * C]) # B H/2*W/2 4*C - x = self.reduction(x) + x = x.reshape((B, H, W, C)) + x, _ = pading_for_not_divisible(x, H, W, 2, "BHWC", function="merge") + + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = x[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = x[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = x[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = x[:, 1::2, 1::2, :] + + # [batch_size, height/2 * width/2, 4*num_channels] + input_feature = paddle.concat([ + input_feature_0, input_feature_1, input_feature_2, input_feature_3 + ], -1) + input_feature = input_feature.reshape( + (B, -1, 4 * C)) # [batch_size, height/2 * width/2, 4*C] + x = self.reduction(input_feature) x = self.norm(x) return x @@ -548,12 +649,15 @@ class BasicLayer(nn.Layer): else: self.downsample = None - def forward(self, x): + def forward(self, x, input_dimensions): + H, W = input_dimensions for blk in self.blocks: - x = blk(x) + x = blk(x, input_dimensions) if self.downsample is not None: - x = self.downsample(x) - return x + H, W = (H + 1) // 2, (W + 1) // 2 + x = self.downsample(x, input_dimensions) + + return x, (H, W) def extra_repr(self): return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" @@ -607,13 +711,15 @@ class PatchEmbed(nn.Layer): def forward(self, x): B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose([0, 2, 1]) # B Ph*Pw C + + x, _ = pading_for_not_divisible(x, H, W, self.patch_size, "BCHW") + x = self.proj(x) + _, _, height, width = x.shape + output_dimensions = (height, width) + x = x.flatten(2).transpose([0, 2, 1]) # B Ph*Pw C if self.norm is not None: x = self.norm(x) - return x + return x, output_dimensions def flops(self): Ho, Wo = self.patches_resolution @@ -674,6 +780,7 @@ class SwinTransformerV2(nn.Layer): self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape + self.img_size = img_size self.patch_norm = patch_norm self.num_features = int(embed_dim * 2**(self.num_layers - 1)) self.mlp_ratio = mlp_ratio @@ -740,13 +847,13 @@ class SwinTransformerV2(nn.Layer): ones_(m.weight) def forward_features(self, x): - x = self.patch_embed(x) + x, output_dimensions = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: - x = layer(x) + x, output_dimensions = layer(x, input_dimensions=output_dimensions) x = self.norm(x) # B L C x = self.avgpool(x.transpose([0, 2, 1])) # B C 1