diff --git a/ppcls/arch/backbone/legendary_models/swin_transformer.py b/ppcls/arch/backbone/legendary_models/swin_transformer.py index 412d76ba8..50d781c96 100644 --- a/ppcls/arch/backbone/legendary_models/swin_transformer.py +++ b/ppcls/arch/backbone/legendary_models/swin_transformer.py @@ -62,7 +62,6 @@ class RollWithIndexSelect(paddle.autograd.PyLayer): grad_input = grad.reshape([N, H * W, C]).index_select(index_bp, 1).reshape([N, H, W, C]) return grad_input, None, None -roll_with_index_select = RollWithIndexSelect.apply def get_roll_index(H, W, shifts, place): # following tensors will be created on cpu place with npu custom device @@ -74,11 +73,26 @@ def get_roll_index(H, W, shifts, place): index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place) return [index_fp, index_bp] -class NpuRollWithIndexSelect(): +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance + +@singleton +class RollWrapperSingleton(): def __init__(self): self.index_dict = {} + self.roll_with_index_select = RollWithIndexSelect.apply + if 'npu' in paddle.device.get_all_custom_device_type(): + self.enable = True def __call__(self, x, shifts, axis): + if not self.enable: + return padlde.roll(x, shifts, axis) + assert x.dim() == 4 assert len(shifts) == 2 assert len(axis) == 2 @@ -87,16 +101,8 @@ class NpuRollWithIndexSelect(): if key not in self.index_dict: self.index_dict[key] = get_roll_index(H, W, shifts, x.place) index_fp, index_bp = self.index_dict[key] - return roll_with_index_select(x, index_fp, index_bp) + return self.roll_with_index_select(x, index_fp, index_bp) -roll = None - -def _lazy_init_roll(x): - global roll - if 'npu' in paddle.device.get_all_custom_device_type() and hasattr(x, '_place_str') and 'npu' in x._place_str: - roll = NpuRollWithIndexSelect() - else: - roll = paddle.roll class Mlp(nn.Layer): def __init__(self, @@ -409,11 +415,9 @@ class SwinTransformerBlock(nn.Layer): x = self.norm1(x) x = x.reshape([B, H, W, C]) + roll = RollWrapperSingleton() # cyclic shift if self.shift_size > 0: - if roll is None: - _lazy_init_roll(x) - shifted_x = roll( x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2)) else: @@ -438,9 +442,6 @@ class SwinTransformerBlock(nn.Layer): # reverse cyclic shift if self.shift_size > 0: - if roll is None: - _lazy_init_roll(shifted_x) - x = roll( shifted_x, shifts=(self.shift_size, self.shift_size),