1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
2021-09-27 14:02:44 -04:00

974 lines
32 KiB
C++

#include "openmp.h"
#include "schedule.h"
#include <algorithm>
#include <iterator>
#include <limits>
#include "sir/util/cloning.h"
#include "sir/util/irtools.h"
#include "sir/util/outlining.h"
namespace seq {
namespace ir {
namespace transform {
namespace parallel {
namespace {
const std::string ompModule = "std.openmp";
const std::string builtinModule = "std.internal.builtin";
struct OMPTypes {
types::Type *i32 = nullptr;
types::Type *i8ptr = nullptr;
types::Type *i32ptr = nullptr;
types::Type *micro = nullptr;
types::Type *routine = nullptr;
types::Type *ident = nullptr;
types::Type *task = nullptr;
explicit OMPTypes(Module *M) {
i32 = M->getIntNType(32, /*sign=*/true);
i8ptr = M->getPointerType(M->getByteType());
i32ptr = M->getPointerType(i32);
micro = M->getFuncType(M->getVoidType(), {i32ptr, i32ptr});
routine = M->getFuncType(i32, {i32ptr, i8ptr});
ident = M->getOrRealizeType("Ident", {}, ompModule);
task = M->getOrRealizeType("Task", {}, ompModule);
seqassert(ident, "openmp.Ident type not found");
seqassert(task, "openmp.Task type not found");
}
};
Var *getVarFromOutlinedArg(Value *arg) {
if (auto *val = cast<VarValue>(arg)) {
return val->getVar();
} else if (auto *val = cast<PointerValue>(arg)) {
return val->getVar();
} else {
seqassert(false, "unknown outline var");
}
return nullptr;
}
// we create the locks lazily to avoid them when they're not needed
struct ReductionLocks {
Var *mainLock =
nullptr; // lock used in calls to _reduce_no_wait and _end_reduce_no_wait
Var *critLock = nullptr; // lock used in reduction critical sections
Var *createLock(Module *M) {
auto *main = cast<BodiedFunc>(M->getMainFunc());
auto *lck = util::alloc(M->getByteType(), 32);
auto *val = util::makeVar(lck, cast<SeriesFlow>(main->getBody()),
/*parent=*/nullptr, /*prepend=*/true);
return val->getVar();
}
Var *getMainLock(Module *M) {
if (!mainLock)
mainLock = createLock(M);
return mainLock;
}
Var *getCritLock(Module *M) {
if (!critLock)
critLock = createLock(M);
return critLock;
}
};
struct Reduction {
enum Kind {
NONE,
ADD,
MUL,
AND,
OR,
XOR,
MIN,
MAX,
};
Kind kind = Kind::NONE;
Var *shared = nullptr;
types::Type *getType() {
auto *ptrType = cast<types::PointerType>(shared->getType());
seqassert(ptrType, "expected shared var to be of pointer type");
return ptrType->getBase();
}
Value *getInitial() {
if (!*this)
return nullptr;
auto *M = shared->getModule();
auto *type = getType();
if (isA<types::IntType>(type)) {
switch (kind) {
case Kind::ADD:
return M->getInt(0);
case Kind::MUL:
return M->getInt(1);
case Kind::AND:
return M->getInt(~0);
case Kind::OR:
return M->getInt(0);
case Kind::XOR:
return M->getInt(0);
case Kind::MIN:
return M->getInt(std::numeric_limits<int64_t>::max());
case Kind::MAX:
return M->getInt(std::numeric_limits<int64_t>::min());
default:
return nullptr;
}
} else if (isA<types::FloatType>(type)) {
switch (kind) {
case Kind::ADD:
return M->getFloat(0.);
case Kind::MUL:
return M->getFloat(1.);
case Kind::MIN:
return M->getFloat(std::numeric_limits<double>::max());
case Kind::MAX:
return M->getFloat(std::numeric_limits<double>::min());
default:
return nullptr;
}
}
auto *init = (*type)();
if (!init || !init->getType()->is(type))
return nullptr;
return init;
}
Value *generateNonAtomicReduction(Value *ptr, Value *arg) {
auto *M = ptr->getModule();
Value *lhs = util::ptrLoad(ptr);
Value *result = nullptr;
switch (kind) {
case Kind::ADD:
result = *lhs + *arg;
break;
case Kind::MUL:
result = *lhs * *arg;
break;
case Kind::AND:
result = *lhs & *arg;
break;
case Kind::OR:
result = *lhs | *arg;
break;
case Kind::XOR:
result = *lhs ^ *arg;
break;
case Kind::MIN: {
auto *tup = util::makeTuple({lhs, arg});
auto *fn = M->getOrRealizeFunc("min", {tup->getType()}, {}, builtinModule);
seqassert(fn, "min function not found");
result = util::call(fn, {tup});
break;
}
case Kind::MAX: {
auto *tup = util::makeTuple({lhs, arg});
auto *fn = M->getOrRealizeFunc("max", {tup->getType()}, {}, builtinModule);
seqassert(fn, "max function not found");
result = util::call(fn, {tup});
break;
}
default:
return nullptr;
}
return util::ptrStore(ptr, result);
}
Value *generateAtomicReduction(Value *ptr, Value *arg, Var *loc, Var *gtid,
ReductionLocks &locks) {
auto *M = ptr->getModule();
auto *type = getType();
std::string func = "";
if (isA<types::IntType>(type)) {
switch (kind) {
case Kind::ADD:
func = "_atomic_int_add";
break;
case Kind::MUL:
func = "_atomic_int_mul";
break;
case Kind::AND:
func = "_atomic_int_and";
break;
case Kind::OR:
func = "_atomic_int_or";
break;
case Kind::XOR:
func = "_atomic_int_xor";
break;
case Kind::MIN:
func = "_atomic_int_min";
break;
case Kind::MAX:
func = "_atomic_int_max";
break;
default:
break;
}
} else if (isA<types::FloatType>(type)) {
switch (kind) {
case Kind::ADD:
func = "_atomic_float_add";
break;
case Kind::MUL:
func = "_atomic_float_mul";
break;
case Kind::MIN:
func = "_atomic_float_min";
break;
case Kind::MAX:
func = "_atomic_float_max";
break;
default:
break;
}
}
if (!func.empty()) {
auto *atomicOp =
M->getOrRealizeFunc(func, {ptr->getType(), arg->getType()}, {}, ompModule);
seqassert(atomicOp, "atomic op '{}' not found", func);
return util::call(atomicOp, {ptr, arg});
}
switch (kind) {
case Kind::ADD:
func = "__atomic_add__";
break;
case Kind::MUL:
func = "__atomic_mul__";
break;
case Kind::AND:
func = "__atomic_and__";
break;
case Kind::OR:
func = "__atomic_or__";
break;
case Kind::XOR:
func = "__atomic_xor__";
break;
case Kind::MIN:
func = "__atomic_min__";
break;
case Kind::MAX:
func = "__atomic_max__";
break;
default:
break;
}
if (!func.empty()) {
auto *atomicOp =
M->getOrRealizeMethod(arg->getType(), func, {ptr->getType(), arg->getType()});
if (atomicOp)
return util::call(atomicOp, {ptr, arg});
}
seqassert(loc && gtid, "loc and/or gtid are null");
auto *lck = locks.getCritLock(M);
auto *critBegin = M->getOrRealizeFunc(
"_critical_begin", {loc->getType(), gtid->getType(), lck->getType()}, {},
ompModule);
seqassert(critBegin, "critical begin function not found");
auto *critEnd = M->getOrRealizeFunc(
"_critical_end", {loc->getType(), gtid->getType(), lck->getType()}, {},
ompModule);
seqassert(critEnd, "critical end function not found");
auto *critEnter = util::call(
critBegin, {M->Nr<VarValue>(loc), M->Nr<VarValue>(gtid), M->Nr<VarValue>(lck)});
auto *operation = generateNonAtomicReduction(ptr, arg);
auto *critExit = util::call(
critEnd, {M->Nr<VarValue>(loc), M->Nr<VarValue>(gtid), M->Nr<VarValue>(lck)});
// make sure the unlock is in a finally-block
return util::series(critEnter, M->Nr<TryCatchFlow>(util::series(operation),
util::series(critExit)));
}
operator bool() const { return kind != Kind::NONE; }
};
struct ReductionFunction {
std::string name;
Reduction::Kind kind;
bool method;
};
struct ReductionIdentifier : public util::Operator {
std::vector<Var *> shareds;
std::unordered_map<id_t, Reduction> reductions;
explicit ReductionIdentifier(std::vector<Var *> shareds)
: util::Operator(), shareds(std::move(shareds)), reductions() {}
bool isShared(Var *shared) {
for (auto *v : shareds) {
if (shared->getId() == v->getId())
return true;
}
return false;
}
bool isSharedDeref(Var *shared, Value *v) {
auto *M = v->getModule();
auto *ptrType = cast<types::PointerType>(shared->getType());
seqassert(ptrType, "expected shared var to be of pointer type");
auto *type = ptrType->getBase();
if (util::isCallOf(v, Module::GETITEM_MAGIC_NAME, {ptrType, M->getIntType()}, type,
/*method=*/true)) {
auto *call = cast<CallInstr>(v);
auto *var = util::getVar(call->front());
return util::isConst<int64_t>(call->back(), 0) && var &&
var->getId() == shared->getId();
}
return false;
}
Reduction getReductionFromCall(CallInstr *v) {
auto *M = v->getModule();
auto *func = util::getFunc(v->getCallee());
if (v->numArgs() != 3 || !func ||
func->getUnmangledName() != Module::SETITEM_MAGIC_NAME)
return {};
std::vector<Value *> args(v->begin(), v->end());
Value *self = args[0];
Value *idx = args[1];
Value *item = args[2];
Var *shared = util::getVar(self);
if (!shared || !isShared(shared) || !util::isConst<int64_t>(idx, 0))
return {};
auto *ptrType = cast<types::PointerType>(shared->getType());
seqassert(ptrType, "expected shared var to be of pointer type");
auto *type = ptrType->getBase();
// double-check the call
if (!util::isCallOf(v, Module::SETITEM_MAGIC_NAME,
{self->getType(), idx->getType(), item->getType()},
M->getVoidType(), /*method=*/true))
return {};
const std::vector<ReductionFunction> reductionFunctions = {
{Module::ADD_MAGIC_NAME, Reduction::Kind::ADD, true},
{Module::MUL_MAGIC_NAME, Reduction::Kind::MUL, true},
{Module::AND_MAGIC_NAME, Reduction::Kind::AND, true},
{Module::OR_MAGIC_NAME, Reduction::Kind::OR, true},
{Module::XOR_MAGIC_NAME, Reduction::Kind::XOR, true},
{"min", Reduction::Kind::MIN, false},
{"max", Reduction::Kind::MAX, false},
};
for (auto &rf : reductionFunctions) {
if (rf.method) {
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true))
continue;
} else {
if (!util::isCallOf(item, rf.name, {M->getTupleType({type, type})}, type,
/*method=*/false))
continue;
}
auto *callRHS = cast<CallInstr>(item);
if (!rf.method) {
callRHS = cast<CallInstr>(callRHS->front()); // this will be Tuple.__new__
if (!callRHS || callRHS->numArgs() != 2)
continue;
}
auto *arg1 = callRHS->front();
auto *arg2 = callRHS->back();
Value *deref = nullptr;
Value *other = nullptr;
if (isSharedDeref(shared, arg1)) {
deref = arg1;
other = arg2;
} else if (isSharedDeref(shared, arg2)) {
deref = arg2;
other = arg1;
}
if (!deref)
return {};
Reduction reduction = {rf.kind, shared};
if (!reduction.getInitial())
return {};
return reduction;
}
return {};
}
Reduction getReduction(Var *shared) {
auto it = reductions.find(shared->getId());
return (it != reductions.end()) ? it->second : Reduction();
}
void handle(CallInstr *v) override {
if (auto reduction = getReductionFromCall(v)) {
auto it = reductions.find(reduction.shared->getId());
// if we've seen the var before, make sure it's consistent
// otherwise mark as invalid via an empty reduction
if (it == reductions.end()) {
reductions.emplace(reduction.shared->getId(), reduction);
} else if (it->second && it->second.kind != reduction.kind) {
it->second = {};
}
}
}
};
struct ImperativeLoopTemplateReplacer : public util::Operator {
struct SharedInfo {
unsigned memb; // member index in template's `extra` arg
Var *local; // the local var we create to store current value
Reduction reduction; // the reduction we're performing, or empty if none
};
BodiedFunc *parent;
CallInstr *replacement;
Var *loopVar;
OMPSched *sched;
ReductionIdentifier *reds;
int64_t step;
std::vector<SharedInfo> sharedInfo;
ReductionLocks locks;
Var *locRef;
Var *reductionLocRef;
Var *gtid;
ImperativeLoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement,
Var *loopVar, OMPSched *sched,
ReductionIdentifier *reds, int64_t step)
: util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar),
sched(sched), reds(reds), step(step), sharedInfo(), locks(), locRef(nullptr),
reductionLocRef(nullptr), gtid(nullptr) {}
unsigned numReductions() {
unsigned num = 0;
for (auto &info : sharedInfo) {
if (info.reduction)
num += 1;
}
return num;
}
Value *getReductionTuple() {
auto *M = parent->getModule();
std::vector<Value *> elements;
for (auto &info : sharedInfo) {
if (info.reduction)
elements.push_back(M->Nr<PointerValue>(info.local));
}
return util::makeTuple(elements, M);
}
BodiedFunc *makeReductionFunc() {
auto *M = parent->getModule();
auto *tupleType = getReductionTuple()->getType();
auto *argType = M->getPointerType(tupleType);
auto *funcType = M->getFuncType(M->getVoidType(), {argType, argType});
auto *reducer = M->Nr<BodiedFunc>("__omp_reducer");
reducer->realize(funcType, {"lhs", "rhs"});
auto *lhsVar = reducer->arg_front();
auto *rhsVar = reducer->arg_back();
auto *body = M->Nr<SeriesFlow>();
unsigned next = 0;
for (auto &info : sharedInfo) {
if (info.reduction) {
auto *lhs = util::ptrLoad(M->Nr<VarValue>(lhsVar));
auto *rhs = util::ptrLoad(M->Nr<VarValue>(rhsVar));
auto *lhsElem = util::tupleGet(lhs, next);
auto *rhsElem = util::tupleGet(rhs, next);
body->push_back(
info.reduction.generateNonAtomicReduction(lhsElem, util::ptrLoad(rhsElem)));
++next;
}
}
reducer->setBody(body);
return reducer;
}
void handle(CallInstr *v) override {
auto *M = v->getModule();
auto *func = util::getFunc(v->getCallee());
if (!func)
return;
auto name = func->getUnmangledName();
if (name == "_loop_step") {
v->replaceAll(M->getInt(step));
}
if (name == "_loop_loc_and_gtid") {
seqassert(v->numArgs() == 3 &&
std::all_of(v->begin(), v->end(),
[](auto x) { return isA<VarValue>(x); }),
"unexpected loop loc and gtid stub");
std::vector<Value *> args(v->begin(), v->end());
locRef = util::getVar(args[0]);
reductionLocRef = util::getVar(args[1]);
gtid = util::getVar(args[2]);
}
if (name == "_loop_body_stub") {
seqassert(replacement, "unexpected double replacement");
seqassert(v->numArgs() == 2 && isA<VarValue>(v->front()) &&
isA<VarValue>(v->back()),
"unexpected loop body stub");
auto *outlinedFunc = util::getFunc(replacement->getCallee());
// the template passes the new loop var and extra args
// to the body stub for convenience
auto *newLoopVar = util::getVar(v->front());
auto *extras = util::getVar(v->back());
std::vector<Value *> newArgs;
auto outlinedArgs = outlinedFunc->arg_begin(); // arg vars of *outlined func*
unsigned next = 0; // next index in "extra" args tuple, passed to template
// `arg` is an argument of the original outlined func call
for (auto *arg : *replacement) {
if (getVarFromOutlinedArg(arg)->getId() != loopVar->getId()) {
Value *newArg = nullptr;
// shared vars will be stored in a new var
if (isA<PointerValue>(arg)) {
types::Type *base = cast<types::PointerType>(arg->getType())->getBase();
// get extras again since we'll be inserting the new var before extras local
Var *lastArg = parent->arg_back(); // ptr to {chunk, start, stop, extras}
Value *val = util::tupleGet(util::ptrLoad(M->Nr<VarValue>(lastArg)), 3);
Value *initVal = util::ptrLoad(util::tupleGet(val, next));
Reduction reduction = reds->getReduction(*outlinedArgs);
if (reduction) {
initVal = reduction.getInitial();
seqassert(initVal && initVal->getType()->is(base),
"unknown reduction init value");
}
VarValue *newVar = util::makeVar(
initVal, cast<SeriesFlow>(parent->getBody()), parent, /*prepend=*/true);
sharedInfo.push_back({next, newVar->getVar(), reduction});
newArg = M->Nr<PointerValue>(newVar->getVar());
++next;
} else {
newArg = util::tupleGet(M->Nr<VarValue>(extras), next++);
}
newArgs.push_back(newArg);
} else {
if (isA<VarValue>(arg)) {
newArgs.push_back(M->Nr<VarValue>(newLoopVar));
} else if (isA<PointerValue>(arg)) {
newArgs.push_back(M->Nr<PointerValue>(newLoopVar));
} else {
seqassert(false, "unknown outline var");
}
}
++outlinedArgs;
}
v->replaceAll(util::call(outlinedFunc, newArgs));
replacement = nullptr;
}
if (name == "_loop_shared_updates") {
// for all non-reduction shareds, set the final values
// this will be similar to OpenMP's "lastprivate"
seqassert(v->numArgs() == 1 && isA<VarValue>(v->front()),
"unexpected shared updates stub");
auto *extras = util::getVar(v->front());
auto *series = M->Nr<SeriesFlow>();
for (auto &info : sharedInfo) {
if (info.reduction)
continue;
auto *finalValue = M->Nr<VarValue>(info.local);
auto *val = M->Nr<VarValue>(extras);
auto *origPtr = util::tupleGet(val, info.memb);
series->push_back(util::ptrStore(origPtr, finalValue));
}
v->replaceAll(series);
}
if (name == "_loop_reductions") {
seqassert(reductionLocRef && gtid, "bad visit order in template");
seqassert(v->numArgs() == 1 && isA<VarValue>(v->front()),
"unexpected shared updates stub");
if (numReductions() == 0)
return;
auto *M = parent->getModule();
auto *extras = util::getVar(v->front());
auto *reductionTuple = getReductionTuple();
auto *reducer = makeReductionFunc();
auto *lck = locks.getMainLock(M);
auto *reducerType = reducer->getType();
auto *rawMethod = M->getOrRealizeMethod(reducerType, "__raw__", {reducerType});
auto *rawReducer = util::call(rawMethod, {M->Nr<VarValue>(reducer)});
auto *reduceNoWait = M->getOrRealizeFunc(
"_reduce_nowait",
{reductionLocRef->getType(), gtid->getType(), reductionTuple->getType(),
rawReducer->getType(), lck->getType()},
{}, ompModule);
seqassert(reduceNoWait, "reduce nowait function not found");
auto *reduceNoWaitEnd = M->getOrRealizeFunc(
"_end_reduce_nowait",
{reductionLocRef->getType(), gtid->getType(), lck->getType()}, {}, ompModule);
seqassert(reduceNoWaitEnd, "end reduce nowait function not found");
auto *series = M->Nr<SeriesFlow>();
auto *tupleVal = util::makeVar(reductionTuple, series, parent);
auto *reduceCode = util::call(reduceNoWait, {M->Nr<VarValue>(reductionLocRef),
M->Nr<VarValue>(gtid), tupleVal,
rawReducer, M->Nr<VarValue>(lck)});
auto *codeVar = util::makeVar(reduceCode, series, parent)->getVar();
seqassert(codeVar->getType()->is(M->getIntType()), "wrong reduce code type");
auto *sectionNonAtomic = M->Nr<SeriesFlow>();
auto *sectionAtomic = M->Nr<SeriesFlow>();
for (auto &info : sharedInfo) {
if (info.reduction) {
Value *ptr = util::tupleGet(M->Nr<VarValue>(extras), info.memb);
Value *arg = M->Nr<VarValue>(info.local);
sectionNonAtomic->push_back(
info.reduction.generateNonAtomicReduction(ptr, arg));
}
}
sectionNonAtomic->push_back(
util::call(reduceNoWaitEnd, {M->Nr<VarValue>(reductionLocRef),
M->Nr<VarValue>(gtid), M->Nr<VarValue>(lck)}));
for (auto &info : sharedInfo) {
if (info.reduction) {
Value *ptr = util::tupleGet(M->Nr<VarValue>(extras), info.memb);
Value *arg = M->Nr<VarValue>(info.local);
sectionAtomic->push_back(
info.reduction.generateAtomicReduction(ptr, arg, locRef, gtid, locks));
}
}
// make: if code == 1 { sectionNonAtomic } elif code == 2 { sectionAtomic }
auto *theSwitch = M->Nr<IfFlow>(
*M->Nr<VarValue>(codeVar) == *M->getInt(1), sectionNonAtomic,
util::series(M->Nr<IfFlow>(*M->Nr<VarValue>(codeVar) == *M->getInt(2),
sectionAtomic)));
series->push_back(theSwitch);
v->replaceAll(series);
}
if (name == "_loop_schedule") {
v->replaceAll(M->getInt(sched->code));
}
if (name == "_loop_ordered") {
v->replaceAll(M->getBool(sched->ordered));
}
}
};
struct TaskLoopBodyStubReplacer : public util::Operator {
CallInstr *replacement;
explicit TaskLoopBodyStubReplacer(CallInstr *replacement)
: util::Operator(), replacement(replacement) {}
void handle(CallInstr *v) override {
auto *func = util::getFunc(v->getCallee());
if (func && func->getUnmangledName() == "_task_loop_body_stub") {
seqassert(replacement, "unexpected double replacement");
seqassert(v->numArgs() == 2 && isA<VarValue>(v->front()) &&
isA<VarValue>(v->back()),
"unexpected loop body stub");
// the template passes the privs and shareds to the body stub for convenience
auto *privatesTuple = v->front();
auto *sharedsTuple = v->back();
unsigned privatesNext = 0;
unsigned sharedsNext = 0;
std::vector<Value *> newArgs;
for (auto *arg : *replacement) {
if (isA<VarValue>(arg)) {
newArgs.push_back(util::tupleGet(privatesTuple, privatesNext++));
} else if (isA<PointerValue>(arg)) {
newArgs.push_back(util::tupleGet(sharedsTuple, sharedsNext++));
} else {
seqassert(false, "unknown outline var");
}
}
v->replaceAll(util::call(util::getFunc(replacement->getCallee()), newArgs));
replacement = nullptr;
}
}
};
struct TaskLoopRoutineStubReplacer : public util::Operator {
std::vector<Value *> &privates;
std::vector<Value *> &shareds;
CallInstr *replacement;
Var *loopVar;
TaskLoopRoutineStubReplacer(std::vector<Value *> &privates,
std::vector<Value *> &shareds, CallInstr *replacement,
Var *loopVar)
: util::Operator(), privates(privates), shareds(shareds),
replacement(replacement), loopVar(loopVar) {}
void handle(VarValue *v) override {
auto *M = v->getModule();
auto *func = util::getFunc(v);
if (func && func->getUnmangledName() == "_routine_stub") {
util::CloneVisitor cv(M);
auto *newRoutine = cv.forceClone(func);
TaskLoopBodyStubReplacer rep(replacement);
newRoutine->accept(rep);
v->setVar(newRoutine);
}
}
void handle(CallInstr *v) override {
auto *M = v->getModule();
auto *func = util::getFunc(v->getCallee());
if (func && func->getUnmangledName() == "_insert_new_loop_var") {
std::vector<Value *> args(v->begin(), v->end());
seqassert(args.size() == 3, "invalid _insert_new_loop_var call found");
auto *newLoopVar = args[0];
auto *privatesTuple = args[1];
auto *sharedsTuple = args[2];
unsigned privatesNext = 0;
unsigned sharedsNext = 0;
bool needNewPrivates = false;
bool needNewShareds = false;
std::vector<Value *> newPrivates;
std::vector<Value *> newShareds;
for (auto *val : privates) {
if (getVarFromOutlinedArg(val)->getId() != loopVar->getId()) {
newPrivates.push_back(util::tupleGet(privatesTuple, privatesNext));
} else {
newPrivates.push_back(newLoopVar);
needNewPrivates = true;
}
++privatesNext;
}
for (auto *val : shareds) {
if (getVarFromOutlinedArg(val)->getId() != loopVar->getId()) {
newShareds.push_back(util::tupleGet(sharedsTuple, sharedsNext));
} else {
newShareds.push_back(M->Nr<PointerValue>(util::getVar(newLoopVar)));
needNewShareds = true;
}
++sharedsNext;
}
privatesTuple = needNewPrivates ? util::makeTuple(newPrivates, M) : privatesTuple;
sharedsTuple = needNewShareds ? util::makeTuple(newShareds, M) : sharedsTuple;
Value *result = util::makeTuple({privatesTuple, sharedsTuple}, M);
v->replaceAll(result);
}
}
};
template <typename T> void unpar(T *v) { v->setParallel(false); }
} // namespace
const std::string OpenMPPass::KEY = "core-parallel-openmp";
void OpenMPPass::handle(ForFlow *v) {
if (!v->isParallel())
return unpar(v);
auto *M = v->getModule();
auto *parent = cast<BodiedFunc>(getParentFunc());
auto *body = cast<SeriesFlow>(v->getBody());
if (!parent || !body)
return unpar(v);
auto outline = util::outlineRegion(parent, body, /*allowOutflows=*/false);
if (!outline)
return unpar(v);
// set up args to pass fork_call
auto *sched = v->getSchedule();
Var *loopVar = v->getVar();
OMPTypes types(M);
// separate arguments into 'private' and 'shared'
std::vector<Value *> privates, shareds;
for (auto *arg : *outline.call) {
const bool isPrivate = isA<VarValue>(arg);
auto &vec = (isPrivate ? privates : shareds);
vec.push_back(arg);
}
auto *privatesTuple = util::makeTuple(privates, M);
auto *sharedsTuple = util::makeTuple(shareds, M);
// template call
std::vector<types::Type *> templateFuncArgs = {
types.i32ptr, types.i32ptr,
M->getPointerType(
M->getTupleType({v->getIter()->getType(), privatesTuple->getType(),
sharedsTuple->getType()}))};
auto *templateFunc = M->getOrRealizeFunc("_task_loop_outline_template",
templateFuncArgs, {}, ompModule);
seqassert(templateFunc, "task loop outline template not found");
util::CloneVisitor cv(M);
templateFunc = cast<BodiedFunc>(cv.forceClone(templateFunc));
TaskLoopRoutineStubReplacer rep(privates, shareds, outline.call, loopVar);
templateFunc->accept(rep);
// raw template func
auto *templateFuncType = templateFunc->getType();
auto *rawMethod =
M->getOrRealizeMethod(templateFuncType, "__raw__", {templateFuncType});
auto *rawTemplateFunc = util::call(rawMethod, {M->Nr<VarValue>(templateFunc)});
// fork call
std::vector<Value *> forkExtraArgs = {v->getIter(), privatesTuple, sharedsTuple};
auto *forkExtra = util::makeTuple(forkExtraArgs, M);
std::vector<types::Type *> forkArgTypes = {types.i8ptr, forkExtra->getType()};
auto *forkFunc = M->getOrRealizeFunc("_fork_call", forkArgTypes, {}, ompModule);
seqassert(forkFunc, "fork call function not found");
auto *fork = util::call(forkFunc, {rawTemplateFunc, forkExtra});
if (sched->threads && sched->threads->getType()->is(M->getIntType())) {
auto *pushNumThreadsFunc =
M->getOrRealizeFunc("_push_num_threads", {M->getIntType()}, {}, ompModule);
seqassert(pushNumThreadsFunc, "push num threads func not found");
auto *pushNumThreads = util::call(pushNumThreadsFunc, {sched->threads});
insertBefore(pushNumThreads);
}
v->replaceAll(fork);
}
void OpenMPPass::handle(ImperativeForFlow *v) {
if (!v->isParallel() || v->getStep() == 0)
return unpar(v);
auto *M = v->getModule();
auto *parent = cast<BodiedFunc>(getParentFunc());
auto *body = cast<SeriesFlow>(v->getBody());
if (!parent || !body)
return unpar(v);
auto outline = util::outlineRegion(parent, body, /*allowOutflows=*/false,
/*outlineGlobals=*/true);
if (!outline)
return unpar(v);
// set up args to pass fork_call
auto *sched = v->getSchedule();
Var *loopVar = v->getVar();
OMPTypes types(M);
// shared argument vars
std::vector<Var *> shareds;
unsigned i = 0;
for (auto it = outline.func->arg_begin(); it != outline.func->arg_end(); ++it) {
if (outline.argKinds[i++] == util::OutlineResult::ArgKind::MODIFIED)
shareds.push_back(*it);
}
ReductionIdentifier reds(shareds);
outline.func->accept(reds);
// gather extra arguments
std::vector<Value *> extraArgs;
std::vector<types::Type *> extraArgTypes;
for (auto *arg : *outline.call) {
if (getVarFromOutlinedArg(arg)->getId() != loopVar->getId()) {
extraArgs.push_back(arg);
extraArgTypes.push_back(arg->getType());
}
}
// template call
std::string templateFuncName;
if (sched->dynamic) {
templateFuncName = "_dynamic_loop_outline_template";
} else if (sched->chunk) {
templateFuncName = "_static_chunked_loop_outline_template";
} else {
templateFuncName = "_static_loop_outline_template";
}
auto *intType = M->getIntType();
std::vector<types::Type *> templateFuncArgs = {
types.i32ptr, types.i32ptr,
M->getPointerType(M->getTupleType(
{intType, intType, intType, M->getTupleType(extraArgTypes)}))};
auto *templateFunc =
M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {}, ompModule);
seqassert(templateFunc, "imperative loop outline template not found");
util::CloneVisitor cv(M);
templateFunc = cast<Func>(cv.forceClone(templateFunc));
ImperativeLoopTemplateReplacer rep(cast<BodiedFunc>(templateFunc), outline.call,
loopVar, sched, &reds, v->getStep());
templateFunc->accept(rep);
// raw template func
auto *templateFuncType = templateFunc->getType();
auto *rawMethod =
M->getOrRealizeMethod(templateFuncType, "__raw__", {templateFuncType});
auto *rawTemplateFunc = util::call(rawMethod, {M->Nr<VarValue>(templateFunc)});
// fork call
auto *chunk = (sched->chunk && sched->chunk->getType()->is(intType)) ? sched->chunk
: M->getInt(1);
std::vector<Value *> forkExtraArgs = {chunk, v->getStart(), v->getEnd()};
for (auto *arg : extraArgs) {
forkExtraArgs.push_back(arg);
}
auto *forkExtra = util::makeTuple(forkExtraArgs, M);
std::vector<types::Type *> forkArgTypes = {types.i8ptr, forkExtra->getType()};
auto *forkFunc = M->getOrRealizeFunc("_fork_call", forkArgTypes, {}, ompModule);
seqassert(forkFunc, "fork call function not found");
auto *fork = util::call(forkFunc, {rawTemplateFunc, forkExtra});
if (sched->threads && sched->threads->getType()->is(intType)) {
auto *pushNumThreadsFunc =
M->getOrRealizeFunc("_push_num_threads", {intType}, {}, ompModule);
seqassert(pushNumThreadsFunc, "push num threads func not found");
auto *pushNumThreads = util::call(pushNumThreadsFunc, {sched->threads});
insertBefore(pushNumThreads);
}
v->replaceAll(fork);
}
} // namespace parallel
} // namespace transform
} // namespace ir
} // namespace seq