Update OpenMP reduction detection for new ops

ptr-field-support
A. R. Shajii 2025-02-07 12:04:12 -05:00
parent 56c00d36c2
commit b58b1ee767
3 changed files with 38 additions and 8 deletions

View File

@ -402,7 +402,8 @@ struct ReductionIdentifier : public util::Operator {
static void extractAssociativeOpChain(Value *v, const std::string &op, static void extractAssociativeOpChain(Value *v, const std::string &op,
types::Type *type, types::Type *type,
std::vector<Value *> &result) { std::vector<Value *> &result) {
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) { if (util::isCallOf(v, op, {type, nullptr}, type, /*method=*/true) ||
util::isCallOf(v, op, {nullptr, type}, type, /*method=*/true)) {
auto *call = cast<CallInstr>(v); auto *call = cast<CallInstr>(v);
extractAssociativeOpChain(call->front(), op, type, result); extractAssociativeOpChain(call->front(), op, type, result);
extractAssociativeOpChain(call->back(), op, type, result); extractAssociativeOpChain(call->back(), op, type, result);
@ -450,7 +451,8 @@ struct ReductionIdentifier : public util::Operator {
for (auto &rf : reductionFunctions) { for (auto &rf : reductionFunctions) {
if (rf.method) { if (rf.method) {
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true)) if (!(util::isCallOf(item, rf.name, {type, nullptr}, type, /*method=*/true) ||
util::isCallOf(item, rf.name, {nullptr, type}, type, /*method=*/true)))
continue; continue;
} else { } else {
if (!util::isCallOf(item, rf.name, if (!util::isCallOf(item, rf.name,
@ -464,8 +466,7 @@ struct ReductionIdentifier : public util::Operator {
if (rf.method) { if (rf.method) {
std::vector<Value *> opChain; std::vector<Value *> opChain;
extractAssociativeOpChain(callRHS, rf.name, callRHS->front()->getType(), extractAssociativeOpChain(callRHS, rf.name, type, opChain);
opChain);
if (opChain.size() < 2) if (opChain.size() < 2)
continue; continue;

View File

@ -38,16 +38,21 @@ bool isCallOf(const Value *value, const std::string &name,
unsigned i = 0; unsigned i = 0;
for (auto *arg : *call) { for (auto *arg : *call) {
if (!arg->getType()->is(inputs[i++])) if (inputs[i] && !arg->getType()->is(inputs[i]))
return false; return false;
++i;
} }
if (output && !value->getType()->is(output)) if (output && !value->getType()->is(output))
return false; return false;
if (method && if (method) {
(inputs.empty() || !fn->getParentType() || !fn->getParentType()->is(inputs[0]))) if (inputs.empty() || !fn->getParentType())
return false; return false;
if (inputs[0] && !fn->getParentType()->is(inputs[0]))
return false;
}
return true; return true;
} }

View File

@ -450,6 +450,18 @@ def test_omp_reductions():
c = min(b, c) c = min(b, c)
assert c == -1. assert c == -1.
c = 0.
@par
for i in L:
c += i # float-int op
assert c == expected(N, 0., float.__add__)
c = 0.
@par
for i in L:
c = i + c # int-float op
assert c == expected(N, 0., float.__add__)
# float32s # float32s
c = f32(0.) c = f32(0.)
# this one can give different results due to # this one can give different results due to
@ -479,6 +491,18 @@ def test_omp_reductions():
c = min(b, c) c = min(b, c)
assert c == f32(-1.) assert c == f32(-1.)
c = f32(0.)
@par
for i in L[:12]:
c += i # float-int op
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
c = f32(0.)
@par
for i in L[:12]:
c = i + c # int-float op
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
x_add = 10. x_add = 10.
x_min = inf x_min = inf
x_max = -inf x_max = -inf