mirror of https://github.com/exaloop/codon.git
Update OpenMP reduction detection for new ops
parent
56c00d36c2
commit
b58b1ee767
|
@ -402,7 +402,8 @@ struct ReductionIdentifier : public util::Operator {
|
|||
static void extractAssociativeOpChain(Value *v, const std::string &op,
|
||||
types::Type *type,
|
||||
std::vector<Value *> &result) {
|
||||
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) {
|
||||
if (util::isCallOf(v, op, {type, nullptr}, type, /*method=*/true) ||
|
||||
util::isCallOf(v, op, {nullptr, type}, type, /*method=*/true)) {
|
||||
auto *call = cast<CallInstr>(v);
|
||||
extractAssociativeOpChain(call->front(), op, type, result);
|
||||
extractAssociativeOpChain(call->back(), op, type, result);
|
||||
|
@ -450,7 +451,8 @@ struct ReductionIdentifier : public util::Operator {
|
|||
|
||||
for (auto &rf : reductionFunctions) {
|
||||
if (rf.method) {
|
||||
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true))
|
||||
if (!(util::isCallOf(item, rf.name, {type, nullptr}, type, /*method=*/true) ||
|
||||
util::isCallOf(item, rf.name, {nullptr, type}, type, /*method=*/true)))
|
||||
continue;
|
||||
} else {
|
||||
if (!util::isCallOf(item, rf.name,
|
||||
|
@ -464,8 +466,7 @@ struct ReductionIdentifier : public util::Operator {
|
|||
|
||||
if (rf.method) {
|
||||
std::vector<Value *> opChain;
|
||||
extractAssociativeOpChain(callRHS, rf.name, callRHS->front()->getType(),
|
||||
opChain);
|
||||
extractAssociativeOpChain(callRHS, rf.name, type, opChain);
|
||||
if (opChain.size() < 2)
|
||||
continue;
|
||||
|
||||
|
|
|
@ -38,16 +38,21 @@ bool isCallOf(const Value *value, const std::string &name,
|
|||
|
||||
unsigned i = 0;
|
||||
for (auto *arg : *call) {
|
||||
if (!arg->getType()->is(inputs[i++]))
|
||||
if (inputs[i] && !arg->getType()->is(inputs[i]))
|
||||
return false;
|
||||
++i;
|
||||
}
|
||||
|
||||
if (output && !value->getType()->is(output))
|
||||
return false;
|
||||
|
||||
if (method &&
|
||||
(inputs.empty() || !fn->getParentType() || !fn->getParentType()->is(inputs[0])))
|
||||
return false;
|
||||
if (method) {
|
||||
if (inputs.empty() || !fn->getParentType())
|
||||
return false;
|
||||
|
||||
if (inputs[0] && !fn->getParentType()->is(inputs[0]))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -450,6 +450,18 @@ def test_omp_reductions():
|
|||
c = min(b, c)
|
||||
assert c == -1.
|
||||
|
||||
c = 0.
|
||||
@par
|
||||
for i in L:
|
||||
c += i # float-int op
|
||||
assert c == expected(N, 0., float.__add__)
|
||||
|
||||
c = 0.
|
||||
@par
|
||||
for i in L:
|
||||
c = i + c # int-float op
|
||||
assert c == expected(N, 0., float.__add__)
|
||||
|
||||
# float32s
|
||||
c = f32(0.)
|
||||
# this one can give different results due to
|
||||
|
@ -479,6 +491,18 @@ def test_omp_reductions():
|
|||
c = min(b, c)
|
||||
assert c == f32(-1.)
|
||||
|
||||
c = f32(0.)
|
||||
@par
|
||||
for i in L[:12]:
|
||||
c += i # float-int op
|
||||
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
|
||||
|
||||
c = f32(0.)
|
||||
@par
|
||||
for i in L[:12]:
|
||||
c = i + c # int-float op
|
||||
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
|
||||
|
||||
x_add = 10.
|
||||
x_min = inf
|
||||
x_max = -inf
|
||||
|
|
Loading…
Reference in New Issue