// Copyright (c) OpenMMLab. All rights reserved.

#include "net_module.h"

#include <thread>

#include "archive/value_archive.h"
#include "core/logger.h"
#include "core/model.h"
#include "core/module.h"
#include "core/net.h"
#include "core/registry.h"
#include "core/utils/formatter.h"
#include "core/utils/scope_counter.h"
#include "experimental/module_adapter.h"

using std::string;
using std::vector;

namespace mmdeploy {

struct NetModule::Impl {
  using Input = std::map<std::string, Tensor>;
  using Output = std::map<std::string, Tensor>;

  explicit Impl(const Value& args) {
    DEBUG("Net Module cfg: {}", args);
    auto init = [&]() -> Result<void> {
      auto name = args["name"].get<std::string>();
      auto& context = args["context"];
      auto model = context["model"].get<Model>();
      OUTCOME_TRY(auto config, model.GetModelConfig(name));
      device_ = context.value("device", Device{"cpu"});
      stream_ = context.value("stream", Stream::GetDefault(device_));
      auto creator = Registry<Net>::Get().GetCreator(config.backend);
      if (!creator) {
        ERROR("Net backend not found: {}", config.backend);
        return Status(eEntryNotFound);
      }
      auto net_cfg = args;
      net_cfg["context"].update({{"device", device_}, {"stream", stream_}});
      net_ = creator->Create(net_cfg);
      if (!net_) {
        return Status(eFail);
      }
      OUTCOME_TRY(InitializeInputTensors(args));
      OUTCOME_TRY(InitializeOutputTensors(args));
      return success();
    };
    init().value();
  }

  Result<void> InitializeInputTensors(const Value& args) {
    auto inputs = args.value<Value>("input_map", ValueType::kObject);
    for (auto it = inputs.begin(); it != inputs.end(); ++it) {
      input_mapping_.insert({(*it).get<std::string>(), it.key()});
    }
    OUTCOME_TRY(inputs_, net_->GetInputTensors());
    for (const auto& t : inputs_) {
      input_mapping_.insert({t.name(), t.name()});
    }
    return success();
  }

  Result<void> InitializeOutputTensors(const Value& args) {
    auto outputs = args.value<Value>("output_map", ValueType::kObject);
    for (auto it = outputs.begin(); it != outputs.end(); ++it) {
      output_mapping_.insert({(*it).get<std::string>(), it.key()});
    }
    OUTCOME_TRY(outputs_, net_->GetOutputTensors());
    for (const auto& t : outputs_) {
      output_mapping_.insert({t.name(), t.name()});
    }
    return success();
  }

  Result<TensorShape> InferInputShape(const vector<Tensor>& input) {
    auto batch_size = input.size();
    auto& exemplar = input.front();
    auto shape = exemplar.shape();
    if (batch_size == 1) {
      return shape;
    }
    if (shape[0] != 1) {
      ERROR("unsupported shape for batch assemble: {}", shape);
      return Status(eNotSupported);
    }
    for (int i = 1; i < input.size(); ++i) {
      auto& sample = input[i];
      if (sample.shape() != shape) {
        ERROR("shapes are not consistent across the batch");
        return Status(eNotSupported);
      }
    }
    shape[0] = static_cast<int64_t>(batch_size);
    return shape;
  }

  Result<vector<TensorShape> > InferInputShape(const vector<vector<Tensor> >& inputs) {
    vector<TensorShape> shapes;
    shapes.reserve(inputs.size());
    for (const auto& input : inputs) {
      OUTCOME_TRY(auto shape, InferInputShape(input));
      shapes.push_back(std::move(shape));
    }
    return shapes;
  }

  Result<std::vector<Output> > Forward(const std::vector<Input>& input) {
    //    auto t0 = std::chrono::high_resolution_clock::now();
    //
    auto batch_size = static_cast<int>(input.size());

    std::vector<std::vector<Tensor> > input_samples;
    input_samples.reserve(inputs_.size());
    for (const auto& t : inputs_) {
      auto name = input_mapping_.at(t.name());
      std::vector<Tensor> tmp;
      tmp.reserve(input.size());
      for (int i = 0; i < input.size(); ++i) {
        auto& sample = input[i];
        if (auto it = sample.find(name); it != sample.end()) {
          tmp.push_back(it->second);
        } else {
          ERROR("sample {} missing key {}", i, name);
          return Status(eInvalidArgument);
        }
      }
      input_samples.push_back(std::move(tmp));
    }

    // 1. calculate input shape
    OUTCOME_TRY(auto input_shapes, InferInputShape(input_samples));

    // 2. call backend's reshape
    OUTCOME_TRY(net_->Reshape(input_shapes));

    // 3. fill input tensor
    for (int i = 0; i < inputs_.size(); ++i) {
      auto& src = input_samples[i];
      auto& dst = inputs_[i];
      if (dst.shape() != input_shapes[i]) {
        ERROR("inconsistent input shape, expect {}, got {}", input_shapes[i], dst.shape());
        return Status(eFail);
      }
      if (src.size() > 1) {
        for (int j = 0; j < src.size(); ++j) {
          auto slice = dst.Slice(j);
          OUTCOME_TRY(src[j].CopyTo(slice, stream_));
        }
      } else {
        OUTCOME_TRY(src[0].CopyTo(dst, stream_));
      }
    }

    // 5. forward
    OUTCOME_TRY(net_->Forward());

    vector<Output> output(batch_size);
    for (const auto& t : outputs_) {
      auto name = output_mapping_.at(t.name());
      auto desc = t.desc();
      desc.device = device_;
      Tensor tmp(desc);
      if (tmp.size()) {
        OUTCOME_TRY(t.CopyTo(tmp, stream_));
      } else {
        WARN("copy skipped due to zero sized tensor");
      }
      if (output.size() > 1) {
        for (int i = 0; i < output.size(); ++i) {
          output[i].insert({name, tmp.Slice(i)});
        }
      } else {
        output[0].insert({name, std::move(tmp)});
      }
    }

    return output;
  }

  Device device_;
  Stream stream_;
  std::unique_ptr<Net> net_;
  Span<Tensor> inputs_;
  Span<Tensor> outputs_;
  // outer scope to model input names
  std::map<std::string, std::string> input_mapping_;
  // outer scope to model output names
  std::map<std::string, std::string> output_mapping_;
};

NetModule::~NetModule() = default;

NetModule::NetModule(const Value& args) : impl_(std::make_unique<Impl>(args)) {}

Result<Value> NetModule::operator()(const Value& input) {
  auto filter = [](const Value& sample) {
    Impl::Input tensors;
    for (auto it = sample.begin(); it != sample.end(); ++it) {
      if (it->is_any<Tensor>()) {
        tensors.insert({it.key(), it->get<Tensor>()});
      }
    }
    return tensors;
  };
  std::vector<Impl::Input> batch;
  if (input.is_array()) {
    batch.reserve(input.size());
    for (const auto& sample : input) {
      batch.push_back(filter(sample));
    }
  } else if (input.is_object()) {
    batch.push_back(filter(input));
  } else {
    return Status(eNotSupported);
  }
  OUTCOME_TRY(auto batch_output, impl_->Forward(batch));
  if (input.is_array()) {
    return to_value(batch_output);
  } else {
    return to_value(batch_output.at(0));
  }
}

class NetModuleCreator : public Creator<Module> {
 public:
  const char* GetName() const override { return "Net"; }
  int GetVersion() const override { return 0; }
  std::unique_ptr<Module> Create(const Value& value) override {
    return CreateTask(NetModule{value});
  }
};

REGISTER_MODULE(Module, NetModuleCreator);

}  // namespace mmdeploy