convert npu roll op into paddle roll (#3138)
* convert npu roll op into paddle roll * convert npu roll op into paddle rollpull/3147/head
parent
d1ae38d30d
commit
0f915713ec
|
@ -42,79 +42,11 @@ MODEL_URLS = {
|
|||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
# The following re-implementation of roll is inspired by
|
||||
# https://gitee.com/ascend/pytorch/blob/master/torch_npu/contrib/function/roll.py
|
||||
|
||||
|
||||
def masked_fill(x, mask, value):
|
||||
y = paddle.full(x.shape, value, x.dtype)
|
||||
return paddle.where(mask, y, x)
|
||||
|
||||
|
||||
class RollWithIndexSelect(paddle.autograd.PyLayer):
|
||||
@staticmethod
|
||||
def forward(ctx, input1, index_fp, index_bp):
|
||||
N, H, W, C = input1.shape
|
||||
ctx.input1 = input1
|
||||
ctx.index_bp = index_bp
|
||||
result = input1.reshape([N, H * W, C]).index_select(
|
||||
index_fp, 1).reshape([N, H, W, C])
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
input1 = ctx.input1
|
||||
N, H, W, C = input1.shape
|
||||
index_bp = ctx.index_bp
|
||||
grad_input = grad.reshape([N, H * W, C]).index_select(
|
||||
index_bp, 1).reshape([N, H, W, C])
|
||||
return grad_input, None, None
|
||||
|
||||
|
||||
def get_roll_index(H, W, shifts, place):
|
||||
index = np.arange(0, H * W, dtype=np.int64).reshape([H, W])
|
||||
index_fp = np.roll(index, shift=shifts, axis=(0, 1)).reshape([-1])
|
||||
index_bp = {i: idx for idx, i in enumerate(index_fp.tolist())}
|
||||
index_bp = [index_bp[i] for i in range(H * W)]
|
||||
index_fp = paddle.to_tensor(index_fp, place=place)
|
||||
index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place)
|
||||
return [index_fp, index_bp]
|
||||
|
||||
|
||||
class NpuRollWithIndexSelect():
|
||||
def __init__(self):
|
||||
self.index_dict = {}
|
||||
self.roll_with_index_select = RollWithIndexSelect.apply
|
||||
|
||||
def __call__(self, x, shifts, axis):
|
||||
assert x.dim() == 4
|
||||
assert len(shifts) == 2
|
||||
assert len(axis) == 2
|
||||
N, H, W, C = x.shape
|
||||
key = (H, W, shifts, axis)
|
||||
if key not in self.index_dict:
|
||||
self.index_dict[key] = get_roll_index(H, W, shifts, x.place)
|
||||
index_fp, index_bp = self.index_dict[key]
|
||||
return self.roll_with_index_select(x, index_fp, index_bp)
|
||||
|
||||
|
||||
class RollWrapper(object):
|
||||
_roll = None
|
||||
|
||||
@staticmethod
|
||||
def roll(x, shifts, axis):
|
||||
return RollWrapper._roll(x, shifts, axis)
|
||||
|
||||
|
||||
# NOTE(xiongkun): paddle.SOT can't analysis this builtin function, which will cause subgraph break in sot.
|
||||
# we do this here will not effect sot translate.
|
||||
paddle_custom_device_types = paddle.device.get_all_custom_device_type()
|
||||
|
||||
if RollWrapper._roll is None:
|
||||
RollWrapper._roll = NpuRollWithIndexSelect(
|
||||
) if 'npu' in paddle_custom_device_types else paddle.roll
|
||||
|
||||
|
||||
class Mlp(nn.Layer):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
|
@ -457,7 +389,7 @@ class SwinTransformerBlock(nn.Layer):
|
|||
5] > 0 # change variable name
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = RollWrapper.roll(
|
||||
shifted_x = paddle.roll(
|
||||
x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
|
||||
else:
|
||||
shifted_x = x
|
||||
|
@ -484,7 +416,7 @@ class SwinTransformerBlock(nn.Layer):
|
|||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = RollWrapper.roll(
|
||||
x = paddle.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
axis=(1, 2))
|
||||
|
|
Loading…
Reference in New Issue