1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
2021-09-27 14:02:44 -04:00

148 lines
4.5 KiB
C++

#include "dict.h"
#include <algorithm>
#include "sir/util/cloning.h"
#include "sir/util/irtools.h"
#include "sir/util/matching.h"
namespace seq {
namespace ir {
namespace transform {
namespace pythonic {
namespace {
/// get or __getitem__ call metadata
struct GetCall {
/// the function, nullptr if not a get call
Func *func = nullptr;
/// the dictionary, must not be a call
Value *dict = nullptr;
/// the key, must not be a call
Value *key = nullptr;
/// the default value, may be null
Const *dflt = nullptr;
};
/// Identify the call and return its metadata.
/// @param call the call
/// @return the metadata
GetCall analyzeGet(CallInstr *call) {
// extract the function
auto *func = util::getFunc(call->getCallee());
if (!func)
return {};
auto unmangled = func->getUnmangledName();
// canonical get/__getitem__ calls have at least two arguments
auto it = call->begin();
auto dist = std::distance(it, call->end());
if (dist < 2)
return {};
// extract the dictionary and keys
auto *dict = *it++;
auto *k = *it++;
// dictionary and key must not be calls
if (isA<CallInstr>(dict) || isA<CallInstr>(k))
return {};
// get calls have a default
if (unmangled == "get" && std::distance(it, call->end()) == 1) {
auto *dflt = cast<Const>(*it);
return {func, dict, k, dflt};
} else if (unmangled == "__getitem__" && std::distance(it, call->end()) == 0) {
return {func, dict, k, nullptr};
}
// call is not correct
return {};
}
} // namespace
const std::string DictArithmeticOptimization::KEY = "core-pythonic-dict-arithmetic-opt";
void DictArithmeticOptimization::handle(CallInstr *v) {
auto *M = v->getModule();
// get and check the exterior function (should be a __setitem__ with 3 args)
auto *setFunc = util::getFunc(v->getCallee());
if (setFunc && setFunc->getUnmangledName() == "__setitem__" &&
std::distance(v->begin(), v->end()) == 3) {
auto it = v->begin();
// extract all the arguments to the function
// the dictionary and key must not be calls, and the value must
// be a call
auto *dictValue = *it++;
auto *keyValue = *it++;
if (isA<CallInstr>(dictValue) || isA<CallInstr>(keyValue))
return;
auto *opCall = cast<CallInstr>(*it++);
// the call must take exactly two arguments
if (!dictValue || !opCall || std::distance(opCall->begin(), opCall->end()) != 2)
return;
// grab the function, which needs to be an int or float call for now
auto *opFunc = util::getFunc(opCall->getCallee());
auto *getCall = cast<CallInstr>(opCall->front());
if (!opFunc || !getCall)
return;
auto *intType = M->getIntType();
auto *floatType = M->getFloatType();
auto *parentType = opFunc->getParentType();
if (!parentType || !(parentType->is(intType) || parentType->is(floatType)))
return;
// check the first argument
auto getAnalysis = analyzeGet(getCall);
if (!getAnalysis.func)
return;
// second argument can be any non-null value
auto *secondValue = opCall->back();
// verify that we are dealing with the same dictionary and key
if (util::match(dictValue, getAnalysis.dict, false, true) &&
util::match(keyValue, getAnalysis.key, false, true)) {
util::CloneVisitor cv(M);
Func *replacementFunc;
// call non-throwing version if we have a default
if (getAnalysis.dflt) {
replacementFunc = M->getOrRealizeMethod(
dictValue->getType(), "__dict_do_op__",
{dictValue->getType(), keyValue->getType(), secondValue->getType(),
getAnalysis.dflt->getType(), opFunc->getType()});
} else {
replacementFunc =
M->getOrRealizeMethod(dictValue->getType(), "__dict_do_op_throws__",
{dictValue->getType(), keyValue->getType(),
secondValue->getType(), opFunc->getType()});
}
if (replacementFunc) {
std::vector<Value *> args = {cv.clone(dictValue), cv.clone(keyValue),
cv.clone(secondValue)};
if (getAnalysis.dflt)
args.push_back(cv.clone(getAnalysis.dflt));
// sanity check to make sure function is inlined
if (args.size() !=
std::distance(replacementFunc->arg_begin(), replacementFunc->arg_end()))
args.push_back(M->N<VarValue>(v, opFunc));
v->replaceAll(util::call(replacementFunc, args));
}
}
}
}
} // namespace pythonic
} // namespace transform
} // namespace ir
} // namespace seq