#include "module.h" #include #include #include "parser/cache.h" #include "func.h" namespace seq { namespace ir { namespace { std::vector translateGenerics(std::vector &generics) { std::vector ret; for (auto &g : generics) { seqassert(g.isStatic() || g.getTypeValue(), "generic must be static or a type"); ret.push_back(std::make_shared( g.isStatic() ? std::make_shared(g.getStaticValue()) : g.getTypeValue()->getAstType())); } return ret; } std::vector> generateDummyNames(std::vector &types) { std::vector> ret; for (auto *t : types) { seqassert(t->getAstType(), "{} must have an ast type", *t); ret.emplace_back("", t->getAstType()); } return ret; } std::vector translateArgs(std::vector &types) { std::vector ret = { std::make_shared( seq::ast::types::LinkType::Kind::Unbound, 0)}; for (auto *t : types) { seqassert(t->getAstType(), "{} must have an ast type", *t); if (auto f = t->getAstType()->getFunc()) { auto *irType = cast(t); std::vector mask(std::distance(irType->begin(), irType->end()), 0); ret.push_back(std::make_shared( t->getAstType()->getRecord(), f, mask)); } else { ret.push_back(t->getAstType()); } } return ret; } } // namespace const std::string Module::VOID_NAME = "void"; const std::string Module::BOOL_NAME = "bool"; const std::string Module::BYTE_NAME = "byte"; const std::string Module::INT_NAME = "int"; const std::string Module::FLOAT_NAME = "float"; const std::string Module::STRING_NAME = "str"; const std::string Module::EQ_MAGIC_NAME = "__eq__"; const std::string Module::NE_MAGIC_NAME = "__ne__"; const std::string Module::LT_MAGIC_NAME = "__lt__"; const std::string Module::GT_MAGIC_NAME = "__gt__"; const std::string Module::LE_MAGIC_NAME = "__le__"; const std::string Module::GE_MAGIC_NAME = "__ge__"; const std::string Module::POS_MAGIC_NAME = "__pos__"; const std::string Module::NEG_MAGIC_NAME = "__neg__"; const std::string Module::INVERT_MAGIC_NAME = "__invert__"; const std::string Module::ADD_MAGIC_NAME = "__add__"; const std::string Module::SUB_MAGIC_NAME = "__sub__"; const std::string Module::MUL_MAGIC_NAME = "__mul__"; const std::string Module::MATMUL_MAGIC_NAME = "__matmul__"; const std::string Module::TRUE_DIV_MAGIC_NAME = "__truediv__"; const std::string Module::FLOOR_DIV_MAGIC_NAME = "__floordiv__"; const std::string Module::MOD_MAGIC_NAME = "__mod__"; const std::string Module::POW_MAGIC_NAME = "__pow__"; const std::string Module::LSHIFT_MAGIC_NAME = "__lshift__"; const std::string Module::RSHIFT_MAGIC_NAME = "__rshift__"; const std::string Module::AND_MAGIC_NAME = "__and__"; const std::string Module::OR_MAGIC_NAME = "__or__"; const std::string Module::XOR_MAGIC_NAME = "__xor__"; const std::string Module::INT_MAGIC_NAME = "__int__"; const std::string Module::FLOAT_MAGIC_NAME = "__float__"; const std::string Module::BOOL_MAGIC_NAME = "__bool__"; const std::string Module::STR_MAGIC_NAME = "__str__"; const std::string Module::GETITEM_MAGIC_NAME = "__getitem__"; const std::string Module::SETITEM_MAGIC_NAME = "__setitem__"; const std::string Module::ITER_MAGIC_NAME = "__iter__"; const std::string Module::LEN_MAGIC_NAME = "__len__"; const std::string Module::NEW_MAGIC_NAME = "__new__"; const std::string Module::INIT_MAGIC_NAME = "__init__"; const char Module::NodeId = 0; Module::Module(std::string name, std::shared_ptr cache) : AcceptorExtend(std::move(name)), cache(std::move(cache)) { mainFunc = std::make_unique("main"); mainFunc->realize(cast(unsafeGetDummyFuncType()), {}); mainFunc->setModule(this); mainFunc->setReplaceable(false); argVar = std::make_unique(unsafeGetArrayType(getStringType()), true, "argv"); argVar->setModule(this); argVar->setReplaceable(false); } Func *Module::getOrRealizeMethod(types::Type *parent, const std::string &methodName, std::vector args, std::vector generics) { auto cls = std::const_pointer_cast(parent->getAstType())->getClass(); auto method = cache->findMethod(cls.get(), methodName, generateDummyNames(args)); if (!method) return nullptr; try { return cache->realizeFunction(method, translateArgs(args), translateGenerics(generics), cls); } catch (const exc::ParserException &e) { LOG_IR("getOrRealizeMethod parser error: {}", e.what()); return nullptr; } } Func *Module::getOrRealizeFunc(const std::string &funcName, std::vector args, std::vector generics, const std::string &module) { auto fqName = module.empty() ? funcName : fmt::format(FMT_STRING("{}.{}"), module, funcName); auto func = cache->findFunction(fqName); if (!func) return nullptr; auto arg = translateArgs(args); auto gens = translateGenerics(generics); try { return cache->realizeFunction(func, arg, gens); } catch (const exc::ParserException &e) { LOG_IR("getOrRealizeFunc parser error: {}", e.what()); return nullptr; } } types::Type *Module::getOrRealizeType(const std::string &typeName, std::vector generics, const std::string &module) { auto fqName = module.empty() ? typeName : fmt::format(FMT_STRING("{}.{}"), module, typeName); auto type = cache->findClass(fqName); if (!type) return nullptr; try { return cache->realizeType(type, translateGenerics(generics)); } catch (const exc::ParserException &e) { LOG_IR("getOrRealizeType parser error: {}", e.what()); return nullptr; } } types::Type *Module::getVoidType() { if (auto *rVal = getType(VOID_NAME)) return rVal; return Nr(); } types::Type *Module::getBoolType() { if (auto *rVal = getType(BOOL_NAME)) return rVal; return Nr(); } types::Type *Module::getByteType() { if (auto *rVal = getType(BYTE_NAME)) return rVal; return Nr(); } types::Type *Module::getIntType() { if (auto *rVal = getType(INT_NAME)) return rVal; return Nr(); } types::Type *Module::getFloatType() { if (auto *rVal = getType(FLOAT_NAME)) return rVal; return Nr(); } types::Type *Module::getStringType() { if (auto *rVal = getType(STRING_NAME)) return rVal; return Nr( STRING_NAME, std::vector{getIntType(), unsafeGetPointerType(getByteType())}, std::vector{"len", "ptr"}); } types::Type *Module::getPointerType(types::Type *base) { return getOrRealizeType("Ptr", {base}); } types::Type *Module::getArrayType(types::Type *base) { return getOrRealizeType("Array", {base}); } types::Type *Module::getGeneratorType(types::Type *base) { return getOrRealizeType("Generator", {base}); } types::Type *Module::getOptionalType(types::Type *base) { return getOrRealizeType("Optional", {base}); } types::Type *Module::getFuncType(types::Type *rType, std::vector argTypes, bool variadic) { auto args = translateArgs(argTypes); args[0] = std::make_shared(rType->getAstType()); auto *result = cache->makeFunction(args); if (variadic) { // Type checker types have no concept of variadic functions, so we will // create a new IR type here with the same AST type. auto *f = cast(result); result = unsafeGetFuncType(f->getName() + "$variadic", f->getReturnType(), std::vector(f->begin(), f->end()), /*variadic=*/true); result->setAstType(f->getAstType()); } return result; } types::Type *Module::getIntNType(unsigned int len, bool sign) { return getOrRealizeType(sign ? "Int" : "UInt", {len}); } types::Type *Module::getTupleType(std::vector args) { std::vector argTypes; for (auto *t : args) { seqassert(t->getAstType(), "{} must have an ast type", *t); argTypes.push_back(t->getAstType()); } return cache->makeTuple(argTypes); } Value *Module::getInt(int64_t v) { return Nr(v, getIntType()); } Value *Module::getFloat(double v) { return Nr(v, getFloatType()); } Value *Module::getBool(bool v) { return Nr(v, getBoolType()); } Value *Module::getString(std::string v) { return Nr(std::move(v), getStringType()); } types::Type *Module::unsafeGetDummyFuncType() { return unsafeGetFuncType("", getVoidType(), {}); } types::Type *Module::unsafeGetPointerType(types::Type *base) { auto name = types::PointerType::getInstanceName(base); if (auto *rVal = getType(name)) return rVal; return Nr(base); } types::Type *Module::unsafeGetArrayType(types::Type *base) { auto name = fmt::format(FMT_STRING(".Array[{}]"), base->referenceString()); if (auto *rVal = getType(name)) return rVal; std::vector members = {getIntType(), unsafeGetPointerType(base)}; std::vector names = {"len", "ptr"}; return Nr(name, members, names); } types::Type *Module::unsafeGetGeneratorType(types::Type *base) { auto name = types::GeneratorType::getInstanceName(base); if (auto *rVal = getType(name)) return rVal; return Nr(base); } types::Type *Module::unsafeGetOptionalType(types::Type *base) { auto name = types::OptionalType::getInstanceName(base); if (auto *rVal = getType(name)) return rVal; return Nr(base); } types::Type *Module::unsafeGetFuncType(const std::string &name, types::Type *rType, std::vector argTypes, bool variadic) { if (auto *rVal = getType(name)) return rVal; return Nr(name, rType, std::move(argTypes), variadic); } types::Type *Module::unsafeGetMemberedType(const std::string &name, bool ref) { auto *rVal = getType(name); if (!rVal) { if (ref) { auto contentName = name + ".contents"; auto *record = getType(contentName); if (!record) { record = Nr(contentName); } rVal = Nr(name, cast(record)); } else { rVal = Nr(name); } } return rVal; } types::Type *Module::unsafeGetIntNType(unsigned int len, bool sign) { auto name = types::IntNType::getInstanceName(len, sign); if (auto *rVal = getType(name)) return rVal; return Nr(len, sign); } } // namespace ir } // namespace seq