convert npu roll op into paddle roll (#3139)

pull/3223/head
zhuyipin 2024-05-15 17:11:28 +08:00 committed by GitHub
parent 168097fd61
commit 28dc67e3e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 66 deletions

View File

@ -42,70 +42,6 @@ MODEL_URLS = {
__all__ = list(MODEL_URLS.keys())
# The following re-implementation of roll is inspired by
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
class RollWithIndexSelect(paddle.autograd.PyLayer):
@staticmethod
def forward(ctx, input1, index_fp, index_bp):
N, H, W, C = input1.shape
ctx.input1 = input1
ctx.index_bp = index_bp
result = input1.reshape([N, H * W, C]).index_select(
index_fp, 1).reshape([N, H, W, C])
return result
@staticmethod
def backward(ctx, grad):
input1 = ctx.input1
N, H, W, C = input1.shape
index_bp = ctx.index_bp
grad_input = grad.reshape([N, H * W, C]).index_select(
index_bp, 1).reshape([N, H, W, C])
return grad_input, None, None
def get_roll_index(H, W, shifts, place):
index = np.arange(0, H * W, dtype=np.int64).reshape([H, W])
index_fp = np.roll(index, shift=shifts, axis=(0, 1)).reshape([-1])
index_bp = {i: idx for idx, i in enumerate(index_fp.tolist())}
index_bp = [index_bp[i] for i in range(H * W)]
index_fp = paddle.to_tensor(index_fp, place=place)
index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place)
return [index_fp, index_bp]
class NpuRollWithIndexSelect():
def __init__(self):
self.index_dict = {}
self.roll_with_index_select = RollWithIndexSelect.apply
def __call__(self, x, shifts, axis):
assert x.dim() == 4
assert len(shifts) == 2
assert len(axis) == 2
N, H, W, C = x.shape
key = (H, W, shifts, axis)
if key not in self.index_dict:
self.index_dict[key] = get_roll_index(H, W, shifts, x.place)
index_fp, index_bp = self.index_dict[key]
return self.roll_with_index_select(x, index_fp, index_bp)
class RollWrapper(object):
_roll = None
@staticmethod
def roll(x, shifts, axis):
if RollWrapper._roll is None:
RollWrapper._roll = NpuRollWithIndexSelect(
) if 'npu' in paddle.device.get_all_custom_device_type(
) else paddle.roll
return RollWrapper._roll(x, shifts, axis)
class Mlp(nn.Layer):
def __init__(self,
@ -420,7 +356,7 @@ class SwinTransformerBlock(nn.Layer):
# cyclic shift
if self.shift_size > 0:
shifted_x = RollWrapper.roll(
shifted_x = paddle.roll(
x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
else:
shifted_x = x
@ -444,7 +380,7 @@ class SwinTransformerBlock(nn.Layer):
# reverse cyclic shift
if self.shift_size > 0:
x = RollWrapper.roll(
x = paddle.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
axis=(1, 2))