mirror of https://github.com/exaloop/codon.git
Update OpenMP reduction detection for new ops
parent
56c00d36c2
commit
b58b1ee767
|
@ -402,7 +402,8 @@ struct ReductionIdentifier : public util::Operator {
|
||||||
static void extractAssociativeOpChain(Value *v, const std::string &op,
|
static void extractAssociativeOpChain(Value *v, const std::string &op,
|
||||||
types::Type *type,
|
types::Type *type,
|
||||||
std::vector<Value *> &result) {
|
std::vector<Value *> &result) {
|
||||||
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) {
|
if (util::isCallOf(v, op, {type, nullptr}, type, /*method=*/true) ||
|
||||||
|
util::isCallOf(v, op, {nullptr, type}, type, /*method=*/true)) {
|
||||||
auto *call = cast<CallInstr>(v);
|
auto *call = cast<CallInstr>(v);
|
||||||
extractAssociativeOpChain(call->front(), op, type, result);
|
extractAssociativeOpChain(call->front(), op, type, result);
|
||||||
extractAssociativeOpChain(call->back(), op, type, result);
|
extractAssociativeOpChain(call->back(), op, type, result);
|
||||||
|
@ -450,7 +451,8 @@ struct ReductionIdentifier : public util::Operator {
|
||||||
|
|
||||||
for (auto &rf : reductionFunctions) {
|
for (auto &rf : reductionFunctions) {
|
||||||
if (rf.method) {
|
if (rf.method) {
|
||||||
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true))
|
if (!(util::isCallOf(item, rf.name, {type, nullptr}, type, /*method=*/true) ||
|
||||||
|
util::isCallOf(item, rf.name, {nullptr, type}, type, /*method=*/true)))
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
if (!util::isCallOf(item, rf.name,
|
if (!util::isCallOf(item, rf.name,
|
||||||
|
@ -464,8 +466,7 @@ struct ReductionIdentifier : public util::Operator {
|
||||||
|
|
||||||
if (rf.method) {
|
if (rf.method) {
|
||||||
std::vector<Value *> opChain;
|
std::vector<Value *> opChain;
|
||||||
extractAssociativeOpChain(callRHS, rf.name, callRHS->front()->getType(),
|
extractAssociativeOpChain(callRHS, rf.name, type, opChain);
|
||||||
opChain);
|
|
||||||
if (opChain.size() < 2)
|
if (opChain.size() < 2)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
|
|
@ -38,16 +38,21 @@ bool isCallOf(const Value *value, const std::string &name,
|
||||||
|
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
for (auto *arg : *call) {
|
for (auto *arg : *call) {
|
||||||
if (!arg->getType()->is(inputs[i++]))
|
if (inputs[i] && !arg->getType()->is(inputs[i]))
|
||||||
return false;
|
return false;
|
||||||
|
++i;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output && !value->getType()->is(output))
|
if (output && !value->getType()->is(output))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
if (method &&
|
if (method) {
|
||||||
(inputs.empty() || !fn->getParentType() || !fn->getParentType()->is(inputs[0])))
|
if (inputs.empty() || !fn->getParentType())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
if (inputs[0] && !fn->getParentType()->is(inputs[0]))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -450,6 +450,18 @@ def test_omp_reductions():
|
||||||
c = min(b, c)
|
c = min(b, c)
|
||||||
assert c == -1.
|
assert c == -1.
|
||||||
|
|
||||||
|
c = 0.
|
||||||
|
@par
|
||||||
|
for i in L:
|
||||||
|
c += i # float-int op
|
||||||
|
assert c == expected(N, 0., float.__add__)
|
||||||
|
|
||||||
|
c = 0.
|
||||||
|
@par
|
||||||
|
for i in L:
|
||||||
|
c = i + c # int-float op
|
||||||
|
assert c == expected(N, 0., float.__add__)
|
||||||
|
|
||||||
# float32s
|
# float32s
|
||||||
c = f32(0.)
|
c = f32(0.)
|
||||||
# this one can give different results due to
|
# this one can give different results due to
|
||||||
|
@ -479,6 +491,18 @@ def test_omp_reductions():
|
||||||
c = min(b, c)
|
c = min(b, c)
|
||||||
assert c == f32(-1.)
|
assert c == f32(-1.)
|
||||||
|
|
||||||
|
c = f32(0.)
|
||||||
|
@par
|
||||||
|
for i in L[:12]:
|
||||||
|
c += i # float-int op
|
||||||
|
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
|
||||||
|
|
||||||
|
c = f32(0.)
|
||||||
|
@par
|
||||||
|
for i in L[:12]:
|
||||||
|
c = i + c # int-float op
|
||||||
|
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
|
||||||
|
|
||||||
x_add = 10.
|
x_add = 10.
|
||||||
x_min = inf
|
x_min = inf
|
||||||
x_max = -inf
|
x_max = -inf
|
||||||
|
|
Loading…
Reference in New Issue