diff --git a/configs/mmrotate/rotated-detection_tensorrt-fp16_static-1024x1024.py b/configs/mmrotate/rotated-detection_tensorrt-fp16_static-1024x1024.py
new file mode 100644
index 000000000..f46a3109e
--- /dev/null
+++ b/configs/mmrotate/rotated-detection_tensorrt-fp16_static-1024x1024.py
@@ -0,0 +1,16 @@
+_base_ = [
+    './rotated-detection_static.py', '../_base_/backends/tensorrt-fp16.py'
+]
+
+onnx_config = dict(output_names=['dets', 'labels'], input_shape=(1024, 1024))
+
+backend_config = dict(
+    common_config=dict(max_workspace_size=1 << 30),
+    model_inputs=[
+        dict(
+            input_shapes=dict(
+                input=dict(
+                    min_shape=[1, 3, 1024, 1024],
+                    opt_shape=[1, 3, 1024, 1024],
+                    max_shape=[1, 3, 1024, 1024])))
+    ])
diff --git a/docs/en/04-supported-codebases/mmrotate.md b/docs/en/04-supported-codebases/mmrotate.md
index 5099b281a..8ace128e8 100644
--- a/docs/en/04-supported-codebases/mmrotate.md
+++ b/docs/en/04-supported-codebases/mmrotate.md
@@ -182,3 +182,4 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
 | [Rotated FasterRCNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/rotated_faster_rcnn) |      Y      |    Y     |
 | [Oriented R-CNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/oriented_rcnn)           |      Y      |    Y     |
 | [Gliding Vertex](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/gliding_vertex)          |      Y      |    Y     |
+| [RTMDET-R](https://github.com/open-mmlab/mmrotate/blob/dev-1.x/configs/rotated_rtmdet)            |      Y      |    Y     |
diff --git a/docs/zh_cn/04-supported-codebases/mmrotate.md b/docs/zh_cn/04-supported-codebases/mmrotate.md
index d8f2c8971..990440fab 100644
--- a/docs/zh_cn/04-supported-codebases/mmrotate.md
+++ b/docs/zh_cn/04-supported-codebases/mmrotate.md
@@ -186,3 +186,4 @@ det = detector(img)
 | [Rotated FasterRCNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/rotated_faster_rcnn) |      Y      |    Y     |
 | [Oriented R-CNN](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/oriented_rcnn)           |      Y      |    Y     |
 | [Gliding Vertex](https://github.com/open-mmlab/mmrotate/blob/1.x/configs/gliding_vertex)          |      Y      |    Y     |
+| [RTMDET-R](https://github.com/open-mmlab/mmrotate/blob/dev-1.x/configs/rotated_rtmdet)            |      Y      |    Y     |
diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py b/mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py
index d526b50d1..1782f5c68 100644
--- a/mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py
+++ b/mmdeploy/codebase/mmrotate/models/dense_heads/__init__.py
@@ -1,2 +1,3 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from . import oriented_rpn_head  # noqa: F401, F403
+from . import rotated_rtmdet_head  # noqa: F401, F403
diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rtmdet_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rtmdet_head.py
new file mode 100644
index 000000000..e60b9ca9e
--- /dev/null
+++ b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rtmdet_head.py
@@ -0,0 +1,119 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Tuple
+
+import torch
+from mmengine.config import ConfigDict
+from mmrotate.structures import norm_angle
+from torch import Tensor
+
+from mmdeploy.codebase.mmdet import get_post_processing_params
+from mmdeploy.core import FUNCTION_REWRITER
+from mmdeploy.mmcv.ops.nms_rotated import multiclass_nms_rotated
+
+
+@FUNCTION_REWRITER.register_rewriter(
+    func_name='mmrotate.models.dense_heads.rotated_rtmdet_head.'
+    'RotatedRTMDetHead.predict_by_feat')
+def rotated_rtmdet_head__predict_by_feat(
+        self,
+        cls_scores: List[Tensor],
+        bbox_preds: List[Tensor],
+        angle_preds: List[Tensor],
+        batch_img_metas: Optional[List[dict]] = None,
+        cfg: Optional[ConfigDict] = None,
+        rescale: bool = False,
+        with_nms: bool = True) -> Tuple[Tensor]:
+    """Rewrite `predict_by_feat` of `Rotated RTMDet` for default backend.
+
+    Rewrite this function to deploy model, transform network output for a
+    batch into bbox predictions.
+
+    Args:
+        cls_scores (list[Tensor]): Classification scores for all
+            scale levels, each is a 4D-tensor, has shape
+            (batch_size, num_priors * num_classes, H, W).
+        bbox_preds (list[Tensor]): Box energies / deltas for all
+            scale levels, each is a 4D-tensor, has shape
+            (batch_size, num_priors * 4, H, W).
+        angle_preds (list[Tensor]): Box angle for each scale level
+            with shape (batch_size, num_priors * angle_dim, H, W)
+        batch_img_metas (list[dict], Optional): Batch image meta info.
+            Defaults to None.
+        cfg (ConfigDict, optional): Test / postprocessing
+            configuration, if None, test_cfg would be used.
+            Defaults to None.
+        rescale (bool): If True, return boxes in original image space.
+            Defaults to False.
+        with_nms (bool): If True, do nms before return boxes.
+            Defaults to True.
+
+    Returns:
+        tuple[Tensor, Tensor]: The first item is an (N, num_box, 6) tensor,
+            where 5 represent (x, y, w, h, angle, score), N is batch
+            size and the score between 0 and 1. The shape of the second
+            tensor in the tuple is (N, num_box), and each element
+            represents the class label of the corresponding box.
+    """
+    ctx = FUNCTION_REWRITER.get_context()
+    assert len(cls_scores) == len(bbox_preds)
+    device = cls_scores[0].device
+    cfg = self.test_cfg if cfg is None else cfg
+    batch_size = bbox_preds[0].shape[0]
+    featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
+    mlvl_priors = self.prior_generator.grid_priors(
+        featmap_sizes, device=device)
+
+    flatten_cls_scores = [
+        cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
+                                              self.cls_out_channels)
+        for cls_score in cls_scores
+    ]
+    flatten_bbox_preds = [
+        bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
+        for bbox_pred in bbox_preds
+    ]
+    flatten_angle_preds = [
+        angle_pred.permute(0, 2, 3, 1).reshape(batch_size, -1,
+                                               self.angle_coder.encode_size)
+        for angle_pred in angle_preds
+    ]
+    flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
+    flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
+    flatten_angle_preds = torch.cat(flatten_angle_preds, dim=1)
+    priors = torch.cat(mlvl_priors)
+
+    angle = self.angle_coder.decode(flatten_angle_preds, keepdim=True)
+    distance = flatten_bbox_preds
+    cos_angle, sin_angle = torch.cos(angle), torch.sin(angle)
+
+    rot_matrix = torch.cat([cos_angle, -sin_angle, sin_angle, cos_angle],
+                           dim=-1)
+    rot_matrix = rot_matrix.reshape(*rot_matrix.shape[:-1], 2, 2)
+
+    wh = distance[..., :2] + distance[..., 2:]
+    offset_t = (distance[..., 2:] - distance[..., :2]) / 2
+    offset_t = offset_t.unsqueeze(-1)
+    offset = torch.matmul(rot_matrix, offset_t).squeeze(-1)
+    ctr = priors[..., :2] + offset
+
+    angle_regular = norm_angle(angle, self.angle_version)
+    bboxes = torch.cat([ctr, wh, angle_regular], dim=-1)
+
+    # directly multiply score factor and feed to nms
+    max_scores, _ = torch.max(flatten_cls_scores, 1)
+    mask = max_scores >= cfg.score_thr
+    scores = flatten_cls_scores.where(mask, flatten_cls_scores.new_zeros(1))
+    if not with_nms:
+        return bboxes, scores
+
+    deploy_cfg = ctx.cfg
+    post_params = get_post_processing_params(deploy_cfg)
+    max_output_boxes_per_class = post_params.max_output_boxes_per_class
+    iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
+    score_threshold = cfg.get('score_thr', post_params.score_threshold)
+    pre_top_k = post_params.pre_top_k
+    keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
+
+    return multiclass_nms_rotated(bboxes, scores, max_output_boxes_per_class,
+                                  iou_threshold, score_threshold, pre_top_k,
+                                  keep_top_k)
diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py
index 28414fa27..3c5862679 100644
--- a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py
+++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py
@@ -335,3 +335,126 @@ def test_gvfixcoder__decode(backend_type: Backend):
         run_with_backend=False)
 
     assert rewrite_outputs is not None
+
+
+def get_rotated_rtmdet_head_model():
+    """RTMDet-R Head Config."""
+    test_cfg = Config(
+        dict(
+            deploy_nms_pre=0,
+            min_bbox_size=0,
+            score_thr=0.05,
+            nms=dict(type='nms_rotated', iou_threshold=0.1),
+            max_per_img=2000))
+
+    from mmrotate.models.dense_heads import RotatedRTMDetHead
+    model = RotatedRTMDetHead(
+        num_classes=4,
+        in_channels=1,
+        anchor_generator=dict(
+            type='mmdet.MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
+        bbox_coder=dict(type='DistanceAnglePointCoder', angle_version='le90'),
+        loss_cls=dict(
+            type='mmdet.QualityFocalLoss',
+            use_sigmoid=True,
+            beta=2.0,
+            loss_weight=1.0),
+        loss_bbox=dict(type='RotatedIoULoss', mode='linear', loss_weight=2.0),
+        test_cfg=test_cfg)
+
+    model.requires_grad_(False)
+    return model
+
+
+@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
+def test_rotated_rtmdet_head_predict_by_feat(backend_type: Backend):
+    """Test predict_by_feat rewrite of RTMDet-R."""
+    check_backend(backend_type)
+    rtm_r_head = get_rotated_rtmdet_head_model()
+    rtm_r_head.cpu().eval()
+    s = 128
+    batch_img_metas = [{
+        'scale_factor': np.ones(4),
+        'pad_shape': (s, s, 3),
+        'img_shape': (s, s, 3)
+    }]
+    output_names = ['dets', 'labels']
+    deploy_cfg = Config(
+        dict(
+            backend_config=dict(type=backend_type.value),
+            onnx_config=dict(output_names=output_names, input_shape=None),
+            codebase_config=dict(
+                type='mmrotate',
+                task='RotatedDetection',
+                post_processing=dict(
+                    score_threshold=0.05,
+                    iou_threshold=0.1,
+                    pre_top_k=3000,
+                    keep_top_k=2000,
+                    max_output_boxes_per_class=2000))))
+    seed_everything(1234)
+    cls_scores = [
+        torch.rand(1, rtm_r_head.num_classes, 2 * pow(2, i), 2 * pow(2, i))
+        for i in range(3, 0, -1)
+    ]
+    seed_everything(5678)
+    bbox_preds = [
+        torch.rand(1, 4, 2 * pow(2, i), 2 * pow(2, i))
+        for i in range(3, 0, -1)
+    ]
+    seed_everything(9101)
+    angle_preds = [
+        torch.rand(1, rtm_r_head.angle_coder.encode_size, 2 * pow(2, i),
+                   2 * pow(2, i)) for i in range(3, 0, -1)
+    ]
+
+    # to get outputs of pytorch model
+    model_inputs = {
+        'cls_scores': cls_scores,
+        'bbox_preds': bbox_preds,
+        'angle_preds': angle_preds,
+        'batch_img_metas': batch_img_metas,
+        'with_nms': True
+    }
+    model_outputs = get_model_outputs(rtm_r_head, 'predict_by_feat',
+                                      model_inputs)
+
+    # to get outputs of onnx model after rewrite
+    wrapped_model = WrapModel(
+        rtm_r_head,
+        'predict_by_feat',
+        batch_img_metas=batch_img_metas,
+        with_nms=True)
+    rewrite_inputs = {
+        'cls_scores': cls_scores,
+        'bbox_preds': bbox_preds,
+        'angle_preds': angle_preds,
+    }
+    rewrite_outputs, is_backend_output = get_rewrite_outputs(
+        wrapped_model=wrapped_model,
+        model_inputs=rewrite_inputs,
+        deploy_cfg=deploy_cfg)
+
+    if is_backend_output:
+        # hard code to make two tensors with the same shape
+        # rewrite and original codes applied different nms strategy
+        min_shape = min(model_outputs[0].bboxes.shape[0],
+                        rewrite_outputs[0].shape[1], 5)
+        for i in range(len(model_outputs)):
+            assert np.allclose(
+                model_outputs[i].bboxes.tensor[:min_shape],
+                rewrite_outputs[0][i, :min_shape, :5],
+                rtol=1e-03,
+                atol=1e-05)
+            assert np.allclose(
+                model_outputs[i].scores[:min_shape],
+                rewrite_outputs[0][i, :min_shape, 5],
+                rtol=1e-03,
+                atol=1e-05)
+            assert np.allclose(
+                model_outputs[i].labels[:min_shape],
+                rewrite_outputs[1][i, :min_shape],
+                rtol=1e-03,
+                atol=1e-05)
+    else:
+        assert rewrite_outputs is not None