Sync profiler (#1446)

* Sdk profiler (#1274)

* sdk-profiler

* fix lint

* support lift

* sync net module when profile

* use Scope*

* update use task name

* fix

* use std::unique_ptr<Event>

* remove mmdeploy::graph link for c and transform

* fix

* fix

* fix

* [Enhancement] refactor profiler (#1403)

* reduce profile node name

* add profiler for pipeline

* add profiler for cond

* udpate

* fix total time (#1451)
pull/1502/head
Chen Xin 2022-12-07 18:51:17 +08:00 committed by GitHub
parent 817917f111
commit c4e95f1ade
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 4217 additions and 5 deletions

View File

@ -3,6 +3,7 @@
#include "common_internal.h"
#include "executor_internal.h"
#include "mmdeploy/core/mat.h"
#include "mmdeploy/core/profiler.h"
mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value) {
if (!value) {
@ -72,6 +73,19 @@ int mmdeploy_device_create(const char* device_name, int device_id, mmdeploy_devi
void mmdeploy_device_destroy(mmdeploy_device_t device) { delete (Device*)device; }
int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler) {
*profiler = (mmdeploy_profiler_t) new profiler::Profiler(path);
return MMDEPLOY_SUCCESS;
}
void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler) {
if (profiler) {
auto p = (profiler::Profiler*)profiler;
p->Release();
delete p;
}
}
int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, const char* name,
const void* object) {
auto& ctx = *Cast(context);
@ -88,6 +102,12 @@ int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t typ
case MMDEPLOY_TYPE_MODEL:
ctx["model"][name] = *Cast((const mmdeploy_model_t)object);
break;
case MMDEPLOY_TYPE_PROFILER: {
const auto& profiler = *(profiler::Profiler*)object;
profiler::Scope* root(profiler.scope());
ctx["scope"] = root;
break;
}
default:
return MMDEPLOY_E_NOT_SUPPORTED;
}

View File

@ -56,6 +56,8 @@ typedef enum mmdeploy_status_t {
typedef struct mmdeploy_device* mmdeploy_device_t;
typedef struct mmdeploy_profiler* mmdeploy_profiler_t;
typedef struct mmdeploy_mat_t {
uint8_t* data;
int height;
@ -88,6 +90,7 @@ typedef enum mmdeploy_context_type_t {
MMDEPLOY_TYPE_MODEL = 2,
MMDEPLOY_TYPE_SCHEDULER = 3,
MMDEPLOY_TYPE_MAT = 4,
MMDEPLOY_TYPE_PROFILER = 5,
} mmdeploy_context_type_t;
#if __cplusplus
@ -123,6 +126,21 @@ MMDEPLOY_API int mmdeploy_device_create(const char* device_name, int device_id,
*/
MMDEPLOY_API void mmdeploy_device_destroy(mmdeploy_device_t device);
/**
* Create profiler
* @param path path to save the profile data
* @param profiler handle for profiler, should be added to context and deleted by
* mmdeploy_profiler_destroy
* @return status of create
*/
MMDEPLOY_API int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler);
/**
* Destroy profiler handle
* @param profiler handle for profiler, profile data will be written to disk after this call
*/
MMDEPLOY_API void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler);
/**
* Create context
* @param context

View File

@ -76,6 +76,24 @@ class Device {
std::shared_ptr<mmdeploy_device> device_;
};
class Profiler {
public:
explicit Profiler(std::string_view path) : path_(path) {
mmdeploy_profiler_t profiler{};
auto ec = mmdeploy_profiler_create(path_.c_str(), &profiler);
if (ec != MMDEPLOY_SUCCESS) {
throw_exception(static_cast<ErrorCode>(ec));
}
profiler_.reset(profiler, [](auto p) { mmdeploy_profiler_destroy(p); });
};
operator mmdeploy_profiler_t() const noexcept { return profiler_.get(); }
private:
std::string path_;
std::shared_ptr<mmdeploy_profiler> profiler_;
};
class Mat {
public:
Mat() : desc_{} {}
@ -187,6 +205,10 @@ class Context {
mmdeploy_context_add(*this, MMDEPLOY_TYPE_DEVICE, nullptr, device);
}
void Add(const Profiler& profiler) {
mmdeploy_context_add(*this, MMDEPLOY_TYPE_PROFILER, nullptr, profiler);
}
operator mmdeploy_context_t() const noexcept { return context_.get(); }
private:
@ -199,6 +221,7 @@ using cxx::Context;
using cxx::Device;
using cxx::Mat;
using cxx::Model;
using cxx::Profiler;
using cxx::Rect;
using cxx::Scheduler;

View File

@ -40,15 +40,16 @@ set(SRCS
utils/device_utils.cpp
utils/formatter.cpp
utils/stacktrace.cpp
profiler.cpp
)
mmdeploy_add_library(${PROJECT_NAME} ${SRCS})
target_include_directories(${PROJECT_NAME}
PUBLIC
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/csrc>
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/third_party/outcome>
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/third_party/concurrentqueue>
# TODO: remove dependency of `json`
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/third_party/json>
)

View File

@ -0,0 +1,87 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/core/profiler.h"
#include <iomanip>
namespace mmdeploy {
namespace profiler {
Event* Scope::Add(Event::Type type, Index index, TimePoint time_point) {
return profiler_->AddEvent({this, type, index, time_point});
}
Scope* Scope::CreateScope(std::string_view name) {
auto node = children_.emplace_back(profiler_->CreateScope(name));
node->parent_ = this;
return node;
}
void Scope::Dump(Scope* scope, std::ofstream& ofs) {
ofs << scope->name_ << " " << (void*)scope << " ";
for (auto& child : scope->children_) {
ofs << (void*)child << " ";
}
ofs << "\n";
for (const auto& child : scope->children_) {
Dump(child, ofs);
}
}
ScopedCounter::ScopedCounter(Scope* scope) {
if (scope) {
start_ = scope->Add(Event::kStart, scope->next_.fetch_add(1, std::memory_order_relaxed),
Clock::now());
}
}
ScopedCounter::~ScopedCounter() {
if (start_) {
start_->scope->Add(Event::kEnd, start_->index, Clock::now());
}
}
Profiler::Profiler(std::string_view path) : path_(path) { root_ = CreateScope("."); }
Scope* Profiler::CreateScope(std::string_view name) {
auto& node = nodes_.emplace_back();
node.profiler_ = this;
node.name_ = name;
return &node;
}
Event* Profiler::AddEvent(Event e) {
auto uptr = std::make_unique<Event>(e);
Event* pe = uptr.get();
events_.enqueue(std::move(uptr));
return pe;
}
void Profiler::Release() {
std::ofstream ofs(path_);
root_->Dump(ofs);
ofs << "----\n";
std::unique_ptr<Event> item;
std::vector<std::unique_ptr<Event>> vec;
while (events_.try_dequeue(item)) {
vec.push_back(std::move(item));
}
std::sort(vec.begin(), vec.end(),
[](const std::unique_ptr<Event>& a, const std::unique_ptr<Event>& b) {
return a->time_point < b->time_point;
});
for (int i = 0; i < vec.size(); i++) {
ofs << (void*)vec[i]->scope << " " << vec[i]->type << " " << vec[i]->index << " "
<< std::chrono::duration_cast<std::chrono::microseconds>(vec[i]->time_point -
vec[0]->time_point)
.count()
<< "\n";
}
}
} // namespace profiler
} // namespace mmdeploy

View File

@ -0,0 +1,86 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_CSRC_MMDEPLOY_CORE_PROFILER_H_
#define MMDEPLOY_CSRC_MMDEPLOY_CORE_PROFILER_H_
#include <atomic>
#include <chrono>
#include <deque>
#include <fstream>
#include <iostream>
#include <string>
#include <string_view>
#include <vector>
#include "concurrentqueue.h"
#include "mmdeploy/core/logger.h"
#include "mmdeploy/core/macro.h"
#include "mmdeploy/core/value.h"
namespace mmdeploy {
namespace profiler {
struct Profiler;
struct Scope;
using Clock = std::conditional_t<std::chrono::high_resolution_clock::is_steady,
std::chrono::high_resolution_clock, std::chrono::steady_clock>;
using TimePoint = Clock::time_point;
using Index = uint64_t;
struct Event {
enum Type { kStart, kEnd };
Scope* scope;
Type type;
Index index;
TimePoint time_point;
};
struct MMDEPLOY_API Scope {
Scope() = default;
Scope(const Scope&) = delete;
Scope(Scope&&) noexcept = delete;
Scope& operator=(const Scope&) = delete;
Scope& operator=(Scope&&) noexcept = delete;
Event* Add(Event::Type type, Index index, TimePoint time_point);
Scope* CreateScope(std::string_view name);
void Dump(Scope* scope, std::ofstream& ofs);
void Dump(std::ofstream& ofs) { Dump(this, ofs); }
Profiler* profiler_{};
Scope* parent_{};
std::vector<Scope*> children_;
std::atomic<Index> next_{};
std::string name_;
};
struct MMDEPLOY_API ScopedCounter {
explicit ScopedCounter(Scope* scope);
~ScopedCounter();
Event* start_{};
};
struct MMDEPLOY_API Profiler {
explicit Profiler(std::string_view path);
Scope* CreateScope(std::string_view name);
Event* AddEvent(Event e);
Scope* scope() const noexcept { return root_; }
void Release();
std::string path_;
std::deque<Scope> nodes_;
moodycamel::ConcurrentQueue<std::unique_ptr<Event>> events_;
Scope* root_{};
};
} // namespace profiler
MMDEPLOY_REGISTER_TYPE_ID(profiler::Scope*, 10);
} // namespace mmdeploy
#endif // MMDEPLOY_CSRC_MMDEPLOY_GRAPH_PROFILER_H_

View File

@ -54,7 +54,16 @@ Value get_divergent_output(Value::Array& rs, const vector<int>& ps) {
} // namespace
Sender<Value> Cond::Process(Sender<Value> input) {
return LetValue(std::move(input), [this](Value& _input) -> Sender<Value> {
auto index = std::make_shared<profiler::Index>();
if (scope_) {
*index = scope_->next_.fetch_add(1, std::memory_order_relaxed);
input = Then(std::move(input), [this, index](Value v) mutable {
scope_->Add(profiler::Event::kStart, *index, profiler::Clock::now());
return std::move(v);
});
}
Sender<Value> output = LetValue(std::move(input), [this](Value& _input) -> Sender<Value> {
assert(_input.is_array());
auto& as = _input.array();
auto ps = get_predicates(as.front().array());
@ -75,6 +84,14 @@ Sender<Value> Cond::Process(Sender<Value> input) {
});
}
});
if (scope_) {
output = Then(std::move(output), [this, index](Value v) {
scope_->Add(profiler::Event::kEnd, *index, profiler::Clock::now());
return std::move(v);
});
}
return output;
}
CondBuilder::CondBuilder(Value config) : Builder(std::move(config)) {}
@ -97,6 +114,12 @@ Result<unique_ptr<Node>> CondBuilder::BuildImpl() {
}
if (config_.contains("context")) {
update(body_config["context"].object(), config_["context"].object(), 2);
if (config_["context"].contains("scope")) {
auto scope = config_["context"]["scope"].get<profiler::Scope*>();
auto name = config_.value("name", std::string("Cond"));
cond->scope_ = scope->CreateScope(name);
body_config["context"]["scope"] = cond->scope_;
}
}
if (auto builder = Builder::CreateFromConfig(body_config).value()) {

View File

@ -4,6 +4,7 @@
#define MMDEPLOY_CSRC_MMDEPLOY_GRAPH_COND_H_
#include "mmdeploy/core/graph.h"
#include "mmdeploy/core/profiler.h"
namespace mmdeploy::graph {
@ -14,6 +15,7 @@ class Cond : public Node {
Sender<Value> Process(Sender<Value> input) override;
private:
profiler::Scope* scope_{nullptr};
std::unique_ptr<Node> node_;
int n_output_{0};
};

View File

@ -5,6 +5,7 @@
#include "mmdeploy/archive/json_archive.h"
#include "mmdeploy/core/graph.h"
#include "mmdeploy/core/model.h"
#include "mmdeploy/core/profiler.h"
#include "mmdeploy/graph/common.h"
namespace mmdeploy::graph {
@ -33,6 +34,11 @@ Result<unique_ptr<Node>> InferenceBuilder::BuildImpl() {
context["model"] = std::move(model);
auto pipeline_config = from_json<Value>(json);
if (context.contains("scope")) {
auto name = config_.value("name", config_["type"].get<std::string>());
auto scope = context["scope"].get_ref<profiler::Scope*&>()->CreateScope(name);
context["scope"] = scope;
}
pipeline_config["context"] = context;
MMDEPLOY_INFO("{}", pipeline_config);

View File

@ -97,13 +97,33 @@ StaticRouter::State::State(vector<int> use_count, Sender<Value> args)
}
Sender<Value> StaticRouter::Process(Sender<Value> args) {
auto index = std::make_shared<profiler::Index>();
auto start = std::make_shared<bool>(false);
if (scope_) {
*index = scope_->next_.fetch_add(1, std::memory_order_relaxed);
args = Then(std::move(args), [this, index, start](Value v) mutable {
if (*start == false) {
scope_->Add(profiler::Event::kStart, *index, profiler::Clock::now());
*start = true;
}
return std::move(v);
});
}
State state(use_count_, std::move(args));
for (size_t i = 0; i < nodes_.size(); ++i) {
auto input = state.Collect(input_coords_[i]);
auto output = nodes_[i]->Process(std::move(input));
state.Write(static_cast<int>(i), std::move(output));
}
return state.Collect(ret_coords_);
auto output = state.Collect(ret_coords_);
if (scope_) {
output = Then(std::move(output), [this, index](Value v) {
scope_->Add(profiler::Event::kEnd, *index, profiler::Clock::now());
return std::move(v);
});
}
return output;
}
/////////////////////////////////////////////////////////////////////
@ -112,6 +132,11 @@ Sender<Value> StaticRouter::Process(Sender<Value> args) {
Result<unique_ptr<StaticRouter>> StaticRouterBuilder::Build(const Value& config) {
try {
auto pipeline = std::make_unique<StaticRouter>();
if (config.contains("context") && config["context"].contains("scope")) {
auto name = config.value("name", std::string("Pipeline"));
auto scope = config["context"]["scope"].get<profiler::Scope*>();
pipeline->scope_ = scope->CreateScope(name);
}
const auto& task_configs = config["tasks"];
auto size = task_configs.size();
@ -139,6 +164,9 @@ Result<unique_ptr<StaticRouter>> StaticRouterBuilder::Build(const Value& config)
}
if (config.contains("context")) {
update(task_config["context"].object(), config["context"].object(), 2);
if (pipeline->scope_) {
task_config["context"]["scope"] = pipeline->scope_;
}
}
OUTCOME_TRY(auto builder, Builder::CreateFromConfig(task_config));

View File

@ -8,6 +8,7 @@
#include "mmdeploy/core/graph.h"
#include "mmdeploy/core/module.h"
#include "mmdeploy/core/operator.h"
#include "mmdeploy/core/profiler.h"
#include "mmdeploy/core/value.h"
#include "mmdeploy/execution/schedulers/registry.h"
#include "mmdeploy/execution/when_all_value.h"
@ -34,6 +35,7 @@ class StaticRouter : public Node {
vector<int> use_count_;
vector<vector<Coords>> input_coords_;
vector<Coords> ret_coords_;
profiler::Scope* scope_{nullptr};
};
class StaticRouterBuilder {

View File

@ -12,6 +12,7 @@ Sender<Value> Task::Process(Sender<Value> input) {
assert(v.is_array());
// handle empty input
if (v.front().empty()) {
profiler::ScopedCounter counter(scope_);
return TransferJust(*sched_, Value(Value::Array(v.size(), Value::kArray)));
}
if (v.front().is_array() && !is_batched_) {
@ -24,6 +25,7 @@ Sender<Value> Task::Process(Sender<Value> input) {
return Value{std::move(input), std::move(output)};
})
| Bulk(batch_size, [&](size_t index, Value& in_out) {
profiler::ScopedCounter counter(scope_);
const auto& input = in_out[0];
auto& output = in_out[1];
output[index] = module_->Process(input[index]).value();
@ -33,8 +35,10 @@ Sender<Value> Task::Process(Sender<Value> input) {
});
// clang-format on
} else {
return DynamicBatch(TransferJust(*sched_, std::move(v)), batch_context_,
[&](const Value& u) { return module_->Process(u).value(); });
return DynamicBatch(TransferJust(*sched_, std::move(v)), batch_context_, [&](const Value& u) {
profiler::ScopedCounter counter(scope_);
return module_->Process(u).value();
});
}
});
}
@ -63,6 +67,17 @@ inline Result<unique_ptr<Module>> CreateModule(const Value& config) {
Result<unique_ptr<Node>> TaskBuilder::BuildImpl() {
try {
auto task = std::make_unique<Task>();
if (auto scope = Maybe{config_} / "context" / "scope" / identity<profiler::Scope*>{}) {
auto module_name = config_.value<std::string>("module", "");
auto name = config_.value<std::string>("name", "");
string scope_name = (name != "") ? name : module_name;
task->scope_ = (*scope)->CreateScope(scope_name);
config_["context"]["scope"] = task->scope_;
if (module_name == "Transform") {
task->scope_ = nullptr;
}
}
OUTCOME_TRY(task->module_, CreateModule(config_));
if (auto name = Maybe{config_} / "scheduler" / identity<string>{}) {

View File

@ -4,6 +4,7 @@
#define MMDEPLOY_CSRC_GRAPH_TASK_H_
#include "mmdeploy/core/graph.h"
#include "mmdeploy/core/profiler.h"
namespace mmdeploy::graph {
@ -19,6 +20,7 @@ class Task : public Node {
bool is_batched_{false};
bool is_thread_safe_{false};
dynamic_batch_t::context_t batch_context_;
profiler::Scope* scope_{nullptr};
};
class TaskBuilder : public Builder {

View File

@ -28,6 +28,9 @@ struct NetModule::Impl {
auto init = [&]() -> Result<void> {
auto name = args["name"].get<std::string>();
auto& context = args["context"];
if (context.contains("scope")) {
is_profiling_ = true;
}
auto model = context["model"].get<Model>();
OUTCOME_TRY(auto config, model.GetModelConfig(name));
device_ = context.value("device", Device{"cpu"});
@ -177,6 +180,9 @@ struct NetModule::Impl {
output[0].emplace(name, std::move(tmp));
}
}
if (is_profiling_) {
OUTCOME_TRY(stream_.Wait());
}
return output;
}
@ -190,6 +196,7 @@ struct NetModule::Impl {
std::map<std::string, std::string> input_mapping_;
// outer scope to model output names
std::map<std::string, std::string> output_mapping_;
bool is_profiling_{false};
};
NetModule::~NetModule() = default;

View File

@ -35,6 +35,10 @@ Compose::Compose(const Value& args, int version) : Transform(args) {
context["fuse_transform"] = true;
context["sha256"] = sha256;
}
if (context.contains("scope")) {
auto scope = context["scope"].get<profiler::Scope*>();
scope_ = scope->CreateScope("Compose");
}
for (auto cfg : args["transforms"]) {
cfg["context"] = context;
auto type = cfg.value("type", std::string{});
@ -45,6 +49,17 @@ Compose::Compose(const Value& args, int version) : Transform(args) {
gRegistry<Transform>().List());
throw_exception(eEntryNotFound);
}
if (scope_) {
auto scope = scope_->CreateScope(type);
if (type == "Lift") {
cfg["context"]["scope"] = scope;
transform_scopes_.push_back(nullptr);
} else {
transform_scopes_.push_back(scope);
}
} else {
transform_scopes_.push_back(nullptr);
}
auto transform = creator->Create(cfg);
if (!transform) {
MMDEPLOY_ERROR("Failed to create transform: {}, config: {}", type, cfg);
@ -57,10 +72,15 @@ Compose::Compose(const Value& args, int version) : Transform(args) {
Result<Value> Compose::Process(const Value& input) {
Value output = input;
Value::Array intermediates;
int idx = 0;
for (auto& transform : transforms_) {
profiler::ScopedCounter counter(transform_scopes_[idx++]);
OUTCOME_TRY(auto t, transform->Process(output));
SaveIntermediates(t, intermediates);
output = std::move(t);
if (transform_scopes_[idx - 1]) {
OUTCOME_TRY(stream_.Wait());
}
}
OUTCOME_TRY(stream_.Wait());
return std::move(output);

View File

@ -3,6 +3,7 @@
#ifndef MMDEPLOY_SRC_PREPROCESS_TRANSFORM_COMPOSE_H_
#define MMDEPLOY_SRC_PREPROCESS_TRANSFORM_COMPOSE_H_
#include "mmdeploy/core/profiler.h"
#include "transform.h"
namespace mmdeploy {
@ -17,6 +18,8 @@ class MMDEPLOY_API Compose : public Transform {
private:
std::vector<std::unique_ptr<Transform>> transforms_;
Stream stream_;
std::vector<profiler::Scope*> transform_scopes_;
profiler::Scope* scope_{nullptr};
};
} // namespace mmdeploy

View File

@ -78,11 +78,14 @@ int main() {
mmdeploy_device_t device{};
mmdeploy_device_create("cpu", 0, &device);
mmdeploy_profiler_t profiler{};
mmdeploy_profiler_create("profile.bin", &profiler);
mmdeploy_context_t ctx{};
mmdeploy_context_create(&ctx);
mmdeploy_context_add(ctx, MMDEPLOY_TYPE_DEVICE, nullptr, device);
mmdeploy_context_add(ctx, MMDEPLOY_TYPE_PROFILER, nullptr, profiler);
auto thread_pool = mmdeploy_executor_create_thread_pool(4);
auto infer_thread = mmdeploy_executor_create_thread();
@ -118,6 +121,7 @@ int main() {
mmdeploy_scheduler_destroy(thread_pool);
mmdeploy_device_destroy(device);
mmdeploy_profiler_destroy(profiler);
return 0;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,118 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import numpy as np
from texttable import Texttable
def parse_args():
parser = argparse.ArgumentParser(
description='Analyze sdk profiler file tool.')
parser.add_argument('profile_file', help='SDK profile file path')
args = parser.parse_args()
return args
def get_name(addr, prev, addr2name, used_addr, depth, skip):
node_name = addr2name[addr] if not skip else ''
if addr not in prev:
return ' ' * depth * 4 + node_name
prev_addr = prev[addr]
if prev_addr in used_addr:
depth += 1
skip = True
prev_name = get_name(prev[addr], prev, addr2name, used_addr, depth, skip)
if len(prev_name.split()) == 0:
return prev_name + node_name
return prev_name + '/' + node_name
def main():
args = parse_args()
with open(args.profile_file) as f:
data = f.read()
graph, events = data.split('----\n')
graph = graph.strip().split('\n')
events = events.strip().split('\n')
addr2name = {}
addr2id = {}
id2addr = {}
next = {}
prev = {}
for i, line in enumerate(graph):
info = line.split()
name, addr = info[:2]
addr2name[addr] = name
addr2id[addr] = i
id2addr[i] = addr
next[addr] = []
for child in info[2:]:
next[addr].append(child)
prev[child] = addr
n_active = {i: 0 for i in range(len(addr2id))}
n_call = {i: 0 for i in range(len(addr2id))}
t_occupy = {i: 0 for i in range(len(addr2id))}
t_usage = {i: 0 for i in range(len(addr2id))}
t_time = {i: [] for i in range(len(addr2id))}
used_id = set()
used_addr = set()
event_start = {}
now = 0
first_id = None
for event in events:
words = event.split()
addr = words[0]
id = addr2id[addr]
used_addr.add(addr)
used_id.add(id)
kind, index, ts = map(int, words[1:])
if first_id is None:
first_id = id
if id == first_id and kind == 0 and n_active[id] == 0:
now = ts
key = (id, index)
delta = ts - now
now = ts
for i, n_act in n_active.items():
if n_act > 0:
t_occupy[i] += delta
t_usage[i] += delta * n_act
if kind == 0:
event_start[key] = ts
n_active[id] += 1
n_call[id] += 1
else:
dt = ts - event_start[key]
t_time[id].append(dt)
event_start.pop(key)
n_active[id] -= 1
table = Texttable(max_width=0)
table.header(
['name', 'occupy', 'usage', 'n_call', 't_mean', 't_50%', 't_90%'])
for id in sorted(list(used_id)):
occupy = t_occupy[id] / (t_occupy[first_id])
usage = t_usage[id] / (t_occupy[first_id])
times = sorted(t_time[id])
t_mean = np.mean(times) / 1000
t_50 = times[int(len(times) * 0.5)] / 1000
t_90 = times[int(len(times) * 0.9)] / 1000
name = get_name(id2addr[id], prev, addr2name, used_addr, 0, False)
if len(next[id2addr[id]]) != 0:
occupy = '-'
usage = '-'
table.add_row([name, occupy, usage, n_call[id], t_mean, t_50, t_90])
print(table.draw())
if __name__ == '__main__':
main()