add rewrite for ort (#1138)

pull/1155/head
AllentDan 2022-09-29 18:35:51 +08:00 committed by GitHub
parent 161f27d493
commit d153b5aa0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -175,10 +175,9 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward',
backend='tensorrt')
def shift_window_msa__forward__tensorrt(ctx, self, query, hw_shape):
"""Rewrite forward function of ShiftWindowMSA class for TensorRT.
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward')
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
"""Rewrite forward function of ShiftWindowMSA class.
1. replace dynamic padding with static padding and dynamic slice.
2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability.