add rewrite for ort (#1138)
parent
161f27d493
commit
d153b5aa0a
|
@ -175,10 +175,9 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x):
|
|||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward',
|
||||
backend='tensorrt')
|
||||
def shift_window_msa__forward__tensorrt(ctx, self, query, hw_shape):
|
||||
"""Rewrite forward function of ShiftWindowMSA class for TensorRT.
|
||||
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward')
|
||||
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
|
||||
"""Rewrite forward function of ShiftWindowMSA class.
|
||||
|
||||
1. replace dynamic padding with static padding and dynamic slice.
|
||||
2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability.
|
||||
|
|
Loading…
Reference in New Issue