diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index ee1b4e29c..47fc831cd 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -42,70 +42,6 @@ MODEL_URLS = { __all__ = list(MODEL_URLS.keys()) -# The following re-implementation of roll is inspired by -# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py - - -class RollWithIndexSelect(paddle.autograd.PyLayer): - @staticmethod - def forward(ctx, input1, index_fp, index_bp): - N, H, W, C = input1.shape - ctx.input1 = input1 - ctx.index_bp = index_bp - result = input1.reshape([N, H * W, C]).index_select( - index_fp, 1).reshape([N, H, W, C]) - return result - - @staticmethod - def backward(ctx, grad): - input1 = ctx.input1 - N, H, W, C = input1.shape - index_bp = ctx.index_bp - grad_input = grad.reshape([N, H * W, C]).index_select( - index_bp, 1).reshape([N, H, W, C]) - return grad_input, None, None - - -def get_roll_index(H, W, shifts, place): - index = np.arange(0, H * W, dtype=np.int64).reshape([H, W]) - index_fp = np.roll(index, shift=shifts, axis=(0, 1)).reshape([-1]) - index_bp = {i: idx for idx, i in enumerate(index_fp.tolist())} - index_bp = [index_bp[i] for i in range(H * W)] - index_fp = paddle.to_tensor(index_fp, place=place) - index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place) - return [index_fp, index_bp] - - -class NpuRollWithIndexSelect(): - def __init__(self): - self.index_dict = {} - self.roll_with_index_select = RollWithIndexSelect.apply - - def __call__(self, x, shifts, axis): - assert x.dim() == 4 - assert len(shifts) == 2 - assert len(axis) == 2 - N, H, W, C = x.shape - key = (H, W, shifts, axis) - if key not in self.index_dict: - self.index_dict[key] = get_roll_index(H, W, shifts, x.place) - index_fp, index_bp = self.index_dict[key] - return self.roll_with_index_select(x, index_fp, index_bp) - - -class RollWrapper(object): - - _roll = None - - @staticmethod - def roll(x, shifts, axis): - if RollWrapper._roll is None: - RollWrapper._roll = NpuRollWithIndexSelect( - ) if 'npu' in paddle.device.get_all_custom_device_type( - ) else paddle.roll - - return RollWrapper._roll(x, shifts, axis) - class Mlp(nn.Layer): def __init__(self, @@ -420,7 +356,7 @@ class SwinTransformerBlock(nn.Layer): # cyclic shift if self.shift_size > 0: - shifted_x = RollWrapper.roll( + shifted_x = paddle.roll( x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2)) else: shifted_x = x @@ -444,7 +380,7 @@ class SwinTransformerBlock(nn.Layer): # reverse cyclic shift if self.shift_size > 0: - x = RollWrapper.roll( + x = paddle.roll( shifted_x, shifts=(self.shift_size, self.shift_size), axis=(1, 2))