diff --git a/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt b/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt
index 2267b029d..968b6e956 100644
--- a/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt
+++ b/csrc/mmdeploy/codebase/mmpose/CMakeLists.txt
@@ -6,7 +6,9 @@ project(mmdeploy_mmpose)
 file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
 mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
 target_link_libraries(${PROJECT_NAME} PRIVATE
-        mmdeploy::transform mmdeploy_opencv_utils)
+        mmdeploy::transform
+        mmdeploy_operation
+        mmdeploy_opencv_utils)
 add_library(mmdeploy::mmpose ALIAS ${PROJECT_NAME})
 
 set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} pose_detector CACHE INTERNAL "")
diff --git a/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp b/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp
index d49885fbc..75b93fa84 100644
--- a/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp
+++ b/csrc/mmdeploy/codebase/mmpose/topdown_affine.cpp
@@ -7,6 +7,8 @@
 #include "mmdeploy/core/tensor.h"
 #include "mmdeploy/core/utils/device_utils.h"
 #include "mmdeploy/core/utils/formatter.h"
+#include "mmdeploy/operation/managed.h"
+#include "mmdeploy/operation/vision.h"
 #include "mmdeploy/preprocess/transform/transform.h"
 #include "opencv2/imgproc.hpp"
 #include "opencv_utils.h"
@@ -32,6 +34,7 @@ class TopDownAffine : public transform::Transform {
     stream_ = args["context"]["stream"].get<Stream>();
     assert(args.contains("image_size"));
     from_value(args["image_size"], image_size_);
+    warp_affine_ = operation::Managed<operation::WarpAffine>::Create("bilinear");
   }
 
   ~TopDownAffine() override = default;
@@ -39,11 +42,7 @@ class TopDownAffine : public transform::Transform {
   Result<void> Apply(Value& data) override {
     MMDEPLOY_DEBUG("top_down_affine input: {}", data);
 
-    Device host{"cpu"};
-    auto _img = data["img"].get<Tensor>();
-    OUTCOME_TRY(auto img, MakeAvailableOnDevice(_img, host, stream_));
-    stream_.Wait().value();
-    auto src = cpu::Tensor2CVMat(img);
+    auto img = data["img"].get<Tensor>();
 
     // prepare data
     vector<float> bbox;
@@ -62,21 +61,20 @@ class TopDownAffine : public transform::Transform {
 
     auto r = data["rotation"].get<float>();
 
-    cv::Mat dst;
+    Tensor dst;
     if (use_udp_) {
       cv::Mat trans =
           GetWarpMatrix(r, {c[0] * 2.f, c[1] * 2.f}, {image_size_[0] - 1.f, image_size_[1] - 1.f},
                         {s[0] * 200.f, s[1] * 200.f});
-
-      cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR);
+      OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr<float>(), image_size_[1], image_size_[0]));
     } else {
       cv::Mat trans =
           GetAffineTransform({c[0], c[1]}, {s[0], s[1]}, r, {image_size_[0], image_size_[1]});
-      cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR);
+      OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr<float>(), image_size_[1], image_size_[0]));
     }
 
-    data["img"] = cpu::CVMat2Tensor(dst);
-    data["img_shape"] = {1, image_size_[1], image_size_[0], dst.channels()};
+    data["img_shape"] = {1, image_size_[1], image_size_[0], dst.shape(3)};
+    data["img"] = std::move(dst);
     data["center"] = to_value(c);
     data["scale"] = to_value(s);
     MMDEPLOY_DEBUG("output: {}", data);
@@ -106,7 +104,7 @@ class TopDownAffine : public transform::Transform {
     theta = theta * 3.1415926 / 180;
     float scale_x = size_dst.width / size_target.width;
     float scale_y = size_dst.height / size_target.height;
-    cv::Mat matrix = cv::Mat(2, 3, CV_32FC1);
+    cv::Mat matrix = cv::Mat(2, 3, CV_32F);
     matrix.at<float>(0, 0) = std::cos(theta) * scale_x;
     matrix.at<float>(0, 1) = -std::sin(theta) * scale_x;
     matrix.at<float>(0, 2) =
@@ -142,6 +140,7 @@ class TopDownAffine : public transform::Transform {
 
     cv::Mat trans = inv ? cv::getAffineTransform(dst_points, src_points)
                         : cv::getAffineTransform(src_points, dst_points);
+    trans.convertTo(trans, CV_32F);
     return trans;
   }
 
@@ -160,6 +159,7 @@ class TopDownAffine : public transform::Transform {
   }
 
  protected:
+  operation::Managed<operation::WarpAffine> warp_affine_;
   bool use_udp_{false};
   vector<int> image_size_;
   std::string backend_;
diff --git a/csrc/mmdeploy/operation/cpu/CMakeLists.txt b/csrc/mmdeploy/operation/cpu/CMakeLists.txt
index 7a4edad41..d1310baae 100644
--- a/csrc/mmdeploy/operation/cpu/CMakeLists.txt
+++ b/csrc/mmdeploy/operation/cpu/CMakeLists.txt
@@ -9,7 +9,8 @@ set(SRCS resize.cpp
         hwc2chw.cpp
         normalize.cpp
         crop.cpp
-        flip.cpp)
+        flip.cpp
+        warp_affine.cpp)
 
 mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
 
diff --git a/csrc/mmdeploy/operation/cpu/resize.cpp b/csrc/mmdeploy/operation/cpu/resize.cpp
index 8c345df17..33c5ce313 100644
--- a/csrc/mmdeploy/operation/cpu/resize.cpp
+++ b/csrc/mmdeploy/operation/cpu/resize.cpp
@@ -7,7 +7,7 @@ namespace mmdeploy::operation::cpu {
 
 class ResizeImpl : public Resize {
  public:
-  ResizeImpl(std::string interp) : interp_(std::move(interp)) {}
+  explicit ResizeImpl(std::string interp) : interp_(std::move(interp)) {}
 
   Result<void> apply(const Tensor& src, Tensor& dst, int dst_h, int dst_w) override {
     auto src_mat = mmdeploy::cpu::Tensor2CVMat(src);
diff --git a/csrc/mmdeploy/operation/cpu/warp_affine.cpp b/csrc/mmdeploy/operation/cpu/warp_affine.cpp
new file mode 100644
index 000000000..5b5914db7
--- /dev/null
+++ b/csrc/mmdeploy/operation/cpu/warp_affine.cpp
@@ -0,0 +1,29 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "mmdeploy/operation/vision.h"
+#include "mmdeploy/utils/opencv/opencv_utils.h"
+
+namespace mmdeploy::operation::cpu {
+
+class WarpAffineImpl : public WarpAffine {
+ public:
+  explicit WarpAffineImpl(int method) : method_(method) {}
+
+  Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h,
+                     int dst_w) override {
+    auto src_mat = mmdeploy::cpu::Tensor2CVMat(src);
+    cv::Mat_<float> _matrix(2, 3, const_cast<float*>(affine_matrix));
+    auto dst_mat = mmdeploy::cpu::WarpAffine(src_mat, _matrix, dst_h, dst_w, method_);
+    dst = mmdeploy::cpu::CVMat2Tensor(dst_mat);
+    return success();
+  }
+
+ private:
+  int method_;
+};
+
+MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cpu, 0), [](const string_view& interp) {
+  return std::make_unique<WarpAffineImpl>(::mmdeploy::cpu::GetInterpolationMethod(interp).value());
+});
+
+}  // namespace mmdeploy::operation::cpu
diff --git a/csrc/mmdeploy/operation/cuda/CMakeLists.txt b/csrc/mmdeploy/operation/cuda/CMakeLists.txt
index d962d3c5b..5e04f640b 100644
--- a/csrc/mmdeploy/operation/cuda/CMakeLists.txt
+++ b/csrc/mmdeploy/operation/cuda/CMakeLists.txt
@@ -17,7 +17,8 @@ set(SRCS resize.cpp
         normalize.cu
         crop.cpp
         crop.cu
-        flip.cpp)
+        flip.cpp
+        warp_affine.cpp)
 
 mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
 
diff --git a/csrc/mmdeploy/operation/cuda/warp_affine.cpp b/csrc/mmdeploy/operation/cuda/warp_affine.cpp
new file mode 100644
index 000000000..4f2071c06
--- /dev/null
+++ b/csrc/mmdeploy/operation/cuda/warp_affine.cpp
@@ -0,0 +1,118 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "mmdeploy/core/utils/formatter.h"
+#include "mmdeploy/operation/vision.h"
+#include "ppl/cv/cuda/warpaffine.h"
+
+namespace mmdeploy::operation::cuda {
+
+class WarpAffineImpl : public WarpAffine {
+ public:
+  explicit WarpAffineImpl(ppl::cv::InterpolationType interp) : interp_(interp) {}
+
+  Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h,
+                     int dst_w) override {
+    assert(src.device() == device());
+
+    TensorDesc desc{device(), src.data_type(), {1, dst_h, dst_w, src.shape(3)}, src.name()};
+    Tensor dst_tensor(desc);
+
+    const auto m = affine_matrix;
+    auto inv = Invert(affine_matrix);
+
+    auto cuda_stream = GetNative<cudaStream_t>(stream());
+    if (src.data_type() == DataType::kINT8) {
+      OUTCOME_TRY(Dispatch<uint8_t>(src, dst_tensor, inv.data(), cuda_stream));
+    } else if (src.data_type() == DataType::kFLOAT) {
+      OUTCOME_TRY(Dispatch<float>(src, dst_tensor, inv.data(), cuda_stream));
+    } else {
+      MMDEPLOY_ERROR("unsupported data type {}", src.data_type());
+      return Status(eNotSupported);
+    }
+
+    dst = std::move(dst_tensor);
+    return success();
+  }
+
+ private:
+  // ppl.cv uses inverted transform
+  // https://github.com/opencv/opencv/blob/bc6544c0bcfa9ca5db5e0d0551edf5c8e7da3852/modules/imgproc/src/imgwarp.cpp#L3478
+  static std::array<float, 6> Invert(const float affine_matrix[6]) {
+    const auto* M = affine_matrix;
+    std::array<float, 6> inv{};
+    auto iM = inv.data();
+
+    auto D = M[0] * M[3 + 1] - M[1] * M[3];
+    D = D != 0.f ? 1.f / D : 0.f;
+    auto A11 = M[3 + 1] * D, A22 = M[0] * D, A12 = -M[1] * D, A21 = -M[3] * D;
+    auto b1 = -A11 * M[2] - A12 * M[3 + 2];
+    auto b2 = -A21 * M[2] - A22 * M[3 + 2];
+
+    iM[0] = A11;
+    iM[1] = A12;
+    iM[2] = b1;
+    iM[3] = A21;
+    iM[3 + 1] = A22;
+    iM[3 + 2] = b2;
+
+    return inv;
+  }
+
+  template <typename T>
+  auto Select(int channels) -> decltype(&ppl::cv::cuda::WarpAffine<T, 1>) {
+    switch (channels) {
+      case 1:
+        return &ppl::cv::cuda::WarpAffine<T, 1>;
+      case 3:
+        return &ppl::cv::cuda::WarpAffine<T, 3>;
+      case 4:
+        return &ppl::cv::cuda::WarpAffine<T, 4>;
+      default:
+        MMDEPLOY_ERROR("unsupported channels {}", channels);
+        return nullptr;
+    }
+  }
+
+  template <class T>
+  Result<void> Dispatch(const Tensor& src, Tensor& dst, const float affine_matrix[6],
+                        cudaStream_t stream) {
+    int h = (int)src.shape(1);
+    int w = (int)src.shape(2);
+    int c = (int)src.shape(3);
+    int dst_h = (int)dst.shape(1);
+    int dst_w = (int)dst.shape(2);
+
+    auto input = src.data<T>();
+    auto output = dst.data<T>();
+
+    ppl::common::RetCode ret = 0;
+
+    if (auto warp_affine = Select<T>(c); warp_affine) {
+      ret = warp_affine(stream, h, w, w * c, input, dst_h, dst_w, dst_w * c, output, affine_matrix,
+                        interp_, ppl::cv::BORDER_CONSTANT, 0);
+    } else {
+      return Status(eNotSupported);
+    }
+
+    return ret == 0 ? success() : Result<void>(Status(eFail));
+  }
+
+  ppl::cv::InterpolationType interp_;
+};
+
+static auto Create(const string_view& interp) {
+  ppl::cv::InterpolationType type{};
+  if (interp == "bilinear") {
+    type = ppl::cv::InterpolationType::INTERPOLATION_LINEAR;
+  } else if (interp == "nearest") {
+    type = ppl::cv::InterpolationType::INTERPOLATION_NEAREST_POINT;
+  } else {
+    MMDEPLOY_ERROR("unsupported interpolation method: {}", interp);
+    throw_exception(eNotSupported);
+  }
+  return std::make_unique<WarpAffineImpl>(type);
+}
+
+MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cuda, 0), Create);
+
+}  // namespace mmdeploy::operation::cuda
diff --git a/csrc/mmdeploy/operation/vision.cpp b/csrc/mmdeploy/operation/vision.cpp
index c7f7ba77d..35076e2bd 100644
--- a/csrc/mmdeploy/operation/vision.cpp
+++ b/csrc/mmdeploy/operation/vision.cpp
@@ -12,5 +12,6 @@ MMDEPLOY_DEFINE_REGISTRY(HWC2CHW);
 MMDEPLOY_DEFINE_REGISTRY(Normalize);
 MMDEPLOY_DEFINE_REGISTRY(Crop);
 MMDEPLOY_DEFINE_REGISTRY(Flip);
+MMDEPLOY_DEFINE_REGISTRY(WarpAffine);
 
 }  // namespace mmdeploy::operation
diff --git a/csrc/mmdeploy/operation/vision.h b/csrc/mmdeploy/operation/vision.h
index aea99859c..9b65dbaaa 100644
--- a/csrc/mmdeploy/operation/vision.h
+++ b/csrc/mmdeploy/operation/vision.h
@@ -76,7 +76,13 @@ class Flip : public Operation {
 };
 MMDEPLOY_DECLARE_REGISTRY(Flip, unique_ptr<Flip>(int flip_code));
 
-// TODO: warp affine
+// 2x3 OpenCV affine matrix, row major
+class WarpAffine : public Operation {
+ public:
+  virtual Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6],
+                             int dst_h, int dst_w) = 0;
+};
+MMDEPLOY_DECLARE_REGISTRY(WarpAffine, unique_ptr<WarpAffine>(const string_view& interp));
 
 }  // namespace mmdeploy::operation
 
diff --git a/csrc/mmdeploy/preprocess/transform/load.cpp b/csrc/mmdeploy/preprocess/transform/load.cpp
index 5640d1c47..57879d5b4 100644
--- a/csrc/mmdeploy/preprocess/transform/load.cpp
+++ b/csrc/mmdeploy/preprocess/transform/load.cpp
@@ -48,6 +48,12 @@ class PrepareImage : public Transform {
 
   Result<void> Apply(Value& data) override {
     MMDEPLOY_DEBUG("input: {}", data);
+
+    // early exit
+    if (data.contains("img") && data["img"].is_any<Tensor>()) {
+      return success();
+    }
+
     assert(data.contains("ori_img"));
 
     Mat src_mat = data["ori_img"].get<Mat>();
diff --git a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp
index b2801cb3e..d410d5dcc 100644
--- a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp
+++ b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp
@@ -106,23 +106,34 @@ Tensor CVMat2Tensor(const cv::Mat& mat) {
   return Tensor{desc, data};
 }
 
+Result<int> GetInterpolationMethod(const std::string_view& method) {
+  if (method == "bilinear") {
+    return cv::INTER_LINEAR;
+  } else if (method == "nearest") {
+    return cv::INTER_NEAREST;
+  } else if (method == "area") {
+    return cv::INTER_AREA;
+  } else if (method == "bicubic") {
+    return cv::INTER_CUBIC;
+  } else if (method == "lanczos") {
+    return cv::INTER_LANCZOS4;
+  }
+  MMDEPLOY_ERROR("unsupported interpolation method: {}", method);
+  return Status(eNotSupported);
+}
+
 cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
                const std::string& interpolation) {
   cv::Mat dst(dst_height, dst_width, src.type());
-  if (interpolation == "bilinear") {
-    cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_LINEAR);
-  } else if (interpolation == "nearest") {
-    cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_NEAREST);
-  } else if (interpolation == "area") {
-    cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_AREA);
-  } else if (interpolation == "bicubic") {
-    cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_CUBIC);
-  } else if (interpolation == "lanczos") {
-    cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_LANCZOS4);
-  } else {
-    MMDEPLOY_ERROR("{} interpolation is not supported", interpolation);
-    assert(0);
-  }
+  auto method = GetInterpolationMethod(interpolation).value();
+  cv::resize(src, dst, dst.size(), method);
+  return dst;
+}
+
+cv::Mat WarpAffine(const cv::Mat& src, const cv::Mat& affine_matrix, int dst_height, int dst_width,
+                   int interpolation) {
+  cv::Mat dst(dst_height, dst_width, src.type());
+  cv::warpAffine(src, dst, affine_matrix, dst.size(), interpolation);
   return dst;
 }
 
diff --git a/csrc/mmdeploy/utils/opencv/opencv_utils.h b/csrc/mmdeploy/utils/opencv/opencv_utils.h
index 6f5e432b9..0c9646466 100644
--- a/csrc/mmdeploy/utils/opencv/opencv_utils.h
+++ b/csrc/mmdeploy/utils/opencv/opencv_utils.h
@@ -18,6 +18,8 @@ MMDEPLOY_API cv::Mat Tensor2CVMat(const framework::Tensor& tensor);
 MMDEPLOY_API framework::Mat CVMat2Mat(const cv::Mat& mat, PixelFormat format);
 MMDEPLOY_API framework::Tensor CVMat2Tensor(const cv::Mat& mat);
 
+MMDEPLOY_API Result<int> GetInterpolationMethod(const std::string_view& method);
+
 /**
  * @brief resize an image to specified size
  *
@@ -29,6 +31,9 @@ MMDEPLOY_API framework::Tensor CVMat2Tensor(const cv::Mat& mat);
 MMDEPLOY_API cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
                             const std::string& interpolation);
 
+MMDEPLOY_API cv::Mat WarpAffine(const cv::Mat& src, const cv::Mat& affine_matrix, int dst_height,
+                                int dst_width, int interpolation);
+
 /**
  * @brief crop an image
  *
diff --git a/demo/csrc/cpp/pose_tracker.cpp b/demo/csrc/cpp/pose_tracker.cpp
index 896fe7565..1ddab8978 100644
--- a/demo/csrc/cpp/pose_tracker.cpp
+++ b/demo/csrc/cpp/pose_tracker.cpp
@@ -1,5 +1,8 @@
 
 
+#include <cmath>
+#include <numeric>
+
 #include "mmdeploy/archive/json_archive.h"
 #include "mmdeploy/archive/value_archive.h"
 #include "mmdeploy/common.hpp"
@@ -15,9 +18,17 @@
 const auto config_json = R"(
 {
   "type": "Pipeline",
-  "input": ["data", "use_det", "state"],
+  "input": ["img", "use_det", "state"],
   "output": "targets",
   "tasks": [
+    {
+      "type": "Task",
+      "module": "Transform",
+      "name": "preload",
+      "input": "img",
+      "output": "data",
+      "transforms": [ { "type": "LoadImageFromFile" } ]
+    },
     {
       "type": "Cond",
       "input": ["use_det", "data"],
@@ -32,7 +43,7 @@ const auto config_json = R"(
       "type": "Task",
       "module": "ProcessBboxes",
       "input": ["dets", "data", "state"],
-      "output": "rois"
+      "output": ["rois", "track_ids"]
     },
     {
       "input": "*rois",
@@ -45,7 +56,7 @@ const auto config_json = R"(
       "type": "Task",
       "module": "TrackPose",
       "scheduler": "pool",
-      "input": ["keypoints", "state"],
+      "input": ["keypoints", "track_ids", "state"],
       "output": "targets"
     }
   ]
@@ -57,26 +68,38 @@ namespace mmdeploy {
 #define REGISTER_SIMPLE_MODULE(name, fn) \
   MMDEPLOY_REGISTER_FACTORY_FUNC(Module, (name, 0), [](const Value&) { return CreateTask(fn); });
 
-std::optional<std::array<float, 4>> keypoints_to_bbox(const std::vector<cv::Point2f>& keypoints,
-                                                      const std::vector<float>& scores, float img_h,
-                                                      float img_w, float scale = 1.5,
-                                                      float kpt_thr = 0.3) {
-  auto valid = false;
+#define POSE_TRACKER_DEBUG(...) MMDEPLOY_INFO(__VA_ARGS__)
+
+using std::vector;
+using Bbox = std::array<float, 4>;
+using Bboxes = vector<Bbox>;
+using Point = cv::Point2f;
+using Points = vector<cv::Point2f>;
+using Score = float;
+using Scores = vector<float>;
+
+// scale = 1.5, kpt_thr = 0.3
+std::optional<Bbox> keypoints_to_bbox(const Points& keypoints, const Scores& scores, float img_h,
+                                      float img_w, float scale, float kpt_thr, int min_keypoints) {
+  int valid = 0;
   auto x1 = static_cast<float>(img_w);
   auto y1 = static_cast<float>(img_h);
   auto x2 = 0.f;
   auto y2 = 0.f;
   for (size_t i = 0; i < keypoints.size(); ++i) {
     auto& kpt = keypoints[i];
-    if (scores[i] > kpt_thr) {
+    if (scores[i] >= kpt_thr) {
       x1 = std::min(x1, kpt.x);
       y1 = std::min(y1, kpt.y);
       x2 = std::max(x2, kpt.x);
       y2 = std::max(y2, kpt.y);
-      valid = true;
+      ++valid;
     }
   }
-  if (!valid) {
+  if (min_keypoints < 0) {
+    min_keypoints = (static_cast<int>(scores.size()) + 1) / 2;
+  }
+  if (valid < min_keypoints) {
     return std::nullopt;
   }
   auto xc = .5f * (x1 + x2);
@@ -92,223 +115,781 @@ std::optional<std::array<float, 4>> keypoints_to_bbox(const std::vector<cv::Poin
   };
 }
 
+class Filter {
+ public:
+  virtual ~Filter() = default;
+  virtual cv::Mat_<float> Predict(float t) = 0;
+  virtual cv::Mat_<float> Correct(const cv::Mat_<float>& x) = 0;
+};
+
+class OneEuroFilter : public Filter {
+ public:
+  explicit OneEuroFilter(const cv::Mat_<float>& x, float beta, float fc_min, float fc_d)
+      : x_(x.clone()), beta_(beta), fc_min_(fc_min), fc_d_(fc_d) {
+    v_ = cv::Mat::zeros(x_.size(), x.type());
+  }
+
+  cv::Mat_<float> Predict(float t) override { return x_ + v_; }
+
+  cv::Mat_<float> Correct(const cv::Mat_<float>& x) override {
+    auto a_v = SmoothingFactor(fc_d_);
+    v_ = ExponentialSmoothing(a_v, x - x_, v_);
+    auto fc = fc_min_ + beta_ * (float)cv::norm(v_);
+    auto a_x = SmoothingFactor(fc);
+    x_ = ExponentialSmoothing(a_x, x, x_);
+    return x_.clone();
+  }
+
+ private:
+  static float SmoothingFactor(float cutoff) {
+    static constexpr float kPi = 3.1415926;
+    auto r = 2 * kPi * cutoff;
+    return r / (r + 1);
+  }
+
+  static cv::Mat_<float> ExponentialSmoothing(float a, const cv::Mat_<float>& x,
+                                              const cv::Mat_<float>& x0) {
+    return a * x + (1 - a) * x0;
+  }
+
+ private:
+  cv::Mat_<float> x_;
+  cv::Mat_<float> v_;
+  float beta_;
+  float fc_min_;
+  float fc_d_;
+};
+
+template <typename T>
+class PointFilterArray : public Filter {
+ public:
+  template <typename... Args>
+  explicit PointFilterArray(const Points& ps, const Args&... args) {
+    for (const auto& p : ps) {
+      fs_.emplace_back(cv::Mat_<float>(p, false), args...);
+    }
+  }
+
+  cv::Mat_<float> Predict(float t) override {
+    cv::Mat_<float> m(fs_.size() * 2, 1);
+    for (int i = 0; i < fs_.size(); ++i) {
+      cv::Range r(i * 2, i * 2 + 2);
+      fs_[i].Predict(1).copyTo(m.rowRange(r));
+    }
+    return m.reshape(0, fs_.size());
+  }
+
+  cv::Mat_<float> Correct(const cv::Mat_<float>& x) override {
+    cv::Mat_<float> m(fs_.size() * 2, 1);
+    auto _x = x.reshape(1, x.rows * x.cols);
+    for (int i = 0; i < fs_.size(); ++i) {
+      cv::Range r(i * 2, i * 2 + 2);
+      fs_[i].Correct(_x.rowRange(r)).copyTo(m.rowRange(r));
+    }
+    return m.reshape(0, fs_.size());
+  }
+
+ private:
+  vector<T> fs_;
+};
+
+class TrackerFilter {
+ public:
+  using Points = vector<cv::Point2f>;
+
+  explicit TrackerFilter(float c_beta, float c_fc_min, float c_fc_d, float k_beta, float k_fc_min,
+                         float k_fc_d, const Bbox& bbox, const Points& kpts)
+      : n_kpts_(kpts.size()) {
+    c_ = std::make_unique<OneEuroFilter>(cv::Mat_<float>(Center(bbox)), c_beta, c_fc_min, c_fc_d);
+    s_ = std::make_unique<OneEuroFilter>(cv::Mat_<float>(Scale(bbox)), 0, 1, 0);
+    kpts_ = std::make_unique<PointFilterArray<OneEuroFilter>>(kpts, k_beta, k_fc_min, k_fc_d);
+  }
+
+  std::pair<Bbox, Points> Predict() {
+    cv::Point2f c;
+    c_->Predict(1).copyTo(cv::Mat(c, false));
+    cv::Point2f s;
+    s_->Predict(0).copyTo(cv::Mat(s, false));
+    Points p(n_kpts_);
+    kpts_->Predict(1).copyTo(cv::Mat(p, false).reshape(1));
+    return {GetBbox(c, s), std::move(p)};
+  }
+
+  std::pair<Bbox, Points> Correct(const Bbox& bbox, const Points& kpts) {
+    cv::Point2f c;
+    c_->Correct(cv::Mat_<float>(Center(bbox), false)).copyTo(cv::Mat(c, false));
+    cv::Point2f s;
+    s_->Correct(cv::Mat_<float>(Scale(bbox), false)).copyTo(cv::Mat(s, false));
+    Points p(kpts.size());
+    kpts_->Correct(cv::Mat(kpts, false)).copyTo(cv::Mat(p, false).reshape(1));
+    return {GetBbox(c, s), std::move(p)};
+  }
+
+ private:
+  static cv::Point2f Center(const Bbox& bbox) {
+    return {.5f * (bbox[0] + bbox[2]), .5f * (bbox[1] + bbox[3])};
+  }
+  static cv::Point2f Scale(const Bbox& bbox) {
+    return {bbox[2] - bbox[0], bbox[3] - bbox[1]};
+    //    return {std::log(bbox[2] - bbox[0]), std::log(bbox[3] - bbox[1])};
+  }
+  static Bbox GetBbox(const cv::Point2f& center, const cv::Point2f& scale) {
+    //    cv::Point2f half_size(.5 * std::exp(scale.x), .5 * std::exp(scale.y));
+    Point half_size(.5f * scale.x, .5f * scale.y);
+    auto lo = center - half_size;
+    auto hi = center + half_size;
+    return {lo.x, lo.y, hi.x, hi.y};
+  }
+  int n_kpts_;
+  std::unique_ptr<Filter> c_;
+  std::unique_ptr<Filter> s_;
+  std::unique_ptr<Filter> kpts_;
+};
+
 struct Track {
-  std::vector<std::vector<cv::Point2f>> keypoints;
-  std::vector<std::vector<float>> scores;
-  std::vector<std::array<float, 4>> bboxes;
+  vector<Points> keypoints;
+  vector<Scores> scores;
+  vector<float> avg_scores;
+  vector<Bbox> bboxes;
+  vector<int> is_missing;
   int64_t track_id{-1};
+  std::shared_ptr<TrackerFilter> filter;
+  Bbox bbox_pred{};
+  Points kpts_pred;
+  int64_t age{0};
+  int64_t n_missing{0};
 };
 
 struct TrackInfo {
-  std::vector<Track> tracks;
+  vector<Track> tracks;
   int64_t next_id{0};
 };
 
-MMDEPLOY_REGISTER_TYPE_ID(TrackInfo, 0xcfe87980aa895d3a);  // randomly generated type id
+static inline float Area(const Bbox& bbox) { return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]); }
 
-Value::Array GetObjectsByTracking(Value& state, int img_h, int img_w) {
-  Value::Array objs;
-  auto& track_info = state["track_info"].get_ref<TrackInfo&>();
-  for (auto& track : track_info.tracks) {
-    auto bbox = keypoints_to_bbox(track.keypoints.back(), track.scores.back(),
-                                  static_cast<float>(img_h), static_cast<float>(img_w));
-    if (bbox) {
-      objs.push_back({{"bbox", to_value(*bbox)}});
+struct TrackerParams {
+  // detector params
+  int det_interval = 5;           // detection interval
+  int det_label = 0;              // label used to filter detections
+  float det_min_bbox_size = 100;  // threshold for sqrt(area(bbox))
+  float det_thr = .5f;            // confidence threshold used to filter detections
+  float det_nms_thr = .7f;        // detection nms threshold
+
+  // pose model params
+  int pose_max_num_bboxes = 1;    // max num of bboxes for pose model per frame
+  int pose_min_keypoints = -1;    // min of visible key-points for valid bbox, -1 -> len(kpts)/2
+  float pose_min_bbox_size = 64;  // threshold for sqrt(area(bbox))
+  vector<float> sigmas;           // sigmas for key-points
+
+  // tracker params
+  float track_nms_oks_thr = .5f;        // OKS threshold for suppressing duplicated key-points
+  float track_kpts_thr = .6f;           // threshold for key-point visibility
+  float track_oks_thr = .3f;            // OKS assignment threshold
+  float track_iou_thr = .3f;            // IOU assignment threshold
+  float track_bbox_scale = 1.25f;       // scale factor for bboxes
+  int track_max_missing = 10;           // max number of missing frames before track removal
+  float track_missing_momentum = .95f;  // extrapolation momentum for missing tracks
+  int track_n_history = 10;             // track history length
+
+  // filter params for bbox center
+  float filter_c_beta = .005;
+  float filter_c_fc_min = .05;
+  float filter_c_fc_d = 1.;
+  // filter params for key-points
+  float filter_k_beta = .0075;
+  float filter_k_fc_min = .1;
+  float filter_k_fc_d = .25;
+};
+
+class Tracker {
+ public:
+  explicit Tracker(const TrackerParams& _params) : params(_params) {}
+  // xyxy format
+  float IntersectionOverUnion(const std::array<float, 4>& a, const std::array<float, 4>& b) {
+    auto x1 = std::max(a[0], b[0]);
+    auto y1 = std::max(a[1], b[1]);
+    auto x2 = std::min(a[2], b[2]);
+    auto y2 = std::min(a[3], b[3]);
+
+    auto inter_area = std::max(0.f, x2 - x1) * std::max(0.f, y2 - y1);
+
+    auto a_area = Area(a);
+    auto b_area = Area(b);
+    auto union_area = a_area + b_area - inter_area;
+
+    if (union_area == 0.f) {
+      return 0;
     }
-  }
-  return objs;
-}
 
-Value ProcessBboxes(const Value& detections, const Value& data, Value state) {
-  assert(state.is_pointer());
-  Value::Array bboxes;
-  if (detections.is_array()) {  // has detections
-    auto& dets = detections.array();
-    for (const auto& det : dets) {
-      if (det["label_id"].get<int>() == 0 && det["score"].get<float>() >= .3f) {
-        bboxes.push_back(det);
-      }
+    return inter_area / union_area;
+  }
+
+  // TopDownAffine's internal logic for mapping pose detector inputs
+  Bbox MapBbox(const Bbox& box) {
+    Point p0(box[0], box[1]);
+    Point p1(box[2], box[3]);
+    auto c = .5f * (p0 + p1);
+    auto s = p1 - p0;
+    static constexpr std::array image_size{192.f, 256.f};
+    float aspect_ratio = image_size[0] * 1.0 / image_size[1];
+    if (s.x > aspect_ratio * s.y) {
+      s.y = s.x / aspect_ratio;
+    } else if (s.x < aspect_ratio * s.y) {
+      s.x = s.y * aspect_ratio;
     }
-    MMDEPLOY_INFO("bboxes by detection: {}", bboxes.size());
-    state["bboxes"] = bboxes;
-  } else {  // no detections, use tracked results
-    auto img_h = state["img_shape"][0].get<int>();
-    auto img_w = state["img_shape"][1].get<int>();
-    bboxes = GetObjectsByTracking(state, img_h, img_w);
-    MMDEPLOY_INFO("GetObjectsByTracking: {}", bboxes.size());
-  }
-  // attach bboxes to image data
-  for (auto& bbox : bboxes) {
-    auto img = data["ori_img"].get<framework::Mat>();
-    auto box = from_value<std::array<float, 4>>(bbox["bbox"]);
-    cv::Rect rect(cv::Rect2f(cv::Point2f(box[0], box[1]), cv::Point2f(box[2], box[3])));
-    bbox = Value::Object{
-        {"ori_img", img}, {"bbox", {rect.x, rect.y, rect.width, rect.height}}, {"rotation", 0.f}};
-  };
-  return bboxes;
-}
-REGISTER_SIMPLE_MODULE(ProcessBboxes, ProcessBboxes);
-
-// xyxy format
-float ComputeIoU(const std::array<float, 4>& a, const std::array<float, 4>& b) {
-  auto x1 = std::max(a[0], b[0]);
-  auto y1 = std::max(a[1], b[1]);
-  auto x2 = std::min(a[2], b[2]);
-  auto y2 = std::min(a[3], b[3]);
-
-  auto inter_area = std::max(0.f, x2 - x1) * std::max(0.f, y2 - y1);
-
-  auto a_area = (a[2] - a[0]) * (a[3] - a[1]);
-  auto b_area = (b[2] - b[0]) * (b[3] - b[1]);
-  auto union_area = a_area + b_area - inter_area;
-
-  if (union_area == 0.f) {
-    return 0;
+    s.x *= 1.25f;
+    s.y *= 1.25f;
+    p0 = c - .5f * s;
+    p1 = c + .5f * s;
+    return {p0.x, p0.y, p1.x, p1.y};
   }
 
-  return inter_area / union_area;
-}
-
-void UpdateTrack(Track& track, std::vector<cv::Point2f>& keypoints, std::vector<float>& score,
-                 const std::array<float, 4>& bbox, int n_history) {
-  if (track.scores.size() == n_history) {
-    std::rotate(track.keypoints.begin(), track.keypoints.begin() + 1, track.keypoints.end());
-    std::rotate(track.scores.begin(), track.scores.begin() + 1, track.scores.end());
-    std::rotate(track.bboxes.begin(), track.bboxes.begin() + 1, track.bboxes.end());
-    track.keypoints.back() = std::move(keypoints);
-    track.scores.back() = std::move(score);
-    track.bboxes.back() = bbox;
-  } else {
-    track.keypoints.push_back(std::move(keypoints));
-    track.scores.push_back(std::move(score));
-    track.bboxes.push_back(bbox);
-  }
-}
-
-std::vector<std::tuple<int, int, float>> GreedyAssignment(const std::vector<float>& scores,
-                                                          int n_rows, int n_cols, float thr) {
-  std::vector<int> used_rows(n_rows);
-  std::vector<int> used_cols(n_cols);
-  std::vector<std::tuple<int, int, float>> assignment;
-  assignment.reserve(std::max(n_rows, n_cols));
-  while (true) {
-    auto max_score = 0.f;
-    int max_row = -1;
-    int max_col = -1;
-    for (int i = 0; i < n_rows; ++i) {
-      if (!used_rows[i]) {
-        for (int j = 0; j < n_cols; ++j) {
-          if (!used_cols[j]) {
-            if (scores[i * n_cols + j] > max_score) {
-              max_score = scores[i * n_cols + j];
-              max_row = i;
-              max_col = j;
+  template <typename T>
+  vector<int> SuppressNonMaximum(const vector<T>& scores, const vector<float>& similarities,
+                                 vector<int> is_valid, float thresh) {
+    assert(is_valid.size() == scores.size());
+    vector<int> indices(scores.size());
+    std::iota(indices.begin(), indices.end(), 0);
+    // stable sort, useful when the scores are equal
+    std::sort(indices.begin(), indices.end(), [&](int i, int j) { return scores[i] > scores[j]; });
+    // suppress similar samples
+    for (int i = 0; i < indices.size(); ++i) {
+      if (auto u = indices[i]; is_valid[u]) {
+        for (int j = i + 1; j < indices.size(); ++j) {
+          if (auto v = indices[j]; is_valid[v]) {
+            if (similarities[u * scores.size() + v] >= thresh) {
+              is_valid[v] = false;
             }
           }
         }
       }
     }
-    if (max_score < thr) {
-      break;
-    }
-    used_rows[max_row] = 1;
-    used_cols[max_col] = 1;
-    assignment.emplace_back(max_row, max_col, max_score);
+    return is_valid;
   }
-  return assignment;
-}
 
-void TrackStep(std::vector<std::vector<cv::Point2f>>& keypoints,
-               std::vector<std::vector<float>>& scores, TrackInfo& track_info, int img_h, int img_w,
-               float iou_thr, int min_keypoints, int n_history) {
-  auto& tracks = track_info.tracks;
+  struct Detections {
+    Bboxes bboxes;
+    Scores scores;
+    vector<int> labels;
+  };
 
-  std::vector<Track> new_tracks;
-  new_tracks.reserve(tracks.size());
-
-  std::vector<std::array<float, 4>> bboxes;
-  bboxes.reserve(keypoints.size());
-
-  std::vector<int> indices;
-  indices.reserve(keypoints.size());
-
-  for (size_t i = 0; i < keypoints.size(); ++i) {
-    if (auto bbox = keypoints_to_bbox(keypoints[i], scores[i], img_h, img_w, 1.f, 0.f)) {
-      bboxes.push_back(*bbox);
-      indices.push_back(i);
+  void GetObjectsByDetection(const Detections& dets, vector<Bbox>& bboxes,
+                             vector<int64_t>& track_ids, vector<int>& types) const {
+    auto& [_bboxes, _scores, _labels] = dets;
+    for (size_t i = 0; i < _bboxes.size(); ++i) {
+      if (_labels[i] == params.det_label && _scores[i] > params.det_thr &&
+          Area(_bboxes[i]) >= params.det_min_bbox_size * params.det_min_bbox_size) {
+        bboxes.push_back(_bboxes[i]);
+        track_ids.push_back(-1);
+        types.push_back(1);
+      }
     }
   }
 
-  const auto n_rows = static_cast<int>(bboxes.size());
-  const auto n_cols = static_cast<int>(tracks.size());
-
-  std::vector<float> similarities(n_rows * n_cols);
-  for (size_t i = 0; i < n_rows; ++i) {
-    for (size_t j = 0; j < n_cols; ++j) {
-      similarities[i * n_cols + j] = ComputeIoU(bboxes[i], tracks[j].bboxes.back());
+  void GetObjectsByTracking(vector<Bbox>& bboxes, vector<int64_t>& track_ids,
+                            vector<int>& types) const {
+    for (auto& track : track_info.tracks) {
+      std::optional<Bbox> bbox;
+      if (track.n_missing) {
+        bbox = track.bbox_pred;
+      } else {
+        bbox = keypoints_to_bbox(track.kpts_pred, track.scores.back(), static_cast<float>(frame_h),
+                                 static_cast<float>(frame_w), params.track_bbox_scale,
+                                 params.track_kpts_thr, params.pose_min_keypoints);
+      }
+      if (bbox && Area(*bbox) >= params.pose_min_bbox_size * params.pose_min_bbox_size) {
+        bboxes.push_back(*bbox);
+        track_ids.push_back(track.track_id);
+        types.push_back(track.n_missing ? 0 : 2);
+      }
     }
   }
 
-  const auto assignment = GreedyAssignment(similarities, n_rows, n_cols, iou_thr);
+  std::tuple<vector<Bbox>, vector<int64_t>> ProcessBboxes(const std::optional<Detections>& dets) {
+    vector<Bbox> bboxes;
+    vector<int64_t> track_ids;
 
-  std::vector<int> used(n_rows);
-  for (auto [i, j, _] : assignment) {
-    auto k = indices[i];
-    UpdateTrack(tracks[j], keypoints[k], scores[k], bboxes[i], n_history);
-    new_tracks.push_back(std::move(tracks[j]));
-    used[i] = true;
+    // 2 - visible tracks
+    // 1 - detection
+    // 0 - missing tracks
+    vector<int> types;
+
+    if (dets) {
+      GetObjectsByDetection(*dets, bboxes, track_ids, types);
+    }
+
+    GetObjectsByTracking(bboxes, track_ids, types);
+
+    vector<int> is_valid_bboxes(bboxes.size(), 1);
+
+    auto count = [&] {
+      std::array<int, 3> acc{};
+      for (size_t i = 0; i < is_valid_bboxes.size(); ++i) {
+        if (is_valid_bboxes[i]) {
+          ++acc[types[i]];
+        }
+      }
+      return acc;
+    };
+    POSE_TRACKER_DEBUG("frame {}, bboxes {}", frame_id, count());
+
+    vector<std::pair<int, float>> ranks;
+    ranks.reserve(bboxes.size());
+    for (int i = 0; i < bboxes.size(); ++i) {
+      ranks.emplace_back(types[i], Area(bboxes[i]));
+    }
+
+    vector<float> iou(ranks.size() * ranks.size());
+    for (int i = 0; i < bboxes.size(); ++i) {
+      for (int j = 0; j < i; ++j) {
+        iou[i * bboxes.size() + j] = iou[j * bboxes.size() + i] =
+            IntersectionOverUnion(bboxes[i], bboxes[j]);
+      }
+    }
+
+    is_valid_bboxes =
+        SuppressNonMaximum(ranks, iou, std::move(is_valid_bboxes), params.det_nms_thr);
+    POSE_TRACKER_DEBUG("frame {}, bboxes after nms: {}", frame_id, count());
+
+    vector<int> idxs;
+    idxs.reserve(bboxes.size());
+    for (int i = 0; i < bboxes.size(); ++i) {
+      if (is_valid_bboxes[i]) {
+        idxs.push_back(i);
+      }
+    }
+
+    std::stable_sort(idxs.begin(), idxs.end(), [&](int i, int j) { return ranks[i] > ranks[j]; });
+    std::fill(is_valid_bboxes.begin(), is_valid_bboxes.end(), 0);
+    {
+      vector<Bbox> tmp_bboxes;
+      vector<int64_t> tmp_track_ids;
+      for (const auto& i : idxs) {
+        if (tmp_bboxes.size() >= params.pose_max_num_bboxes) {
+          break;
+        }
+        tmp_bboxes.push_back(bboxes[i]);
+        tmp_track_ids.push_back(track_ids[i]);
+        is_valid_bboxes[i] = 1;
+      }
+      bboxes = std::move(tmp_bboxes);
+      track_ids = std::move(tmp_track_ids);
+    }
+
+    POSE_TRACKER_DEBUG("frame {}, bboxes after sort: {}", frame_id, count());
+
+    pose_bboxes.clear();
+    for (const auto& bbox : bboxes) {
+      //    pose_bboxes.push_back(MapBbox(bbox));
+      pose_bboxes.push_back(bbox);
+    }
+
+    return {bboxes, track_ids};
   }
 
-  for (size_t i = 0; i < used.size(); ++i) {
-    if (used[i] == 0) {
-      auto k = indices[i];
-      auto count = std::count_if(scores[k].begin(), scores[k].end(), [](auto x) { return x > 0; });
-      if (count >= min_keypoints) {
+  float ObjectKeypointSimilarity(const Points& pts_a, const Bbox& box_a, const Points& pts_b,
+                                 const Bbox& box_b) {
+    assert(pts_a.size() == sigmas.size());
+    assert(pts_b.size() == sigmas.size());
+    auto scale = [](const Bbox& bbox) -> float {
+      auto a = bbox[2] - bbox[0];
+      auto b = bbox[3] - bbox[1];
+      return std::sqrt(a * a + b * b);
+    };
+    auto oks = [](const Point& pa, const Point& pb, float s, float k) {
+      return std::exp(-(pa - pb).dot(pa - pb) / (2.f * s * s * k * k));
+    };
+    auto sum = 0.f;
+    const auto s = .5f * (scale(box_a) + scale(box_b));
+    for (int i = 0; i < params.sigmas.size(); ++i) {
+      sum += oks(pts_a[i], pts_b[i], s, params.sigmas[i]);
+    }
+    sum /= static_cast<float>(params.sigmas.size());
+    return sum;
+  }
+
+  void UpdateTrack(Track& track, Points kpts, Scores score, const Bbox& bbox, int is_missing) {
+    auto avg_score = std::accumulate(score.begin(), score.end(), 0.f) / score.size();
+    if (track.scores.size() == params.track_n_history) {
+      std::rotate(track.keypoints.begin(), track.keypoints.begin() + 1, track.keypoints.end());
+      std::rotate(track.scores.begin(), track.scores.begin() + 1, track.scores.end());
+      std::rotate(track.bboxes.begin(), track.bboxes.begin() + 1, track.bboxes.end());
+      std::rotate(track.avg_scores.begin(), track.avg_scores.begin() + 1, track.avg_scores.end());
+      std::rotate(track.is_missing.begin(), track.is_missing.begin() + 1, track.is_missing.end());
+      track.keypoints.back() = std::move(kpts);
+      track.scores.back() = std::move(score);
+      track.bboxes.back() = bbox;
+      track.avg_scores.back() = avg_score;
+      track.is_missing.back() = is_missing;
+    } else {
+      track.keypoints.push_back(std::move(kpts));
+      track.scores.push_back(std::move(score));
+      track.bboxes.push_back(bbox);
+      track.avg_scores.push_back(avg_score);
+      track.is_missing.push_back(is_missing);
+    }
+    ++track.age;
+    track.n_missing = is_missing ? track.n_missing + 1 : 0;
+  }
+
+  vector<std::tuple<int, int, float>> GreedyAssignment(const vector<float>& scores,
+                                                       vector<int>& is_valid_rows,
+                                                       vector<int>& is_valid_cols, float thr) {
+    const auto n_rows = is_valid_rows.size();
+    const auto n_cols = is_valid_cols.size();
+    vector<std::tuple<int, int, float>> assignment;
+    assignment.reserve(std::max(n_rows, n_cols));
+    while (true) {
+      auto max_score = 0.f;
+      int max_row = -1;
+      int max_col = -1;
+      for (int i = 0; i < n_rows; ++i) {
+        if (is_valid_rows[i]) {
+          for (int j = 0; j < n_cols; ++j) {
+            if (is_valid_cols[j]) {
+              if (scores[i * n_cols + j] > max_score) {
+                max_score = scores[i * n_cols + j];
+                max_row = i;
+                max_col = j;
+              }
+            }
+          }
+        }
+      }
+      if (max_score < thr) {
+        break;
+      }
+      is_valid_rows[max_row] = 0;
+      is_valid_cols[max_col] = 0;
+      assignment.emplace_back(max_row, max_col, max_score);
+    }
+    return assignment;
+  }
+
+  vector<int> SuppressOverlappingBboxes(
+      const vector<Points>& keypoints, const vector<Scores>& scores,
+      const vector<int>& is_present,  // bbox from a visible track?
+      const vector<Bbox>& bboxes, vector<int> is_valid, const vector<float>& sigmas,
+      float oks_thr) {
+    assert(keypoints.size() == is_valid.size());
+    assert(scores.size() == is_valid.size());
+    assert(bboxes.size() == is_valid.size());
+    const auto size = is_valid.size();
+    vector<float> oks(size * size);
+    for (int i = 0; i < size; ++i) {
+      if (is_valid[i]) {
+        for (int j = 0; j < i; ++j) {
+          if (is_valid[j]) {
+            oks[i * size + j] = oks[j * size + i] =
+                ObjectKeypointSimilarity(keypoints[i], bboxes[i], keypoints[j], bboxes[j]);
+          }
+        }
+      }
+    }
+    vector<std::pair<int, float>> ranks;
+    ranks.reserve(size);
+    for (int i = 0; i < size; ++i) {
+      auto& s = scores[i];
+      auto avg = std::accumulate(s.begin(), s.end(), 0.f) / static_cast<float>(s.size());
+      // prevents bboxes from missing tracks to suppress visible tracks
+      ranks.emplace_back(is_present[i], avg);
+    }
+    return SuppressNonMaximum(ranks, oks, is_valid, oks_thr);
+  }
+
+  void TrackStep(vector<Points>& keypoints, vector<Scores>& scores,
+                 const vector<int64_t>& track_ids) {
+    auto& tracks = track_info.tracks;
+
+    vector<Track> new_tracks;
+    new_tracks.reserve(tracks.size());
+
+    vector<Bbox> bboxes(keypoints.size());
+    vector<int> is_valid_bboxes(keypoints.size(), 1);
+
+    pose_results.clear();
+
+    // key-points to bboxes
+    for (size_t i = 0; i < keypoints.size(); ++i) {
+      if (auto bbox =
+              keypoints_to_bbox(keypoints[i], scores[i], frame_h, frame_w, params.track_bbox_scale,
+                                params.track_kpts_thr, params.pose_min_keypoints)) {
+        bboxes[i] = *bbox;
+        pose_results.push_back(*bbox);
+      } else {
+        is_valid_bboxes[i] = false;
+        //      MMDEPLOY_INFO("frame {}: invalid key-points {}", frame_id, scores[i]);
+      }
+    }
+
+    vector<int> is_present(is_valid_bboxes.size());
+    for (int i = 0; i < track_ids.size(); ++i) {
+      for (const auto& t : tracks) {
+        if (t.track_id == track_ids[i]) {
+          is_present[i] = !t.n_missing;
+          break;
+        }
+      }
+    }
+    is_valid_bboxes =
+        SuppressOverlappingBboxes(keypoints, scores, is_present, bboxes, is_valid_bboxes,
+                                  params.sigmas, params.track_nms_oks_thr);
+    assert(is_valid_bboxes.size() == bboxes.size());
+
+    const auto n_rows = static_cast<int>(bboxes.size());
+    const auto n_cols = static_cast<int>(tracks.size());
+
+    // generate similarity matrix
+    vector<float> iou(n_rows * n_cols);
+    vector<float> oks(n_rows * n_cols);
+    for (size_t i = 0; i < n_rows; ++i) {
+      const auto& bbox = bboxes[i];
+      const auto& kpts = keypoints[i];
+      for (size_t j = 0; j < n_cols; ++j) {
+        const auto& track = tracks[j];
+        if (track_ids[i] != -1 && track_ids[i] != track.track_id) {
+          continue;
+        }
+        const auto index = i * n_cols + j;
+        iou[index] = IntersectionOverUnion(bbox, track.bbox_pred);
+        oks[index] = ObjectKeypointSimilarity(kpts, bbox, track.kpts_pred, track.bbox_pred);
+      }
+    }
+
+    vector<int> is_valid_tracks(n_cols, 1);
+    // disable missing tracks in the #1 assignment
+    for (int i = 0; i < tracks.size(); ++i) {
+      if (tracks[i].n_missing) {
+        is_valid_tracks[i] = 0;
+      }
+    }
+    const auto oks_assignment =
+        GreedyAssignment(oks, is_valid_bboxes, is_valid_tracks, params.track_oks_thr);
+
+    // enable missing tracks in the #2 assignment
+    for (int i = 0; i < tracks.size(); ++i) {
+      if (tracks[i].n_missing) {
+        is_valid_tracks[i] = 1;
+      }
+    }
+    const auto iou_assignment =
+        GreedyAssignment(iou, is_valid_bboxes, is_valid_tracks, params.track_iou_thr);
+
+    POSE_TRACKER_DEBUG("frame {}, oks assignment {}", frame_id, oks_assignment);
+    POSE_TRACKER_DEBUG("frame {}, iou assignment {}", frame_id, iou_assignment);
+
+    auto assignment = oks_assignment;
+    assignment.insert(assignment.end(), iou_assignment.begin(), iou_assignment.end());
+
+    // update assigned tracks
+    for (auto [i, j, _] : assignment) {
+      auto& track = tracks[j];
+      if (track.n_missing) {
+        // re-initialize filter for recovering tracks
+        track.filter = CreateFilter(bboxes[i], keypoints[i]);
+        UpdateTrack(track, keypoints[i], scores[i], bboxes[i], false);
+        POSE_TRACKER_DEBUG("frame {}, track recovered {}", frame_id, track.track_id);
+      } else {
+        auto [bbox, kpts] = track.filter->Correct(bboxes[i], keypoints[i]);
+        UpdateTrack(track, std::move(kpts), std::move(scores[i]), bbox, false);
+      }
+      new_tracks.push_back(std::move(track));
+    }
+
+    // generating new tracks
+    for (size_t i = 0; i < is_valid_bboxes.size(); ++i) {
+      // only newly detected bboxes are allowed to form new tracks
+      if (is_valid_bboxes[i] && track_ids[i] == -1) {
         auto& track = new_tracks.emplace_back();
         track.track_id = track_info.next_id++;
-        UpdateTrack(track, keypoints[k], scores[k], bboxes[i], n_history);
+        track.filter = CreateFilter(bboxes[i], keypoints[i]);
+        UpdateTrack(track, std::move(keypoints[i]), std::move(scores[i]), bboxes[i], false);
+        is_valid_bboxes[i] = 0;
+        POSE_TRACKER_DEBUG("frame {}, new track {}", frame_id, track.track_id);
+      }
+    }
+
+    if (1) {
+      // diagnostic for missing tracks
+      int n_missing = 0;
+      for (int i = 0; i < is_valid_tracks.size(); ++i) {
+        if (is_valid_tracks[i]) {
+          float best_oks = 0.f;
+          float best_iou = 0.f;
+          for (int j = 0; j < is_valid_bboxes.size(); ++j) {
+            if (is_valid_bboxes[j]) {
+              best_oks = std::max(oks[j * n_cols + i], best_oks);
+              best_iou = std::max(iou[j * n_cols + i], best_iou);
+            }
+          }
+          POSE_TRACKER_DEBUG("frame {}: track missing {}, best_oks={}, best_iou={}", frame_id,
+                             tracks[i].track_id, best_oks, best_iou);
+          ++n_missing;
+        }
+      }
+      if (n_missing) {
+        {
+          std::stringstream ss;
+          ss << cv::Mat_<float>(n_rows, n_cols, oks.data());
+          POSE_TRACKER_DEBUG("frame {}, oks: \n{}", frame_id, ss.str());
+        }
+        {
+          std::stringstream ss;
+          ss << cv::Mat_<float>(n_rows, n_cols, iou.data());
+          POSE_TRACKER_DEBUG("frame {}, iou: \n{}", frame_id, ss.str());
+        }
+      }
+    }
+
+    for (int i = 0; i < is_valid_tracks.size(); ++i) {
+      if (is_valid_tracks[i]) {
+        if (auto& track = tracks[i]; track.n_missing < params.track_max_missing) {
+          // use predicted state to update missing tracks
+          auto [bbox, kpts] = track.filter->Correct(track.bbox_pred, track.kpts_pred);
+          vector<float> score(track.kpts_pred.size());
+          POSE_TRACKER_DEBUG("frame {}, track {}, bbox width {}", frame_id, track.track_id,
+                             bbox[2] - bbox[0]);
+          UpdateTrack(track, std::move(kpts), std::move(score), bbox, true);
+          new_tracks.push_back(std::move(track));
+        } else {
+          POSE_TRACKER_DEBUG("frame {}, track lost {}", frame_id, track.track_id);
+        }
+        is_valid_tracks[i] = false;
+      }
+    }
+
+    tracks = std::move(new_tracks);
+    for (auto& t : tracks) {
+      if (t.n_missing == 0) {
+        std::tie(t.bbox_pred, t.kpts_pred) = t.filter->Predict();
+      } else {
+        auto [bbox, kpts] = t.filter->Predict();
+        const auto alpha = params.track_missing_momentum;
+        cv::Mat tmp_bbox = alpha * cv::Mat(bbox, false) + (1 - alpha) * cv::Mat(t.bbox_pred, false);
+        tmp_bbox.copyTo(cv::Mat(t.bbox_pred, false));
+      }
+    }
+
+    if (0) {
+      vector<std::tuple<int64_t, int>> summary;
+      for (const auto& track : tracks) {
+        summary.emplace_back(track.track_id, track.n_missing);
+      }
+      POSE_TRACKER_DEBUG("frame {}, track summary {}", frame_id, summary);
+      for (const auto& track : tracks) {
+        if (!track.n_missing) {
+          POSE_TRACKER_DEBUG("frame {}, track {}, scores {}", frame_id, track.track_id,
+                             track.scores.back());
+        }
       }
     }
   }
 
-  tracks = std::move(new_tracks);
+  std::shared_ptr<TrackerFilter> CreateFilter(const Bbox& bbox, const Points& kpts) const {
+    return std::make_shared<TrackerFilter>(
+        params.filter_c_beta, params.filter_c_fc_min, params.filter_c_fc_d, params.filter_k_beta,
+        params.filter_k_fc_min, params.filter_k_fc_d, bbox, kpts);
+  }
+
+  struct Target {
+    Bbox bbox;
+    vector<float> keypoints;
+    Scores scores;
+    MMDEPLOY_ARCHIVE_MEMBERS(bbox, keypoints, scores);
+  };
+
+  vector<Target> TrackPose(vector<Points> keypoints, vector<Scores> scores,
+                           const vector<int64_t>& track_ids) {
+    TrackStep(keypoints, scores, track_ids);
+    vector<Target> targets;
+    for (const auto& track : track_info.tracks) {
+      if (track.n_missing) {
+        continue;
+      }
+      if (auto bbox = keypoints_to_bbox(track.keypoints.back(), track.scores.back(), frame_h,
+                                        frame_w, params.track_bbox_scale, params.track_kpts_thr,
+                                        params.pose_min_keypoints)) {
+        vector<float> kpts;
+        kpts.reserve(track.keypoints.back().size());
+        for (const auto& kpt : track.keypoints.back()) {
+          kpts.emplace_back(kpt.x);
+          kpts.emplace_back(kpt.y);
+        }
+        targets.push_back(Target{*bbox, std::move(kpts), track.scores.back()});
+      }
+    }
+    return targets;
+  }
+
+  float frame_h = 0;
+  float frame_w = 0;
+  TrackInfo track_info;
+
+  TrackerParams params;
+
+  int frame_id = 0;
+
+  vector<Bbox> pose_bboxes;
+  vector<Bbox> pose_results;
+};
+
+MMDEPLOY_REGISTER_TYPE_ID(Tracker, 0xcfe87980aa895d3a);
+
+std::tuple<Value, Value> ProcessBboxes(const Value& det_val, const Value& data, Value state) {
+  auto& tracker = state.get_ref<Tracker&>();
+
+  std::optional<Tracker::Detections> dets;
+
+  if (det_val.is_array()) {  // has detections
+    auto& [bboxes, scores, labels] = dets.emplace();
+    for (const auto& det : det_val.array()) {
+      bboxes.push_back(from_value<Bbox>(det["bbox"]));
+      scores.push_back(det["score"].get<float>());
+      labels.push_back(det["label_id"].get<int>());
+    }
+  }
+
+  auto [bboxes, ids] = tracker.ProcessBboxes(dets);
+
+  Value::Array bbox_array;
+  Value track_ids_array;
+  // attach bboxes to image data
+  for (auto& bbox : bboxes) {
+    cv::Rect rect(cv::Rect2f(cv::Point2f(bbox[0], bbox[1]), cv::Point2f(bbox[2], bbox[3])));
+    bbox_array.push_back({
+        {"img", data["img"]},                                 // img
+        {"bbox", {rect.x, rect.y, rect.width, rect.height}},  // bbox
+        {"rotation", 0.f}                                     // rotation
+    });
+  }
+
+  track_ids_array = to_value(ids);
+  return {std::move(bbox_array), std::move(track_ids_array)};
 }
 
-Value TrackPose(const Value& result, Value state) {
-  assert(state.is_pointer());
-  assert(result.is_array());
-  std::vector<std::vector<cv::Point2f>> keypoints;
-  std::vector<std::vector<float>> scores;
-  for (auto& output : result.array()) {
+REGISTER_SIMPLE_MODULE(ProcessBboxes, ProcessBboxes);
+
+Value TrackPose(const Value& poses, const Value& track_indices, Value state) {
+  assert(poses.is_array());
+  vector<Points> keypoints;
+  vector<Scores> scores;
+  for (auto& output : poses.array()) {
     auto& k = keypoints.emplace_back();
     auto& s = scores.emplace_back();
+    float avg = 0.f;
     for (auto& kpt : output["key_points"].array()) {
-      k.push_back(cv::Point2f{kpt["bbox"][0].get<float>(), kpt["bbox"][1].get<float>()});
+      k.emplace_back(kpt["bbox"][0].get<float>(), kpt["bbox"][1].get<float>());
       s.push_back(kpt["score"].get<float>());
+      avg += s.back();
     }
   }
-  auto& track_info = state["track_info"].get_ref<TrackInfo&>();
-  auto img_h = state["img_shape"][0].get<int>();
-  auto img_w = state["img_shape"][1].get<int>();
-  auto iou_thr = state["iou_thr"].get<float>();
-  auto min_keypoints = state["min_keypoints"].get<int>();
-  auto n_history = state["n_history"].get<int>();
-  TrackStep(keypoints, scores, track_info, img_h, img_w, iou_thr, min_keypoints, n_history);
-
-  Value::Array targets;
-  for (const auto& track : track_info.tracks) {
-    if (auto bbox = keypoints_to_bbox(track.keypoints.back(), track.scores.back(), img_h, img_w)) {
-      Value::Array kpts;
-      kpts.reserve(track.keypoints.back().size());
-      for (const auto& kpt : track.keypoints.back()) {
-        kpts.push_back(kpt.x);
-        kpts.push_back(kpt.y);
-      }
-      targets.push_back({{"bbox", to_value(*bbox)}, {"keypoints", std::move(kpts)}});
-    }
-  }
-  return targets;
+  vector<int64_t> track_ids;
+  from_value(track_indices, track_ids);
+  auto& tracker = state.get_ref<Tracker&>();
+  auto targets = tracker.TrackPose(std::move(keypoints), std::move(scores), track_ids);
+  return to_value(targets);
 }
+
 REGISTER_SIMPLE_MODULE(TrackPose, TrackPose);
 
 class PoseTracker {
@@ -324,12 +905,10 @@ class PoseTracker {
           return Pipeline{config, context};
         }()) {}
 
-  State CreateState() {  // NOLINT
-    return make_pointer({{"frame_id", 0},
-                         {"n_history", 10},
-                         {"iou_thr", .3f},
-                         {"min_keypoints", 3},
-                         {"track_info", TrackInfo{}}});
+  State CreateState(const TrackerParams& params) {
+    auto state = make_pointer(Tracker{params});
+    auto& tracker = state.get_ref<Tracker&>();
+    return state;
   }
 
   Value Track(const Mat& img, State& state, int use_detector = -1) {
@@ -337,19 +916,30 @@ class PoseTracker {
     framework::Mat mat(img.desc().height, img.desc().width,
                        static_cast<PixelFormat>(img.desc().format),
                        static_cast<DataType>(img.desc().type), {img.desc().data, [](void*) {}});
-    // TODO: get_ref<int&> is not working
-    auto frame_id = state["frame_id"].get<int>();
+
+    auto& tracker = state.get_ref<Tracker&>();
+
     if (use_detector < 0) {
-      use_detector = frame_id % 10 == 0;
-      if (use_detector) {
-        MMDEPLOY_WARN("use detector");
+      if (tracker.frame_id % tracker.params.det_interval == 0) {
+        use_detector = 1;
+        POSE_TRACKER_DEBUG("frame {}, use detector", tracker.frame_id);
+      } else {
+        use_detector = 0;
       }
     }
-    state["frame_id"] = frame_id + 1;
-    state["img_shape"] = {mat.height(), mat.width()};
+
+    if (tracker.frame_id == 0) {
+      tracker.frame_h = static_cast<float>(mat.height());
+      tracker.frame_w = static_cast<float>(mat.width());
+    }
+
     Value::Object data{{"ori_img", mat}};
     Value input{{data}, {use_detector}, {state}};
-    return pipeline_.Apply(input)[0][0];
+    auto ret = pipeline_.Apply(input)[0][0];
+
+    ++tracker.frame_id;
+
+    return ret;
   }
 
  private:
@@ -360,32 +950,74 @@ class PoseTracker {
 
 using namespace mmdeploy;
 
-void Visualize(cv::Mat& frame, const Value& result) {
-  static std::vector<std::pair<int, int>> skeleton{
+const cv::Scalar& gPalette(int index) {
+  static vector<cv::Scalar> inst{
+      {255, 128, 0},   {255, 153, 51},  {255, 178, 102}, {230, 230, 0},   {255, 153, 255},
+      {153, 204, 255}, {255, 102, 255}, {255, 51, 255},  {102, 178, 255}, {51, 153, 255},
+      {255, 153, 153}, {255, 102, 102}, {255, 51, 51},   {153, 255, 153}, {102, 255, 102},
+      {51, 255, 51},   {0, 255, 0},     {0, 0, 255},     {255, 0, 0},     {255, 255, 255}};
+  return inst[index];
+}
+
+void Visualize(cv::Mat& frame, const Value& result, const Bboxes& pose_bboxes,
+               const Bboxes& pose_results, int size) {
+  static vector<std::pair<int, int>> skeleton{
       {15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12}, {5, 6}, {5, 7}, {6, 8},
       {7, 9},   {8, 10},  {1, 2},   {0, 1},   {0, 2},   {1, 3},  {2, 4},  {3, 5}, {4, 6}};
+  static vector link_color{0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16};
+  static vector kpt_color{16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0};
+  auto scale = (float)size / (float)std::max(frame.cols, frame.rows);
+  if (scale != 1) {
+    cv::resize(frame, frame, {}, scale, scale);
+  }
+  auto draw_bbox = [](cv::Mat& image, Bbox bbox, const cv::Scalar& color, float scale = 1) {
+    std::for_each(bbox.begin(), bbox.end(), [&](auto& x) { x *= scale; });
+    cv::Point p1(bbox[0], bbox[1]);
+    cv::Point p2(bbox[2], bbox[3]);
+    cv::rectangle(image, p1, p2, color);
+  };
   const auto& targets = result.array();
   for (const auto& target : targets) {
     auto bbox = from_value<std::array<float, 4>>(target["bbox"]);
-    auto kpts = from_value<std::vector<float>>(target["keypoints"]);
-    cv::Point p1(bbox[0], bbox[1]);
-    cv::Point p2(bbox[2], bbox[3]);
-    cv::rectangle(frame, p1, p2, cv::Scalar(0, 255, 0));
-    for (int i = 0; i < kpts.size(); i += 2) {
-      cv::Point p(kpts[i], kpts[i + 1]);
-      cv::circle(frame, p, 1, cv::Scalar(0, 255, 255), 2, cv::LINE_AA);
+    auto kpts = from_value<vector<float>>(target["keypoints"]);
+    std::for_each(bbox.begin(), bbox.end(), [&](auto& x) { x *= scale; });
+    std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; });
+    auto scores = from_value<vector<float>>(target["scores"]);
+    if (0) {
+      draw_bbox(frame, bbox, cv::Scalar(0, 255, 0));
     }
+    constexpr auto score_thr = .5f;
+    vector<int> used(kpts.size());
     for (int i = 0; i < skeleton.size(); ++i) {
       auto [u, v] = skeleton[i];
-      cv::Point p_u(kpts[u * 2], kpts[u * 2 + 1]);
-      cv::Point p_v(kpts[v * 2], kpts[v * 2 + 1]);
-      cv::line(frame, p_u, p_v, cv::Scalar(0, 255, 255), 1, cv::LINE_AA);
+      if (scores[u] > score_thr && scores[v] > score_thr) {
+        used[u] = used[v] = 1;
+        cv::Point p_u(kpts[u * 2], kpts[u * 2 + 1]);
+        cv::Point p_v(kpts[v * 2], kpts[v * 2 + 1]);
+        cv::line(frame, p_u, p_v, gPalette(link_color[i]), 1, cv::LINE_AA);
+      }
+    }
+    for (int i = 0; i < kpts.size(); i += 2) {
+      if (used[i / 2]) {
+        cv::Point p(kpts[i], kpts[i + 1]);
+        cv::circle(frame, p, 1, gPalette(kpt_color[i / 2]), 2, cv::LINE_AA);
+      }
     }
   }
-  cv::imshow("", frame);
-  cv::waitKey(1);
+  if (0) {
+    for (auto bbox : pose_bboxes) {
+      draw_bbox(frame, bbox, {0, 255, 255}, scale);
+    }
+    for (auto bbox : pose_results) {
+      draw_bbox(frame, bbox, {0, 255, 0}, scale);
+    }
+  }
+  static int frame_id = 0;
+  cv::imwrite(fmt::format("pose_{}.jpg", frame_id++), frame, {cv::IMWRITE_JPEG_QUALITY, 90});
 }
 
+// ffmpeg -f image2 -i pose_%d.jpg -vcodec hevc -crf 30 pose.mp4
+
 int main(int argc, char* argv[]) {
   const auto device_name = argv[1];
   const auto det_model_path = argv[2];
@@ -396,7 +1028,14 @@ int main(int argc, char* argv[]) {
   Profiler profiler("pose_tracker.perf");
   context.Add(profiler);
   PoseTracker tracker(Model(det_model_path), Model(pose_model_path), context);
-  auto state = tracker.CreateState();
+  TrackerParams params;
+  // coco
+  params.sigmas = {0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
+                   0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089};
+  params.pose_max_num_bboxes = 5;
+  params.det_interval = 5;
+
+  auto state = tracker.CreateState(params);
 
   cv::Mat frame;
   std::chrono::duration<double, std::milli> dt{};
@@ -414,7 +1053,11 @@ int main(int argc, char* argv[]) {
     auto t1 = std::chrono::high_resolution_clock::now();
     dt += t1 - t0;
     ++frame_id;
-    Visualize(frame, result);
+
+    auto& pose_bboxes = state.get_ref<Tracker&>().pose_bboxes;
+    auto& pose_results = state.get_ref<Tracker&>().pose_results;
+
+    Visualize(frame, result, pose_bboxes, pose_results, 1024);
   }
 
   MMDEPLOY_INFO("frames: {}, time {} ms", frame_id, dt.count());