add rewrite for ort (#1138)

pull/1155/head
AllentDan 2022-09-29 18:35:51 +08:00 committed by GitHub
parent 161f27d493
commit d153b5aa0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -175,10 +175,9 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward', func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward')
backend='tensorrt') def shift_window_msa__forward__default(ctx, self, query, hw_shape):
def shift_window_msa__forward__tensorrt(ctx, self, query, hw_shape): """Rewrite forward function of ShiftWindowMSA class.
"""Rewrite forward function of ShiftWindowMSA class for TensorRT.
1. replace dynamic padding with static padding and dynamic slice. 1. replace dynamic padding with static padding and dynamic slice.
2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability. 2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability.