mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
remove global vars
This commit is contained in:
parent
fd4a97d144
commit
a37cee70bf
@ -62,7 +62,6 @@ class RollWithIndexSelect(paddle.autograd.PyLayer):
|
|||||||
grad_input = grad.reshape([N, H * W, C]).index_select(index_bp, 1).reshape([N, H, W, C])
|
grad_input = grad.reshape([N, H * W, C]).index_select(index_bp, 1).reshape([N, H, W, C])
|
||||||
return grad_input, None, None
|
return grad_input, None, None
|
||||||
|
|
||||||
roll_with_index_select = RollWithIndexSelect.apply
|
|
||||||
|
|
||||||
def get_roll_index(H, W, shifts, place):
|
def get_roll_index(H, W, shifts, place):
|
||||||
# following tensors will be created on cpu place with npu custom device
|
# following tensors will be created on cpu place with npu custom device
|
||||||
@ -74,11 +73,26 @@ def get_roll_index(H, W, shifts, place):
|
|||||||
index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place)
|
index_bp = paddle.to_tensor(index_fp, dtype='int64', place=place)
|
||||||
return [index_fp, index_bp]
|
return [index_fp, index_bp]
|
||||||
|
|
||||||
class NpuRollWithIndexSelect():
|
def singleton(class_):
|
||||||
|
instances = {}
|
||||||
|
def getinstance(*args, **kwargs):
|
||||||
|
if class_ not in instances:
|
||||||
|
instances[class_] = class_(*args, **kwargs)
|
||||||
|
return instances[class_]
|
||||||
|
return getinstance
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class RollWrapperSingleton():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.index_dict = {}
|
self.index_dict = {}
|
||||||
|
self.roll_with_index_select = RollWithIndexSelect.apply
|
||||||
|
if 'npu' in paddle.device.get_all_custom_device_type():
|
||||||
|
self.enable = True
|
||||||
|
|
||||||
def __call__(self, x, shifts, axis):
|
def __call__(self, x, shifts, axis):
|
||||||
|
if not self.enable:
|
||||||
|
return padlde.roll(x, shifts, axis)
|
||||||
|
|
||||||
assert x.dim() == 4
|
assert x.dim() == 4
|
||||||
assert len(shifts) == 2
|
assert len(shifts) == 2
|
||||||
assert len(axis) == 2
|
assert len(axis) == 2
|
||||||
@ -87,16 +101,8 @@ class NpuRollWithIndexSelect():
|
|||||||
if key not in self.index_dict:
|
if key not in self.index_dict:
|
||||||
self.index_dict[key] = get_roll_index(H, W, shifts, x.place)
|
self.index_dict[key] = get_roll_index(H, W, shifts, x.place)
|
||||||
index_fp, index_bp = self.index_dict[key]
|
index_fp, index_bp = self.index_dict[key]
|
||||||
return roll_with_index_select(x, index_fp, index_bp)
|
return self.roll_with_index_select(x, index_fp, index_bp)
|
||||||
|
|
||||||
roll = None
|
|
||||||
|
|
||||||
def _lazy_init_roll(x):
|
|
||||||
global roll
|
|
||||||
if 'npu' in paddle.device.get_all_custom_device_type() and hasattr(x, '_place_str') and 'npu' in x._place_str:
|
|
||||||
roll = NpuRollWithIndexSelect()
|
|
||||||
else:
|
|
||||||
roll = paddle.roll
|
|
||||||
|
|
||||||
class Mlp(nn.Layer):
|
class Mlp(nn.Layer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -409,11 +415,9 @@ class SwinTransformerBlock(nn.Layer):
|
|||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
x = x.reshape([B, H, W, C])
|
x = x.reshape([B, H, W, C])
|
||||||
|
|
||||||
|
roll = RollWrapperSingleton()
|
||||||
# cyclic shift
|
# cyclic shift
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
if roll is None:
|
|
||||||
_lazy_init_roll(x)
|
|
||||||
|
|
||||||
shifted_x = roll(
|
shifted_x = roll(
|
||||||
x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
|
x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
|
||||||
else:
|
else:
|
||||||
@ -438,9 +442,6 @@ class SwinTransformerBlock(nn.Layer):
|
|||||||
|
|
||||||
# reverse cyclic shift
|
# reverse cyclic shift
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
if roll is None:
|
|
||||||
_lazy_init_roll(shifted_x)
|
|
||||||
|
|
||||||
x = roll(
|
x = roll(
|
||||||
shifted_x,
|
shifted_x,
|
||||||
shifts=(self.shift_size, self.shift_size),
|
shifts=(self.shift_size, self.shift_size),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user