From ebd344f8949858f7cbaf30956f0e57653a2b79e4 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Thu, 15 Sep 2022 15:40:00 -0400 Subject: [PATCH] GPU and other updates (#52) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add nvptx pass * Fix spaces * Don't change name * Add runtime support * Add init call * Add more runtime functions * Add launch function * Add intrinsics * Fix codegen * Run GPU pass between general opt passes * Set data layout * Create context * Link libdevice * Add function remapping * Fix linkage * Fix libdevice link * Fix linking * Fix personality * Fix linking * Fix linking * Fix linking * Add internalize pass * Add more math conversions * Add more re-mappings * Fix conversions * Fix __str__ * Add decorator attribute for any decorator * Update kernel decorator * Fix kernel decorator * Fix kernel decorator * Fix kernel decorator * Fix kernel decorator * Remove old decorator * Fix pointer calc * Fix fill-in codegen * Fix linkage * Add comment * Update list conversion * Add more conversions * Add dict and set conversions * Add float32 type to IR/LLVM * Add float32 * Add float32 stdlib * Keep required global values in PTX module * Fix PTX module pruning * Fix malloc * Set will-return * Fix name cleanup * Fix access * Fix name cleanup * Fix function renaming * Update dimension API * Fix args * Clean up API * Move GPU transformations to end of opt pipeline * Fix alloc replacements * Fix naming * Target PTX 4.2 * Fix global renaming * Fix early return in static blocks; Add __realized__ function * Format * Add __llvm_name__ for functions * Add vector type to IR * SIMD support [wip] * Update kernel naming * Fix early returns; Fix SIMD calls * Fix kernel naming * Fix IR matcher * Remove module print * Update realloc * Add overloads for 32-bit float math ops * Add gpu.Pointer type for working with raw pointers * Add float32 conversion * Add to_gpu and from_gpu * clang-format * Add f32 reduction support to OpenMP * Fix automatic GPU class conversions * Fix conversion functions * Fix conversions * Rename self * Fix tuple conversion * Fix conversions * Fix conversions * Update PTX filename * Fix filename * Add raw function * Add GPU docs * Allow nested object conversions * Add tests (WIP) * Update SIMD * Add staticrange and statictuple loop support * SIMD updates * Add new Vec constructors * Fix UInt conversion * Fix size-0 allocs * Add more tests * Add matmul test * Rename gpu test file * Add more tests * Add alloc cache * Fix object_to_gpu * Fix frees * Fix str conversion * Fix set conversion * Fix conversions * Fix class conversion * Fix str conversion * Fix byte conversion * Fix list conversion * Fix pointer conversions * Fix conversions * Fix conversions * Update tests * Fix conversions * Fix tuple conversion * Fix tuple conversion * Fix auto conversions * Fix conversion * Fix magics * Update tests * Support GPU in JIT mode * Fix GPU+JIT * Fix kernel filename in JIT mode * Add __static_print__; Add earlyDefines; Various domination bugfixes; SimplifyContext RAII base handling * Fix global static handling * Fix float32 tests * FIx gpu module * Support OpenMP "collapse" option * Add more collapse tests * Capture generics and statics * TraitVar handling * Python exceptions / isinstance [wip; no_ci] * clang-format * Add list comparison operators * Support empty raise in IR * Add dict 'or' operator * Fix repr * Add copy module * Fix spacing * Use sm_30 * Python exceptions * TypeTrait support; Fix defaultDict * Fix earlyDefines * Add defaultdict * clang-format * Fix invalid canonicalizations * Fix empty raise * Fix copyright * Add Python numerics option * Support py-numerics in math module * Update docs * Add static Python division / modulus * Add static py numerics tests * Fix staticrange/tuple; Add KwTuple.__getitem__ * clang-format * Add gpu parameter to par * Fix globals * Don't init loop vars on loop collapse * Add par-gpu tests * Update gpu docs * Fix isinstance check * Remove invalid test * Add -libdevice to set custom path [skip ci] * Add release notes; bump version [skip ci] * Add libdevice docs [skip ci] Co-authored-by: Ibrahim Numanagić --- CMakeLists.txt | 13 +- codon/app/main.cpp | 13 +- codon/compiler/compiler.cpp | 20 +- codon/compiler/compiler.h | 11 +- codon/compiler/jit.cpp | 5 +- codon/parser/ast/expr.cpp | 6 +- codon/parser/ast/expr.h | 3 + codon/parser/ast/stmt.cpp | 22 +- codon/parser/ast/stmt.h | 3 + codon/parser/ast/types.cpp | 20 +- codon/parser/ast/types.h | 12 + codon/parser/cache.cpp | 6 + codon/parser/cache.h | 7 + codon/parser/peg/openmp.peg | 8 +- codon/parser/peg/peg.cpp | 1 + codon/parser/visitors/simplify/access.cpp | 82 +- codon/parser/visitors/simplify/assign.cpp | 9 +- codon/parser/visitors/simplify/class.cpp | 333 ++++---- .../parser/visitors/simplify/collections.cpp | 2 +- codon/parser/visitors/simplify/ctx.cpp | 4 +- codon/parser/visitors/simplify/ctx.h | 20 +- codon/parser/visitors/simplify/error.cpp | 2 +- codon/parser/visitors/simplify/function.cpp | 198 ++--- codon/parser/visitors/simplify/import.cpp | 10 +- codon/parser/visitors/simplify/loops.cpp | 4 +- codon/parser/visitors/simplify/simplify.cpp | 19 +- codon/parser/visitors/simplify/simplify.h | 10 +- codon/parser/visitors/translate/translate.cpp | 7 +- codon/parser/visitors/typecheck/access.cpp | 15 +- codon/parser/visitors/typecheck/assign.cpp | 6 +- codon/parser/visitors/typecheck/call.cpp | 74 +- codon/parser/visitors/typecheck/class.cpp | 21 + codon/parser/visitors/typecheck/cond.cpp | 6 +- codon/parser/visitors/typecheck/ctx.cpp | 3 +- codon/parser/visitors/typecheck/ctx.h | 2 + codon/parser/visitors/typecheck/error.cpp | 100 ++- codon/parser/visitors/typecheck/function.cpp | 10 +- codon/parser/visitors/typecheck/infer.cpp | 13 +- codon/parser/visitors/typecheck/loops.cpp | 101 ++- codon/parser/visitors/typecheck/op.cpp | 81 +- codon/parser/visitors/typecheck/typecheck.cpp | 1 + codon/parser/visitors/typecheck/typecheck.h | 3 + codon/runtime/gpu.cpp | 137 +++ codon/runtime/lib.cpp | 15 +- codon/runtime/lib.h | 2 +- codon/sir/llvm/gpu.cpp | 550 ++++++++++++ codon/sir/llvm/gpu.h | 19 + codon/sir/llvm/llvisitor.cpp | 68 +- codon/sir/llvm/llvisitor.h | 9 + codon/sir/llvm/llvm.h | 1 + codon/sir/llvm/optimize.cpp | 7 +- codon/sir/module.cpp | 27 +- codon/sir/module.h | 14 + codon/sir/transform/cleanup/canonical.cpp | 3 + codon/sir/transform/folding/const_fold.cpp | 69 +- codon/sir/transform/folding/const_fold.h | 11 +- codon/sir/transform/folding/folding.cpp | 4 +- codon/sir/transform/folding/folding.h | 4 +- codon/sir/transform/manager.cpp | 22 +- codon/sir/transform/manager.h | 14 +- codon/sir/transform/parallel/openmp.cpp | 358 +++++++- codon/sir/transform/parallel/schedule.cpp | 10 +- codon/sir/transform/parallel/schedule.h | 10 +- codon/sir/types/types.cpp | 8 + codon/sir/types/types.h | 35 +- codon/sir/util/format.cpp | 7 + codon/sir/util/matching.cpp | 6 + codon/sir/util/visitor.cpp | 4 + codon/sir/util/visitor.h | 6 + docs/SUMMARY.md | 1 + docs/advanced/gpu.md | 249 ++++++ docs/advanced/parallel.md | 2 + docs/intro/differences.md | 9 + docs/intro/releases.md | 44 + docs/language/extra.md | 6 + stdlib/algorithms/heapsort.codon | 2 +- stdlib/algorithms/insertionsort.codon | 2 +- stdlib/algorithms/pdqsort.codon | 2 +- stdlib/collections.codon | 89 +- stdlib/copy.codon | 14 + stdlib/gpu.codon | 732 ++++++++++++++++ stdlib/internal/__init__.codon | 7 +- stdlib/internal/builtin.codon | 7 - stdlib/internal/c_stubs.codon | 232 +++++- stdlib/internal/core.codon | 24 +- stdlib/internal/file.codon | 3 +- stdlib/internal/gc.codon | 6 +- stdlib/internal/pynumerics.codon | 168 ++++ stdlib/internal/python.codon | 37 +- stdlib/internal/types/array.codon | 1 + stdlib/internal/types/collections/dict.codon | 35 +- stdlib/internal/types/collections/list.codon | 40 +- stdlib/internal/types/collections/set.codon | 8 +- stdlib/internal/types/error.codon | 4 +- stdlib/internal/types/float.codon | 323 +++++++ stdlib/internal/types/int.codon | 3 + stdlib/internal/types/range.codon | 8 + stdlib/math.codon | 752 ++++++++++++++++- stdlib/openmp.codon | 32 + stdlib/pickle.codon | 9 + stdlib/simd.codon | 310 +++++++ test/core/containers.codon | 227 ++++- test/core/exceptions.codon | 36 + test/core/numerics.codon | 786 ++++++++++++++++++ test/core/serialization.codon | 1 + test/main.cpp | 36 +- test/parser/simplify_expr.codon | 2 +- test/parser/simplify_stmt.codon | 95 ++- test/parser/types.codon | 124 +++ test/python/pybridge.codon | 4 +- test/stdlib/math_test.codon | 507 +++++++++++ test/stdlib/random_test.codon | 1 + test/transform/canonical.codon | 13 - test/transform/kernels.codon | 261 ++++++ test/transform/omp.codon | 162 ++++ 115 files changed, 7505 insertions(+), 617 deletions(-) create mode 100644 codon/runtime/gpu.cpp create mode 100644 codon/sir/llvm/gpu.cpp create mode 100644 codon/sir/llvm/gpu.h create mode 100644 docs/advanced/gpu.md create mode 100644 stdlib/copy.codon create mode 100644 stdlib/gpu.codon create mode 100644 stdlib/internal/pynumerics.codon create mode 100644 stdlib/simd.codon create mode 100644 test/core/numerics.codon create mode 100644 test/transform/kernels.codon diff --git a/CMakeLists.txt b/CMakeLists.txt index ae35b788..a476070a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.14) project( Codon - VERSION "0.13.0" + VERSION "0.14.0" HOMEPAGE_URL "https://github.com/exaloop/codon" DESCRIPTION "high-performance, extensible Python compiler") configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in" @@ -10,6 +10,7 @@ configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in" "${PROJECT_SOURCE_DIR}/extra/python/config/config.py") option(CODON_JUPYTER "build Codon Jupyter server" OFF) +option(CODON_GPU "build Codon GPU backend" OFF) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") @@ -59,7 +60,8 @@ add_custom_command( # Codon runtime library set(CODONRT_FILES codon/runtime/lib.h codon/runtime/lib.cpp - codon/runtime/re.cpp codon/runtime/exc.cpp) + codon/runtime/re.cpp codon/runtime/exc.cpp + codon/runtime/gpu.cpp) add_library(codonrt SHARED ${CODONRT_FILES}) add_dependencies(codonrt zlibstatic gc backtrace bz2 liblzma re2) target_include_directories(codonrt PRIVATE ${backtrace_SOURCE_DIR} @@ -90,6 +92,11 @@ if(ASAN) codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address" "-fsanitize-recover=address") endif() +if(CODON_GPU) + add_compile_definitions(CODON_GPU) + find_package(CUDAToolkit REQUIRED) + target_link_libraries(codonrt PRIVATE CUDA::cudart_static CUDA::cuda_driver) +endif() add_custom_command( TARGET codonrt POST_BUILD @@ -148,6 +155,7 @@ set(CODON_HPPFILES codon/sir/llvm/coro/CoroInternal.h codon/sir/llvm/coro/CoroSplit.h codon/sir/llvm/coro/Coroutines.h + codon/sir/llvm/gpu.h codon/sir/llvm/llvisitor.h codon/sir/llvm/llvm.h codon/sir/llvm/optimize.h @@ -295,6 +303,7 @@ set(CODON_CPPFILES codon/sir/llvm/coro/CoroFrame.cpp codon/sir/llvm/coro/CoroSplit.cpp codon/sir/llvm/coro/Coroutines.cpp + codon/sir/llvm/gpu.cpp codon/sir/llvm/llvisitor.cpp codon/sir/llvm/optimize.cpp codon/sir/module.cpp diff --git a/codon/app/main.cpp b/codon/app/main.cpp index 6b3f6daf..42fe328e 100644 --- a/codon/app/main.cpp +++ b/codon/app/main.cpp @@ -72,6 +72,7 @@ void initLogFlags(const llvm::cl::opt &log) { enum BuildKind { LLVM, Bitcode, Object, Executable, Library, Detect }; enum OptMode { Debug, Release }; +enum Numerics { C, Python }; } // namespace int docMode(const std::vector &args, const std::string &argv0) { @@ -116,6 +117,14 @@ std::unique_ptr processSource(const std::vector & llvm::cl::list plugins("plugin", llvm::cl::desc("Load specified plugin")); llvm::cl::opt log("log", llvm::cl::desc("Enable given log streams")); + llvm::cl::opt numerics( + "numerics", llvm::cl::desc("numerical semantics"), + llvm::cl::values( + clEnumValN(C, "c", "C semantics: best performance but deviates from Python"), + clEnumValN(Python, "py", + "Python semantics: mirrors Python but might disable optimizations " + "like vectorization")), + llvm::cl::init(C)); llvm::cl::ParseCommandLineOptions(args.size(), args.data()); initLogFlags(log); @@ -148,7 +157,9 @@ std::unique_ptr processSource(const std::vector & const bool isDebug = (optMode == OptMode::Debug); std::vector disabledOptsVec(disabledOpts); - auto compiler = std::make_unique(args[0], isDebug, disabledOptsVec); + auto compiler = std::make_unique(args[0], isDebug, disabledOptsVec, + /*isTest=*/false, + (numerics == Numerics::Python)); compiler->getLLVMVisitor()->setStandalone(standalone); // load plugins diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index fd48fc1e..c3bcc4f6 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -29,15 +29,17 @@ ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool is } // namespace Compiler::Compiler(const std::string &argv0, Compiler::Mode mode, - const std::vector &disabledPasses, bool isTest) - : argv0(argv0), debug(mode == Mode::DEBUG), input(), + const std::vector &disabledPasses, bool isTest, + bool pyNumerics) + : argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), input(), plm(std::make_unique()), cache(std::make_unique(argv0)), module(std::make_unique()), pm(std::make_unique(getPassManagerInit(mode, isTest), - disabledPasses)), + disabledPasses, pyNumerics)), llvisitor(std::make_unique()) { cache->module = module.get(); + cache->pythonCompat = pyNumerics; module->setCache(cache.get()); llvisitor->setDebug(debug); llvisitor->setPluginManager(plm.get()); @@ -77,8 +79,9 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code, Timer t2("simplify"); t2.logged = true; - auto transformed = ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt), - abspath, defines, (testFlags > 1)); + auto transformed = + ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt), abspath, defines, + getEarlyDefines(), (testFlags > 1)); LOG_TIME("[T] parse = {:.1f}", totalPeg); LOG_TIME("[T] simplify = {:.1f}", t2.elapsed() - totalPeg); @@ -171,4 +174,11 @@ llvm::Expected Compiler::docgen(const std::vector &fil } } +std::unordered_map Compiler::getEarlyDefines() { + std::unordered_map earlyDefines; + earlyDefines.emplace("__debug__", debug ? "1" : "0"); + earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0"); + return earlyDefines; +} + } // namespace codon diff --git a/codon/compiler/compiler.h b/codon/compiler/compiler.h index 8d0a8001..d11bae32 100644 --- a/codon/compiler/compiler.h +++ b/codon/compiler/compiler.h @@ -25,6 +25,7 @@ public: private: std::string argv0; bool debug; + bool pyNumerics; std::string input; std::unique_ptr plm; std::unique_ptr cache; @@ -38,12 +39,14 @@ private: public: Compiler(const std::string &argv0, Mode mode, - const std::vector &disabledPasses = {}, bool isTest = false); + const std::vector &disabledPasses = {}, bool isTest = false, + bool pyNumerics = false); explicit Compiler(const std::string &argv0, bool debug = false, const std::vector &disabledPasses = {}, - bool isTest = false) - : Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest) {} + bool isTest = false, bool pyNumerics = false) + : Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest, + pyNumerics) {} std::string getInput() const { return input; } PluginManager *getPluginManager() const { return plm.get(); } @@ -62,6 +65,8 @@ public: const std::unordered_map &defines = {}); llvm::Error compile(); llvm::Expected docgen(const std::vector &files); + + std::unordered_map getEarlyDefines(); }; } // namespace codon diff --git a/codon/compiler/jit.cpp b/codon/compiler/jit.cpp index c1c1d5a6..32640003 100644 --- a/codon/compiler/jit.cpp +++ b/codon/compiler/jit.cpp @@ -37,8 +37,9 @@ llvm::Error JIT::init() { auto *pm = compiler->getPassManager(); auto *llvisitor = compiler->getLLVMVisitor(); - auto transformed = ast::SimplifyVisitor::apply( - cache, std::make_shared(), JIT_FILENAME, {}); + auto transformed = + ast::SimplifyVisitor::apply(cache, std::make_shared(), + JIT_FILENAME, {}, compiler->getEarlyDefines()); auto typechecked = ast::TypecheckVisitor::apply(cache, std::move(transformed)); ast::TranslateVisitor::apply(cache, std::move(typechecked)); diff --git a/codon/parser/ast/expr.cpp b/codon/parser/ast/expr.cpp index 6c9f8744..967475c7 100644 --- a/codon/parser/ast/expr.cpp +++ b/codon/parser/ast/expr.cpp @@ -5,6 +5,7 @@ #include #include "codon/parser/ast.h" +#include "codon/parser/cache.h" #include "codon/parser/visitors/visitor.h" #define ACCEPT_IMPL(T, X) \ @@ -17,7 +18,7 @@ namespace codon::ast { Expr::Expr() : type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC), - done(false), attributes(0) {} + done(false), attributes(0), origExpr(nullptr) {} void Expr::validate() const {} types::TypePtr Expr::getType() const { return type; } void Expr::setType(types::TypePtr t) { this->type = std::move(t); } @@ -65,7 +66,8 @@ Param::Param(std::string name, ExprPtr type, ExprPtr defaultValue, int status) : name(std::move(name)), type(std::move(type)), defaultValue(std::move(defaultValue)) { if (status == 0 && this->type && - (this->type->isId("type") || this->type->isId("TypeVar") || + (this->type->isId("type") || this->type->isId(TYPE_TYPEVAR) || + (this->type->getIndex() && this->type->getIndex()->expr->isId(TYPE_TYPEVAR)) || getStaticGeneric(this->type.get()))) this->status = Generic; else diff --git a/codon/parser/ast/expr.h b/codon/parser/ast/expr.h index 9e67e3e3..c4bf2f3b 100644 --- a/codon/parser/ast/expr.h +++ b/codon/parser/ast/expr.h @@ -80,6 +80,9 @@ struct Expr : public codon::SrcObject { /// Set of attributes. int attributes; + /// Original (pre-transformation) expression + std::shared_ptr origExpr; + public: Expr(); Expr(const Expr &expr) = default; diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index 84899c5b..49d49de6 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -339,11 +339,15 @@ std::string FunctionStmt::toString(int indent) const { std::vector as; for (auto &a : args) as.push_back(a.toString()); - std::vector attr; + std::vector dec, attr; for (auto &a : decorators) - attr.push_back(format("(dec {})", a->toString())); - return format("(fn '{} ({}){}{}{}{})", name, join(as, " "), + if (a) + dec.push_back(format("(dec {})", a->toString())); + for (auto &a : attributes.customAttr) + attr.push_back(format("'{}'", a)); + return format("(fn '{} ({}){}{}{}{}{})", name, join(as, " "), ret ? " #:ret " + ret->toString() : "", + dec.empty() ? "" : format(" (dec {})", join(dec, " ")), attr.empty() ? "" : format(" (attr {})", join(attr, " ")), pad, suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1) : "(suite)"); @@ -493,10 +497,11 @@ void ClassStmt::parseDecorators() { // @extend std::map tupleMagics = { - {"new", true}, {"repr", false}, {"hash", false}, {"eq", false}, - {"ne", false}, {"lt", false}, {"le", false}, {"gt", false}, - {"ge", false}, {"pickle", true}, {"unpickle", true}, {"to_py", false}, - {"from_py", false}, {"iter", false}, {"getitem", false}, {"len", false}}; + {"new", true}, {"repr", false}, {"hash", false}, {"eq", false}, + {"ne", false}, {"lt", false}, {"le", false}, {"gt", false}, + {"ge", false}, {"pickle", true}, {"unpickle", true}, {"to_py", false}, + {"from_py", false}, {"iter", false}, {"getitem", false}, {"len", false}, + {"to_gpu", false}, {"from_gpu", false}, {"from_gpu_new", false}}; for (auto &d : decorators) { if (d->isId("deduce")) { @@ -531,6 +536,9 @@ void ClassStmt::parseDecorators() { tupleMagics["pickle"] = tupleMagics["unpickle"] = val; } else if (a.name == "python") { tupleMagics["to_py"] = tupleMagics["from_py"] = val; + } else if (a.name == "gpu") { + tupleMagics["to_gpu"] = tupleMagics["from_gpu"] = + tupleMagics["from_gpu_new"] = val; } else if (a.name == "container") { tupleMagics["iter"] = tupleMagics["getitem"] = val; } else { diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index 8fdf0e65..3aa4d912 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -238,6 +238,9 @@ struct WhileStmt : public Stmt { StmtPtr suite; /// nullptr if there is no else suite. StmtPtr elseSuite; + /// Set if a while loop is used to emulate goto statement + /// (as `while gotoVar: ...`). + std::string gotoVar = ""; WhileStmt(ExprPtr cond, StmtPtr suite, StmtPtr elseSuite = nullptr); WhileStmt(const WhileStmt &stmt); diff --git a/codon/parser/ast/types.cpp b/codon/parser/ast/types.cpp index d8618b02..3fae7f7a 100644 --- a/codon/parser/ast/types.cpp +++ b/codon/parser/ast/types.cpp @@ -312,10 +312,9 @@ std::string ClassType::debugString(bool debug) const { if (!a.name.empty()) gs.push_back(a.type->debugString(debug)); if (debug && !hiddenGenerics.empty()) { - gs.emplace_back("//"); for (auto &a : hiddenGenerics) if (!a.name.empty()) - gs.push_back(a.type->debugString(debug)); + gs.push_back("-" + a.type->debugString(debug)); } // Special formatting for Functions and Tuples auto n = niceName; @@ -829,4 +828,21 @@ std::string CallableTrait::debugString(bool debug) const { return fmt::format("Callable[{}]", join(gs, ",")); } +TypeTrait::TypeTrait(TypePtr typ) : type(std::move(typ)) {} +int TypeTrait::unify(Type *typ, Unification *us) { return typ->unify(type.get(), us); } +TypePtr TypeTrait::generalize(int atLevel) { + auto c = std::make_shared(type->generalize(atLevel)); + c->setSrcInfo(getSrcInfo()); + return c; +} +TypePtr TypeTrait::instantiate(int atLevel, int *unboundCount, + std::unordered_map *cache) { + auto c = std::make_shared(type->instantiate(atLevel, unboundCount, cache)); + c->setSrcInfo(getSrcInfo()); + return c; +} +std::string TypeTrait::debugString(bool debug) const { + return fmt::format("Trait[{}]", type->debugString(debug)); +} + } // namespace codon::ast::types diff --git a/codon/parser/ast/types.h b/codon/parser/ast/types.h index af16ee6e..10b0d563 100644 --- a/codon/parser/ast/types.h +++ b/codon/parser/ast/types.h @@ -422,5 +422,17 @@ public: std::string debugString(bool debug) const override; }; +struct TypeTrait : public Trait { + TypePtr type; + +public: + explicit TypeTrait(TypePtr type); + int unify(Type *typ, Unification *undo) override; + TypePtr generalize(int atLevel) override; + TypePtr instantiate(int atLevel, int *unboundCount, + std::unordered_map *cache) override; + std::string debugString(bool debug) const override; +}; + } // namespace types } // namespace codon::ast diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index ad3cbf95..8fab7311 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -31,6 +31,12 @@ std::string Cache::rev(const std::string &s) { return ""; } +void Cache::addGlobal(const std::string &name, ir::Var *var) { + if (!in(globals, name)) { + globals[name] = var; + } +} + SrcInfo Cache::generateSrcInfo() { return {FILE_GENERATED, generatedSrcInfoCount, generatedSrcInfoCount++, 0}; } diff --git a/codon/parser/cache.h b/codon/parser/cache.h index 0e9477a3..d75a6b03 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -19,6 +19,7 @@ #define TYPE_TUPLE "Tuple.N" #define TYPE_KWTUPLE "KwTuple.N" +#define TYPE_TYPEVAR "TypeVar" #define TYPE_CALLABLE "Callable" #define TYPE_PARTIAL "Partial.N" #define TYPE_OPTIONAL "Optional" @@ -28,6 +29,7 @@ #define MAX_INT_WIDTH 10000 #define MAX_REALIZATION_DEPTH 200 +#define MAX_STATIC_ITER 1024 namespace codon::ast { @@ -207,6 +209,9 @@ struct Cache : public std::enable_shared_from_this { std::unordered_map generatedTuples; std::vector errors; + /// Set if Codon operates in Python compatibility mode (e.g., with Python numerics) + bool pythonCompat = false; + public: explicit Cache(std::string argv0 = ""); @@ -220,6 +225,8 @@ public: SrcInfo generateSrcInfo(); /// Get file contents at the given location. std::string getContent(const SrcInfo &info); + /// Register a global identifier. + void addGlobal(const std::string &name, ir::Var *var = nullptr); /// Realization API. diff --git a/codon/parser/peg/openmp.peg b/codon/parser/peg/openmp.peg index c8bcad17..2d2893a4 100644 --- a/codon/parser/peg/openmp.peg +++ b/codon/parser/peg/openmp.peg @@ -29,9 +29,15 @@ clause <- / "num_threads" _ "(" _ int _ ")" { return vector{{"num_threads", make_shared(ac(V0))}}; } - / "ordered" { + / "ordered" { return vector{{"ordered", make_shared(true)}}; } + / "collapse" { + return vector{{"collapse", make_shared(ac(V0))}}; + } + / "gpu" { + return vector{{"gpu", make_shared(true)}}; + } schedule_kind <- ("static" / "dynamic" / "guided" / "auto" / "runtime") { return VS.token_to_string(); } diff --git a/codon/parser/peg/peg.cpp b/codon/parser/peg/peg.cpp index e78cdcd0..308a2e41 100644 --- a/codon/parser/peg/peg.cpp +++ b/codon/parser/peg/peg.cpp @@ -34,6 +34,7 @@ std::shared_ptr initParser() { x.second.accept(v); } (*g)["program"].enablePackratParsing = true; + (*g)["fstring"].enablePackratParsing = true; for (auto &rule : std::vector{ "arguments", "slices", "genexp", "parentheses", "star_parens", "generics", "with_parens_item", "params", "from_as_parens", "from_params"}) { diff --git a/codon/parser/visitors/simplify/access.cpp b/codon/parser/visitors/simplify/access.cpp index 7cba0dad..d957640a 100644 --- a/codon/parser/visitors/simplify/access.cpp +++ b/codon/parser/visitors/simplify/access.cpp @@ -82,8 +82,9 @@ void SimplifyVisitor::visit(IdExpr *expr) { /// Flatten imports. /// @example -/// `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import -/// `a.B.c` -> canonical name of `c` in class `a.B` +/// `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import +/// `a.B.c` -> canonical name of `c` in class `a.B` +/// `python.foo` -> internal.python._get_identifier("foo") /// Other cases are handled during the type checking. void SimplifyVisitor::visit(DotExpr *expr) { // First flatten the imports: @@ -99,7 +100,11 @@ void SimplifyVisitor::visit(DotExpr *expr) { std::reverse(chain.begin(), chain.end()); auto p = getImport(chain); - if (p.second->getModule() == ctx->getModule() && p.first == 1) { + if (p.second->getModule() == "std.python") { + resultExpr = transform(N( + N(N(N("internal"), "python"), "_get_identifier"), + N(chain[p.first++]))); + } else if (p.second->getModule() == ctx->getModule() && p.first == 1) { resultExpr = transform(N(chain[0]), true); } else { resultExpr = N(p.second->canonicalName); @@ -120,57 +125,74 @@ void SimplifyVisitor::visit(DotExpr *expr) { bool SimplifyVisitor::checkCapture(const SimplifyContext::Item &val) { if (!ctx->isOuter(val)) return false; + if ((val->isType() && !val->isGeneric()) || val->isFunc()) + return false; // Ensure that outer variables can be captured (i.e., do not cross no-capture // boundary). Example: // def foo(): // x = 1 - // class T: # <- boundary (class methods cannot capture locals) - // def bar(): + // class T: # <- boundary (classes cannot capture locals) + // t: int = x # x cannot be accessed + // def bar(): # <- another boundary + // # (class methods cannot capture locals except class generics) // print(x) # x cannot be accessed bool crossCaptureBoundary = false; + bool localGeneric = val->isGeneric() && val->getBaseName() == ctx->getBaseName(); + bool parentClassGeneric = + val->isGeneric() && !ctx->getBase()->isType() && + (ctx->bases.size() > 1 && ctx->bases[ctx->bases.size() - 2].isType() && + ctx->bases[ctx->bases.size() - 2].name == val->getBaseName()); auto i = ctx->bases.size(); for (; i-- > 0;) { if (ctx->bases[i].name == val->getBaseName()) break; - if (!ctx->bases[i].captures) + if (!localGeneric && !parentClassGeneric && !ctx->bases[i].captures) crossCaptureBoundary = true; } seqassert(i < ctx->bases.size(), "invalid base for '{}'", val->canonicalName); - // Disallow outer generics except for class generics in methods - if (val->isGeneric() && !(ctx->bases[i].isType() && i + 2 == ctx->bases.size())) - error("cannot access nonlocal variable '{}'", ctx->cache->rev(val->canonicalName)); - // Mark methods (class functions that access class generics) - if (val->isGeneric() && ctx->bases[i].isType() && i + 2 == ctx->bases.size() && - ctx->getBase()->attributes) + if (parentClassGeneric) ctx->getBase()->attributes->set(Attr::Method); - // Check if a real variable (not a static) is defined outside the current scope - if (!val->isVar() || val->isGeneric()) + // Ignore generics + if (parentClassGeneric || localGeneric) return false; // Case: a global variable that has not been marked with `global` statement - if (val->getBaseName().empty()) { /// TODO: use isGlobal instead? + if (val->isVar() && val->getBaseName().empty()) { val->noShadow = true; - if (val->scope.size() == 1 && !in(ctx->cache->globals, val->canonicalName)) - ctx->cache->globals[val->canonicalName] = nullptr; + if (val->scope.size() == 1) + ctx->cache->addGlobal(val->canonicalName); return false; } + // Check if a real variable (not a static) is defined outside the current scope + if (crossCaptureBoundary) + error("cannot access nonlocal variable '{}'", ctx->cache->rev(val->canonicalName)); + // Case: a nonlocal variable that has not been marked with `nonlocal` statement // and capturing is enabled auto captures = ctx->getBase()->captures; - if (!crossCaptureBoundary && captures && !in(*captures, val->canonicalName)) { + if (captures && !in(*captures, val->canonicalName)) { // Captures are transformed to function arguments; generate new name for that // argument - auto newName = (*captures)[val->canonicalName] = - ctx->generateCanonicalName(val->canonicalName); + ExprPtr typ = nullptr; + if (val->isType()) + typ = N("type"); + if (auto st = val->isStatic()) + typ = N(N("Static"), + N(st == StaticValue::INT ? "int" : "str")); + auto [newName, _] = (*captures)[val->canonicalName] = { + ctx->generateCanonicalName(val->canonicalName), typ}; ctx->cache->reverseIdentifierLookup[newName] = newName; // Add newly generated argument to the context - auto newVal = - ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo()); + std::shared_ptr newVal = nullptr; + if (val->isType()) + newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, getSrcInfo()); + else + newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo()); newVal->baseName = ctx->getBaseName(); newVal->noShadow = true; return true; @@ -206,10 +228,18 @@ SimplifyVisitor::getImport(const std::vector &chain) { size_t itemEnd = 0; auto fctx = importName.empty() ? ctx : ctx->cache->imports[importName].ctx; for (auto i = chain.size(); i-- > importEnd;) { - val = fctx->find(join(chain, ".", importEnd, i + 1)); - if (val && (importName.empty() || val->isType() || !val->isConditional())) { - itemName = val->canonicalName, itemEnd = i + 1; - break; + if (fctx->getModule() == "std.python" && importEnd < chain.size()) { + // Special case: importing from Python. + // Fake SimplifyItem that inidcates std.python access + val = std::make_shared(SimplifyItem::Var, "", "", + fctx->getModule(), std::vector{}); + return {importEnd, val}; + } else { + val = fctx->find(join(chain, ".", importEnd, i + 1)); + if (val && (importName.empty() || val->isType() || !val->isConditional())) { + itemName = val->canonicalName, itemEnd = i + 1; + break; + } } } if (itemName.empty() && importName.empty()) diff --git a/codon/parser/visitors/simplify/assign.cpp b/codon/parser/visitors/simplify/assign.cpp index 31be553e..e3d2f4a7 100644 --- a/codon/parser/visitors/simplify/assign.cpp +++ b/codon/parser/visitors/simplify/assign.cpp @@ -147,14 +147,15 @@ StmtPtr SimplifyVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr t if (rhs && rhs->isType()) { ctx->addType(e->value, canonical, lhs->getSrcInfo()); } else { - ctx->addVar(e->value, canonical, lhs->getSrcInfo()); + auto val = ctx->addVar(e->value, canonical, lhs->getSrcInfo()); + if (auto st = getStaticGeneric(type.get())) + val->staticType = st; } // Register all toplevel variables as global in JIT mode bool isGlobal = (ctx->cache->isJit && ctx->isGlobal()) || (canonical == VAR_ARGV); - if (isGlobal && !in(ctx->cache->globals, canonical)) { - ctx->cache->globals[canonical] = nullptr; - } + if (isGlobal) + ctx->cache->addGlobal(canonical); return assign; } diff --git a/codon/parser/visitors/simplify/class.cpp b/codon/parser/visitors/simplify/class.cpp index 3d8ac691..15fa3de4 100644 --- a/codon/parser/visitors/simplify/class.cpp +++ b/codon/parser/visitors/simplify/class.cpp @@ -50,169 +50,170 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { argsToParse = astIter->second.ast->args; } - // Add the class base - ctx->bases.emplace_back(SimplifyContext::Base(canonicalName)); - ctx->addBlock(); - - // Parse and add class generics - std::vector args; - std::pair autoDeducedInit{nullptr, nullptr}; - if (stmt->attributes.has("deduce") && args.empty()) { - // Auto-detect generics and fields - autoDeducedInit = autoDeduceMembers(stmt, args); - } else { - // Add all generics before parent classes, fields and methods - for (auto &a : argsToParse) { - if (a.status != Param::Generic) - continue; - std::string genName, varName; - if (stmt->attributes.has(Attr::Extend)) - varName = a.name, genName = ctx->cache->rev(a.name); - else - varName = ctx->generateCanonicalName(a.name), genName = a.name; - if (getStaticGeneric(a.type.get())) - ctx->addVar(genName, varName, a.type->getSrcInfo())->generic = true; - else - ctx->addType(genName, varName, a.type->getSrcInfo())->generic = true; - args.emplace_back(Param{varName, transformType(clone(a.type), false), - transformType(clone(a.defaultValue), false), a.status}); - } - } - - // Form class type node (e.g. `Foo`, or `Foo[T, U]` for generic classes) - ExprPtr typeAst = N(name); - for (auto &a : args) { - if (a.status == Param::Generic) { - if (!typeAst->getIndex()) - typeAst = N(N(name), N()); - typeAst->getIndex()->index->getTuple()->items.push_back(N(a.name)); - } - } - - // Collect classes (and their fields) that are to be statically inherited - auto baseASTs = parseBaseClasses(stmt->baseClasses, args, stmt->attributes); - - // A ClassStmt will be separated into class variable assignments, method-free - // ClassStmts (that include nested classes) and method FunctionStmts std::vector clsStmts; // Will be filled later! std::vector varStmts; // Will be filled later! std::vector fnStmts; // Will be filled later! - transformNestedClasses(stmt, clsStmts, varStmts, fnStmts); + std::vector addLater; + { + // Add the class base + SimplifyContext::BaseGuard br(ctx.get(), canonicalName); - // Collect class fields - for (auto &a : argsToParse) { - if (a.status == Param::Normal) { - if (!ClassStmt::isClassVar(a)) { - args.emplace_back(Param{a.name, transformType(clone(a.type), false), - transform(clone(a.defaultValue), true)}); - } else if (!stmt->attributes.has(Attr::Extend)) { - // Handle class variables. Transform them later to allow self-references - auto name = format("{}.{}", canonicalName, a.name); - preamble->push_back(N(N(name), nullptr, nullptr)); - if (!in(ctx->cache->globals, name)) - ctx->cache->globals[name] = nullptr; - auto assign = N(N(name), a.defaultValue, - a.type ? a.type->getIndex()->index : nullptr); - assign->setUpdate(); - varStmts.push_back(assign); - ctx->cache->classes[canonicalName].classVars[a.name] = name; + // Parse and add class generics + std::vector args; + std::pair autoDeducedInit{nullptr, nullptr}; + if (stmt->attributes.has("deduce") && args.empty()) { + // Auto-detect generics and fields + autoDeducedInit = autoDeduceMembers(stmt, args); + } else { + // Add all generics before parent classes, fields and methods + for (auto &a : argsToParse) { + if (a.status != Param::Generic) + continue; + std::string genName, varName; + if (stmt->attributes.has(Attr::Extend)) + varName = a.name, genName = ctx->cache->rev(a.name); + else + varName = ctx->generateCanonicalName(a.name), genName = a.name; + if (auto st = getStaticGeneric(a.type.get())) { + auto val = ctx->addVar(genName, varName, a.type->getSrcInfo()); + val->generic = true; + val->staticType = st; + } else { + ctx->addType(genName, varName, a.type->getSrcInfo())->generic = true; + } + args.emplace_back(Param{varName, transformType(clone(a.type), false), + transformType(clone(a.defaultValue), false), a.status}); } } - } - // ASTs for member arguments to be used for populating magic methods - std::vector memberArgs; - for (auto &a : args) { - if (a.status == Param::Normal) - memberArgs.push_back(a.clone()); - } - - // Ensure that all fields and class variables are registered - if (!stmt->attributes.has(Attr::Extend)) { - for (size_t ai = 0; ai < args.size();) { - if (args[ai].status == Param::Normal) - ctx->cache->classes[canonicalName].fields.push_back({args[ai].name, nullptr}); - ai++; + // Form class type node (e.g. `Foo`, or `Foo[T, U]` for generic classes) + ExprPtr typeAst = N(name); + for (auto &a : args) { + if (a.status == Param::Generic) { + if (!typeAst->getIndex()) + typeAst = N(N(name), N()); + typeAst->getIndex()->index->getTuple()->items.push_back(N(a.name)); + } } - } - // Parse class members (arguments) and methods - if (!stmt->attributes.has(Attr::Extend)) { - // Now that we are done with arguments, add record type to the context - if (stmt->attributes.has(Attr::Tuple)) { - // Ensure that class binding does not shadow anything. - // Class bindings cannot be dominated either - auto v = ctx->find(name); - if (v && v->noShadow) - error("cannot update global/nonlocal"); - ctx->add(name, classItem); - ctx->addAlwaysVisible(classItem); - } - // Create a cached AST. - stmt->attributes.module = - format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", - ctx->moduleName.module); - ctx->cache->classes[canonicalName].ast = - N(canonicalName, args, N(), stmt->attributes); - for (auto &b : baseASTs) - ctx->cache->classes[canonicalName].parentClasses.emplace_back(b->name); - ctx->cache->classes[canonicalName].ast->validate(); + // Collect classes (and their fields) that are to be statically inherited + auto baseASTs = parseBaseClasses(stmt->baseClasses, args, stmt->attributes); - // Codegen default magic methods - for (auto &m : stmt->attributes.magics) { - fnStmts.push_back(transform( - codegenMagic(m, typeAst, memberArgs, stmt->attributes.has(Attr::Tuple)))); - } - // Add inherited methods - for (auto &base : baseASTs) { - for (auto &mm : ctx->cache->classes[base->name].methods) - for (auto &mf : ctx->cache->overloads[mm.second]) { - auto f = ctx->cache->functions[mf.name].ast; - if (!f->attributes.has("autogenerated")) { - std::string rootName; - auto &mts = ctx->cache->classes[ctx->getBase()->name].methods; - auto it = mts.find(ctx->cache->rev(f->name)); - if (it != mts.end()) - rootName = it->second; - else - rootName = ctx->generateCanonicalName(ctx->cache->rev(f->name), true); - auto newCanonicalName = - format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); - ctx->cache->overloads[rootName].push_back( - {newCanonicalName, ctx->cache->age}); - ctx->cache->reverseIdentifierLookup[newCanonicalName] = - ctx->cache->rev(f->name); - auto nf = std::dynamic_pointer_cast(f->clone()); - nf->name = newCanonicalName; - nf->attributes.parentClass = ctx->getBase()->name; - ctx->cache->functions[newCanonicalName].ast = nf; - ctx->cache->classes[ctx->getBase()->name] - .methods[ctx->cache->rev(f->name)] = rootName; - fnStmts.push_back(nf); - } + // A ClassStmt will be separated into class variable assignments, method-free + // ClassStmts (that include nested classes) and method FunctionStmts + transformNestedClasses(stmt, clsStmts, varStmts, fnStmts); + + // Collect class fields + for (auto &a : argsToParse) { + if (a.status == Param::Normal) { + if (!ClassStmt::isClassVar(a)) { + args.emplace_back(Param{a.name, transformType(clone(a.type), false), + transform(clone(a.defaultValue), true)}); + } else if (!stmt->attributes.has(Attr::Extend)) { + // Handle class variables. Transform them later to allow self-references + auto name = format("{}.{}", canonicalName, a.name); + preamble->push_back(N(N(name), nullptr, nullptr)); + ctx->cache->addGlobal(name); + auto assign = N(N(name), a.defaultValue, + a.type ? a.type->getIndex()->index : nullptr); + assign->setUpdate(); + varStmts.push_back(assign); + ctx->cache->classes[canonicalName].classVars[a.name] = name; } - } - // Add auto-deduced __init__ (if available) - if (autoDeducedInit.first) - fnStmts.push_back(autoDeducedInit.first); - } - // Add class methods - for (const auto &sp : getClassMethods(stmt->suite)) - if (sp && sp->getFunction()) { - if (sp.get() != autoDeducedInit.second) - fnStmts.push_back(transform(sp)); + } } - // After popping context block, record types and nested classes will disappear. - // Store their references and re-add them to the context after popping - std::vector addLater; - addLater.reserve(clsStmts.size() + 1); - for (auto &c : clsStmts) - addLater.push_back(ctx->find(c->getClass()->name)); - if (stmt->attributes.has(Attr::Tuple)) - addLater.push_back(ctx->forceFind(name)); - ctx->bases.pop_back(); - ctx->popBlock(); + // ASTs for member arguments to be used for populating magic methods + std::vector memberArgs; + for (auto &a : args) { + if (a.status == Param::Normal) + memberArgs.push_back(a.clone()); + } + + // Ensure that all fields and class variables are registered + if (!stmt->attributes.has(Attr::Extend)) { + for (size_t ai = 0; ai < args.size();) { + if (args[ai].status == Param::Normal) + ctx->cache->classes[canonicalName].fields.push_back({args[ai].name, nullptr}); + ai++; + } + } + + // Parse class members (arguments) and methods + if (!stmt->attributes.has(Attr::Extend)) { + // Now that we are done with arguments, add record type to the context + if (stmt->attributes.has(Attr::Tuple)) { + // Ensure that class binding does not shadow anything. + // Class bindings cannot be dominated either + auto v = ctx->find(name); + if (v && v->noShadow) + error("cannot update global/nonlocal"); + ctx->add(name, classItem); + ctx->addAlwaysVisible(classItem); + } + // Create a cached AST. + stmt->attributes.module = + format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", + ctx->moduleName.module); + ctx->cache->classes[canonicalName].ast = + N(canonicalName, args, N(), stmt->attributes); + for (auto &b : baseASTs) + ctx->cache->classes[canonicalName].parentClasses.emplace_back(b->name); + ctx->cache->classes[canonicalName].ast->validate(); + + // Codegen default magic methods + for (auto &m : stmt->attributes.magics) { + fnStmts.push_back(transform( + codegenMagic(m, typeAst, memberArgs, stmt->attributes.has(Attr::Tuple)))); + } + // Add inherited methods + for (auto &base : baseASTs) { + for (auto &mm : ctx->cache->classes[base->name].methods) + for (auto &mf : ctx->cache->overloads[mm.second]) { + auto f = ctx->cache->functions[mf.name].ast; + if (!f->attributes.has("autogenerated")) { + std::string rootName; + auto &mts = ctx->cache->classes[ctx->getBase()->name].methods; + auto it = mts.find(ctx->cache->rev(f->name)); + if (it != mts.end()) + rootName = it->second; + else + rootName = ctx->generateCanonicalName(ctx->cache->rev(f->name), true); + auto newCanonicalName = + format("{}:{}", rootName, ctx->cache->overloads[rootName].size()); + ctx->cache->overloads[rootName].push_back( + {newCanonicalName, ctx->cache->age}); + ctx->cache->reverseIdentifierLookup[newCanonicalName] = + ctx->cache->rev(f->name); + auto nf = std::dynamic_pointer_cast(f->clone()); + nf->name = newCanonicalName; + nf->attributes.parentClass = ctx->getBase()->name; + ctx->cache->functions[newCanonicalName].ast = nf; + ctx->cache->classes[ctx->getBase()->name] + .methods[ctx->cache->rev(f->name)] = rootName; + fnStmts.push_back(nf); + } + } + } + // Add auto-deduced __init__ (if available) + if (autoDeducedInit.first) + fnStmts.push_back(autoDeducedInit.first); + } + // Add class methods + for (const auto &sp : getClassMethods(stmt->suite)) + if (sp && sp->getFunction()) { + if (sp.get() != autoDeducedInit.second) + fnStmts.push_back(transform(sp)); + } + + // After popping context block, record types and nested classes will disappear. + // Store their references and re-add them to the context after popping + addLater.reserve(clsStmts.size() + 1); + for (auto &c : clsStmts) + addLater.push_back(ctx->find(c->getClass()->name)); + if (stmt->attributes.has(Attr::Tuple)) + addLater.push_back(ctx->forceFind(name)); + } for (auto &i : addLater) ctx->add(ctx->cache->rev(i->canonicalName), i); @@ -279,10 +280,13 @@ SimplifyVisitor::parseBaseClasses(const std::vector &baseClasses, args.emplace_back(a); } if (a.status != Param::Normal) { - if (getStaticGeneric(a.type.get())) - ctx->addVar(a.name, a.name, a.type->getSrcInfo())->generic = true; - else + if (auto st = getStaticGeneric(a.type.get())) { + auto val = ctx->addVar(a.name, a.name, a.type->getSrcInfo()); + val->generic = true; + val->staticType = st; + } else { ctx->addType(a.name, a.name, a.type->getSrcInfo())->generic = true; + } } } if (si != subs.size()) @@ -632,6 +636,29 @@ StmtPtr SimplifyVisitor::codegenMagic(const std::string &op, const ExprPtr &typE N(clone(args[i].type), "__from_py__"), N(N(I("pyobj"), "_tuple_get"), I("src"), N(i)))); stmts.emplace_back(N(N(typExpr->clone(), ar))); + } else if (op == "to_gpu") { + // def __to_gpu__(self: T, cache) -> T: + // return __internal__.class_to_gpu(self, cache) + fargs.emplace_back(Param{"self", typExpr->clone()}); + fargs.emplace_back(Param{"cache"}); + ret = typExpr->clone(); + stmts.emplace_back(N(N( + N(I("__internal__"), "class_to_gpu"), I("self"), I("cache")))); + } else if (op == "from_gpu") { + // def __from_gpu__(self: T, other: T) -> None: + // __internal__.class_from_gpu(self, other) + fargs.emplace_back(Param{"self", typExpr->clone()}); + fargs.emplace_back(Param{"other", typExpr->clone()}); + ret = I("NoneType"); + stmts.emplace_back(N(N( + N(I("__internal__"), "class_from_gpu"), I("self"), I("other")))); + } else if (op == "from_gpu_new") { + // def __from_gpu_new__(other: T) -> T: + // return __internal__.class_from_gpu_new(other) + fargs.emplace_back(Param{"other", typExpr->clone()}); + ret = typExpr->clone(); + stmts.emplace_back(N( + N(N(I("__internal__"), "class_from_gpu_new"), I("other")))); } else if (op == "repr") { // def __repr__(self: T) -> str: // a = __array__[str](N) (number of args) diff --git a/codon/parser/visitors/simplify/collections.cpp b/codon/parser/visitors/simplify/collections.cpp index 31b1a72d..3583af86 100644 --- a/codon/parser/visitors/simplify/collections.cpp +++ b/codon/parser/visitors/simplify/collections.cpp @@ -13,7 +13,7 @@ namespace codon::ast { /// The rest will be handled during the type-checking stage. void SimplifyVisitor::visit(TupleExpr *expr) { for (auto &i : expr->items) - transform(i); + transform(i, true); // types needed for some constructs (e.g., isinstance) } /// Transform a list `[a1, ..., aN]` to the corresponding statement expression. diff --git a/codon/parser/visitors/simplify/ctx.cpp b/codon/parser/visitors/simplify/ctx.cpp index ed1c2b47..e6414067 100644 --- a/codon/parser/visitors/simplify/ctx.cpp +++ b/codon/parser/visitors/simplify/ctx.cpp @@ -160,8 +160,8 @@ SimplifyContext::Item SimplifyContext::findDominatingBinding(const std::string & std::make_unique(false), nullptr)); // Reached the toplevel? Register the binding as global. if (prefix == 1) { - cache->globals[canonicalName] = nullptr; - cache->globals[fmt::format("{}.__used__", canonicalName)] = nullptr; + cache->addGlobal(canonicalName); + cache->addGlobal(fmt::format("{}.__used__", canonicalName)); } hasUsed = true; } diff --git a/codon/parser/visitors/simplify/ctx.h b/codon/parser/visitors/simplify/ctx.h index baee4f03..f3b92455 100644 --- a/codon/parser/visitors/simplify/ctx.h +++ b/codon/parser/visitors/simplify/ctx.h @@ -40,6 +40,8 @@ struct SimplifyItem : public SrcObject { bool noShadow = false; /// Set if an identifier is a class or a function generic bool generic = false; + /// Set if an identifier is a static variable. + char staticType = 0; public: SimplifyItem(Kind kind, std::string baseName, std::string canonicalName, @@ -61,6 +63,7 @@ public: /// (i.e., a block that might not be executed during the runtime) bool isConditional() const { return scope.size() > 1; } bool isGeneric() const { return generic; } + char isStatic() const { return staticType; } }; /** Context class that tracks identifiers during the simplification. **/ @@ -99,8 +102,9 @@ struct SimplifyContext : public Context { /// Map of captured identifiers (i.e., identifiers not defined in a function). /// Captured (canonical) identifiers are mapped to the new canonical names /// (representing the canonical function argument names that are appended to the - /// function after processing). - std::unordered_map *captures; + /// function after processing) and their types (indicating if they are a type, a + /// static or a variable). + std::unordered_map> *captures; /// A stack of nested loops enclosing the current statement used for transforming /// "break" statement in loop-else constructs. Each loop is defined by a "break" @@ -123,6 +127,18 @@ struct SimplifyContext : public Context { /// Current base stack (the last enclosing base is the last base in the stack). std::vector bases; + struct BaseGuard { + SimplifyContext *holder; + BaseGuard(SimplifyContext *holder, const std::string &name) : holder(holder) { + holder->bases.emplace_back(Base(name)); + holder->addBlock(); + } + ~BaseGuard() { + holder->bases.pop_back(); + holder->popBlock(); + } + }; + /// Set of seen global identifiers used to prevent later creation of local variables /// with the same name. std::unordered_map> diff --git a/codon/parser/visitors/simplify/error.cpp b/codon/parser/visitors/simplify/error.cpp index f0eea9b1..0542b46a 100644 --- a/codon/parser/visitors/simplify/error.cpp +++ b/codon/parser/visitors/simplify/error.cpp @@ -42,7 +42,7 @@ void SimplifyVisitor::visit(TryStmt *stmt) { c.var = ctx->generateCanonicalName(c.var); ctx->addVar(ctx->cache->rev(c.var), c.var, c.suite->getSrcInfo()); } - transformType(c.exc); + transform(c.exc, true); transformConditionalScope(c.suite); ctx->leaveConditionalBlock(); } diff --git a/codon/parser/visitors/simplify/function.cpp b/codon/parser/visitors/simplify/function.cpp index 2806ea16..784ae912 100644 --- a/codon/parser/visitors/simplify/function.cpp +++ b/codon/parser/visitors/simplify/function.cpp @@ -74,8 +74,7 @@ void SimplifyVisitor::visit(GlobalStmt *stmt) { stmt->var); // Register as global if needed - if (!in(ctx->cache->globals, val->canonicalName)) - ctx->cache->globals[val->canonicalName] = nullptr; + ctx->cache->addGlobal(val->canonicalName); val = ctx->addVar(stmt->var, val->canonicalName, stmt->getSrcInfo()); val->baseName = ctx->getBaseName(); @@ -111,9 +110,11 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { // Parse attributes for (auto i = stmt->decorators.size(); i-- > 0;) { - if (auto n = isAttribute(stmt->decorators[i])) { - stmt->attributes.set(*n); - stmt->decorators[i] = nullptr; // remove it from further consideration + auto [isAttr, attrName] = getDecorator(stmt->decorators[i]); + if (!attrName.empty()) { + stmt->attributes.set(attrName); + if (isAttr) + stmt->decorators[i] = nullptr; // remove it from further consideration } } @@ -153,89 +154,92 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { ctx->addAlwaysVisible(funcVal); } - // Set up the base - ctx->bases.emplace_back(SimplifyContext::Base{canonicalName}); - ctx->addBlock(); - ctx->getBase()->attributes = &(stmt->attributes); - - // Parse arguments and add them to the context std::vector args; - for (auto &a : stmt->args) { - std::string varName = a.name; - int stars = trimStars(varName); - auto name = ctx->generateCanonicalName(varName); - - // Mark as method if the first argument is self - if (isClassMember && stmt->attributes.has(Attr::HasSelf) && a.name == "self") { - ctx->getBase()->selfName = name; - stmt->attributes.set(Attr::Method); - } - - // Handle default values - auto defaultValue = a.defaultValue; - if (a.type && defaultValue && defaultValue->getNone()) { - // Special case: `arg: Callable = None` -> `arg: Callable = NoneType()` - if (a.type->getIndex() && a.type->getIndex()->expr->isId(TYPE_CALLABLE)) - defaultValue = N(N("NoneType")); - // Special case: `arg: type = None` -> `arg: type = NoneType` - if (a.type->isId("type") || a.type->isId("TypeVar")) - defaultValue = N("NoneType"); - } - /// TODO: Uncomment for Python-style defaults - // if (defaultValue) { - // auto defaultValueCanonicalName = - // ctx->generateCanonicalName(format("{}.{}", canonicalName, name)); - // prependStmts->push_back(N(N(defaultValueCanonicalName), - // defaultValue)); - // defaultValue = N(defaultValueCanonicalName); - // } - args.emplace_back( - Param{std::string(stars, '*') + name, a.type, defaultValue, a.status}); - - // Add generics to the context - if (a.status != Param::Normal) { - if (getStaticGeneric(a.type.get())) - ctx->addVar(varName, name, stmt->getSrcInfo())->generic = true; - else - ctx->addType(varName, name, stmt->getSrcInfo())->generic = true; - } - } - // Parse arguments to the context. Needs to be done after adding generics - // to support cases like `foo(a: T, T: type)` - for (auto &a : args) { - a.type = transformType(a.type, false); - a.defaultValue = transform(a.defaultValue, true); - } - // Add non-generic arguments to the context. Delayed to prevent cases like - // `def foo(a, b=a)` - for (auto &a : args) { - if (a.status == Param::Normal) { - std::string canName = a.name; - trimStars(canName); - ctx->addVar(ctx->cache->rev(canName), canName, stmt->getSrcInfo()); - } - } - - // Parse the return type - auto ret = transformType(stmt->ret, false); - - // Parse function body StmtPtr suite = nullptr; - std::unordered_map captures; - if (!stmt->attributes.has(Attr::Internal) && !stmt->attributes.has(Attr::C)) { - if (stmt->attributes.has(Attr::LLVM)) { - suite = transformLLVMDefinition(stmt->suite->firstInBlock()); - } else if (stmt->attributes.has(Attr::C)) { - // Do nothing - } else { - if ((isEnclosedFunc || stmt->attributes.has(Attr::Capture)) && !isClassMember) - ctx->getBase()->captures = &captures; - suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite); + ExprPtr ret = nullptr; + std::unordered_map> captures; + { + // Set up the base + SimplifyContext::BaseGuard br(ctx.get(), canonicalName); + ctx->getBase()->attributes = &(stmt->attributes); + + // Parse arguments and add them to the context + for (auto &a : stmt->args) { + std::string varName = a.name; + int stars = trimStars(varName); + auto name = ctx->generateCanonicalName(varName); + + // Mark as method if the first argument is self + if (isClassMember && stmt->attributes.has(Attr::HasSelf) && a.name == "self") { + ctx->getBase()->selfName = name; + stmt->attributes.set(Attr::Method); + } + + // Handle default values + auto defaultValue = a.defaultValue; + if (a.type && defaultValue && defaultValue->getNone()) { + // Special case: `arg: Callable = None` -> `arg: Callable = NoneType()` + if (a.type->getIndex() && a.type->getIndex()->expr->isId(TYPE_CALLABLE)) + defaultValue = N(N("NoneType")); + // Special case: `arg: type = None` -> `arg: type = NoneType` + if (a.type->isId("type") || a.type->isId(TYPE_TYPEVAR)) + defaultValue = N("NoneType"); + } + /// TODO: Uncomment for Python-style defaults + // if (defaultValue) { + // auto defaultValueCanonicalName = + // ctx->generateCanonicalName(format("{}.{}", canonicalName, name)); + // prependStmts->push_back(N(N(defaultValueCanonicalName), + // defaultValue)); + // defaultValue = N(defaultValueCanonicalName); + // } + args.emplace_back( + Param{std::string(stars, '*') + name, a.type, defaultValue, a.status}); + + // Add generics to the context + if (a.status != Param::Normal) { + if (auto st = getStaticGeneric(a.type.get())) { + auto val = ctx->addVar(varName, name, stmt->getSrcInfo()); + val->generic = true; + val->staticType = st; + } else { + ctx->addType(varName, name, stmt->getSrcInfo())->generic = true; + } + } + } + + // Parse arguments to the context. Needs to be done after adding generics + // to support cases like `foo(a: T, T: type)` + for (auto &a : args) { + a.type = transformType(a.type, false); + a.defaultValue = transform(a.defaultValue, true); + } + // Add non-generic arguments to the context. Delayed to prevent cases like + // `def foo(a, b=a)` + for (auto &a : args) { + if (a.status == Param::Normal) { + std::string canName = a.name; + trimStars(canName); + ctx->addVar(ctx->cache->rev(canName), canName, stmt->getSrcInfo()); + } + } + + // Parse the return type + ret = transformType(stmt->ret, false); + + // Parse function body + if (!stmt->attributes.has(Attr::Internal) && !stmt->attributes.has(Attr::C)) { + if (stmt->attributes.has(Attr::LLVM)) { + suite = transformLLVMDefinition(stmt->suite->firstInBlock()); + } else if (stmt->attributes.has(Attr::C)) { + // Do nothing + } else { + if ((isEnclosedFunc || stmt->attributes.has(Attr::Capture)) && !isClassMember) + ctx->getBase()->captures = &captures; + suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite); + } } } - - ctx->bases.pop_back(); - ctx->popBlock(); stmt->attributes.module = format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::", ctx->moduleName.module); @@ -259,8 +263,8 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) { args.pop_back(); } for (auto &c : captures) { - args.emplace_back(Param{c.second, nullptr, nullptr}); - partialArgs.push_back({c.second, N(ctx->cache->rev(c.first))}); + args.emplace_back(Param{c.second.first, c.second.second, nullptr}); + partialArgs.push_back({c.second.first, N(ctx->cache->rev(c.first))}); } if (!kw.name.empty()) args.push_back(kw); @@ -417,20 +421,22 @@ StmtPtr SimplifyVisitor::transformLLVMDefinition(Stmt *codeStmt) { return N(items); } -/// Check if a decorator is actually an attribute (a function with `@__attribute__`) -std::string *SimplifyVisitor::isAttribute(const ExprPtr &e) { +/// Fetch a decorator canonical name. The first pair member indicates if a decorator is +/// actually an attribute (a function with `@__attribute__`). +std::pair SimplifyVisitor::getDecorator(const ExprPtr &e) { auto dt = transform(clone(e)); - if (dt && dt->getId()) { - auto ci = ctx->find(dt->getId()->value); + auto id = dt->getCall() ? dt->getCall()->expr : dt; + if (id && id->getId()) { + auto ci = ctx->find(id->getId()->value); if (ci && ci->isFunc()) { - if (ctx->cache->overloads[ci->canonicalName].size() == 1) - if (ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name] - .ast->attributes.isAttribute) { - return &(ci->canonicalName); - } + if (ctx->cache->overloads[ci->canonicalName].size() == 1) { + return {ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name] + .ast->attributes.isAttribute, + ci->canonicalName}; + } } } - return nullptr; + return {false, ""}; } } // namespace codon::ast diff --git a/codon/parser/visitors/simplify/import.cpp b/codon/parser/visitors/simplify/import.cpp index 45ca3702..e47a8b3c 100644 --- a/codon/parser/visitors/simplify/import.cpp +++ b/codon/parser/visitors/simplify/import.cpp @@ -319,8 +319,7 @@ StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) { // `import_[I]_done = False` (set to True upon successful import) preamble->push_back(N(N(importDoneVar = importVar + "_done"), N(false))); - if (!in(ctx->cache->globals, importDoneVar)) - ctx->cache->globals[importDoneVar] = nullptr; + ctx->cache->addGlobal(importDoneVar); // Wrap all imported top-level statements into a function. // Make sure to register the global variables and set their assignments as updates. @@ -332,11 +331,8 @@ StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) { if (!a->isUpdate() && a->lhs->getId()) { // Global `a = ...` auto val = ictx->forceFind(a->lhs->getId()->value); - if (val->isVar() && val->isGlobal() && !getStaticGeneric(a->type.get())) { - // Register global - if (!in(ctx->cache->globals, val->canonicalName)) - ctx->cache->globals[val->canonicalName] = nullptr; - } + if (val->isVar() && val->isGlobal() && !getStaticGeneric(a->type.get())) + ctx->cache->addGlobal(val->canonicalName); } } stmts.push_back(s); diff --git a/codon/parser/visitors/simplify/loops.cpp b/codon/parser/visitors/simplify/loops.cpp index 6bfd371d..2080f4a6 100644 --- a/codon/parser/visitors/simplify/loops.cpp +++ b/codon/parser/visitors/simplify/loops.cpp @@ -97,7 +97,7 @@ void SimplifyVisitor::visit(ForStmt *stmt) { ctx->addVar(i->value, varName = ctx->generateCanonicalName(i->value), stmt->var->getSrcInfo()); transform(stmt->var); - transform(stmt->suite); + stmt->suite = transform(N(stmt->suite)); } else { varName = ctx->cache->getTemporaryVar("for"); ctx->addVar(varName, varName, stmt->var->getSrcInfo()); @@ -109,7 +109,7 @@ void SimplifyVisitor::visit(ForStmt *stmt) { stmts.push_back(stmt->suite); stmt->suite = transform(N(stmts)); } - ctx->leaveConditionalBlock(); + ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts)); // Dominate loop variables for (auto &var : ctx->getBase()->getLoop()->seenVars) ctx->findDominatingBinding(var); diff --git a/codon/parser/visitors/simplify/simplify.cpp b/codon/parser/visitors/simplify/simplify.cpp index 02a495f6..0726b9e9 100644 --- a/codon/parser/visitors/simplify/simplify.cpp +++ b/codon/parser/visitors/simplify/simplify.cpp @@ -20,12 +20,12 @@ using namespace types; /// @param cache Pointer to the shared cache ( @c Cache ) /// @param file Filename to be used for error reporting /// @param barebones Use the bare-bones standard library for faster testing -/// @param defines User-defined static values (typically passed as `codon run -DX=Y -/// ...`). +/// @param defines User-defined static values (typically passed as `codon run -DX=Y`). /// Each value is passed as a string. StmtPtr SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &file, const std::unordered_map &defines, + const std::unordered_map &earlyDefines, bool barebones) { auto preamble = std::make_shared>(); seqassertn(cache->module, "cache's module is not set"); @@ -53,6 +53,20 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"}; // Load the standard library stdlib->setFilename(stdlibPath->path); + // Core definitions + preamble->push_back(SimplifyVisitor(stdlib, preamble) + .transform(parseCode(stdlib->cache, stdlibPath->path, + "from internal.core import *"))); + for (auto &d : earlyDefines) { + // Load early compile-time defines (for standard library) + preamble->push_back( + SimplifyVisitor(stdlib, preamble) + .transform(std::make_shared( + std::make_shared(d.first), + std::make_shared(d.second), + std::make_shared(std::make_shared("Static"), + std::make_shared("int"))))); + } preamble->push_back(SimplifyVisitor(stdlib, preamble) .transform(parseFile(stdlib->cache, stdlibPath->path))); stdlib->isStdlibLoading = false; @@ -181,6 +195,7 @@ StmtPtr SimplifyVisitor::transform(StmtPtr &stmt) { stmt->accept(v); } catch (const exc::ParserException &e) { ctx->cache->errors.push_back(e); + // throw; } ctx->popSrcInfo(); if (v.resultStmt) diff --git a/codon/parser/visitors/simplify/simplify.h b/codon/parser/visitors/simplify/simplify.h index 90c61e41..007b715c 100644 --- a/codon/parser/visitors/simplify/simplify.h +++ b/codon/parser/visitors/simplify/simplify.h @@ -43,9 +43,11 @@ class SimplifyVisitor : public CallbackASTVisitor { StmtPtr resultStmt; public: - static StmtPtr apply(Cache *cache, const StmtPtr &node, const std::string &file, - const std::unordered_map &defines, - bool barebones = false); + static StmtPtr + apply(Cache *cache, const StmtPtr &node, const std::string &file, + const std::unordered_map &defines = {}, + const std::unordered_map &earlyDefines = {}, + bool barebones = false); static StmtPtr apply(const std::shared_ptr &cache, const StmtPtr &node, const std::string &file, int atAge = -1); @@ -169,7 +171,7 @@ private: // Node simplification rules StmtPtr transformPythonDefinition(const std::string &, const std::vector &, const Expr *, Stmt *); StmtPtr transformLLVMDefinition(Stmt *); - std::string *isAttribute(const ExprPtr &); + std::pair getDecorator(const ExprPtr &); /* Classes (class.cpp) */ void visit(ClassStmt *) override; diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 6c45d269..8d0cb762 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -409,7 +409,10 @@ void TranslateVisitor::visit(ForStmt *stmt) { bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt(); auto threads = transform(c->args[0].value); auto chunk = transform(c->args[1].value); - os = std::make_unique(schedule, threads, chunk, ordered); + int64_t collapse = + fc->funcGenerics[2].type->getStatic()->expr->staticValue.getInt(); + bool gpu = fc->funcGenerics[3].type->getStatic()->expr->staticValue.getInt(); + os = std::make_unique(schedule, threads, chunk, ordered, collapse, gpu); LOG_TYPECHECK("parsed {}", stmt->decorator->toString()); } @@ -487,7 +490,7 @@ void TranslateVisitor::visit(TryStmt *stmt) { } void TranslateVisitor::visit(ThrowStmt *stmt) { - result = make(stmt, transform(stmt->expr)); + result = make(stmt, stmt->expr ? transform(stmt->expr) : nullptr); } void TranslateVisitor::visit(FunctionStmt *stmt) { diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index dba240ef..b0e11e0e 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -24,15 +24,6 @@ void TypecheckVisitor::visit(IdExpr *expr) { if (isTuple(expr->value)) generateTuple(std::stoi(expr->value.substr(sizeof(TYPE_TUPLE) - 1))); - // Handle empty callable references - if (expr->value == TYPE_CALLABLE) { - auto typ = ctx->getUnbound(); - typ->getLink()->trait = std::make_shared(std::vector{}); - unify(expr->type, typ); - expr->markType(); - return; - } - // Replace identifiers that have been superseded by domination analysis during the // simplification while (auto s = in(ctx->cache->replacements, expr->value)) @@ -160,6 +151,12 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, if (expr->expr->type->getFunc() && expr->member == "__name__") { return transform(N(expr->expr->type->toString())); } + // Special case: fn.__llvm_name__ + if (expr->expr->type->getFunc() && expr->member == "__llvm_name__") { + if (realize(expr->expr->type)) + return transform(N(expr->expr->type->realizedName())); + return nullptr; + } // Special case: cls.__name__ if (expr->expr->isType() && expr->member == "__name__") { if (realize(expr->expr->type)) diff --git a/codon/parser/visitors/typecheck/assign.cpp b/codon/parser/visitors/typecheck/assign.cpp index 8b82983e..04ab82f2 100644 --- a/codon/parser/visitors/typecheck/assign.cpp +++ b/codon/parser/visitors/typecheck/assign.cpp @@ -69,7 +69,11 @@ void TypecheckVisitor::visit(AssignStmt *stmt) { seqassert(stmt->rhs->staticValue.evaluated, "static not evaluated"); unify(stmt->lhs->type, unify(stmt->type->type, std::make_shared(stmt->rhs, ctx))); - ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type); + auto val = ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type); + if (in(ctx->cache->globals, lhs)) { + // Make globals always visible! + ctx->addToplevel(lhs, val); + } if (realize(stmt->lhs->type)) stmt->setDone(); } else { diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 35a06168..f945ac06 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -320,7 +320,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e if (!part.known.empty()) { auto e = getPartialArg(-1); auto t = e->getType()->getRecord(); - seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple", + seqassert(t && startswith(t->name, TYPE_KWTUPLE), "{} not a kwtuple", e->toString()); auto &ff = ctx->cache->classes[t->name].fields; for (int i = 0; i < t->getRecord()->args.size(); i++) { @@ -410,8 +410,9 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e if (typeArgs[si]) { auto typ = typeArgs[si]->type; if (calleeFn->funcGenerics[si].type->isStaticType()) { - if (!typeArgs[si]->isStatic()) + if (!typeArgs[si]->isStatic()) { error("expected static expression"); + } typ = std::make_shared(typeArgs[si], ctx); } unify(typ, calleeFn->funcGenerics[si].type); @@ -543,6 +544,10 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) return {true, transformCompileError(expr)}; } else if (val == "tuple") { return {true, transformTupleFn(expr)}; + } else if (val == "__realized__") { + return {true, transformRealizedFn(expr)}; + } else if (val == "__static_print__") { + return {false, transformStaticPrintFn(expr)}; } else { return {false, nullptr}; } @@ -668,7 +673,6 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) { /// `isinstance(obj, ByVal)` is True if `type(obj)` is a tuple type /// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) { - expr->staticValue.type = StaticValue::INT; expr->setType(unify(expr->type, ctx->getType("bool"))); transform(expr->args[0].value); auto typ = expr->args[0].value->type->getClass(); @@ -678,21 +682,44 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) { transform(expr->args[0].value); // transform again to realize it auto &typExpr = expr->args[1].value; + if (auto c = typExpr->getCall()) { + // Handle `isinstance(obj, (type1, type2, ...))` + if (typExpr->origExpr && typExpr->origExpr->getTuple()) { + ExprPtr result = transform(N(false)); + for (auto &i : typExpr->origExpr->getTuple()->items) { + result = transform(N( + result, "||", + N(N("isinstance"), expr->args[0].value, i))); + } + return result; + } + } + + expr->staticValue.type = StaticValue::INT; if (typExpr->isId("Tuple") || typExpr->isId("tuple")) { return transform(N(startswith(typ->name, TYPE_TUPLE))); } else if (typExpr->isId("ByVal")) { return transform(N(typ->getRecord() != nullptr)); } else if (typExpr->isId("ByRef")) { return transform(N(typ->getRecord() == nullptr)); - } else { - transformType(typExpr); - // Check super types (i.e., statically inherited) as well - for (auto &tx : getSuperTypes(typ->getClass())) { - if (tx->unify(typExpr->type.get(), nullptr) >= 0) - return transform(N(true)); + } else if (typExpr->type->is("pyobj") && !typExpr->isType()) { + if (typ->is("pyobj")) { + expr->staticValue.type = StaticValue::NOT_STATIC; + return transform(N(N("std.internal.python._isinstance:0"), + expr->args[0].value, expr->args[1].value)); + } else { + return transform(N(false)); } - return transform(N(false)); } + + transformType(typExpr); + + // Check super types (i.e., statically inherited) as well + for (auto &tx : getSuperTypes(typ->getClass())) { + if (tx->unify(typExpr->type.get(), nullptr) >= 0) + return transform(N(true)); + } + return transform(N(false)); } /// Transform staticlen method to a static integer expression. This method supports only @@ -801,6 +828,33 @@ ExprPtr TypecheckVisitor::transformTypeFn(CallExpr *expr) { return e; } +/// Transform __realized__ function to a fully realized type identifier. +ExprPtr TypecheckVisitor::transformRealizedFn(CallExpr *expr) { + auto call = + transform(N(expr->args[0].value, N(expr->args[1].value))); + if (!call->getCall()->expr->type->getFunc()) + error("the first argument must be a function"); + if (auto f = realize(call->getCall()->expr->type)) { + auto e = N(f->getFunc()->realizedName()); + e->setType(f); + e->setDone(); + return e; + } + return nullptr; +} + +/// Transform __static_print__ function to a fully realized type identifier. +ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) { + auto &args = expr->args[0].value->getCall()->args; + for (size_t i = 0; i < args.size(); i++) { + realize(args[i].value->type); + fmt::print(stderr, "[static_print] {}: {} := {}\n", getSrcInfo(), + FormatVisitor::apply(args[i].value), + args[i].value->type ? args[i].value->type->debugString(1) : "-"); + } + return nullptr; +} + /// Get the list that describes the inheritance hierarchy of a given type. /// The first type in the list is the most recently inherited type. std::vector TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) { diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index 0163495a..9879708d 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -55,6 +55,15 @@ void TypecheckVisitor::visit(ClassStmt *stmt) { unify(defType->type, generic); } } + if (auto ti = CAST(a.type, InstantiateExpr)) { + // Parse TraitVar + seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation"); + auto l = transformType(ti->typeParams[0])->type; + if (l->getLink() && l->getLink()->trait) + generic->getLink()->trait = l->getLink()->trait; + else + generic->getLink()->trait = std::make_shared(l); + } ctx->add(TypecheckItem::Type, a.name, generic); ClassType::Generic g{a.name, ctx->cache->rev(a.name), generic->generalize(ctx->typecheckLevel), typId}; @@ -135,6 +144,18 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name, args.emplace_back(Param(format("T{}", i + 1), N("type"), nullptr, true)); StmtPtr stmt = N(ctx->cache->generateSrcInfo(), typeName, args, nullptr, std::vector{N("tuple")}); + + // Add getItem for KwArgs: + // `def __getitem__(self, key: Static[str]): return getattr(self, key)` + auto getItem = N( + "__getitem__", nullptr, + std::vector{Param{"self"}, Param{"key", N(N("Static"), + N("str"))}}, + N(N( + N(N("getattr"), N("self"), N("key"))))); + if (startswith(typeName, TYPE_KWTUPLE)) + stmt->getClass()->suite = getItem; + // Simplify in the standard library context and type check stmt = SimplifyVisitor::apply(ctx->cache->imports[STDLIB_IMPORT].ctx, stmt, FILE_GENERATED, 0); diff --git a/codon/parser/visitors/typecheck/cond.cpp b/codon/parser/visitors/typecheck/cond.cpp index 581dae4a..bfd95344 100644 --- a/codon/parser/visitors/typecheck/cond.cpp +++ b/codon/parser/visitors/typecheck/cond.cpp @@ -36,7 +36,10 @@ void TypecheckVisitor::visit(IfExpr *expr) { if (expr->cond->isStatic()) { resultExpr = evaluateStaticCondition( expr->cond, - [&](bool isTrue) { return transform(isTrue ? expr->ifexpr : expr->elsexpr); }, + [&](bool isTrue) { + LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue); + return transform(isTrue ? expr->ifexpr : expr->elsexpr); + }, [&]() -> ExprPtr { // Check if both subexpressions are static; if so, this if expression is also // static and should be marked as such @@ -83,6 +86,7 @@ void TypecheckVisitor::visit(IfStmt *stmt) { resultStmt = evaluateStaticCondition( stmt->cond, [&](bool isTrue) { + LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue); auto t = transform(isTrue ? stmt->ifSuite : stmt->elseSuite); return t ? t : transform(N()); }, diff --git a/codon/parser/visitors/typecheck/ctx.cpp b/codon/parser/visitors/typecheck/ctx.cpp index 6bdb2ca0..677dec0a 100644 --- a/codon/parser/visitors/typecheck/ctx.cpp +++ b/codon/parser/visitors/typecheck/ctx.cpp @@ -83,12 +83,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo, const types::ClassTypePtr &generics) { seqassert(type, "type is null"); std::unordered_map genericCache; - if (generics) + if (generics) { for (auto &g : generics->generics) if (g.type && !(g.type->getLink() && g.type->getLink()->kind == types::LinkType::Generic)) { genericCache[g.id] = g.type; } + } auto t = type->instantiate(typecheckLevel, &(cache->unboundCount), &genericCache); for (auto &i : genericCache) { if (auto l = i.second->getLink()) { diff --git a/codon/parser/visitors/typecheck/ctx.h b/codon/parser/visitors/typecheck/ctx.h index 977073e1..06725ae3 100644 --- a/codon/parser/visitors/typecheck/ctx.h +++ b/codon/parser/visitors/typecheck/ctx.h @@ -66,6 +66,8 @@ struct TypeContext : public Context { int blockLevel; /// True if an early return is found (anything afterwards won't be typechecked) bool returnEarly; + /// Stack of static loop control variables (used to emulate goto statements). + std::vector staticLoops; public: explicit TypeContext(Cache *cache); diff --git a/codon/parser/visitors/typecheck/error.cpp b/codon/parser/visitors/typecheck/error.cpp index 95a107e5..cb6a7897 100644 --- a/codon/parser/visitors/typecheck/error.cpp +++ b/codon/parser/visitors/typecheck/error.cpp @@ -9,38 +9,97 @@ namespace codon::ast { using namespace types; -/// Typecheck try-except statements. +/// Typecheck try-except statements. Handle Python exceptions separately. +/// @example +/// ```try: ... +/// except python.Error as e: ... +/// except PyExc as f: ... +/// except ValueError as g: ... +/// ``` -> ``` +/// try: ... +/// except ValueError as g: ... # ValueError +/// except PyExc as exc: +/// while True: +/// if isinstance(exc.pytype, python.Error): # python.Error +/// e = exc.pytype; ...; break +/// f = exc; ...; break # PyExc +/// raise``` void TypecheckVisitor::visit(TryStmt *stmt) { ctx->blockLevel++; transform(stmt->suite); ctx->blockLevel--; + std::vector catches; + auto pyVar = ctx->cache->getTemporaryVar("pyexc"); + auto pyCatchStmt = N(N(true), N()); + auto done = stmt->suite->isDone(); for (auto &c : stmt->catches) { - transformType(c.exc); - if (!c.var.empty()) { - // Handle dominated except bindings - auto changed = in(ctx->cache->replacements, c.var); - while (auto s = in(ctx->cache->replacements, c.var)) - c.var = s->first, changed = s; - if (changed && changed->second) { - auto update = - N(N(format("{}.__used__", c.var)), N(true)); - update->setUpdate(); - c.suite = N(update, c.suite); + transform(c.exc); + if (c.exc && c.exc->type->is("pyobj")) { + // Transform python.Error exceptions + if (!c.var.empty()) { + c.suite = N( + N(N(c.var), N(N(pyVar), "pytype")), + c.suite); } - if (changed) - c.exc->setAttr(ExprAttr::Dominated); - auto val = ctx->find(c.var); - if (!changed) - val = ctx->add(TypecheckItem::Var, c.var, c.exc->getType()); - unify(val->type, c.exc->getType()); + c.suite = + N(N(N("isinstance"), + N(N(pyVar), "pytype"), clone(c.exc)), + N(c.suite, N()), nullptr); + pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite); + } else if (c.exc && c.exc->type->is("std.internal.types.error.PyError")) { + // Transform PyExc exceptions + if (!c.var.empty()) { + c.suite = + N(N(N(c.var), N(pyVar)), c.suite); + } + c.suite = N(c.suite, N()); + pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite); + } else { + // Handle all other exceptions + transformType(c.exc); + if (!c.var.empty()) { + // Handle dominated except bindings + auto changed = in(ctx->cache->replacements, c.var); + while (auto s = in(ctx->cache->replacements, c.var)) + c.var = s->first, changed = s; + if (changed && changed->second) { + auto update = + N(N(format("{}.__used__", c.var)), N(true)); + update->setUpdate(); + c.suite = N(update, c.suite); + } + if (changed) + c.exc->setAttr(ExprAttr::Dominated); + auto val = ctx->find(c.var); + if (!changed) + val = ctx->add(TypecheckItem::Var, c.var, c.exc->getType()); + unify(val->type, c.exc->getType()); + } + ctx->blockLevel++; + transform(c.suite); + ctx->blockLevel--; + done &= (!c.exc || c.exc->isDone()) && c.suite->isDone(); + catches.push_back(c); } + } + if (!pyCatchStmt->suite->getSuite()->stmts.empty()) { + // Process PyError catches + auto exc = N("std.internal.types.error.PyError"); + exc->markType(); + pyCatchStmt->suite->getSuite()->stmts.push_back(N(nullptr)); + TryStmt::Catch c{pyVar, transformType(exc), pyCatchStmt}; + + auto val = ctx->add(TypecheckItem::Var, pyVar, c.exc->getType()); + unify(val->type, c.exc->getType()); ctx->blockLevel++; transform(c.suite); ctx->blockLevel--; done &= (!c.exc || c.exc->isDone()) && c.suite->isDone(); + catches.push_back(c); } + stmt->catches = catches; if (stmt->finally) { ctx->blockLevel++; transform(stmt->finally); @@ -56,6 +115,11 @@ void TypecheckVisitor::visit(TryStmt *stmt) { /// @example /// `raise exc` -> ```raise __internal__.set_header(exc, "fn", "file", line, col)``` void TypecheckVisitor::visit(ThrowStmt *stmt) { + if (!stmt->expr) { + stmt->setDone(); + return; + } + transform(stmt->expr); if (!(stmt->expr->getCall() && diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index f5561ab4..8503e617 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -40,16 +40,16 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { unify(ctx->getRealizationBase()->returnType, stmt->expr->type); } else { - // If we are not within conditional block, ignore later statements in this function. - // Useful with static if statements. - if (!ctx->blockLevel) - ctx->returnEarly = true; - // Just set the expr for the translation stage. However, do not unify the return // type! This might be a `return` in a generator. stmt->expr = transform(N(N("NoneType"))); } + // If we are not within conditional block, ignore later statements in this function. + // Useful with static if statements. + if (!ctx->blockLevel) + ctx->returnEarly = true; + if (stmt->expr->isDone()) stmt->setDone(); } diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 5291c252..37e67533 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -52,8 +52,11 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { ctx->typecheckLevel++; auto changedNodes = ctx->changedNodes; ctx->changedNodes = 0; + auto returnEarly = ctx->returnEarly; + ctx->returnEarly = false; TypecheckVisitor(ctx).transform(result); std::swap(ctx->changedNodes, changedNodes); + std::swap(ctx->returnEarly, returnEarly); ctx->typecheckLevel--; if (iteration == 1 && isToplevel) { @@ -282,10 +285,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { if (hasAst) { auto oldBlockLevel = ctx->blockLevel; - auto oldReturnEarly = ctx->returnEarly; ctx->blockLevel = 0; - ctx->returnEarly = false; - if (startswith(type->ast->name, "Function.__call__")) { // Special case: Function.__call__ /// TODO: move to IR @@ -313,7 +313,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { } inferTypes(ast->suite); ctx->blockLevel = oldBlockLevel; - ctx->returnEarly = oldReturnEarly; // Use NoneType as the return type when the return type is not specified and // function has no return statement @@ -382,6 +381,7 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { // Get the IR type auto *module = ctx->cache->module; ir::types::Type *handle = nullptr; + if (t->name == "bool") { handle = module->getBoolType(); } else if (t->name == "byte") { @@ -390,6 +390,8 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { handle = module->getIntType(); } else if (t->name == "float") { handle = module->getFloatType(); + } else if (t->name == "float32") { + handle = module->getFloat32Type(); } else if (t->name == "str") { handle = module->getStringType(); } else if (t->name == "Int" || t->name == "UInt") { @@ -415,6 +417,9 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) { types.push_back(forceFindIRType(m)); auto ret = forceFindIRType(t->generics[1].type); handle = module->unsafeGetFuncType(realizedName, ret, types); + } else if (t->name == "std.simd.Vec") { + seqassert(types.size() == 1 && statics.size() == 1, "bad generics/statics"); + handle = module->unsafeGetVectorType(statics[0]->getInt(), types[0]); } else if (auto tr = t->getRecord()) { std::vector typeArgs; std::vector names; diff --git a/codon/parser/visitors/typecheck/loops.cpp b/codon/parser/visitors/typecheck/loops.cpp index 62bfbb83..fd056504 100644 --- a/codon/parser/visitors/typecheck/loops.cpp +++ b/codon/parser/visitors/typecheck/loops.cpp @@ -13,17 +13,32 @@ namespace codon::ast { using namespace types; /// Nothing to typecheck; just call setDone -void TypecheckVisitor::visit(BreakStmt *stmt) { stmt->setDone(); } +void TypecheckVisitor::visit(BreakStmt *stmt) { + stmt->setDone(); + if (!ctx->staticLoops.back().empty()) { + auto a = N(N(ctx->staticLoops.back()), N(false)); + a->setUpdate(); + resultStmt = transform(N(a, stmt->clone())); + } +} /// Nothing to typecheck; just call setDone -void TypecheckVisitor::visit(ContinueStmt *stmt) { stmt->setDone(); } +void TypecheckVisitor::visit(ContinueStmt *stmt) { + stmt->setDone(); + if (!ctx->staticLoops.back().empty()) { + resultStmt = N(); + resultStmt->setDone(); + } +} /// Typecheck while statements. void TypecheckVisitor::visit(WhileStmt *stmt) { + ctx->staticLoops.push_back(stmt->gotoVar.empty() ? "" : stmt->gotoVar); transform(stmt->cond); ctx->blockLevel++; transform(stmt->suite); ctx->blockLevel--; + ctx->staticLoops.pop_back(); if (stmt->cond->isDone() && stmt->suite->isDone()) stmt->setDone(); @@ -40,6 +55,9 @@ void TypecheckVisitor::visit(ForStmt *stmt) { if (!iterType) return; // wait until the iterator is known + if ((resultStmt = transformStaticForLoop(stmt))) + return; + bool maybeHeterogenous = startswith(iterType->name, TYPE_TUPLE) || startswith(iterType->name, TYPE_KWTUPLE); if (maybeHeterogenous && !iterType->canRealize()) { @@ -83,9 +101,11 @@ void TypecheckVisitor::visit(ForStmt *stmt) { unify(stmt->var->type, iterType ? unify(val->type, iterType->generics[0].type) : val->type); + ctx->staticLoops.push_back(""); ctx->blockLevel++; transform(stmt->suite); ctx->blockLevel--; + ctx->staticLoops.pop_back(); if (stmt->iter->isDone() && stmt->suite->isDone()) stmt->setDone(); @@ -132,4 +152,81 @@ StmtPtr TypecheckVisitor::transformHeterogenousTupleFor(ForStmt *stmt) { return block; } +/// Handle static for constructs. +/// @example +/// `for i in statictuple(1, x): ` -> +/// ```loop = True +/// while loop: +/// while loop: +/// i: Static[int] = 1; ; break +/// while loop: +/// i = x; ; break +/// loop = False # also set to False on break +/// A separate suite is generated for each static iteration. +StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { + auto loopVar = ctx->cache->getTemporaryVar("loop"); + auto fn = [&](const std::string &var, const ExprPtr &expr) { + bool staticInt = expr->isStatic(); + auto t = N( + N("Static"), + N(expr->staticValue.type == StaticValue::INT ? "int" : "str")); + t->markType(); + auto brk = N(); + brk->setDone(); // Avoid transforming this one to continue + // var [: Static] := expr; suite... + auto loop = N(N(loopVar), + N(N(N(var), expr->clone(), + staticInt ? t : nullptr), + clone(stmt->suite), brk)); + loop->gotoVar = loopVar; + return loop; + }; + + auto var = stmt->var->getId()->value; + if (!stmt->iter->getCall() || !stmt->iter->getCall()->expr->getId()) + return nullptr; + auto iter = stmt->iter->getCall()->expr->getId(); + auto block = N(); + if (iter && startswith(iter->value, "statictuple:0")) { + auto &args = stmt->iter->getCall()->args[0].value->getCall()->args; + for (size_t i = 0; i < args.size(); i++) + block->stmts.push_back(fn(var, args[i].value)); + } else if (iter && + startswith(iter->value, "std.internal.types.range.staticrange:0")) { + int st = + iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); + int ed = + iter->type->getFunc()->funcGenerics[1].type->getStatic()->evaluate().getInt(); + int step = + iter->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt(); + if (abs(st - ed) / abs(step) > MAX_STATIC_ITER) + error("staticrange out of bounds ({} > {})", abs(st - ed) / abs(step), + MAX_STATIC_ITER); + for (int i = st; step > 0 ? i < ed : i > ed; i += step) + block->stmts.push_back(fn(var, N(i))); + } else if (iter && + startswith(iter->value, "std.internal.types.range.staticrange:1")) { + int ed = + iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt(); + if (ed > MAX_STATIC_ITER) + error("staticrange out of bounds ({} > {})", ed, MAX_STATIC_ITER); + for (int i = 0; i < ed; i++) + block->stmts.push_back(fn(var, N(i))); + } else { + return nullptr; + } + ctx->blockLevel++; + + // Close the loop + auto a = N(N(loopVar), N(false)); + a->setUpdate(); + block->stmts.push_back(a); + + auto loop = + transform(N(N(N(loopVar), N(true)), + N(N(loopVar), block))); + ctx->blockLevel--; + return loop; +} + } // namespace codon::ast diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index ca558a60..d91f8cd4 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -262,15 +262,20 @@ void TypecheckVisitor::visit(IndexExpr *expr) { /// @example /// Instantiate(foo, [bar]) -> Id("foo[bar]") void TypecheckVisitor::visit(InstantiateExpr *expr) { - // Infer the expression type + transformType(expr->typeExpr); + TypePtr typ = + ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType()); + seqassert(typ->getClass(), "unknown type: {}", expr->typeExpr->toString()); + + auto &generics = typ->getClass()->generics; + if (expr->typeParams.size() != generics.size()) + error("expected {} generics and/or statics", generics.size()); + if (expr->typeExpr->isId(TYPE_CALLABLE)) { - // Case: Callable[...] instantiation + // Case: Callable[...] trait instantiation std::vector types; // Callable error checking. - /// TODO: move to Codon? - if (expr->typeParams.size() != 2) - error("invalid Callable type declaration"); for (auto &typeParam : expr->typeParams) { transformType(typeParam); if (typeParam->type->isStaticType()) @@ -281,16 +286,13 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) { // Set up the Callable trait typ->getLink()->trait = std::make_shared(types); unify(expr->type, typ); + } else if (expr->typeExpr->isId(TYPE_TYPEVAR)) { + // Case: TypeVar[...] trait instantiation + transformType(expr->typeParams[0]); + auto typ = ctx->getUnbound(); + typ->getLink()->trait = std::make_shared(expr->typeParams[0]->type); + unify(expr->type, typ); } else { - transformType(expr->typeExpr); - TypePtr typ = - ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType()); - seqassert(typ->getClass(), "unknown type"); - - auto &generics = typ->getClass()->generics; - if (expr->typeParams.size() != generics.size()) - error("expected {} generics and/or statics", generics.size()); - for (size_t i = 0; i < expr->typeParams.size(); i++) { transform(expr->typeParams[i]); TypePtr t = nullptr; @@ -339,7 +341,9 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { if (expr->expr->staticValue.type == StaticValue::STRING) { if (expr->op == "!") { if (expr->expr->staticValue.evaluated) { - return transform(N(expr->expr->staticValue.getString().empty())); + bool value = expr->expr->staticValue.getString().empty(); + LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); + return transform(N(value)); } else { // Cannot be evaluated yet: just set the type unify(expr->type, ctx->getType("bool")); @@ -360,6 +364,7 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { value = -value; else value = !bool(value); + LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value); if (expr->op == "!") return transform(N(bool(value))); else @@ -375,6 +380,25 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) { return nullptr; } +/// Division and modulus implementations. +std::pair divMod(const std::shared_ptr &ctx, int a, int b) { + if (!b) + error(ctx->getSrcInfo(), "static division by zero"); + if (ctx->cache->pythonCompat) { + // Use Python implementation. + int d = a / b; + int m = a - d * b; + if (m && ((b ^ m) < 0)) { + m += b; + d -= 1; + } + return {d, m}; + } else { + // Use C implementation. + return {a / b, a % b}; + } +} + /// Evaluate a static binary expression and return the resulting static expression. /// If the expression cannot be evaluated yet, return nullptr. /// Supported operators: (strings) +, ==, != @@ -385,8 +409,10 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { if (expr->op == "+") { // `"a" + "b"` -> `"ab"` if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) { - return transform(N(expr->lexpr->staticValue.getString() + - expr->rexpr->staticValue.getString())); + auto value = + expr->lexpr->staticValue.getString() + expr->rexpr->staticValue.getString(); + LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); + return transform(N(value)); } else { // Cannot be evaluated yet: just set the type if (!expr->isStatic()) @@ -398,7 +424,9 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) { bool eq = expr->lexpr->staticValue.getString() == expr->rexpr->staticValue.getString(); - return transform(N(expr->op == "==" ? eq : !eq)); + bool value = expr->op == "==" ? eq : !eq; + LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value); + return transform(N(value)); } else { // Cannot be evaluated yet: just set the type if (!expr->isStatic()) @@ -441,18 +469,13 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) { lvalue = lvalue & rvalue; else if (expr->op == "|") lvalue = lvalue | rvalue; - else if (expr->op == "//") { - if (!rvalue) - error("static division by zero"); - lvalue = lvalue / rvalue; - } else if (expr->op == "%") { - if (!rvalue) - error("static division by zero"); - lvalue = lvalue % rvalue; - } else { + else if (expr->op == "//") + lvalue = divMod(ctx, lvalue, rvalue).first; + else if (expr->op == "%") + lvalue = divMod(ctx, lvalue, rvalue).second; + else seqassert(false, "unknown static operator {}", expr->op); - } - + LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), lvalue); if (in(std::set{"==", "!=", "<", "<=", ">", ">=", "&&", "||"}, expr->op)) return transform(N(bool(lvalue))); diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index eb0b68bf..786c3154 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -47,6 +47,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) { ctx->popSrcInfo(); if (v.resultExpr) { v.resultExpr->attributes |= expr->attributes; + v.resultExpr->origExpr = expr; expr = v.resultExpr; } seqassert(expr->type, "type not set for {}", expr->toString()); diff --git a/codon/parser/visitors/typecheck/typecheck.h b/codon/parser/visitors/typecheck/typecheck.h index f78293d9..696d9902 100644 --- a/codon/parser/visitors/typecheck/typecheck.h +++ b/codon/parser/visitors/typecheck/typecheck.h @@ -135,6 +135,8 @@ private: // Node typechecking rules ExprPtr transformCompileError(CallExpr *expr); ExprPtr transformTupleFn(CallExpr *expr); ExprPtr transformTypeFn(CallExpr *expr); + ExprPtr transformRealizedFn(CallExpr *expr); + ExprPtr transformStaticPrintFn(CallExpr *expr); std::vector getSuperTypes(const types::ClassTypePtr &cls); void addFunctionGenerics(const types::FuncType *t); std::string generatePartialStub(const std::vector &mask, types::FuncType *fn); @@ -151,6 +153,7 @@ private: // Node typechecking rules void visit(WhileStmt *) override; void visit(ForStmt *) override; StmtPtr transformHeterogenousTupleFor(ForStmt *); + StmtPtr transformStaticForLoop(ForStmt *); /* Errors and exceptions (error.cpp) */ void visit(TryStmt *) override; diff --git a/codon/runtime/gpu.cpp b/codon/runtime/gpu.cpp new file mode 100644 index 00000000..f757d2c8 --- /dev/null +++ b/codon/runtime/gpu.cpp @@ -0,0 +1,137 @@ +#include +#include +#include +#include + +#include "lib.h" + +#ifdef CODON_GPU + +#include "cuda.h" + +#define fail(err) \ + do { \ + const char *msg; \ + cuGetErrorString((err), &msg); \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, msg); \ + abort(); \ + } while (0) + +#define check(call) \ + do { \ + auto err = (call); \ + if (err != CUDA_SUCCESS) { \ + fail(err); \ + } \ + } while (0) + +static std::vector modules; +static CUcontext context; + +void seq_nvptx_init() { + CUdevice device; + check(cuInit(0)); + check(cuDeviceGet(&device, 0)); + check(cuCtxCreate(&context, 0, device)); +} + +SEQ_FUNC void seq_nvptx_load_module(const char *filename) { + CUmodule module; + check(cuModuleLoad(&module, filename)); + modules.push_back(module); +} + +SEQ_FUNC seq_int_t seq_nvptx_device_count() { + int devCount; + check(cuDeviceGetCount(&devCount)); + return devCount; +} + +SEQ_FUNC seq_str_t seq_nvptx_device_name(CUdevice device) { + char name[128]; + check(cuDeviceGetName(name, sizeof(name) - 1, device)); + auto sz = static_cast(strlen(name)); + auto *p = (char *)seq_alloc_atomic(sz); + memcpy(p, name, sz); + return {sz, p}; +} + +SEQ_FUNC seq_int_t seq_nvptx_device_capability(CUdevice device) { + int devMajor, devMinor; + check(cuDeviceComputeCapability(&devMajor, &devMinor, device)); + return ((seq_int_t)devMajor << 32) | (seq_int_t)devMinor; +} + +SEQ_FUNC CUdevice seq_nvptx_device(seq_int_t idx) { + CUdevice device; + check(cuDeviceGet(&device, idx)); + return device; +} + +static bool name_char_valid(char c, bool first) { + bool ok = ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_'); + if (!first) + ok = ok || ('0' <= c && c <= '9'); + return ok; +} + +SEQ_FUNC CUfunction seq_nvptx_function(seq_str_t name) { + CUfunction function; + CUresult result; + + std::vector clean(name.len + 1); + for (unsigned i = 0; i < name.len; i++) { + char c = name.str[i]; + clean[i] = (name_char_valid(c, i == 0) ? c : '_'); + } + clean[name.len] = '\0'; + + for (auto it = modules.rbegin(); it != modules.rend(); ++it) { + result = cuModuleGetFunction(&function, *it, clean.data()); + if (result == CUDA_SUCCESS) { + return function; + } else if (result == CUDA_ERROR_NOT_FOUND) { + continue; + } else { + break; + } + } + + fail(result); + return {}; +} + +SEQ_FUNC void seq_nvptx_invoke(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, + void **kernelParams) { + check(cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, + sharedMemBytes, nullptr, kernelParams, nullptr)); +} + +SEQ_FUNC CUdeviceptr seq_nvptx_device_alloc(seq_int_t size) { + if (size == 0) + return {}; + + CUdeviceptr devp; + check(cuMemAlloc(&devp, size)); + return devp; +} + +SEQ_FUNC void seq_nvptx_memcpy_h2d(CUdeviceptr devp, char *hostp, seq_int_t size) { + if (size) + check(cuMemcpyHtoD(devp, hostp, size)); +} + +SEQ_FUNC void seq_nvptx_memcpy_d2h(char *hostp, CUdeviceptr devp, seq_int_t size) { + if (size) + check(cuMemcpyDtoH(hostp, devp, size)); +} + +SEQ_FUNC void seq_nvptx_device_free(CUdeviceptr devp) { + if (devp) + check(cuMemFree(devp)); +} + +#endif /* CODON_GPU */ diff --git a/codon/runtime/lib.cpp b/codon/runtime/lib.cpp index 62be4721..e0bc1ae8 100644 --- a/codon/runtime/lib.cpp +++ b/codon/runtime/lib.cpp @@ -38,6 +38,10 @@ extern "C" void __kmpc_set_gc_callbacks(gc_setup_callback get_stack_base, void seq_exc_init(); +#ifdef CODON_GPU +void seq_nvptx_init(); +#endif + int seq_flags; SEQ_FUNC void seq_init(int flags) { @@ -47,6 +51,9 @@ SEQ_FUNC void seq_init(int flags) { __kmpc_set_gc_callbacks(GC_get_stack_base, (gc_setup_callback)GC_register_my_thread, GC_add_roots, GC_remove_roots); seq_exc_init(); +#ifdef CODON_GPU + seq_nvptx_init(); +#endif seq_flags = flags; } @@ -176,11 +183,11 @@ SEQ_FUNC void *seq_calloc_atomic(size_t m, size_t n) { #endif } -SEQ_FUNC void *seq_realloc(void *p, size_t n) { +SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize) { #if USE_STANDARD_MALLOC - return realloc(p, n); + return realloc(p, newsize); #else - return GC_REALLOC(p, n); + return GC_REALLOC(p, newsize); #endif } @@ -231,7 +238,7 @@ static seq_str_t string_conv(const char *fmt, const size_t size, T t) { int n = snprintf(p, size, fmt, t); if (n >= size) { auto n2 = (size_t)n + 1; - p = (char *)seq_realloc((void *)p, n2); + p = (char *)seq_realloc((void *)p, n2, size); n = snprintf(p, n2, fmt, t); } return {(seq_int_t)n, p}; diff --git a/codon/runtime/lib.h b/codon/runtime/lib.h index 1d8518f3..32b35838 100644 --- a/codon/runtime/lib.h +++ b/codon/runtime/lib.h @@ -54,7 +54,7 @@ SEQ_FUNC void *seq_alloc(size_t n); SEQ_FUNC void *seq_alloc_atomic(size_t n); SEQ_FUNC void *seq_calloc(size_t m, size_t n); SEQ_FUNC void *seq_calloc_atomic(size_t m, size_t n); -SEQ_FUNC void *seq_realloc(void *p, size_t n); +SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize); SEQ_FUNC void seq_free(void *p); SEQ_FUNC void seq_register_finalizer(void *p, void (*f)(void *obj, void *data)); diff --git a/codon/sir/llvm/gpu.cpp b/codon/sir/llvm/gpu.cpp new file mode 100644 index 00000000..76727d01 --- /dev/null +++ b/codon/sir/llvm/gpu.cpp @@ -0,0 +1,550 @@ +#include "gpu.h" + +#include +#include +#include + +#include "codon/util/common.h" +#include "llvm/CodeGen/CommandFlags.h" + +namespace codon { +namespace ir { +namespace { +const std::string GPU_TRIPLE = "nvptx64-nvidia-cuda"; +const std::string GPU_DL = + "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-" + "f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"; +llvm::cl::opt + libdevice("libdevice", llvm::cl::desc("libdevice path for GPU kernels"), + llvm::cl::init("/usr/local/cuda/nvvm/libdevice/libdevice.10.bc")); + +std::string cleanUpName(llvm::StringRef name) { + std::string validName; + llvm::raw_string_ostream validNameStream(validName); + + auto valid = [](char c, bool first) { + bool ok = ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_'); + if (!first) + ok = ok || ('0' <= c && c <= '9'); + return ok; + }; + + bool first = true; + for (char c : name) { + validNameStream << (valid(c, first) ? c : '_'); + first = false; + } + + return validNameStream.str(); +} + +void linkLibdevice(llvm::Module *M, const std::string &path) { + llvm::SMDiagnostic err; + auto libdevice = llvm::parseIRFile(path, err, M->getContext()); + if (!libdevice) + compilationError(err.getMessage().str(), err.getFilename().str(), err.getLineNo(), + err.getColumnNo()); + libdevice->setDataLayout(M->getDataLayout()); + libdevice->setTargetTriple(M->getTargetTriple()); + + llvm::Linker L(*M); + const bool fail = L.linkInModule(std::move(libdevice)); + seqassertn(!fail, "linking libdevice failed"); +} + +llvm::Function *copyPrototype(llvm::Function *F, const std::string &name) { + auto *M = F->getParent(); + return llvm::Function::Create(F->getFunctionType(), llvm::GlobalValue::PrivateLinkage, + name.empty() ? F->getName() : name, *M); +} + +llvm::Function *makeNoOp(llvm::Function *F) { + auto *M = F->getParent(); + auto &context = M->getContext(); + auto dummyName = (".codon.gpu.dummy." + F->getName()).str(); + auto *dummy = M->getFunction(dummyName); + if (!dummy) { + dummy = copyPrototype(F, dummyName); + auto *entry = llvm::BasicBlock::Create(context, "entry", dummy); + llvm::IRBuilder<> B(entry); + + auto *retType = F->getReturnType(); + if (retType->isVoidTy()) { + B.CreateRetVoid(); + } else { + B.CreateRet(llvm::UndefValue::get(retType)); + } + } + return dummy; +} + +using Codegen = + std::function &, const std::vector &)>; + +llvm::Function *makeFillIn(llvm::Function *F, Codegen codegen) { + auto *M = F->getParent(); + auto &context = M->getContext(); + auto fillInName = (".codon.gpu.fillin." + F->getName()).str(); + auto *fillIn = M->getFunction(fillInName); + if (!fillIn) { + fillIn = copyPrototype(F, fillInName); + std::vector args; + for (auto it = fillIn->arg_begin(); it != fillIn->arg_end(); ++it) { + args.push_back(it); + } + auto *entry = llvm::BasicBlock::Create(context, "entry", fillIn); + llvm::IRBuilder<> B(entry); + codegen(B, args); + } + return fillIn; +} + +llvm::Function *makeMalloc(llvm::Module *M) { + auto &context = M->getContext(); + auto F = M->getOrInsertFunction("malloc", llvm::Type::getInt8PtrTy(context), + llvm::Type::getInt64Ty(context)); + auto *G = llvm::cast(F.getCallee()); + G->setLinkage(llvm::GlobalValue::ExternalLinkage); + G->setDoesNotThrow(); + G->setReturnDoesNotAlias(); + G->setOnlyAccessesInaccessibleMemory(); + G->setWillReturn(); + return G; +} + +void remapFunctions(llvm::Module *M) { + // simple name-to-name remappings + static const std::vector> remapping = { + // 64-bit float intrinsics + {"llvm.ceil.f64", "__nv_ceil"}, + {"llvm.floor.f64", "__nv_floor"}, + {"llvm.fabs.f64", "__nv_fabs"}, + {"llvm.exp.f64", "__nv_exp"}, + {"llvm.log.f64", "__nv_log"}, + {"llvm.log2.f64", "__nv_log2"}, + {"llvm.log10.f64", "__nv_log10"}, + {"llvm.sqrt.f64", "__nv_sqrt"}, + {"llvm.pow.f64", "__nv_pow"}, + {"llvm.sin.f64", "__nv_sin"}, + {"llvm.cos.f64", "__nv_cos"}, + {"llvm.copysign.f64", "__nv_copysign"}, + {"llvm.trunc.f64", "__nv_trunc"}, + {"llvm.rint.f64", "__nv_rint"}, + {"llvm.nearbyint.f64", "__nv_nearbyint"}, + {"llvm.round.f64", "__nv_round"}, + {"llvm.minnum.f64", "__nv_fmin"}, + {"llvm.maxnum.f64", "__nv_fmax"}, + {"llvm.copysign.f64", "__nv_copysign"}, + {"llvm.fma.f64", "__nv_fma"}, + + // 64-bit float math functions + {"expm1", "__nv_expm1"}, + {"ldexp", "__nv_ldexp"}, + {"acos", "__nv_acos"}, + {"asin", "__nv_asin"}, + {"atan", "__nv_atan"}, + {"atan2", "__nv_atan2"}, + {"hypot", "__nv_hypot"}, + {"tan", "__nv_tan"}, + {"cosh", "__nv_cosh"}, + {"sinh", "__nv_sinh"}, + {"tanh", "__nv_tanh"}, + {"acosh", "__nv_acosh"}, + {"asinh", "__nv_asinh"}, + {"atanh", "__nv_atanh"}, + {"erf", "__nv_erf"}, + {"erfc", "__nv_erfc"}, + {"tgamma", "__nv_tgamma"}, + {"lgamma", "__nv_lgamma"}, + {"remainder", "__nv_remainder"}, + {"frexp", "__nv_frexp"}, + {"modf", "__nv_modf"}, + + // 32-bit float intrinsics + {"llvm.ceil.f32", "__nv_ceilf"}, + {"llvm.floor.f32", "__nv_floorf"}, + {"llvm.fabs.f32", "__nv_fabsf"}, + {"llvm.exp.f32", "__nv_expf"}, + {"llvm.log.f32", "__nv_logf"}, + {"llvm.log2.f32", "__nv_log2f"}, + {"llvm.log10.f32", "__nv_log10f"}, + {"llvm.sqrt.f32", "__nv_sqrtf"}, + {"llvm.pow.f32", "__nv_powf"}, + {"llvm.sin.f32", "__nv_sinf"}, + {"llvm.cos.f32", "__nv_cosf"}, + {"llvm.copysign.f32", "__nv_copysignf"}, + {"llvm.trunc.f32", "__nv_truncf"}, + {"llvm.rint.f32", "__nv_rintf"}, + {"llvm.nearbyint.f32", "__nv_nearbyintf"}, + {"llvm.round.f32", "__nv_roundf"}, + {"llvm.minnum.f32", "__nv_fminf"}, + {"llvm.maxnum.f32", "__nv_fmaxf"}, + {"llvm.copysign.f32", "__nv_copysignf"}, + {"llvm.fma.f32", "__nv_fmaf"}, + + // 32-bit float math functions + {"expm1f", "__nv_expm1f"}, + {"ldexpf", "__nv_ldexpf"}, + {"acosf", "__nv_acosf"}, + {"asinf", "__nv_asinf"}, + {"atanf", "__nv_atanf"}, + {"atan2f", "__nv_atan2f"}, + {"hypotf", "__nv_hypotf"}, + {"tanf", "__nv_tanf"}, + {"coshf", "__nv_coshf"}, + {"sinhf", "__nv_sinhf"}, + {"tanhf", "__nv_tanhf"}, + {"acoshf", "__nv_acoshf"}, + {"asinhf", "__nv_asinhf"}, + {"atanhf", "__nv_atanhf"}, + {"erff", "__nv_erff"}, + {"erfcf", "__nv_erfcf"}, + {"tgammaf", "__nv_tgammaf"}, + {"lgammaf", "__nv_lgammaf"}, + {"remainderf", "__nv_remainderf"}, + {"frexpf", "__nv_frexpf"}, + {"modff", "__nv_modff"}, + + // runtime library functions + {"seq_free", "free"}, + {"seq_register_finalizer", ""}, + {"seq_gc_add_roots", ""}, + {"seq_gc_remove_roots", ""}, + {"seq_gc_clear_roots", ""}, + {"seq_gc_exclude_static_roots", ""}, + }; + + // functions that need to be generated as they're not available on GPU + static const std::vector> fillins = { + {"seq_alloc", + [](llvm::IRBuilder<> &B, const std::vector &args) { + auto *M = B.GetInsertBlock()->getModule(); + llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]); + B.CreateRet(mem); + }}, + + {"seq_alloc_atomic", + [](llvm::IRBuilder<> &B, const std::vector &args) { + auto *M = B.GetInsertBlock()->getModule(); + llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]); + B.CreateRet(mem); + }}, + + {"seq_realloc", + [](llvm::IRBuilder<> &B, const std::vector &args) { + auto *M = B.GetInsertBlock()->getModule(); + llvm::Value *mem = B.CreateCall(makeMalloc(M), args[1]); + auto F = llvm::Intrinsic::getDeclaration( + M, llvm::Intrinsic::memcpy, + {B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt64Ty()}); + B.CreateCall(F, {mem, args[0], args[2], B.getFalse()}); + B.CreateRet(mem); + }}, + + {"seq_calloc", + [](llvm::IRBuilder<> &B, const std::vector &args) { + auto *M = B.GetInsertBlock()->getModule(); + llvm::Value *size = B.CreateMul(args[0], args[1]); + llvm::Value *mem = B.CreateCall(makeMalloc(M), size); + auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset, + {B.getInt8PtrTy(), B.getInt64Ty()}); + B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()}); + B.CreateRet(mem); + }}, + + {"seq_calloc_atomic", + [](llvm::IRBuilder<> &B, const std::vector &args) { + auto *M = B.GetInsertBlock()->getModule(); + llvm::Value *size = B.CreateMul(args[0], args[1]); + llvm::Value *mem = B.CreateCall(makeMalloc(M), size); + auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset, + {B.getInt8PtrTy(), B.getInt64Ty()}); + B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()}); + B.CreateRet(mem); + }}, + + {"seq_alloc_exc", + [](llvm::IRBuilder<> &B, const std::vector &args) { + // TODO: print error message and abort if in debug mode + B.CreateUnreachable(); + }}, + + {"seq_throw", + [](llvm::IRBuilder<> &B, + const std::vector &args) { B.CreateUnreachable(); }}, + }; + + for (auto &pair : remapping) { + if (auto *F = M->getFunction(pair.first)) { + llvm::Function *G = nullptr; + if (pair.second.empty()) { + G = makeNoOp(F); + } else { + G = M->getFunction(pair.second); + if (!G) + G = copyPrototype(F, pair.second); + } + + G->setWillReturn(); + F->replaceAllUsesWith(G); + F->dropAllReferences(); + F->eraseFromParent(); + } + } + + for (auto &pair : fillins) { + if (auto *F = M->getFunction(pair.first)) { + llvm::Function *G = makeFillIn(F, pair.second); + F->replaceAllUsesWith(G); + F->dropAllReferences(); + F->eraseFromParent(); + } + } +} + +void exploreGV(llvm::GlobalValue *G, llvm::SmallPtrSetImpl &keep) { + if (keep.contains(G)) + return; + + keep.insert(G); + if (auto *F = llvm::dyn_cast(G)) { + for (auto I = llvm::inst_begin(F), E = inst_end(F); I != E; ++I) { + for (auto &U : I->operands()) { + if (auto *G2 = llvm::dyn_cast(U.get())) + exploreGV(G2, keep); + } + } + } +} + +std::vector +getRequiredGVs(const std::vector &kernels) { + llvm::SmallPtrSet keep; + for (auto *G : kernels) { + exploreGV(G, keep); + } + return std::vector(keep.begin(), keep.end()); +} + +void moduleToPTX(llvm::Module *M, const std::string &filename, + std::vector &kernels, + const std::string &cpuStr = "sm_30", + const std::string &featuresStr = "+ptx42") { + llvm::Triple triple(llvm::Triple::normalize(GPU_TRIPLE)); + llvm::TargetLibraryInfoImpl tlii(triple); + + std::string err; + const llvm::Target *target = + llvm::TargetRegistry::lookupTarget("nvptx64", triple, err); + seqassertn(target, "couldn't lookup target: {}", err); + + const llvm::TargetOptions options = + llvm::codegen::InitTargetOptionsFromCodeGenFlags(triple); + + std::unique_ptr machine(target->createTargetMachine( + triple.getTriple(), cpuStr, featuresStr, options, + llvm::codegen::getExplicitRelocModel(), llvm::codegen::getExplicitCodeModel(), + llvm::CodeGenOpt::Aggressive)); + + M->setDataLayout(machine->createDataLayout()); + auto keep = getRequiredGVs(kernels); + + auto prune = [&](std::vector keep) { + auto pm = std::make_unique(); + pm->add(new llvm::TargetLibraryInfoWrapperPass(tlii)); + // Delete everything but kernel functions. + pm->add(llvm::createGVExtractionPass(keep)); + // Delete unreachable globals. + pm->add(llvm::createGlobalDCEPass()); + // Remove dead debug info. + pm->add(llvm::createStripDeadDebugInfoPass()); + // Remove dead func decls. + pm->add(llvm::createStripDeadPrototypesPass()); + pm->run(*M); + }; + + // Remove non-kernel functions. + prune(keep); + + // Link libdevice and other cleanup. + linkLibdevice(M, libdevice); + remapFunctions(M); + + // Run NVPTX passes and general opt pipeline. + { + auto pm = std::make_unique(); + auto fpm = std::make_unique(M); + pm->add(new llvm::TargetLibraryInfoWrapperPass(tlii)); + + pm->add(llvm::createTargetTransformInfoWrapperPass( + machine ? machine->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); + fpm->add(llvm::createTargetTransformInfoWrapperPass( + machine ? machine->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); + + if (machine) { + auto <m = dynamic_cast(*machine); + llvm::Pass *tpc = ltm.createPassConfig(*pm); + pm->add(tpc); + } + + pm->add(llvm::createInternalizePass([&](const llvm::GlobalValue &gv) { + return std::find(keep.begin(), keep.end(), &gv) != keep.end(); + })); + + llvm::PassManagerBuilder pmb; + unsigned optLevel = 3, sizeLevel = 0; + pmb.OptLevel = optLevel; + pmb.SizeLevel = sizeLevel; + pmb.Inliner = llvm::createFunctionInliningPass(optLevel, sizeLevel, false); + pmb.DisableUnrollLoops = false; + pmb.LoopVectorize = true; + pmb.SLPVectorize = true; + + if (machine) { + machine->adjustPassManager(pmb); + } + + pmb.populateModulePassManager(*pm); + pmb.populateFunctionPassManager(*fpm); + + fpm->doInitialization(); + for (llvm::Function &f : *M) { + fpm->run(f); + } + fpm->doFinalization(); + pm->run(*M); + } + + // Prune again after optimizations. + keep = getRequiredGVs(kernels); + prune(keep); + + // Clean up names. + { + for (auto &G : M->globals()) { + G.setName(cleanUpName(G.getName())); + } + + for (auto &F : M->functions()) { + if (F.getInstructionCount() > 0) + F.setName(cleanUpName(F.getName())); + } + + for (auto *S : M->getIdentifiedStructTypes()) { + S->setName(cleanUpName(S->getName())); + } + } + + // Generate PTX file. + { + std::error_code errcode; + auto out = std::make_unique(filename, errcode, + llvm::sys::fs::OF_Text); + if (errcode) + compilationError(errcode.message()); + llvm::raw_pwrite_stream *os = &out->os(); + + auto &llvmtm = static_cast(*machine); + auto *mmiwp = new llvm::MachineModuleInfoWrapperPass(&llvmtm); + llvm::legacy::PassManager pm; + + pm.add(new llvm::TargetLibraryInfoWrapperPass(tlii)); + seqassertn(!machine->addPassesToEmitFile(pm, *os, nullptr, llvm::CGFT_AssemblyFile, + /*DisableVerify=*/false, mmiwp), + "could not add passes"); + const_cast(llvmtm.getObjFileLowering()) + ->Initialize(mmiwp->getMMI().getContext(), *machine); + pm.run(*M); + out->keep(); + } +} + +void addInitCall(llvm::Module *M, const std::string &filename) { + llvm::LLVMContext &context = M->getContext(); + auto f = + M->getOrInsertFunction("seq_nvptx_load_module", llvm::Type::getVoidTy(context), + llvm::Type::getInt8PtrTy(context)); + auto *g = llvm::cast(f.getCallee()); + g->setDoesNotThrow(); + + auto *filenameVar = new llvm::GlobalVariable( + *M, llvm::ArrayType::get(llvm::Type::getInt8Ty(context), filename.length() + 1), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, + llvm::ConstantDataArray::getString(context, filename), ".nvptx.filename"); + filenameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); + llvm::IRBuilder<> B(context); + + if (auto *init = M->getFunction("seq_init")) { + seqassertn(init->hasOneUse(), "seq_init used more than once"); + auto *use = llvm::dyn_cast(init->use_begin()->getUser()); + seqassertn(use, "seq_init use was not a call"); + B.SetInsertPoint(use->getNextNode()); + B.CreateCall(g, B.CreateBitCast(filenameVar, B.getInt8PtrTy())); + } + + for (auto &F : M->functions()) { + if (F.hasFnAttribute("jit")) { + B.SetInsertPoint(F.getEntryBlock().getFirstNonPHI()); + B.CreateCall(g, B.CreateBitCast(filenameVar, B.getInt8PtrTy())); + } + } +} + +void cleanUpIntrinsics(llvm::Module *M) { + llvm::LLVMContext &context = M->getContext(); + llvm::SmallVector remove; + for (auto &F : *M) { + if (F.getIntrinsicID() != llvm::Intrinsic::not_intrinsic && + F.getName().startswith("llvm.nvvm")) + remove.push_back(&F); + } + + for (auto *F : remove) { + F->replaceAllUsesWith(makeNoOp(F)); + F->dropAllReferences(); + F->eraseFromParent(); + } +} +} // namespace + +void applyGPUTransformations(llvm::Module *M, const std::string &ptxFilename) { + llvm::LLVMContext &context = M->getContext(); + std::unique_ptr clone = llvm::CloneModule(*M); + clone->setTargetTriple(llvm::Triple::normalize(GPU_TRIPLE)); + clone->setDataLayout(GPU_DL); + + llvm::NamedMDNode *nvvmAnno = clone->getOrInsertNamedMetadata("nvvm.annotations"); + std::vector kernels; + + for (auto &F : *clone) { + if (!F.hasFnAttribute("kernel")) + continue; + + llvm::Metadata *nvvmElem[] = { + llvm::ConstantAsMetadata::get(&F), + llvm::MDString::get(context, "kernel"), + llvm::ConstantAsMetadata::get( + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)), + }; + + nvvmAnno->addOperand(llvm::MDNode::get(context, nvvmElem)); + kernels.push_back(&F); + } + + if (kernels.empty()) + return; + + std::string filename = ptxFilename.empty() ? M->getSourceFileName() : ptxFilename; + if (filename.empty() || filename[0] == '<') + filename = "kernel"; + llvm::SmallString<128> path(filename); + llvm::sys::path::replace_extension(path, "ptx"); + filename = path.str(); + + moduleToPTX(clone.get(), filename, kernels); + cleanUpIntrinsics(M); + addInitCall(M, filename); +} + +} // namespace ir +} // namespace codon diff --git a/codon/sir/llvm/gpu.h b/codon/sir/llvm/gpu.h new file mode 100644 index 00000000..b45fd28b --- /dev/null +++ b/codon/sir/llvm/gpu.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "codon/sir/llvm/llvm.h" + +namespace codon { +namespace ir { + +/// Applies GPU-specific transformations and generates PTX +/// code from kernel functions in the given LLVM module. +/// @param module LLVM module containing GPU kernel functions (marked with "kernel" +/// annotation) +/// @param ptxFilename Filename for output PTX code; empty to use filename based on +/// module +void applyGPUTransformations(llvm::Module *module, const std::string &ptxFilename = ""); + +} // namespace ir +} // namespace codon diff --git a/codon/sir/llvm/llvisitor.cpp b/codon/sir/llvm/llvisitor.cpp index 488c71db..d1c09fc9 100644 --- a/codon/sir/llvm/llvisitor.cpp +++ b/codon/sir/llvm/llvisitor.cpp @@ -22,6 +22,7 @@ namespace { const std::string EXPORT_ATTR = "std.internal.attributes.export"; const std::string INLINE_ATTR = "std.internal.attributes.inline"; const std::string NOINLINE_ATTR = "std.internal.attributes.noinline"; +const std::string GPU_KERNEL_ATTR = "std.gpu.kernel"; } // namespace llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { @@ -41,6 +42,8 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { std::string LLVMVisitor::getNameForFunction(const Func *x) { if (isA(x) || util::hasAttribute(x, EXPORT_ATTR)) { return x->getUnmangledName(); + } else if (util::hasAttribute(x, GPU_KERNEL_ATTR)) { + return x->getName(); } else { return x->referenceString(); } @@ -49,7 +52,7 @@ std::string LLVMVisitor::getNameForFunction(const Func *x) { LLVMVisitor::LLVMVisitor() : util::ConstVisitor(), context(std::make_unique()), M(), B(std::make_unique>(*context)), func(nullptr), block(nullptr), - value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), db(), + value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), catches(), db(), plugins(nullptr) { llvm::InitializeAllTargets(); llvm::InitializeAllTargetMCs(); @@ -287,6 +290,7 @@ LLVMVisitor::takeModule(Module *module, const SrcInfo *src) { coro.reset(); loops.clear(); trycatch.clear(); + catches.clear(); db.reset(); context = std::make_unique(); @@ -697,6 +701,16 @@ void LLVMVisitor::exitTryCatch() { trycatch.pop_back(); } +void LLVMVisitor::enterCatch(CatchData data) { + catches.push_back(std::move(data)); + catches.back().sequenceNumber = nextSequenceNumber++; +} + +void LLVMVisitor::exitCatch() { + seqassertn(!catches.empty(), "no catches present"); + catches.pop_back(); +} + LLVMVisitor::TryCatchData *LLVMVisitor::getInnermostTryCatch() { return trycatch.empty() ? nullptr : &trycatch.back(); } @@ -1119,7 +1133,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) { err.print("LLVM", buf); // LOG("-> ERR {}", x->referenceString()); // LOG(" {}", code); - compilationError(buf.str()); + compilationError(fmt::format("{} ({})", buf.str(), x->getName())); } sub->setDataLayout(M->getDataLayout()); @@ -1152,6 +1166,9 @@ void LLVMVisitor::visit(const BodiedFunc *x) { setDebugInfoForNode(x); auto *fnAttributes = x->getAttribute(); + if (x->isJIT()) { + func->addFnAttr(llvm::Attribute::get(*context, "jit")); + } if (x->isJIT() || (fnAttributes && fnAttributes->has(EXPORT_ATTR))) { func->setLinkage(llvm::GlobalValue::ExternalLinkage); } else { @@ -1163,6 +1180,11 @@ void LLVMVisitor::visit(const BodiedFunc *x) { if (fnAttributes && fnAttributes->has(NOINLINE_ATTR)) { func->addFnAttr(llvm::Attribute::AttrKind::NoInline); } + if (fnAttributes && fnAttributes->has(GPU_KERNEL_ATTR)) { + func->addFnAttr(llvm::Attribute::AttrKind::NoInline); + func->addFnAttr(llvm::Attribute::get(*context, "kernel")); + func->setLinkage(llvm::GlobalValue::ExternalLinkage); + } func->setPersonalityFn(llvm::cast(makePersonalityFunc().getCallee())); auto *funcType = cast(x->getType()); @@ -1352,6 +1374,10 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) { return B->getDoubleTy(); } + if (auto *x = cast(t)) { + return B->getFloatTy(); + } + if (auto *x = cast(t)) { return B->getInt8Ty(); } @@ -1406,6 +1432,11 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) { return B->getIntNTy(x->getLen()); } + if (auto *x = cast(t)) { + return llvm::VectorType::get(getLLVMType(x->getBase()), x->getCount(), + /*Scalable=*/false); + } + if (auto *x = cast(t)) { return x->getBuilder()->buildType(this); } @@ -1429,6 +1460,11 @@ llvm::DIType *LLVMVisitor::getDITypeHelper( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); } + if (auto *x = cast(t)) { + return db.builder->createBasicType( + x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float); + } + if (auto *x = cast(t)) { return db.builder->createBasicType( x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean); @@ -1554,6 +1590,12 @@ llvm::DIType *LLVMVisitor::getDITypeHelper( x->isSigned() ? llvm::dwarf::DW_ATE_signed : llvm::dwarf::DW_ATE_unsigned); } + if (auto *x = cast(t)) { + return db.builder->createBasicType(x->getName(), + layout.getTypeAllocSizeInBits(type), + llvm::dwarf::DW_ATE_unsigned); + } + if (auto *x = cast(t)) { return x->getBuilder()->buildDebugType(this); } @@ -2094,7 +2136,12 @@ void LLVMVisitor::visit(const TryCatchFlow *x) { } B->CreateStore(excStateCaught, tc.excFlag); + CatchData cd; + cd.exception = objPtr; + cd.typeId = objType; + enterCatch(cd); process(catches[i]->getHandler()); + exitCatch(); B->SetInsertPoint(block); B->CreateBr(tc.finallyBlock); } @@ -2455,10 +2502,21 @@ void LLVMVisitor::visit(const ThrowInstr *x) { // note: exception header should be set in the frontend auto excAllocFunc = makeExcAllocFunc(); auto throwFunc = makeThrowFunc(); - process(x->getValue()); + llvm::Value *obj = nullptr; + llvm::Value *typ = nullptr; + + if (x->getValue()) { + process(x->getValue()); + obj = value; + typ = B->getInt32(getTypeIdx(x->getValue()->getType())); + } else { + seqassertn(!catches.empty(), "empty raise outside of except block"); + obj = catches.back().exception; + typ = catches.back().typeId; + } + B->SetInsertPoint(block); - llvm::Value *exc = B->CreateCall( - excAllocFunc, {B->getInt32(getTypeIdx(x->getValue()->getType())), value}); + llvm::Value *exc = B->CreateCall(excAllocFunc, {typ, obj}); call(throwFunc, exc); } diff --git a/codon/sir/llvm/llvisitor.h b/codon/sir/llvm/llvisitor.h index 443dae62..08dd43b5 100644 --- a/codon/sir/llvm/llvisitor.h +++ b/codon/sir/llvm/llvisitor.h @@ -89,6 +89,11 @@ private: } }; + struct CatchData : NestableData { + llvm::Value *exception; + llvm::Value *typeId; + }; + struct DebugInfo { /// LLVM debug info builder std::unique_ptr builder; @@ -139,6 +144,8 @@ private: std::vector loops; /// Try-catch data stack std::vector trycatch; + /// Catch-block data stack + std::vector catches; /// Debug information DebugInfo db; /// Plugin manager @@ -181,6 +188,8 @@ private: void exitLoop(); void enterTryCatch(TryCatchData data); void exitTryCatch(); + void enterCatch(CatchData data); + void exitCatch(); TryCatchData *getInnermostTryCatch(); TryCatchData *getInnermostTryCatchBeforeLoop(); diff --git a/codon/sir/llvm/llvm.h b/codon/sir/llvm/llvm.h index f8c2b21f..9c1220fd 100644 --- a/codon/sir/llvm/llvm.h +++ b/codon/sir/llvm/llvm.h @@ -32,6 +32,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" diff --git a/codon/sir/llvm/optimize.cpp b/codon/sir/llvm/optimize.cpp index 2013e213..2e6ae32c 100644 --- a/codon/sir/llvm/optimize.cpp +++ b/codon/sir/llvm/optimize.cpp @@ -3,6 +3,7 @@ #include #include "codon/sir/llvm/coro/Coroutines.h" +#include "codon/sir/llvm/gpu.h" #include "codon/util/common.h" #include "llvm/CodeGen/CommandFlags.h" @@ -642,13 +643,17 @@ void verify(llvm::Module *module) { void optimize(llvm::Module *module, bool debug, bool jit, PluginManager *plugins) { verify(module); { - TIME("llvm/opt"); + TIME("llvm/opt1"); runLLVMOptimizationPasses(module, debug, jit, plugins); } if (!debug) { TIME("llvm/opt2"); runLLVMOptimizationPasses(module, debug, jit, plugins); } + { + TIME("llvm/gpu"); + applyGPUTransformations(module); + } verify(module); } diff --git a/codon/sir/module.cpp b/codon/sir/module.cpp index f9616a8c..bec395bf 100644 --- a/codon/sir/module.cpp +++ b/codon/sir/module.cpp @@ -57,6 +57,7 @@ const std::string Module::BOOL_NAME = "bool"; const std::string Module::BYTE_NAME = "byte"; const std::string Module::INT_NAME = "int"; const std::string Module::FLOAT_NAME = "float"; +const std::string Module::FLOAT32_NAME = "float32"; const std::string Module::STRING_NAME = "str"; const std::string Module::EQ_MAGIC_NAME = "__eq__"; @@ -126,7 +127,9 @@ Func *Module::getOrRealizeMethod(types::Type *parent, const std::string &methodN return cache->realizeFunction(method, translateArgs(args), translateGenerics(generics), cls); } catch (const exc::ParserException &e) { - LOG_IR("getOrRealizeMethod parser error: {}", e.what()); + for (int i = 0; i < e.messages.size(); i++) + LOG_IR("getOrRealizeMethod parser error at {}: {}", e.locations[i], + e.messages[i]); return nullptr; } } @@ -162,7 +165,8 @@ types::Type *Module::getOrRealizeType(const std::string &typeName, try { return cache->realizeType(type, translateGenerics(generics)); } catch (const exc::ParserException &e) { - LOG_IR("getOrRealizeType parser error: {}", e.what()); + for (int i = 0; i < e.messages.size(); i++) + LOG_IR("getOrRealizeType parser error at {}: {}", e.locations[i], e.messages[i]); return nullptr; } } @@ -197,6 +201,12 @@ types::Type *Module::getFloatType() { return Nr(); } +types::Type *Module::getFloat32Type() { + if (auto *rVal = getType(FLOAT32_NAME)) + return rVal; + return Nr(); +} + types::Type *Module::getStringType() { if (auto *rVal = getType(STRING_NAME)) return rVal; @@ -243,6 +253,10 @@ types::Type *Module::getIntNType(unsigned int len, bool sign) { return getOrRealizeType(sign ? "Int" : "UInt", {len}); } +types::Type *Module::getVectorType(unsigned count, types::Type *base) { + return getOrRealizeType("Vec", {base, count}); +} + types::Type *Module::getTupleType(std::vector args) { std::vector argTypes; for (auto *t : args) { @@ -332,5 +346,14 @@ types::Type *Module::unsafeGetIntNType(unsigned int len, bool sign) { return Nr(len, sign); } +types::Type *Module::unsafeGetVectorType(unsigned int count, types::Type *base) { + auto *primitive = cast(base); + auto name = types::VectorType::getInstanceName(count, primitive); + if (auto *rVal = getType(name)) + return rVal; + seqassertn(primitive, "base type must be a primitive type"); + return Nr(count, primitive); +} + } // namespace ir } // namespace codon diff --git a/codon/sir/module.h b/codon/sir/module.h index 28bbd21d..d48edaba 100644 --- a/codon/sir/module.h +++ b/codon/sir/module.h @@ -31,6 +31,7 @@ public: static const std::string BYTE_NAME; static const std::string INT_NAME; static const std::string FLOAT_NAME; + static const std::string FLOAT32_NAME; static const std::string STRING_NAME; static const std::string EQ_MAGIC_NAME; @@ -305,6 +306,8 @@ public: types::Type *getIntType(); /// @return the float type types::Type *getFloatType(); + /// @return the float32 type + types::Type *getFloat32Type(); /// @return the string type types::Type *getStringType(); /// Gets a pointer type. @@ -335,6 +338,11 @@ public: /// @param sign true if signed /// @return a variable length integer type types::Type *getIntNType(unsigned len, bool sign); + /// Gets a vector type. + /// @param count the vector size + /// @param base the vector base type (MUST be a primitive type) + /// @return a vector type + types::Type *getVectorType(unsigned count, types::Type *base); /// Gets a tuple type. /// @param args the arg types /// @return the tuple type @@ -401,6 +409,12 @@ public: /// @param sign true if signed /// @return a variable length integer type types::Type *unsafeGetIntNType(unsigned len, bool sign); + /// Gets a vector type. Should generally not be used as no + /// type-checker information is generated. + /// @param count the vector size + /// @param base the vector base type (MUST be a primitive type) + /// @return a vector type + types::Type *unsafeGetVectorType(unsigned count, types::Type *base); private: void store(types::Type *t) { diff --git a/codon/sir/transform/cleanup/canonical.cpp b/codon/sir/transform/cleanup/canonical.cpp index b00fdeb4..3fd3fd82 100644 --- a/codon/sir/transform/cleanup/canonical.cpp +++ b/codon/sir/transform/cleanup/canonical.cpp @@ -281,6 +281,9 @@ struct CanonConstSub : public RewriteRule { Value *lhs = v->front(); Value *rhs = v->back(); + if (!lhs->getType()->is(rhs->getType())) + return; + Value *newCall = nullptr; if (util::isConst(rhs)) { auto c = util::getConst(rhs); diff --git a/codon/sir/transform/folding/const_fold.cpp b/codon/sir/transform/folding/const_fold.cpp index 3701c8b0..511273b1 100644 --- a/codon/sir/transform/folding/const_fold.cpp +++ b/codon/sir/transform/folding/const_fold.cpp @@ -1,6 +1,7 @@ #include "const_fold.h" #include +#include #include "codon/sir/util/cloning.h" #include "codon/sir/util/irtools.h" @@ -15,15 +16,28 @@ namespace ir { namespace transform { namespace folding { namespace { +auto pyDivmod(int64_t self, int64_t other) { + auto d = self / other; + auto m = self - d * other; + if (m && ((other ^ m) < 0)) { + m += other; + d -= 1; + } + return std::make_pair(d, m); +} + template class IntFloatBinaryRule : public RewriteRule { private: Func f; std::string magic; types::Type *out; + bool excludeRHSZero; public: - IntFloatBinaryRule(Func f, std::string magic, types::Type *out) - : f(std::move(f)), magic(std::move(magic)), out(out) {} + IntFloatBinaryRule(Func f, std::string magic, types::Type *out, + bool excludeRHSZero = false) + : f(std::move(f)), magic(std::move(magic)), out(out), + excludeRHSZero(excludeRHSZero) {} virtual ~IntFloatBinaryRule() noexcept = default; @@ -41,11 +55,15 @@ public: if (isA(leftConst) && isA(rightConst)) { auto left = cast(leftConst)->getVal(); auto right = cast(rightConst)->getVal(); + if (excludeRHSZero && right == 0) + return; return setResult(M->template N>(v->getSrcInfo(), f(left, (double)right), out)); } else if (isA(leftConst) && isA(rightConst)) { auto left = cast(leftConst)->getVal(); auto right = cast(rightConst)->getVal(); + if (excludeRHSZero && right == 0.0) + return; return setResult(M->template N>(v->getSrcInfo(), f((double)left, right), out)); } @@ -140,15 +158,22 @@ template auto floatToFloatBinary(Module *m, Func f, std::string std::move(f), std::move(magic), m->getFloatType(), m->getFloatType()); } +template +auto floatToFloatBinaryNoZeroRHS(Module *m, Func f, std::string magic) { + return std::make_unique>( + std::move(f), std::move(magic), m->getFloatType(), m->getFloatType()); +} + template auto floatToBoolBinary(Module *m, Func f, std::string magic) { return std::make_unique>( std::move(f), std::move(magic), m->getFloatType(), m->getBoolType()); } template -auto intFloatToFloatBinary(Module *m, Func f, std::string magic) { +auto intFloatToFloatBinary(Module *m, Func f, std::string magic, + bool excludeRHSZero = false) { return std::make_unique>( - std::move(f), std::move(magic), m->getFloatType()); + std::move(f), std::move(magic), m->getFloatType(), excludeRHSZero); } template @@ -222,8 +247,15 @@ void FoldingPass::registerStandardRules(Module *m) { intToIntBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("int-constant-subtraction", intToIntBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); - registerRule("int-constant-floor-div", - intToIntBinaryNoZeroRHS(m, BINOP(/), Module::FLOOR_DIV_MAGIC_NAME)); + if (pyNumerics) { + registerRule("int-constant-floor-div", + intToIntBinaryNoZeroRHS( + m, [](auto x, auto y) -> auto{ return pyDivmod(x, y).first; }, + Module::FLOOR_DIV_MAGIC_NAME)); + } else { + registerRule("int-constant-floor-div", + intToIntBinaryNoZeroRHS(m, BINOP(/), Module::FLOOR_DIV_MAGIC_NAME)); + } registerRule("int-constant-mul", intToIntBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); registerRule("int-constant-lshift", intToIntBinary(m, BINOP(<<), Module::LSHIFT_MAGIC_NAME)); @@ -233,8 +265,15 @@ void FoldingPass::registerStandardRules(Module *m) { registerRule("int-constant-xor", intToIntBinary(m, BINOP(^), Module::XOR_MAGIC_NAME)); registerRule("int-constant-or", intToIntBinary(m, BINOP(|), Module::OR_MAGIC_NAME)); registerRule("int-constant-and", intToIntBinary(m, BINOP(&), Module::AND_MAGIC_NAME)); - registerRule("int-constant-mod", - intToIntBinaryNoZeroRHS(m, BINOP(%), Module::MOD_MAGIC_NAME)); + if (pyNumerics) { + registerRule("int-constant-mod", + intToIntBinaryNoZeroRHS( + m, [](auto x, auto y) -> auto{ return pyDivmod(x, y).second; }, + Module::MOD_MAGIC_NAME)); + } else { + registerRule("int-constant-mod", + intToIntBinaryNoZeroRHS(m, BINOP(%), Module::MOD_MAGIC_NAME)); + } // binary, double constant, int->bool registerRule("int-constant-eq", intToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME)); @@ -273,8 +312,13 @@ void FoldingPass::registerStandardRules(Module *m) { floatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("float-constant-subtraction", floatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); - registerRule("float-constant-floor-div", - floatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); + if (pyNumerics) { + registerRule("float-constant-floor-div", + floatToFloatBinaryNoZeroRHS(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); + } else { + registerRule("float-constant-floor-div", + floatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); + } registerRule("float-constant-mul", floatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); registerRule( @@ -301,8 +345,9 @@ void FoldingPass::registerStandardRules(Module *m) { intFloatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME)); registerRule("int-float-constant-subtraction", intFloatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME)); - registerRule("int-float-constant-floor-div", - intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME)); + registerRule( + "int-float-constant-floor-div", + intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME, pyNumerics)); registerRule("int-float-constant-mul", intFloatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME)); diff --git a/codon/sir/transform/folding/const_fold.h b/codon/sir/transform/folding/const_fold.h index 5670f10f..a9456240 100644 --- a/codon/sir/transform/folding/const_fold.h +++ b/codon/sir/transform/folding/const_fold.h @@ -12,18 +12,21 @@ namespace transform { namespace folding { class FoldingPass : public OperatorPass, public Rewriter { +private: + bool pyNumerics; + + void registerStandardRules(Module *m); + public: /// Constructs a folding pass. - FoldingPass() : OperatorPass(/*childrenFirst=*/true) {} + FoldingPass(bool pyNumerics = false) + : OperatorPass(/*childrenFirst=*/true), pyNumerics(pyNumerics) {} static const std::string KEY; std::string getKey() const override { return KEY; } void run(Module *m) override; void handle(CallInstr *v) override; - -private: - void registerStandardRules(Module *m); }; } // namespace folding diff --git a/codon/sir/transform/folding/folding.cpp b/codon/sir/transform/folding/folding.cpp index ea1d7839..ce2f064a 100644 --- a/codon/sir/transform/folding/folding.cpp +++ b/codon/sir/transform/folding/folding.cpp @@ -13,12 +13,12 @@ const std::string FoldingPassGroup::KEY = "core-folding-pass-group"; FoldingPassGroup::FoldingPassGroup(const std::string &sideEffectsPass, const std::string &reachingDefPass, const std::string &globalVarPass, int repeat, - bool runGlobalDemotion) + bool runGlobalDemotion, bool pyNumerics) : PassGroup(repeat) { auto gdUnique = runGlobalDemotion ? std::make_unique() : std::unique_ptr(); auto canonUnique = std::make_unique(sideEffectsPass); - auto fpUnique = std::make_unique(); + auto fpUnique = std::make_unique(pyNumerics); auto dceUnique = std::make_unique(sideEffectsPass); gd = gdUnique.get(); diff --git a/codon/sir/transform/folding/folding.h b/codon/sir/transform/folding/folding.h index 45addefb..fcab7db7 100644 --- a/codon/sir/transform/folding/folding.h +++ b/codon/sir/transform/folding/folding.h @@ -29,9 +29,11 @@ public: /// @param globalVarPass the key of the global variables pass /// @param repeat default number of times to repeat the pass /// @param runGlobalDemotion whether to demote globals if possible + /// @param pyNumerics whether to use Python (vs. C) semantics when folding FoldingPassGroup(const std::string &sideEffectsPass, const std::string &reachingDefPass, const std::string &globalVarPass, - int repeat = 5, bool runGlobalDemotion = true); + int repeat = 5, bool runGlobalDemotion = true, + bool pyNumerics = false); bool shouldRepeat(int num) const override; }; diff --git a/codon/sir/transform/manager.cpp b/codon/sir/transform/manager.cpp index 61ed2eb8..ba47dbbf 100644 --- a/codon/sir/transform/manager.cpp +++ b/codon/sir/transform/manager.cpp @@ -185,11 +185,11 @@ void PassManager::registerStandardPasses(PassManager::Init init) { capKey, /*globalAssignmentHasSideEffects=*/false), {capKey}); - registerPass( - std::make_unique( - seKey1, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/false), - /*insertBefore=*/"", {seKey1, rdKey, globalKey}, - {seKey1, rdKey, cfgKey, globalKey, capKey}); + registerPass(std::make_unique( + seKey1, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/false, + pyNumerics), + /*insertBefore=*/"", {seKey1, rdKey, globalKey}, + {seKey1, rdKey, cfgKey, globalKey, capKey}); // parallel registerPass(std::make_unique(), /*insertBefore=*/"", {}, @@ -198,12 +198,12 @@ void PassManager::registerStandardPasses(PassManager::Init init) { if (init != Init::JIT) { // Don't demote globals in JIT mode, since they might be used later // by another user input. - registerPass( - std::make_unique(seKey2, rdKey, globalKey, - /*repeat=*/5, - /*runGlobalDemoton=*/true), - /*insertBefore=*/"", {seKey2, rdKey, globalKey}, - {seKey2, rdKey, cfgKey, globalKey}); + registerPass(std::make_unique( + seKey2, rdKey, globalKey, + /*repeat=*/5, + /*runGlobalDemoton=*/true, pyNumerics), + /*insertBefore=*/"", {seKey2, rdKey, globalKey}, + {seKey2, rdKey, cfgKey, globalKey}); } break; } diff --git a/codon/sir/transform/manager.h b/codon/sir/transform/manager.h index 0180810c..e925b9b0 100644 --- a/codon/sir/transform/manager.h +++ b/codon/sir/transform/manager.h @@ -89,6 +89,9 @@ private: /// passes to avoid registering std::vector disabled; + /// whether to use Python (vs. C) numeric semantics in passes + bool pyNumerics; + public: /// PassManager initialization mode. enum Init { @@ -98,14 +101,17 @@ public: JIT, }; - explicit PassManager(Init init, std::vector disabled = {}) + explicit PassManager(Init init, std::vector disabled = {}, + bool pyNumerics = false) : km(), passes(), analyses(), executionOrder(), results(), - disabled(std::move(disabled)) { + disabled(std::move(disabled)), pyNumerics(pyNumerics) { registerStandardPasses(init); } - explicit PassManager(bool debug = false, std::vector disabled = {}) - : PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled)) {} + explicit PassManager(bool debug = false, std::vector disabled = {}, + bool pyNumerics = false) + : PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled), + pyNumerics) {} /// Checks if the given pass is included in this manager. /// @param key the pass key diff --git a/codon/sir/transform/parallel/openmp.cpp b/codon/sir/transform/parallel/openmp.cpp index 629ecc39..0a4b397b 100644 --- a/codon/sir/transform/parallel/openmp.cpp +++ b/codon/sir/transform/parallel/openmp.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "codon/sir/transform/parallel/schedule.h" #include "codon/sir/util/cloning.h" @@ -15,14 +16,22 @@ namespace transform { namespace parallel { namespace { const std::string ompModule = "std.openmp"; +const std::string gpuModule = "std.gpu"; const std::string builtinModule = "std.internal.builtin"; +void warn(const std::string &msg, const Value *v) { + auto src = v->getSrcInfo(); + compilationWarning(msg, src.file, src.line, src.col); +} + struct OMPTypes { + types::Type *i64 = nullptr; types::Type *i32 = nullptr; types::Type *i8ptr = nullptr; types::Type *i32ptr = nullptr; explicit OMPTypes(Module *M) { + i64 = M->getIntType(); i32 = M->getIntNType(32, /*sign=*/true); i8ptr = M->getPointerType(M->getByteType()); i32ptr = M->getPointerType(i32); @@ -142,6 +151,28 @@ struct Reduction { default: return nullptr; } + } else if (isA(type)) { + auto *f32 = M->getOrRealizeType("float32"); + float value = 0.0; + + switch (kind) { + case Kind::ADD: + value = 0.0; + break; + case Kind::MUL: + value = 1.0; + break; + case Kind::MIN: + value = std::numeric_limits::max(); + break; + case Kind::MAX: + value = std::numeric_limits::min(); + break; + default: + return nullptr; + } + + return (*f32)(*M->getFloat(value)); } auto *init = (*type)(); @@ -239,6 +270,23 @@ struct Reduction { default: break; } + } else if (isA(type)) { + switch (kind) { + case Kind::ADD: + func = "_atomic_float32_add"; + break; + case Kind::MUL: + func = "_atomic_float32_mul"; + break; + case Kind::MIN: + func = "_atomic_float32_min"; + break; + case Kind::MAX: + func = "_atomic_float32_max"; + break; + default: + break; + } } if (!func.empty()) { @@ -476,10 +524,16 @@ struct SharedInfo { Reduction reduction; // the reduction we're performing, or empty if none }; -struct ParallelLoopTemplateReplacer : public util::Operator { +struct LoopTemplateReplacer : public util::Operator { BodiedFunc *parent; CallInstr *replacement; Var *loopVar; + + LoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar) + : util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar) {} +}; + +struct ParallelLoopTemplateReplacer : public LoopTemplateReplacer { ReductionIdentifier *reds; std::vector sharedInfo; ReductionLocks locks; @@ -489,9 +543,8 @@ struct ParallelLoopTemplateReplacer : public util::Operator { ParallelLoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar, ReductionIdentifier *reds) - : util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar), - reds(reds), sharedInfo(), locks(), locRef(nullptr), reductionLocRef(nullptr), - gtid(nullptr) {} + : LoopTemplateReplacer(parent, replacement, loopVar), reds(reds), sharedInfo(), + locks(), locRef(nullptr), reductionLocRef(nullptr), gtid(nullptr) {} unsigned numReductions() { unsigned num = 0; @@ -1102,6 +1155,72 @@ struct TaskLoopRoutineStubReplacer : public ParallelLoopTemplateReplacer { } }; +struct GPULoopBodyStubReplacer : public util::Operator { + CallInstr *replacement; + Var *loopVar; + int64_t step; + + GPULoopBodyStubReplacer(CallInstr *replacement, Var *loopVar, int64_t step) + : util::Operator(), replacement(replacement), loopVar(loopVar), step(step) {} + + void handle(CallInstr *v) override { + auto *M = v->getModule(); + auto *func = util::getFunc(v->getCallee()); + if (!func) + return; + auto name = func->getUnmangledName(); + + if (name == "_gpu_loop_body_stub") { + seqassertn(replacement, "unexpected double replacement"); + seqassertn(v->numArgs() == 2, "unexpected loop body stub"); + + // the template passes gtid, privs and shareds to the body stub for convenience + auto *idx = v->front(); + auto *args = v->back(); + unsigned next = 0; + + std::vector newArgs; + for (auto *arg : *replacement) { + // std::cout << "A: " << *arg << std::endl; + if (getVarFromOutlinedArg(arg)->getId() == loopVar->getId()) { + // std::cout << "(loop var)" << std::endl; + newArgs.push_back(idx); + } else { + newArgs.push_back(util::tupleGet(args, next++)); + } + } + + auto *outlinedFunc = cast(util::getFunc(replacement->getCallee())); + v->replaceAll(util::call(outlinedFunc, newArgs)); + replacement = nullptr; + } + + if (name == "_loop_step") { + v->replaceAll(M->getInt(step)); + } + } +}; + +struct GPULoopTemplateReplacer : public LoopTemplateReplacer { + int64_t step; + + GPULoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar, + int64_t step) + : LoopTemplateReplacer(parent, replacement, loopVar), step(step) {} + + void handle(CallInstr *v) override { + auto *M = v->getModule(); + auto *func = util::getFunc(v->getCallee()); + if (!func) + return; + auto name = func->getUnmangledName(); + + if (name == "_loop_step") { + v->replaceAll(M->getInt(step)); + } + } +}; + struct OpenMPTransformData { util::OutlineResult outline; std::vector sharedVars; @@ -1164,15 +1283,118 @@ ForkCallData createForkCall(Module *M, OMPTypes &types, Value *rawTemplateFunc, seqassertn(forkFunc, "fork call function not found"); result.fork = util::call(forkFunc, {rawTemplateFunc, forkExtra}); - auto *intType = M->getIntType(); - if (sched->threads && sched->threads->getType()->is(intType)) { + if (sched->threads && sched->threads->getType()->is(types.i64)) { auto *pushNumThreadsFunc = - M->getOrRealizeFunc("_push_num_threads", {intType}, {}, ompModule); + M->getOrRealizeFunc("_push_num_threads", {types.i64}, {}, ompModule); seqassertn(pushNumThreadsFunc, "push num threads func not found"); result.pushNumThreads = util::call(pushNumThreadsFunc, {sched->threads}); } return result; } + +struct CollapseResult { + ImperativeForFlow *collapsed = nullptr; + SeriesFlow *setup = nullptr; + std::string error; + + operator bool() const { return collapsed != nullptr; } +}; + +struct LoopRange { + ImperativeForFlow *loop; + Var *start; + Var *stop; + int64_t step; + Var *len; +}; + +CollapseResult collapseLoop(BodiedFunc *parent, ImperativeForFlow *v, int64_t levels) { + auto fail = [](const std::string &error) { + CollapseResult bad; + bad.error = error; + return bad; + }; + + auto *M = v->getModule(); + CollapseResult res; + if (levels < 1) + return fail("'collapse' must be at least 1"); + + std::vector loopNests = {v}; + ImperativeForFlow *curr = v; + + for (auto i = 0; i < levels - 1; i++) { + auto *body = cast(curr->getBody()); + seqassertn(body, "unexpected loop body"); + if (std::distance(body->begin(), body->end()) != 1 || + !isA(body->front())) + return fail("loop nest not collapsible"); + + curr = cast(body->front()); + loopNests.push_back(curr); + } + + std::vector ranges; + auto *setup = M->Nr(); + + auto *intType = M->getIntType(); + auto *lenCalc = + M->getOrRealizeFunc("_range_len", {intType, intType, intType}, {}, ompModule); + seqassertn(lenCalc, "range length calculation function not found"); + + for (auto *loop : loopNests) { + LoopRange range; + range.loop = loop; + range.start = util::makeVar(loop->getStart(), setup, parent)->getVar(); + range.stop = util::makeVar(loop->getEnd(), setup, parent)->getVar(); + range.step = loop->getStep(); + range.len = util::makeVar(util::call(lenCalc, {M->Nr(range.start), + M->Nr(range.stop), + M->getInt(range.step)}), + setup, parent) + ->getVar(); + ranges.push_back(range); + } + + auto *numIters = M->getInt(1); + for (auto &range : ranges) { + numIters = (*numIters) * (*M->Nr(range.len)); + } + + auto *collapsedVar = M->Nr(M->getIntType(), /*global=*/false); + parent->push_back(collapsedVar); + auto *body = M->Nr(); + auto sched = std::make_unique(*v->getSchedule()); + sched->collapse = 0; + auto *collapsed = M->Nr(M->getInt(0), 1, numIters, body, + collapsedVar, std::move(sched)); + + // reconstruct indices by successive divmods + Var *lastDiv = nullptr; + for (auto it = ranges.rbegin(); it != ranges.rend(); ++it) { + auto *k = lastDiv ? lastDiv : collapsedVar; + auto *div = + util::makeVar(*M->Nr(k) / *M->Nr(it->len), body, parent) + ->getVar(); + auto *mod = + util::makeVar(*M->Nr(k) % *M->Nr(it->len), body, parent) + ->getVar(); + auto *i = + *M->Nr(it->start) + *(*M->Nr(mod) * *M->getInt(it->step)); + body->push_back(M->Nr(it->loop->getVar(), i)); + lastDiv = div; + } + + auto *oldBody = cast(loopNests.back()->getBody()); + for (auto *x : *oldBody) { + body->push_back(x); + } + + res.collapsed = collapsed; + res.setup = setup; + + return res; +} } // namespace const std::string OpenMPPass::KEY = "core-parallel-openmp"; @@ -1279,7 +1501,22 @@ void OpenMPPass::handle(ForFlow *v) { } void OpenMPPass::handle(ImperativeForFlow *v) { - auto data = setupOpenMPTransform(v, cast(getParentFunc())); + auto *parent = cast(getParentFunc()); + + if (v->isParallel() && v->getSchedule()->collapse != 0) { + auto levels = v->getSchedule()->collapse; + auto collapse = collapseLoop(parent, v, levels); + + if (collapse) { + v->replaceAll(collapse.collapsed); + v = collapse.collapsed; + insertBefore(collapse.setup); + } else if (!collapse.error.empty()) { + warn("could not collapse loop: " + collapse.error, v); + } + } + + auto data = setupOpenMPTransform(v, parent); if (!v->isParallel()) return; @@ -1292,6 +1529,12 @@ void OpenMPPass::handle(ImperativeForFlow *v) { auto *sched = v->getSchedule(); OMPTypes types(M); + if (sched->gpu && !sharedVars.empty()) { + warn("GPU-parallel loop cannot modify external variables; ignoring", v); + v->setParallel(false); + return; + } + // gather extra arguments std::vector extraArgs; std::vector extraArgTypes; @@ -1304,41 +1547,88 @@ void OpenMPPass::handle(ImperativeForFlow *v) { // template call std::string templateFuncName; - if (sched->dynamic) { + if (sched->gpu) { + templateFuncName = "_gpu_loop_outline_template"; + } else if (sched->dynamic) { templateFuncName = "_dynamic_loop_outline_template"; } else if (sched->chunk) { templateFuncName = "_static_chunked_loop_outline_template"; } else { templateFuncName = "_static_loop_outline_template"; } - auto *intType = M->getIntType(); - std::vector templateFuncArgs = { - types.i32ptr, types.i32ptr, - M->getPointerType(M->getTupleType( - {intType, intType, intType, M->getTupleType(extraArgTypes)}))}; - auto *templateFunc = - M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {}, ompModule); - seqassertn(templateFunc, "imperative loop outline template not found"); - util::CloneVisitor cv(M); - templateFunc = cast(cv.forceClone(templateFunc)); - ImperativeLoopTemplateReplacer rep(cast(templateFunc), outline.call, - loopVar, &reds, sched, v->getStep()); - templateFunc->accept(rep); - auto *rawTemplateFunc = ptrFromFunc(templateFunc); + if (sched->gpu) { + std::unordered_set kernels; + const std::string gpuAttr = "std.gpu.kernel"; + for (auto *var : *M) { + if (auto *func = cast(var)) { + if (util::hasAttribute(func, gpuAttr)) + kernels.insert(func->getId()); + } + } - auto *chunk = (sched->chunk && sched->chunk->getType()->is(intType)) ? sched->chunk - : M->getInt(1); - std::vector forkExtraArgs = {chunk, v->getStart(), v->getEnd()}; - for (auto *arg : extraArgs) { - forkExtraArgs.push_back(arg); + std::vector templateFuncArgs = {types.i64, types.i64, + M->getTupleType(extraArgTypes)}; + static int64_t instance = 0; + auto *templateFunc = M->getOrRealizeFunc(templateFuncName, templateFuncArgs, + {instance++}, gpuModule); + + if (!templateFunc) { + warn("loop not compilable for GPU; ignoring", v); + v->setParallel(false); + return; + } + + BodiedFunc *kernel = nullptr; + for (auto *var : *M) { + if (auto *func = cast(var)) { + if (util::hasAttribute(func, gpuAttr) && kernels.count(func->getId()) == 0) { + seqassertn(!kernel, "multiple new kernels found after instantiation"); + kernel = func; + } + } + } + seqassertn(kernel, "no new kernel found"); + GPULoopBodyStubReplacer brep(outline.call, loopVar, v->getStep()); + kernel->accept(brep); + + util::CloneVisitor cv(M); + templateFunc = cast(cv.forceClone(templateFunc)); + GPULoopTemplateReplacer rep(cast(templateFunc), outline.call, loopVar, + v->getStep()); + templateFunc->accept(rep); + v->replaceAll(util::call( + templateFunc, {v->getStart(), v->getEnd(), util::makeTuple(extraArgs, M)})); + } else { + std::vector templateFuncArgs = { + types.i32ptr, types.i32ptr, + M->getPointerType(M->getTupleType( + {types.i64, types.i64, types.i64, M->getTupleType(extraArgTypes)}))}; + auto *templateFunc = + M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {}, ompModule); + seqassertn(templateFunc, "imperative loop outline template not found"); + + util::CloneVisitor cv(M); + templateFunc = cast(cv.forceClone(templateFunc)); + ImperativeLoopTemplateReplacer rep(cast(templateFunc), outline.call, + loopVar, &reds, sched, v->getStep()); + templateFunc->accept(rep); + auto *rawTemplateFunc = ptrFromFunc(templateFunc); + + auto *chunk = (sched->chunk && sched->chunk->getType()->is(types.i64)) + ? sched->chunk + : M->getInt(1); + std::vector forkExtraArgs = {chunk, v->getStart(), v->getEnd()}; + for (auto *arg : extraArgs) { + forkExtraArgs.push_back(arg); + } + + // fork call + auto forkData = createForkCall(M, types, rawTemplateFunc, forkExtraArgs, sched); + if (forkData.pushNumThreads) + insertBefore(forkData.pushNumThreads); + v->replaceAll(forkData.fork); } - - // fork call - auto forkData = createForkCall(M, types, rawTemplateFunc, forkExtraArgs, sched); - if (forkData.pushNumThreads) - insertBefore(forkData.pushNumThreads); - v->replaceAll(forkData.fork); } } // namespace parallel diff --git a/codon/sir/transform/parallel/schedule.cpp b/codon/sir/transform/parallel/schedule.cpp index 0306e7ec..c1b2c59b 100644 --- a/codon/sir/transform/parallel/schedule.cpp +++ b/codon/sir/transform/parallel/schedule.cpp @@ -47,17 +47,19 @@ Value *nullIfNeg(Value *v) { } } // namespace -OMPSched::OMPSched(int code, bool dynamic, Value *threads, Value *chunk, bool ordered) +OMPSched::OMPSched(int code, bool dynamic, Value *threads, Value *chunk, bool ordered, + int64_t collapse, bool gpu) : code(code), dynamic(dynamic), threads(nullIfNeg(threads)), - chunk(nullIfNeg(chunk)), ordered(ordered) { + chunk(nullIfNeg(chunk)), ordered(ordered), collapse(collapse), gpu(gpu) { if (code < 0) this->code = getScheduleCode(); } OMPSched::OMPSched(const std::string &schedule, Value *threads, Value *chunk, - bool ordered) + bool ordered, int64_t collapse, bool gpu) : OMPSched(getScheduleCode(schedule, nullIfNeg(chunk) != nullptr, ordered), - (schedule != "static") || ordered, threads, chunk, ordered) {} + (schedule != "static") || ordered, threads, chunk, ordered, collapse, + gpu) {} std::vector OMPSched::getUsedValues() const { std::vector ret; diff --git a/codon/sir/transform/parallel/schedule.h b/codon/sir/transform/parallel/schedule.h index fe956f73..44763bf6 100644 --- a/codon/sir/transform/parallel/schedule.h +++ b/codon/sir/transform/parallel/schedule.h @@ -16,14 +16,18 @@ struct OMPSched { Value *threads; Value *chunk; bool ordered; + int64_t collapse; + bool gpu; explicit OMPSched(int code = -1, bool dynamic = false, Value *threads = nullptr, - Value *chunk = nullptr, bool ordered = false); + Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0, + bool gpu = false); explicit OMPSched(const std::string &code, Value *threads = nullptr, - Value *chunk = nullptr, bool ordered = false); + Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0, + bool gpu = false); OMPSched(const OMPSched &s) : code(s.code), dynamic(s.dynamic), threads(s.threads), chunk(s.chunk), - ordered(s.ordered) {} + ordered(s.ordered), collapse(s.collapse), gpu(s.gpu) {} std::vector getUsedValues() const; int replaceUsedValue(id_t id, Value *newValue); diff --git a/codon/sir/types/types.cpp b/codon/sir/types/types.cpp index c373140c..970c1944 100644 --- a/codon/sir/types/types.cpp +++ b/codon/sir/types/types.cpp @@ -65,6 +65,8 @@ const char IntType::NodeId = 0; const char FloatType::NodeId = 0; +const char Float32Type::NodeId = 0; + const char BoolType::NodeId = 0; const char ByteType::NodeId = 0; @@ -188,6 +190,12 @@ std::string IntNType::getInstanceName(unsigned int len, bool sign) { return fmt::format(FMT_STRING("{}Int{}"), sign ? "" : "U", len); } +const char VectorType::NodeId = 0; + +std::string VectorType::getInstanceName(unsigned int count, PrimitiveType *base) { + return fmt::format(FMT_STRING("Vector[{}, {}]"), count, base->referenceString()); +} + } // namespace types } // namespace ir } // namespace codon diff --git a/codon/sir/types/types.h b/codon/sir/types/types.h index 2b2257c2..e2ef9567 100644 --- a/codon/sir/types/types.h +++ b/codon/sir/types/types.h @@ -138,6 +138,15 @@ public: FloatType() : AcceptorExtend("float") {} }; +/// Float32 type (32-bit float) +class Float32Type : public AcceptorExtend { +public: + static const char NodeId; + + /// Constructs a float32 type. + Float32Type() : AcceptorExtend("float32") {} +}; + /// Bool type (8-bit unsigned integer; either 0 or 1) class BoolType : public AcceptorExtend { public: @@ -424,7 +433,7 @@ private: }; /// Type of a variably sized integer -class IntNType : public AcceptorExtend { +class IntNType : public AcceptorExtend { private: /// length of the integer unsigned len; @@ -451,9 +460,31 @@ public: std::string oppositeSignName() const { return getInstanceName(len, !sign); } static std::string getInstanceName(unsigned len, bool sign); +}; +/// Type of a vector of primitives +class VectorType : public AcceptorExtend { private: - bool doIsAtomic() const override { return true; } + /// number of elements + unsigned count; + /// base type + PrimitiveType *base; + +public: + static const char NodeId; + + /// Constructs a vector type. + /// @param count the number of elements + /// @param base the base type + VectorType(unsigned count, PrimitiveType *base) + : AcceptorExtend(getInstanceName(count, base)), count(count), base(base) {} + + /// @return the count of the vector + unsigned getCount() const { return count; } + /// @return the base type of the vector + PrimitiveType *getBase() const { return base; } + + static std::string getInstanceName(unsigned count, PrimitiveType *base); }; } // namespace types diff --git a/codon/sir/util/format.cpp b/codon/sir/util/format.cpp index 14b59ca7..de56a75a 100644 --- a/codon/sir/util/format.cpp +++ b/codon/sir/util/format.cpp @@ -288,6 +288,9 @@ public: void visit(const types::FloatType *v) override { fmt::print(os, FMT_STRING("(float '\"{}\")"), v->referenceString()); } + void visit(const types::Float32Type *v) override { + fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString()); + } void visit(const types::BoolType *v) override { fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString()); } @@ -334,6 +337,10 @@ public: fmt::print(os, FMT_STRING("(intn '\"{}\" {} (signed {}))"), v->referenceString(), v->getLen(), v->isSigned()); } + void visit(const types::VectorType *v) override { + fmt::print(os, FMT_STRING("(vector '\"{}\" {} (count {}))"), v->referenceString(), + makeFormatter(v->getBase()), v->getCount()); + } void visit(const dsl::types::CustomType *v) override { v->doFormat(os); } void format(const Node *n) { diff --git a/codon/sir/util/matching.cpp b/codon/sir/util/matching.cpp index 82e51a9f..447c38c1 100644 --- a/codon/sir/util/matching.cpp +++ b/codon/sir/util/matching.cpp @@ -233,6 +233,8 @@ public: void handle(const types::IntType *, const types::IntType *) { result = true; } VISIT(types::FloatType); void handle(const types::FloatType *, const types::FloatType *) { result = true; } + VISIT(types::Float32Type); + void handle(const types::Float32Type *, const types::Float32Type *) { result = true; } VISIT(types::BoolType); void handle(const types::BoolType *, const types::BoolType *) { result = true; } VISIT(types::ByteType); @@ -272,6 +274,10 @@ public: void handle(const types::IntNType *x, const types::IntNType *y) { result = x->getLen() == y->getLen() && x->isSigned() == y->isSigned(); } + VISIT(types::VectorType); + void handle(const types::VectorType *x, const types::VectorType *y) { + result = x->getCount() == y->getCount() && process(x->getBase(), y->getBase()); + } VISIT(dsl::types::CustomType); void handle(const dsl::types::CustomType *x, const dsl::types::CustomType *y) { result = x->match(y); diff --git a/codon/sir/util/visitor.cpp b/codon/sir/util/visitor.cpp index 33c9caa6..f0273eb1 100644 --- a/codon/sir/util/visitor.cpp +++ b/codon/sir/util/visitor.cpp @@ -51,6 +51,7 @@ void Visitor::visit(types::Type *x) { defaultVisit(x); } void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); } void Visitor::visit(types::IntType *x) { defaultVisit(x); } void Visitor::visit(types::FloatType *x) { defaultVisit(x); } +void Visitor::visit(types::Float32Type *x) { defaultVisit(x); } void Visitor::visit(types::BoolType *x) { defaultVisit(x); } void Visitor::visit(types::ByteType *x) { defaultVisit(x); } void Visitor::visit(types::VoidType *x) { defaultVisit(x); } @@ -61,6 +62,7 @@ void Visitor::visit(types::OptionalType *x) { defaultVisit(x); } void Visitor::visit(types::PointerType *x) { defaultVisit(x); } void Visitor::visit(types::GeneratorType *x) { defaultVisit(x); } void Visitor::visit(types::IntNType *x) { defaultVisit(x); } +void Visitor::visit(types::VectorType *x) { defaultVisit(x); } void Visitor::visit(dsl::types::CustomType *x) { defaultVisit(x); } void ConstVisitor::visit(const Module *x) { defaultVisit(x); } @@ -108,6 +110,7 @@ void ConstVisitor::visit(const types::Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); } +void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); } void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); } @@ -118,6 +121,7 @@ void ConstVisitor::visit(const types::OptionalType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::PointerType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::GeneratorType *x) { defaultVisit(x); } void ConstVisitor::visit(const types::IntNType *x) { defaultVisit(x); } +void ConstVisitor::visit(const types::VectorType *x) { defaultVisit(x); } void ConstVisitor::visit(const dsl::types::CustomType *x) { defaultVisit(x); } } // namespace util diff --git a/codon/sir/util/visitor.h b/codon/sir/util/visitor.h index b27462b3..750c59b5 100644 --- a/codon/sir/util/visitor.h +++ b/codon/sir/util/visitor.h @@ -16,6 +16,7 @@ class Type; class PrimitiveType; class IntType; class FloatType; +class Float32Type; class BoolType; class ByteType; class VoidType; @@ -26,6 +27,7 @@ class OptionalType; class PointerType; class GeneratorType; class IntNType; +class VectorType; } // namespace types namespace dsl { @@ -146,6 +148,7 @@ public: VISIT(types::PrimitiveType); VISIT(types::IntType); VISIT(types::FloatType); + VISIT(types::Float32Type); VISIT(types::BoolType); VISIT(types::ByteType); VISIT(types::VoidType); @@ -156,6 +159,7 @@ public: VISIT(types::PointerType); VISIT(types::GeneratorType); VISIT(types::IntNType); + VISIT(types::VectorType); VISIT(dsl::types::CustomType); }; @@ -220,6 +224,7 @@ public: CONST_VISIT(types::PrimitiveType); CONST_VISIT(types::IntType); CONST_VISIT(types::FloatType); + CONST_VISIT(types::Float32Type); CONST_VISIT(types::BoolType); CONST_VISIT(types::ByteType); CONST_VISIT(types::VoidType); @@ -230,6 +235,7 @@ public: CONST_VISIT(types::PointerType); CONST_VISIT(types::GeneratorType); CONST_VISIT(types::IntNType); + CONST_VISIT(types::VectorType); CONST_VISIT(dsl::types::CustomType); }; diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index cfdd28b5..1ac2ca22 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -31,6 +31,7 @@ ## Advanced * [Parallelism and multithreading](advanced/parallel.md) +* [GPU programming](advanced/gpu.md) * [Pipelines](advanced/pipelines.md) * [Intermediate representation](advanced/ir.md) * [Building from source](advanced/build.md) diff --git a/docs/advanced/gpu.md b/docs/advanced/gpu.md new file mode 100644 index 00000000..180d280c --- /dev/null +++ b/docs/advanced/gpu.md @@ -0,0 +1,249 @@ +Codon supports GPU programming through a native GPU backend. +Currently, only Nvidia devices are supported. +Here is a simple example: + +``` python +import gpu + +@gpu.kernel +def hello(a, b, c): + i = gpu.thread.x + c[i] = a[i] + b[i] + +a = [i for i in range(16)] +b = [2*i for i in range(16)] +c = [0 for _ in range(16)] + +hello(a, b, c, grid=1, block=16) +print(c) +``` + +which outputs: + +``` +[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45] +``` + +The same code can be written using Codon's `@par` syntax: + +``` python +a = [i for i in range(16)] +b = [2*i for i in range(16)] +c = [0 for _ in range(16)] + +@par(gpu=True) +for i in range(16): + c[i] = a[i] + b[i] + +print(c) +``` + +Below is a more comprehensive example for computing the [Mandelbrot +set](https://en.wikipedia.org/wiki/Mandelbrot_set), and plotting it +using NumPy/Matplotlib: + +``` python +from python import numpy as np +from python import matplotlib.pyplot as plt +import gpu + +MAX = 1000 # maximum Mandelbrot iterations +N = 4096 # width and height of image +pixels = [0 for _ in range(N * N)] + +def scale(x, a, b): + return a + (x/N)*(b - a) + +@gpu.kernel +def mandelbrot(pixels): + idx = (gpu.block.x * gpu.block.dim.x) + gpu.thread.x + i, j = divmod(idx, N) + c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12)) + z = 0j + iteration = 0 + + while abs(z) <= 2 and iteration < MAX: + z = z**2 + c + iteration += 1 + + pixels[idx] = int(255 * iteration/MAX) + +mandelbrot(pixels, grid=(N*N)//1024, block=1024) +plt.imshow(np.array(pixels).reshape(N, N)) +plt.show() +``` + +The GPU version of the Mandelbrot code is about 450 times faster +than an equivalent CPU version. + +GPU kernels are marked with the `@gpu.kernel` annotation, and +compiled specially in Codon's backend. Kernel functions can +use the vast majority of features supported in Codon, with a +couple notable exceptions: + +- Exception handling is not supported inside the kernel, meaning + kernel code should not throw or catch exceptions. `raise` + statements inside the kernel are marked as unreachable and + optimized out. + +- Functionality related to I/O is not supported (e.g. you can't + open a file in the kernel). + +- A few other modules and functions are not allowed, such as the + `re` module (which uses an external regex library) or the `os` + module. + +{% hint style="warning" %} +The GPU module is under active development. APIs and semantics +might change between Codon releases. +{% endhint %} + +# Invoking the kernel + +The kernel can be invoked via a simple call with added `grid` and +`block` parameters. These parameters define the grid and block +dimensions, respectively. Recall that GPU execution involves a *grid* +of (`X` x `Y` x `Z`) *blocks* where each block contains (`x` x `y` x `z`) +executing threads. Device-specific restrictions on grid and block sizes +apply. + +The `grid` and `block` parameters can be one of: + +- Single integer `x`, giving dimensions `(x, 1, 1)` +- Tuple of two integers `(x, y)`, giving dimensions `(x, y, 1)` +- Tuple of three integers `(x, y, z)`, giving dimensions `(x, y, z)` +- Instance of `gpu.Dim3` as in `Dim3(x, y, z)`, specifying the three dimensions + +# GPU intrinsics + +Codon's GPU module provides many of the same intrinsics that CUDA does: + +| Codon | Description | CUDA equivalent | +|-------------------|-----------------------------------------|-----------------| +| `gpu.thread.x` | x-coordinate of current thread in block | `threadId.x` | +| `gpu.block.x` | x-coordinate of current block in grid | `blockIdx.x` | +| `gpu.block.dim.x` | x-dimension of block | `blockDim.x` | +| `gpu.grid.dim.x` | x-dimension of grid | `gridDim.x` | + +The same applies for the `y` and `z` coordinates. The `*.dim` objects are instances +of `gpu.Dim3`. + +# Math functions + +All the functions in the `math` module are supported in kernel functions, and +are automatically replaced with GPU-optimized versions: + +``` python +import math +import gpu + +@gpu.kernel +def hello(x): + i = gpu.thread.x + x[i] = math.sqrt(x[i]) # uses __nv_sqrt from libdevice + +x = [float(i) for i in range(10)] +hello(x, grid=1, block=10) +print(x) +``` + +gives: + +``` +[0, 1, 1.41421, 1.73205, 2, 2.23607, 2.44949, 2.64575, 2.82843, 3] +``` + +# Libdevice + +Codon uses [libdevice](https://docs.nvidia.com/cuda/libdevice-users-guide/index.html) +for GPU-optimized math functions. The default libdevice path is +`/usr/local/cuda/nvvm/libdevice/libdevice.10.bc`. An alternative path can be specified +via the `-libdevice` compiler flag. + +# Working with raw pointers + +By default, objects are converted entirely to their GPU counterparts, which have +the same data layout as the original objects (although the Codon compiler might perform +optimizations by swapping a CPU implementation of a data type with a GPU-optimized +implementation that exposes the same API). This preserves all of Codon/Python's +standard semantics within the kernel. + +It is possible to use a kernel with raw pointers via `gpu.raw`, which corresponds +to how the kernel would be written in C++/CUDA: + +``` python +import gpu + +@gpu.kernel +def hello(a, b, c): + i = gpu.thread.x + c[i] = a[i] + b[i] + +a = [i for i in range(16)] +b = [2*i for i in range(16)] +c = [0 for _ in range(16)] + +# call the kernel with three int-pointer arguments: +hello(gpu.raw(a), gpu.raw(b), gpu.raw(c), grid=1, block=16) +print(c) # output same as first snippet's +``` + +`gpu.raw` can avoid an extra pointer indirection, but outputs a Codon `Ptr` object, +meaning the corresponding kernel parameters will not have the full list API, instead +having the more limited `Ptr` API (which primarily just supports indexing/assignment). + +# Object conversions + +A hidden API is used to copy objects to and from the GPU device. This API consists of +two new *magic methods*: + +- `__to_gpu__(self)`: Allocates the necessary GPU memory and copies the object `self` to + the device. + +- `__from_gpu__(self, gpu_object)`: Copies the GPU memory of `gpu_object` (which is + a value returned by `__to_gpu__`) back to the CPU object `self`. + +For primitive types like `int` and `float`, `__to_gpu__` simply returns `self` and +`__from_gpu__` does nothing. These methods are defined for all the built-in types *and* +are automatically generated for user-defined classes, so most objects can be transferred +back and forth from the GPU seamlessly. A user-defined class that makes use of raw pointers +or other low-level constructs will have to define these methods for GPU use. Please refer +to the `gpu` module for implementation examples. + +# `@par(gpu=True)` + +Codon's `@par` syntax can be used to seamlessly parallelize existing loops on the GPU, +without needing to explicitly write them as kernels. For loop nests, the `collapse` argument +can be used to cover the entire iteration space on the GPU. For example, here is the Mandelbrot +code above written using `@par`: + +``` python +MAX = 1000 # maximum Mandelbrot iterations +N = 4096 # width and height of image +pixels = [0 for _ in range(N * N)] + +def scale(x, a, b): + return a + (x/N)*(b - a) + +@par(gpu=True, collapse=2) +for i in range(N): + for j in range(N): + c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12)) + z = 0j + iteration = 0 + + while abs(z) <= 2 and iteration < MAX: + z = z**2 + c + iteration += 1 + + pixels[i*N + j] = int(255 * iteration/MAX) +``` + +Note that the `gpu=True` option disallows shared variables (i.e. assigning out-of-loop +variables in the loop body) as well as reductions. The other GPU-specific restrictions +described here apply as well. + +# Troubleshooting + +CUDA errors resulting in kernel abortion are printed, and typically arise from invalid +code in the kernel, either via using exceptions or using unsupported modules/objects. diff --git a/docs/advanced/parallel.md b/docs/advanced/parallel.md index 703086d2..36a08544 100644 --- a/docs/advanced/parallel.md +++ b/docs/advanced/parallel.md @@ -29,6 +29,8 @@ for i in range(10): - `chunk_size` (int): chunk size when partitioning loop iterations - `ordered` (bool): whether the loop iterations should be executed in the same order +- `collapse` (int): number of loop nests to collapse into a single + iteration space Other OpenMP parameters like `private`, `shared` or `reduction`, are inferred automatically by the compiler. For example, the following loop diff --git a/docs/intro/differences.md b/docs/intro/differences.md index 1d1b5de6..b363748c 100644 --- a/docs/intro/differences.md +++ b/docs/intro/differences.md @@ -30,6 +30,15 @@ compile to native code without any runtime performance overhead. Future versions of Codon will lift some of these restrictions by the introduction of e.g. implicit union types. +# Numerics + +For performance reasons, some numeric operations use C semantics +rather than Python semantics. This includes, for example, raising +an exception when dividing by zero, or other checks done by `math` +functions. Strict adherence to Python semantics can be achieved by +using the `-numerics=py` flag of the Codon compiler. Note that this +does *not* change `int`s from 64-bit. + # Modules While most of the commonly used builtin modules have Codon-native diff --git a/docs/intro/releases.md b/docs/intro/releases.md index 49c3ac24..9d10b032 100644 --- a/docs/intro/releases.md +++ b/docs/intro/releases.md @@ -2,6 +2,50 @@ Below you can find release notes for each major Codon release, listing improvements, updates, optimizations and more for each new version. +# v0.14 + +## GPU support + +GPU kernels can now be written and called in Codon. Existing +loops can be parallelized on the GPU with the `@par(gpu=True)` +annotation. Please see the [docs](../advanced/gpu.md) for +more information and examples. + +## Semantics + +Added `-numerics` flag, which specifies semantics of various +numeric operations: + +- `-numerics=c` (default): C semantics; best performance +- `-numerics=py`: Python semantics (checks for zero divisors + and raises `ZeroDivisionError`, and adds domain checks to `math` + functions); might slightly decrease performance. + +## Types + +Added `float32` type to represent 32-bit floats (equivalent to C's +`float`). All `math` functions now have `float32` overloads. + +## Parallelism + +Added `collapse` option to `@par`: + +``` python +@par(collapse=2) # parallelize entire iteration space of 2 loops +for i in range(N): + for j in range(N): + do_work(i, j) +``` + +## Standard library + +Added `collections.defaultdict`. + +## Python interoperability + +Various Python interoperability improvements: can now use `isinstance` +on Python objects/types and can now catch Python exceptions by name. + # v0.13 ## Language diff --git a/docs/language/extra.md b/docs/language/extra.md index 6e2b606f..4f6ac694 100644 --- a/docs/language/extra.md +++ b/docs/language/extra.md @@ -19,6 +19,12 @@ variants: - `i32`/`u32`: signed/unsigned 32-bit integer - `i64`/`u64`: signed/unsigned 64-bit integer +# 32-bit float + +Codon's `float` type is a 64-bit floating point value. Codon +also supports `float32` (or `f32` as a shorthand), representing +a 32-bit floating point value (like C's `float`). + # Pointers Codon has a `Ptr[T]` type that represents a pointer to an object diff --git a/stdlib/algorithms/heapsort.codon b/stdlib/algorithms/heapsort.codon index 31783427..35dc0639 100644 --- a/stdlib/algorithms/heapsort.codon +++ b/stdlib/algorithms/heapsort.codon @@ -69,6 +69,6 @@ def heap_sort(collection: List[T], keyf: Callable[[T], S], T: type, S: type) -> Heap Sort Returns a sorted list. """ - newlst = copy(collection) + newlst = collection.__copy__() heap_sort_inplace(newlst, keyf) return newlst diff --git a/stdlib/algorithms/insertionsort.codon b/stdlib/algorithms/insertionsort.codon index ea130325..1cf469ab 100644 --- a/stdlib/algorithms/insertionsort.codon +++ b/stdlib/algorithms/insertionsort.codon @@ -42,6 +42,6 @@ def insertion_sort( Insertion Sort Returns the sorted list. """ - newlst = copy(collection) + newlst = collection.__copy__() insertion_sort_inplace(newlst, keyf) return newlst diff --git a/stdlib/algorithms/pdqsort.codon b/stdlib/algorithms/pdqsort.codon index 8c8ba628..e3f72fa1 100644 --- a/stdlib/algorithms/pdqsort.codon +++ b/stdlib/algorithms/pdqsort.codon @@ -321,6 +321,6 @@ def pdq_sort(collection: List[T], keyf: Callable[[T], S], T: type, S: type) -> L Returns a sorted list. """ - newlst = copy(collection) + newlst = collection.__copy__() pdq_sort_inplace(newlst, keyf) return newlst diff --git a/stdlib/collections.codon b/stdlib/collections.codon index bf6a16fb..9634e2f5 100644 --- a/stdlib/collections.codon +++ b/stdlib/collections.codon @@ -114,13 +114,13 @@ class deque: return deque(i.__deepcopy__() for i in self) def __copy__(self) -> deque[T]: - return deque[T](copy(self._arr), self._head, self._tail, self._maxlen) + return deque[T](self._arr.__copy__(), self._head, self._tail, self._maxlen) def copy(self) -> deque[T]: return self.__copy__() def __repr__(self) -> str: - return repr(List[T](iter(self))) + return f"deque({repr(List[T](iter(self)))})" def _idx_check(self, idx: int, msg: str): if self._head == self._tail or idx >= len(self) or idx < 0: @@ -323,25 +323,28 @@ class Counter(Dict[T, int]): return result def __add__(self, other: Counter[T]) -> Counter[T]: - result = copy(self) + result = self.__copy__() result += other return result def __sub__(self, other: Counter[T]) -> Counter[T]: - result = copy(self) + result = self.__copy__() result -= other return result def __and__(self, other: Counter[T]) -> Counter[T]: - result = copy(self) + result = self.__copy__() result &= other return result def __or__(self, other: Counter[T]) -> Counter[T]: - result = copy(self) + result = self.__copy__() result |= other return result + def __repr__(self): + return f"Counter({super().__repr__()})" + def __dict_do_op_throws__(self, key: T, other: Z, op: F, F: type, Z: type): self.__dict_do_op__(key, other, 0, op) @@ -357,5 +360,79 @@ class Dict: self._init_from(other) +class defaultdict(Dict[K,V]): + default_factory: S + K: type + V: type + S: TypeVar[Callable[[], V]] + + def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V]): + super().__init__() + self.default_factory = lambda: VV() + + def __init__(self, f: S): + super().__init__() + self.default_factory = f + + def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V], other: Dict[K, V]): + super().__init__(other) + self.default_factory = lambda: VV() + + def __init__(self, f: S, other: Dict[K, V]): + super().__init__(other) + self.default_factory = f + + def __missing__(self, key: K): + default_value = self.default_factory() + self.__setitem__(key, default_value) + return default_value + + def __getitem__(self, key: K) -> V: + if key not in self: + return self.__missing__(key) + return super().__getitem__(key) + + def __dict_do_op_throws__(self, key: K, other: Z, op: F, F: type, Z: type): + x = self._kh_get(key) + if x == self._kh_end(): + self.__missing__(key) + x = self._kh_get(key) + self._vals[x] = op(self._vals[x], other) + + def copy(self): + d = defaultdict[K,V,S](self.default_factory) + d._init_from(self) + return d + + def __copy__(self): + return self.copy() + + def __deepcopy__(self): + d = defaultdict[K,V,S](self.default_factory) + for k,v in self.items(): + d[k.__deepcopy__()] = v.__deepcopy__() + return d + + def __eq__(self, other: defaultdict[K,V,S]) -> bool: + if self.__len__() != other.__len__(): + return False + for k, v in self.items(): + if k not in other or other[k] != v: + return False + return True + + def __ne__(self, other: defaultdict[K,V,S]) -> bool: + return not (self == other) + + def __repr__(self): + return f"defaultdict(, {super().__repr__()})" + + +@extend +class Dict: + def __init__(self: Dict[K, V], other: defaultdict[K, V, S], S: type): + self._init_from(other) + + def namedtuple(name: Static[str], args): # internal pass diff --git a/stdlib/copy.codon b/stdlib/copy.codon new file mode 100644 index 00000000..0e4da585 --- /dev/null +++ b/stdlib/copy.codon @@ -0,0 +1,14 @@ +# (c) 2022 Exaloop Inc. All rights reserved. + + +class Error(Exception): + def __init__(self, message: str = ""): + super().__init__("copy.Error", message) + + +def copy(x): + return x.__copy__() + + +def deepcopy(x): + return x.__deepcopy__() diff --git a/stdlib/gpu.codon b/stdlib/gpu.codon new file mode 100644 index 00000000..9810dcf5 --- /dev/null +++ b/stdlib/gpu.codon @@ -0,0 +1,732 @@ +# (c) 2022 Exaloop Inc. All rights reserved. + +from internal.gc import sizeof as _sizeof + +@tuple +class Device: + _device: i32 + + def __new__(device: int): + from C import seq_nvptx_device(int) -> i32 + return Device(seq_nvptx_device(device)) + + @staticmethod + def count(): + from C import seq_nvptx_device_count() -> int + return seq_nvptx_device_count() + + def __str__(self): + from C import seq_nvptx_device_name(i32) -> str + return seq_nvptx_device_name(self._device) + + def __index__(self): + return int(self._device) + + def __bool__(self): + return True + + @property + def compute_capability(self): + from C import seq_nvptx_device_capability(i32) -> int + c = seq_nvptx_device_capability(self._device) + return (c >> 32, c & 0xffffffff) + +@tuple +class Memory[T]: + _ptr: Ptr[byte] + + def _alloc(n: int, T: type): + from C import seq_nvptx_device_alloc(int) -> Ptr[byte] + return Memory[T](seq_nvptx_device_alloc(n * _sizeof(T))) + + def _read(self, p: Ptr[T], n: int): + from C import seq_nvptx_memcpy_d2h(Ptr[byte], Ptr[byte], int) + seq_nvptx_memcpy_d2h(p.as_byte(), self._ptr, n * _sizeof(T)) + + def _write(self, p: Ptr[T], n: int): + from C import seq_nvptx_memcpy_h2d(Ptr[byte], Ptr[byte], int) + seq_nvptx_memcpy_h2d(self._ptr, p.as_byte(), n * _sizeof(T)) + + def _free(self): + from C import seq_nvptx_device_free(Ptr[byte]) + seq_nvptx_device_free(self._ptr) + +@llvm +def syncthreads() -> None: + declare void @llvm.nvvm.barrier0() + call void @llvm.nvvm.barrier0() + ret {} {} + +@tuple +class Dim3: + _x: u32 + _y: u32 + _z: u32 + + def __new__(x: int, y: int, z: int): + return Dim3(u32(x), u32(y), u32(z)) + + @property + def x(self): + return int(self._x) + + @property + def y(self): + return int(self._y) + + @property + def z(self): + return int(self._z) + +@tuple +class Thread: + @property + def x(self): + @pure + @llvm + def get_x() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + ret i32 %res + + return int(get_x()) + + @property + def y(self): + @pure + @llvm + def get_y() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.tid.y() + %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.y() + ret i32 %res + + return int(get_y()) + + @property + def z(self): + @pure + @llvm + def get_z() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.tid.z() + %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.z() + ret i32 %res + + return int(get_z()) + +@tuple +class Block: + @property + def x(self): + @pure + @llvm + def get_x() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() + ret i32 %res + + return int(get_x()) + + @property + def y(self): + @pure + @llvm + def get_y() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() + ret i32 %res + + return int(get_y()) + + @property + def z(self): + @pure + @llvm + def get_z() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() + ret i32 %res + + return int(get_z()) + + @property + def dim(self): + @pure + @llvm + def get_x() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + ret i32 %res + + @pure + @llvm + def get_y() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y() + ret i32 %res + + @pure + @llvm + def get_z() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z() + ret i32 %res + + return Dim3(get_x(), get_y(), get_z()) + +@tuple +class Grid: + @property + def dim(self): + @pure + @llvm + def get_x() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() + %res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x() + ret i32 %res + + @pure + @llvm + def get_y() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.y() + %res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y() + ret i32 %res + + @pure + @llvm + def get_z() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.z() + %res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z() + ret i32 %res + + return Dim3(get_x(), get_y(), get_z()) + +@tuple +class Warp: + def __len__(self): + @pure + @llvm + def get_warpsize() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.warpsize() + %res = call i32 @llvm.nvvm.read.ptx.sreg.warpsize() + ret i32 %res + + return int(get_warpsize()) + +thread = Thread() +block = Block() +grid = Grid() +warp = Warp() + +def _catch(): + return (thread, block, grid, warp) + +_catch() + +@tuple +class AllocCache: + v: List[Ptr[byte]] + + def add(self, p: Ptr[byte]): + self.v.append(p) + + def free(self): + for p in self.v: + Memory[byte](p)._free() + +def _tuple_from_gpu(args, gpu_args): + if staticlen(args) > 0: + a = args[0] + g = gpu_args[0] + a.__from_gpu__(g) + _tuple_from_gpu(args[1:], gpu_args[1:]) + +def kernel(fn): + from C import seq_nvptx_function(str) -> cobj + from C import seq_nvptx_invoke(cobj, u32, u32, u32, u32, u32, u32, u32, cobj) + + def canonical_dim(dim): + if isinstance(dim, NoneType): + return (1, 1, 1) + elif isinstance(dim, int): + return (dim, 1, 1) + elif isinstance(dim, Tuple[int,int]): + return (dim[0], dim[1], 1) + elif isinstance(dim, Tuple[int,int,int]): + return dim + elif isinstance(dim, Dim3): + return (dim.x, dim.y, dim.z) + else: + compile_error("bad dimension argument") + + def offsets(t): + @pure + @llvm + def offsetof(t: T, i: Static[int], T: type, S: type) -> int: + %p = getelementptr {=T}, {=T}* null, i64 0, i32 {=i} + %s = ptrtoint {=S}* %p to i64 + ret i64 %s + + if staticlen(t) == 0: + return () + else: + T = type(t) + S = type(t[-1]) + return (*offsets(t[:-1]), offsetof(t, staticlen(t) - 1, T, S)) + + def wrapper(*args, grid, block): + grid = canonical_dim(grid) + block = canonical_dim(block) + cache = AllocCache([]) + shared_mem = 0 + gpu_args = tuple(arg.__to_gpu__(cache) for arg in args) + kernel_ptr = seq_nvptx_function(__realized__(fn, gpu_args).__llvm_name__) + p = __ptr__(gpu_args).as_byte() + arg_ptrs = tuple((p + offset) for offset in offsets(gpu_args)) + seq_nvptx_invoke(kernel_ptr, u32(grid[0]), u32(grid[1]), u32(grid[2]), u32(block[0]), + u32(block[1]), u32(block[2]), u32(shared_mem), __ptr__(arg_ptrs).as_byte()) + _tuple_from_gpu(args, gpu_args) + cache.free() + + return wrapper + +def _ptr_to_gpu(p: Ptr[T], n: int, cache: AllocCache, index_filter = lambda i: True, T: type): + from internal.gc import atomic + + if not atomic(T): + tmp = Ptr[T](n) + for i in range(n): + if index_filter(i): + tmp[i] = p[i].__to_gpu__(cache) + p = tmp + + mem = Memory._alloc(n, T) + cache.add(mem._ptr) + mem._write(p, n) + return Ptr[T](mem._ptr) + +def _ptr_from_gpu(p: Ptr[T], q: Ptr[T], n: int, index_filter = lambda i: True, T: type): + from internal.gc import atomic + + mem = Memory[T](q.as_byte()) + if not atomic(T): + tmp = Ptr[T](n) + mem._read(tmp, n) + for i in range(n): + if index_filter(i): + p[i] = T.__from_gpu_new__(tmp[i]) + else: + mem._read(p, n) + +@pure +@llvm +def _ptr_to_type(p: cobj, T: type) -> T: + ret i8* %p + +def _object_to_gpu(obj: T, cache: AllocCache, T: type): + s = tuple(obj) + gpu_mem = Memory._alloc(1, type(s)) + cache.add(gpu_mem._ptr) + gpu_mem._write(__ptr__(s), 1) + return _ptr_to_type(gpu_mem._ptr, T) + +def _object_from_gpu(obj): + T = type(obj) + S = type(tuple(obj)) + + tmp = T.__new__() + p = Ptr[S](tmp.__raw__()) + q = Ptr[S](obj.__raw__()) + + mem = Memory[S](q.as_byte()) + mem._read(p, 1) + return tmp + +@tuple +class Pointer[T]: + _ptr: Ptr[T] + _len: int + + def __to_gpu__(self, cache: AllocCache): + return _ptr_to_gpu(self._ptr, self._len, cache) + + def __from_gpu__(self, other: Ptr[T]): + _ptr_from_gpu(self._ptr, other, self._len) + + def __from_gpu_new__(other: Ptr[T]): + return other + +def raw(v: List[T], T: type): + return Pointer(v.arr.ptr, len(v)) + +@extend +class Ptr: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: Ptr[T]): + pass + + def __from_gpu_new__(other: Ptr[T]): + return other + +@extend +class NoneType: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: NoneType): + pass + + def __from_gpu_new__(other: NoneType): + return other + +@extend +class int: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: int): + pass + + def __from_gpu_new__(other: int): + return other + +@extend +class float: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: float): + pass + + def __from_gpu_new__(other: float): + return other + +@extend +class float32: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: float32): + pass + + def __from_gpu_new__(other: float32): + return other + +@extend +class bool: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: bool): + pass + + def __from_gpu_new__(other: bool): + return other + +@extend +class byte: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: byte): + pass + + def __from_gpu_new__(other: byte): + return other + +@extend +class Int: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: Int[N]): + pass + + def __from_gpu_new__(other: Int[N]): + return other + +@extend +class UInt: + def __to_gpu__(self, cache: AllocCache): + return self + + def __from_gpu__(self, other: UInt[N]): + pass + + def __from_gpu_new__(other: UInt[N]): + return other + +@extend +class str: + def __to_gpu__(self, cache: AllocCache): + n = self.len + return str(_ptr_to_gpu(self.ptr, n, cache), n) + + def __from_gpu__(self, other: str): + pass + + def __from_gpu_new__(other: str): + n = other.len + p = Ptr[byte](n) + _ptr_from_gpu(p, other.ptr, n) + return str(p, n) + +@extend +class List: + @inline + def __to_gpu__(self, cache: AllocCache): + mem = List[T].__new__() + n = self.len + gpu_ptr = _ptr_to_gpu(self.arr.ptr, n, cache) + mem.arr = Array[T](gpu_ptr, n) + mem.len = n + return _object_to_gpu(mem, cache) + + @inline + def __from_gpu__(self, other: List[T]): + mem = _object_from_gpu(other) + my_cap = self.arr.len + other_cap = mem.arr.len + + if other_cap > my_cap: + self._resize(other_cap) + + _ptr_from_gpu(self.arr.ptr, mem.arr.ptr, mem.len) + self.len = mem.len + + @inline + def __from_gpu_new__(other: List[T]): + mem = _object_from_gpu(other) + arr = Array[T](mem.arr.len) + _ptr_from_gpu(arr.ptr, mem.arr.ptr, arr.len) + mem.arr = arr + return mem + +@extend +class Dict: + def __to_gpu__(self, cache: AllocCache): + from internal.khash import __ac_fsize + mem = Dict[K,V].__new__() + n = self._n_buckets + f = __ac_fsize(n) if n else 0 + + mem._n_buckets = n + mem._size = self._size + mem._n_occupied = self._n_occupied + mem._upper_bound = self._upper_bound + mem._flags = _ptr_to_gpu(self._flags, f, cache) + mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i)) + mem._vals = _ptr_to_gpu(self._vals, n, cache, lambda i: self._kh_exist(i)) + + return _object_to_gpu(mem, cache) + + def __from_gpu__(self, other: Dict[K,V]): + from internal.khash import __ac_fsize + mem = _object_from_gpu(other) + my_n = self._n_buckets + n = mem._n_buckets + f = __ac_fsize(n) if n else 0 + + if my_n != n: + self._flags = Ptr[u32](f) + self._keys = Ptr[K](n) + self._vals = Ptr[V](n) + + _ptr_from_gpu(self._flags, mem._flags, f) + _ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i)) + _ptr_from_gpu(self._vals, mem._vals, n, lambda i: self._kh_exist(i)) + + self._n_buckets = n + self._size = mem._size + self._n_occupied = mem._n_occupied + self._upper_bound = mem._upper_bound + + def __from_gpu_new__(other: Dict[K,V]): + from internal.khash import __ac_fsize + mem = _object_from_gpu(other) + + n = mem._n_buckets + f = __ac_fsize(n) if n else 0 + flags = Ptr[u32](f) + keys = Ptr[K](n) + vals = Ptr[V](n) + + _ptr_from_gpu(flags, mem._flags, f) + mem._flags = flags + _ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i)) + mem._keys = keys + _ptr_from_gpu(vals, mem._vals, n, lambda i: mem._kh_exist(i)) + mem._vals = vals + return mem + +@extend +class Set: + def __to_gpu__(self, cache: AllocCache): + from internal.khash import __ac_fsize + mem = Set[K].__new__() + n = self._n_buckets + f = __ac_fsize(n) if n else 0 + + mem._n_buckets = n + mem._size = self._size + mem._n_occupied = self._n_occupied + mem._upper_bound = self._upper_bound + mem._flags = _ptr_to_gpu(self._flags, f, cache) + mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i)) + + return _object_to_gpu(mem, cache) + + def __from_gpu__(self, other: Set[K]): + from internal.khash import __ac_fsize + mem = _object_from_gpu(other) + + my_n = self._n_buckets + n = mem._n_buckets + f = __ac_fsize(n) if n else 0 + + if my_n != n: + self._flags = Ptr[u32](f) + self._keys = Ptr[K](n) + + _ptr_from_gpu(self._flags, mem._flags, f) + _ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i)) + + self._n_buckets = n + self._size = mem._size + self._n_occupied = mem._n_occupied + self._upper_bound = mem._upper_bound + + def __from_gpu_new__(other: Set[K]): + from internal.khash import __ac_fsize + mem = _object_from_gpu(other) + + n = mem._n_buckets + f = __ac_fsize(n) if n else 0 + flags = Ptr[u32](f) + keys = Ptr[K](n) + + _ptr_from_gpu(flags, mem._flags, f) + mem._flags = flags + _ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i)) + mem._keys = keys + return mem + +@extend +class Optional: + def __to_gpu__(self, cache: AllocCache): + if self is None: + return self + else: + return Optional[T](self.__val__().__to_gpu__(cache)) + + def __from_gpu__(self, other: Optional[T]): + if self is not None and other is not None: + self.__val__().__from_gpu__(other.__val__()) + + def __from_gpu_new__(other: Optional[T]): + if other is None: + return Optional[T]() + else: + return Optional[T](T.__from_gpu_new__(other.__val__())) + +@extend +class __internal__: + def class_to_gpu(obj, cache: AllocCache): + if isinstance(obj, Tuple): + return tuple(a.__to_gpu__(cache) for a in obj) + elif isinstance(obj, ByVal): + T = type(obj) + return T(*tuple(a.__to_gpu__(cache) for a in tuple(obj))) + else: + T = type(obj) + S = type(tuple(obj)) + mem = T.__new__() + Ptr[S](mem.__raw__())[0] = tuple(obj).__to_gpu__(cache) + return _object_to_gpu(mem, cache) + + def class_from_gpu(obj, other): + if isinstance(obj, Tuple): + _tuple_from_gpu(obj, other) + elif isinstance(obj, ByVal): + _tuple_from_gpu(tuple(obj), tuple(other)) + else: + S = type(tuple(obj)) + Ptr[S](obj.__raw__())[0] = S.__from_gpu_new__(tuple(_object_from_gpu(other))) + + def class_from_gpu_new(other): + if isinstance(other, Tuple): + return tuple(type(a).__from_gpu_new__(a) for a in other) + elif isinstance(other, ByVal): + T = type(other) + return T(*tuple(type(a).__from_gpu_new__(a) for a in tuple(other))) + else: + S = type(tuple(other)) + mem = _object_from_gpu(other) + Ptr[S](mem.__raw__())[0] = S.__from_gpu_new__(tuple(mem)) + return mem + +# @par(gpu=True) support + +@pure +@llvm +def _gpu_thread_x() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + ret i32 %res + +@pure +@llvm +def _gpu_block_x() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() + ret i32 %res + +@pure +@llvm +def _gpu_block_dim_x() -> u32: + declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + %res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + ret i32 %res + +def _gpu_loop_outline_template(start, stop, args, instance: Static[int]): + @nonpure + def _loop_step(): + return 1 + + @kernel + def _kernel_stub(start: int, count: int, args): + @nonpure + def _gpu_loop_body_stub(idx, args): + pass + + @nonpure + def _dummy_use(n): + pass + + _dummy_use(instance) + idx = (int(_gpu_block_dim_x()) * int(_gpu_block_x())) + int(_gpu_thread_x()) + step = _loop_step() + if idx < count: + _gpu_loop_body_stub(start + (idx * step), args) + + step = _loop_step() + loop = range(start, stop, step) + + MAX_BLOCK = 1024 + MAX_GRID = 2147483647 + G = MAX_BLOCK * MAX_GRID + n = len(loop) + + if n == 0: + return + elif n > G: + raise ValueError(f'loop exceeds GPU iteration limit of {G}') + + block = n + grid = 1 + if n > MAX_BLOCK: + block = MAX_BLOCK + grid = (n // MAX_BLOCK) + (0 if n % MAX_BLOCK == 0 else 1) + + _kernel_stub(start, n, args, grid=grid, block=block) diff --git a/stdlib/internal/__init__.codon b/stdlib/internal/__init__.codon index 6f1f7c6a..09ff1362 100644 --- a/stdlib/internal/__init__.codon +++ b/stdlib/internal/__init__.codon @@ -1,6 +1,7 @@ +# (c) 2022 Exaloop Inc. All rights reserved. + # Core library -from internal.core import * from internal.attributes import * from internal.types.ptr import * from internal.types.str import * @@ -34,7 +35,11 @@ from internal.str import * from internal.sort import sorted from openmp import Ident as __OMPIdent, for_par +from gpu import _gpu_loop_outline_template from internal.file import File, gzFile, open, gzopen from pickle import pickle, unpickle from internal.dlopen import dlsym as _dlsym import internal.python + +if __py_numerics__: + import internal.pynumerics diff --git a/stdlib/internal/builtin.codon b/stdlib/internal/builtin.codon index befea33a..8b96368f 100644 --- a/stdlib/internal/builtin.codon +++ b/stdlib/internal/builtin.codon @@ -108,13 +108,6 @@ def iter(x): return x.__iter__() -def copy(x): - """ - Return a copy of x - """ - return x.__copy__() - - def abs(x): """ Return the absolute value of x diff --git a/stdlib/internal/c_stubs.codon b/stdlib/internal/c_stubs.codon index bf0b9e7b..7bdd5716 100644 --- a/stdlib/internal/c_stubs.codon +++ b/stdlib/internal/c_stubs.codon @@ -1,6 +1,6 @@ # (c) 2022 Exaloop Inc. All rights reserved. -# Seq runtime functions +# runtime functions from C import seq_print(str) from C import seq_print_full(str, cobj) @@ -216,7 +216,7 @@ def expm1(a: float) -> float: @pure @C -def ldexp(a: float, b: int) -> float: +def ldexp(a: float, b: i32) -> float: pass @@ -406,6 +406,234 @@ def modf(a: float, b: Ptr[float]) -> float: pass +@pure +@C +def ceilf(a: float32) -> float32: + pass + + +@pure +@C +def floorf(a: float32) -> float32: + pass + + +@pure +@C +def fabsf(a: float32) -> float32: + pass + + +@pure +@C +def fmodf(a: float32, b: float32) -> float32: + pass + + +@pure +@C +def expf(a: float32) -> float32: + pass + + +@pure +@C +def expm1f(a: float32) -> float32: + pass + + +@pure +@C +def ldexpf(a: float32, b: i32) -> float32: + pass + + +@pure +@C +def logf(a: float32) -> float32: + pass + + +@pure +@C +def log2f(a: float32) -> float32: + pass + + +@pure +@C +def log10f(a: float32) -> float32: + pass + + +@pure +@C +def sqrtf(a: float32) -> float32: + pass + + +@pure +@C +def powf(a: float32, b: float32) -> float32: + pass + + +@pure +@C +def roundf(a: float32) -> float32: + pass + + +@pure +@C +def acosf(a: float32) -> float32: + pass + + +@pure +@C +def asinf(a: float32) -> float32: + pass + + +@pure +@C +def atanf(a: float32) -> float32: + pass + + +@pure +@C +def atan2f(a: float32, b: float32) -> float32: + pass + + +@pure +@C +def cosf(a: float32) -> float32: + pass + + +@pure +@C +def sinf(a: float32) -> float32: + pass + + +@pure +@C +def tanf(a: float32) -> float32: + pass + + +@pure +@C +def coshf(a: float32) -> float32: + pass + + +@pure +@C +def sinhf(a: float32) -> float32: + pass + + +@pure +@C +def tanhf(a: float32) -> float32: + pass + + +@pure +@C +def acoshf(a: float32) -> float32: + pass + + +@pure +@C +def asinhf(a: float32) -> float32: + pass + + +@pure +@C +def atanhf(a: float32) -> float32: + pass + + +@pure +@C +def copysignf(a: float32, b: float32) -> float32: + pass + + +@pure +@C +def log1pf(a: float32) -> float32: + pass + + +@pure +@C +def truncf(a: float32) -> float32: + pass + + +@pure +@C +def log2f(a: float32) -> float32: + pass + + +@pure +@C +def erff(a: float32) -> float32: + pass + + +@pure +@C +def erfcf(a: float32) -> float32: + pass + + +@pure +@C +def tgammaf(a: float32) -> float32: + pass + + +@pure +@C +def lgammaf(a: float32) -> float32: + pass + + +@pure +@C +def remainderf(a: float32, b: float32) -> float32: + pass + + +@pure +@C +def hypotf(a: float32, b: float32) -> float32: + pass + + +@nocapture +@C +def frexpf(a: float32, b: Ptr[Int[32]]) -> float32: + pass + + +@nocapture +@C +def modff(a: float32, b: Ptr[float32]) -> float32: + pass + + # @pure @C diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index f45828e0..fbdfc287 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -1,3 +1,5 @@ +# (c) 2022 Exaloop Inc. All rights reserved. + @__internal__ class __internal__: pass @@ -15,13 +17,23 @@ class byte: @tuple @__internal__ class int: + MAX = 9223372036854775807 pass @tuple @__internal__ class float: + MIN_10_EXP = -307 pass + +@tuple +@__internal__ +class float32: + MIN_10_EXP = -37 + pass + + @tuple @__internal__ class NoneType: @@ -93,7 +105,7 @@ function = Function # dummy @__internal__ -class TypeVar: pass +class TypeVar[T]: pass @__internal__ class ByVal: pass @__internal__ @@ -154,3 +166,13 @@ def super(): def superf(*args): """Special handling""" pass + +def __realized__(fn, args): + pass + +def statictuple(*args): + return args + +def __static_print__(*args): + pass + diff --git a/stdlib/internal/file.codon b/stdlib/internal/file.codon index 4ae18d0f..9d5e1ca6 100644 --- a/stdlib/internal/file.codon +++ b/stdlib/internal/file.codon @@ -145,8 +145,9 @@ class gzFile: if self.buf[offset - 1] == byte(10): # '\n' break + oldsz = self.sz self.sz *= 2 - self.buf = realloc(self.buf, self.sz) + self.buf = realloc(self.buf, self.sz, oldsz) return offset diff --git a/stdlib/internal/gc.codon b/stdlib/internal/gc.codon index f8b3fddc..51d260fd 100644 --- a/stdlib/internal/gc.codon +++ b/stdlib/internal/gc.codon @@ -18,7 +18,7 @@ def seq_alloc_atomic(a: int) -> cobj: @nocapture @derives @C -def seq_realloc(p: cobj, a: int) -> cobj: +def seq_realloc(p: cobj, newsize: int, oldsize: int) -> cobj: pass @@ -76,8 +76,8 @@ def alloc_atomic(sz: int): return seq_alloc_atomic(sz) -def realloc(p: cobj, sz: int): - return seq_realloc(p, sz) +def realloc(p: cobj, newsz: int, oldsz: int): + return seq_realloc(p, newsz, oldsz) def free(p: cobj): diff --git a/stdlib/internal/pynumerics.codon b/stdlib/internal/pynumerics.codon new file mode 100644 index 00000000..813a31ed --- /dev/null +++ b/stdlib/internal/pynumerics.codon @@ -0,0 +1,168 @@ +# (c) 2022 Exaloop Inc. All rights reserved. + +@pure +@llvm +def _floordiv_int_float(self: int, other: float) -> float: + declare double @llvm.floor.f64(double) + %0 = sitofp i64 %self to double + %1 = fdiv double %0, %other + %2 = call double @llvm.floor.f64(double %1) + ret double %2 + +@pure +@llvm +def _floordiv_int_int(self: int, other: int) -> int: + %0 = sdiv i64 %self, %other + ret i64 %0 + +@pure +@llvm +def _truediv_int_float(self: int, other: float) -> float: + %0 = sitofp i64 %self to double + %1 = fdiv double %0, %other + ret double %1 + +@pure +@llvm +def _truediv_int_int(self: int, other: int) -> float: + %0 = sitofp i64 %self to double + %1 = sitofp i64 %other to double + %2 = fdiv double %0, %1 + ret double %2 + +@pure +@llvm +def _mod_int_float(self: int, other: float) -> float: + %0 = sitofp i64 %self to double + %1 = frem double %0, %other + ret double %1 + +@pure +@llvm +def _mod_int_int(self: int, other: int) -> int: + %0 = srem i64 %self, %other + ret i64 %0 + +@pure +@llvm +def _truediv_float_float(self: float, other: float) -> float: + %0 = fdiv double %self, %other + ret double %0 + +@pure +@llvm +def _mod_float_float(self: float, other: float) -> float: + %0 = frem double %self, %other + ret double %0 + +def _divmod_int_int(self: int, other: int): + d = _floordiv_int_int(self, other) + m = self - d * other + if m and ((other ^ m) < 0): + m += other + d -= 1 + return (d, m) + +def _divmod_float_float(self: float, other: float): + mod = _mod_float_float(self, other) + div = _truediv_float_float(self - mod, other) + if mod: + if (other < 0) != (mod < 0): + mod += other + div -= 1.0 + else: + mod = (0.0).copysign(other) + + floordiv = 0.0 + if div: + floordiv = div.__floor__() + if div - floordiv > 0.5: + floordiv += 1.0 + else: + floordiv = (0.0).copysign(self / other) + + return (floordiv, mod) + +@extend +class int: + def __floordiv__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float floor division by zero") + return _divmod_float_float(float(self), other)[0] + + def __floordiv__(self, other: int): + if other == 0: + raise ZeroDivisionError("integer division or modulo by zero") + return _divmod_int_int(self, other)[0] + + def __truediv__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float division by zero") + return _truediv_int_float(self, other) + + def __truediv__(self, other: int): + if other == 0: + raise ZeroDivisionError("division by zero") + return _truediv_int_int(self, other) + + def __mod__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float modulo") + return _divmod_float_float(self, other)[1] + + def __mod__(self, other: int): + if other == 0: + raise ZeroDivisionError("integer division or modulo by zero") + return _divmod_int_int(self, other)[1] + + def __divmod__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float divmod()") + return _divmod_float_float(float(self), other) + + def __divmod__(self, other: int): + if other == 0: + raise ZeroDivisionError("integer division or modulo by zero") + return _divmod_int_int(self, other) + +@extend +class float: + def __floordiv__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float floor division by zero") + return _divmod_float_float(self, other)[0] + + def __floordiv__(self, other: int): + if other == 0: + raise ZeroDivisionError("float floor division by zero") + return _divmod_float_float(self, float(other))[0] + + def __truediv__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float division by zero") + return _truediv_float_float(self, other) + + def __truediv__(self, other: int): + if other == 0: + raise ZeroDivisionError("float division by zero") + return _truediv_float_float(self, float(other)) + + def __mod__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float modulo") + return _divmod_float_float(self, other)[1] + + def __mod__(self, other: int): + if other == 0: + raise ZeroDivisionError("float modulo") + return _divmod_float_float(self, float(other))[1] + + def __divmod__(self, other: float): + if other == 0.0: + raise ZeroDivisionError("float divmod()") + return _divmod_float_float(self, other) + + def __divmod__(self, other: int): + if other == 0: + raise ZeroDivisionError("float divmod()") + return _divmod_float_float(self, float(other)) diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index 26acae30..b23cd151 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -12,8 +12,10 @@ PyImport_AddModule = Function[[cobj], cobj](cobj()) PyImport_AddModuleObject = Function[[cobj], cobj](cobj()) PyImport_ImportModule = Function[[cobj], cobj](cobj()) PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj()) +PyErr_NormalizeException = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj()) PyRun_SimpleString = Function[[cobj], NoneType](cobj()) PyEval_GetGlobals = Function[[], cobj](cobj()) +PyEval_GetBuiltins = Function[[], cobj](cobj()) # conversions PyLong_AsLong = Function[[cobj], int](cobj()) @@ -97,6 +99,7 @@ PyObject_GetItem = Function[[cobj, cobj], cobj](cobj()) PyObject_SetItem = Function[[cobj, cobj, cobj], int](cobj()) PyObject_DelItem = Function[[cobj, cobj], int](cobj()) PyObject_RichCompare = Function[[cobj, cobj, i32], cobj](cobj()) +PyObject_IsInstance = Function[[cobj, cobj], i32](cobj()) # constants Py_None = cobj() @@ -154,8 +157,10 @@ def init_dl_handles(py_handle: cobj): global PyImport_AddModuleObject global PyImport_ImportModule global PyErr_Fetch + global PyErr_NormalizeException global PyRun_SimpleString global PyEval_GetGlobals + global PyEval_GetBuiltins global PyLong_AsLong global PyLong_FromLong global PyFloat_AsDouble @@ -233,6 +238,7 @@ def init_dl_handles(py_handle: cobj): global PyObject_SetItem global PyObject_DelItem global PyObject_RichCompare + global PyObject_IsInstance global Py_None global Py_True global Py_False @@ -244,8 +250,10 @@ def init_dl_handles(py_handle: cobj): PyImport_AddModuleObject = dlsym(py_handle, "PyImport_AddModuleObject") PyImport_ImportModule = dlsym(py_handle, "PyImport_ImportModule") PyErr_Fetch = dlsym(py_handle, "PyErr_Fetch") + PyErr_NormalizeException = dlsym(py_handle, "PyErr_NormalizeException") PyRun_SimpleString = dlsym(py_handle, "PyRun_SimpleString") PyEval_GetGlobals = dlsym(py_handle, "PyEval_GetGlobals") + PyEval_GetBuiltins = dlsym(py_handle, "PyEval_GetBuiltins") PyLong_AsLong = dlsym(py_handle, "PyLong_AsLong") PyLong_FromLong = dlsym(py_handle, "PyLong_FromLong") PyFloat_AsDouble = dlsym(py_handle, "PyFloat_AsDouble") @@ -323,6 +331,7 @@ def init_dl_handles(py_handle: cobj): PyObject_SetItem = dlsym(py_handle, "PyObject_SetItem") PyObject_DelItem = dlsym(py_handle, "PyObject_DelItem") PyObject_RichCompare = dlsym(py_handle, "PyObject_RichCompare") + PyObject_IsInstance = dlsym(py_handle, "PyObject_IsInstance") Py_None = dlsym(py_handle, "_Py_NoneStruct") Py_True = dlsym(py_handle, "_Py_TrueStruct") Py_False = dlsym(py_handle, "_Py_FalseStruct") @@ -603,19 +612,17 @@ class pyobj: def exc_check(): ptype, pvalue, ptraceback = cobj(), cobj(), cobj() PyErr_Fetch(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback)) + PyErr_NormalizeException(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback)) if ptype != cobj(): py_msg = PyObject_Str(pvalue) if pvalue != cobj() else pvalue msg = pyobj.to_str(py_msg, "ignore", "") - py_typ = PyObject_GetAttrString(ptype, "__name__".c_str()) - typ = pyobj.to_str(py_typ, "ignore") pyobj.decref(ptype) - pyobj.decref(pvalue) pyobj.decref(ptraceback) pyobj.decref(py_msg) - pyobj.decref(py_typ) - raise PyError(msg, typ) + # pyobj.decref(pvalue) + raise PyError(msg, pyobj(pvalue)) def exc_wrap(_retval: T, T: type) -> T: pyobj.exc_check() @@ -667,7 +674,14 @@ class pyobj: PyRun_SimpleString(code.c_str()) def _globals() -> pyobj: - return pyobj(PyEval_GetGlobals()) + p = PyEval_GetGlobals() + if p == cobj(): + Py_IncRef(Py_None) + return pyobj(Py_None) + return pyobj(p) + + def _builtins() -> pyobj: + return pyobj(PyEval_GetBuiltins()) def _get_module(name: str) -> pyobj: p = pyobj(pyobj.exc_wrap(PyImport_AddModule(name.c_str()))) @@ -686,6 +700,17 @@ class pyobj: return bool(pyobj.exc_wrap(PyObject_IsTrue(self.p) == 1)) +def _get_identifier(typ: str) -> pyobj: + t = pyobj._builtins()[typ] + if t.p == cobj(): + t = pyobj._main_module()[typ] + return t + + +def _isinstance(what: pyobj, typ: pyobj) -> bool: + return bool(pyobj.exc_wrap(PyObject_IsInstance(what.p, typ.p))) + + # Type conversions @extend diff --git a/stdlib/internal/types/array.codon b/stdlib/internal/types/array.codon index 6f883dd8..c9aecd2f 100644 --- a/stdlib/internal/types/array.codon +++ b/stdlib/internal/types/array.codon @@ -1,3 +1,4 @@ +# (c) 2022 Exaloop Inc. All rights reserved. from internal.gc import sizeof diff --git a/stdlib/internal/types/collections/dict.codon b/stdlib/internal/types/collections/dict.codon index 4665edf6..4ce1141a 100644 --- a/stdlib/internal/types/collections/dict.codon +++ b/stdlib/internal/types/collections/dict.codon @@ -106,6 +106,15 @@ class Dict: def __len__(self) -> int: return self._size + def __or__(self, other): + new = self.__copy__() + new.update(other) + return new + + def __ior__(self, other): + self.update(other) + return self + def __copy__(self): if self.__len__() == 0: return Dict[K, V]() @@ -183,9 +192,13 @@ class Dict: ret, x = self._kh_put(key) self._vals[x] = op(dflt if ret != 0 else self._vals[x], other) - def update(self, other: Dict[K, V]): - for k, v in other.items(): - self[k] = v + def update(self, other): + if isinstance(other, Dict[K, V]): + for k, v in other.items(): + self[k] = v + else: + for k, v in other: + self[k] = v def pop(self, key: K) -> V: x = self._kh_get(key) @@ -284,10 +297,14 @@ class Dict: if self._n_buckets < new_n_buckets: self._keys = Ptr[K]( - gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)) + gc.realloc(self._keys.as_byte(), + new_n_buckets * gc.sizeof(K), + self._n_buckets * gc.sizeof(K)) ) self._vals = Ptr[V]( - gc.realloc(self._vals.as_byte(), new_n_buckets * gc.sizeof(V)) + gc.realloc(self._vals.as_byte(), + new_n_buckets * gc.sizeof(V), + self._n_buckets * gc.sizeof(V)) ) if j: @@ -324,10 +341,14 @@ class Dict: if self._n_buckets > new_n_buckets: self._keys = Ptr[K]( - gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)) + gc.realloc(self._keys.as_byte(), + new_n_buckets * gc.sizeof(K), + self._n_buckets * gc.sizeof(K)) ) self._vals = Ptr[V]( - gc.realloc(self._vals.as_byte(), new_n_buckets * gc.sizeof(V)) + gc.realloc(self._vals.as_byte(), + new_n_buckets * gc.sizeof(V), + self._n_buckets * gc.sizeof(V)) ) self._flags = new_flags diff --git a/stdlib/internal/types/collections/list.codon b/stdlib/internal/types/collections/list.codon index f14fba79..d44ae48c 100644 --- a/stdlib/internal/types/collections/list.codon +++ b/stdlib/internal/types/collections/list.codon @@ -238,7 +238,9 @@ class List: len0 = self.__len__() new_cap = n * len0 if self.arr.len < new_cap: - p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), new_cap * gc.sizeof(T))) + p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), + new_cap * gc.sizeof(T), + self.arr.len * gc.sizeof(T))) self.arr = Array[T](p, new_cap) idx = len0 @@ -349,7 +351,9 @@ class List: raise IndexError(msg) def _resize(self, new_cap: int): - p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), new_cap * gc.sizeof(T))) + p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), + new_cap * gc.sizeof(T), + self.arr.len * gc.sizeof(T))) self.arr = Array[T](p, new_cap) def _resize_if_full(self): @@ -414,5 +418,37 @@ class List: return Array[T](Ptr[T](), 0) return self.arr.slice(start, stop).__copy__() + def _cmp(self, other: List[T]): + n1 = self.__len__() + n2 = other.__len__() + nmin = n1 if n1 < n2 else n2 + for i in range(nmin): + a = self.arr[i] + b = other.arr[i] + + if a < b: + return -1 + elif a == b: + continue + else: + return 1 + if n1 < n2: + return -1 + elif n1 == n2: + return 0 + else: + return 1 + + def __lt__(self, other: List[T]): + return self._cmp(other) < 0 + + def __gt__(self, other: List[T]): + return self._cmp(other) > 0 + + def __le__(self, other: List[T]): + return self._cmp(other) <= 0 + + def __ge__(self, other: List[T]): + return self._cmp(other) >= 0 list = List diff --git a/stdlib/internal/types/collections/set.codon b/stdlib/internal/types/collections/set.codon index 2dcf3825..5922026a 100644 --- a/stdlib/internal/types/collections/set.codon +++ b/stdlib/internal/types/collections/set.codon @@ -316,7 +316,9 @@ class Set: if self._n_buckets < new_n_buckets: self._keys = Ptr[K]( - gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)) + gc.realloc(self._keys.as_byte(), + new_n_buckets * gc.sizeof(K), + self._n_buckets * gc.sizeof(K)) ) if j: @@ -350,7 +352,9 @@ class Set: if self._n_buckets > new_n_buckets: self._keys = Ptr[K]( - gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)) + gc.realloc(self._keys.as_byte(), + new_n_buckets * gc.sizeof(K), + self._n_buckets * gc.sizeof(K)) ) self._flags = new_flags diff --git a/stdlib/internal/types/error.codon b/stdlib/internal/types/error.codon index e80ca690..f533c3e8 100644 --- a/stdlib/internal/types/error.codon +++ b/stdlib/internal/types/error.codon @@ -72,9 +72,9 @@ class CError(Exception): class PyError(Exception): - pytype: str + pytype: pyobj - def __init__(self, message: str, pytype: str): + def __init__(self, message: str, pytype: pyobj): super().__init__("PyError", message) self.pytype = pytype diff --git a/stdlib/internal/types/float.codon b/stdlib/internal/types/float.codon index 837d7096..0664affc 100644 --- a/stdlib/internal/types/float.codon +++ b/stdlib/internal/types/float.codon @@ -395,3 +395,326 @@ class float: @property def imag(self) -> float: return 0.0 + + +@extend +class float32: + @pure + @llvm + def __new__(self: float) -> float32: + %0 = fptrunc double %self to float + ret float %0 + + def __new__(what: float32) -> float32: + return what + + def __new__() -> float32: + return float32.__new__(0.0) + + def __repr__(self) -> str: + s = seq_str_float(self.__float__()) + return s if s != "-nan" else "nan" + + def __copy__(self) -> float32: + return self + + def __deepcopy__(self) -> float32: + return self + + @pure + @llvm + def __int__(self) -> int: + %0 = fptosi float %self to i64 + ret i64 %0 + + @pure + @llvm + def __float__(self) -> float: + %0 = fpext float %self to double + ret double %0 + + @pure + @llvm + def __bool__(self) -> bool: + %0 = fcmp one float %self, 0.000000e+00 + %1 = zext i1 %0 to i8 + ret i8 %1 + + def __pos__(self) -> float32: + return self + + @pure + @llvm + def __neg__(self) -> float32: + %0 = fneg float %self + ret float %0 + + @pure + @commutative + @llvm + def __add__(a: float32, b: float32) -> float32: + %tmp = fadd float %a, %b + ret float %tmp + + @pure + @llvm + def __sub__(a: float32, b: float32) -> float32: + %tmp = fsub float %a, %b + ret float %tmp + + @pure + @commutative + @llvm + def __mul__(a: float32, b: float32) -> float32: + %tmp = fmul float %a, %b + ret float %tmp + + def __floordiv__(self, other: float32) -> float: + return self.__truediv__(other).__floor__() + + @pure + @llvm + def __truediv__(a: float32, b: float32) -> float32: + %tmp = fdiv float %a, %b + ret float %tmp + + @pure + @llvm + def __mod__(a: float32, b: float32) -> float32: + %tmp = frem float %a, %b + ret float %tmp + + def __divmod__(self, other: float32) -> Tuple[float32, float32]: + mod = self % other + div = (self - mod) / other + if mod: + if (other < 0) != (mod < 0): + mod += other + div -= 1.0 + else: + mod = (0.0).copysign(other) + + floordiv = 0.0 + if div: + floordiv = div.__floor__() + if div - floordiv > 0.5: + floordiv += 1.0 + else: + floordiv = (0.0).copysign(self / other) + + return (floordiv, mod) + + @pure + @llvm + def __eq__(a: float32, b: float32) -> bool: + %tmp = fcmp oeq float %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ne__(a: float32, b: float32) -> bool: + entry: + %tmp = fcmp one float %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __lt__(a: float32, b: float32) -> bool: + %tmp = fcmp olt float %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __gt__(a: float32, b: float32) -> bool: + %tmp = fcmp ogt float %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __le__(a: float32, b: float32) -> bool: + %tmp = fcmp ole float %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def __ge__(a: float32, b: float32) -> bool: + %tmp = fcmp oge float %a, %b + %res = zext i1 %tmp to i8 + ret i8 %res + + @pure + @llvm + def sqrt(a: float32) -> float32: + declare float @llvm.sqrt.f32(float %a) + %tmp = call float @llvm.sqrt.f32(float %a) + ret float %tmp + + @pure + @llvm + def sin(a: float32) -> float32: + declare float @llvm.sin.f32(float %a) + %tmp = call float @llvm.sin.f32(float %a) + ret float %tmp + + @pure + @llvm + def cos(a: float32) -> float32: + declare float @llvm.cos.f32(float %a) + %tmp = call float @llvm.cos.f32(float %a) + ret float %tmp + + @pure + @llvm + def exp(a: float32) -> float32: + declare float @llvm.exp.f32(float %a) + %tmp = call float @llvm.exp.f32(float %a) + ret float %tmp + + @pure + @llvm + def exp2(a: float32) -> float32: + declare float @llvm.exp2.f32(float %a) + %tmp = call float @llvm.exp2.f32(float %a) + ret float %tmp + + @pure + @llvm + def log(a: float32) -> float32: + declare float @llvm.log.f32(float %a) + %tmp = call float @llvm.log.f32(float %a) + ret float %tmp + + @pure + @llvm + def log10(a: float32) -> float32: + declare float @llvm.log10.f32(float %a) + %tmp = call float @llvm.log10.f32(float %a) + ret float %tmp + + @pure + @llvm + def log2(a: float32) -> float32: + declare float @llvm.log2.f32(float %a) + %tmp = call float @llvm.log2.f32(float %a) + ret float %tmp + + @pure + @llvm + def __abs__(a: float32) -> float32: + declare float @llvm.fabs.f32(float %a) + %tmp = call float @llvm.fabs.f32(float %a) + ret float %tmp + + @pure + @llvm + def __floor__(a: float32) -> float32: + declare float @llvm.floor.f32(float %a) + %tmp = call float @llvm.floor.f32(float %a) + ret float %tmp + + @pure + @llvm + def __ceil__(a: float32) -> float32: + declare float @llvm.ceil.f32(float %a) + %tmp = call float @llvm.ceil.f32(float %a) + ret float %tmp + + @pure + @llvm + def __trunc__(a: float32) -> float32: + declare float @llvm.trunc.f32(float %a) + %tmp = call float @llvm.trunc.f32(float %a) + ret float %tmp + + @pure + @llvm + def rint(a: float32) -> float32: + declare float @llvm.rint.f32(float %a) + %tmp = call float @llvm.rint.f32(float %a) + ret float %tmp + + @pure + @llvm + def nearbyint(a: float32) -> float32: + declare float @llvm.nearbyint.f32(float %a) + %tmp = call float @llvm.nearbyint.f32(float %a) + ret float %tmp + + @pure + @llvm + def __round__(a: float32) -> float32: + declare float @llvm.round.f32(float %a) + %tmp = call float @llvm.round.f32(float %a) + ret float %tmp + + @pure + @llvm + def __pow__(a: float32, b: float32) -> float32: + declare float @llvm.pow.f32(float %a, float %b) + %tmp = call float @llvm.pow.f32(float %a, float %b) + ret float %tmp + + @pure + @llvm + def min(a: float32, b: float32) -> float32: + declare float @llvm.minnum.f32(float %a, float %b) + %tmp = call float @llvm.minnum.f32(float %a, float %b) + ret float %tmp + + @pure + @llvm + def max(a: float32, b: float32) -> float32: + declare float @llvm.maxnum.f32(float %a, float %b) + %tmp = call float @llvm.maxnum.f32(float %a, float %b) + ret float %tmp + + @pure + @llvm + def copysign(a: float32, b: float32) -> float32: + declare float @llvm.copysign.f32(float %a, float %b) + %tmp = call float @llvm.copysign.f32(float %a, float %b) + ret float %tmp + + @pure + @llvm + def fma(a: float32, b: float32, c: float32) -> float32: + declare float @llvm.fma.f32(float %a, float %b, float %c) + %tmp = call float @llvm.fma.f32(float %a, float %b, float %c) + ret float %tmp + + @nocapture + @llvm + def __atomic_xchg__(d: Ptr[float32], b: float32) -> None: + %tmp = atomicrmw xchg float* %d, float %b seq_cst + ret {} {} + + @nocapture + @llvm + def __atomic_add__(d: Ptr[float32], b: float32) -> float32: + %tmp = atomicrmw fadd float* %d, float %b seq_cst + ret float %tmp + + @nocapture + @llvm + def __atomic_sub__(d: Ptr[float32], b: float32) -> float32: + %tmp = atomicrmw fsub float* %d, float %b seq_cst + ret float %tmp + + def __hash__(self) -> int: + return self.__float__().__hash__() + + def __match__(self, i: float32) -> bool: + return self == i + + +@extend +class float: + def __suffix_f32__(double) -> float32: + return float32.__new__(double) + +f32 = float32 diff --git a/stdlib/internal/types/int.codon b/stdlib/internal/types/int.codon index 563005d0..3a43b83b 100644 --- a/stdlib/internal/types/int.codon +++ b/stdlib/internal/types/int.codon @@ -182,6 +182,9 @@ class int: %tmp = srem i64 %a, %b ret i64 %tmp + def __divmod__(self, other: float) -> Tuple[float, float]: + return float(self).__divmod__(other) + def __divmod__(self, other: int) -> Tuple[int, int]: d = self // other m = self - d * other diff --git a/stdlib/internal/types/range.codon b/stdlib/internal/types/range.codon index 3ee8fda3..0c2c0d3e 100644 --- a/stdlib/internal/types/range.codon +++ b/stdlib/internal/types/range.codon @@ -95,3 +95,11 @@ class range: return f"range({self.start}, {self.stop})" else: return f"range({self.start}, {self.stop}, {self.step})" + +@overload +def staticrange(start: Static[int], stop: Static[int], step: Static[int] = 1): + return range(start, stop, step) + +@overload +def staticrange(stop: Static[int]): + return range(0, stop, 1) diff --git a/stdlib/math.codon b/stdlib/math.codon index af7324b3..3c248775 100644 --- a/stdlib/math.codon +++ b/stdlib/math.codon @@ -1,11 +1,23 @@ # (c) 2022 Exaloop Inc. All rights reserved. -e = 2.718281828459045 -pi = 3.141592653589793 -tau = 6.283185307179586 -inf = 1.0 / 0.0 -nan = 0.0 / 0.0 +@pure +@llvm +def _inf() -> float: + ret double 0x7FF0000000000000 + + +@pure +@llvm +def _nan() -> float: + ret double 0x7FF8000000000000 + + +e = 2.7182818284590452354 +pi = 3.14159265358979323846 +tau = 6.28318530717958647693 +inf = _inf() +nan = _nan() def factorial(x: int) -> int: @@ -43,11 +55,14 @@ def isnan(x: float) -> bool: Return True if float arg is a NaN, else False. """ - test = x == x - # if it is true then it is a number - if test: - return False - return True + @pure + @llvm + def f(x: float) -> bool: + %y = fcmp uno double %x, 0.000000e+00 + %z = zext i1 %y to i8 + ret i8 %z + + return f(x) def isinf(x: float) -> bool: @@ -56,7 +71,16 @@ def isinf(x: float) -> bool: Return True if float arg is an INF, else False. """ - return x == inf or x == -inf + @pure + @llvm + def f(x: float) -> bool: + declare double @llvm.fabs.f64(double) + %a = call double @llvm.fabs.f64(double %x) + %b = fcmp oeq double %a, 0x7FF0000000000000 + %c = zext i1 %b to i8 + ret i8 %c + + return f(x) def isfinite(x: float) -> bool: @@ -66,9 +90,35 @@ def isfinite(x: float) -> bool: Return True if x is neither an infinity nor a NaN, and False otherwise. """ - if isnan(x) or isinf(x): - return False - return True + return not (isnan(x) or isinf(x)) + + +def _check1(arg: float, r: float, can_overflow: bool = False): + if __py_numerics__: + if isnan(r) and not isnan(arg): + raise ValueError("math domain error") + + if isinf(r) and isfinite(arg): + if can_overflow: + raise OverflowError("math range error") + else: + raise ValueError("math domain error") + + return r + + +def _check2(x: float, y: float, r: float, can_overflow: bool = False): + if __py_numerics__: + if isnan(r) and not isnan(x) and not isnan(y): + raise ValueError("math domain error") + + if isinf(r) and isfinite(x) and isfinite(y): + if can_overflow: + raise OverflowError("math range error") + else: + raise ValueError("math domain error") + + return r def ceil(x: float) -> float: @@ -149,7 +199,7 @@ def exp(x: float) -> float: %y = call double @llvm.exp.f64(double %x) ret double %y - return f(x) + return _check1(x, f(x), True) def expm1(x: float) -> float: @@ -159,7 +209,7 @@ def expm1(x: float) -> float: Return e raised to the power x, minus 1. expm1 provides a way to compute this quantity to full precision. """ - return _C.expm1(x) + return _check1(x, _C.expm1(x), True) def ldexp(x: float, i: int) -> float: @@ -168,7 +218,7 @@ def ldexp(x: float, i: int) -> float: Returns x multiplied by 2 raised to the power of exponent. """ - return _C.ldexp(x, i) + return _check1(x, _C.ldexp(x, i32(i)), True) def log(x: float, base: float = e) -> float: @@ -185,9 +235,9 @@ def log(x: float, base: float = e) -> float: ret double %y if base == e: - return f(x) + return _check1(x, f(x)) else: - return f(x) / f(base) + return _check1(x, f(x)) / _check1(base, f(base)) def log2(x: float) -> float: @@ -203,7 +253,7 @@ def log2(x: float) -> float: %y = call double @llvm.log2.f64(double %x) ret double %y - return f(x) + return _check1(x, f(x)) def log10(x: float) -> float: @@ -219,7 +269,7 @@ def log10(x: float) -> float: %y = call double @llvm.log10.f64(double %x) ret double %y - return f(x) + return _check1(x, f(x)) def degrees(x: float) -> float: @@ -255,7 +305,7 @@ def sqrt(x: float) -> float: %y = call double @llvm.sqrt.f64(double %x) ret double %y - return f(x) + return _check1(x, f(x)) def pow(x: float, y: float) -> float: @@ -271,7 +321,7 @@ def pow(x: float, y: float) -> float: %z = call double @llvm.pow.f64(double %x, double %y) ret double %z - return f(x, y) + return _check2(x, y, f(x, y), True) def acos(x: float) -> float: @@ -280,7 +330,7 @@ def acos(x: float) -> float: Returns the arc cosine of x in radians. """ - return _C.acos(x) + return _check1(x, _C.acos(x)) def asin(x: float) -> float: @@ -289,7 +339,7 @@ def asin(x: float) -> float: Returns the arc sine of x in radians. """ - return _C.asin(x) + return _check1(x, _C.asin(x)) def atan(x: float) -> float: @@ -298,7 +348,7 @@ def atan(x: float) -> float: Returns the arc tangent of x in radians. """ - return _C.atan(x) + return _check1(x, _C.atan(x)) def atan2(y: float, x: float) -> float: @@ -309,7 +359,7 @@ def atan2(y: float, x: float) -> float: on the signs of both values to determine the correct quadrant. """ - return _C.atan2(y, x) + return _check2(x, y, _C.atan2(y, x)) def cos(x: float) -> float: @@ -325,7 +375,7 @@ def cos(x: float) -> float: %y = call double @llvm.cos.f64(double %x) ret double %y - return f(x) + return _check1(x, f(x)) def sin(x: float) -> float: @@ -341,7 +391,7 @@ def sin(x: float) -> float: %y = call double @llvm.sin.f64(double %x) ret double %y - return f(x) + return _check1(x, f(x)) def hypot(x: float, y: float) -> float: @@ -352,7 +402,7 @@ def hypot(x: float, y: float) -> float: This is the length of the vector from the origin to point (x, y). """ - return _C.hypot(x, y) + return _check2(x, y, _C.hypot(x, y), True) def tan(x: float) -> float: @@ -361,7 +411,7 @@ def tan(x: float) -> float: Return the tangent of a radian angle x. """ - return _C.tan(x) + return _check1(x, _C.tan(x)) def cosh(x: float) -> float: @@ -370,7 +420,7 @@ def cosh(x: float) -> float: Returns the hyperbolic cosine of x. """ - return _C.cosh(x) + return _check1(x, _C.cosh(x), True) def sinh(x: float) -> float: @@ -379,7 +429,7 @@ def sinh(x: float) -> float: Returns the hyperbolic sine of x. """ - return _C.sinh(x) + return _check1(x, _C.sinh(x), True) def tanh(x: float) -> float: @@ -388,7 +438,7 @@ def tanh(x: float) -> float: Returns the hyperbolic tangent of x. """ - return _C.tanh(x) + return _check1(x, _C.tanh(x)) def acosh(x: float) -> float: @@ -397,7 +447,7 @@ def acosh(x: float) -> float: Return the inverse hyperbolic cosine of x. """ - return _C.acosh(x) + return _check1(x, _C.acosh(x)) def asinh(x: float) -> float: @@ -406,7 +456,7 @@ def asinh(x: float) -> float: Return the inverse hyperbolic sine of x. """ - return _C.asinh(x) + return _check1(x, _C.asinh(x)) def atanh(x: float) -> float: @@ -415,7 +465,7 @@ def atanh(x: float) -> float: Return the inverse hyperbolic tangent of x. """ - return _C.atanh(x) + return _check1(x, _C.atanh(x)) def copysign(x: float, y: float) -> float: @@ -432,7 +482,7 @@ def copysign(x: float, y: float) -> float: %z = call double @llvm.copysign.f64(double %x, double %y) ret double %z - return f(x, y) + return _check2(x, y, f(x, y)) def log1p(x: float) -> float: @@ -441,7 +491,7 @@ def log1p(x: float) -> float: Return the natural logarithm of 1+x (base e). """ - return _C.log1p(x) + return _check1(x, _C.log1p(x)) def trunc(x: float) -> float: @@ -458,7 +508,7 @@ def trunc(x: float) -> float: %y = call double @llvm.trunc.f64(double %x) ret double %y - return f(x) + return _check1(x, f(x)) def erf(x: float) -> float: @@ -467,7 +517,7 @@ def erf(x: float) -> float: Return the error function at x. """ - return _C.erf(x) + return _check1(x, _C.erf(x)) def erfc(x: float) -> float: @@ -476,7 +526,7 @@ def erfc(x: float) -> float: Return the complementary error function at x. """ - return _C.erfc(x) + return _check1(x, _C.erfc(x)) def gamma(x: float) -> float: @@ -485,7 +535,7 @@ def gamma(x: float) -> float: Return the Gamma function at x. """ - return _C.tgamma(x) + return _check1(x, _C.tgamma(x), True) def lgamma(x: float) -> float: @@ -495,7 +545,7 @@ def lgamma(x: float) -> float: Return the natural logarithm of the absolute value of the Gamma function at x. """ - return _C.lgamma(x) + return _check1(x, _C.lgamma(x), True) def remainder(x: float, y: float) -> float: @@ -509,7 +559,7 @@ def remainder(x: float, y: float) -> float: two consecutive integers, the nearest even integer is used for n. """ - return _C.remainder(x, y) + return _check2(x, y, _C.remainder(x, y)) def gcd(a: float, b: float) -> float: @@ -585,3 +635,615 @@ def isclose(a: float, b: float, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> return ((diff <= fabs(rel_tol * b)) or (diff <= fabs(rel_tol * a))) or ( diff <= abs_tol ) + + +# 32-bit float ops + +e32 = float32(e) +pi32 = float32(pi) +tau32 = float32(tau) + +inf32 = float32(inf) +nan32 = float32(nan) + + +@overload +def isnan(x: float32) -> bool: + """ + isnan(float32) -> bool + + Return True if float arg is a NaN, else False. + """ + @pure + @llvm + def f(x: float32) -> bool: + %y = fcmp uno float %x, 0.000000e+00 + %z = zext i1 %y to i8 + ret i8 %z + + return f(x) + + +@overload +def isinf(x: float32) -> bool: + """ + isinf(float32) -> bool: + + Return True if float arg is an INF, else False. + """ + @pure + @llvm + def f(x: float32) -> bool: + declare float @llvm.fabs.f32(float) + %a = call float @llvm.fabs.f32(float %x) + %b = fcmp oeq float %a, 0x7FF0000000000000 + %c = zext i1 %b to i8 + ret i8 %c + + return f(x) + + +@overload +def isfinite(x: float32) -> bool: + """ + isfinite(float32) -> bool + + Return True if x is neither an infinity nor a NaN, + and False otherwise. + """ + return not (isnan(x) or isinf(x)) + + +@overload +def ceil(x: float32) -> float32: + """ + ceil(float32) -> float32 + + Return the ceiling of x as an Integral. + This is the smallest integer >= x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.ceil.f32(float) + %y = call float @llvm.ceil.f32(float %x) + ret float %y + + return f(x) + + +@overload +def floor(x: float32) -> float32: + """ + floor(float32) -> float32 + + Return the floor of x as an Integral. + This is the largest integer <= x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.floor.f32(float) + %y = call float @llvm.floor.f32(float %x) + ret float %y + + return f(x) + + +@overload +def fabs(x: float32) -> float32: + """ + fabs(float32) -> float32 + + Returns the absolute value of a float32ing point number. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.fabs.f32(float) + %y = call float @llvm.fabs.f32(float %x) + ret float %y + + return f(x) + + +@overload +def fmod(x: float32, y: float32) -> float32: + """ + fmod(float32, float32) -> float32 + + Returns the remainder of x divided by y. + """ + @pure + @llvm + def f(x: float32, y: float32) -> float32: + %z = frem float %x, %y + ret float %z + + return f(x, y) + + +@overload +def exp(x: float32) -> float32: + """ + exp(float32) -> float32 + + Returns the value of e raised to the xth power. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.exp.f32(float) + %y = call float @llvm.exp.f32(float %x) + ret float %y + + return f(x) + + +@overload +def expm1(x: float32) -> float32: + """ + expm1(float32) -> float32 + + Return e raised to the power x, minus 1. expm1 provides + a way to compute this quantity to full precision. + """ + return _C.expm1f(x) + + +@overload +def ldexp(x: float32, i: int) -> float32: + """ + ldexp(float32, int) -> float32 + + Returns x multiplied by 2 raised to the power of exponent. + """ + return _C.ldexpf(x, i32(i)) + + +@overload +def log(x: float32, base: float32 = e32) -> float32: + """ + log(float32) -> float32 + + Returns the natural logarithm (base-e logarithm) of x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.log.f32(float) + %y = call float @llvm.log.f32(float %x) + ret float %y + + if base == e32: + return f(x) + else: + return f(x) / f(base) + + +@overload +def log2(x: float32) -> float32: + """ + log2(float32) -> float32 + + Return the base-2 logarithm of x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.log2.f32(float) + %y = call float @llvm.log2.f32(float %x) + ret float %y + + return f(x) + + +@overload +def log10(x: float32) -> float32: + """ + log10(float32) -> float32 + + Returns the common logarithm (base-10 logarithm) of x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.log10.f32(float) + %y = call float @llvm.log10.f32(float %x) + ret float %y + + return f(x) + + +@overload +def degrees(x: float32) -> float32: + """ + degrees(float32) -> float32 + + Convert angle x from radians to degrees. + """ + radToDeg = float32(180.0) / pi32 + return x * radToDeg + + +@overload +def radians(x: float32) -> float32: + """ + radians(float32) -> float32 + + Convert angle x from degrees to radians. + """ + degToRad = pi32 / float32(180.0) + return x * degToRad + +@overload +def sqrt(x: float32) -> float32: + """ + sqrt(float32) -> float32 + + Returns the square root of x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.sqrt.f32(float) + %y = call float @llvm.sqrt.f32(float %x) + ret float %y + + return f(x) + + +@overload +def pow(x: float32, y: float32) -> float32: + """ + pow(float32, float32) -> float32 + + Returns x raised to the power of y. + """ + @pure + @llvm + def f(x: float32, y: float32) -> float32: + declare float @llvm.pow.f32(float, float) + %z = call float @llvm.pow.f32(float %x, float %y) + ret float %z + + return f(x, y) + + +@overload +def acos(x: float32) -> float32: + """ + acos(float32) -> float32 + + Returns the arc cosine of x in radians. + """ + return _C.acosf(x) + + +@overload +def asin(x: float32) -> float32: + """ + asin(float32) -> float32 + + Returns the arc sine of x in radians. + """ + return _C.asinf(x) + + +@overload +def atan(x: float32) -> float32: + """ + atan(float32) -> float32 + + Returns the arc tangent of x in radians. + """ + return _C.atanf(x) + + +@overload +def atan2(y: float32, x: float32) -> float32: + """ + atan2(float32, float32) -> float32 + + Returns the arc tangent in radians of y/x based + on the signs of both values to determine the + correct quadrant. + """ + return _C.atan2f(y, x) + + +@overload +def cos(x: float32) -> float32: + """ + cos(float32) -> float32 + + Returns the cosine of a radian angle x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.cos.f32(float) + %y = call float @llvm.cos.f32(float %x) + ret float %y + + return f(x) + + +@overload +def sin(x: float32) -> float32: + """ + sin(float32) -> float32 + + Returns the sine of a radian angle x. + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.sin.f32(float) + %y = call float @llvm.sin.f32(float %x) + ret float %y + + return f(x) + + +@overload +def hypot(x: float32, y: float32) -> float32: + """ + hypot(float32, float32) -> float32 + + Return the Euclidean norm. + This is the length of the vector from the + origin to point (x, y). + """ + return _C.hypotf(x, y) + + +@overload +def tan(x: float32) -> float32: + """ + tan(float32) -> float32 + + Return the tangent of a radian angle x. + """ + return _C.tanf(x) + + +@overload +def cosh(x: float32) -> float32: + """ + cosh(float32) -> float32 + + Returns the hyperbolic cosine of x. + """ + return _C.coshf(x) + + +@overload +def sinh(x: float32) -> float32: + """ + sinh(float32) -> float32 + + Returns the hyperbolic sine of x. + """ + return _C.sinhf(x) + + +@overload +def tanh(x: float32) -> float32: + """ + tanh(float32) -> float32 + + Returns the hyperbolic tangent of x. + """ + return _C.tanhf(x) + + +@overload +def acosh(x: float32) -> float32: + """ + acosh(float32) -> float32 + + Return the inverse hyperbolic cosine of x. + """ + return _C.acoshf(x) + + +@overload +def asinh(x: float32) -> float32: + """ + asinh(float32) -> float32 + + Return the inverse hyperbolic sine of x. + """ + return _C.asinhf(x) + + +@overload +def atanh(x: float32) -> float32: + """ + atanh(float32) -> float32 + + Return the inverse hyperbolic tangent of x. + """ + return _C.atanhf(x) + + +@overload +def copysign(x: float32, y: float32) -> float32: + """ + copysign(float32, float32) -> float32 + + Return a float32 with the magnitude (absolute value) of + x but the sign of y. + """ + @pure + @llvm + def f(x: float32, y: float32) -> float32: + declare float @llvm.copysign.f32(float, float) + %z = call float @llvm.copysign.f32(float %x, float %y) + ret float %z + + return f(x, y) + + +@overload +def log1p(x: float32) -> float32: + """ + log1p(float32) -> float32 + + Return the natural logarithm of 1+x (base e). + """ + return _C.log1pf(x) + + +@overload +def trunc(x: float32) -> float32: + """ + trunc(float32) -> float32 + + Return the Real value x truncated to an Integral + (usually an integer). + """ + @pure + @llvm + def f(x: float32) -> float32: + declare float @llvm.trunc.f32(float) + %y = call float @llvm.trunc.f32(float %x) + ret float %y + + return f(x) + + +@overload +def erf(x: float32) -> float32: + """ + erf(float32) -> float32 + + Return the error function at x. + """ + return _C.erff(x) + + +@overload +def erfc(x: float32) -> float32: + """ + erfc(float32) -> float32 + + Return the complementary error function at x. + """ + return _C.erfcf(x) + + +@overload +def gamma(x: float32) -> float32: + """ + gamma(float32) -> float32 + + Return the Gamma function at x. + """ + return _C.tgammaf(x) + + +@overload +def lgamma(x: float32) -> float32: + """ + lgamma(float32) -> float32 + + Return the natural logarithm of + the absolute value of the Gamma function at x. + """ + return _C.lgammaf(x) + + +@overload +def remainder(x: float32, y: float32) -> float32: + """ + remainder(float32, float32) -> float32 + + Return the IEEE 754-style remainder of x with respect to y. + For finite x and finite nonzero y, this is the difference + x - n*y, where n is the closest integer to the exact value + of the quotient x / y. If x / y is exactly halfway between + two consecutive integers, the nearest even integer is used + for n. + """ + return _C.remainderf(x, y) + + +@overload +def gcd(a: float32, b: float32) -> float32: + """ + gcd(float32, float32) -> float32 + + returns greatest common divisor of x and y. + """ + a = abs(a) + b = abs(b) + while a: + a, b = b % a, a + return b + + +@overload +@pure +def frexp(x: float32) -> Tuple[float32, int]: + """ + frexp(float32) -> Tuple[float32, int] + + The returned value is the mantissa and the integer pointed + to by exponent is the exponent. The resultant value is + x = mantissa * 2 ^ exponent. + """ + tmp = i32(0) + res = _C.frexpf(float32(x), __ptr__(tmp)) + return (res, int(tmp)) + + +@overload +@pure +def modf(x: float32) -> Tuple[float32, float32]: + """ + modf(float32) -> Tuple[float32, float32] + + The returned value is the fraction component (part after + the decimal), and sets integer to the integer component. + """ + tmp = float32(0.0) + res = _C.modff(float32(x), __ptr__(tmp)) + return (res, tmp) + + +@overload +def isclose(a: float32, b: float32, rel_tol: float32 = float32(1e-09), abs_tol: float32 = float32(0.0)) -> bool: + """ + isclose(float32, float32) -> bool + + Return True if a is close in value to b, and False otherwise. + For the values to be considered close, the difference between them + must be smaller than at least one of the tolerances. + """ + + # short circuit exact equality -- needed to catch two + # infinities of the same sign. And perhaps speeds things + # up a bit sometimes. + if a == b: + return True + + # This catches the case of two infinities of opposite sign, or + # one infinity and one finite number. Two infinities of opposite + # sign would otherwise have an infinite relative tolerance. + # Two infinities of the same sign are caught by the equality check + # above. + if a == inf32 or b == inf32: + return False + + # NAN is not close to anything, not even itself + if a == nan32 or b == nan32: + return False + + # regular computation + diff = fabs(b - a) + + return ((diff <= fabs(rel_tol * b)) or (diff <= fabs(rel_tol * a))) or ( + diff <= abs_tol + ) diff --git a/stdlib/openmp.codon b/stdlib/openmp.codon index d9014b84..ee65512f 100644 --- a/stdlib/openmp.codon +++ b/stdlib/openmp.codon @@ -643,6 +643,7 @@ def _task_loop_outline_template(gtid_ptr: Ptr[i32], btid_ptr: Ptr[i32], args): _loop_reductions(shared) _barrier(loc_ref, gtid) + @pure def get_num_threads(): from C import omp_get_num_threads() -> i32 @@ -952,10 +953,41 @@ def _atomic_float_max(a: Ptr[float], b: float): __kmpc_atomic_float8_max(_default_loc(), i32(0), a, b) +def _atomic_float32_add(a: Ptr[float32], b: float32) -> None: + from C import __kmpc_atomic_float4_add(Ptr[Ident], i32, Ptr[float32], float32) + __kmpc_atomic_float4_add(_default_loc(), i32(0), a, b) + + +def _atomic_float32_mul(a: Ptr[float32], b: float32): + from C import __kmpc_atomic_float4_mul(Ptr[Ident], i32, Ptr[float32], float32) + __kmpc_atomic_float4_mul(_default_loc(), i32(0), a, b) + + +def _atomic_float32_min(a: Ptr[float32], b: float32) -> None: + from C import __kmpc_atomic_float4_min(Ptr[Ident], i32, Ptr[float32], float32) + __kmpc_atomic_float4_min(_default_loc(), i32(0), a, b) + + +def _atomic_float32_max(a: Ptr[float32], b: float32) -> None: + from C import __kmpc_atomic_float4_max(Ptr[Ident], i32, Ptr[float32], float32) + __kmpc_atomic_float4_max(_default_loc(), i32(0), a, b) + + +def _range_len(start: int, stop: int, step: int): + if step > 0 and start < stop: + return 1 + (stop - 1 - start) // step + elif step < 0 and start > stop: + return 1 + (start - 1 - stop) // (-step) + else: + return 0 + + def for_par( num_threads: int = -1, chunk_size: int = -1, schedule: Static[str] = "static", ordered: Static[int] = False, + collapse: Static[int] = 0, + gpu: Static[int] = False, ): pass diff --git a/stdlib/pickle.codon b/stdlib/pickle.codon index c511f866..a838d517 100644 --- a/stdlib/pickle.codon +++ b/stdlib/pickle.codon @@ -77,6 +77,15 @@ class float: return _read(jar, float) +@extend +class float32: + def __pickle__(self, jar: Jar): + _write(jar, self) + + def __unpickle__(jar: Jar) -> float32: + return _read(jar, float32) + + @extend class bool: def __pickle__(self, jar: Jar): diff --git a/stdlib/simd.codon b/stdlib/simd.codon new file mode 100644 index 00000000..ff4d046f --- /dev/null +++ b/stdlib/simd.codon @@ -0,0 +1,310 @@ +@tuple +class Vec[T, N: Static[int]]: + ZERO_16x8i = Vec[u8,16](u8(0)) + FF_16x8i = Vec[u8,16](u8(0xff)) + ZERO_32x8i = Vec[u8,32](u8(0)) + FF_32x8i = Vec[u8,32](u8(0xff)) + + @llvm + def _mm_set1_epi8(val: u8) -> Vec[u8, 16]: + %0 = insertelement <16 x i8> undef, i8 %val, i32 0 + %1 = shufflevector <16 x i8> %0, <16 x i8> undef, <16 x i32> zeroinitializer + ret <16 x i8> %1 + + @llvm + def _mm256_set1_epi8(val: u8) -> Vec[u8, 32]: + %0 = insertelement <32 x i8> undef, i8 %val, i32 0 + %1 = shufflevector <32 x i8> %0, <32 x i8> undef, <32 x i32> zeroinitializer + ret <32 x i8> %1 + + @llvm + def _mm_loadu_si128(data) -> Vec[u8, 16]: + %0 = bitcast i8* %data to <16 x i8>* + %1 = load <16 x i8>, <16 x i8>* %0, align 1 + ret <16 x i8> %1 + + @llvm + def _mm256_loadu_si256(data) -> Vec[u8, 32]: + %0 = bitcast i8* %data to <32 x i8>* + %1 = load <32 x i8>, <32 x i8>* %0, align 1 + ret <32 x i8> %1 + + @llvm + def _mm256_set1_ps(val: f32) -> Vec[f32, 8]: + %0 = insertelement <8 x float> undef, float %val, i32 0 + %1 = shufflevector <8 x float> %0, <8 x float> undef, <8 x i32> zeroinitializer + ret <8 x float> %1 + + @llvm + def _mm512_set1_ps(val: f32) -> Vec[f32, 16]: + %0 = insertelement <16 x float> undef, float %val, i32 0 + %1 = shufflevector <16 x float> %0, <16 x float> undef, <16 x i32> zeroinitializer + ret <16 x float> %1 + + @llvm + def _mm256_loadu_ps(data: Ptr[f32]) -> Vec[f32, 8]: + %0 = bitcast float* %data to <8 x float>* + %1 = load <8 x float>, <8 x float>* %0 + ret <8 x float> %1 + + @llvm + def _mm512_loadu_ps(data: Ptr[f32]) -> Vec[f32, 16]: + %0 = bitcast float* %data to <16 x float>* + %1 = load <16 x float>, <16 x float>* %0 + ret <16 x float> %1 + + @llvm + def _mm256_cvtepi8_epi32(vec: Vec[u8, 16]) -> Vec[u32, 8]: + %0 = shufflevector <16 x i8> %vec, <16 x i8> undef, <8 x i32> + %1 = sext <8 x i8> %0 to <8 x i32> + ret <8 x i32> %1 + + @llvm + def _mm512_cvtepi8_epi64(vec: Vec[u8, 32]) -> Vec[u32, 16]: + %0 = shufflevector <32 x i8> %vec, <32 x i8> undef, <16 x i32> + %1 = sext <16 x i8> %0 to <16 x i32> + ret <16 x i32> %1 + + @llvm + def _mm256_castsi256_ps(vec: Vec[u32, 8]) -> Vec[f32, 8]: + %0 = bitcast <8 x i32> %vec to <8 x float> + ret <8 x float> %0 + + @llvm + def _mm512_castsi512_ps(vec: Vec[u32, 16]) -> Vec[f32, 16]: + %0 = bitcast <16 x i32> %vec to <16 x float> + ret <16 x float> %0 + + def __new__(x, T: type, N: Static[int]) -> Vec[T, N]: + if isinstance(T, u8) and N == 16: + if isinstance(x, u8) or isinstance(x, byte): # TODO: u8<->byte + return Vec._mm_set1_epi8(x) + if isinstance(x, Ptr[u8]) or isinstance(x, Ptr[byte]): + return Vec._mm_loadu_si128(x) + if isinstance(x, str): + return Vec._mm_loadu_si128(x.ptr) + if isinstance(T, u8) and N == 32: + if isinstance(x, u8) or isinstance(x, byte): # TODO: u8<->byte + return Vec._mm256_set1_epi8(x) + if isinstance(x, Ptr[u8]) or isinstance(x, Ptr[byte]): + return Vec._mm256_loadu_si256(x) + if isinstance(x, str): + return Vec._mm256_loadu_si256(x.ptr) + if isinstance(T, f32) and N == 8: + if isinstance(x, f32): + return Vec._mm256_set1_ps(x) + if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!] + return Vec._mm256_loadu_ps(x) + if isinstance(x, List[f32]): + return Vec._mm256_loadu_ps(x.arr.ptr) + if isinstance(x, Vec[u8, 16]): + return Vec._mm256_castsi256_ps(Vec._mm256_cvtepi8_epi32(x)) + if isinstance(T, f32) and N == 16: + if isinstance(x, f32): + return Vec._mm512_set1_ps(x) + if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!] + return Vec._mm512_loadu_ps(x) + if isinstance(x, List[f32]): + return Vec._mm512_loadu_ps(x.arr.ptr) + if isinstance(x, Vec[u8, 32]): + return Vec._mm512_castsi512_ps(Vec._mm512_cvtepi8_epi64(x)) + compile_error("invalid SIMD vector constructor") + + def __new__(x: str, offset: int = 0) -> Vec[u8, N]: + return Vec(x.ptr + offset, u8, N) + + def __new__(x: List[T], offset: int = 0) -> Vec[T, N]: + return Vec(x.arr.ptr + offset, T, N) + + def __new__(x) -> Vec[T, N]: + return Vec(x, T, N) + + @llvm + def _mm_cmpeq_epi8(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: + %0 = icmp eq <16 x i8> %x, %y + %1 = sext <16 x i1> %0 to <16 x i8> + ret <16 x i8> %1 + + def __eq__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: + return Vec._mm_cmpeq_epi8(self, other) + + @llvm + def _mm256_cmpeq_epi8(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: + %0 = icmp eq <32 x i8> %x, %y + %1 = sext <32 x i1> %0 to <32 x i8> + ret <32 x i8> %1 + + def __eq__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: + return Vec._mm256_cmpeq_epi8(self, other) + + @llvm + def _mm_andnot_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: + %0 = xor <16 x i8> %x, + %1 = and <16 x i8> %y, %0 + ret <16 x i8> %1 + + def __ne__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: + return Vec._mm_andnot_si128((self == other), Vec.FF_16x8i) + + @llvm + def _mm256_andnot_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: + %0 = xor <32 x i8> %x, + %1 = and <32 x i8> %y, %0 + ret <32 x i8> %1 + + def __ne__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: + return Vec._mm256_andnot_si256((self == other), Vec.FF_32x8i) + + def __eq__(self: Vec[u8, 16], other: bool) -> Vec[u8, 16]: + if not other: + return Vec._mm_andnot_si128(self, Vec.FF_16x8i) + else: + return Vec._mm_andnot_si128(self, Vec.ZERO_16x8i) + + def __eq__(self: Vec[u8, 32], other: bool) -> Vec[u8, 32]: + if not other: + return Vec._mm256_andnot_si256(self, Vec.FF_32x8i) + else: + return Vec._mm256_andnot_si256(self, Vec.ZERO_32x8i) + + @llvm + def _mm_and_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: + %0 = and <16 x i8> %x, %y + ret <16 x i8> %0 + + def __and__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: + return Vec._mm_and_si128(self, other) + + @llvm + def _mm_and_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: + %0 = and <32 x i8> %x, %y + ret <32 x i8> %0 + + def __and__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: + return Vec._mm_and_si256(self, other) + + @llvm + def _mm256_and_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]: + %0 = bitcast <8 x float> %x to <8 x i32> + %1 = bitcast <8 x float> %y to <8 x i32> + %2 = and <8 x i32> %0, %1 + %3 = bitcast <8 x i32> %2 to <8 x float> + ret <8 x float> %3 + + def __and__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]: + return Vec._mm256_and_ps(self, other) + + @llvm + def _mm512_and_ps(x: Vec[f32, 16], y: Vec[f32, 16]) -> Vec[f32, 16]: + %0 = bitcast <16 x float> %x to <16 x i32> + %1 = bitcast <16 x float> %y to <16 x i32> + %2 = and <16 x i32> %0, %1 + %3 = bitcast <16 x i32> %2 to <16 x float> + ret <16 x float> %3 + + def __and__(self: Vec[f32, 16], other: Vec[f32, 16]) -> Vec[f32, 16]: + return Vec._mm512_and_ps(self, other) + + @llvm + def _mm_or_si128(x: Vec[u8, 16], y: Vec[u8, 16]) -> Vec[u8, 16]: + %0 = or <16 x i8> %x, %y + ret <16 x i8> %0 + + def __or__(self: Vec[u8, 16], other: Vec[u8, 16]) -> Vec[u8, 16]: + return Vec._mm_or_si128(self, other) + + @llvm + def _mm_or_si256(x: Vec[u8, 32], y: Vec[u8, 32]) -> Vec[u8, 32]: + %0 = or <32 x i8> %x, %y + ret <32 x i8> %0 + + def __or__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]: + return Vec._mm_or_si256(self, other) + + @llvm + def _mm256_or_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]: + %0 = bitcast <8 x float> %x to <8 x i32> + %1 = bitcast <8 x float> %y to <8 x i32> + %2 = or <8 x i32> %0, %1 + %3 = bitcast <8 x i32> %2 to <8 x float> + ret <8 x float> %3 + + def __or__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]: + return Vec._mm256_or_ps(self, other) + + @llvm + def _mm512_or_ps(x: Vec[f32, 16], y: Vec[f32, 16]) -> Vec[f32, 16]: + %0 = bitcast <16 x float> %x to <16 x i32> + %1 = bitcast <16 x float> %y to <16 x i32> + %2 = or <16 x i32> %0, %1 + %3 = bitcast <16 x i32> %2 to <16 x float> + ret <16 x float> %3 + + def __or__(self: Vec[f32, 16], other: Vec[f32, 16]) -> Vec[f32, 16]: + return Vec._mm512_or_ps(self, other) + + @llvm + def _mm_bsrli_si128_8(vec: Vec[u8, 16]) -> Vec[u8, 16]: + %0 = shufflevector <16 x i8> %vec, <16 x i8> zeroinitializer, <16 x i32> + ret <16 x i8> %0 + + @llvm + def _mm256_add_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]: + %0 = fadd <8 x float> %x, %y + ret <8 x float> %0 + + def __add__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]: + return Vec._mm256_add_ps(self, other) + + def __rshift__(self: Vec[u8, 16], shift: Static[int]) -> Vec[u8, 16]: + if shift == 0: + return self + elif shift == 8: + return Vec._mm_bsrli_si128_8(self) + else: + compile_error("invalid bitshift") + + @llvm + def _mm_bsrli_256(vec: Vec[u8, 32]) -> Vec[u8, 32]: + %0 = shufflevector <32 x i8> %vec, <32 x i8> zeroinitializer, <32 x i32> + ret <32 x i8> %0 + + def __rshift__(self: Vec[u8, 32], shift: Static[int]) -> Vec[u8, 32]: + if shift == 0: + return self + elif shift == 16: + return Vec._mm_bsrli_256(self) + else: + compile_error("invalid bitshift") + + # @llvm # https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction + # def sum(self: Vec[f32, 8]) -> f32: + # %0 = shufflevector <8 x float> %self, <8 x float> undef, <4 x i32> + # %1 = shufflevector <8 x float> %self, <8 x float> poison, <4 x i32> + # %2 = fadd <4 x float> %0, %1 + # %3 = shufflevector <4 x float> %2, <4 x float> undef, <4 x i32> + # %4 = fadd <4 x float> %2, %3 + # %5 = shufflevector <4 x float> %4, <4 x float> poison, <4 x i32> + # %6 = fadd <4 x float> %4, %5 + # %7 = extractelement <4 x float> %6, i32 0 + # ret float %7 + + def sum(self: Vec[f32, 8], x: f32 = f32(0.0)) -> f32: + return x + self[0] + self[1] + self[2] + self[3] + self[4] + self[5] + self[6] + self[7] + + @llvm + def __getitem__(self, n: Static[int]) -> T: + %0 = extractelement <{=N} x {=T}> %self, i32 {=n} + ret {=T} %0 + + def __repr__(self): + if N == 8: + return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}>" + elif N == 16: + return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}, {self[8]}, {self[9]}, {self[10]}, {self[11]}, {self[12]}, {self[13]}, {self[14]}, {self[15]}>" + else: + return "?" + +u8x16 = Vec[u8, 16] +u8x32 = Vec[u8, 32] +f32x8 = Vec[f32, 8] \ No newline at end of file diff --git a/test/core/containers.codon b/test/core/containers.codon index e906f5b4..ed6d58e6 100644 --- a/test/core/containers.codon +++ b/test/core/containers.codon @@ -1,3 +1,5 @@ +from copy import copy, deepcopy + @tuple class A: a: int @@ -118,6 +120,24 @@ def test_list(): assert List[int]().copy() == List[int]() assert [1,2,3].copy() == [1,2,3] + + def test_cmp[T](a: T, b: T): + yield 'EQ', a == b + yield 'NE', a != b + yield 'LT', a < b + yield 'GT', a > b + yield 'LE', a <= b + yield 'GE', a >= b + + assert list(test_cmp([1,2], [1,2])) == [('EQ', True), ('NE', False), ('LT', False), ('GT', False), ('LE', True), ('GE', True)] + assert list(test_cmp([1,2,2], [1,2,3])) == [('EQ', False), ('NE', True), ('LT', True), ('GT', False), ('LE', True), ('GE', False)] + assert list(test_cmp([1,2,-1], [1,0,1])) == [('EQ', False), ('NE', True), ('LT', False), ('GT', True), ('LE', False), ('GE', True)] + assert list(test_cmp(List[int](), List[int]())) == [('EQ', True), ('NE', False), ('LT', False), ('GT', False), ('LE', True), ('GE', True)] + assert list(test_cmp([1], List[int]())) == [('EQ', False), ('NE', True), ('LT', False), ('GT', True), ('LE', False), ('GE', True)] + assert list(test_cmp(List[int](), [1])) == [('EQ', False), ('NE', True), ('LT', True), ('GT', False), ('LE', True), ('GE', False)] + assert list(test_cmp([1,2,-1], [2])) == [('EQ', False), ('NE', True), ('LT', True), ('GT', False), ('LE', True), ('GE', False)] + assert list(test_cmp([1,2,-1], [1,2,-1,3])) == [('EQ', False), ('NE', True), ('LT', True), ('GT', False), ('LE', True), ('GE', False)] + assert list(test_cmp([1,2,-1,3], [1,2,-1])) == [('EQ', False), ('NE', True), ('LT', False), ('GT', True), ('LE', False), ('GE', True)] test_list() @test @@ -373,6 +393,12 @@ def test_dict(): assert d2['y'] == -1 assert d2['z'] == 2 assert d2 == {'x': 11, 'y': -1, 'z': 2} + + d3 = {1: 2, 42: 42} + d4 = {1: 5, 2: 9} + assert d3 | d4 == {1: 5, 42: 42, 2: 9} + d3 |= d4 + assert d3 == {1: 5, 42: 42, 2: 9} test_dict() @test @@ -385,7 +411,7 @@ def test_deque(): dq.append(3) dq.appendleft(11) dq.appendleft(22) - assert str(dq) == '[22, 11, 1, 2, 3]' + assert str(dq) == 'deque([22, 11, 1, 2, 3])' assert bool(dq) == True # test cap increase: @@ -394,29 +420,29 @@ def test_deque(): for i in range(20): dq.append(i) dq.appendleft(i) - assert str(dq) == '[19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]' + assert str(dq) == 'deque([19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])' assert len(dq) == 40 for i in range(19): dq.pop() dq.popleft() - assert str(dq) == '[0, 0]' + assert str(dq) == 'deque([0, 0])' for a in dq: assert a == 0 assert (0 in dq) == True assert (1 in dq) == False - assert str(copy(dq)) == '[0, 0]' + assert str(copy(dq)) == 'deque([0, 0])' # test maxlen: dq = deque[int](5) for i in range(100): dq.append(i) - assert str(dq) == '[95, 96, 97, 98, 99]' + assert str(dq) == 'deque([95, 96, 97, 98, 99])' for i in range(5): dq.append(i) - assert str(dq) == '[0, 1, 2, 3, 4]' + assert str(dq) == 'deque([0, 1, 2, 3, 4])' test_deque() @test @@ -635,4 +661,193 @@ def test_counter(): exp = sorted(d.values(), reverse=True) got = [v for k,v in d.most_common()] assert exp == got + + assert repr(Counter('abcabc')) == "Counter({'a': 2, 'b': 2, 'c': 2})" test_counter() + +@test +def test_defaultdict(): + from collections import defaultdict + + # basic + #d1 = defaultdict() + #self.assertEqual(d1.default_factory, None) + #d1.default_factory = list + d1 = defaultdict(list) + d1[12].append(42) + assert d1 == {12: [42]} + d1[12].append(24) + assert d1 == {12: [42, 24]} + d1[13] + d1[14] + assert d1 == {12: [42, 24], 13: [], 14: []} + assert d1[12] is not d1[13] is not d1[14] + #d2 = defaultdict(list, foo=1, bar=2) + #self.assertEqual(d2.default_factory, list) + #self.assertEqual(d2, {"foo": 1, "bar": 2}) + #self.assertEqual(d2["foo"], 1) + #self.assertEqual(d2["bar"], 2) + #self.assertEqual(d2[42], []) + #self.assertIn("foo", d2) + #self.assertIn("foo", d2.keys()) + #self.assertIn("bar", d2) + #self.assertIn("bar", d2.keys()) + #self.assertIn(42, d2) + #self.assertIn(42, d2.keys()) + #self.assertNotIn(12, d2) + #self.assertNotIn(12, d2.keys()) + #d2.default_factory = None + #self.assertEqual(d2.default_factory, None) + #try: + # d2[15] + #except KeyError as err: + # self.assertEqual(err.args, (15,)) + #else: + # self.fail("d2[15] didn't raise KeyError") + #self.assertRaises(TypeError, defaultdict, 1) + + # missing + #d1 = defaultdict() + #self.assertRaises(KeyError, d1.__missing__, 42) + #d1.default_factory = list + #d1 = defaultdict(list) + assert d1.__missing__(42) == [] + + # repr + d1 = defaultdict(lambda: 0) + #self.assertEqual(d1.default_factory, None) + #self.assertEqual(repr(d1), "defaultdict(None, {})") + #self.assertEqual(eval(repr(d1)), d1) + d1[11] = 41 + assert repr(d1) == "defaultdict(, {11: 41})" + d2 = defaultdict(lambda: 0) # TODO: use 'int' when it's fixed... + #self.assertEqual(d2.default_factory, int) + d2[12] = 42 + assert repr(d2) == "defaultdict(, {12: 42})" + def foo(): return 43 + d3 = defaultdict(foo) + #self.assertTrue(d3.default_factory is foo) + d3[13] + assert repr(d3) == "defaultdict(, {13: 43})" + + + # copy + d1 = defaultdict(list) + d2 = d1.copy() + #self.assertEqual(type(d2), defaultdict) + #self.assertEqual(d2.default_factory, None) + assert d2 == {} + #d1.default_factory = list + #d3 = d1.copy() + #self.assertEqual(type(d3), defaultdict) + #self.assertEqual(d3.default_factory, list) + #self.assertEqual(d3, {}) + d1[42].append(0) + #d4 = d1.copy() + #assert d4 == {42: [0]} + #d4[12] + #assert d4 == {42: [], 12: []} + + # Issue 6637: Copy fails for empty default dict + #d = defaultdict() + #d['a'] = 42 + #e = d.copy() + #assert e['a'] == 42 + + + # shallow copy + foobar = list + d1 = defaultdict(foobar) + d1[1] += [1] + d2 = copy(d1) + #self.assertEqual(d2.default_factory, foobar) + assert d2 == d1 + #d1.default_factory = list + d2 = copy(d1) + #self.assertEqual(d2.default_factory, list) + assert d2 == d1 + + # deep copy + d1 = defaultdict(foobar) + d1[1].append(1) + d2 = deepcopy(d1) + #self.assertEqual(d2.default_factory, foobar) + assert d2 == d1 + assert d1[1] is not d2[1] + #d1.default_factory = list + d2 = deepcopy(d1) + #self.assertEqual(d2.default_factory, list) + assert d2 == d1 + + # KeyError without factory + #d1 = defaultdict() + #try: + # d1[(1,)] + #except KeyError as err: + # self.assertEqual(err.args[0], (1,)) + #else: + # self.fail("expected KeyError") + + # pickling + #d = defaultdict(int) + #d[1] + #for proto in range(pickle.HIGHEST_PROTOCOL + 1): + # s = pickle.dumps(d, proto) + # o = pickle.loads(s) + # self.assertEqual(d, o) + + # union + i = defaultdict(int, {1: 1, 2: 2}) + s = defaultdict(int, {0: 0, 1: 111}) + + i_s = i | s + #self.assertIs(i_s.default_factory, int) + assert i_s == {1: 111, 2: 2, 0: 0} + assert sorted(i_s) == [0, 1, 2] + + s_i = s | i + #self.assertIs(s_i.default_factory, str) + assert s_i == {0: 0, 1: 1, 2: 2} + assert sorted(s_i) == [0, 1, 2] + + i_ds = i | dict(s) + #self.assertIs(i_ds.default_factory, int) + assert i_ds == {1: 111, 2: 2, 0: 0} + assert sorted(i_ds) == [0, 1, 2] + + ds_i = dict(s) | i + #self.assertIs(ds_i.default_factory, int) + assert ds_i == {0: 0, 1: 1, 2: 2} + assert sorted(ds_i) == [0, 1, 2] + + # We inherit a fine |= from dict, so just a few sanity checks here: + i |= list(s.items()) + #self.assertIs(i.default_factory, int) + assert i == {1: 111, 2: 2, 0: 0} + assert sorted(i), [1, 2, 0] + + # general + s = 'mississippi' + d = defaultdict(int) + for k in s: + d[k] += 1 + assert sorted(d.items()) == [('i', 4), ('m', 1), ('p', 2), ('s', 4)] + + s = 'mississippi' + d = defaultdict(int) + for k in s: + d[k] = d.get(k, 0) + 1 + assert sorted(d.items()) == [('i', 4), ('m', 1), ('p', 2), ('s', 4)] + + s = [('yellow', 1), ('blue', 2), ('yellow', 3), ('blue', 4), ('red', 1)] + d = defaultdict(list) + for k, v in s: + d[k].append(v) + assert sorted(d.items()) == [('blue', [2, 4]), ('red', [1]), ('yellow', [1, 3])] + + def constant_factory(value): + return lambda: value + + d = defaultdict(constant_factory('')) + assert d[10] == '' +test_defaultdict() diff --git a/test/core/exceptions.codon b/test/core/exceptions.codon index 620e595b..35f29f0f 100644 --- a/test/core/exceptions.codon +++ b/test/core/exceptions.codon @@ -501,3 +501,39 @@ def test_property_exceptions(): # EXPECT: foo test_property_exceptions() + +def test_empty_raise(): + def foo(b): + if b: + raise ValueError('A') + else: + raise IOError('B') + + def bar(b): + try: + foo(b) + print('X') + except IOError as e: + print(e) + raise + except: + raise + + def baz(b): + try: + bar(b) + except ValueError as e: + print(e) + raise + + for b in (False, True): + try: + baz(b) + except: + print('C') + +# EXPECT: B +# EXPECT: C +# EXPECT: A +# EXPECT: C +test_empty_raise() diff --git a/test/core/numerics.codon b/test/core/numerics.codon new file mode 100644 index 00000000..232af5c5 --- /dev/null +++ b/test/core/numerics.codon @@ -0,0 +1,786 @@ +import operator as op +import math + +NAN = math.nan +INF = math.inf +NINF = -math.inf + +@test +def test_py_numerics_int(): + one = 1 + iz = 0 + fz = 0.0 + n = 0 + + # through function (not optimized / pre-evaluated) + assert op.floordiv(-5, 2) == -3 + assert op.floordiv(-5, 2.0) == -3.0 + assert op.truediv(-5, 2) == -2.5 + assert op.truediv(-5, 2.0) == -2.5 + assert op.mod(-10, 3) == 2 + assert op.mod(-1, 0.3) == 0.19999999999999996 + assert divmod(-10, 3) == (-4, 2) + assert divmod(-1, 0.3) == (-4.0, 0.19999999999999996) + + # with vars (evaluated in IR) + a = -5 + b = 2 + c = 2.0 + d = -10 + e = 3 + f = -1 + g = 0.3 + assert a // b == -3 + assert a // c == -3.0 + assert a / b == -2.5 + assert a / c == -2.5 + assert d % e == 2 + assert f % g == 0.19999999999999996 + + # constant (evaluated statically by parser) + assert -5 // 2 == -3 + assert -10 % 3 == 2 + + # errors + try: + print(one // fz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float floor division by zero' + n += 1 + + try: + print(one // iz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'integer division or modulo by zero' + n += 1 + + try: + print(one / fz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float division by zero' + n += 1 + + try: + print(one / iz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'division by zero' + n += 1 + + try: + print(one % fz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float modulo' + n += 1 + + try: + print(one % iz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'integer division or modulo by zero' + n += 1 + + try: + print(divmod(one, iz)) + assert False + except ZeroDivisionError as e: + assert str(e) == 'integer division or modulo by zero' + n += 1 + + try: + print(divmod(one, fz)) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float divmod()' + n += 1 + + assert n == 8 + +@test +def test_py_numerics_float(): + one = 1.0 + iz = 0 + fz = 0.0 + n = 0 + + # through function (not optimized / pre-evaluated) + assert op.floordiv(-5.6, 2) == -3.0 + assert op.floordiv(-5.6, 2.0) == -3.0 + assert op.truediv(-5.6, 2) == -2.8 + assert op.truediv(-5.6, 2.0) == -2.8 + assert op.mod(-10.0, 3) == 2.0 + assert op.mod(-1.0, 0.3) == 0.19999999999999996 + assert divmod(-10.0, 3) == (-4.0, 2.0) + assert divmod(-1.0, 0.3) == (-4.0, 0.19999999999999996) + + # with vars (evaluated in IR) + a = -5.6 + b = 2 + c = 2.0 + d = -10.0 + e = 3 + f = -1.0 + g = 0.3 + assert a // b == -3 + assert a // c == -3.0 + assert a / b == -2.8 + assert a / c == -2.8 + assert d % e == 2 + assert f % g == 0.19999999999999996 + + try: + print(one // fz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float floor division by zero' + n += 1 + + try: + print(one // iz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float floor division by zero' + n += 1 + + try: + print(one / fz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float division by zero' + n += 1 + + try: + print(one / iz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float division by zero' + n += 1 + + try: + print(one % fz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float modulo' + n += 1 + + try: + print(one % iz) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float modulo' + n += 1 + + try: + print(divmod(one, iz)) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float divmod()' + n += 1 + + try: + print(divmod(one, fz)) + assert False + except ZeroDivisionError as e: + assert str(e) == 'float divmod()' + n += 1 + + assert n == 8 + + +import math + +NAN = math.nan +INF = math.inf +NINF = -math.inf + + +def close(a: float, b: float, epsilon: float = 1e-7): + return abs(a - b) <= epsilon + + +@test +def test_isnan(): + assert math.isnan(float("nan")) == True + assert math.isnan(4.0) == False + + +@test +def test_isinf(): + assert math.isinf(float("inf")) == True + assert math.isinf(7.0) == False + + +@test +def test_isfinite(): + assert math.isfinite(1.4) == True + assert math.isfinite(0.0) == True + assert math.isfinite(NAN) == False + assert math.isfinite(INF) == False + assert math.isfinite(NINF) == False + + +@test +def test_ceil(): + assert math.ceil(3.3) == 4 + assert math.ceil(0.5) == 1 + assert math.ceil(1.0) == 1 + assert math.ceil(1.5) == 2 + assert math.ceil(-0.5) == 0 + assert math.ceil(-1.0) == -1 + assert math.ceil(-1.5) == -1 + + +@test +def test_floor(): + assert math.floor(3.3) == 3 + assert math.floor(0.5) == 0 + assert math.floor(1.0) == 1 + assert math.floor(1.5) == 1 + assert math.floor(-0.5) == -1 + assert math.floor(-1.0) == -1 + assert math.floor(-1.5) == -2 + + +@test +def test_fabs(): + assert math.fabs(-1.0) == 1 + assert math.fabs(0.0) == 0 + assert math.fabs(1.0) == 1 + + +@test +def test_fmod(): + assert math.fmod(10.0, 1.0) == 0.0 + assert math.fmod(10.0, 0.5) == 0.0 + assert math.fmod(10.0, 1.5) == 1.0 + assert math.fmod(-10.0, 1.0) == -0.0 + assert math.fmod(-10.0, 0.5) == -0.0 + assert math.fmod(-10.0, 1.5) == -1.0 + + +@test +def test_exp(): + assert math.exp(0.0) == 1 + assert math.exp(-1.0) == 1 / math.e + assert math.exp(1.0) == math.e + + +@test +def test_expm1(): + assert math.expm1(0.0) == 0 + assert close(math.expm1(1.0), 1.7182818284590453) + assert close(math.expm1(3.0), 19.085536923187668) + assert close(math.expm1(5.0), 147.4131591025766) + assert math.expm1(INF) == INF + assert math.expm1(NINF) == -1 + assert math.isnan(math.expm1(NAN)) == True + + +@test +def test_ldexp(): + assert math.ldexp(0.0, 1) == 0.0 + assert math.ldexp(1.0, 1) == 2.0 + assert math.ldexp(1.0, -1) == 0.5 + assert math.ldexp(-1.0, 1) == -2.0 + assert math.ldexp(0.0, 1) == 0.0 + assert math.ldexp(1.0, -1000000) == 0.0 + assert math.ldexp(-1.0, -1000000) == -0.0 + assert math.ldexp(INF, 30) == INF + assert math.ldexp(NINF, -213) == NINF + assert math.isnan(math.ldexp(NAN, 0)) == True + + +@test +def test_log(): + assert math.log(1.0 / math.e) == -1 + assert math.log(1.0) == 0 + assert math.log(math.e) == 1 + + +@test +def test_log2(): + assert math.log2(1.0) == 0.0 + assert math.log2(2.0) == 1.0 + assert math.log2(4.0) == 2.0 + assert math.log2(2.0 ** 1023) == 1023.0 + try: + math.log2(-1.5) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + try: + math.log2(NINF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + assert math.isnan(math.log2(NAN)) == True + + +@test +def test_log10(): + assert math.log10(0.1) == -1 + assert math.log10(1.0) == 0 + assert math.log10(10.0) == 1 + assert math.log10(10000.0) == 4 + + +@test +def test_degrees(): + assert math.degrees(math.pi) == 180.0 + assert math.degrees(math.pi / 2) == 90.0 + assert math.degrees(-math.pi / 4) == -45.0 + assert math.degrees(0.0) == 0.0 + + +@test +def test_radians(): + assert math.radians(180.0) == math.pi + assert math.radians(90.0) == math.pi / 2 + assert math.radians(-45.0) == -math.pi / 4 + assert math.radians(0.0) == 0.0 + + +@test +def test_sqrt(): + assert math.sqrt(4.0) == 2 + assert math.sqrt(0.0) == 0 + assert math.sqrt(1.0) == 1 + try: + math.sqrt(-1.0) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + +@test +def test_pow(): + assert math.pow(0.0, 1.0) == 0 + assert math.pow(1.0, 0.0) == 1 + assert math.pow(2.0, 1.0) == 2 + assert math.pow(2.0, -1.0) == 0.5 + assert math.pow(-0.0, 3.0) == -0.0 + assert math.pow(-0.0, 2.3) == 0.0 + assert math.pow(-0.0, 0.0) == 1 + assert math.pow(-0.0, -0.0) == 1 + assert math.pow(-2.0, 2.0) == 4.0 + assert math.pow(-2.0, 3.0) == -8.0 + assert math.pow(-2.0, -3.0) == -0.125 + assert math.pow(INF, 1.0) == INF + assert math.pow(NINF, 1.0) == NINF + assert math.pow(1.0, INF) == 1 + assert math.pow(1.0, NINF) == 1 + assert math.isnan(math.pow(NAN, 1.0)) == True + assert math.isnan(math.pow(2.0, NAN)) == True + assert math.isnan(math.pow(0.0, NAN)) == True + assert math.pow(1.0, NAN) == 1 + try: + math.pow(10.0, 400.0) + assert False + except OverflowError as e: + assert str(e) == 'math range error' + + +@test +def test_acos(): + assert math.acos(-1.0) == math.pi + assert math.acos(0.0) == math.pi / 2 + assert math.acos(1.0) == 0 + assert math.isnan(math.acos(NAN)) == True + + +@test +def test_asin(): + assert math.asin(-1.0) == -math.pi / 2 + assert math.asin(0.0) == 0 + assert math.asin(1.0) == math.pi / 2 + assert math.isnan(math.asin(NAN)) == True + + +@test +def test_atan(): + assert math.atan(-1.0) == -math.pi / 4 + assert math.atan(0.0) == 0 + assert math.atan(1.0) == math.pi / 4 + assert math.atan(INF) == math.pi / 2 + assert math.atan(NINF) == -math.pi / 2 + assert math.isnan(math.atan(NAN)) == True + + +@test +def test_atan2(): + assert math.atan2(-1.0, 0.0) == -math.pi / 2 + assert math.atan2(-1.0, 1.0) == -math.pi / 4 + assert math.atan2(0.0, 1.0) == 0 + assert math.atan2(1.0, 1.0) == math.pi / 4 + assert math.atan2(1.0, 0.0) == math.pi / 2 + assert math.atan2(-0.0, 0.0) == -0 + assert math.atan2(-0.0, 2.3) == -0 + assert math.atan2(0.0, -2.3) == math.pi + assert math.atan2(INF, NINF) == math.pi * 3 / 4 + assert math.atan2(INF, 2.3) == math.pi / 2 + assert math.isnan(math.atan2(NAN, 0.0)) == True + + +@test +def test_cos(): + assert math.cos(0.0) == 1 + assert close(math.cos(math.pi / 2), 6.123233995736766e-17) + assert close(math.cos(-math.pi / 2), 6.123233995736766e-17) + assert math.cos(math.pi) == -1 + + try: + math.cos(INF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + try: + math.cos(NINF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + assert math.isnan(math.cos(NAN)) == True + + +@test +def test_sin(): + assert math.sin(0.0) == 0 + assert math.sin(math.pi / 2) == 1 + assert math.sin(-math.pi / 2) == -1 + + try: + math.sin(INF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + try: + math.sin(NINF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + assert math.isnan(math.sin(NAN)) == True + + +@test +def test_hypot(): + assert math.hypot(12.0, 5.0) == 13 + assert math.hypot(12.0 / 32.0, 5.0 / 32) == 13 / 32 + assert math.hypot(0.0, 0.0) == 0 + assert math.hypot(-3.0, 4.0) == 5 + assert math.hypot(3.0, 4.0) == 5 + + +@test +def test_tan(): + assert math.tan(0.0) == 0 + assert close(math.tan(math.pi / 4), 0.9999999999999999) + assert close(math.tan(-math.pi / 4), -0.9999999999999999) + + try: + math.tan(INF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + try: + math.tan(NINF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + assert math.isnan(math.tan(NAN)) == True + + +@test +def test_cosh(): + assert math.cosh(0.0) == 1 + assert math.cosh(2.0) - 2 * math.cosh(1.0) ** 2 == -1 + assert math.cosh(INF) == INF + assert math.cosh(NINF) == INF + assert math.isnan(math.cosh(NAN)) == True + + +@test +def test_sinh(): + assert math.sinh(0.0) == 0 + assert math.sinh(1.0) + math.sinh(-1.0) == 0 + assert math.sinh(INF) == INF + assert math.sinh(NINF) == NINF + assert math.isnan(math.sinh(NAN)) == True + + +@test +def test_tanh(): + assert math.tanh(0.0) == 0 + assert math.tanh(1.0) + math.tanh(-1.0) == 0 + assert math.tanh(INF) == 1 + assert math.tanh(NINF) == -1 + assert math.isnan(math.tanh(NAN)) == True + + +@test +def test_acosh(): + assert math.acosh(1.0) == 0 + assert close(math.acosh(2.0), 1.3169578969248166) + assert math.acosh(INF) == INF + assert math.isnan(math.acosh(NAN)) == True + try: + math.acosh(-1.0) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + +@test +def test_asinh(): + assert math.asinh(0.0) == 0 + assert close(math.asinh(1.0), 0.881373587019543) + assert close(math.asinh(-1.0), -0.881373587019543) + assert math.asinh(INF) == INF + assert math.isnan(math.asinh(NAN)) == True + assert math.asinh(NINF) == NINF + + +@test +def test_atanh(): + assert math.atanh(0.0) == 0 + assert close(math.atanh(0.5), 0.5493061443340549) + assert close(math.atanh(-0.5), -0.5493061443340549) + + try: + math.atanh(INF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + try: + math.atanh(NINF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + assert math.isnan(math.atanh(NAN)) == True + + +@test +def test_copysign(): + assert math.copysign(1.0, -0.0) == -1 + assert math.copysign(1.0, 42.0) == 1 + assert math.copysign(1.0, -42.0) == -1 + assert math.copysign(3.0, 0.0) == 3 + assert math.copysign(INF, 0.0) == INF + assert math.copysign(INF, -0.0) == NINF + assert math.copysign(NINF, 0.0) == INF + assert math.copysign(NINF, -0.0) == NINF + assert math.copysign(1.0, INF) == 1 + assert math.copysign(1.0, NINF) == -1 + assert math.copysign(INF, INF) == INF + assert math.copysign(INF, NINF) == NINF + assert math.copysign(NINF, INF) == INF + assert math.copysign(NINF, NINF) == NINF + assert math.isnan(math.copysign(NAN, 1.0)) == True + assert math.isnan(math.copysign(NAN, INF)) == True + assert math.isnan(math.copysign(NAN, NINF)) == True + assert math.isnan(math.copysign(NAN, NAN)) == True + + +@test +def test_log1p(): + assert close(math.log1p(2.0), 1.0986122886681098) + assert close(math.log1p(2.0 ** 90), 62.383246250395075) + assert close(math.log1p(2.0 ** 300), 207.94415416798358) + assert math.log1p(INF) == INF + try: + math.log1p(-1.0) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + + +@test +def test_trunc(): + assert math.trunc(1.0) == 1 + assert math.trunc(-1.0) == -1 + assert math.trunc(1.5) == 1 + assert math.trunc(-1.5) == -1 + assert math.trunc(1.99999999) == 1 + assert math.trunc(-1.99999999) == -1 + assert math.trunc(0.99999999) == 0 + assert math.trunc(-100.999) == -100 + + +@test +def test_erf(): + assert close(math.erf(1.0), 0.8427007929497148) + assert math.erf(0.0) == 0 + assert close(math.erf(3.0), 0.9999779095030015) + assert math.erf(256.0) == 1.0 + assert math.erf(INF) == 1.0 + assert math.erf(NINF) == -1.0 + assert math.isnan(math.erf(NAN)) == True + + +@test +def test_erfc(): + assert math.erfc(0.0) == 1.0 + assert close(math.erfc(1.0), 0.15729920705028516) + assert close(math.erfc(2.0), 0.0046777349810472645) + assert close(math.erfc(-1.0), 1.8427007929497148) + assert math.erfc(INF) == 0.0 + assert math.erfc(NINF) == 2.0 + assert math.isnan(math.erfc(NAN)) == True + + +@test +def test_gamma(): + assert close(math.gamma(6.0), 120.0) + assert close(math.gamma(1.0), 1.0) + assert close(math.gamma(2.0), 1.0) + assert close(math.gamma(3.0), 2.0) + try: + math.gamma(-1.0) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + assert math.gamma(INF) == INF + try: + math.gamma(NINF) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + assert math.isnan(math.gamma(NAN)) == True + + +@test +def test_lgamma(): + assert math.lgamma(1.0) == 0.0 + assert math.lgamma(2.0) == 0.0 + #assert math.lgamma(-1.0) == INF # Python's custom lgamma gives math domain error + assert math.lgamma(INF) == INF + assert math.lgamma(NINF) == INF + assert math.isnan(math.lgamma(NAN)) == True + + +@test +def test_remainder(): + assert math.remainder(2.0, 2.0) == 0.0 + assert math.remainder(-4.0, 1.0) == -0.0 + assert close(math.remainder(-3.8, 1.0), 0.20000000000000018) + assert close(math.remainder(3.8, 1.0), -0.20000000000000018) + try: + math.remainder(INF, 1.0) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + try: + math.remainder(NINF, 1.0) + assert False + except ValueError as e: + assert str(e) == 'math domain error' + assert math.isnan(math.remainder(NAN, 1.0)) == True + + +@test +def test_gcd(): + assert math.gcd(0.0, 0.0) == 0 + assert math.gcd(1.0, 0.0) == 1 + assert math.gcd(-1.0, 0.0) == 1 + assert math.gcd(0.0, -1.0) == 1 + assert math.gcd(0.0, 1.0) == 1 + assert math.gcd(7.0, 1.0) == 1 + assert math.gcd(7.0, -1.0) == 1 + assert math.gcd(-23.0, 15.0) == 1 + assert math.gcd(120.0, 84.0) == 12 + assert math.gcd(84.0, -120.0) == 12 + + +@test +def test_frexp(): + assert math.frexp(-2.0) == (-0.5, 2) + assert math.frexp(-1.0) == (-0.5, 1) + assert math.frexp(0.0) == (0.0, 0) + assert math.frexp(1.0) == (0.5, 1) + assert math.frexp(2.0) == (0.5, 2) + assert math.frexp(INF)[0] == INF + assert math.frexp(NINF)[0] == NINF + assert math.isnan(math.frexp(NAN)[0]) == True + + +@test +def test_modf(): + assert math.modf(1.5) == (0.5, 1.0) + assert math.modf(-1.5) == (-0.5, -1.0) + assert math.modf(math.inf) == (0.0, INF) + assert math.modf(-math.inf) == (-0.0, NINF) + modf_nan = math.modf(NAN) + assert math.isnan(modf_nan[0]) == True + assert math.isnan(modf_nan[1]) == True + + +@test +def test_isclose(): + assert math.isclose(1.0 + 1.0, 1.000000000001 + 1.0) == True + assert math.isclose(2.90909324093284, 2.909093240932844234234234234) == True + assert math.isclose(2.90909324093284, 2.9) == False + assert math.isclose(2.90909324093284, 2.90909324) == True + assert math.isclose(2.90909324, 2.90909325) == False + assert math.isclose(NAN, 2.9) == False + assert math.isclose(2.9, NAN) == False + assert math.isclose(INF, INF) == True + assert math.isclose(NINF, NINF) == True + assert math.isclose(NINF, INF) == False + assert math.isclose(INF, NINF) == False + +test_py_numerics_int() +test_py_numerics_float() + +test_isnan() +test_isinf() +test_isfinite() +test_ceil() +test_floor() +test_fabs() +test_fmod() +test_exp() +test_expm1() +test_ldexp() +test_log() +test_log2() +test_log10() +test_degrees() +test_radians() +test_sqrt() +test_pow() +test_acos() +test_asin() +test_atan() +test_atan2() +test_cos() +test_sin() +test_hypot() +test_tan() +test_cosh() +test_sinh() +test_tanh() +test_acosh() +test_asinh() +test_atanh() +test_copysign() +test_log1p() +test_trunc() +test_erf() +test_erfc() +test_gamma() +test_lgamma() +test_remainder() +test_gcd() +test_frexp() +test_modf() +test_isclose() diff --git a/test/core/serialization.codon b/test/core/serialization.codon index 91727aba..fc91d5dc 100644 --- a/test/core/serialization.codon +++ b/test/core/serialization.codon @@ -1,4 +1,5 @@ import pickle +from copy import copy @tuple class MyType: diff --git a/test/main.cpp b/test/main.cpp index 4218cccb..fb175503 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -235,10 +235,10 @@ static pair, bool> findExpects(const string &filename, bool isCod string argv0; -class SeqTest - : public testing::TestWithParam> { +class SeqTest : public testing::TestWithParam< + tuple> { vector buf; int out_pipe[2]; pid_t pid; @@ -265,9 +265,11 @@ public: auto code = get<3>(GetParam()); auto startLine = get<4>(GetParam()); int testFlags = 1 + get<5>(GetParam()); + bool pyNumerics = get<6>(GetParam()); auto compiler = std::make_unique( - argv0, debug, /*disabledPasses=*/std::vector{}, /*isTest=*/true); + argv0, debug, /*disabledPasses=*/std::vector{}, /*isTest=*/true, + pyNumerics); compiler->getLLVMVisitor()->setStandalone( true); // make sure we abort() on runtime error llvm::handleAllErrors(code.empty() @@ -364,7 +366,7 @@ TEST_P(SeqTest, Run) { } } auto getTypeTests(const vector &files) { - vector> cases; + vector> cases; for (auto &f : files) { bool barebones = false; string l; @@ -377,7 +379,7 @@ auto getTypeTests(const vector &files) { if (l.substr(0, 3) == "#%%") { if (line) cases.emplace_back(make_tuple(f, true, to_string(line) + "_" + testName, code, - codeLine, barebones)); + codeLine, barebones, false)); auto t = ast::split(l.substr(4), ','); barebones = (t.size() > 1 && t[1] == "barebones"); testName = t[0]; @@ -391,7 +393,7 @@ auto getTypeTests(const vector &files) { } if (line) cases.emplace_back(make_tuple(f, true, to_string(line) + "_" + testName, code, - codeLine, barebones)); + codeLine, barebones, false)); } return cases; } @@ -433,10 +435,26 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(""), testing::Values(""), testing::Values(0), + testing::Values(false), testing::Values(false) ), getTestNameFromParam); +INSTANTIATE_TEST_SUITE_P( + NumericsTests, SeqTest, + testing::Combine( + testing::Values( + "core/numerics.codon" + ), + testing::Values(true, false), + testing::Values(""), + testing::Values(""), + testing::Values(0), + testing::Values(false), + testing::Values(true) + ), + getTestNameFromParam); + INSTANTIATE_TEST_SUITE_P( StdlibTests, SeqTest, testing::Combine( @@ -459,6 +477,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(""), testing::Values(""), testing::Values(0), + testing::Values(false), testing::Values(false) ), getTestNameFromParam); @@ -482,6 +501,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(""), testing::Values(""), testing::Values(0), + testing::Values(false), testing::Values(false) ), getTestNameFromParam); diff --git a/test/parser/simplify_expr.codon b/test/parser/simplify_expr.codon index 6c1565f6..cd48fa5c 100644 --- a/test/parser/simplify_expr.codon +++ b/test/parser/simplify_expr.codon @@ -418,7 +418,7 @@ except NameError: 1 ... 3 #! unexpected pattern range expression #%% callable_error,barebones -def foo(x: Callable[[]]): pass #! invalid Callable type declaration +def foo(x: Callable[[]]): pass #! expected 2 generics and/or statics #%% unpack_specials,barebones x, = 1, diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index fd12eaa4..d83bdde2 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -580,12 +580,22 @@ def f(x): return g(x) print f(5) #: 6 -#%% function_err_nested,barebones +##% nested_generic_static,barebones +def foo(): + N: Static[int] = 5 + Z: Static[int] = 15 + T = Int[Z] + def bar(): + x = __array__[T](N) + print(x.__class__.__name__) + return bar +foo()() #: Array[Int[15]] + def f[T](): def g(): return T() - g() -f(int) #! cannot access nonlocal variable 'T' + return g() +print f(int) #: 0 #%% class_err_1,barebones @extend @@ -933,6 +943,66 @@ a = {1: "s", 2: "t"} a[3] = foo["str"] print(sorted(a.items())) #: [(1, 's'), (2, 't'), (3, 'hai')] + +#%% python_isinstance +import python + +@python +def foo(): + return 1 + +z = foo() +print(z.__class__.__name__) #: pyobj + +print isinstance(z, pyobj) #: True +print isinstance(z, int) #: False +print isinstance(z, python.int) #: True +print isinstance(z, python.ValueError) #: False + +print isinstance(z, (int, str, python.int)) #: True +print isinstance(z, (int, str, python.AttributeError)) #: False + +try: + foo().x +except python.ValueError: + pass +except python.AttributeError as e: + print('caught', e, e.__class__.__name__) #: caught 'int' object has no attribute 'x' pyobj + + +#%% python_exceptions +import python + +@python +def foo(): + return 1 + +try: + foo().x +except python.AttributeError as f: + print 'py.Att', f #: py.Att 'int' object has no attribute 'x' +except ValueError: + print 'Val' +except PyError as e: + print 'PyError', e +try: + foo().x +except python.ValueError as f: + print 'py.Att', f +except ValueError: + print 'Val' +except PyError as e: + print 'PyError', e #: PyError 'int' object has no attribute 'x' +try: + raise ValueError("ho") +except python.ValueError as f: + print 'py.Att', f +except ValueError: + print 'Val' #: Val +except PyError as e: + print 'PyError', e + + #%% typeof_definition_error,barebones a = 1 class X: @@ -1120,3 +1190,22 @@ def match(match): match match(1) + +#%% loop_domination,barebones +for i in range(2): + try: dat = 1 + except: pass + print(dat) +#: 1 +#: 1 + +#%% block_unroll,barebones +# Ensure that block unrolling is done in RAII manner on error +def foo(): + while True: + def magic(a: x): + return + print b +foo() +#! identifier 'x' not found +#! identifier 'b' not found diff --git a/test/parser/types.codon b/test/parser/types.codon index 6167c2dc..8cf5634d 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -1208,3 +1208,127 @@ foo('hooooooooy!', 1, 2) #! cannot find a method 'foo' with arguments 'str', 'in from C import seq_print(str) x = seq_print("not ") print x #: not None + + +#%% static_for,barebones +def foo(i: Static[int]): + print('static', i, Int[i].__class__.__name__) + +for i in statictuple(1, 2, 3, 4, 5): + foo(i) + if i == 3: break +#: static 1 Int[1] +#: static 2 Int[2] +#: static 3 Int[3] +for i in staticrange(9, 4, -2): + foo(i) + if i == 3: + break +#: static 9 Int[9] +#: static 7 Int[7] +#: static 5 Int[5] +for i in statictuple("x", 1, 3.3, 2): + print(i) +#: x +#: 1 +#: 3.3 +#: 2 + +for i in staticrange(0, 10): + if i % 2 == 0: continue + if i > 8: break + print('xyz', Int[i].__class__.__name__) +print('whoa') +#: xyz Int[1] +#: xyz Int[3] +#: xyz Int[5] +#: xyz Int[7] +#: whoa + +for i in staticrange(15): + if i % 2 == 0: continue + if i > 8: break + print('xyz', Int[i].__class__.__name__) +print('whoa') +#: xyz Int[1] +#: xyz Int[3] +#: xyz Int[5] +#: xyz Int[7] +#: whoa + +for i in statictuple(0, 2, 4, 7, 11, 12, 13): + if i % 2 == 0: continue + if i > 8: break + print('xyz', Int[i].__class__.__name__) +print('whoa') +#: xyz Int[7] +#: whoa + +for i in staticrange(10): # TODO: large values are too slow! + pass +print('done') +#: done + +#%% static_range_error,barebones +for i in staticrange(1000, -2000, -2): + pass +#! staticrange out of bounds (1500 > 1024) + +#%% trait_defdict +class dd(Dict[K,V]): + fn: S + K: type + V: type + S: TypeVar[Callable[[], V]] + + def __init__(self: dd[K, VV, Function[[], V]], VV: TypeVar[V]): + self.fn = lambda: VV() + + def __init__(self, f: S): + self.fn = f + + def __getitem__(self, key: K) -> V: + if key not in self: + self.__setitem__(key, self.fn()) + return super().__getitem__(key) + + +x = dd(list) +x[1] = [1, 2] +print(x[2]) +#: [] +print(x) +#: {1: [1, 2], 2: []} + +z = 5 +y = dd(lambda: z+1) +y.update({'a': 5}) +print(y['b']) +#: 6 +z = 6 +print(y['c']) +#: 7 +print(y) +#: {'a': 5, 'b': 6, 'c': 7} + +xx = dd(lambda: 'empty') +xx.update({1: 's', 2: 'b'}) +print(xx[1], xx[44]) +#: s empty +print(xx) +#: {44: 'empty', 1: 's', 2: 'b'} + +s = 'mississippi' +d = dd(int) +for k in s: + d[k] = d["x" + k] +print(sorted(d.items())) +#: [('i', 0), ('m', 0), ('p', 0), ('s', 0), ('xi', 0), ('xm', 0), ('xp', 0), ('xs', 0)] + + +#%% kwargs_getattr,barebones +def foo(**kwargs): + print kwargs['foo'], kwargs['bar'] + +foo(foo=1, bar='s') +#: 1 s diff --git a/test/python/pybridge.codon b/test/python/pybridge.codon index 32a7eef3..2a3f091c 100644 --- a/test/python/pybridge.codon +++ b/test/python/pybridge.codon @@ -1,3 +1,5 @@ +import python + @test def test_basic(): from python import mymodule @@ -32,7 +34,7 @@ def test_pythrow(): try: te() except PyError as e: - assert e.pytype + ":" + e.message == "ValueError:foo" + assert python.type(e.pytype)._getattr('__name__') + ":" + e.message == "ValueError:foo" return assert False test_pythrow() diff --git a/test/stdlib/math_test.codon b/test/stdlib/math_test.codon index 9e653728..c72bd53f 100644 --- a/test/stdlib/math_test.codon +++ b/test/stdlib/math_test.codon @@ -504,3 +504,510 @@ test_gcd() test_frexp() test_modf() test_isclose() + + +# 32-bit float ops + +NAN32 = math.nan32 +INF32 = math.inf32 +NINF32 = -math.inf32 + + +def close32(a: float32, b: float32, epsilon: float32 = 1e-5f32): + return abs(a - b) <= epsilon + + +@test +def test_float32_isnan(): + assert math.isnan(float32(float("nan"))) == True + assert math.isnan(4.0f32) == False + + +@test +def test_float32_isinf(): + assert math.isinf(float32(float("inf"))) == True + assert math.isinf(7.0f32) == False + + +@test +def test_float32_isfinite(): + assert math.isfinite(1.4f32) == True + assert math.isfinite(0.0f32) == True + assert math.isfinite(NAN32) == False + assert math.isfinite(INF32) == False + assert math.isfinite(NINF32) == False + + +@test +def test_float32_ceil(): + assert math.ceil(3.3f32) == 4.0f32 + assert math.ceil(0.5f32) == 1.0f32 + assert math.ceil(1.0f32) == 1.0f32 + assert math.ceil(1.5f32) == 2.0f32 + assert math.ceil(-0.5f32) == 0.0f32 + assert math.ceil(-1.0f32) == -1.0f32 + assert math.ceil(-1.5f32) == -1.0f32 + + +@test +def test_float32_floor(): + assert math.floor(3.3f32) == 3.0f32 + assert math.floor(0.5f32) == 0.0f32 + assert math.floor(1.0f32) == 1.0f32 + assert math.floor(1.5f32) == 1.0f32 + assert math.floor(-0.5f32) == -1.0f32 + assert math.floor(-1.0f32) == -1.0f32 + assert math.floor(-1.5f32) == -2.0f32 + + +@test +def test_float32_fabs(): + assert math.fabs(-1.0f32) == 1.0f32 + assert math.fabs(0.0f32) == 0.0f32 + assert math.fabs(1.0f32) == 1.0f32 + + +@test +def test_float32_fmod(): + assert math.fmod(10.0f32, 1.0f32) == 0.0f32 + assert math.fmod(10.0f32, 0.5f32) == 0.0f32 + assert math.fmod(10.0f32, 1.5f32) == 1.0f32 + assert math.fmod(-10.0f32, 1.0f32) == -0.0f32 + assert math.fmod(-10.0f32, 0.5f32) == -0.0f32 + assert math.fmod(-10.0f32, 1.5f32) == -1.0f32 + + +@test +def test_float32_exp(): + assert math.exp(0.0f32) == 1.0f32 + assert math.exp(-1.0f32) == 1.0f32 / math.e32 + assert math.exp(1.0f32) == math.e32 + + +@test +def test_float32_expm1(): + assert math.expm1(0.0f32) == 0.0f32 + assert close32(math.expm1(1.0f32), 1.7182818284590453f32) + assert close32(math.expm1(3.0f32), 19.085536923187668f32) + assert close32(math.expm1(5.0f32), 147.4131591025766f32) + assert math.expm1(INF32) == INF32 + assert math.expm1(NINF32) == -1.0f32 + assert math.isnan(math.expm1(NAN32)) == True + + +@test +def test_float32_ldexp(): + assert math.ldexp(0.0f32, 1) == 0.0f32 + assert math.ldexp(1.0f32, 1) == 2.0f32 + assert math.ldexp(1.0f32, -1) == 0.5f32 + assert math.ldexp(-1.0f32, 1) == -2.0f32 + assert math.ldexp(0.0f32, 1) == 0.0f32 + assert math.ldexp(1.0f32, -1000000) == 0.0f32 + assert math.ldexp(-1.0f32, -1000000) == -0.0f32 + assert math.ldexp(INF32, 30) == INF32 + assert math.ldexp(NINF32, -213) == NINF32 + assert math.isnan(math.ldexp(NAN32, 0)) == True + + +@test +def test_float32_log(): + assert math.log(1.0f32 / math.e32) == -1.0f32 + assert math.log(1.0f32) == 0.0f32 + assert close32(math.log(math.e32), 1.0f32) + + +@test +def test_float32_log2(): + assert math.log2(1.0f32) == 0.0f32 + assert math.log2(2.0f32) == 1.0f32 + assert math.log2(4.0f32) == 2.0f32 + assert math.log2(2.0f32 ** 50.0f32) == 50.0f32 + assert math.isnan(math.log2(-1.5f32)) == True + assert math.isnan(math.log2(NINF32)) == True + assert math.isnan(math.log2(NAN32)) == True + + +@test +def test_float32_log10(): + assert math.log10(0.1f32) == -1.0f32 + assert math.log10(1.0f32) == 0.0f32 + assert math.log10(10.0f32) == 1.0f32 + assert math.log10(10000.0f32) == 4.0f32 + + +@test +def test_float32_degrees(): + assert math.degrees(math.pi32) == 180.0f32 + assert math.degrees(math.pi32 / 2.0f32) == 90.0f32 + assert math.degrees(-math.pi32 / 4.0f32) == -45.0f32 + assert math.degrees(0.0f32) == 0.0f32 + + +@test +def test_float32_radians(): + assert math.radians(180.0f32) == math.pi32 + assert math.radians(90.0f32) == math.pi32 / 2.0f32 + assert math.radians(-45.0f32) == -math.pi32 / 4.0f32 + assert math.radians(0.0f32) == 0.0f32 + + +@test +def test_float32_sqrt(): + assert math.sqrt(4.0f32) == 2.0f32 + assert math.sqrt(0.0f32) == 0.0f32 + assert math.sqrt(1.0f32) == 1.0f32 + assert math.isnan(math.sqrt(-1.0f32)) == True + + +@test +def test_float32_pow(): + assert math.pow(0.0f32, 1.0f32) == 0.0f32 + assert math.pow(1.0f32, 0.0f32) == 1.0f32 + assert math.pow(2.0f32, 1.0f32) == 2.0f32 + assert math.pow(2.0f32, -1.0f32) == 0.5f32 + assert math.pow(-0.0f32, 3.0f32) == -0.0f32 + assert math.pow(-0.0f32, 2.3f32) == 0.0f32 + assert math.pow(-0.0f32, 0.0f32) == 1.0f32 + assert math.pow(-0.0f32, -0.0f32) == 1.0f32 + assert math.pow(-2.0f32, 2.0f32) == 4.0f32 + assert math.pow(-2.0f32, 3.0f32) == -8.0f32 + assert math.pow(-2.0f32, -3.0f32) == -0.125f32 + assert math.pow(INF32, 1.0f32) == INF32 + assert math.pow(NINF32, 1.0f32) == NINF32 + assert math.pow(1.0f32, INF32) == 1.0f32 + assert math.pow(1.0f32, NINF32) == 1.0f32 + assert math.isnan(math.pow(NAN32, 1.0f32)) == True + assert math.isnan(math.pow(2.0f32, NAN32)) == True + assert math.isnan(math.pow(0.0f32, NAN32)) == True + assert math.pow(1.0f32, NAN32) == 1.0f32 + + +@test +def test_float32_acos(): + assert close32(math.acos(-1.0f32), math.pi32) + assert close32(math.acos(0.0f32), math.pi32 / 2.0f32) + assert math.acos(1.0f32) == 0.0f32 + assert math.isnan(math.acos(NAN32)) == True + + +@test +def test_float32_asin(): + assert close32(math.asin(-1.0f32), -math.pi32 / 2.0f32) + assert math.asin(0.0f32) == 0.0f32 + assert close32(math.asin(1.0f32), math.pi32 / 2.0f32) + assert math.isnan(math.asin(NAN32)) == True + + +@test +def test_float32_atan(): + assert math.atan(-1.0f32) == -math.pi32 / 4.0f32 + assert math.atan(0.0f32) == 0.0f32 + assert close32(math.atan(1.0f32), math.pi32 / 4.0f32) + assert close32(math.atan(INF32), math.pi32 / 2.0f32) + assert close32(math.atan(NINF32), -math.pi32 / 2.0f32) + assert math.isnan(math.atan(NAN32)) == True + + +@test +def test_float32_atan2(): + assert math.atan2(-1.0f32, 0.0f32) == -math.pi32 / 2.0f32 + assert math.atan2(-1.0f32, 1.0f32) == -math.pi32 / 4.0f32 + assert math.atan2(0.0f32, 1.0f32) == 0.0f32 + assert math.atan2(1.0f32, 1.0f32) == math.pi32 / 4.0f32 + assert math.atan2(1.0f32, 0.0f32) == math.pi32 / 2.0f32 + assert math.atan2(-0.0f32, 0.0f32) == -0.0f32 + assert math.atan2(-0.0f32, 2.3f32) == -0.0f32 + assert close32(math.atan2(0.0f32, -2.3f32), math.pi32) + assert math.atan2(INF32, NINF32) == math.pi32 * 3.0f32 / 4.0f32 + assert math.atan2(INF32, 2.3f32) == math.pi32 / 2.0f32 + assert math.isnan(math.atan2(NAN32, 0.0f32)) == True + + +@test +def test_float32_cos(): + assert math.cos(0.0f32) == 1.0f32 + assert close32(math.cos(math.pi32 / 2.0f32), 6.123233995736766e-17f32) + assert close32(math.cos(-math.pi32 / 2.0f32), 6.123233995736766e-17f32) + assert math.cos(math.pi32) == -1.0f32 + assert math.isnan(math.cos(INF32)) == True + assert math.isnan(math.cos(NINF32)) == True + assert math.isnan(math.cos(NAN32)) == True + + +@test +def test_float32_sin(): + assert math.sin(0.0f32) == 0.0f32 + assert math.sin(math.pi32 / 2.0f32) == 1.0f32 + assert math.sin(-math.pi32 / 2.0f32) == -1.0f32 + assert math.isnan(math.sin(INF32)) == True + assert math.isnan(math.sin(NINF32)) == True + assert math.isnan(math.sin(NAN32)) == True + + +@test +def test_float32_hypot(): + assert math.hypot(12.0f32, 5.0f32) == 13.0f32 + assert math.hypot(12.0f32 / 32.0f32, 5.0f32 / 32.0f32) == 13.0f32 / 32.0f32 + assert math.hypot(0.0f32, 0.0f32) == 0.0f32 + assert math.hypot(-3.0f32, 4.0f32) == 5.0f32 + assert math.hypot(3.0f32, 4.0f32) == 5.0f32 + + +@test +def test_float32_tan(): + assert math.tan(0.0f32) == 0.0f32 + assert close32(math.tan(math.pi32 / 4.0f32), 0.9999999999999999f32) + assert close32(math.tan(-math.pi32 / 4.0f32), -0.9999999999999999f32) + assert math.isnan(math.tan(INF32)) == True + assert math.isnan(math.tan(NINF32)) == True + assert math.isnan(math.tan(NAN32)) == True + + +@test +def test_float32_cosh(): + assert math.cosh(0.0f32) == 1.0f32 + assert close32(math.cosh(2.0f32) - 2.0f32 * math.cosh(1.0f32) ** 2.0f32, -1.0f32) + assert math.cosh(INF32) == INF32 + assert math.cosh(NINF32) == INF32 + assert math.isnan(math.cosh(NAN32)) == True + + +@test +def test_float32_sinh(): + assert math.sinh(0.0f32) == 0.0f32 + assert math.sinh(1.0f32) + math.sinh(-1.0f32) == 0.0f32 + assert math.sinh(INF32) == INF32 + assert math.sinh(NINF32) == NINF32 + assert math.isnan(math.sinh(NAN32)) == True + + +@test +def test_float32_tanh(): + assert math.tanh(0.0f32) == 0.0f32 + assert math.tanh(1.0f32) + math.tanh(-1.0f32) == 0.0f32 + assert math.tanh(INF32) == 1.0f32 + assert math.tanh(NINF32) == -1.0f32 + assert math.isnan(math.tanh(NAN32)) == True + + +@test +def test_float32_acosh(): + assert math.acosh(1.0f32) == 0.0f32 + assert close32(math.acosh(2.0f32), 1.3169578969248166f32) + assert math.acosh(INF32) == INF32 + assert math.isnan(math.acosh(NAN32)) == True + assert math.isnan(math.acosh(-1.0f32)) == True + + +@test +def test_float32_asinh(): + assert math.asinh(0.0f32) == 0.0f32 + assert close32(math.asinh(1.0f32), 0.881373587019543f32) + assert close32(math.asinh(-1.0f32), -0.881373587019543f32) + assert math.asinh(INF32) == INF32 + assert math.isnan(math.asinh(NAN32)) == True + assert math.asinh(NINF32) == NINF32 + + +@test +def test_float32_atanh(): + assert math.atanh(0.0f32) == 0.0f32 + assert close32(math.atanh(0.5f32), 0.5493061443340549f32) + assert close32(math.atanh(-0.5f32), -0.5493061443340549f32) + assert math.isnan(math.atanh(INF32)) == True + assert math.isnan(math.atanh(NAN32)) == True + assert math.isnan(math.atanh(NINF32)) == True + + +@test +def test_float32_copysign(): + assert math.copysign(1.0f32, -0.0f32) == -1.0f32 + assert math.copysign(1.0f32, 42.0f32) == 1.0f32 + assert math.copysign(1.0f32, -42.0f32) == -1.0f32 + assert math.copysign(3.0f32, 0.0f32) == 3.0f32 + assert math.copysign(INF32, 0.0f32) == INF32 + assert math.copysign(INF32, -0.0f32) == NINF32 + assert math.copysign(NINF32, 0.0f32) == INF32 + assert math.copysign(NINF32, -0.0f32) == NINF32 + assert math.copysign(1.0f32, INF32) == 1.0f32 + assert math.copysign(1.0f32, NINF32) == -1.0f32 + assert math.copysign(INF32, INF32) == INF32 + assert math.copysign(INF32, NINF32) == NINF32 + assert math.copysign(NINF32, INF32) == INF32 + assert math.copysign(NINF32, NINF32) == NINF32 + assert math.isnan(math.copysign(NAN32, 1.0f32)) == True + assert math.isnan(math.copysign(NAN32, INF32)) == True + assert math.isnan(math.copysign(NAN32, NINF32)) == True + assert math.isnan(math.copysign(NAN32, NAN32)) == True + + +@test +def test_float32_log1p(): + assert close32(math.log1p(2.0f32), 1.0986122886681098f32) + assert close32(math.log1p(2.0f32 ** 90.0f32), 62.383246250395075f32) + assert math.log1p(INF32) == INF32 + assert math.log1p(-1.0f32) == NINF32 + + +@test +def test_float32_trunc(): + assert math.trunc(1.0f32) == 1.0f32 + assert math.trunc(-1.0f32) == -1.0f32 + assert math.trunc(1.5f32) == 1.0f32 + assert math.trunc(-1.5f32) == -1.0f32 + assert math.trunc(1.99999f32) == 1.0f32 + assert math.trunc(-1.99999f32) == -1.0f32 + assert math.trunc(0.99999f32) == 0.0f32 + assert math.trunc(-100.999f32) == -100.0f32 + + +@test +def test_float32_erf(): + assert close32(math.erf(1.0f32), 0.8427007929497148f32) + assert math.erf(0.0f32) == 0.0f32 + assert close32(math.erf(3.0f32), 0.9999779095030015f32) + assert math.erf(256.0f32) == 1.0f32 + assert math.erf(INF32) == 1.0f32 + assert math.erf(NINF32) == -1.0f32 + assert math.isnan(math.erf(NAN32)) == True + + +@test +def test_float32_erfc(): + assert math.erfc(0.0f32) == 1.0f32 + assert close32(math.erfc(1.0f32), 0.15729920705028516f32) + assert close32(math.erfc(2.0f32), 0.0046777349810472645f32) + assert close32(math.erfc(-1.0f32), 1.8427007929497148f32) + assert math.erfc(INF32) == 0.0f32 + assert math.erfc(NINF32) == 2.0f32 + assert math.isnan(math.erfc(NAN32)) == True + + +@test +def test_float32_gamma(): + assert close32(math.gamma(6.0f32), 120.0f32) + assert close32(math.gamma(1.0f32), 1.0f32) + assert close32(math.gamma(2.0f32), 1.0f32) + assert close32(math.gamma(3.0f32), 2.0f32) + assert math.isnan(math.gamma(-1.0f32)) == True + assert math.gamma(INF32) == INF32 + assert math.isnan(math.gamma(NINF32)) == True + assert math.isnan(math.gamma(NAN32)) == True + + +@test +def test_float32_lgamma(): + assert math.lgamma(1.0f32) == 0.0f32 + assert math.lgamma(2.0f32) == 0.0f32 + assert math.lgamma(-1.0f32) == INF32 + assert math.lgamma(INF32) == INF32 + assert math.lgamma(NINF32) == INF32 + assert math.isnan(math.lgamma(NAN32)) == True + + +@test +def test_float32_remainder(): + assert math.remainder(2.0f32, 2.0f32) == 0.0f32 + assert math.remainder(-4.0f32, 1.0f32) == -0.0f32 + assert close32(math.remainder(-3.8f32, 1.0f32), 0.20000000000000018f32) + assert close32(math.remainder(3.8f32, 1.0f32), -0.20000000000000018f32) + assert math.isnan(math.remainder(INF32, 1.0f32)) == True + assert math.isnan(math.remainder(NINF32, 1.0f32)) == True + assert math.isnan(math.remainder(NAN32, 1.0f32)) == True + + +@test +def test_float32_gcd(): + assert math.gcd(0.0f32, 0.0f32) == 0.0f32 + assert math.gcd(1.0f32, 0.0f32) == 1.0f32 + assert math.gcd(-1.0f32, 0.0f32) == 1.0f32 + assert math.gcd(0.0f32, -1.0f32) == 1.0f32 + assert math.gcd(0.0f32, 1.0f32) == 1.0f32 + assert math.gcd(7.0f32, 1.0f32) == 1.0f32 + assert math.gcd(7.0f32, -1.0f32) == 1.0f32 + assert math.gcd(-23.0f32, 15.0f32) == 1.0f32 + assert math.gcd(120.0f32, 84.0f32) == 12.0f32 + assert math.gcd(84.0f32, -120.0f32) == 12.0f32 + + +@test +def test_float32_frexp(): + assert math.frexp(-2.0f32) == (-0.5f32, 2) + assert math.frexp(-1.0f32) == (-0.5f32, 1) + assert math.frexp(0.0f32) == (0.0f32, 0) + assert math.frexp(1.0f32) == (0.5f32, 1) + assert math.frexp(2.0f32) == (0.5f32, 2) + assert math.frexp(INF32)[0] == INF32 + assert math.frexp(NINF32)[0] == NINF32 + assert math.isnan(math.frexp(NAN32)[0]) == True + + +@test +def test_float32_modf(): + assert math.modf(1.5f32) == (0.5f32, 1.0f32) + assert math.modf(-1.5f32) == (-0.5f32, -1.0f32) + assert math.modf(math.inf32) == (0.0f32, INF32) + assert math.modf(-math.inf32) == (-0.0f32, NINF32) + modf_nan = math.modf(NAN32) + assert math.isnan(modf_nan[0]) == True + assert math.isnan(modf_nan[1]) == True + + +@test +def test_float32_isclose(): + assert math.isclose(1.0f32 + 1.0f32, 1.000000000001f32 + 1.0f32) == True + assert math.isclose(2.90909324093284f32, 2.909093240932844234234234234f32) == True + assert math.isclose(2.90909324093284f32, 2.9f32) == False + assert math.isclose(2.90909324093284f32, 2.90909324f32) == True + assert math.isclose(2.909094f32, 2.909095f32) == False + assert math.isclose(NAN32, 2.9f32) == False + assert math.isclose(2.9f32, NAN32) == False + assert math.isclose(INF32, INF32) == True + assert math.isclose(NINF32, NINF32) == True + assert math.isclose(NINF32, INF32) == False + assert math.isclose(INF32, NINF32) == False + + +test_float32_isnan() +test_float32_isinf() +test_float32_isfinite() +test_float32_ceil() +test_float32_floor() +test_float32_fabs() +test_float32_fmod() +test_float32_exp() +test_float32_expm1() +test_float32_ldexp() +test_float32_log() +test_float32_log2() +test_float32_log10() +test_float32_degrees() +test_float32_radians() +test_float32_sqrt() +test_float32_pow() +test_float32_acos() +test_float32_asin() +test_float32_atan() +test_float32_atan2() +test_float32_cos() +test_float32_sin() +test_float32_hypot() +test_float32_tan() +test_float32_cosh() +test_float32_sinh() +test_float32_tanh() +test_float32_acosh() +test_float32_asinh() +test_float32_atanh() +test_float32_copysign() +test_float32_log1p() +test_float32_trunc() +test_float32_erf() +test_float32_erfc() +test_float32_gamma() +test_float32_lgamma() +test_float32_remainder() +test_float32_gcd() +test_float32_frexp() +test_float32_modf() +test_float32_isclose() diff --git a/test/stdlib/random_test.codon b/test/stdlib/random_test.codon index 4d2954c3..61ed60fe 100644 --- a/test/stdlib/random_test.codon +++ b/test/stdlib/random_test.codon @@ -1,6 +1,7 @@ import random as R import time import sys +from copy import copy seed = int(time.time()) # sys.stderr.write('seed: ' + str(seed) + '\n') diff --git a/test/transform/canonical.codon b/test/transform/canonical.codon index c6051522..3e288c3b 100644 --- a/test/transform/canonical.codon +++ b/test/transform/canonical.codon @@ -198,16 +198,3 @@ def test_add_mul_canon(): x = f(100.) # don't distribute float ops assert (x * 0.1) + (x * 0.2) == 30. test_add_mul_canon() - -@test -def test_const_sub(): - a = Vec(1,1) - b = a - -1 - assert (b.x, b.y) == (2, 2) - # EXPECT: vec add (x: 1, y: 1) 1 - - x = Vec(1.5, 1.5) - y = x - .5 - assert (y.x, y.y) == (1., 1.) - # EXPECT: vec add (x: 1.5, y: 1.5) -0.5 -test_const_sub() diff --git a/test/transform/kernels.codon b/test/transform/kernels.codon new file mode 100644 index 00000000..feac1f32 --- /dev/null +++ b/test/transform/kernels.codon @@ -0,0 +1,261 @@ +import gpu + +@test +def test_hello_world(): + @gpu.kernel + def kernel(a, b, c): + i = gpu.thread.x + c[i] = a[i] + b[i] + + a = [i for i in range(16)] + b = [2*i for i in range(16)] + c = [0 for _ in range(16)] + kernel(a, b, c, grid=1, block=16) + + assert c == [3*i for i in range(16)] + +@test +def test_raw(): + @gpu.kernel + def kernel(a, b, c): + i = gpu.thread.x + c[i] = a[i] + b[i] + + a = [i for i in range(16)] + b = [2*i for i in range(16)] + c = [0 for _ in range(16)] + kernel(gpu.raw(a), gpu.raw(b), gpu.raw(c), grid=1, block=16) + + assert c == [3*i for i in range(16)] + +@test +def test_conversions(): + @gpu.kernel + def kernel(x, v): + v[0] = x + + def check(x): + T = type(x) + v = [T()] + kernel(x, v, grid=1, block=1) + return v == [x] + + assert check(None) + assert check(42) + assert check(3.14) + assert check(f32(2.718)) + assert check(byte(99)) + assert check(Int[128](123123)) + assert check(UInt[128](321321)) + assert check(Optional[int]()) + assert check(Optional(111)) + assert check((1, 2, 3)) + assert check(([1], [2], [3])) + assert check('hello world') + assert check([1, 2, 3]) + assert check([[1], [2], [3]]) + assert check({1: [1.1], 2: [2.2]}) + assert check({'a', 'b', 'c'}) + assert check(Optional([1, 2, 3])) + +@test +def test_user_classes(): + @dataclass(gpu=True, eq=True) + class A: + x: int + y: List[int] + + @tuple + class B: + x: int + y: List[int] + + @gpu.kernel + def kernel(a, b, c): + a.x += b.x + c[0] + c[1][0][0] = 9999 + a.y[0] = c[0] + 1 + b.y[0] = c[0] + 2 + + a = A(42, [-1]) + b = B(100, [-2]) + c = (1000, [[-1]]) + kernel(a, b, c, grid=1, block=1) + + assert a == A(1142, [1001]) + assert b == B(100, [1002]) + assert c == (1000, [[9999]]) + + @gpu.kernel + def kernel2(a, b, c): + a[0].x += b[0].x + c[0][0] + c[0][1][0][0] = 9999 + a[0].y[0] = c[0][0] + 1 + b[0].y[0] = c[0][0] + 2 + + a = [A(42, [-1])] + b = [B(100, [-2])] + c = [(1000, [[-1]])] + kernel2(a, b, c, grid=1, block=1) + + assert a == [A(1142, [1001])] + assert b == [B(100, [1002])] + assert c == [(1000, [[9999]])] + +@test +def test_intrinsics(): + @gpu.kernel + def kernel(v): + block_id = (gpu.block.x + gpu.block.y*gpu.grid.dim.x + + gpu.block.z*gpu.grid.dim.x*gpu.grid.dim.y) + thread_id = (block_id*gpu.block.dim.x*gpu.block.dim.y*gpu.block.dim.z + + gpu.thread.z*gpu.block.dim.x*gpu.block.dim.y + + gpu.thread.y*gpu.block.dim.x + + gpu.thread.x) + v[thread_id] = thread_id + gpu.syncthreads() + + grid = gpu.Dim3(3, 4, 5) + block = gpu.Dim3(6, 7, 8) + N = grid.x * grid.y * grid.z * block.x * block.y * block.z + v = [0 for _ in range(N)] + kernel(v, grid=grid, block=block) + assert v == list(range(N)) + +@test +def test_matmul(): + A = [[12, 7, 3], + [4, 5, 6], + [7, 8, 9]] + + B = [[5, 8, 1, 2], + [6, 7, 3, 0], + [4, 5, 9, 1]] + + def mmz(A, B): + return [[0]*len(B[0]) for _ in range(len(A))] + + def matmul(A, B): + result = mmz(A, B) + for i in range(len(A)): + for j in range(len(B[0])): + for k in range(len(B)): + result[i][j] += A[i][k] * B[k][j] + return result + + expected = matmul(A, B) + + @gpu.kernel + def kernel(A, B, result): + i = gpu.thread.x + j = gpu.thread.y + result[i][j] = sum(A[i][k]*B[k][j] for k in range(len(A[0]))) + + result = mmz(A, B) + kernel(A, B, result, grid=1, block=(len(result), len(result[0]))) + assert result == expected + +MAX = 1000 # maximum Mandelbrot iterations +N = 256 # width and height of image + +@test +def test_mandelbrot(): + pixels = [0 for _ in range(N * N)] + + def scale(x, a, b): + return a + (x/N)*(b - a) + + expected = [0 for _ in range(N * N)] + for i in range(N): + for j in range(N): + c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12)) + z = 0j + iteration = 0 + + while abs(z) <= 2 and iteration < MAX: + z = z**2 + c + iteration += 1 + + expected[N*i + j] = int(255 * iteration/MAX) + + @gpu.kernel + def kernel(pixels): + idx = (gpu.block.x * gpu.block.dim.x) + gpu.thread.x + i, j = divmod(idx, N) + c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12)) + z = 0j + iteration = 0 + + while abs(z) <= 2 and iteration < MAX: + z = z**2 + c + iteration += 1 + + pixels[idx] = int(255 * iteration/MAX) + + kernel(pixels, grid=(N*N)//1024, block=1024) + assert pixels == expected + +@test +def test_kitchen_sink(): + @gpu.kernel + def kernel(x): + i = gpu.thread.x + d = {1: 2.1, 2: 3.5, 3: 4.2} + s = {4, 5, 6} + z = sum( + d.get(x[i], j) + (j if i in s else -j) + for j in range(i) + ) + x[i] = int(z) + + x = [i for i in range(16)] + kernel(x, grid=1, block=16) + assert x == [0, 2, 6, 9, 12, 20, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0] + +@test +def test_auto_par(): + a = [i for i in range(16)] + b = [2*i for i in range(16)] + c = [0 for _ in range(16)] + + @par(gpu=True) + for i in range(16): + c[i] = a[i] + b[i] + + assert c == [3*i for i in range(16)] + + @par(gpu=True) + for i in range(16): + c[i] += a[i] + b[i] + + assert c == [6*i for i in range(16)] + + N = 200 + Z = 42 + x = [0] * (N*N) + y = [0] * (N*N) + + for i in range(2, N - 1, 3): + for j in range(3, N, 2): + x[i*N + j] = i + j + Z + + @par(gpu=True, collapse=2) + for i in range(2, N - 1, 3): + for j in range(3, N, 2): + y[i*N + j] = i + j + Z + + assert x == y + + @par(gpu=True) + for i in range(1): + pass + +test_hello_world() +test_raw() +test_conversions() +test_user_classes() +test_intrinsics() +test_matmul() +test_mandelbrot() +test_kitchen_sink() +test_auto_par() diff --git a/test/transform/omp.codon b/test/transform/omp.codon index d15fd6cb..6626594d 100644 --- a/test/transform/omp.codon +++ b/test/transform/omp.codon @@ -450,6 +450,35 @@ def test_omp_reductions(): c = min(b, c) assert c == -1. + # float32s + c = f32(0.) + # this one can give different results due to + # non-commutativity of floats; so limit to 1001 + @par + for i in L[1:1001]: + c += f32(i) + assert c == sum(f32(i) for i in range(1001)) + + c = f32(1.) + @par + for i in L[1:10]: + c *= f32(i) + assert c == f32(1*2*3*4*5*6*7*8*9) + + c = f32(0.) + @par + for i in L: + b = f32(N+1 if i == N//2 else i) + c = max(b, c) + assert c == f32(N+1) + + c = f32(0.) + @par + for i in L: + b = f32(-1 if i == N//2 else i) + c = min(b, c) + assert c == f32(-1.) + x_add = 10. x_min = inf x_max = -inf @@ -728,6 +757,138 @@ def test_omp_corner_cases(): for i in squares(10): i += i +@test +def test_omp_collapse(): + # trivial + A0 = [] + B0 = [] + + for i in range(10): + A0.append(i) + + @par(num_threads=4, collapse=1) + for i in range(10): + with lock: + B0.append(i) + + assert sorted(A0) == sorted(B0) + + # basic + A1 = [] + B1 = [] + + for i in range(10): + for j in range(10): + A1.append((i,j)) + + @par(num_threads=4, collapse=2) + for i in range(10): + for j in range(10): + with lock: + B1.append((i,j)) + + assert sorted(A1) == sorted(B1) + + # deep + A2 = [] + B2 = [] + + for a in range(3): + for b in range(4): + for c in range(5): + for d in range(6): + A2.append((a,b,c,d)) + + @par(num_threads=4, collapse=4) + for a in range(3): + for b in range(4): + for c in range(5): + for d in range(6): + with lock: + B2.append((a,b,c,d)) + + assert sorted(A2) == sorted(B2) + + # ranges 1 + A3 = [] + B3 = [] + + for a in range(-5,5,2): + for b in range(5,-7,-2): + for c in range(0,17,3): + for d in range(5): + A3.append((a,b,c,d)) + + @par(num_threads=4, collapse=4) + for a in range(-5,5,2): + for b in range(5,-7,-2): + for c in range(0,17,3): + for d in range(5): + with lock: + B3.append((a,b,c,d)) + + assert sorted(A3) == sorted(B3) + + # ranges 2 + A4 = [] + B4 = [] + + for i in range(10): + for j in range(7,-5,-2): + for k in range(-5,10,3): + A4.append((i,j,k)) + + @par(num_threads=4, collapse=3) + for i in range(10): + for j in range(7,-5,-2): + for k in range(-5,10,3): + with lock: + B4.append((i,j,k)) + + assert sorted(A4) == sorted(B4) + + # zero + B5 = [] + + @noinline + def zstart(): + return 5 + + @noinline + def zstop(): + return -5 + + start = zstart() + stop = zstop() + + @par(num_threads=4, collapse=3) + for i in range(10): + for j in range(start, stop, 1): + for k in range(-5,10,3): + with lock: + B5.append((i,j,k)) + + assert len(B5) == 0 + + # order + A6 = [] + B6 = [] + + for a in range(-5,5,2): + for b in range(5,-7,-2): + for c in range(0,17,3): + for d in range(5): + A6.append((a,b,c,d)) + + @par(num_threads=1, collapse=4) + for a in range(-5,5,2): + for b in range(5,-7,-2): + for c in range(0,17,3): + for d in range(5): + B6.append((a,b,c,d)) # no lock since threads=1 + + assert A6 == B6 + test_omp_api() test_omp_schedules() test_omp_ranges() @@ -739,3 +900,4 @@ test_omp_transform(111, 222, 333) test_omp_transform(111.1, 222.2, 333.3) test_omp_nested() test_omp_corner_cases() +test_omp_collapse()