Fix JIT output capture

pull/6/head
A. R. Shajii 2021-11-21 09:21:40 -05:00
parent 71cad478f7
commit a6ada78aa6
3 changed files with 88 additions and 7 deletions

View File

@ -9,5 +9,7 @@ char RuntimeErrorInfo::ID = 0;
char PluginErrorInfo::ID = 0;
char IOErrorInfo::ID = 0;
} // namespace error
} // namespace codon

View File

@ -128,5 +128,23 @@ public:
static char ID;
};
class IOErrorInfo : public llvm::ErrorInfo<IOErrorInfo> {
private:
std::string message;
public:
explicit IOErrorInfo(const std::string &message) : message(message) {}
std::string getMessage() const { return message; }
void log(llvm::raw_ostream &out) const override { out << message; }
std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}
static char ID;
};
} // namespace error
} // namespace codon

View File

@ -1,6 +1,8 @@
#include "jit.h"
#include <sstream>
#include <cstdio>
#include <cstdlib>
#include <unistd.h>
#include "codon/parser/peg/peg.h"
#include "codon/parser/visitors/doc/doc.h"
@ -20,11 +22,65 @@ const std::string JIT_FILENAME = "<jit>";
class CaptureOutput {
private:
std::streambuf *orig;
std::vector<char> buf;
int outpipe[2];
int saved;
bool stopped;
std::string result;
llvm::Error err(const std::string &msg) {
return llvm::make_error<error::IOErrorInfo>(msg);
}
public:
CaptureOutput(std::streambuf *buf) : orig(std::cout.rdbuf(buf)) {}
~CaptureOutput() { std::cout.rdbuf(orig); }
static constexpr size_t BUFFER_SIZE = 65536;
CaptureOutput() : buf(BUFFER_SIZE), outpipe(), saved(0), stopped(false), result() {}
std::string getResult() const { return result; }
llvm::Error start() {
if (stopped)
return llvm::Error::success();
saved = dup(STDOUT_FILENO);
if (saved == -1)
return err("dup(STDOUT_FILENO) call failed");
if (pipe(outpipe) != 0)
return err("pipe(outpipe) call failed");
if (dup2(outpipe[1], STDOUT_FILENO) == -1)
return err("dup2(outpipe[1], STDOUT_FILENO) call failed");
if (close(outpipe[1]) == -1)
return err("close(outpipe[1]) call failed");
return llvm::Error::success();
}
llvm::Error stop() {
if (stopped)
return llvm::Error::success();
stopped = true;
if (fflush(stdout) != 0)
return err("fflush(stdout) call failed");
auto count = read(outpipe[0], buf.data(), buf.size() - 1);
if (count == -1)
return err("read(outpipe[0], buf.data(), buf.size() - 1) call failed");
if (dup2(saved, STDOUT_FILENO) == -1)
return err("dup2(saved, STDOUT_FILENO) call failed");
result = std::string(buf.data(), count);
return llvm::Error::success();
}
~CaptureOutput() {
seqassert(dup2(saved, STDOUT_FILENO) != -1, "IO error when capturing stdout");
}
};
} // namespace
@ -90,10 +146,15 @@ llvm::Expected<std::string> JIT::run(const ir::Func *input,
return std::move(err);
auto *repl = (InputFunc *)func->getAddress();
std::stringstream buffer;
std::string output;
try {
CaptureOutput(buffer.rdbuf());
CaptureOutput capture;
if (auto err = capture.start())
return std::move(err);
(*repl)();
if (auto err = capture.stop())
return std::move(err);
output = capture.getResult();
} catch (const JITError &e) {
std::vector<std::string> backtrace;
for (auto pc : e.getBacktrace()) {
@ -105,7 +166,7 @@ llvm::Expected<std::string> JIT::run(const ir::Func *input,
e.what(), e.getFile(), e.getLine(),
e.getCol(), backtrace);
}
return buffer.str();
return output;
}
llvm::Expected<std::string> JIT::exec(const std::string &code) {