// Copyright (c) OpenMMLab. All rights reserved.

#include "graph/task.h"

#include "archive/value_archive.h"
#include "core/graph.h"
#include "core/operator.h"
#include "graph/common.h"

namespace mmdeploy::graph {

static int GetDepth(const Value& input) {
  if (input.is_array() && input.size() > 0) {
    return GetDepth(input[0]) + 1;
  }
  return input.is_array();
}

// all args are array of the same length
static size_t GetBatchSize(const Value& args) {
  size_t batch_size = 0;
  for (const auto& x : args) {
    if (x.is_array()) {
      if (!batch_size) {
        batch_size = x.size();
      } else if (batch_size != x.size()) {
        return 0;
      }
    } else {
      return 0;
    }
  }
  return batch_size;
}

unique_ptr<Task> Task::Create(const Value& config) {
  try {
    auto inst = std::make_unique<Task>();
    auto module = CreateFromRegistry<Module>(config, "module");
    if (!module) {
      ERROR("failed to create task: {}", config);
      return nullptr;
    }
    inst->module_ = std::move(module).value();
    inst->name_ = config.value("name", string{});
    inst->is_batched_ = config.value("is_batched", false);
    inst->is_thread_safe_ = config.value("is_thread_safe", false);
    from_value(config["input"], inst->inputs_);
    from_value(config["output"], inst->outputs_);
    return inst;
  } catch (...) {
    return nullptr;
  }
}
void Task::Build(TaskGraph& graph) {
  auto handle = graph.Add([this](Context& ctx) -> Result<void> {
    OUTCOME_TRY(auto args, Keys2Idxs(ctx.current(), inputs_));
    Value rets = Value::kArray;
    auto batch_size = GetBatchSize(args);
    //    ERROR("name: {}, is_batched: {}, INPUT batch_size: {}", name_, is_batched_, batch_size);
    if (!is_batched_ && batch_size) {
      for (int i = 0; i < outputs_.size(); ++i) {
        rets.push_back(Value::kArray);
      }
      if (!is_thread_safe_) {
        for (int i = 0; i < batch_size; ++i) {
          Value sample = Value::kArray;
          for (const auto& a : args) {
            sample.push_back(a[i]);
          }
          OUTCOME_TRY(auto ret, module_->Process(sample));
          for (int j = 0; j < ret.size(); ++j) {
            rets[j].push_back(std::move(ret[j]));
          }
        }
      } else {
        std::vector<std::function<Result<Value>()>> tasks;
        tasks.reserve(batch_size);
        OUTCOME_TRY(auto batch_args, DistribAA(args));
        for (int sample_id = 0; sample_id < batch_size; ++sample_id) {
          tasks.emplace_back([&, sample_id]() -> Result<Value> {
            return module_->Process(batch_args[sample_id]);
          });
        }
        auto batch_rets = ctx.Execute(tasks);
        for (auto& batch_ret : batch_rets) {
          OUTCOME_TRY(auto ret, std::move(batch_ret));
          for (int j = 0; j < rets.size(); ++j) {
            rets[j].push_back(std::move(ret[j]));
          }
        }
      }
    } else {
      OUTCOME_TRY(rets, module_->Process(args));
    }

    //    ERROR("name: {}, is_batched: {}, OUTPUT batch_size: {}", name_, is_batched_,
    //          GetBatchSize(rets));

    OUTCOME_TRY(Idxs2Keys(std::move(rets), outputs_, ctx.current()));
    return success();
  });
  handle->set_name(name_);
}

class TaskNodeCreator : public Creator<Node> {
 public:
  const char* GetName() const override { return "Task"; }
  int GetVersion() const override { return 0; }
  std::unique_ptr<Node> Create(const Value& value) override { return Task::Create(value); }
};

REGISTER_MODULE(Node, TaskNodeCreator);

}  // namespace mmdeploy::graph