diff --git a/CMakeLists.txt b/CMakeLists.txt index fa87c24c..ee50b352 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,8 +16,16 @@ if(CODON_JUPYTER) endif() set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility-inlines-hidden -pedantic -Wno-return-type-c-linkage -Wno-gnu-zero-variadic-macro-arguments") -set(CMAKE_CXX_FLAGS_DEBUG "-g -fno-limit-debug-info") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pedantic -fvisibility-inlines-hidden -Wno-return-type-c-linkage -Wno-gnu-zero-variadic-macro-arguments") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-return-type") +endif() +set(CMAKE_CXX_FLAGS_DEBUG "-g") +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fno-limit-debug-info") +endif() set(CMAKE_CXX_FLAGS_RELEASE "-O3") include_directories(.) @@ -74,8 +82,8 @@ else() -Wl,--no-whole-archive) endif() if(ASAN) - target_compile_options(codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address") - target_link_libraries(codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address") + target_compile_options(codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") + target_link_libraries(codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") endif() add_custom_command(TARGET codonrt POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${CMAKE_BINARY_DIR}) @@ -209,7 +217,7 @@ set(CODON_HPPFILES codon/util/toml++/toml_node.h codon/util/toml++/toml_parser.hpp codon/util/toml++/toml_utf8_streams.h - extra/jupyter/src/codon.h) + extra/jupyter/jupyter.h) set(CODON_CPPFILES codon/compiler/compiler.cpp codon/compiler/debug_listener.cpp @@ -288,7 +296,7 @@ set(CODON_CPPFILES codon/sir/var.cpp codon/util/common.cpp codon/util/fmt/format.cpp - extra/jupyter/src/codon.cpp) + extra/jupyter/jupyter.cpp) add_library(codonc SHARED ${CODON_HPPFILES}) target_sources(codonc PRIVATE ${CODON_CPPFILES} codon_rules.cpp omp_rules.cpp) if(CODON_JUPYTER) @@ -297,8 +305,8 @@ if(CODON_JUPYTER) target_link_libraries(codonc PRIVATE xeus-static) endif() if(ASAN) - target_compile_options(codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address") - target_link_libraries(codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address") + target_compile_options(codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") + target_link_libraries(codonc PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") endif() if(CMAKE_BUILD_TYPE MATCHES Debug) set_source_files_properties(codon_rules.cpp codon/parser/peg/peg.cpp PROPERTIES COMPILE_FLAGS "-O2") diff --git a/cmake/deps.cmake b/cmake/deps.cmake index dfa45958..31a3efde 100644 --- a/cmake/deps.cmake +++ b/cmake/deps.cmake @@ -66,10 +66,11 @@ if(bdwgc_ADDED) endif() CPMAddPackage( - GITHUB_REPOSITORY "llvm-mirror/openmp" - VERSION 9.0 - GIT_TAG release_90 - OPTIONS "OPENMP_ENABLE_LIBOMPTARGET OFF" + NAME openmp + GITHUB_REPOSITORY "exaloop/openmp" + VERSION 13.0.0-patch1 + OPTIONS "CMAKE_BUILD_TYPE Release" + "OPENMP_ENABLE_LIBOMPTARGET OFF" "OPENMP_STANDALONE_BUILD ON") CPMAddPackage( @@ -125,8 +126,8 @@ if(CODON_JUPYTER) NAME libzmq VERSION 4.3.4 URL https://github.com/zeromq/libzmq/releases/download/v4.3.4/zeromq-4.3.4.tar.gz - OPTIONS "WITH_PERF_TOOL OFF" - "ZMQ_BUILD_TESTS OFF" + OPTIONS "WITH_PERF_TOOL OFF" + "ZMQ_BUILD_TESTS OFF" "ENABLE_CPACK OFF" "BUILD_SHARED ON" "WITH_LIBSODIUM OFF") @@ -144,16 +145,17 @@ if(CODON_JUPYTER) CPMAddPackage( NAME json GITHUB_REPOSITORY "nlohmann/json" - VERSION 3.10.4) + VERSION 3.10.1) CPMAddPackage( NAME xeus GITHUB_REPOSITORY "jupyter-xeus/xeus" VERSION 2.2.0 GIT_TAG 2.2.0 - PATCH_COMMAND sed -i bak "s/-Wunused-parameter -Wextra -Wreorder//g" CMakeLists.txt + PATCH_COMMAND sed -ibak "s/-Wunused-parameter -Wextra -Wreorder//g" CMakeLists.txt OPTIONS "BUILD_EXAMPLES OFF" "XEUS_BUILD_SHARED_LIBS OFF" - "XEUS_STATIC_DEPENDENCIES ON") + "XEUS_STATIC_DEPENDENCIES ON" + "CMAKE_POSITION_INDEPENDENT_CODE ON") if (xeus_ADDED) install(TARGETS nlohmann_json EXPORT xeus-targets) endif() diff --git a/codon/app/main.cpp b/codon/app/main.cpp index ae830ccf..bb4b2a35 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -55,6 +56,12 @@ void display(const codon::error::ParserErrorInfo &e) { } } +void initLogFlags(const llvm::cl::opt &log) { + codon::getLogger().parse(log); + if (auto *d = getenv("CODON_DEBUG")) + codon::getLogger().parse(std::string(d)); +} + enum BuildKind { LLVM, Bitcode, Object, Executable, Detect }; enum OptMode { Debug, Release }; } // namespace @@ -84,10 +91,11 @@ std::unique_ptr processSource(const std::vector & bool standalone) { llvm::cl::opt input(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-")); + auto regs = llvm::cl::getRegisteredOptions(); llvm::cl::opt optMode( llvm::cl::desc("optimization mode"), llvm::cl::values( - clEnumValN(Debug, "debug", + clEnumValN(Debug, regs.find("debug") != regs.end() ? "default" : "debug", "Turn off compiler optimizations and show backtraces"), clEnumValN(Release, "release", "Turn on compiler optimizations and disable debug info")), @@ -102,9 +110,7 @@ std::unique_ptr processSource(const std::vector & llvm::cl::opt log("log", llvm::cl::desc("Enable given log streams")); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); - codon::getLogger().parse(log); - if (auto *d = getenv("CODON_DEBUG")) - codon::getLogger().parse(std::string(d)); + initLogFlags(log); auto &exts = supportedExtensions(); if (input != "-" && std::find_if(exts.begin(), exts.end(), [&](auto &ext) { @@ -206,7 +212,9 @@ std::string jitExec(codon::jit::JIT *jit, const std::string &code) { int jitMode(const std::vector &args) { llvm::cl::list plugins("plugin", llvm::cl::desc("Load specified plugin")); + llvm::cl::opt log("log", llvm::cl::desc("Enable given log streams")); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); + initLogFlags(log); codon::jit::JIT jit(args[0]); // load plugins @@ -230,13 +238,13 @@ int jitMode(const std::vector &args) { if (line != "#%%") { code += line + "\n"; } else { - fmt::print("{}\n\n[done]\n\n", jitExec(&jit, code)); + fmt::print("{}[done]\n", jitExec(&jit, code)); code = ""; fflush(stdout); } } if (!code.empty()) - fmt::print("{}\n\n[done]\n\n", jitExec(&jit, code)); + fmt::print("{}[done]\n", jitExec(&jit, code)); return EXIT_SUCCESS; } diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 2c920274..f781b14f 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -1,7 +1,5 @@ #include "compiler.h" -#include - #include "codon/parser/cache.h" #include "codon/parser/peg/peg.h" #include "codon/parser/visitors/doc/doc.h" @@ -11,13 +9,29 @@ #include "codon/parser/visitors/typecheck/typecheck.h" namespace codon { +namespace { +ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool isTest) { + using ir::transform::PassManager; + switch (mode) { + case Compiler::Mode::DEBUG: + return isTest ? PassManager::Init::RELEASE : PassManager::Init::DEBUG; + case Compiler::Mode::RELEASE: + return PassManager::Init::RELEASE; + case Compiler::Mode::JIT: + return PassManager::Init::JIT; + default: + return PassManager::Init::EMPTY; + } +} +} // namespace -Compiler::Compiler(const std::string &argv0, bool debug, +Compiler::Compiler(const std::string &argv0, Compiler::Mode mode, const std::vector &disabledPasses, bool isTest) - : argv0(argv0), debug(debug), input(), plm(std::make_unique()), + : argv0(argv0), debug(mode == Mode::DEBUG), input(), + plm(std::make_unique()), cache(std::make_unique(argv0)), module(std::make_unique()), - pm(std::make_unique(debug && !isTest, + pm(std::make_unique(getPassManagerInit(mode, isTest), disabledPasses)), llvisitor(std::make_unique()) { cache->module = module.get(); @@ -50,9 +64,7 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code, int startLine, int testFlags, const std::unordered_map &defines) { input = file; - std::string abspath = - (file != "-") ? std::filesystem::absolute(std::filesystem::path(file)).string() - : file; + std::string abspath = (file != "-") ? ast::getAbsolutePath(file) : file; try { Timer t1("parse"); ast::StmtPtr codeStmt = isCode diff --git a/codon/compiler/compiler.h b/codon/compiler/compiler.h index 35a30d69..8d0a8001 100644 --- a/codon/compiler/compiler.h +++ b/codon/compiler/compiler.h @@ -15,6 +15,13 @@ namespace codon { class Compiler { +public: + enum Mode { + DEBUG, + RELEASE, + JIT, + }; + private: std::string argv0; bool debug; @@ -30,9 +37,14 @@ private: const std::unordered_map &defines); public: - Compiler(const std::string &argv0, bool debug = false, + Compiler(const std::string &argv0, Mode mode, const std::vector &disabledPasses = {}, bool isTest = false); + explicit Compiler(const std::string &argv0, bool debug = false, + const std::vector &disabledPasses = {}, + bool isTest = false) + : Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest) {} + std::string getInput() const { return input; } PluginManager *getPluginManager() const { return plm.get(); } ast::Cache *getCache() const { return cache.get(); } diff --git a/codon/compiler/jit.cpp b/codon/compiler/jit.cpp index 34016480..881e5a82 100644 --- a/codon/compiler/jit.cpp +++ b/codon/compiler/jit.cpp @@ -18,7 +18,7 @@ const std::string JIT_FILENAME = ""; } // namespace JIT::JIT(const std::string &argv0, const std::string &mode) - : compiler(std::make_unique(argv0, /*debug=*/true)), mode(mode) { + : compiler(std::make_unique(argv0, Compiler::Mode::JIT)), mode(mode) { if (auto e = Engine::create()) { engine = std::move(e.get()); } else { @@ -44,7 +44,7 @@ llvm::Error JIT::init() { pm->run(module); module->accept(*llvisitor); - auto pair = llvisitor->takeModule(); + auto pair = llvisitor->takeModule(module); if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)})) return err; @@ -58,25 +58,22 @@ llvm::Error JIT::init() { return llvm::Error::success(); } -llvm::Expected JIT::run(const ir::Func *input, - const std::vector &newGlobals) { +llvm::Expected JIT::run(const ir::Func *input) { auto *module = compiler->getModule(); auto *pm = compiler->getPassManager(); auto *llvisitor = compiler->getLLVMVisitor(); + + Timer t1("jit/ir"); pm->run(module); + t1.log(); const std::string name = ir::LLVMVisitor::getNameForFunction(input); - llvisitor->registerGlobal(input); - for (auto *var : newGlobals) { - llvisitor->registerGlobal(var); - } - for (auto *var : newGlobals) { - if (auto *func = ir::cast(var)) - func->accept(*llvisitor); - } - input->accept(*llvisitor); - auto pair = llvisitor->takeModule(); + Timer t2("jit/llvm"); + auto pair = llvisitor->takeModule(module); + t2.log(); + + Timer t3("jit/engine"); if (auto err = engine->addModule({std::move(pair.first), std::move(pair.second)})) return std::move(err); @@ -85,6 +82,8 @@ llvm::Expected JIT::run(const ir::Func *input, return std::move(err); auto *repl = (InputFunc *)func->getAddress(); + t3.log(); + try { (*repl)(); } catch (const JITError &e) { @@ -138,11 +137,7 @@ llvm::Expected JIT::exec(const std::string &code) { auto *cache = compiler->getCache(); auto typechecked = ast::TypecheckVisitor::apply(cache, simplified); - std::vector globalNames; - for (auto &g : cache->globals) { - if (!g.second) - globalNames.push_back(g.first); - } + // add newly realized functions std::vector v; std::vector frs; @@ -155,16 +150,7 @@ llvm::Expected JIT::exec(const std::string &code) { ast::TranslateVisitor::apply(cache, std::make_shared(v, false)); cache->jitCell++; - std::vector globalVars; - for (auto &g : globalNames) { - seqassert(cache->globals[g], "JIT global {} not set", g); - globalVars.push_back(cache->globals[g]); - } - for (auto &i : frs) { - seqassert(*i, "JIT fn not set"); - globalVars.push_back(*i); - } - return run(func, globalVars); + return run(func); } catch (const exc::ParserException &e) { *cache = bCache; *(cache->imports[MAIN_IMPORT].ctx) = bSimplify; diff --git a/codon/compiler/jit.h b/codon/compiler/jit.h index b781299d..ab3e42a7 100644 --- a/codon/compiler/jit.h +++ b/codon/compiler/jit.h @@ -28,8 +28,7 @@ public: Engine *getEngine() const { return engine.get(); } llvm::Error init(); - llvm::Expected run(const ir::Func *input, - const std::vector &newGlobals = {}); + llvm::Expected run(const ir::Func *input); llvm::Expected exec(const std::string &code); }; diff --git a/codon/dsl/plugins.cpp b/codon/dsl/plugins.cpp index 07c1fb97..897aab68 100644 --- a/codon/dsl/plugins.cpp +++ b/codon/dsl/plugins.cpp @@ -1,12 +1,14 @@ #include "plugins.h" #include -#include #include "codon/parser/common.h" #include "codon/util/common.h" #include "codon/util/semver/semver.h" #include "codon/util/toml++/toml.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" namespace codon { namespace { @@ -17,8 +19,6 @@ llvm::Expected pluginError(const std::string &msg) { typedef std::unique_ptr LoadFunc(); } // namespace -namespace fs = std::filesystem; - llvm::Expected PluginManager::load(const std::string &path) { #if __APPLE__ const std::string libExt = "dylib"; @@ -27,35 +27,42 @@ llvm::Expected PluginManager::load(const std::string &path) { #endif const std::string config = "plugin.toml"; - fs::path tomlPath = fs::path(path) / config; - if (!fs::exists(tomlPath)) { + + llvm::SmallString<128> tomlPath(path); + llvm::sys::path::append(tomlPath, config); + if (!llvm::sys::fs::exists(tomlPath)) { // try default install path - if (auto *homeDir = std::getenv("HOME")) - tomlPath = fs::path(homeDir) / ".codon/plugins" / path / config; + if (auto *homeDir = std::getenv("HOME")) { + tomlPath = homeDir; + llvm::sys::path::append(tomlPath, ".codon", "plugins", path, config); + } } toml::parse_result tml; try { - tml = toml::parse_file(tomlPath.string()); + tml = toml::parse_file(tomlPath.str()); } catch (const toml::parse_error &e) { return pluginError( - fmt::format("[toml::parse_file(\"{}\")] {}", tomlPath.string(), e.what())); + fmt::format("[toml::parse_file(\"{}\")] {}", tomlPath.str(), e.what())); } auto about = tml["about"]; auto library = tml["library"]; std::string cppLib = library["cpp"].value_or(""); - std::string dylibPath; - if (!cppLib.empty()) - dylibPath = fs::path(tomlPath) - .replace_filename(library["cpp"].value_or("lib")) - .replace_extension(libExt) - .string(); + llvm::SmallString<128> dylibPath; + if (!cppLib.empty()) { + dylibPath = llvm::sys::path::parent_path(tomlPath); + auto fn = std::string(library["cpp"].value_or("lib")) + "." + libExt; + llvm::sys::path::append(dylibPath, fn); + } std::string codonLib = library["codon"].value_or(""); std::string stdlibPath; - if (!codonLib.empty()) - stdlibPath = fs::path(tomlPath).replace_filename(codonLib).string(); + if (!codonLib.empty()) { + llvm::SmallString<128> p = llvm::sys::path::parent_path(tomlPath.str()); + llvm::sys::path::append(p, codonLib); + stdlibPath = p.str(); + } DSL::Info info = {about["name"].value_or(""), about["description"].value_or(""), about["version"].value_or(""), about["url"].value_or(""), @@ -80,13 +87,13 @@ llvm::Expected PluginManager::load(const std::string &path) { &libLoadErrorMsg); if (!handle.isValid()) return pluginError(fmt::format( - "[llvm::sys::DynamicLibrary::getPermanentLibrary(\"{}\", ...)] {}", dylibPath, - libLoadErrorMsg)); + "[llvm::sys::DynamicLibrary::getPermanentLibrary(\"{}\", ...)] {}", + dylibPath.str(), libLoadErrorMsg)); auto *entry = (LoadFunc *)handle.getAddressOfSymbol("load"); if (!entry) - return pluginError( - fmt::format("could not find 'load' in plugin shared library: {}", dylibPath)); + return pluginError(fmt::format( + "could not find 'load' in plugin shared library: {}", dylibPath.str())); auto dsl = (*entry)(); plugins.push_back(std::make_unique(std::move(dsl), info, handle)); diff --git a/codon/parser/ast/expr.cpp b/codon/parser/ast/expr.cpp index 0f3e3dd2..b9e87b8c 100644 --- a/codon/parser/ast/expr.cpp +++ b/codon/parser/ast/expr.cpp @@ -18,7 +18,7 @@ namespace ast { Expr::Expr() : type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC), - done(false) {} + done(false), attributes(0) {} types::TypePtr Expr::getType() const { return type; } void Expr::setType(types::TypePtr t) { this->type = std::move(t); } bool Expr::isType() const { return isTypeExpr; } @@ -31,6 +31,8 @@ std::string Expr::wrapType(const std::string &sexpr) const { done ? "*" : ""); } bool Expr::isStatic() const { return staticValue.type != StaticValue::NOT_STATIC; } +bool Expr::hasAttr(int attr) const { return (attributes & (1 << attr)); } +void Expr::setAttr(int attr) { attributes |= (1 << attr); } StaticValue::StaticValue(StaticValue::Type t) : value(), type(t), evaluated(false) {} StaticValue::StaticValue(int64_t i) : value(i), type(INT), evaluated(true) {} diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index df5716c4..b4f8a11c 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -78,6 +78,9 @@ struct Expr : public codon::SrcObject { /// type-checking procedure was successful). bool done; + /// Set of attributes. + int attributes; + public: Expr(); Expr(const Expr &expr) = default; @@ -124,6 +127,10 @@ public: virtual const TupleExpr *getTuple() const { return nullptr; } virtual const UnaryExpr *getUnary() const { return nullptr; } + /// Attribute helpers + bool hasAttr(int attr) const; + void setAttr(int attr); + protected: /// Add a type to S-expression string. std::string wrapType(const std::string &sexpr) const; @@ -666,5 +673,7 @@ struct StackAllocExpr : Expr { #undef ACCEPT +enum ExprAttr { SequenceItem, StarSequenceItem, List, Set, Dict, Partial, __LAST__ }; + } // namespace ast } // namespace codon diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index 119825aa..45245dae 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -285,6 +285,8 @@ const std::string Attr::Atomic = "atomic"; const std::string Attr::Property = "property"; const std::string Attr::Internal = "__internal__"; const std::string Attr::ForceRealize = "__force__"; +const std::string Attr::RealizeWithoutSelf = + "std.internal.attributes.realize_without_self"; const std::string Attr::C = "std.internal.attributes.C"; const std::string Attr::CVarArg = ".__vararg__"; const std::string Attr::Method = ".__method__"; @@ -292,6 +294,7 @@ const std::string Attr::Capture = ".__capture__"; const std::string Attr::Extend = "extend"; const std::string Attr::Tuple = "tuple"; const std::string Attr::Test = "std.internal.attributes.test"; +const std::string Attr::Overload = "std.internal.attributes.overload"; FunctionStmt::FunctionStmt(std::string name, ExprPtr ret, std::vector args, StmtPtr suite, Attr attributes, diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index 23fde86b..c4155f3a 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -398,6 +398,7 @@ struct Attr { // Internal attributes const static std::string Internal; const static std::string ForceRealize; + const static std::string RealizeWithoutSelf; // Compiler-generated attributes const static std::string C; const static std::string CVarArg; @@ -408,6 +409,7 @@ struct Attr { const static std::string Tuple; // Standard library attributes const static std::string Test; + const static std::string Overload; // Function module std::string module; // Parent class (set for methods only) diff --git a/codon/parser/ast/types.cpp b/codon/parser/ast/types.cpp index a08abed4..9aa156b4 100644 --- a/codon/parser/ast/types.cpp +++ b/codon/parser/ast/types.cpp @@ -185,7 +185,7 @@ std::string LinkType::debugString(bool debug) const { // fmt::format("{}->{}", id, type->debugString(debug)); } std::string LinkType::realizedName() const { - if (kind == Unbound) + if (kind == Unbound || kind == Generic) return "?"; seqassert(kind == Link, "unexpected generic link"); return type->realizedName(); @@ -476,12 +476,18 @@ std::vector FuncType::getUnbounds() const { } bool FuncType::canRealize() const { // Important: return type does not have to be realized. - for (int ai = 1; ai < args.size(); ai++) + + bool force = ast->hasAttr(Attr::RealizeWithoutSelf); + + int ai = 1 + force; + for (; ai < args.size(); ai++) if (!args[ai]->getFunc() && !args[ai]->canRealize()) return false; - return std::all_of(funcGenerics.begin(), funcGenerics.end(), - [](auto &a) { return !a.type || a.type->canRealize(); }) && - (!funcParent || funcParent->canRealize()); + bool generics = std::all_of(funcGenerics.begin(), funcGenerics.end(), + [](auto &a) { return !a.type || a.type->canRealize(); }); + if (!force) + generics &= (!funcParent || funcParent->canRealize()); + return generics; } bool FuncType::isInstantiated() const { TypePtr removed = nullptr; @@ -532,15 +538,7 @@ PartialType::PartialType(const std::shared_ptr &baseType, std::shared_ptr func, std::vector known) : RecordType(*baseType), func(move(func)), known(move(known)) {} int PartialType::unify(Type *typ, Unification *us) { - int s1 = 0, s; - if (auto tc = typ->getPartial()) { - // Check names. - if ((s = func->unify(tc->func.get(), us)) == -1) - return -1; - s1 += s; - } - s = this->RecordType::unify(typ, us); - return s == -1 ? s : s1 + s; + return this->RecordType::unify(typ, us); } TypePtr PartialType::generalize(int atLevel) { return std::make_shared( @@ -549,10 +547,9 @@ TypePtr PartialType::generalize(int atLevel) { } TypePtr PartialType::instantiate(int atLevel, int *unboundCount, std::unordered_map *cache) { - return std::make_shared( - std::static_pointer_cast( - this->RecordType::instantiate(atLevel, unboundCount, cache)), - func, known); + auto rec = std::static_pointer_cast( + this->RecordType::instantiate(atLevel, unboundCount, cache)); + return std::make_shared(rec, func, known); } std::string PartialType::debugString(bool debug) const { std::vector gs; @@ -573,7 +570,7 @@ std::string PartialType::debugString(bool debug) const { } std::string PartialType::realizedName() const { std::vector gs; - gs.push_back(func->realizedName()); + gs.push_back(func->ast->name); for (auto &a : generics) if (!a.name.empty()) gs.push_back(a.type->realizedName()); @@ -755,12 +752,24 @@ int CallableTrait::unify(Type *typ, Unification *us) { zeros.emplace_back(pi - 9); if (zeros.size() + 1 != args.size()) return -1; - if (args[0]->unify(pt->func->args[0].get(), us) == -1) - return -1; + + int ic = 0; + std::unordered_map c; + auto pf = pt->func->instantiate(0, &ic, &c)->getFunc(); + // For partial functions, we just check can we unify without actually performing + // unification for (int pi = 0, gi = 1; pi < pt->known.size(); pi++) - if (!pt->known[pi] && !pt->func->ast->args[pi].generic) - if (args[gi++]->unify(pt->func->args[pi + 1].get(), us) == -1) + if (!pt->known[pi] && !pf->ast->args[pi].generic) + if (args[gi++]->unify(pf->args[pi + 1].get(), us) == -1) return -1; + if (us && us->realizator && pf->canRealize()) { + // Realize if possible to allow deduction of return type [and possible + // unification!] + auto rf = us->realizator->realize(pf); + pf->unify(rf.get(), us); + } + if (args[0]->unify(pf->args[0].get(), us) == -1) + return -1; return 1; } } else if (auto tl = typ->getLink()) { diff --git a/codon/parser/ast/types.h b/codon/parser/ast/types.h index 065e21b0..94f30263 100644 --- a/codon/parser/ast/types.h +++ b/codon/parser/ast/types.h @@ -14,6 +14,7 @@ struct Expr; struct StaticValue; struct FunctionStmt; struct TypeContext; +class TypecheckVisitor; namespace types { @@ -43,6 +44,11 @@ struct Type : public codon::SrcObject, public std::enable_shared_from_this std::vector> leveled; /// List of assigned traits. std::vector traits; + /// Pointer to a TypecheckVisitor to support realization function types. + TypecheckVisitor *realizator = nullptr; + /// List of pointers that are owned by unification process + /// (to avoid memory issues with undoing). + std::vector> ownedTypes; public: /// Undo the unification step. diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index e94ca102..1af811eb 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -111,14 +111,18 @@ ir::Func *Cache::realizeFunction(types::FuncTypePtr type, } } } + int oldAge = typeCtx->age; + typeCtx->age = 99999; auto tv = TypecheckVisitor(typeCtx); + ir::Func *f = nullptr; if (auto rtv = tv.realize(type)) { auto pr = pendingRealizations; // copy it as it might be modified for (auto &fn : pr) TranslateVisitor(codegenCtx).transform(functions[fn.first].ast->clone()); - return functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir; + f = functions[rtv->getFunc()->ast->name].realizations[rtv->realizedName()]->ir; } - return nullptr; + typeCtx->age = oldAge; + return f; } ir::types::Type *Cache::makeTuple(const std::vector &types) { diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 6b0ccbb7..feb234bf 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -188,6 +188,9 @@ struct Cache : public std::enable_shared_from_this { std::shared_ptr codegenCtx; /// Set of function realizations that are to be translated to IR. std::set> pendingRealizations; + /// Mapping of partial record names to function pointers and corresponding masks. + std::unordered_map>> + partials; /// Custom operators std::unordered_map #include #include #include "codon/parser/common.h" #include "codon/util/fmt/format.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" namespace codon { namespace ast { @@ -195,28 +196,29 @@ std::string executable_path(const char *argv0) { std::string executable_path(const char *argv0) { return std::string(argv0); } #endif -namespace fs = std::filesystem; - namespace { -void addPath(std::vector &paths, const fs::path &path) { - if (fs::exists(path)) - paths.push_back(fs::canonical(path)); + +void addPath(std::vector &paths, const std::string &path) { + if (llvm::sys::fs::exists(path)) + paths.push_back(getAbsolutePath(path)); } -std::vector getStdLibPaths(const std::string &argv0, - const std::vector &plugins) { - std::vector paths; +std::vector getStdLibPaths(const std::string &argv0, + const std::vector &plugins) { + std::vector paths; if (auto c = getenv("CODON_PATH")) { - addPath(paths, fs::path(std::string(c))); + addPath(paths, c); } if (!argv0.empty()) { - auto base = fs::path(executable_path(argv0.c_str())); + auto base = executable_path(argv0.c_str()); for (auto loci : {"../lib/codon/stdlib", "../stdlib", "stdlib"}) { - addPath(paths, base.parent_path() / loci); + auto path = llvm::SmallString<128>(llvm::sys::path::parent_path(base)); + llvm::sys::path::append(path, loci); + addPath(paths, std::string(path)); } } for (auto &path : plugins) { - addPath(paths, fs::path(path)); + addPath(paths, path); } return paths; } @@ -244,28 +246,47 @@ ImportFile getRoot(const std::string argv0, const std::vector &plug } } // namespace +std::string getAbsolutePath(const std::string &path) { + char *c = realpath(path.c_str(), nullptr); + if (!c) + return path; + std::string result(c); + free(c); + return result; +} + std::shared_ptr getImportFile(const std::string &argv0, const std::string &what, const std::string &relativeTo, bool forceStdlib, const std::string &module0, const std::vector &plugins) { - std::vector paths; + std::vector paths; if (what != "") { - auto parentRelativeTo = fs::path(relativeTo).parent_path(); + auto parentRelativeTo = llvm::sys::path::parent_path(relativeTo); if (!forceStdlib) { - addPath(paths, (parentRelativeTo / what).replace_extension("codon")); - addPath(paths, parentRelativeTo / what / "__init__.codon"); + auto path = llvm::SmallString<128>(parentRelativeTo); + llvm::sys::path::append(path, what); + llvm::sys::path::replace_extension(path, "codon"); + addPath(paths, std::string(path)); + path = llvm::SmallString<128>(parentRelativeTo); + llvm::sys::path::append(path, what, "__init__.codon"); + addPath(paths, std::string(path)); } } for (auto &p : getStdLibPaths(argv0, plugins)) { - addPath(paths, (p / what).replace_extension("codon")); - addPath(paths, p / what / "__init__.codon"); + auto path = llvm::SmallString<128>(p); + llvm::sys::path::append(path, what); + llvm::sys::path::replace_extension(path, "codon"); + addPath(paths, std::string(path)); + path = llvm::SmallString<128>(p); + llvm::sys::path::append(path, what, "__init__.codon"); + addPath(paths, std::string(path)); } - auto module0Root = fs::path(module0).parent_path().string(); + auto module0Root = llvm::sys::path::parent_path(module0).str(); return paths.empty() ? nullptr : std::make_shared( - getRoot(argv0, plugins, module0Root, paths[0].string())); + getRoot(argv0, plugins, module0Root, paths[0])); } } // namespace ast diff --git a/codon/parser/common.h b/codon/parser/common.h index 8c1be3a5..4978e64b 100644 --- a/codon/parser/common.h +++ b/codon/parser/common.h @@ -167,6 +167,9 @@ template std::vector clone_nop(const std::vector &t) { /// Path utilities +/// @return The absolute canonical path of a given path. +std::string getAbsolutePath(const std::string &path); + /// Detect a absolute path of the current executable (whose argv0 is known). /// @return Absolute executable path or argv0 if one cannot be found. std::string executable_path(const char *argv0); diff --git a/codon/parser/peg/peg.cpp b/codon/parser/peg/peg.cpp index 80ef9f49..48a02d27 100644 --- a/codon/parser/peg/peg.cpp +++ b/codon/parser/peg/peg.cpp @@ -109,9 +109,8 @@ StmtPtr parseFile(Cache *cache, const std::string &file) { cache->imports[file].content = lines; auto result = parseCode(cache, file, code); + // For debugging purposes: // LOG("peg/{} := {}", file, result ? result->toString(0) : ""); - // throw; - // LOG("fmt := {}", FormatVisitor::apply(result)); return result; } diff --git a/codon/parser/visitors/doc/doc.cpp b/codon/parser/visitors/doc/doc.cpp index e5dc470b..264502b1 100644 --- a/codon/parser/visitors/doc/doc.cpp +++ b/codon/parser/visitors/doc/doc.cpp @@ -1,6 +1,5 @@ #include "doc.h" -#include #include #include #include @@ -116,10 +115,9 @@ std::shared_ptr DocVisitor::apply(const std::string &argv0, auto ctx = std::make_shared(shared); for (auto &f : files) { - auto path = std::filesystem::canonical(std::filesystem::path(f)).string(); + auto path = getAbsolutePath(f); ctx->setFilename(path); ast = ast::parseFile(shared->cache, path); - // LOG("parsing {}", f); DocVisitor(ctx).transformModule(std::move(ast)); } diff --git a/codon/parser/visitors/format/format.cpp b/codon/parser/visitors/format/format.cpp index 434f0d00..3fbcd4fa 100644 --- a/codon/parser/visitors/format/format.cpp +++ b/codon/parser/visitors/format/format.cpp @@ -406,15 +406,6 @@ void FormatVisitor::visit(FunctionStmt *fstmt) { } void FormatVisitor::visit(ClassStmt *stmt) { - // if (cache && - // cache->realizationAsts.find(fstmt->name) != cache->realizationAsts.end()) { - // fstmt = (const FunctionStmt *)(cache->realizationAsts[fstmt->name].get()); - // } else if (cache) { - // for (auto &real : cache->realizations[fstmt->name]) - // result += simplify(cache->realizationAsts[real.first]); - // return; - // } - std::vector attrs; if (!stmt->attributes.has(Attr::Extend)) diff --git a/codon/parser/visitors/format/format.h b/codon/parser/visitors/format/format.h index 2f0ffd02..e58ce6d2 100644 --- a/codon/parser/visitors/format/format.h +++ b/codon/parser/visitors/format/format.h @@ -30,9 +30,6 @@ class FormatVisitor : public CallbackASTVisitor { private: template std::string renderExpr(T &&t, Ts &&...args) { std::string s; - // if (renderType) - // s += fmt::format("{}{}{}", typeStart, - // t->getType() ? t->getType()->toString() : "-", typeEnd); return fmt::format("{}{}{}{}{}{}", exprStart, s, nodeStart, fmt::format(args...), nodeEnd, exprEnd); } diff --git a/codon/parser/visitors/simplify/simplify_ctx.cpp b/codon/parser/visitors/simplify/simplify_ctx.cpp index 517518db..0ddf3fda 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.cpp +++ b/codon/parser/visitors/simplify/simplify_ctx.cpp @@ -24,7 +24,8 @@ SimplifyContext::SimplifyContext(std::string filename, Cache *cache) allowTypeOf(true), substitutions(nullptr) {} SimplifyContext::Base::Base(std::string name, std::shared_ptr ast, int attributes) - : name(move(name)), ast(move(ast)), attributes(attributes) {} + : name(move(name)), ast(move(ast)), attributes(attributes), deducedMembers(nullptr), + selfName() {} std::shared_ptr SimplifyContext::add(SimplifyItem::Kind kind, const std::string &name, diff --git a/codon/parser/visitors/simplify/simplify_ctx.h b/codon/parser/visitors/simplify/simplify_ctx.h index 468ded2a..e525ff24 100644 --- a/codon/parser/visitors/simplify/simplify_ctx.h +++ b/codon/parser/visitors/simplify/simplify_ctx.h @@ -67,6 +67,9 @@ struct SimplifyContext : public Context { /// Tracks function attributes (e.g. if it has @atomic or @test attributes). int attributes; + std::shared_ptr> deducedMembers; + std::string selfName; + explicit Base(std::string name, ExprPtr ast = nullptr, int attributes = 0); bool isType() const { return ast != nullptr; } }; diff --git a/codon/parser/visitors/simplify/simplify_expr.cpp b/codon/parser/visitors/simplify/simplify_expr.cpp index b45f02e3..eccc0784 100644 --- a/codon/parser/visitors/simplify/simplify_expr.cpp +++ b/codon/parser/visitors/simplify/simplify_expr.cpp @@ -31,6 +31,8 @@ ExprPtr SimplifyVisitor::transform(const ExprPtr &expr, bool allowTypes, ctx->canAssign = oldAssign; if (!allowTypes && v.resultExpr && v.resultExpr->isType()) error("unexpected type expression"); + if (v.resultExpr) + v.resultExpr->attributes |= expr->attributes; return v.resultExpr; } @@ -179,15 +181,21 @@ void SimplifyVisitor::visit(ListExpr *expr) { for (const auto &it : expr->items) { if (auto star = it->getStar()) { ExprPtr forVar = N(ctx->cache->getTemporaryVar("it")); + auto st = star->what->clone(); + st->setAttr(ExprAttr::StarSequenceItem); stmts.push_back(transform(N( - clone(forVar), star->what->clone(), + clone(forVar), st, N(N(N(clone(var), "append"), clone(forVar)))))); } else { - stmts.push_back(transform( - N(N(N(clone(var), "append"), clone(it))))); + auto st = clone(it); + st->setAttr(ExprAttr::SequenceItem); + stmts.push_back( + transform(N(N(N(clone(var), "append"), st)))); } } - resultExpr = N(stmts, transform(var)); + auto e = N(stmts, transform(var)); + e->setAttr(ExprAttr::List); + resultExpr = e; ctx->popBlock(); } @@ -200,14 +208,20 @@ void SimplifyVisitor::visit(SetExpr *expr) { for (auto &it : expr->items) if (auto star = it->getStar()) { ExprPtr forVar = N(ctx->cache->getTemporaryVar("it")); + auto st = star->what->clone(); + st->setAttr(ExprAttr::StarSequenceItem); stmts.push_back(transform(N( - clone(forVar), star->what->clone(), + clone(forVar), st, N(N(N(clone(var), "add"), clone(forVar)))))); } else { - stmts.push_back(transform( - N(N(N(clone(var), "add"), clone(it))))); + auto st = clone(it); + st->setAttr(ExprAttr::SequenceItem); + stmts.push_back( + transform(N(N(N(clone(var), "add"), st)))); } - resultExpr = N(stmts, transform(var)); + auto e = N(stmts, transform(var)); + e->setAttr(ExprAttr::Set); + resultExpr = e; ctx->popBlock(); } @@ -220,16 +234,24 @@ void SimplifyVisitor::visit(DictExpr *expr) { for (auto &it : expr->items) if (auto star = CAST(it.value, KeywordStarExpr)) { ExprPtr forVar = N(ctx->cache->getTemporaryVar("it")); + auto st = star->what->clone(); + st->setAttr(ExprAttr::StarSequenceItem); stmts.push_back(transform(N( - clone(forVar), N(N(star->what->clone(), "items")), + clone(forVar), N(N(st, "items")), N(N(N(clone(var), "__setitem__"), N(clone(forVar), N(0)), N(clone(forVar), N(1))))))); } else { - stmts.push_back(transform(N(N( - N(clone(var), "__setitem__"), clone(it.key), clone(it.value))))); + auto k = clone(it.key); + k->setAttr(ExprAttr::SequenceItem); + auto v = clone(it.value); + v->setAttr(ExprAttr::SequenceItem); + stmts.push_back(transform( + N(N(N(clone(var), "__setitem__"), k, v)))); } - resultExpr = N(stmts, transform(var)); + auto e = N(stmts, transform(var)); + e->setAttr(ExprAttr::Dict); + resultExpr = e; ctx->popBlock(); } @@ -308,10 +330,10 @@ void SimplifyVisitor::visit(UnaryExpr *expr) { void SimplifyVisitor::visit(BinaryExpr *expr) { auto lhs = (startswith(expr->op, "is") && expr->lexpr->getNone()) ? clone(expr->lexpr) - : transform(expr->lexpr); + : transform(expr->lexpr, startswith(expr->op, "is")); auto rhs = (startswith(expr->op, "is") && expr->rexpr->getNone()) ? clone(expr->rexpr) - : transform(expr->rexpr, false, + : transform(expr->rexpr, startswith(expr->op, "is"), /*allowAssign*/ expr->op != "&&" && expr->op != "||"); resultExpr = N(lhs, expr->op, rhs, expr->inPlace); } @@ -493,9 +515,9 @@ void SimplifyVisitor::visit(CallExpr *expr) { ctx->add(SimplifyItem::Var, varName, varName); var = N(varName); ctx->addBlock(); // prevent tmp vars from being toplevel vars - ex = N( - transform(N(clone(g->loops[0].vars), clone(var), nullptr, true)), - transform(ex)); + auto head = + transform(N(clone(g->loops[0].vars), clone(var), nullptr, true)); + ex = N(head, transform(ex)); ctx->popBlock(); } std::vector body; @@ -688,7 +710,9 @@ void SimplifyVisitor::visit(StmtExpr *expr) { for (auto &s : expr->stmts) stmts.emplace_back(transform(s)); auto e = transform(expr->expr); - resultExpr = N(stmts, e); + auto s = N(stmts, e); + s->attributes = expr->attributes; + resultExpr = s; } /**************************************************************************************/ diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index 178fc12a..c27dffb7 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -180,7 +180,9 @@ void SimplifyVisitor::visit(WhileStmt *stmt) { transform(N(N(breakVar), N(true), nullptr, true)); } ctx->loops.push_back(breakVar); // needed for transforming break in loop..else blocks - StmtPtr whileStmt = N(transform(cond), transform(stmt->suite)); + + cond = transform(cond); + StmtPtr whileStmt = N(cond, transform(stmt->suite)); ctx->loops.pop_back(); if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { resultStmt = @@ -232,8 +234,9 @@ void SimplifyVisitor::visit(ForStmt *stmt) { ctx->addBlock(); if (auto i = stmt->var->getId()) { ctx->add(SimplifyItem::Var, i->value, ctx->generateCanonicalName(i->value)); - forStmt = N(transform(stmt->var), clone(iter), transform(stmt->suite), - nullptr, decorator, ompArgs); + auto var = transform(stmt->var); + forStmt = N(var, clone(iter), transform(stmt->suite), nullptr, decorator, + ompArgs); } else { std::string varName = ctx->cache->getTemporaryVar("for"); ctx->add(SimplifyItem::Var, varName, varName); @@ -259,8 +262,8 @@ void SimplifyVisitor::visit(ForStmt *stmt) { void SimplifyVisitor::visit(IfStmt *stmt) { seqassert(stmt->cond, "invalid if statement"); - resultStmt = N(transform(stmt->cond), transform(stmt->ifSuite), - transform(stmt->elseSuite)); + auto cond = transform(stmt->cond); + resultStmt = N(cond, transform(stmt->ifSuite), transform(stmt->elseSuite)); } void SimplifyVisitor::visit(MatchStmt *stmt) { @@ -473,7 +476,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { // TODO: error on decorators return; } - + bool overload = attr.has(Attr::Overload); bool isClassMember = ctx->inClass(); std::string rootName; if (isClassMember) { @@ -481,10 +484,11 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { auto i = m.find(stmt->name); if (i != m.end()) rootName = i->second; - } else if (auto c = ctx->find(stmt->name)) { - if (c->isFunc() && c->getModule() == ctx->getModule() && - c->getBase() == ctx->getBase()) - rootName = c->canonicalName; + } else if (overload) { + if (auto c = ctx->find(stmt->name)) + if (c->isFunc() && c->getModule() == ctx->getModule() && + c->getBase() == ctx->getBase()) + rootName = c->canonicalName; } if (rootName.empty()) rootName = ctx->generateCanonicalName(stmt->name, true); @@ -504,7 +508,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { if (isClassMember) ctx->bases.push_back(oldBases[0]); ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); // Add new base... - ctx->addBlock(); // ... and a block! + if (isClassMember && ctx->bases[0].deducedMembers) + ctx->bases.back().deducedMembers = ctx->bases[0].deducedMembers; + ctx->addBlock(); // ... and a block! // Set atomic flag if @atomic attribute is present. if (attr.has(Attr::Atomic)) ctx->bases.back().attributes |= FLAG_ATOMIC; @@ -540,9 +546,12 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { error("non-default argument '{}' after a default argument", varName); defaultsStarted |= bool(a.deflt); + auto name = ctx->generateCanonicalName(varName); + auto typeAst = a.type; if (!typeAst && isClassMember && ia == 0 && a.name == "self") { typeAst = ctx->bases[ctx->bases.size() - 2].ast; + ctx->bases.back().selfName = name; attr.set(".changedSelf"); attr.set(Attr::Method); } @@ -557,9 +566,14 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { } // First add all generics! - auto name = ctx->generateCanonicalName(varName); - args.emplace_back( - Param{std::string(stars, '*') + name, typeAst, a.deflt, a.generic}); + auto deflt = a.deflt; + if (typeAst && typeAst->getIndex() && typeAst->getIndex()->expr->isId("Callable") && + deflt && deflt->getNone()) + deflt = N(N("NoneType")); + if (typeAst && (typeAst->isId("type") || typeAst->isId("TypeVar")) && deflt && + deflt->getNone()) + deflt = N("NoneType"); + args.emplace_back(Param{std::string(stars, '*') + name, typeAst, deflt, a.generic}); if (a.generic) { if (a.type->getIndex() && a.type->getIndex()->expr->isId("Static")) ctx->add(SimplifyItem::Var, varName, name); @@ -677,12 +691,15 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { Attr attr = stmt->attributes; std::vector hasMagic(10, 2); hasMagic[Init] = hasMagic[Pickle] = 1; + bool deduce = false; // @tuple(init=, repr=, eq=, order=, hash=, pickle=, container=, python=, add=, // internal=...) // @dataclass(...) // @extend for (auto &d : stmt->decorators) { - if (auto c = d->getCall()) { + if (d->isId("deduce")) { + deduce = true; + } else if (auto c = d->getCall()) { if (c->expr->isId(Attr::Tuple)) attr.set(Attr::Tuple); else if (!c->expr->isId("dataclass")) @@ -861,6 +878,43 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { } argSubstitutions.push_back(substitutions.size() - 1); } + + // Auto-detect fields + StmtPtr autoDeducedInit = nullptr; + Stmt *firstInit = nullptr; + if (deduce && args.empty() && !extension) { + for (auto sp : getClassMethods(stmt->suite)) + if (sp && sp->getFunction()) { + firstInit = sp.get(); + auto f = sp->getFunction(); + if (f->name == "__init__" && f->args.size() >= 1 && f->args[0].name == "self") { + ctx->bases.back().deducedMembers = + std::make_shared>(); + transform(sp); + autoDeducedInit = preamble->functions.back(); + std::dynamic_pointer_cast(autoDeducedInit) + ->attributes.set(Attr::RealizeWithoutSelf); + ctx->cache->functions[autoDeducedInit->getFunction()->name] + .ast->attributes.set(Attr::RealizeWithoutSelf); + + int i = 0; + for (auto &m : *(ctx->bases.back().deducedMembers)) { + auto varName = ctx->generateCanonicalName(format("T{}", ++i)); + auto name = ctx->cache->reverseIdentifierLookup[varName]; + ctx->add(SimplifyItem::Type, name, varName, true); + genAst.push_back(N(varName)); + args.emplace_back(Param{varName, N("type"), nullptr, true}); + argSubstitutions.push_back(substitutions.size() - 1); + + ctx->cache->classes[canonicalName].fields.push_back({m, nullptr}); + args.emplace_back(Param{m, N(varName), nullptr}); + argSubstitutions.push_back(substitutions.size() - 1); + } + ctx->bases.back().deducedMembers = nullptr; + break; + } + } + } if (!genAst.empty()) ctx->bases.back().ast = std::make_shared(N(name), N(genAst)); @@ -925,7 +979,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { magics = {"len", "hash"}; else magics = {"new", "raw"}; - if (hasMagic[Init]) + if (hasMagic[Init] && !firstInit) magics.emplace_back(isRecord ? "new" : "init"); if (hasMagic[Eq]) for (auto &i : {"eq", "ne"}) @@ -959,7 +1013,6 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { } } for (int ai = 0; ai < baseASTs.size(); ai++) { - // FUNCS for (auto &mm : ctx->cache->classes[baseASTs[ai]->name].methods) for (auto &mf : ctx->cache->overloads[mm.second]) { auto f = ctx->cache->functions[mf.name].ast; @@ -999,6 +1052,8 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { } for (auto sp : getClassMethods(stmt->suite)) if (sp && !sp->getClass()) { + if (firstInit && firstInit == sp.get()) + continue; transform(sp); suite->stmts.push_back(preamble->functions.back()); } @@ -1012,8 +1067,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { seqassert(c, "not a class AST for {}", canonicalName); preamble->globals.push_back(c->clone()); c->suite = clone(suite); - // if (stmt->baseClasses.size()) - // LOG("{} -> {}", stmt->name, c->toString(0)); + } stmts[0] = N(canonicalName, std::vector{}, N(), Attr({Attr::Extend}), std::vector{}, @@ -1046,7 +1100,15 @@ StmtPtr SimplifyVisitor::transformAssignment(const ExprPtr &lhs, const ExprPtr & clone(ei->index), rhs->clone()))); } else if (auto ed = lhs->getDot()) { seqassert(!type, "unexpected type annotation"); - return N(transform(ed->expr), ed->member, transform(rhs, false)); + auto l = transform(ed->expr); + if (ctx->bases.size() && ctx->bases.back().deducedMembers && + l->isId(ctx->bases.back().selfName)) { + if (std::find(ctx->bases.back().deducedMembers->begin(), + ctx->bases.back().deducedMembers->end(), + ed->member) == ctx->bases.back().deducedMembers->end()) + ctx->bases.back().deducedMembers->push_back(ed->member); + } + return N(l, ed->member, transform(rhs, false)); } else if (auto e = lhs->getId()) { ExprPtr t = transformType(type, false); if (!shadow && !t) { diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 7e0ba953..055add0f 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -1,6 +1,5 @@ #include "translate.h" -#include #include #include #include @@ -35,8 +34,7 @@ ir::Func *TranslateVisitor::apply(Cache *cache, StmtPtr stmts) { main->setJIT(); } else { main = cast(cache->module->getMainFunc()); - auto path = - std::filesystem::canonical(std::filesystem::path(cache->module0)).string(); + auto path = getAbsolutePath(cache->module0); main->setSrcInfo({path, 0, 0, 0}); } @@ -57,8 +55,79 @@ ir::Func *TranslateVisitor::apply(Cache *cache, StmtPtr stmts) { ir::Value *TranslateVisitor::transform(const ExprPtr &expr) { TranslateVisitor v(ctx); v.setSrcInfo(expr->getSrcInfo()); + + types::PartialType *p = nullptr; + if (expr->attributes) { + if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set) || + expr->hasAttr(ExprAttr::Dict) || expr->hasAttr(ExprAttr::Partial)) { + ctx->seqItems.push_back(std::vector>()); + } + if (expr->hasAttr(ExprAttr::Partial)) + p = expr->type->getPartial().get(); + } + expr->accept(v); - return v.result; + ir::Value *ir = v.result; + + if (expr->attributes) { + if (expr->hasAttr(ExprAttr::List) || expr->hasAttr(ExprAttr::Set)) { + std::vector v; + for (auto &p : ctx->seqItems.back()) { + seqassert(p.first <= ExprAttr::StarSequenceItem, "invalid list/set element"); + v.push_back( + ir::LiteralElement{p.second, p.first == ExprAttr::StarSequenceItem}); + } + if (expr->hasAttr(ExprAttr::List)) + ir->setAttribute(std::make_unique(v)); + else + ir->setAttribute(std::make_unique(v)); + ctx->seqItems.pop_back(); + } + if (expr->hasAttr(ExprAttr::Dict)) { + std::vector v; + for (int pi = 0; pi < ctx->seqItems.back().size(); pi++) { + auto &p = ctx->seqItems.back()[pi]; + if (p.first == ExprAttr::StarSequenceItem) { + v.push_back({p.second, nullptr}); + } else { + seqassert(p.first == ExprAttr::SequenceItem && + pi + 1 < ctx->seqItems.back().size() && + ctx->seqItems.back()[pi + 1].first == ExprAttr::SequenceItem, + "invalid dict element"); + v.push_back({p.second, ctx->seqItems.back()[pi + 1].second}); + pi++; + } + } + ir->setAttribute(std::make_unique(v)); + ctx->seqItems.pop_back(); + } + if (expr->hasAttr(ExprAttr::Partial)) { + std::vector v; + seqassert(p, "invalid partial element"); + int j = 0; + for (int i = 0; i < p->known.size(); i++) { + if (p->known[i] && !p->func->ast->args[i].generic) { + seqassert(j < ctx->seqItems.back().size() && + ctx->seqItems.back()[j].first == ExprAttr::SequenceItem, + "invalid partial element"); + v.push_back(ctx->seqItems.back()[j++].second); + } else if (!p->func->ast->args[i].generic) { + v.push_back({nullptr}); + } + } + ir->setAttribute( + std::make_unique(p->func->ast->name, v)); + ctx->seqItems.pop_back(); + } + if (expr->hasAttr(ExprAttr::SequenceItem)) { + ctx->seqItems.back().push_back({ExprAttr::SequenceItem, ir}); + } + if (expr->hasAttr(ExprAttr::StarSequenceItem)) { + ctx->seqItems.back().push_back({ExprAttr::StarSequenceItem, ir}); + } + } + + return ir; } void TranslateVisitor::defaultVisit(Expr *n) { @@ -91,8 +160,10 @@ void TranslateVisitor::visit(IdExpr *expr) { } void TranslateVisitor::visit(IfExpr *expr) { - result = make(expr, transform(expr->cond), transform(expr->ifexpr), - transform(expr->elsexpr)); + auto cond = transform(expr->cond); + auto ifexpr = transform(expr->ifexpr); + auto elsexpr = transform(expr->elsexpr); + result = make(expr, cond, ifexpr, elsexpr); } void TranslateVisitor::visit(CallExpr *expr) { @@ -428,13 +499,12 @@ void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *as std::map attr; attr[".module"] = ast->attributes.module; for (auto &a : ast->attributes.customAttr) { - // LOG("{} -> {}", ast->name, a); attr[a] = ""; } func->setAttribute(std::make_unique(attr)); for (int i = 0; i < names.size(); i++) func->getArgVar(names[i])->setSrcInfo(ast->args[indices[i]].getSrcInfo()); - func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); + // func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); if (!ast->attributes.has(Attr::C) && !ast->attributes.has(Attr::Internal)) { ctx->addBlock(); for (auto i = 0; i < names.size(); i++) @@ -511,7 +581,7 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt f->setLLVMBody(join(lines, "\n")); f->setLLVMDeclarations(declare); f->setLLVMLiterals(literals); - func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); + // func->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); } } // namespace ast diff --git a/codon/parser/visitors/translate/translate_ctx.cpp b/codon/parser/visitors/translate/translate_ctx.cpp index 690c0b84..7d7589c2 100644 --- a/codon/parser/visitors/translate/translate_ctx.cpp +++ b/codon/parser/visitors/translate/translate_ctx.cpp @@ -21,14 +21,14 @@ std::shared_ptr TranslateContext::find(const std::string &name) c return t; std::shared_ptr ret = nullptr; auto tt = cache->typeCtx->find(name); - if (tt->isType() && tt->type->canRealize()) { + if (tt && tt->isType() && tt->type->canRealize()) { ret = std::make_shared(TranslateItem::Type, bases[0]); seqassert(in(cache->classes, tt->type->getClass()->name) && in(cache->classes[tt->type->getClass()->name].realizations, name), "cannot find type realization {}", name); ret->handle.type = cache->classes[tt->type->getClass()->name].realizations[name]->ir; - } else if (tt->type->getFunc() && tt->type->canRealize()) { + } else if (tt && tt->type->getFunc() && tt->type->canRealize()) { ret = std::make_shared(TranslateItem::Func, bases[0]); seqassert( in(cache->functions, tt->type->getFunc()->ast->name) && diff --git a/codon/parser/visitors/translate/translate_ctx.h b/codon/parser/visitors/translate/translate_ctx.h index 6d182492..49c0e5b5 100644 --- a/codon/parser/visitors/translate/translate_ctx.h +++ b/codon/parser/visitors/translate/translate_ctx.h @@ -50,6 +50,8 @@ struct TranslateContext : public Context { std::vector bases; /// Stack of IR series (blocks). std::vector series; + /// Stack of sequence items for attribute initialization. + std::vector>> seqItems; public: TranslateContext(Cache *cache); diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 77dc0ba0..315da1a6 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -38,6 +38,7 @@ TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b, bool undoOnSuccess return a = b; seqassert(b, "rhs is nullptr"); types::Type::Unification undo; + undo.realizator = this; if (a->unify(b.get(), &undo) >= 0) { if (undoOnSuccess) undo.undo(); @@ -45,7 +46,6 @@ TypePtr TypecheckVisitor::unify(TypePtr &a, const TypePtr &b, bool undoOnSuccess } else { undo.undo(); } - // LOG("{} / {}", a->debugString(true), b->debugString(true)); if (!undoOnSuccess) a->unify(b.get(), &undo); error("cannot unify {} and {}", a->toString(), b->toString()); diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.cpp b/codon/parser/visitors/typecheck/typecheck_ctx.cpp index aac6eb51..4863cbc2 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.cpp +++ b/codon/parser/visitors/typecheck/typecheck_ctx.cpp @@ -78,8 +78,8 @@ std::shared_ptr TypeContext::addUnbound(const Expr *expr, int level, bool setActive, char staticType) { auto t = std::make_shared( types::LinkType::Unbound, cache->unboundCount++, level, nullptr, staticType); - // if (t->id == 7815) - // LOG("debug"); + // Keep it for debugging purposes: + // if (t->id == 7815) LOG("debug"); t->setSrcInfo(expr->getSrcInfo()); LOG_TYPECHECK("[ub] new {}: {} ({})", t->debugString(true), expr->toString(), setActive); @@ -202,17 +202,11 @@ int TypeContext::reorderNamedArgs(types::FuncType *func, int starArgIndex = -1, kwstarArgIndex = -1; for (int i = 0; i < func->ast->args.size(); i++) { - // if (!known.empty() && known[i] && !partial) - // continue; if (startswith(func->ast->args[i].name, "**")) kwstarArgIndex = i, score -= 2; else if (startswith(func->ast->args[i].name, "*")) starArgIndex = i, score -= 2; } - // seqassert(known.empty() || starArgIndex == -1 || !known[starArgIndex], - // "partial *args"); - // seqassert(known.empty() || kwstarArgIndex == -1 || !known[kwstarArgIndex], - // "partial **kwargs"); // 1. Assign positional arguments to slots // Each slot contains a list of arg's indices diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index f893af60..3337c691 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -10,6 +10,7 @@ #include "codon/parser/common.h" #include "codon/parser/visitors/simplify/simplify.h" #include "codon/parser/visitors/typecheck/typecheck.h" +#include "codon/sir/attribute.h" using fmt::format; @@ -35,8 +36,10 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr, bool allowTypes, bool allowVo ctx->allowActivation = false; v.setSrcInfo(expr->getSrcInfo()); expr->accept(v); - if (v.resultExpr) + if (v.resultExpr) { + v.resultExpr->attributes |= expr->attributes; expr = v.resultExpr; + } seqassert(expr->type, "type not set for {}", expr->toString()); unify(typ, expr->type); if (disableActivation) @@ -645,6 +648,8 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, // Check if this is a "a is None" expression. If so, ... if (expr->op == "is" && expr->rexpr->getNone()) { + if (expr->lexpr->getType()->getClass()->name == "NoneType") + return transform(N(true)); if (expr->lexpr->getType()->getClass()->name != TYPE_OPTIONAL) // ... return False if lhs is not an Optional... return transform(N(false)); @@ -739,14 +744,7 @@ ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &e if (!tuple->getRecord()) return nullptr; if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_PARTIAL)) - // in(std::set{"Ptr", "pyobj", "str", "Array"}, tuple->name)) - // Ptr, pyobj and str are internal types and have only one overloaded __getitem__ return nullptr; - // if (in(ctx->cache->classes[tuple->name].methods, "__getitem__")) { - // ctx->cache->overloads[ctx->cache->classes[tuple->name].methods["__getitem__"]] - // .size() != 1) - // return nullptr; - // } // Extract a static integer value from a compatible expression. auto getInt = [&](int64_t *o, const ExprPtr &e) { @@ -870,8 +868,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, return transform( N(N(expr->expr, "_getattr"), N(expr->member))); } else { - // For debugging purposes: - ctx->findMethod(typ->name, expr->member); + // For debugging purposes: ctx->findMethod(typ->name, expr->member); error("cannot find '{}' in {}", expr->member, typ->toString()); } } @@ -1021,7 +1018,6 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ai--; } else { // Case 3: Normal argument - // LOG("-> {}", expr->args[ai].value->toString()); expr->args[ai].value = transform(expr->args[ai].value, true); // Unbound inType might become a generator that will need to be extracted, so // don't unify it yet. @@ -1047,8 +1043,6 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in ctx->bases.back().supers, expr->args); if (m.empty()) error("no matching superf methods are available"); - // LOG("found {} <- {}", ctx->bases.back().type->getFunc()->toString(), - // m[0]->toString()); ExprPtr e = N(N(m[0]->ast->name), expr->args); return transform(e, false, true); } @@ -1126,10 +1120,12 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in expr->expr = transform(N(N(clone(var), expr->expr), N(pc->func->ast->name))); calleeFn = expr->expr->type->getFunc(); + // Fill in generics for (int i = 0, j = 0; i < pc->known.size(); i++) if (pc->func->ast->args[i].generic) { if (pc->known[i]) - unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type); + unify(calleeFn->funcGenerics[j].type, + ctx->instantiate(expr, pc->func->funcGenerics[j].type)); j++; } known = pc->known; @@ -1144,7 +1140,6 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in std::vector args; std::vector typeArgs; int typeArgCount = 0; - // bool isPartial = false; int ellipsisStage = -1; auto newMask = std::vector(calleeFn->ast->args.size(), 1); auto getPartialArg = [&](int pi) { @@ -1248,6 +1243,8 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in return -1; }, known); + bool hasPartialArgs = partialStarArgs != nullptr, + hasPartialKwargs = partialKwstarArgs != nullptr; if (isPartial) { deactivateUnbounds(expr->args.back().value->getType().get()); expr->args.pop_back(); @@ -1283,7 +1280,6 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in } auto e = transform(expr->expr); unify(expr->type, e->getType()); - // LOG("-- {} / {}", e->toString(), e->type->debugString(true)); return e; } @@ -1345,8 +1341,6 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in if (replacements[si]) { if (replacements[si]->getFunc()) deactivateUnbounds(replacements[si].get()); - if (auto pt = replacements[si]->getPartial()) - deactivateUnbounds(pt->func.get()); calleeFn->generics[si + 1].type = calleeFn->args[si + 1] = replacements[si]; } if (!isPartial) { @@ -1365,10 +1359,16 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in deactivateUnbounds(calleeFn.get()); std::vector newArgs; for (auto &r : args) - if (!r.value->getEllipsis()) + if (!r.value->getEllipsis()) { newArgs.push_back(r.value); + newArgs.back()->setAttr(ExprAttr::SequenceItem); + } newArgs.push_back(partialStarArgs); + if (hasPartialArgs) + newArgs.back()->setAttr(ExprAttr::SequenceItem); newArgs.push_back(partialKwstarArgs); + if (hasPartialKwargs) + newArgs.back()->setAttr(ExprAttr::SequenceItem); std::string var = ctx->cache->getTemporaryVar("partial"); ExprPtr call = nullptr; @@ -1383,12 +1383,9 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in N(N(partialTypeName), newArgs)), N(var)); } + call->setAttr(ExprAttr::Partial); call = transform(call, false, allowVoidExpr); - seqassert(call->type->getRecord() && - startswith(call->type->getRecord()->name, partialTypeName) && - !call->type->getPartial(), - "bad partial transformation"); - call->type = N(call->type->getRecord(), calleeFn, newMask); + seqassert(call->type->getPartial(), "expected partial type"); return call; } else { // Case 2. Normal function call. @@ -1629,16 +1626,19 @@ std::string TypecheckVisitor::generateFunctionStub(int n) { std::string TypecheckVisitor::generatePartialStub(const std::vector &mask, types::FuncType *fn) { std::string strMask(mask.size(), '1'); - int tupleSize = 0; + int tupleSize = 0, genericSize = 0; for (int i = 0; i < mask.size(); i++) if (!mask[i]) strMask[i] = '0'; else if (!fn->ast->args[i].generic) tupleSize++; - auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name); - if (!ctx->find(typeName)) - // 2 for .starArgs and .kwstarArgs (empty tuples if fn does not have them) + else + genericSize++; + auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->toString()); + if (!ctx->find(typeName)) { + ctx->cache->partials[typeName] = {fn->generalize(0)->getFunc(), mask}; generateTupleStub(tupleSize + 2, typeName, {}, false); + } return typeName; } @@ -1710,12 +1710,9 @@ ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) { N(N(partialTypeName), N(), N(N(kwName)))), N(var)); + call->setAttr(ExprAttr::Partial); call = transform(call, false, allowVoidExpr); - seqassert(call->type->getRecord() && - startswith(call->type->getRecord()->name, partialTypeName) && - !call->type->getPartial(), - "bad partial transformation"); - call->type = N(call->type->getRecord(), fn, mask); + seqassert(call->type->getPartial(), "expected partial type"); return call; } @@ -1821,12 +1818,6 @@ TypecheckVisitor::findMatchingMethods(types::ClassType *typ, } } if (score != -1) { - // std::vector ar; - // for (auto &a: args) { - // if (a.first.empty()) ar.push_back(a.second->toString()); - // else ar.push_back(format("{}: {}", a.first, a.second->toString())); - // } - // LOG("- {} vs {}", m->toString(), join(ar, "; ")); results.push_back(methods[mi]); } } @@ -1861,6 +1852,8 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, TypePtr expectedType, // Case 7: wrap raw Seq functions into Partial(...) call for easy realization. expr = partializeFunction(expr); } + + // Special case: unify(expr->type, expectedType, undoOnSuccess); return true; } @@ -1951,7 +1944,6 @@ types::FuncTypePtr TypecheckVisitor::findDispatch(const std::string &fn) { ctx->cache->functions[name].ast = ast; ctx->cache->functions[name].type = typ; prependStmts->push_back(ast); - // LOG("dispatch: {}", ast->toString(1)); return typ; } diff --git a/codon/parser/visitors/typecheck/typecheck_infer.cpp b/codon/parser/visitors/typecheck/typecheck_infer.cpp index a6b54731..879783b6 100644 --- a/codon/parser/visitors/typecheck/typecheck_infer.cpp +++ b/codon/parser/visitors/typecheck/typecheck_infer.cpp @@ -32,12 +32,11 @@ types::TypePtr TypecheckVisitor::realize(types::TypePtr typ) { } else if (auto c = typ->getClass()) { auto t = realizeType(c.get()); if (auto p = typ->getPartial()) { - if (auto rt = realize(p->func)) - unify(rt, p->func); + // if (auto rt = realize(p->func)) + // unify(rt, p->func); return std::make_shared(t->getRecord(), p->func, p->known); - } else { - return t; } + return t; } else { return nullptr; } @@ -117,8 +116,9 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { try { auto it = ctx->cache->functions[type->ast->name].realizations.find(type->realizedName()); - if (it != ctx->cache->functions[type->ast->name].realizations.end()) + if (it != ctx->cache->functions[type->ast->name].realizations.end()) { return it->second->type; + } // Set up bases. Ensure that we have proper parent bases even during a realization // of mutually recursive functions. @@ -150,6 +150,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { LOG_REALIZE("[realize] fn {} -> {} : base {} ; depth = {}", type->ast->name, type->realizedName(), ctx->getBase(), depth); { + // Timer trx(fmt::format("fn {}", type->realizedName())); getLogger().level++; ctx->realizationDepth++; ctx->addBlock(); @@ -161,11 +162,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { type->args[0], {}, findSuperMethods(type->getFunc())}); - // if (startswith(type->ast->name, "Foo")) { - // LOG(": {}", type->toString()); - // for (auto &s: ctx->bases.back().supers) - // LOG(" - {}", s->toString()); - // } auto clonedAst = ctx->cache->functions[type->ast->name].ast->clone(); auto *ast = (FunctionStmt *)clonedAst.get(); addFunctionGenerics(type); @@ -177,13 +173,16 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { if (!isInternal) for (int i = 0, j = 1; i < ast->args.size(); i++) if (!ast->args[i].generic) { - seqassert(type->args[j] && type->args[j]->getUnbounds().empty(), - "unbound argument {}", type->args[j]->toString()); std::string varName = ast->args[i].name; trimStars(varName); ctx->add(TypecheckItem::Var, varName, - std::make_shared( - type->args[j++]->generalize(ctx->typecheckLevel))); + std::make_shared(type->args[j++])); + // N.B. this used to be: + // seqassert(type->args[j] && type->args[j]->getUnbounds().empty(), + // "unbound argument {}", type->args[j]->toString()); + // type->args[j++]->generalize(ctx->typecheckLevel) + // no idea why... most likely an old artefact, BUT if seq or sequre + // fail with weird type errors try returning this and see if it works } // Need to populate realization table in advance to make recursive functions @@ -223,7 +222,9 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { // Realize the return type. if (auto t = realize(type->args[0])) unify(type->args[0], t); - LOG_REALIZE("done with {} / {}", type->realizedName(), oldKey); + LOG_REALIZE("[realize] done with {} / {} =>{}", type->realizedName(), oldKey, + time); + // trx.log(); // Create and store IR node and a realized AST to be used // during the code generation. @@ -239,6 +240,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { } else { r->ir = ctx->cache->module->Nr(type->realizedName()); } + r->ir->setUnmangledName(ctx->cache->reverseIdentifierLookup[type->ast->name]); auto parent = type->funcParent; if (!ast->attributes.parentClass.empty() && @@ -408,6 +410,8 @@ std::pair TypecheckVisitor::inferTypes(StmtPtr result, bool keepLa ir::types::Type *TypecheckVisitor::getLLVMType(const types::ClassType *t) { auto realizedName = t->realizedTypeName(); + if (!in(ctx->cache->classes[t->name].realizations, realizedName)) + realizeType(const_cast(t)); if (auto l = ctx->cache->classes[t->name].realizations[realizedName]->ir) return l; auto getLLVM = [&](const TypePtr &tt) { diff --git a/codon/parser/visitors/typecheck/typecheck_stmt.cpp b/codon/parser/visitors/typecheck/typecheck_stmt.cpp index 42ece717..c7aa05e9 100644 --- a/codon/parser/visitors/typecheck/typecheck_stmt.cpp +++ b/codon/parser/visitors/typecheck/typecheck_stmt.cpp @@ -508,6 +508,13 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { else typ = std::make_shared( stmt->name, ctx->cache->reverseIdentifierLookup[stmt->name]); + if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) { + seqassert(in(ctx->cache->partials, stmt->name), + "invalid partial initialization: {}", stmt->name); + typ = std::make_shared(typ->getRecord(), + ctx->cache->partials[stmt->name].first, + ctx->cache->partials[stmt->name].second); + } typ->setSrcInfo(stmt->getSrcInfo()); ctx->add(TypecheckItem::Type, stmt->name, typ); ctx->bases[0].visitedAsts[stmt->name] = {TypecheckItem::Type, typ}; diff --git a/codon/sir/analyze/dataflow/cfg.h b/codon/sir/analyze/dataflow/cfg.h index 3627d32f..9946b259 100644 --- a/codon/sir/analyze/dataflow/cfg.h +++ b/codon/sir/analyze/dataflow/cfg.h @@ -481,6 +481,8 @@ public: void visit(const dsl::CustomInstr *v) override; template void process(const NodeType *v) { + if (!v) + return; if (seenIds.find(v->getId()) != seenIds.end()) return; seenIds.insert(v->getId()); diff --git a/codon/sir/attribute.cpp b/codon/sir/attribute.cpp index b27541e5..075d6b4a 100644 --- a/codon/sir/attribute.cpp +++ b/codon/sir/attribute.cpp @@ -1,5 +1,8 @@ -#include "value.h" +#include "attribute.h" +#include "codon/sir/func.h" +#include "codon/sir/util/cloning.h" +#include "codon/sir/value.h" #include "codon/util/fmt/ostream.h" namespace codon { @@ -36,5 +39,141 @@ std::ostream &MemberAttribute::doFormat(std::ostream &os) const { const std::string SrcInfoAttribute::AttributeName = "srcInfoAttribute"; +const std::string TupleLiteralAttribute::AttributeName = "tupleLiteralAttribute"; + +std::unique_ptr TupleLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.clone(val)); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +TupleLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto *val : elements) + elementsCloned.push_back(cv.forceClone(val)); + return std::make_unique(elementsCloned); +} + +std::ostream &TupleLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto *val : elements) + strings.push_back(fmt::format(FMT_STRING("{}"), *val)); + fmt::print(os, FMT_STRING("({})"), fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string ListLiteralAttribute::AttributeName = "listLiteralAttribute"; + +std::unique_ptr ListLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &e : elements) + elementsCloned.push_back({cv.clone(e.value), e.star}); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +ListLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &e : elements) + elementsCloned.push_back({cv.forceClone(e.value), e.star}); + return std::make_unique(elementsCloned); +} + +std::ostream &ListLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto &e : elements) + strings.push_back(fmt::format(FMT_STRING("{}{}"), e.star ? "*" : "", *e.value)); + fmt::print(os, FMT_STRING("[{}]"), fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string SetLiteralAttribute::AttributeName = "setLiteralAttribute"; + +std::unique_ptr SetLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &e : elements) + elementsCloned.push_back({cv.clone(e.value), e.star}); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +SetLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &e : elements) + elementsCloned.push_back({cv.forceClone(e.value), e.star}); + return std::make_unique(elementsCloned); +} + +std::ostream &SetLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto &e : elements) + strings.push_back(fmt::format(FMT_STRING("{}{}"), e.star ? "*" : "", *e.value)); + fmt::print(os, FMT_STRING("set([{}])"), + fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string DictLiteralAttribute::AttributeName = "dictLiteralAttribute"; + +std::unique_ptr DictLiteralAttribute::clone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &val : elements) + elementsCloned.push_back( + {cv.clone(val.key), val.value ? cv.clone(val.value) : nullptr}); + return std::make_unique(elementsCloned); +} + +std::unique_ptr +DictLiteralAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector elementsCloned; + for (auto &val : elements) + elementsCloned.push_back( + {cv.forceClone(val.key), val.value ? cv.forceClone(val.value) : nullptr}); + return std::make_unique(elementsCloned); +} + +std::ostream &DictLiteralAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto &val : elements) { + if (val.value) { + strings.push_back(fmt::format(FMT_STRING("{}:{}"), *val.key, *val.value)); + } else { + strings.push_back(fmt::format(FMT_STRING("**{}"), *val.key)); + } + } + fmt::print(os, FMT_STRING("dict([{}])"), + fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + +const std::string PartialFunctionAttribute::AttributeName = "partialFunctionAttribute"; + +std::unique_ptr +PartialFunctionAttribute::clone(util::CloneVisitor &cv) const { + std::vector argsCloned; + for (auto *val : args) + argsCloned.push_back(cv.clone(val)); + return std::make_unique(name, argsCloned); +} + +std::unique_ptr +PartialFunctionAttribute::forceClone(util::CloneVisitor &cv) const { + std::vector argsCloned; + for (auto *val : args) + argsCloned.push_back(cv.forceClone(val)); + return std::make_unique(name, argsCloned); +} + +std::ostream &PartialFunctionAttribute::doFormat(std::ostream &os) const { + std::vector strings; + for (auto *val : args) + strings.push_back(val ? fmt::format(FMT_STRING("{}"), *val) : "..."); + fmt::print(os, FMT_STRING("{}({})"), name, + fmt::join(strings.begin(), strings.end(), ",")); + return os; +} + } // namespace ir } // namespace codon diff --git a/codon/sir/attribute.h b/codon/sir/attribute.h index 48beb5b7..eb2e945f 100644 --- a/codon/sir/attribute.h +++ b/codon/sir/attribute.h @@ -14,6 +14,13 @@ namespace codon { namespace ir { +class Func; +class Value; + +namespace util { +class CloneVisitor; +} + /// Base for SIR attributes. struct Attribute { virtual ~Attribute() noexcept = default; @@ -26,14 +33,15 @@ struct Attribute { } /// @return a clone of the attribute - std::unique_ptr clone() const { - return std::unique_ptr(doClone()); + virtual std::unique_ptr clone(util::CloneVisitor &cv) const = 0; + + /// @return a clone of the attribute + virtual std::unique_ptr forceClone(util::CloneVisitor &cv) const { + return clone(cv); } private: virtual std::ostream &doFormat(std::ostream &os) const = 0; - - virtual Attribute *doClone() const = 0; }; /// Attribute containing SrcInfo @@ -48,10 +56,12 @@ struct SrcInfoAttribute : public Attribute { /// @param info the source info explicit SrcInfoAttribute(codon::SrcInfo info) : info(std::move(info)) {} + std::unique_ptr clone(util::CloneVisitor &cv) const override { + return std::make_unique(*this); + } + private: std::ostream &doFormat(std::ostream &os) const override { return os << info; } - - Attribute *doClone() const override { return new SrcInfoAttribute(*this); } }; /// Attribute containing function information @@ -76,10 +86,12 @@ struct KeyValueAttribute : public Attribute { /// string if none std::string get(const std::string &key) const; + std::unique_ptr clone(util::CloneVisitor &cv) const override { + return std::make_unique(*this); + } + private: std::ostream &doFormat(std::ostream &os) const override; - - Attribute *doClone() const override { return new KeyValueAttribute(*this); } }; /// Attribute containing type member information @@ -95,10 +107,116 @@ struct MemberAttribute : public Attribute { explicit MemberAttribute(std::map memberSrcInfo) : memberSrcInfo(std::move(memberSrcInfo)) {} + std::unique_ptr clone(util::CloneVisitor &cv) const override { + return std::make_unique(*this); + } + private: std::ostream &doFormat(std::ostream &os) const override; +}; - Attribute *doClone() const override { return new MemberAttribute(*this); } +/// Attribute attached to IR structures corresponding to tuple literals +struct TupleLiteralAttribute : public Attribute { + static const std::string AttributeName; + + /// values contained in tuple literal + std::vector elements; + + explicit TupleLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Information about an element in a collection literal +struct LiteralElement { + /// the element value + Value *value; + /// true if preceded by "*", as in "[*x]" + bool star; +}; + +/// Attribute attached to IR structures corresponding to list literals +struct ListLiteralAttribute : public Attribute { + static const std::string AttributeName; + + /// elements contained in list literal + std::vector elements; + + explicit ListLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Attribute attached to IR structures corresponding to set literals +struct SetLiteralAttribute : public Attribute { + static const std::string AttributeName; + + /// elements contained in set literal + std::vector elements; + + explicit SetLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Attribute attached to IR structures corresponding to dict literals +struct DictLiteralAttribute : public Attribute { + struct KeyValuePair { + /// the key in the literal + Value *key; + /// the value in the literal, or null if key is being star-unpacked + Value *value; + }; + + static const std::string AttributeName; + + /// keys and values contained in dict literal + std::vector elements; + + explicit DictLiteralAttribute(std::vector elements) + : elements(std::move(elements)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; +}; + +/// Attribute attached to IR structures corresponding to partial functions +struct PartialFunctionAttribute : public Attribute { + static const std::string AttributeName; + + /// base name of the function being used in the partial + std::string name; + + /// partial arguments, or null if none + /// e.g. "f(a, ..., b)" has elements [a, null, b] + std::vector args; + + PartialFunctionAttribute(const std::string &name, std::vector args) + : name(name), args(std::move(args)) {} + + std::unique_ptr clone(util::CloneVisitor &cv) const override; + std::unique_ptr forceClone(util::CloneVisitor &cv) const override; + +private: + std::ostream &doFormat(std::ostream &os) const override; }; } // namespace ir diff --git a/codon/sir/llvm/llvisitor.cpp b/codon/sir/llvm/llvisitor.cpp index cd0cf2ee..269094ea 100644 --- a/codon/sir/llvm/llvisitor.cpp +++ b/codon/sir/llvm/llvisitor.cpp @@ -88,8 +88,7 @@ void LLVMVisitor::registerGlobal(const Var *var) { return; if (auto *f = cast(var)) { - makeLLVMFunction(f); - insertFunc(f, func); + insertFunc(f, makeLLVMFunction(f)); } else { llvm::Type *llvmType = getLLVMType(var->getType()); if (llvmType->isVoidTy()) { @@ -145,6 +144,7 @@ llvm::Value *LLVMVisitor::getVar(const Var *var) { } } else { registerGlobal(var); + it = vars.find(var->getId()); return it->second; } } @@ -177,6 +177,7 @@ llvm::Function *LLVMVisitor::getFunc(const Func *func) { } } else { registerGlobal(func); + it = funcs.find(func->getId()); return it->second; } } @@ -210,7 +211,34 @@ std::unique_ptr LLVMVisitor::makeModule(llvm::LLVMContext &context } std::pair, std::unique_ptr> -LLVMVisitor::takeModule(const SrcInfo *src) { +LLVMVisitor::takeModule(Module *module, const SrcInfo *src) { + // process any new functions or globals + if (module) { + std::unordered_set funcsToProcess; + for (auto *var : *module) { + auto id = var->getId(); + if (auto *func = cast(var)) { + if (funcs.find(id) != funcs.end()) + continue; + else + funcsToProcess.insert(id); + } else { + if (vars.find(id) != vars.end()) + continue; + } + + registerGlobal(var); + } + + for (auto *var : *module) { + if (auto *func = cast(var)) { + if (funcsToProcess.find(func->getId()) != funcsToProcess.end()) { + process(func); + } + } + } + } + db.builder->finalize(); auto currentContext = std::move(context); auto currentModule = std::move(M); @@ -220,10 +248,25 @@ LLVMVisitor::takeModule(const SrcInfo *src) { func = nullptr; block = nullptr; value = nullptr; - for (auto &it : vars) - it.second = nullptr; - for (auto &it : funcs) - it.second = nullptr; + + for (auto it = funcs.begin(); it != funcs.end();) { + if (it->second && it->second->hasPrivateLinkage()) { + it = funcs.erase(it); + } else { + it->second = nullptr; + ++it; + } + } + + for (auto it = vars.begin(); it != vars.end();) { + if (it->second && !llvm::isa(it->second)) { + it = vars.erase(it); + } else { + it->second = nullptr; + ++it; + } + } + coro.reset(); loops.clear(); trycatch.clear(); @@ -575,8 +618,7 @@ void LLVMVisitor::visit(const Module *x) { } const Func *main = x->getMainFunc(); - makeLLVMFunction(main); - llvm::FunctionCallee realMain = func; + llvm::FunctionCallee realMain = makeLLVMFunction(main); process(main); setDebugInfoForNode(nullptr); @@ -712,12 +754,15 @@ llvm::DISubprogram *LLVMVisitor::getDISubprogramForFunc(const Func *x) { return subprogram; } -void LLVMVisitor::makeLLVMFunction(const Func *x) { +llvm::Function *LLVMVisitor::makeLLVMFunction(const Func *x) { // process LLVM functions in full immediately if (auto *llvmFunc = cast(x)) { + auto *oldFunc = func; process(llvmFunc); setDebugInfoForNode(nullptr); - return; + auto *newFunc = func; + func = oldFunc; + return newFunc; } auto *funcType = cast(x->getType()); @@ -730,11 +775,12 @@ void LLVMVisitor::makeLLVMFunction(const Func *x) { auto *llvmFuncType = llvm::FunctionType::get(returnType, argTypes, funcType->isVariadic()); const std::string functionName = getNameForFunction(x); - func = llvm::cast( + auto *f = llvm::cast( M->getOrInsertFunction(functionName, llvmFuncType).getCallee()); if (!cast(x)) { - func->setSubprogram(getDISubprogramForFunc(x)); + f->setSubprogram(getDISubprogramForFunc(x)); } + return f; } void LLVMVisitor::makeYield(llvm::Value *value, bool finalYield) { @@ -801,7 +847,7 @@ void LLVMVisitor::visit(const InternalFunc *x) { auto *funcType = cast(x->getType()); std::vector argTypes(funcType->begin(), funcType->end()); - func->setLinkage(getDefaultLinkage()); + func->setLinkage(llvm::GlobalValue::PrivateLinkage); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); std::vector args; for (auto it = func->arg_begin(); it != func->arg_end(); ++it) { @@ -964,7 +1010,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) { seqassert(!fail, "linking failed"); func = M->getFunction(getNameForFunction(x)); seqassert(func, "function not linked in"); - func->setLinkage(getDefaultLinkage()); + func->setLinkage(llvm::GlobalValue::PrivateLinkage); func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); func->setSubprogram(getDISubprogramForFunc(x)); @@ -988,10 +1034,11 @@ void LLVMVisitor::visit(const BodiedFunc *x) { setDebugInfoForNode(x); auto *fnAttributes = x->getAttribute(); - if (fnAttributes && fnAttributes->has("std.internal.attributes.export")) { + if (x->isJIT() || + (fnAttributes && fnAttributes->has("std.internal.attributes.export"))) { func->setLinkage(llvm::GlobalValue::ExternalLinkage); } else { - func->setLinkage(getDefaultLinkage()); + func->setLinkage(llvm::GlobalValue::PrivateLinkage); } if (fnAttributes && fnAttributes->has("std.internal.attributes.inline")) { func->addFnAttr(llvm::Attribute::AttrKind::AlwaysInline); diff --git a/codon/sir/llvm/llvisitor.h b/codon/sir/llvm/llvisitor.h index 0191daa2..ab20fcc6 100644 --- a/codon/sir/llvm/llvisitor.h +++ b/codon/sir/llvm/llvisitor.h @@ -167,7 +167,7 @@ private: // General function helpers llvm::Value *call(llvm::FunctionCallee callee, llvm::ArrayRef args); - void makeLLVMFunction(const Func *); + llvm::Function *makeLLVMFunction(const Func *); void makeYield(llvm::Value *value = nullptr, bool finalYield = false); std::string buildLLVMCodeString(const LLVMFunc *); void callStage(const PipelineFlow::Stage *stage); @@ -290,10 +290,11 @@ public: /// Returns the current module/LLVM context and replaces them /// with new, fresh ones. References to variables or functions /// from the old module will be included as "external". + /// @param module the IR module /// @param src source information for the new module /// @return the current module/context, replaced internally std::pair, std::unique_ptr> - takeModule(const SrcInfo *src = nullptr); + takeModule(Module *module, const SrcInfo *src = nullptr); /// Sets current debug info based on a given node. /// @param node the node whose debug info to use diff --git a/codon/sir/module.cpp b/codon/sir/module.cpp index e01e8808..b5036f9a 100644 --- a/codon/sir/module.cpp +++ b/codon/sir/module.cpp @@ -142,7 +142,8 @@ Func *Module::getOrRealizeFunc(const std::string &funcName, try { return cache->realizeFunction(func, arg, gens); } catch (const exc::ParserException &e) { - LOG_IR("getOrRealizeFunc parser error: {}", e.what()); + for (int i = 0; i < e.messages.size(); i++) + LOG_IR("getOrRealizeFunc parser error at {}: {}", e.locations[i], e.messages[i]); return nullptr; } } diff --git a/codon/sir/transform/cleanup/canonical.cpp b/codon/sir/transform/cleanup/canonical.cpp index c0b962b3..6cf2da47 100644 --- a/codon/sir/transform/cleanup/canonical.cpp +++ b/codon/sir/transform/cleanup/canonical.cpp @@ -280,7 +280,7 @@ struct CanonConstSub : public RewriteRule { Value *newCall = nullptr; if (util::isConst(rhs)) { auto c = util::getConst(rhs); - if (c != -(static_cast(1) << 63)) // ensure no overflow + if (c != -(1ull << 63)) // ensure no overflow newCall = *lhs + *(M->getInt(-c)); } else if (util::isConst(rhs)) { auto c = util::getConst(rhs); diff --git a/codon/sir/transform/manager.cpp b/codon/sir/transform/manager.cpp index 8e5d12bb..a3fd51a4 100644 --- a/codon/sir/transform/manager.cpp +++ b/codon/sir/transform/manager.cpp @@ -139,12 +139,18 @@ void PassManager::invalidate(const std::string &key) { } } -void PassManager::registerStandardPasses(bool debug) { - if (debug) { +void PassManager::registerStandardPasses(PassManager::Init init) { + switch (init) { + case Init::EMPTY: + break; + case Init::DEBUG: { registerPass(std::make_unique()); registerPass(std::make_unique()); registerPass(std::make_unique()); - } else { + break; + } + case Init::RELEASE: + case Init::JIT: { // Pythonic registerPass(std::make_unique()); registerPass(std::make_unique()); @@ -174,10 +180,19 @@ void PassManager::registerStandardPasses(bool debug) { // parallel registerPass(std::make_unique()); - registerPass(std::make_unique(seKey2, rdKey, globalKey, - /*runGlobalDemoton=*/true), - /*insertBefore=*/"", {seKey2, rdKey, globalKey}, - {seKey2, rdKey, cfgKey, globalKey}); + if (init != Init::JIT) { + // Don't demote globals in JIT mode, since they might be used later + // by another user input. + registerPass( + std::make_unique(seKey2, rdKey, globalKey, + /*runGlobalDemoton=*/true), + /*insertBefore=*/"", {seKey2, rdKey, globalKey}, + {seKey2, rdKey, cfgKey, globalKey}); + } + break; + } + default: + seqassert(false, "unknown PassManager init value"); } } diff --git a/codon/sir/transform/manager.h b/codon/sir/transform/manager.h index 4a4f5317..5a231794 100644 --- a/codon/sir/transform/manager.h +++ b/codon/sir/transform/manager.h @@ -95,6 +95,7 @@ public: EMPTY, DEBUG, RELEASE, + JIT, }; static const int PASS_IT_MAX; @@ -102,16 +103,7 @@ public: explicit PassManager(Init init, std::vector disabled = {}) : km(), passes(), analyses(), executionOrder(), results(), disabled(std::move(disabled)) { - switch (init) { - case Init::EMPTY: - break; - case Init::DEBUG: - registerStandardPasses(true); - break; - case Init::RELEASE: - registerStandardPasses(false); - break; - } + registerStandardPasses(init); } explicit PassManager(bool debug = false, std::vector disabled = {}) @@ -156,7 +148,7 @@ public: private: void runPass(Module *module, const std::string &name); - void registerStandardPasses(bool debug = false); + void registerStandardPasses(Init init); void runAnalysis(Module *module, const std::string &name); void invalidate(const std::string &key); }; diff --git a/codon/sir/util/cloning.h b/codon/sir/util/cloning.h index c62ce17f..e3aaf01f 100644 --- a/codon/sir/util/cloning.h +++ b/codon/sir/util/cloning.h @@ -82,7 +82,7 @@ public: for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { - ctx[id]->setAttribute(attr->clone(), *it); + ctx[id]->setAttribute(attr->clone(*this), *it); } } } @@ -125,7 +125,7 @@ public: for (auto it = other->attributes_begin(); it != other->attributes_end(); ++it) { const auto *attr = other->getAttribute(*it); if (attr->needsClone()) { - ctx[id]->setAttribute(attr->clone(), *it); + ctx[id]->setAttribute(attr->forceClone(*this), *it); } } } diff --git a/codon/sir/util/operator.h b/codon/sir/util/operator.h index 40ec34d5..41a6c81e 100644 --- a/codon/sir/util/operator.h +++ b/codon/sir/util/operator.h @@ -68,8 +68,10 @@ public: } void visit(BodiedFunc *f) override { - seen.insert(f->getBody()->getId()); - process(f->getBody()); + if (f->getBody()) { + seen.insert(f->getBody()->getId()); + process(f->getBody()); + } } LAMBDA_VISIT(VarValue); diff --git a/codon/util/common.h b/codon/util/common.h index 780ae25c..73fdc851 100644 --- a/codon/util/common.h +++ b/codon/util/common.h @@ -1,7 +1,7 @@ #pragma once +#include "llvm/Support/Path.h" #include -#include #include #include @@ -123,7 +123,7 @@ struct SrcInfo { SrcInfo() : SrcInfo("", 0, 0, 0) {} friend std::ostream &operator<<(std::ostream &out, const codon::SrcInfo &src) { - out << std::filesystem::path(src.file).filename() << ":" << src.line << ":" + out << llvm::sys::path::filename(src.file).str() << ":" << src.line << ":" << src.col; return out; } diff --git a/extra/jupyter/src/codon.cpp b/extra/jupyter/jupyter.cpp similarity index 88% rename from extra/jupyter/src/codon.cpp rename to extra/jupyter/jupyter.cpp index f6c37dec..eb068f28 100644 --- a/extra/jupyter/src/codon.cpp +++ b/extra/jupyter/jupyter.cpp @@ -1,9 +1,11 @@ -#include "codon.h" +#include "jupyter.h" #ifdef CODON_JUPYTER +#include #include #include #include +#include #include #include #include @@ -50,8 +52,24 @@ nl::json CodonJupyter::execute_request_impl(int execution_counter, const string ast::join(backtrace, " \n")); }); if (failed.empty()) { + std::string out = *result; nl::json pub_data; - pub_data["text/plain"] = *result; + if (ast::startswith(out, "\x00\x00__codon/mime__\x00")) { + std::string mime = ""; + int i = 17; + for (; i < out.size() && out[i]; i++) + mime += out[i]; + if (i < out.size() && !out[i]) { + i += 1; + } else { + mime = "text/plain"; + i = 0; + } + pub_data[mime] = out.substr(i); + LOG("> {}: {}", mime, out.substr(i)); + } else { + pub_data["text/plain"] = out; + } publish_execution_result(execution_counter, move(pub_data), nl::json::object()); return nl::json{{"status", "ok"}, {"payload", nl::json::array()}, diff --git a/extra/jupyter/src/codon.h b/extra/jupyter/jupyter.h similarity index 100% rename from extra/jupyter/src/codon.h rename to extra/jupyter/jupyter.h diff --git a/stdlib/collections.codon b/stdlib/collections.codon index b5a1a21a..75162ac4 100644 --- a/stdlib/collections.codon +++ b/stdlib/collections.codon @@ -339,6 +339,9 @@ class Counter[T](Dict[T,int]): result |= other return result + def __dict_do_op_throws__[F, Z](self, key: T, other: Z, op: F): + self.__dict_do_op__(key, other, 0, op) + @extend class Dict: diff --git a/stdlib/internal/attributes.codon b/stdlib/internal/attributes.codon index 464424da..19ddeea2 100644 --- a/stdlib/internal/attributes.codon +++ b/stdlib/internal/attributes.codon @@ -38,3 +38,10 @@ def distributive(): def C(): pass +@__attribute__ +def realize_without_self(): + pass + +@__attribute__ +def overload(): + pass diff --git a/stdlib/internal/builtin.codon b/stdlib/internal/builtin.codon index 8a1b8d62..5e1498fb 100644 --- a/stdlib/internal/builtin.codon +++ b/stdlib/internal/builtin.codon @@ -331,10 +331,13 @@ class int: return result -def _jit_display(x, s: Static[str]): - if hasattr(x, "__repr_pretty__") and s == "jupyter": - print x.__repr_pretty__() +def _jit_display(x, s: Static[str], bundle: Set[str] = Set[str]()): + if hasattr(x, "_repr_mimebundle_") and s == "jupyter": + d = x._repr_mimebundle_(bundle) + # TODO: pick appropriate mime + mime = next(d.keys()) # just pick first + print(f"\x00\x00__codon/mime__\x00{mime}\x00{d[mime]}", end='') elif hasattr(x, "__repr__"): - print x.__repr__() + print(x.__repr__(), end='') elif hasattr(x, "__str__"): - print x.__str__() + print(x.__str__(), end='') diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index 3301f318..e0d6142a 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -35,9 +35,43 @@ PyDict_SetItem = Function[[cobj, cobj, cobj], cobj](cobj()) PyDict_Next = Function[[cobj, Ptr[int], Ptr[cobj], Ptr[cobj]], int](cobj()) PyObject_GetIter = Function[[cobj], cobj](cobj()) PyIter_Next = Function[[cobj], cobj](cobj()) +PyObject_HasAttrString = Function[[cobj, cobj], int](cobj()) +PyImport_AddModule = Function[[cobj], cobj](cobj()) _PY_MODULE_CACHE = Dict[str, pyobj]() +_PY_INIT = """ +import io + +clsf = None +clsa = None +plt = None +try: + import matplotlib.figure + import matplotlib.pyplot + plt = matplotlib.pyplot + clsf = matplotlib.figure.Figure + clsa = matplotlib.artist.Artist +except ModuleNotFoundError: + pass + +def __codon_repr__(fig): + if clsf and isinstance(fig, clsf): + stream = io.StringIO() + fig.savefig(stream, format="svg") + return 'image/svg+xml', stream.getvalue() + elif clsa and isinstance(fig, list) and all( + isinstance(i, clsa) for i in fig + ): + stream = io.StringIO() + plt.gcf().savefig(stream, format="svg") + return 'image/svg+xml', stream.getvalue() + elif hasattr(fig, "_repr_html_"): + return 'text/html', fig._repr_html_() + else: + return 'text/plain', fig.__repr__() +""" + _PY_INITIALIZED = False def init(): global _PY_INITIALIZED @@ -115,8 +149,14 @@ def init(): PyObject_GetIter = dlsym(hnd, "PyObject_GetIter") global PyIter_Next PyIter_Next = dlsym(hnd, "PyIter_Next") + global PyObject_HasAttrString + PyObject_HasAttrString = dlsym(hnd, "PyObject_HasAttrString") + global PyImport_AddModule + PyImport_AddModule = dlsym(hnd, "PyImport_AddModule") + Py_Initialize() + PyRun_SimpleString(_PY_INIT.c_str()) _PY_INITIALIZED = True def ensure_initialized(): @@ -229,10 +269,19 @@ class pyobj: def get(self, T: type) -> T: return T.__from_py__(self) + def _main_module(): + m = PyImport_AddModule("__main__".c_str()) + return pyobj(m) + + def _repr_mimebundle_(self, bundle = Set[str]()): + fn = pyobj._main_module()._getattr("__codon_repr__") + assert fn.p != cobj(), "cannot find python.__codon_repr__" + mime, txt = fn.__call__(self).get(Tuple[str, str]) + return {mime: txt} + def none(): raise NotImplementedError() - # Type conversions def py(x) -> pyobj: diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index d849b7dd..598827d2 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -381,6 +381,8 @@ print log(5.5) #: 1.70475 #%% import_c_dylib,barebones from internal.dlopen import dlext RT = "./libcodonrt." + dlext() +if RT[-3:] == ".so": + RT = "build/" + RT[2:] from C import RT.seq_str_int(int) -> str as sp print sp(65) #: 65 @@ -994,4 +996,27 @@ def foo(return_, pass_, yield_, break_, continue_, print_, assert_): assert_.append(7) return return_, pass_, yield_, break_, continue_, print_, assert_ print foo([1], [1], [1], [1], [1], [1], [1]) -#: ([1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7]) \ No newline at end of file +#: ([1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7]) + + +#%% class_deduce,barebones +@deduce +class Foo: + def __init__(self, x): + self.x = [x] + self.y = 1, x + +f = Foo(1) +print(f.x, f.y, f.__class__) #: [1] (1, 1) Foo[List[int],Tuple[int,int]] + +f: Foo = Foo('s') +print(f.x, f.y, f.__class__) #: ['s'] (1, 's') Foo[List[str],Tuple[int,str]] + +@deduce +class Bar: + def __init__(self, y): + self.y = Foo(y) + +b = Bar(3.1) +print(b.y.x, b.__class__) #: [3.1] Bar[Foo[List[float],Tuple[int,float]]] + diff --git a/test/parser/types.codon b/test/parser/types.codon index 82b21a66..8f7f2f2d 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -1113,11 +1113,13 @@ print methodcaller('index')(v, 42) #: 1 def foo(x): return 1, x +@overload def foo(x, y): def foo(x, y): return f'{x}_{y}' return 2, foo(x, y) +@overload def foo(x): if x == '': return 3, 0 @@ -1126,9 +1128,19 @@ def foo(x): print foo('hi') #: (3, 2) print foo('hi', 1) #: (2, 'hi_1') +#%% fn_shadow,barebones +def foo(x): + return 1, x +print foo('hi') #: (1, 'hi') + +def foo(x): + return 2, x +print foo('hi') #: (2, 'hi') + #%% fn_overloads_error,barebones def foo(x): return 1, x +@overload def foo(x, y): return 2, x, y foo('hooooooooy!', 1, 2) #! cannot find an overload 'foo' with arguments = str, = int, = int