mirror of https://github.com/exaloop/codon.git
GPU and other updates (#52)
* Add nvptx pass * Fix spaces * Don't change name * Add runtime support * Add init call * Add more runtime functions * Add launch function * Add intrinsics * Fix codegen * Run GPU pass between general opt passes * Set data layout * Create context * Link libdevice * Add function remapping * Fix linkage * Fix libdevice link * Fix linking * Fix personality * Fix linking * Fix linking * Fix linking * Add internalize pass * Add more math conversions * Add more re-mappings * Fix conversions * Fix __str__ * Add decorator attribute for any decorator * Update kernel decorator * Fix kernel decorator * Fix kernel decorator * Fix kernel decorator * Fix kernel decorator * Remove old decorator * Fix pointer calc * Fix fill-in codegen * Fix linkage * Add comment * Update list conversion * Add more conversions * Add dict and set conversions * Add float32 type to IR/LLVM * Add float32 * Add float32 stdlib * Keep required global values in PTX module * Fix PTX module pruning * Fix malloc * Set will-return * Fix name cleanup * Fix access * Fix name cleanup * Fix function renaming * Update dimension API * Fix args * Clean up API * Move GPU transformations to end of opt pipeline * Fix alloc replacements * Fix naming * Target PTX 4.2 * Fix global renaming * Fix early return in static blocks; Add __realized__ function * Format * Add __llvm_name__ for functions * Add vector type to IR * SIMD support [wip] * Update kernel naming * Fix early returns; Fix SIMD calls * Fix kernel naming * Fix IR matcher * Remove module print * Update realloc * Add overloads for 32-bit float math ops * Add gpu.Pointer type for working with raw pointers * Add float32 conversion * Add to_gpu and from_gpu * clang-format * Add f32 reduction support to OpenMP * Fix automatic GPU class conversions * Fix conversion functions * Fix conversions * Rename self * Fix tuple conversion * Fix conversions * Fix conversions * Update PTX filename * Fix filename * Add raw function * Add GPU docs * Allow nested object conversions * Add tests (WIP) * Update SIMD * Add staticrange and statictuple loop support * SIMD updates * Add new Vec constructors * Fix UInt conversion * Fix size-0 allocs * Add more tests * Add matmul test * Rename gpu test file * Add more tests * Add alloc cache * Fix object_to_gpu * Fix frees * Fix str conversion * Fix set conversion * Fix conversions * Fix class conversion * Fix str conversion * Fix byte conversion * Fix list conversion * Fix pointer conversions * Fix conversions * Fix conversions * Update tests * Fix conversions * Fix tuple conversion * Fix tuple conversion * Fix auto conversions * Fix conversion * Fix magics * Update tests * Support GPU in JIT mode * Fix GPU+JIT * Fix kernel filename in JIT mode * Add __static_print__; Add earlyDefines; Various domination bugfixes; SimplifyContext RAII base handling * Fix global static handling * Fix float32 tests * FIx gpu module * Support OpenMP "collapse" option * Add more collapse tests * Capture generics and statics * TraitVar handling * Python exceptions / isinstance [wip; no_ci] * clang-format * Add list comparison operators * Support empty raise in IR * Add dict 'or' operator * Fix repr * Add copy module * Fix spacing * Use sm_30 * Python exceptions * TypeTrait support; Fix defaultDict * Fix earlyDefines * Add defaultdict * clang-format * Fix invalid canonicalizations * Fix empty raise * Fix copyright * Add Python numerics option * Support py-numerics in math module * Update docs * Add static Python division / modulus * Add static py numerics tests * Fix staticrange/tuple; Add KwTuple.__getitem__ * clang-format * Add gpu parameter to par * Fix globals * Don't init loop vars on loop collapse * Add par-gpu tests * Update gpu docs * Fix isinstance check * Remove invalid test * Add -libdevice to set custom path [skip ci] * Add release notes; bump version [skip ci] * Add libdevice docs [skip ci] Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>pull/48/head^2
parent
3379c064eb
commit
ebd344f894
codon
app
compiler
parser
docs
stdlib
|
@ -1,7 +1,7 @@
|
|||
cmake_minimum_required(VERSION 3.14)
|
||||
project(
|
||||
Codon
|
||||
VERSION "0.13.0"
|
||||
VERSION "0.14.0"
|
||||
HOMEPAGE_URL "https://github.com/exaloop/codon"
|
||||
DESCRIPTION "high-performance, extensible Python compiler")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
|
||||
|
@ -10,6 +10,7 @@ configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
|
|||
"${PROJECT_SOURCE_DIR}/extra/python/config/config.py")
|
||||
|
||||
option(CODON_JUPYTER "build Codon Jupyter server" OFF)
|
||||
option(CODON_GPU "build Codon GPU backend" OFF)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
|
@ -59,7 +60,8 @@ add_custom_command(
|
|||
|
||||
# Codon runtime library
|
||||
set(CODONRT_FILES codon/runtime/lib.h codon/runtime/lib.cpp
|
||||
codon/runtime/re.cpp codon/runtime/exc.cpp)
|
||||
codon/runtime/re.cpp codon/runtime/exc.cpp
|
||||
codon/runtime/gpu.cpp)
|
||||
add_library(codonrt SHARED ${CODONRT_FILES})
|
||||
add_dependencies(codonrt zlibstatic gc backtrace bz2 liblzma re2)
|
||||
target_include_directories(codonrt PRIVATE ${backtrace_SOURCE_DIR}
|
||||
|
@ -90,6 +92,11 @@ if(ASAN)
|
|||
codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address"
|
||||
"-fsanitize-recover=address")
|
||||
endif()
|
||||
if(CODON_GPU)
|
||||
add_compile_definitions(CODON_GPU)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
target_link_libraries(codonrt PRIVATE CUDA::cudart_static CUDA::cuda_driver)
|
||||
endif()
|
||||
add_custom_command(
|
||||
TARGET codonrt
|
||||
POST_BUILD
|
||||
|
@ -148,6 +155,7 @@ set(CODON_HPPFILES
|
|||
codon/sir/llvm/coro/CoroInternal.h
|
||||
codon/sir/llvm/coro/CoroSplit.h
|
||||
codon/sir/llvm/coro/Coroutines.h
|
||||
codon/sir/llvm/gpu.h
|
||||
codon/sir/llvm/llvisitor.h
|
||||
codon/sir/llvm/llvm.h
|
||||
codon/sir/llvm/optimize.h
|
||||
|
@ -295,6 +303,7 @@ set(CODON_CPPFILES
|
|||
codon/sir/llvm/coro/CoroFrame.cpp
|
||||
codon/sir/llvm/coro/CoroSplit.cpp
|
||||
codon/sir/llvm/coro/Coroutines.cpp
|
||||
codon/sir/llvm/gpu.cpp
|
||||
codon/sir/llvm/llvisitor.cpp
|
||||
codon/sir/llvm/optimize.cpp
|
||||
codon/sir/module.cpp
|
||||
|
|
|
@ -72,6 +72,7 @@ void initLogFlags(const llvm::cl::opt<std::string> &log) {
|
|||
|
||||
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, Detect };
|
||||
enum OptMode { Debug, Release };
|
||||
enum Numerics { C, Python };
|
||||
} // namespace
|
||||
|
||||
int docMode(const std::vector<const char *> &args, const std::string &argv0) {
|
||||
|
@ -116,6 +117,14 @@ std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &
|
|||
llvm::cl::list<std::string> plugins("plugin",
|
||||
llvm::cl::desc("Load specified plugin"));
|
||||
llvm::cl::opt<std::string> log("log", llvm::cl::desc("Enable given log streams"));
|
||||
llvm::cl::opt<Numerics> numerics(
|
||||
"numerics", llvm::cl::desc("numerical semantics"),
|
||||
llvm::cl::values(
|
||||
clEnumValN(C, "c", "C semantics: best performance but deviates from Python"),
|
||||
clEnumValN(Python, "py",
|
||||
"Python semantics: mirrors Python but might disable optimizations "
|
||||
"like vectorization")),
|
||||
llvm::cl::init(C));
|
||||
|
||||
llvm::cl::ParseCommandLineOptions(args.size(), args.data());
|
||||
initLogFlags(log);
|
||||
|
@ -148,7 +157,9 @@ std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &
|
|||
|
||||
const bool isDebug = (optMode == OptMode::Debug);
|
||||
std::vector<std::string> disabledOptsVec(disabledOpts);
|
||||
auto compiler = std::make_unique<codon::Compiler>(args[0], isDebug, disabledOptsVec);
|
||||
auto compiler = std::make_unique<codon::Compiler>(args[0], isDebug, disabledOptsVec,
|
||||
/*isTest=*/false,
|
||||
(numerics == Numerics::Python));
|
||||
compiler->getLLVMVisitor()->setStandalone(standalone);
|
||||
|
||||
// load plugins
|
||||
|
|
|
@ -29,15 +29,17 @@ ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool is
|
|||
} // namespace
|
||||
|
||||
Compiler::Compiler(const std::string &argv0, Compiler::Mode mode,
|
||||
const std::vector<std::string> &disabledPasses, bool isTest)
|
||||
: argv0(argv0), debug(mode == Mode::DEBUG), input(),
|
||||
const std::vector<std::string> &disabledPasses, bool isTest,
|
||||
bool pyNumerics)
|
||||
: argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), input(),
|
||||
plm(std::make_unique<PluginManager>()),
|
||||
cache(std::make_unique<ast::Cache>(argv0)),
|
||||
module(std::make_unique<ir::Module>()),
|
||||
pm(std::make_unique<ir::transform::PassManager>(getPassManagerInit(mode, isTest),
|
||||
disabledPasses)),
|
||||
disabledPasses, pyNumerics)),
|
||||
llvisitor(std::make_unique<ir::LLVMVisitor>()) {
|
||||
cache->module = module.get();
|
||||
cache->pythonCompat = pyNumerics;
|
||||
module->setCache(cache.get());
|
||||
llvisitor->setDebug(debug);
|
||||
llvisitor->setPluginManager(plm.get());
|
||||
|
@ -77,8 +79,9 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code,
|
|||
|
||||
Timer t2("simplify");
|
||||
t2.logged = true;
|
||||
auto transformed = ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt),
|
||||
abspath, defines, (testFlags > 1));
|
||||
auto transformed =
|
||||
ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt), abspath, defines,
|
||||
getEarlyDefines(), (testFlags > 1));
|
||||
LOG_TIME("[T] parse = {:.1f}", totalPeg);
|
||||
LOG_TIME("[T] simplify = {:.1f}", t2.elapsed() - totalPeg);
|
||||
|
||||
|
@ -171,4 +174,11 @@ llvm::Expected<std::string> Compiler::docgen(const std::vector<std::string> &fil
|
|||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> Compiler::getEarlyDefines() {
|
||||
std::unordered_map<std::string, std::string> earlyDefines;
|
||||
earlyDefines.emplace("__debug__", debug ? "1" : "0");
|
||||
earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0");
|
||||
return earlyDefines;
|
||||
}
|
||||
|
||||
} // namespace codon
|
||||
|
|
|
@ -25,6 +25,7 @@ public:
|
|||
private:
|
||||
std::string argv0;
|
||||
bool debug;
|
||||
bool pyNumerics;
|
||||
std::string input;
|
||||
std::unique_ptr<PluginManager> plm;
|
||||
std::unique_ptr<ast::Cache> cache;
|
||||
|
@ -38,12 +39,14 @@ private:
|
|||
|
||||
public:
|
||||
Compiler(const std::string &argv0, Mode mode,
|
||||
const std::vector<std::string> &disabledPasses = {}, bool isTest = false);
|
||||
const std::vector<std::string> &disabledPasses = {}, bool isTest = false,
|
||||
bool pyNumerics = false);
|
||||
|
||||
explicit Compiler(const std::string &argv0, bool debug = false,
|
||||
const std::vector<std::string> &disabledPasses = {},
|
||||
bool isTest = false)
|
||||
: Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest) {}
|
||||
bool isTest = false, bool pyNumerics = false)
|
||||
: Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest,
|
||||
pyNumerics) {}
|
||||
|
||||
std::string getInput() const { return input; }
|
||||
PluginManager *getPluginManager() const { return plm.get(); }
|
||||
|
@ -62,6 +65,8 @@ public:
|
|||
const std::unordered_map<std::string, std::string> &defines = {});
|
||||
llvm::Error compile();
|
||||
llvm::Expected<std::string> docgen(const std::vector<std::string> &files);
|
||||
|
||||
std::unordered_map<std::string, std::string> getEarlyDefines();
|
||||
};
|
||||
|
||||
} // namespace codon
|
||||
|
|
|
@ -37,8 +37,9 @@ llvm::Error JIT::init() {
|
|||
auto *pm = compiler->getPassManager();
|
||||
auto *llvisitor = compiler->getLLVMVisitor();
|
||||
|
||||
auto transformed = ast::SimplifyVisitor::apply(
|
||||
cache, std::make_shared<ast::SuiteStmt>(), JIT_FILENAME, {});
|
||||
auto transformed =
|
||||
ast::SimplifyVisitor::apply(cache, std::make_shared<ast::SuiteStmt>(),
|
||||
JIT_FILENAME, {}, compiler->getEarlyDefines());
|
||||
|
||||
auto typechecked = ast::TypecheckVisitor::apply(cache, std::move(transformed));
|
||||
ast::TranslateVisitor::apply(cache, std::move(typechecked));
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "codon/parser/ast.h"
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/parser/visitors/visitor.h"
|
||||
|
||||
#define ACCEPT_IMPL(T, X) \
|
||||
|
@ -17,7 +18,7 @@ namespace codon::ast {
|
|||
|
||||
Expr::Expr()
|
||||
: type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC),
|
||||
done(false), attributes(0) {}
|
||||
done(false), attributes(0), origExpr(nullptr) {}
|
||||
void Expr::validate() const {}
|
||||
types::TypePtr Expr::getType() const { return type; }
|
||||
void Expr::setType(types::TypePtr t) { this->type = std::move(t); }
|
||||
|
@ -65,7 +66,8 @@ Param::Param(std::string name, ExprPtr type, ExprPtr defaultValue, int status)
|
|||
: name(std::move(name)), type(std::move(type)),
|
||||
defaultValue(std::move(defaultValue)) {
|
||||
if (status == 0 && this->type &&
|
||||
(this->type->isId("type") || this->type->isId("TypeVar") ||
|
||||
(this->type->isId("type") || this->type->isId(TYPE_TYPEVAR) ||
|
||||
(this->type->getIndex() && this->type->getIndex()->expr->isId(TYPE_TYPEVAR)) ||
|
||||
getStaticGeneric(this->type.get())))
|
||||
this->status = Generic;
|
||||
else
|
||||
|
|
|
@ -80,6 +80,9 @@ struct Expr : public codon::SrcObject {
|
|||
/// Set of attributes.
|
||||
int attributes;
|
||||
|
||||
/// Original (pre-transformation) expression
|
||||
std::shared_ptr<Expr> origExpr;
|
||||
|
||||
public:
|
||||
Expr();
|
||||
Expr(const Expr &expr) = default;
|
||||
|
|
|
@ -339,11 +339,15 @@ std::string FunctionStmt::toString(int indent) const {
|
|||
std::vector<std::string> as;
|
||||
for (auto &a : args)
|
||||
as.push_back(a.toString());
|
||||
std::vector<std::string> attr;
|
||||
std::vector<std::string> dec, attr;
|
||||
for (auto &a : decorators)
|
||||
attr.push_back(format("(dec {})", a->toString()));
|
||||
return format("(fn '{} ({}){}{}{}{})", name, join(as, " "),
|
||||
if (a)
|
||||
dec.push_back(format("(dec {})", a->toString()));
|
||||
for (auto &a : attributes.customAttr)
|
||||
attr.push_back(format("'{}'", a));
|
||||
return format("(fn '{} ({}){}{}{}{}{})", name, join(as, " "),
|
||||
ret ? " #:ret " + ret->toString() : "",
|
||||
dec.empty() ? "" : format(" (dec {})", join(dec, " ")),
|
||||
attr.empty() ? "" : format(" (attr {})", join(attr, " ")), pad,
|
||||
suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)
|
||||
: "(suite)");
|
||||
|
@ -493,10 +497,11 @@ void ClassStmt::parseDecorators() {
|
|||
// @extend
|
||||
|
||||
std::map<std::string, bool> tupleMagics = {
|
||||
{"new", true}, {"repr", false}, {"hash", false}, {"eq", false},
|
||||
{"ne", false}, {"lt", false}, {"le", false}, {"gt", false},
|
||||
{"ge", false}, {"pickle", true}, {"unpickle", true}, {"to_py", false},
|
||||
{"from_py", false}, {"iter", false}, {"getitem", false}, {"len", false}};
|
||||
{"new", true}, {"repr", false}, {"hash", false}, {"eq", false},
|
||||
{"ne", false}, {"lt", false}, {"le", false}, {"gt", false},
|
||||
{"ge", false}, {"pickle", true}, {"unpickle", true}, {"to_py", false},
|
||||
{"from_py", false}, {"iter", false}, {"getitem", false}, {"len", false},
|
||||
{"to_gpu", false}, {"from_gpu", false}, {"from_gpu_new", false}};
|
||||
|
||||
for (auto &d : decorators) {
|
||||
if (d->isId("deduce")) {
|
||||
|
@ -531,6 +536,9 @@ void ClassStmt::parseDecorators() {
|
|||
tupleMagics["pickle"] = tupleMagics["unpickle"] = val;
|
||||
} else if (a.name == "python") {
|
||||
tupleMagics["to_py"] = tupleMagics["from_py"] = val;
|
||||
} else if (a.name == "gpu") {
|
||||
tupleMagics["to_gpu"] = tupleMagics["from_gpu"] =
|
||||
tupleMagics["from_gpu_new"] = val;
|
||||
} else if (a.name == "container") {
|
||||
tupleMagics["iter"] = tupleMagics["getitem"] = val;
|
||||
} else {
|
||||
|
|
|
@ -238,6 +238,9 @@ struct WhileStmt : public Stmt {
|
|||
StmtPtr suite;
|
||||
/// nullptr if there is no else suite.
|
||||
StmtPtr elseSuite;
|
||||
/// Set if a while loop is used to emulate goto statement
|
||||
/// (as `while gotoVar: ...`).
|
||||
std::string gotoVar = "";
|
||||
|
||||
WhileStmt(ExprPtr cond, StmtPtr suite, StmtPtr elseSuite = nullptr);
|
||||
WhileStmt(const WhileStmt &stmt);
|
||||
|
|
|
@ -312,10 +312,9 @@ std::string ClassType::debugString(bool debug) const {
|
|||
if (!a.name.empty())
|
||||
gs.push_back(a.type->debugString(debug));
|
||||
if (debug && !hiddenGenerics.empty()) {
|
||||
gs.emplace_back("//");
|
||||
for (auto &a : hiddenGenerics)
|
||||
if (!a.name.empty())
|
||||
gs.push_back(a.type->debugString(debug));
|
||||
gs.push_back("-" + a.type->debugString(debug));
|
||||
}
|
||||
// Special formatting for Functions and Tuples
|
||||
auto n = niceName;
|
||||
|
@ -829,4 +828,21 @@ std::string CallableTrait::debugString(bool debug) const {
|
|||
return fmt::format("Callable[{}]", join(gs, ","));
|
||||
}
|
||||
|
||||
TypeTrait::TypeTrait(TypePtr typ) : type(std::move(typ)) {}
|
||||
int TypeTrait::unify(Type *typ, Unification *us) { return typ->unify(type.get(), us); }
|
||||
TypePtr TypeTrait::generalize(int atLevel) {
|
||||
auto c = std::make_shared<TypeTrait>(type->generalize(atLevel));
|
||||
c->setSrcInfo(getSrcInfo());
|
||||
return c;
|
||||
}
|
||||
TypePtr TypeTrait::instantiate(int atLevel, int *unboundCount,
|
||||
std::unordered_map<int, TypePtr> *cache) {
|
||||
auto c = std::make_shared<TypeTrait>(type->instantiate(atLevel, unboundCount, cache));
|
||||
c->setSrcInfo(getSrcInfo());
|
||||
return c;
|
||||
}
|
||||
std::string TypeTrait::debugString(bool debug) const {
|
||||
return fmt::format("Trait[{}]", type->debugString(debug));
|
||||
}
|
||||
|
||||
} // namespace codon::ast::types
|
||||
|
|
|
@ -422,5 +422,17 @@ public:
|
|||
std::string debugString(bool debug) const override;
|
||||
};
|
||||
|
||||
struct TypeTrait : public Trait {
|
||||
TypePtr type;
|
||||
|
||||
public:
|
||||
explicit TypeTrait(TypePtr type);
|
||||
int unify(Type *typ, Unification *undo) override;
|
||||
TypePtr generalize(int atLevel) override;
|
||||
TypePtr instantiate(int atLevel, int *unboundCount,
|
||||
std::unordered_map<int, TypePtr> *cache) override;
|
||||
std::string debugString(bool debug) const override;
|
||||
};
|
||||
|
||||
} // namespace types
|
||||
} // namespace codon::ast
|
||||
|
|
|
@ -31,6 +31,12 @@ std::string Cache::rev(const std::string &s) {
|
|||
return "";
|
||||
}
|
||||
|
||||
void Cache::addGlobal(const std::string &name, ir::Var *var) {
|
||||
if (!in(globals, name)) {
|
||||
globals[name] = var;
|
||||
}
|
||||
}
|
||||
|
||||
SrcInfo Cache::generateSrcInfo() {
|
||||
return {FILE_GENERATED, generatedSrcInfoCount, generatedSrcInfoCount++, 0};
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#define TYPE_TUPLE "Tuple.N"
|
||||
#define TYPE_KWTUPLE "KwTuple.N"
|
||||
#define TYPE_TYPEVAR "TypeVar"
|
||||
#define TYPE_CALLABLE "Callable"
|
||||
#define TYPE_PARTIAL "Partial.N"
|
||||
#define TYPE_OPTIONAL "Optional"
|
||||
|
@ -28,6 +29,7 @@
|
|||
|
||||
#define MAX_INT_WIDTH 10000
|
||||
#define MAX_REALIZATION_DEPTH 200
|
||||
#define MAX_STATIC_ITER 1024
|
||||
|
||||
namespace codon::ast {
|
||||
|
||||
|
@ -207,6 +209,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
|
|||
std::unordered_map<std::string, int> generatedTuples;
|
||||
std::vector<exc::ParserException> errors;
|
||||
|
||||
/// Set if Codon operates in Python compatibility mode (e.g., with Python numerics)
|
||||
bool pythonCompat = false;
|
||||
|
||||
public:
|
||||
explicit Cache(std::string argv0 = "");
|
||||
|
||||
|
@ -220,6 +225,8 @@ public:
|
|||
SrcInfo generateSrcInfo();
|
||||
/// Get file contents at the given location.
|
||||
std::string getContent(const SrcInfo &info);
|
||||
/// Register a global identifier.
|
||||
void addGlobal(const std::string &name, ir::Var *var = nullptr);
|
||||
|
||||
/// Realization API.
|
||||
|
||||
|
|
|
@ -29,9 +29,15 @@ clause <-
|
|||
/ "num_threads" _ "(" _ int _ ")" {
|
||||
return vector<CallExpr::Arg>{{"num_threads", make_shared<IntExpr>(ac<int>(V0))}};
|
||||
}
|
||||
/ "ordered" {
|
||||
/ "ordered" {
|
||||
return vector<CallExpr::Arg>{{"ordered", make_shared<BoolExpr>(true)}};
|
||||
}
|
||||
/ "collapse" {
|
||||
return vector<CallExpr::Arg>{{"collapse", make_shared<IntExpr>(ac<int>(V0))}};
|
||||
}
|
||||
/ "gpu" {
|
||||
return vector<CallExpr::Arg>{{"gpu", make_shared<BoolExpr>(true)}};
|
||||
}
|
||||
schedule_kind <- ("static" / "dynamic" / "guided" / "auto" / "runtime") {
|
||||
return VS.token_to_string();
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ std::shared_ptr<peg::Grammar> initParser() {
|
|||
x.second.accept(v);
|
||||
}
|
||||
(*g)["program"].enablePackratParsing = true;
|
||||
(*g)["fstring"].enablePackratParsing = true;
|
||||
for (auto &rule : std::vector<std::string>{
|
||||
"arguments", "slices", "genexp", "parentheses", "star_parens", "generics",
|
||||
"with_parens_item", "params", "from_as_parens", "from_params"}) {
|
||||
|
|
|
@ -82,8 +82,9 @@ void SimplifyVisitor::visit(IdExpr *expr) {
|
|||
|
||||
/// Flatten imports.
|
||||
/// @example
|
||||
/// `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import
|
||||
/// `a.B.c` -> canonical name of `c` in class `a.B`
|
||||
/// `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import
|
||||
/// `a.B.c` -> canonical name of `c` in class `a.B`
|
||||
/// `python.foo` -> internal.python._get_identifier("foo")
|
||||
/// Other cases are handled during the type checking.
|
||||
void SimplifyVisitor::visit(DotExpr *expr) {
|
||||
// First flatten the imports:
|
||||
|
@ -99,7 +100,11 @@ void SimplifyVisitor::visit(DotExpr *expr) {
|
|||
std::reverse(chain.begin(), chain.end());
|
||||
auto p = getImport(chain);
|
||||
|
||||
if (p.second->getModule() == ctx->getModule() && p.first == 1) {
|
||||
if (p.second->getModule() == "std.python") {
|
||||
resultExpr = transform(N<CallExpr>(
|
||||
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
|
||||
N<StringExpr>(chain[p.first++])));
|
||||
} else if (p.second->getModule() == ctx->getModule() && p.first == 1) {
|
||||
resultExpr = transform(N<IdExpr>(chain[0]), true);
|
||||
} else {
|
||||
resultExpr = N<IdExpr>(p.second->canonicalName);
|
||||
|
@ -120,57 +125,74 @@ void SimplifyVisitor::visit(DotExpr *expr) {
|
|||
bool SimplifyVisitor::checkCapture(const SimplifyContext::Item &val) {
|
||||
if (!ctx->isOuter(val))
|
||||
return false;
|
||||
if ((val->isType() && !val->isGeneric()) || val->isFunc())
|
||||
return false;
|
||||
|
||||
// Ensure that outer variables can be captured (i.e., do not cross no-capture
|
||||
// boundary). Example:
|
||||
// def foo():
|
||||
// x = 1
|
||||
// class T: # <- boundary (class methods cannot capture locals)
|
||||
// def bar():
|
||||
// class T: # <- boundary (classes cannot capture locals)
|
||||
// t: int = x # x cannot be accessed
|
||||
// def bar(): # <- another boundary
|
||||
// # (class methods cannot capture locals except class generics)
|
||||
// print(x) # x cannot be accessed
|
||||
bool crossCaptureBoundary = false;
|
||||
bool localGeneric = val->isGeneric() && val->getBaseName() == ctx->getBaseName();
|
||||
bool parentClassGeneric =
|
||||
val->isGeneric() && !ctx->getBase()->isType() &&
|
||||
(ctx->bases.size() > 1 && ctx->bases[ctx->bases.size() - 2].isType() &&
|
||||
ctx->bases[ctx->bases.size() - 2].name == val->getBaseName());
|
||||
auto i = ctx->bases.size();
|
||||
for (; i-- > 0;) {
|
||||
if (ctx->bases[i].name == val->getBaseName())
|
||||
break;
|
||||
if (!ctx->bases[i].captures)
|
||||
if (!localGeneric && !parentClassGeneric && !ctx->bases[i].captures)
|
||||
crossCaptureBoundary = true;
|
||||
}
|
||||
seqassert(i < ctx->bases.size(), "invalid base for '{}'", val->canonicalName);
|
||||
|
||||
// Disallow outer generics except for class generics in methods
|
||||
if (val->isGeneric() && !(ctx->bases[i].isType() && i + 2 == ctx->bases.size()))
|
||||
error("cannot access nonlocal variable '{}'", ctx->cache->rev(val->canonicalName));
|
||||
|
||||
// Mark methods (class functions that access class generics)
|
||||
if (val->isGeneric() && ctx->bases[i].isType() && i + 2 == ctx->bases.size() &&
|
||||
ctx->getBase()->attributes)
|
||||
if (parentClassGeneric)
|
||||
ctx->getBase()->attributes->set(Attr::Method);
|
||||
|
||||
// Check if a real variable (not a static) is defined outside the current scope
|
||||
if (!val->isVar() || val->isGeneric())
|
||||
// Ignore generics
|
||||
if (parentClassGeneric || localGeneric)
|
||||
return false;
|
||||
|
||||
// Case: a global variable that has not been marked with `global` statement
|
||||
if (val->getBaseName().empty()) { /// TODO: use isGlobal instead?
|
||||
if (val->isVar() && val->getBaseName().empty()) {
|
||||
val->noShadow = true;
|
||||
if (val->scope.size() == 1 && !in(ctx->cache->globals, val->canonicalName))
|
||||
ctx->cache->globals[val->canonicalName] = nullptr;
|
||||
if (val->scope.size() == 1)
|
||||
ctx->cache->addGlobal(val->canonicalName);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if a real variable (not a static) is defined outside the current scope
|
||||
if (crossCaptureBoundary)
|
||||
error("cannot access nonlocal variable '{}'", ctx->cache->rev(val->canonicalName));
|
||||
|
||||
// Case: a nonlocal variable that has not been marked with `nonlocal` statement
|
||||
// and capturing is enabled
|
||||
auto captures = ctx->getBase()->captures;
|
||||
if (!crossCaptureBoundary && captures && !in(*captures, val->canonicalName)) {
|
||||
if (captures && !in(*captures, val->canonicalName)) {
|
||||
// Captures are transformed to function arguments; generate new name for that
|
||||
// argument
|
||||
auto newName = (*captures)[val->canonicalName] =
|
||||
ctx->generateCanonicalName(val->canonicalName);
|
||||
ExprPtr typ = nullptr;
|
||||
if (val->isType())
|
||||
typ = N<IdExpr>("type");
|
||||
if (auto st = val->isStatic())
|
||||
typ = N<IndexExpr>(N<IdExpr>("Static"),
|
||||
N<IdExpr>(st == StaticValue::INT ? "int" : "str"));
|
||||
auto [newName, _] = (*captures)[val->canonicalName] = {
|
||||
ctx->generateCanonicalName(val->canonicalName), typ};
|
||||
ctx->cache->reverseIdentifierLookup[newName] = newName;
|
||||
// Add newly generated argument to the context
|
||||
auto newVal =
|
||||
ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo());
|
||||
std::shared_ptr<SimplifyItem> newVal = nullptr;
|
||||
if (val->isType())
|
||||
newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, getSrcInfo());
|
||||
else
|
||||
newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo());
|
||||
newVal->baseName = ctx->getBaseName();
|
||||
newVal->noShadow = true;
|
||||
return true;
|
||||
|
@ -206,10 +228,18 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {
|
|||
size_t itemEnd = 0;
|
||||
auto fctx = importName.empty() ? ctx : ctx->cache->imports[importName].ctx;
|
||||
for (auto i = chain.size(); i-- > importEnd;) {
|
||||
val = fctx->find(join(chain, ".", importEnd, i + 1));
|
||||
if (val && (importName.empty() || val->isType() || !val->isConditional())) {
|
||||
itemName = val->canonicalName, itemEnd = i + 1;
|
||||
break;
|
||||
if (fctx->getModule() == "std.python" && importEnd < chain.size()) {
|
||||
// Special case: importing from Python.
|
||||
// Fake SimplifyItem that inidcates std.python access
|
||||
val = std::make_shared<SimplifyItem>(SimplifyItem::Var, "", "",
|
||||
fctx->getModule(), std::vector<int>{});
|
||||
return {importEnd, val};
|
||||
} else {
|
||||
val = fctx->find(join(chain, ".", importEnd, i + 1));
|
||||
if (val && (importName.empty() || val->isType() || !val->isConditional())) {
|
||||
itemName = val->canonicalName, itemEnd = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (itemName.empty() && importName.empty())
|
||||
|
|
|
@ -147,14 +147,15 @@ StmtPtr SimplifyVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr t
|
|||
if (rhs && rhs->isType()) {
|
||||
ctx->addType(e->value, canonical, lhs->getSrcInfo());
|
||||
} else {
|
||||
ctx->addVar(e->value, canonical, lhs->getSrcInfo());
|
||||
auto val = ctx->addVar(e->value, canonical, lhs->getSrcInfo());
|
||||
if (auto st = getStaticGeneric(type.get()))
|
||||
val->staticType = st;
|
||||
}
|
||||
|
||||
// Register all toplevel variables as global in JIT mode
|
||||
bool isGlobal = (ctx->cache->isJit && ctx->isGlobal()) || (canonical == VAR_ARGV);
|
||||
if (isGlobal && !in(ctx->cache->globals, canonical)) {
|
||||
ctx->cache->globals[canonical] = nullptr;
|
||||
}
|
||||
if (isGlobal)
|
||||
ctx->cache->addGlobal(canonical);
|
||||
|
||||
return assign;
|
||||
}
|
||||
|
|
|
@ -50,169 +50,170 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
|
|||
argsToParse = astIter->second.ast->args;
|
||||
}
|
||||
|
||||
// Add the class base
|
||||
ctx->bases.emplace_back(SimplifyContext::Base(canonicalName));
|
||||
ctx->addBlock();
|
||||
|
||||
// Parse and add class generics
|
||||
std::vector<Param> args;
|
||||
std::pair<StmtPtr, FunctionStmt *> autoDeducedInit{nullptr, nullptr};
|
||||
if (stmt->attributes.has("deduce") && args.empty()) {
|
||||
// Auto-detect generics and fields
|
||||
autoDeducedInit = autoDeduceMembers(stmt, args);
|
||||
} else {
|
||||
// Add all generics before parent classes, fields and methods
|
||||
for (auto &a : argsToParse) {
|
||||
if (a.status != Param::Generic)
|
||||
continue;
|
||||
std::string genName, varName;
|
||||
if (stmt->attributes.has(Attr::Extend))
|
||||
varName = a.name, genName = ctx->cache->rev(a.name);
|
||||
else
|
||||
varName = ctx->generateCanonicalName(a.name), genName = a.name;
|
||||
if (getStaticGeneric(a.type.get()))
|
||||
ctx->addVar(genName, varName, a.type->getSrcInfo())->generic = true;
|
||||
else
|
||||
ctx->addType(genName, varName, a.type->getSrcInfo())->generic = true;
|
||||
args.emplace_back(Param{varName, transformType(clone(a.type), false),
|
||||
transformType(clone(a.defaultValue), false), a.status});
|
||||
}
|
||||
}
|
||||
|
||||
// Form class type node (e.g. `Foo`, or `Foo[T, U]` for generic classes)
|
||||
ExprPtr typeAst = N<IdExpr>(name);
|
||||
for (auto &a : args) {
|
||||
if (a.status == Param::Generic) {
|
||||
if (!typeAst->getIndex())
|
||||
typeAst = N<IndexExpr>(N<IdExpr>(name), N<TupleExpr>());
|
||||
typeAst->getIndex()->index->getTuple()->items.push_back(N<IdExpr>(a.name));
|
||||
}
|
||||
}
|
||||
|
||||
// Collect classes (and their fields) that are to be statically inherited
|
||||
auto baseASTs = parseBaseClasses(stmt->baseClasses, args, stmt->attributes);
|
||||
|
||||
// A ClassStmt will be separated into class variable assignments, method-free
|
||||
// ClassStmts (that include nested classes) and method FunctionStmts
|
||||
std::vector<StmtPtr> clsStmts; // Will be filled later!
|
||||
std::vector<StmtPtr> varStmts; // Will be filled later!
|
||||
std::vector<StmtPtr> fnStmts; // Will be filled later!
|
||||
transformNestedClasses(stmt, clsStmts, varStmts, fnStmts);
|
||||
std::vector<SimplifyContext::Item> addLater;
|
||||
{
|
||||
// Add the class base
|
||||
SimplifyContext::BaseGuard br(ctx.get(), canonicalName);
|
||||
|
||||
// Collect class fields
|
||||
for (auto &a : argsToParse) {
|
||||
if (a.status == Param::Normal) {
|
||||
if (!ClassStmt::isClassVar(a)) {
|
||||
args.emplace_back(Param{a.name, transformType(clone(a.type), false),
|
||||
transform(clone(a.defaultValue), true)});
|
||||
} else if (!stmt->attributes.has(Attr::Extend)) {
|
||||
// Handle class variables. Transform them later to allow self-references
|
||||
auto name = format("{}.{}", canonicalName, a.name);
|
||||
preamble->push_back(N<AssignStmt>(N<IdExpr>(name), nullptr, nullptr));
|
||||
if (!in(ctx->cache->globals, name))
|
||||
ctx->cache->globals[name] = nullptr;
|
||||
auto assign = N<AssignStmt>(N<IdExpr>(name), a.defaultValue,
|
||||
a.type ? a.type->getIndex()->index : nullptr);
|
||||
assign->setUpdate();
|
||||
varStmts.push_back(assign);
|
||||
ctx->cache->classes[canonicalName].classVars[a.name] = name;
|
||||
// Parse and add class generics
|
||||
std::vector<Param> args;
|
||||
std::pair<StmtPtr, FunctionStmt *> autoDeducedInit{nullptr, nullptr};
|
||||
if (stmt->attributes.has("deduce") && args.empty()) {
|
||||
// Auto-detect generics and fields
|
||||
autoDeducedInit = autoDeduceMembers(stmt, args);
|
||||
} else {
|
||||
// Add all generics before parent classes, fields and methods
|
||||
for (auto &a : argsToParse) {
|
||||
if (a.status != Param::Generic)
|
||||
continue;
|
||||
std::string genName, varName;
|
||||
if (stmt->attributes.has(Attr::Extend))
|
||||
varName = a.name, genName = ctx->cache->rev(a.name);
|
||||
else
|
||||
varName = ctx->generateCanonicalName(a.name), genName = a.name;
|
||||
if (auto st = getStaticGeneric(a.type.get())) {
|
||||
auto val = ctx->addVar(genName, varName, a.type->getSrcInfo());
|
||||
val->generic = true;
|
||||
val->staticType = st;
|
||||
} else {
|
||||
ctx->addType(genName, varName, a.type->getSrcInfo())->generic = true;
|
||||
}
|
||||
args.emplace_back(Param{varName, transformType(clone(a.type), false),
|
||||
transformType(clone(a.defaultValue), false), a.status});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ASTs for member arguments to be used for populating magic methods
|
||||
std::vector<Param> memberArgs;
|
||||
for (auto &a : args) {
|
||||
if (a.status == Param::Normal)
|
||||
memberArgs.push_back(a.clone());
|
||||
}
|
||||
|
||||
// Ensure that all fields and class variables are registered
|
||||
if (!stmt->attributes.has(Attr::Extend)) {
|
||||
for (size_t ai = 0; ai < args.size();) {
|
||||
if (args[ai].status == Param::Normal)
|
||||
ctx->cache->classes[canonicalName].fields.push_back({args[ai].name, nullptr});
|
||||
ai++;
|
||||
// Form class type node (e.g. `Foo`, or `Foo[T, U]` for generic classes)
|
||||
ExprPtr typeAst = N<IdExpr>(name);
|
||||
for (auto &a : args) {
|
||||
if (a.status == Param::Generic) {
|
||||
if (!typeAst->getIndex())
|
||||
typeAst = N<IndexExpr>(N<IdExpr>(name), N<TupleExpr>());
|
||||
typeAst->getIndex()->index->getTuple()->items.push_back(N<IdExpr>(a.name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse class members (arguments) and methods
|
||||
if (!stmt->attributes.has(Attr::Extend)) {
|
||||
// Now that we are done with arguments, add record type to the context
|
||||
if (stmt->attributes.has(Attr::Tuple)) {
|
||||
// Ensure that class binding does not shadow anything.
|
||||
// Class bindings cannot be dominated either
|
||||
auto v = ctx->find(name);
|
||||
if (v && v->noShadow)
|
||||
error("cannot update global/nonlocal");
|
||||
ctx->add(name, classItem);
|
||||
ctx->addAlwaysVisible(classItem);
|
||||
}
|
||||
// Create a cached AST.
|
||||
stmt->attributes.module =
|
||||
format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::",
|
||||
ctx->moduleName.module);
|
||||
ctx->cache->classes[canonicalName].ast =
|
||||
N<ClassStmt>(canonicalName, args, N<SuiteStmt>(), stmt->attributes);
|
||||
for (auto &b : baseASTs)
|
||||
ctx->cache->classes[canonicalName].parentClasses.emplace_back(b->name);
|
||||
ctx->cache->classes[canonicalName].ast->validate();
|
||||
// Collect classes (and their fields) that are to be statically inherited
|
||||
auto baseASTs = parseBaseClasses(stmt->baseClasses, args, stmt->attributes);
|
||||
|
||||
// Codegen default magic methods
|
||||
for (auto &m : stmt->attributes.magics) {
|
||||
fnStmts.push_back(transform(
|
||||
codegenMagic(m, typeAst, memberArgs, stmt->attributes.has(Attr::Tuple))));
|
||||
}
|
||||
// Add inherited methods
|
||||
for (auto &base : baseASTs) {
|
||||
for (auto &mm : ctx->cache->classes[base->name].methods)
|
||||
for (auto &mf : ctx->cache->overloads[mm.second]) {
|
||||
auto f = ctx->cache->functions[mf.name].ast;
|
||||
if (!f->attributes.has("autogenerated")) {
|
||||
std::string rootName;
|
||||
auto &mts = ctx->cache->classes[ctx->getBase()->name].methods;
|
||||
auto it = mts.find(ctx->cache->rev(f->name));
|
||||
if (it != mts.end())
|
||||
rootName = it->second;
|
||||
else
|
||||
rootName = ctx->generateCanonicalName(ctx->cache->rev(f->name), true);
|
||||
auto newCanonicalName =
|
||||
format("{}:{}", rootName, ctx->cache->overloads[rootName].size());
|
||||
ctx->cache->overloads[rootName].push_back(
|
||||
{newCanonicalName, ctx->cache->age});
|
||||
ctx->cache->reverseIdentifierLookup[newCanonicalName] =
|
||||
ctx->cache->rev(f->name);
|
||||
auto nf = std::dynamic_pointer_cast<FunctionStmt>(f->clone());
|
||||
nf->name = newCanonicalName;
|
||||
nf->attributes.parentClass = ctx->getBase()->name;
|
||||
ctx->cache->functions[newCanonicalName].ast = nf;
|
||||
ctx->cache->classes[ctx->getBase()->name]
|
||||
.methods[ctx->cache->rev(f->name)] = rootName;
|
||||
fnStmts.push_back(nf);
|
||||
}
|
||||
// A ClassStmt will be separated into class variable assignments, method-free
|
||||
// ClassStmts (that include nested classes) and method FunctionStmts
|
||||
transformNestedClasses(stmt, clsStmts, varStmts, fnStmts);
|
||||
|
||||
// Collect class fields
|
||||
for (auto &a : argsToParse) {
|
||||
if (a.status == Param::Normal) {
|
||||
if (!ClassStmt::isClassVar(a)) {
|
||||
args.emplace_back(Param{a.name, transformType(clone(a.type), false),
|
||||
transform(clone(a.defaultValue), true)});
|
||||
} else if (!stmt->attributes.has(Attr::Extend)) {
|
||||
// Handle class variables. Transform them later to allow self-references
|
||||
auto name = format("{}.{}", canonicalName, a.name);
|
||||
preamble->push_back(N<AssignStmt>(N<IdExpr>(name), nullptr, nullptr));
|
||||
ctx->cache->addGlobal(name);
|
||||
auto assign = N<AssignStmt>(N<IdExpr>(name), a.defaultValue,
|
||||
a.type ? a.type->getIndex()->index : nullptr);
|
||||
assign->setUpdate();
|
||||
varStmts.push_back(assign);
|
||||
ctx->cache->classes[canonicalName].classVars[a.name] = name;
|
||||
}
|
||||
}
|
||||
// Add auto-deduced __init__ (if available)
|
||||
if (autoDeducedInit.first)
|
||||
fnStmts.push_back(autoDeducedInit.first);
|
||||
}
|
||||
// Add class methods
|
||||
for (const auto &sp : getClassMethods(stmt->suite))
|
||||
if (sp && sp->getFunction()) {
|
||||
if (sp.get() != autoDeducedInit.second)
|
||||
fnStmts.push_back(transform(sp));
|
||||
}
|
||||
}
|
||||
|
||||
// After popping context block, record types and nested classes will disappear.
|
||||
// Store their references and re-add them to the context after popping
|
||||
std::vector<SimplifyContext::Item> addLater;
|
||||
addLater.reserve(clsStmts.size() + 1);
|
||||
for (auto &c : clsStmts)
|
||||
addLater.push_back(ctx->find(c->getClass()->name));
|
||||
if (stmt->attributes.has(Attr::Tuple))
|
||||
addLater.push_back(ctx->forceFind(name));
|
||||
ctx->bases.pop_back();
|
||||
ctx->popBlock();
|
||||
// ASTs for member arguments to be used for populating magic methods
|
||||
std::vector<Param> memberArgs;
|
||||
for (auto &a : args) {
|
||||
if (a.status == Param::Normal)
|
||||
memberArgs.push_back(a.clone());
|
||||
}
|
||||
|
||||
// Ensure that all fields and class variables are registered
|
||||
if (!stmt->attributes.has(Attr::Extend)) {
|
||||
for (size_t ai = 0; ai < args.size();) {
|
||||
if (args[ai].status == Param::Normal)
|
||||
ctx->cache->classes[canonicalName].fields.push_back({args[ai].name, nullptr});
|
||||
ai++;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse class members (arguments) and methods
|
||||
if (!stmt->attributes.has(Attr::Extend)) {
|
||||
// Now that we are done with arguments, add record type to the context
|
||||
if (stmt->attributes.has(Attr::Tuple)) {
|
||||
// Ensure that class binding does not shadow anything.
|
||||
// Class bindings cannot be dominated either
|
||||
auto v = ctx->find(name);
|
||||
if (v && v->noShadow)
|
||||
error("cannot update global/nonlocal");
|
||||
ctx->add(name, classItem);
|
||||
ctx->addAlwaysVisible(classItem);
|
||||
}
|
||||
// Create a cached AST.
|
||||
stmt->attributes.module =
|
||||
format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::",
|
||||
ctx->moduleName.module);
|
||||
ctx->cache->classes[canonicalName].ast =
|
||||
N<ClassStmt>(canonicalName, args, N<SuiteStmt>(), stmt->attributes);
|
||||
for (auto &b : baseASTs)
|
||||
ctx->cache->classes[canonicalName].parentClasses.emplace_back(b->name);
|
||||
ctx->cache->classes[canonicalName].ast->validate();
|
||||
|
||||
// Codegen default magic methods
|
||||
for (auto &m : stmt->attributes.magics) {
|
||||
fnStmts.push_back(transform(
|
||||
codegenMagic(m, typeAst, memberArgs, stmt->attributes.has(Attr::Tuple))));
|
||||
}
|
||||
// Add inherited methods
|
||||
for (auto &base : baseASTs) {
|
||||
for (auto &mm : ctx->cache->classes[base->name].methods)
|
||||
for (auto &mf : ctx->cache->overloads[mm.second]) {
|
||||
auto f = ctx->cache->functions[mf.name].ast;
|
||||
if (!f->attributes.has("autogenerated")) {
|
||||
std::string rootName;
|
||||
auto &mts = ctx->cache->classes[ctx->getBase()->name].methods;
|
||||
auto it = mts.find(ctx->cache->rev(f->name));
|
||||
if (it != mts.end())
|
||||
rootName = it->second;
|
||||
else
|
||||
rootName = ctx->generateCanonicalName(ctx->cache->rev(f->name), true);
|
||||
auto newCanonicalName =
|
||||
format("{}:{}", rootName, ctx->cache->overloads[rootName].size());
|
||||
ctx->cache->overloads[rootName].push_back(
|
||||
{newCanonicalName, ctx->cache->age});
|
||||
ctx->cache->reverseIdentifierLookup[newCanonicalName] =
|
||||
ctx->cache->rev(f->name);
|
||||
auto nf = std::dynamic_pointer_cast<FunctionStmt>(f->clone());
|
||||
nf->name = newCanonicalName;
|
||||
nf->attributes.parentClass = ctx->getBase()->name;
|
||||
ctx->cache->functions[newCanonicalName].ast = nf;
|
||||
ctx->cache->classes[ctx->getBase()->name]
|
||||
.methods[ctx->cache->rev(f->name)] = rootName;
|
||||
fnStmts.push_back(nf);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add auto-deduced __init__ (if available)
|
||||
if (autoDeducedInit.first)
|
||||
fnStmts.push_back(autoDeducedInit.first);
|
||||
}
|
||||
// Add class methods
|
||||
for (const auto &sp : getClassMethods(stmt->suite))
|
||||
if (sp && sp->getFunction()) {
|
||||
if (sp.get() != autoDeducedInit.second)
|
||||
fnStmts.push_back(transform(sp));
|
||||
}
|
||||
|
||||
// After popping context block, record types and nested classes will disappear.
|
||||
// Store their references and re-add them to the context after popping
|
||||
addLater.reserve(clsStmts.size() + 1);
|
||||
for (auto &c : clsStmts)
|
||||
addLater.push_back(ctx->find(c->getClass()->name));
|
||||
if (stmt->attributes.has(Attr::Tuple))
|
||||
addLater.push_back(ctx->forceFind(name));
|
||||
}
|
||||
for (auto &i : addLater)
|
||||
ctx->add(ctx->cache->rev(i->canonicalName), i);
|
||||
|
||||
|
@ -279,10 +280,13 @@ SimplifyVisitor::parseBaseClasses(const std::vector<ExprPtr> &baseClasses,
|
|||
args.emplace_back(a);
|
||||
}
|
||||
if (a.status != Param::Normal) {
|
||||
if (getStaticGeneric(a.type.get()))
|
||||
ctx->addVar(a.name, a.name, a.type->getSrcInfo())->generic = true;
|
||||
else
|
||||
if (auto st = getStaticGeneric(a.type.get())) {
|
||||
auto val = ctx->addVar(a.name, a.name, a.type->getSrcInfo());
|
||||
val->generic = true;
|
||||
val->staticType = st;
|
||||
} else {
|
||||
ctx->addType(a.name, a.name, a.type->getSrcInfo())->generic = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (si != subs.size())
|
||||
|
@ -632,6 +636,29 @@ StmtPtr SimplifyVisitor::codegenMagic(const std::string &op, const ExprPtr &typE
|
|||
N<DotExpr>(clone(args[i].type), "__from_py__"),
|
||||
N<CallExpr>(N<DotExpr>(I("pyobj"), "_tuple_get"), I("src"), N<IntExpr>(i))));
|
||||
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(typExpr->clone(), ar)));
|
||||
} else if (op == "to_gpu") {
|
||||
// def __to_gpu__(self: T, cache) -> T:
|
||||
// return __internal__.class_to_gpu(self, cache)
|
||||
fargs.emplace_back(Param{"self", typExpr->clone()});
|
||||
fargs.emplace_back(Param{"cache"});
|
||||
ret = typExpr->clone();
|
||||
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(
|
||||
N<DotExpr>(I("__internal__"), "class_to_gpu"), I("self"), I("cache"))));
|
||||
} else if (op == "from_gpu") {
|
||||
// def __from_gpu__(self: T, other: T) -> None:
|
||||
// __internal__.class_from_gpu(self, other)
|
||||
fargs.emplace_back(Param{"self", typExpr->clone()});
|
||||
fargs.emplace_back(Param{"other", typExpr->clone()});
|
||||
ret = I("NoneType");
|
||||
stmts.emplace_back(N<ExprStmt>(N<CallExpr>(
|
||||
N<DotExpr>(I("__internal__"), "class_from_gpu"), I("self"), I("other"))));
|
||||
} else if (op == "from_gpu_new") {
|
||||
// def __from_gpu_new__(other: T) -> T:
|
||||
// return __internal__.class_from_gpu_new(other)
|
||||
fargs.emplace_back(Param{"other", typExpr->clone()});
|
||||
ret = typExpr->clone();
|
||||
stmts.emplace_back(N<ReturnStmt>(
|
||||
N<CallExpr>(N<DotExpr>(I("__internal__"), "class_from_gpu_new"), I("other"))));
|
||||
} else if (op == "repr") {
|
||||
// def __repr__(self: T) -> str:
|
||||
// a = __array__[str](N) (number of args)
|
||||
|
|
|
@ -13,7 +13,7 @@ namespace codon::ast {
|
|||
/// The rest will be handled during the type-checking stage.
|
||||
void SimplifyVisitor::visit(TupleExpr *expr) {
|
||||
for (auto &i : expr->items)
|
||||
transform(i);
|
||||
transform(i, true); // types needed for some constructs (e.g., isinstance)
|
||||
}
|
||||
|
||||
/// Transform a list `[a1, ..., aN]` to the corresponding statement expression.
|
||||
|
|
|
@ -160,8 +160,8 @@ SimplifyContext::Item SimplifyContext::findDominatingBinding(const std::string &
|
|||
std::make_unique<BoolExpr>(false), nullptr));
|
||||
// Reached the toplevel? Register the binding as global.
|
||||
if (prefix == 1) {
|
||||
cache->globals[canonicalName] = nullptr;
|
||||
cache->globals[fmt::format("{}.__used__", canonicalName)] = nullptr;
|
||||
cache->addGlobal(canonicalName);
|
||||
cache->addGlobal(fmt::format("{}.__used__", canonicalName));
|
||||
}
|
||||
hasUsed = true;
|
||||
}
|
||||
|
|
|
@ -40,6 +40,8 @@ struct SimplifyItem : public SrcObject {
|
|||
bool noShadow = false;
|
||||
/// Set if an identifier is a class or a function generic
|
||||
bool generic = false;
|
||||
/// Set if an identifier is a static variable.
|
||||
char staticType = 0;
|
||||
|
||||
public:
|
||||
SimplifyItem(Kind kind, std::string baseName, std::string canonicalName,
|
||||
|
@ -61,6 +63,7 @@ public:
|
|||
/// (i.e., a block that might not be executed during the runtime)
|
||||
bool isConditional() const { return scope.size() > 1; }
|
||||
bool isGeneric() const { return generic; }
|
||||
char isStatic() const { return staticType; }
|
||||
};
|
||||
|
||||
/** Context class that tracks identifiers during the simplification. **/
|
||||
|
@ -99,8 +102,9 @@ struct SimplifyContext : public Context<SimplifyItem> {
|
|||
/// Map of captured identifiers (i.e., identifiers not defined in a function).
|
||||
/// Captured (canonical) identifiers are mapped to the new canonical names
|
||||
/// (representing the canonical function argument names that are appended to the
|
||||
/// function after processing).
|
||||
std::unordered_map<std::string, std::string> *captures;
|
||||
/// function after processing) and their types (indicating if they are a type, a
|
||||
/// static or a variable).
|
||||
std::unordered_map<std::string, std::pair<std::string, ExprPtr>> *captures;
|
||||
|
||||
/// A stack of nested loops enclosing the current statement used for transforming
|
||||
/// "break" statement in loop-else constructs. Each loop is defined by a "break"
|
||||
|
@ -123,6 +127,18 @@ struct SimplifyContext : public Context<SimplifyItem> {
|
|||
/// Current base stack (the last enclosing base is the last base in the stack).
|
||||
std::vector<Base> bases;
|
||||
|
||||
struct BaseGuard {
|
||||
SimplifyContext *holder;
|
||||
BaseGuard(SimplifyContext *holder, const std::string &name) : holder(holder) {
|
||||
holder->bases.emplace_back(Base(name));
|
||||
holder->addBlock();
|
||||
}
|
||||
~BaseGuard() {
|
||||
holder->bases.pop_back();
|
||||
holder->popBlock();
|
||||
}
|
||||
};
|
||||
|
||||
/// Set of seen global identifiers used to prevent later creation of local variables
|
||||
/// with the same name.
|
||||
std::unordered_map<std::string, std::unordered_map<std::string, ExprPtr>>
|
||||
|
|
|
@ -42,7 +42,7 @@ void SimplifyVisitor::visit(TryStmt *stmt) {
|
|||
c.var = ctx->generateCanonicalName(c.var);
|
||||
ctx->addVar(ctx->cache->rev(c.var), c.var, c.suite->getSrcInfo());
|
||||
}
|
||||
transformType(c.exc);
|
||||
transform(c.exc, true);
|
||||
transformConditionalScope(c.suite);
|
||||
ctx->leaveConditionalBlock();
|
||||
}
|
||||
|
|
|
@ -74,8 +74,7 @@ void SimplifyVisitor::visit(GlobalStmt *stmt) {
|
|||
stmt->var);
|
||||
|
||||
// Register as global if needed
|
||||
if (!in(ctx->cache->globals, val->canonicalName))
|
||||
ctx->cache->globals[val->canonicalName] = nullptr;
|
||||
ctx->cache->addGlobal(val->canonicalName);
|
||||
|
||||
val = ctx->addVar(stmt->var, val->canonicalName, stmt->getSrcInfo());
|
||||
val->baseName = ctx->getBaseName();
|
||||
|
@ -111,9 +110,11 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
|
|||
|
||||
// Parse attributes
|
||||
for (auto i = stmt->decorators.size(); i-- > 0;) {
|
||||
if (auto n = isAttribute(stmt->decorators[i])) {
|
||||
stmt->attributes.set(*n);
|
||||
stmt->decorators[i] = nullptr; // remove it from further consideration
|
||||
auto [isAttr, attrName] = getDecorator(stmt->decorators[i]);
|
||||
if (!attrName.empty()) {
|
||||
stmt->attributes.set(attrName);
|
||||
if (isAttr)
|
||||
stmt->decorators[i] = nullptr; // remove it from further consideration
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -153,89 +154,92 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
|
|||
ctx->addAlwaysVisible(funcVal);
|
||||
}
|
||||
|
||||
// Set up the base
|
||||
ctx->bases.emplace_back(SimplifyContext::Base{canonicalName});
|
||||
ctx->addBlock();
|
||||
ctx->getBase()->attributes = &(stmt->attributes);
|
||||
|
||||
// Parse arguments and add them to the context
|
||||
std::vector<Param> args;
|
||||
for (auto &a : stmt->args) {
|
||||
std::string varName = a.name;
|
||||
int stars = trimStars(varName);
|
||||
auto name = ctx->generateCanonicalName(varName);
|
||||
|
||||
// Mark as method if the first argument is self
|
||||
if (isClassMember && stmt->attributes.has(Attr::HasSelf) && a.name == "self") {
|
||||
ctx->getBase()->selfName = name;
|
||||
stmt->attributes.set(Attr::Method);
|
||||
}
|
||||
|
||||
// Handle default values
|
||||
auto defaultValue = a.defaultValue;
|
||||
if (a.type && defaultValue && defaultValue->getNone()) {
|
||||
// Special case: `arg: Callable = None` -> `arg: Callable = NoneType()`
|
||||
if (a.type->getIndex() && a.type->getIndex()->expr->isId(TYPE_CALLABLE))
|
||||
defaultValue = N<CallExpr>(N<IdExpr>("NoneType"));
|
||||
// Special case: `arg: type = None` -> `arg: type = NoneType`
|
||||
if (a.type->isId("type") || a.type->isId("TypeVar"))
|
||||
defaultValue = N<IdExpr>("NoneType");
|
||||
}
|
||||
/// TODO: Uncomment for Python-style defaults
|
||||
// if (defaultValue) {
|
||||
// auto defaultValueCanonicalName =
|
||||
// ctx->generateCanonicalName(format("{}.{}", canonicalName, name));
|
||||
// prependStmts->push_back(N<AssignStmt>(N<IdExpr>(defaultValueCanonicalName),
|
||||
// defaultValue));
|
||||
// defaultValue = N<IdExpr>(defaultValueCanonicalName);
|
||||
// }
|
||||
args.emplace_back(
|
||||
Param{std::string(stars, '*') + name, a.type, defaultValue, a.status});
|
||||
|
||||
// Add generics to the context
|
||||
if (a.status != Param::Normal) {
|
||||
if (getStaticGeneric(a.type.get()))
|
||||
ctx->addVar(varName, name, stmt->getSrcInfo())->generic = true;
|
||||
else
|
||||
ctx->addType(varName, name, stmt->getSrcInfo())->generic = true;
|
||||
}
|
||||
}
|
||||
// Parse arguments to the context. Needs to be done after adding generics
|
||||
// to support cases like `foo(a: T, T: type)`
|
||||
for (auto &a : args) {
|
||||
a.type = transformType(a.type, false);
|
||||
a.defaultValue = transform(a.defaultValue, true);
|
||||
}
|
||||
// Add non-generic arguments to the context. Delayed to prevent cases like
|
||||
// `def foo(a, b=a)`
|
||||
for (auto &a : args) {
|
||||
if (a.status == Param::Normal) {
|
||||
std::string canName = a.name;
|
||||
trimStars(canName);
|
||||
ctx->addVar(ctx->cache->rev(canName), canName, stmt->getSrcInfo());
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the return type
|
||||
auto ret = transformType(stmt->ret, false);
|
||||
|
||||
// Parse function body
|
||||
StmtPtr suite = nullptr;
|
||||
std::unordered_map<std::string, std::string> captures;
|
||||
if (!stmt->attributes.has(Attr::Internal) && !stmt->attributes.has(Attr::C)) {
|
||||
if (stmt->attributes.has(Attr::LLVM)) {
|
||||
suite = transformLLVMDefinition(stmt->suite->firstInBlock());
|
||||
} else if (stmt->attributes.has(Attr::C)) {
|
||||
// Do nothing
|
||||
} else {
|
||||
if ((isEnclosedFunc || stmt->attributes.has(Attr::Capture)) && !isClassMember)
|
||||
ctx->getBase()->captures = &captures;
|
||||
suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite);
|
||||
ExprPtr ret = nullptr;
|
||||
std::unordered_map<std::string, std::pair<std::string, ExprPtr>> captures;
|
||||
{
|
||||
// Set up the base
|
||||
SimplifyContext::BaseGuard br(ctx.get(), canonicalName);
|
||||
ctx->getBase()->attributes = &(stmt->attributes);
|
||||
|
||||
// Parse arguments and add them to the context
|
||||
for (auto &a : stmt->args) {
|
||||
std::string varName = a.name;
|
||||
int stars = trimStars(varName);
|
||||
auto name = ctx->generateCanonicalName(varName);
|
||||
|
||||
// Mark as method if the first argument is self
|
||||
if (isClassMember && stmt->attributes.has(Attr::HasSelf) && a.name == "self") {
|
||||
ctx->getBase()->selfName = name;
|
||||
stmt->attributes.set(Attr::Method);
|
||||
}
|
||||
|
||||
// Handle default values
|
||||
auto defaultValue = a.defaultValue;
|
||||
if (a.type && defaultValue && defaultValue->getNone()) {
|
||||
// Special case: `arg: Callable = None` -> `arg: Callable = NoneType()`
|
||||
if (a.type->getIndex() && a.type->getIndex()->expr->isId(TYPE_CALLABLE))
|
||||
defaultValue = N<CallExpr>(N<IdExpr>("NoneType"));
|
||||
// Special case: `arg: type = None` -> `arg: type = NoneType`
|
||||
if (a.type->isId("type") || a.type->isId(TYPE_TYPEVAR))
|
||||
defaultValue = N<IdExpr>("NoneType");
|
||||
}
|
||||
/// TODO: Uncomment for Python-style defaults
|
||||
// if (defaultValue) {
|
||||
// auto defaultValueCanonicalName =
|
||||
// ctx->generateCanonicalName(format("{}.{}", canonicalName, name));
|
||||
// prependStmts->push_back(N<AssignStmt>(N<IdExpr>(defaultValueCanonicalName),
|
||||
// defaultValue));
|
||||
// defaultValue = N<IdExpr>(defaultValueCanonicalName);
|
||||
// }
|
||||
args.emplace_back(
|
||||
Param{std::string(stars, '*') + name, a.type, defaultValue, a.status});
|
||||
|
||||
// Add generics to the context
|
||||
if (a.status != Param::Normal) {
|
||||
if (auto st = getStaticGeneric(a.type.get())) {
|
||||
auto val = ctx->addVar(varName, name, stmt->getSrcInfo());
|
||||
val->generic = true;
|
||||
val->staticType = st;
|
||||
} else {
|
||||
ctx->addType(varName, name, stmt->getSrcInfo())->generic = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse arguments to the context. Needs to be done after adding generics
|
||||
// to support cases like `foo(a: T, T: type)`
|
||||
for (auto &a : args) {
|
||||
a.type = transformType(a.type, false);
|
||||
a.defaultValue = transform(a.defaultValue, true);
|
||||
}
|
||||
// Add non-generic arguments to the context. Delayed to prevent cases like
|
||||
// `def foo(a, b=a)`
|
||||
for (auto &a : args) {
|
||||
if (a.status == Param::Normal) {
|
||||
std::string canName = a.name;
|
||||
trimStars(canName);
|
||||
ctx->addVar(ctx->cache->rev(canName), canName, stmt->getSrcInfo());
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the return type
|
||||
ret = transformType(stmt->ret, false);
|
||||
|
||||
// Parse function body
|
||||
if (!stmt->attributes.has(Attr::Internal) && !stmt->attributes.has(Attr::C)) {
|
||||
if (stmt->attributes.has(Attr::LLVM)) {
|
||||
suite = transformLLVMDefinition(stmt->suite->firstInBlock());
|
||||
} else if (stmt->attributes.has(Attr::C)) {
|
||||
// Do nothing
|
||||
} else {
|
||||
if ((isEnclosedFunc || stmt->attributes.has(Attr::Capture)) && !isClassMember)
|
||||
ctx->getBase()->captures = &captures;
|
||||
suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx->bases.pop_back();
|
||||
ctx->popBlock();
|
||||
stmt->attributes.module =
|
||||
format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::",
|
||||
ctx->moduleName.module);
|
||||
|
@ -259,8 +263,8 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
|
|||
args.pop_back();
|
||||
}
|
||||
for (auto &c : captures) {
|
||||
args.emplace_back(Param{c.second, nullptr, nullptr});
|
||||
partialArgs.push_back({c.second, N<IdExpr>(ctx->cache->rev(c.first))});
|
||||
args.emplace_back(Param{c.second.first, c.second.second, nullptr});
|
||||
partialArgs.push_back({c.second.first, N<IdExpr>(ctx->cache->rev(c.first))});
|
||||
}
|
||||
if (!kw.name.empty())
|
||||
args.push_back(kw);
|
||||
|
@ -417,20 +421,22 @@ StmtPtr SimplifyVisitor::transformLLVMDefinition(Stmt *codeStmt) {
|
|||
return N<SuiteStmt>(items);
|
||||
}
|
||||
|
||||
/// Check if a decorator is actually an attribute (a function with `@__attribute__`)
|
||||
std::string *SimplifyVisitor::isAttribute(const ExprPtr &e) {
|
||||
/// Fetch a decorator canonical name. The first pair member indicates if a decorator is
|
||||
/// actually an attribute (a function with `@__attribute__`).
|
||||
std::pair<bool, std::string> SimplifyVisitor::getDecorator(const ExprPtr &e) {
|
||||
auto dt = transform(clone(e));
|
||||
if (dt && dt->getId()) {
|
||||
auto ci = ctx->find(dt->getId()->value);
|
||||
auto id = dt->getCall() ? dt->getCall()->expr : dt;
|
||||
if (id && id->getId()) {
|
||||
auto ci = ctx->find(id->getId()->value);
|
||||
if (ci && ci->isFunc()) {
|
||||
if (ctx->cache->overloads[ci->canonicalName].size() == 1)
|
||||
if (ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name]
|
||||
.ast->attributes.isAttribute) {
|
||||
return &(ci->canonicalName);
|
||||
}
|
||||
if (ctx->cache->overloads[ci->canonicalName].size() == 1) {
|
||||
return {ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name]
|
||||
.ast->attributes.isAttribute,
|
||||
ci->canonicalName};
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return {false, ""};
|
||||
}
|
||||
|
||||
} // namespace codon::ast
|
||||
|
|
|
@ -319,8 +319,7 @@ StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) {
|
|||
// `import_[I]_done = False` (set to True upon successful import)
|
||||
preamble->push_back(N<AssignStmt>(N<IdExpr>(importDoneVar = importVar + "_done"),
|
||||
N<BoolExpr>(false)));
|
||||
if (!in(ctx->cache->globals, importDoneVar))
|
||||
ctx->cache->globals[importDoneVar] = nullptr;
|
||||
ctx->cache->addGlobal(importDoneVar);
|
||||
|
||||
// Wrap all imported top-level statements into a function.
|
||||
// Make sure to register the global variables and set their assignments as updates.
|
||||
|
@ -332,11 +331,8 @@ StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) {
|
|||
if (!a->isUpdate() && a->lhs->getId()) {
|
||||
// Global `a = ...`
|
||||
auto val = ictx->forceFind(a->lhs->getId()->value);
|
||||
if (val->isVar() && val->isGlobal() && !getStaticGeneric(a->type.get())) {
|
||||
// Register global
|
||||
if (!in(ctx->cache->globals, val->canonicalName))
|
||||
ctx->cache->globals[val->canonicalName] = nullptr;
|
||||
}
|
||||
if (val->isVar() && val->isGlobal() && !getStaticGeneric(a->type.get()))
|
||||
ctx->cache->addGlobal(val->canonicalName);
|
||||
}
|
||||
}
|
||||
stmts.push_back(s);
|
||||
|
|
|
@ -97,7 +97,7 @@ void SimplifyVisitor::visit(ForStmt *stmt) {
|
|||
ctx->addVar(i->value, varName = ctx->generateCanonicalName(i->value),
|
||||
stmt->var->getSrcInfo());
|
||||
transform(stmt->var);
|
||||
transform(stmt->suite);
|
||||
stmt->suite = transform(N<SuiteStmt>(stmt->suite));
|
||||
} else {
|
||||
varName = ctx->cache->getTemporaryVar("for");
|
||||
ctx->addVar(varName, varName, stmt->var->getSrcInfo());
|
||||
|
@ -109,7 +109,7 @@ void SimplifyVisitor::visit(ForStmt *stmt) {
|
|||
stmts.push_back(stmt->suite);
|
||||
stmt->suite = transform(N<SuiteStmt>(stmts));
|
||||
}
|
||||
ctx->leaveConditionalBlock();
|
||||
ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts));
|
||||
// Dominate loop variables
|
||||
for (auto &var : ctx->getBase()->getLoop()->seenVars)
|
||||
ctx->findDominatingBinding(var);
|
||||
|
|
|
@ -20,12 +20,12 @@ using namespace types;
|
|||
/// @param cache Pointer to the shared cache ( @c Cache )
|
||||
/// @param file Filename to be used for error reporting
|
||||
/// @param barebones Use the bare-bones standard library for faster testing
|
||||
/// @param defines User-defined static values (typically passed as `codon run -DX=Y
|
||||
/// ...`).
|
||||
/// @param defines User-defined static values (typically passed as `codon run -DX=Y`).
|
||||
/// Each value is passed as a string.
|
||||
StmtPtr
|
||||
SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &file,
|
||||
const std::unordered_map<std::string, std::string> &defines,
|
||||
const std::unordered_map<std::string, std::string> &earlyDefines,
|
||||
bool barebones) {
|
||||
auto preamble = std::make_shared<std::vector<StmtPtr>>();
|
||||
seqassertn(cache->module, "cache's module is not set");
|
||||
|
@ -53,6 +53,20 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil
|
|||
stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"};
|
||||
// Load the standard library
|
||||
stdlib->setFilename(stdlibPath->path);
|
||||
// Core definitions
|
||||
preamble->push_back(SimplifyVisitor(stdlib, preamble)
|
||||
.transform(parseCode(stdlib->cache, stdlibPath->path,
|
||||
"from internal.core import *")));
|
||||
for (auto &d : earlyDefines) {
|
||||
// Load early compile-time defines (for standard library)
|
||||
preamble->push_back(
|
||||
SimplifyVisitor(stdlib, preamble)
|
||||
.transform(std::make_shared<AssignStmt>(
|
||||
std::make_shared<IdExpr>(d.first),
|
||||
std::make_shared<IntExpr>(d.second),
|
||||
std::make_shared<IndexExpr>(std::make_shared<IdExpr>("Static"),
|
||||
std::make_shared<IdExpr>("int")))));
|
||||
}
|
||||
preamble->push_back(SimplifyVisitor(stdlib, preamble)
|
||||
.transform(parseFile(stdlib->cache, stdlibPath->path)));
|
||||
stdlib->isStdlibLoading = false;
|
||||
|
@ -181,6 +195,7 @@ StmtPtr SimplifyVisitor::transform(StmtPtr &stmt) {
|
|||
stmt->accept(v);
|
||||
} catch (const exc::ParserException &e) {
|
||||
ctx->cache->errors.push_back(e);
|
||||
// throw;
|
||||
}
|
||||
ctx->popSrcInfo();
|
||||
if (v.resultStmt)
|
||||
|
|
|
@ -43,9 +43,11 @@ class SimplifyVisitor : public CallbackASTVisitor<ExprPtr, StmtPtr> {
|
|||
StmtPtr resultStmt;
|
||||
|
||||
public:
|
||||
static StmtPtr apply(Cache *cache, const StmtPtr &node, const std::string &file,
|
||||
const std::unordered_map<std::string, std::string> &defines,
|
||||
bool barebones = false);
|
||||
static StmtPtr
|
||||
apply(Cache *cache, const StmtPtr &node, const std::string &file,
|
||||
const std::unordered_map<std::string, std::string> &defines = {},
|
||||
const std::unordered_map<std::string, std::string> &earlyDefines = {},
|
||||
bool barebones = false);
|
||||
static StmtPtr apply(const std::shared_ptr<SimplifyContext> &cache,
|
||||
const StmtPtr &node, const std::string &file, int atAge = -1);
|
||||
|
||||
|
@ -169,7 +171,7 @@ private: // Node simplification rules
|
|||
StmtPtr transformPythonDefinition(const std::string &, const std::vector<Param> &,
|
||||
const Expr *, Stmt *);
|
||||
StmtPtr transformLLVMDefinition(Stmt *);
|
||||
std::string *isAttribute(const ExprPtr &);
|
||||
std::pair<bool, std::string> getDecorator(const ExprPtr &);
|
||||
|
||||
/* Classes (class.cpp) */
|
||||
void visit(ClassStmt *) override;
|
||||
|
|
|
@ -409,7 +409,10 @@ void TranslateVisitor::visit(ForStmt *stmt) {
|
|||
bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt();
|
||||
auto threads = transform(c->args[0].value);
|
||||
auto chunk = transform(c->args[1].value);
|
||||
os = std::make_unique<OMPSched>(schedule, threads, chunk, ordered);
|
||||
int64_t collapse =
|
||||
fc->funcGenerics[2].type->getStatic()->expr->staticValue.getInt();
|
||||
bool gpu = fc->funcGenerics[3].type->getStatic()->expr->staticValue.getInt();
|
||||
os = std::make_unique<OMPSched>(schedule, threads, chunk, ordered, collapse, gpu);
|
||||
LOG_TYPECHECK("parsed {}", stmt->decorator->toString());
|
||||
}
|
||||
|
||||
|
@ -487,7 +490,7 @@ void TranslateVisitor::visit(TryStmt *stmt) {
|
|||
}
|
||||
|
||||
void TranslateVisitor::visit(ThrowStmt *stmt) {
|
||||
result = make<ir::ThrowInstr>(stmt, transform(stmt->expr));
|
||||
result = make<ir::ThrowInstr>(stmt, stmt->expr ? transform(stmt->expr) : nullptr);
|
||||
}
|
||||
|
||||
void TranslateVisitor::visit(FunctionStmt *stmt) {
|
||||
|
|
|
@ -24,15 +24,6 @@ void TypecheckVisitor::visit(IdExpr *expr) {
|
|||
if (isTuple(expr->value))
|
||||
generateTuple(std::stoi(expr->value.substr(sizeof(TYPE_TUPLE) - 1)));
|
||||
|
||||
// Handle empty callable references
|
||||
if (expr->value == TYPE_CALLABLE) {
|
||||
auto typ = ctx->getUnbound();
|
||||
typ->getLink()->trait = std::make_shared<CallableTrait>(std::vector<TypePtr>{});
|
||||
unify(expr->type, typ);
|
||||
expr->markType();
|
||||
return;
|
||||
}
|
||||
|
||||
// Replace identifiers that have been superseded by domination analysis during the
|
||||
// simplification
|
||||
while (auto s = in(ctx->cache->replacements, expr->value))
|
||||
|
@ -160,6 +151,12 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
|
|||
if (expr->expr->type->getFunc() && expr->member == "__name__") {
|
||||
return transform(N<StringExpr>(expr->expr->type->toString()));
|
||||
}
|
||||
// Special case: fn.__llvm_name__
|
||||
if (expr->expr->type->getFunc() && expr->member == "__llvm_name__") {
|
||||
if (realize(expr->expr->type))
|
||||
return transform(N<StringExpr>(expr->expr->type->realizedName()));
|
||||
return nullptr;
|
||||
}
|
||||
// Special case: cls.__name__
|
||||
if (expr->expr->isType() && expr->member == "__name__") {
|
||||
if (realize(expr->expr->type))
|
||||
|
|
|
@ -69,7 +69,11 @@ void TypecheckVisitor::visit(AssignStmt *stmt) {
|
|||
seqassert(stmt->rhs->staticValue.evaluated, "static not evaluated");
|
||||
unify(stmt->lhs->type,
|
||||
unify(stmt->type->type, std::make_shared<StaticType>(stmt->rhs, ctx)));
|
||||
ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type);
|
||||
auto val = ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type);
|
||||
if (in(ctx->cache->globals, lhs)) {
|
||||
// Make globals always visible!
|
||||
ctx->addToplevel(lhs, val);
|
||||
}
|
||||
if (realize(stmt->lhs->type))
|
||||
stmt->setDone();
|
||||
} else {
|
||||
|
|
|
@ -320,7 +320,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
|||
if (!part.known.empty()) {
|
||||
auto e = getPartialArg(-1);
|
||||
auto t = e->getType()->getRecord();
|
||||
seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple",
|
||||
seqassert(t && startswith(t->name, TYPE_KWTUPLE), "{} not a kwtuple",
|
||||
e->toString());
|
||||
auto &ff = ctx->cache->classes[t->name].fields;
|
||||
for (int i = 0; i < t->getRecord()->args.size(); i++) {
|
||||
|
@ -410,8 +410,9 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
|
|||
if (typeArgs[si]) {
|
||||
auto typ = typeArgs[si]->type;
|
||||
if (calleeFn->funcGenerics[si].type->isStaticType()) {
|
||||
if (!typeArgs[si]->isStatic())
|
||||
if (!typeArgs[si]->isStatic()) {
|
||||
error("expected static expression");
|
||||
}
|
||||
typ = std::make_shared<StaticType>(typeArgs[si], ctx);
|
||||
}
|
||||
unify(typ, calleeFn->funcGenerics[si].type);
|
||||
|
@ -543,6 +544,10 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
|
|||
return {true, transformCompileError(expr)};
|
||||
} else if (val == "tuple") {
|
||||
return {true, transformTupleFn(expr)};
|
||||
} else if (val == "__realized__") {
|
||||
return {true, transformRealizedFn(expr)};
|
||||
} else if (val == "__static_print__") {
|
||||
return {false, transformStaticPrintFn(expr)};
|
||||
} else {
|
||||
return {false, nullptr};
|
||||
}
|
||||
|
@ -668,7 +673,6 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) {
|
|||
/// `isinstance(obj, ByVal)` is True if `type(obj)` is a tuple type
|
||||
/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type
|
||||
ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
|
||||
expr->staticValue.type = StaticValue::INT;
|
||||
expr->setType(unify(expr->type, ctx->getType("bool")));
|
||||
transform(expr->args[0].value);
|
||||
auto typ = expr->args[0].value->type->getClass();
|
||||
|
@ -678,21 +682,44 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
|
|||
transform(expr->args[0].value); // transform again to realize it
|
||||
|
||||
auto &typExpr = expr->args[1].value;
|
||||
if (auto c = typExpr->getCall()) {
|
||||
// Handle `isinstance(obj, (type1, type2, ...))`
|
||||
if (typExpr->origExpr && typExpr->origExpr->getTuple()) {
|
||||
ExprPtr result = transform(N<BoolExpr>(false));
|
||||
for (auto &i : typExpr->origExpr->getTuple()->items) {
|
||||
result = transform(N<BinaryExpr>(
|
||||
result, "||",
|
||||
N<CallExpr>(N<IdExpr>("isinstance"), expr->args[0].value, i)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
expr->staticValue.type = StaticValue::INT;
|
||||
if (typExpr->isId("Tuple") || typExpr->isId("tuple")) {
|
||||
return transform(N<BoolExpr>(startswith(typ->name, TYPE_TUPLE)));
|
||||
} else if (typExpr->isId("ByVal")) {
|
||||
return transform(N<BoolExpr>(typ->getRecord() != nullptr));
|
||||
} else if (typExpr->isId("ByRef")) {
|
||||
return transform(N<BoolExpr>(typ->getRecord() == nullptr));
|
||||
} else {
|
||||
transformType(typExpr);
|
||||
// Check super types (i.e., statically inherited) as well
|
||||
for (auto &tx : getSuperTypes(typ->getClass())) {
|
||||
if (tx->unify(typExpr->type.get(), nullptr) >= 0)
|
||||
return transform(N<BoolExpr>(true));
|
||||
} else if (typExpr->type->is("pyobj") && !typExpr->isType()) {
|
||||
if (typ->is("pyobj")) {
|
||||
expr->staticValue.type = StaticValue::NOT_STATIC;
|
||||
return transform(N<CallExpr>(N<IdExpr>("std.internal.python._isinstance:0"),
|
||||
expr->args[0].value, expr->args[1].value));
|
||||
} else {
|
||||
return transform(N<BoolExpr>(false));
|
||||
}
|
||||
return transform(N<BoolExpr>(false));
|
||||
}
|
||||
|
||||
transformType(typExpr);
|
||||
|
||||
// Check super types (i.e., statically inherited) as well
|
||||
for (auto &tx : getSuperTypes(typ->getClass())) {
|
||||
if (tx->unify(typExpr->type.get(), nullptr) >= 0)
|
||||
return transform(N<BoolExpr>(true));
|
||||
}
|
||||
return transform(N<BoolExpr>(false));
|
||||
}
|
||||
|
||||
/// Transform staticlen method to a static integer expression. This method supports only
|
||||
|
@ -801,6 +828,33 @@ ExprPtr TypecheckVisitor::transformTypeFn(CallExpr *expr) {
|
|||
return e;
|
||||
}
|
||||
|
||||
/// Transform __realized__ function to a fully realized type identifier.
|
||||
ExprPtr TypecheckVisitor::transformRealizedFn(CallExpr *expr) {
|
||||
auto call =
|
||||
transform(N<CallExpr>(expr->args[0].value, N<StarExpr>(expr->args[1].value)));
|
||||
if (!call->getCall()->expr->type->getFunc())
|
||||
error("the first argument must be a function");
|
||||
if (auto f = realize(call->getCall()->expr->type)) {
|
||||
auto e = N<IdExpr>(f->getFunc()->realizedName());
|
||||
e->setType(f);
|
||||
e->setDone();
|
||||
return e;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Transform __static_print__ function to a fully realized type identifier.
|
||||
ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
|
||||
auto &args = expr->args[0].value->getCall()->args;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
realize(args[i].value->type);
|
||||
fmt::print(stderr, "[static_print] {}: {} := {}\n", getSrcInfo(),
|
||||
FormatVisitor::apply(args[i].value),
|
||||
args[i].value->type ? args[i].value->type->debugString(1) : "-");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Get the list that describes the inheritance hierarchy of a given type.
|
||||
/// The first type in the list is the most recently inherited type.
|
||||
std::vector<ClassTypePtr> TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) {
|
||||
|
|
|
@ -55,6 +55,15 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
|
|||
unify(defType->type, generic);
|
||||
}
|
||||
}
|
||||
if (auto ti = CAST(a.type, InstantiateExpr)) {
|
||||
// Parse TraitVar
|
||||
seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation");
|
||||
auto l = transformType(ti->typeParams[0])->type;
|
||||
if (l->getLink() && l->getLink()->trait)
|
||||
generic->getLink()->trait = l->getLink()->trait;
|
||||
else
|
||||
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
|
||||
}
|
||||
ctx->add(TypecheckItem::Type, a.name, generic);
|
||||
ClassType::Generic g{a.name, ctx->cache->rev(a.name),
|
||||
generic->generalize(ctx->typecheckLevel), typId};
|
||||
|
@ -135,6 +144,18 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name,
|
|||
args.emplace_back(Param(format("T{}", i + 1), N<IdExpr>("type"), nullptr, true));
|
||||
StmtPtr stmt = N<ClassStmt>(ctx->cache->generateSrcInfo(), typeName, args, nullptr,
|
||||
std::vector<ExprPtr>{N<IdExpr>("tuple")});
|
||||
|
||||
// Add getItem for KwArgs:
|
||||
// `def __getitem__(self, key: Static[str]): return getattr(self, key)`
|
||||
auto getItem = N<FunctionStmt>(
|
||||
"__getitem__", nullptr,
|
||||
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
|
||||
N<IdExpr>("str"))}},
|
||||
N<SuiteStmt>(N<ReturnStmt>(
|
||||
N<CallExpr>(N<IdExpr>("getattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
|
||||
if (startswith(typeName, TYPE_KWTUPLE))
|
||||
stmt->getClass()->suite = getItem;
|
||||
|
||||
// Simplify in the standard library context and type check
|
||||
stmt = SimplifyVisitor::apply(ctx->cache->imports[STDLIB_IMPORT].ctx, stmt,
|
||||
FILE_GENERATED, 0);
|
||||
|
|
|
@ -36,7 +36,10 @@ void TypecheckVisitor::visit(IfExpr *expr) {
|
|||
if (expr->cond->isStatic()) {
|
||||
resultExpr = evaluateStaticCondition(
|
||||
expr->cond,
|
||||
[&](bool isTrue) { return transform(isTrue ? expr->ifexpr : expr->elsexpr); },
|
||||
[&](bool isTrue) {
|
||||
LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue);
|
||||
return transform(isTrue ? expr->ifexpr : expr->elsexpr);
|
||||
},
|
||||
[&]() -> ExprPtr {
|
||||
// Check if both subexpressions are static; if so, this if expression is also
|
||||
// static and should be marked as such
|
||||
|
@ -83,6 +86,7 @@ void TypecheckVisitor::visit(IfStmt *stmt) {
|
|||
resultStmt = evaluateStaticCondition(
|
||||
stmt->cond,
|
||||
[&](bool isTrue) {
|
||||
LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue);
|
||||
auto t = transform(isTrue ? stmt->ifSuite : stmt->elseSuite);
|
||||
return t ? t : transform(N<SuiteStmt>());
|
||||
},
|
||||
|
|
|
@ -83,12 +83,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
|
|||
const types::ClassTypePtr &generics) {
|
||||
seqassert(type, "type is null");
|
||||
std::unordered_map<int, types::TypePtr> genericCache;
|
||||
if (generics)
|
||||
if (generics) {
|
||||
for (auto &g : generics->generics)
|
||||
if (g.type &&
|
||||
!(g.type->getLink() && g.type->getLink()->kind == types::LinkType::Generic)) {
|
||||
genericCache[g.id] = g.type;
|
||||
}
|
||||
}
|
||||
auto t = type->instantiate(typecheckLevel, &(cache->unboundCount), &genericCache);
|
||||
for (auto &i : genericCache) {
|
||||
if (auto l = i.second->getLink()) {
|
||||
|
|
|
@ -66,6 +66,8 @@ struct TypeContext : public Context<TypecheckItem> {
|
|||
int blockLevel;
|
||||
/// True if an early return is found (anything afterwards won't be typechecked)
|
||||
bool returnEarly;
|
||||
/// Stack of static loop control variables (used to emulate goto statements).
|
||||
std::vector<std::string> staticLoops;
|
||||
|
||||
public:
|
||||
explicit TypeContext(Cache *cache);
|
||||
|
|
|
@ -9,38 +9,97 @@ namespace codon::ast {
|
|||
|
||||
using namespace types;
|
||||
|
||||
/// Typecheck try-except statements.
|
||||
/// Typecheck try-except statements. Handle Python exceptions separately.
|
||||
/// @example
|
||||
/// ```try: ...
|
||||
/// except python.Error as e: ...
|
||||
/// except PyExc as f: ...
|
||||
/// except ValueError as g: ...
|
||||
/// ``` -> ```
|
||||
/// try: ...
|
||||
/// except ValueError as g: ... # ValueError
|
||||
/// except PyExc as exc:
|
||||
/// while True:
|
||||
/// if isinstance(exc.pytype, python.Error): # python.Error
|
||||
/// e = exc.pytype; ...; break
|
||||
/// f = exc; ...; break # PyExc
|
||||
/// raise```
|
||||
void TypecheckVisitor::visit(TryStmt *stmt) {
|
||||
ctx->blockLevel++;
|
||||
transform(stmt->suite);
|
||||
ctx->blockLevel--;
|
||||
|
||||
std::vector<TryStmt::Catch> catches;
|
||||
auto pyVar = ctx->cache->getTemporaryVar("pyexc");
|
||||
auto pyCatchStmt = N<WhileStmt>(N<BoolExpr>(true), N<SuiteStmt>());
|
||||
|
||||
auto done = stmt->suite->isDone();
|
||||
for (auto &c : stmt->catches) {
|
||||
transformType(c.exc);
|
||||
if (!c.var.empty()) {
|
||||
// Handle dominated except bindings
|
||||
auto changed = in(ctx->cache->replacements, c.var);
|
||||
while (auto s = in(ctx->cache->replacements, c.var))
|
||||
c.var = s->first, changed = s;
|
||||
if (changed && changed->second) {
|
||||
auto update =
|
||||
N<AssignStmt>(N<IdExpr>(format("{}.__used__", c.var)), N<BoolExpr>(true));
|
||||
update->setUpdate();
|
||||
c.suite = N<SuiteStmt>(update, c.suite);
|
||||
transform(c.exc);
|
||||
if (c.exc && c.exc->type->is("pyobj")) {
|
||||
// Transform python.Error exceptions
|
||||
if (!c.var.empty()) {
|
||||
c.suite = N<SuiteStmt>(
|
||||
N<AssignStmt>(N<IdExpr>(c.var), N<DotExpr>(N<IdExpr>(pyVar), "pytype")),
|
||||
c.suite);
|
||||
}
|
||||
if (changed)
|
||||
c.exc->setAttr(ExprAttr::Dominated);
|
||||
auto val = ctx->find(c.var);
|
||||
if (!changed)
|
||||
val = ctx->add(TypecheckItem::Var, c.var, c.exc->getType());
|
||||
unify(val->type, c.exc->getType());
|
||||
c.suite =
|
||||
N<IfStmt>(N<CallExpr>(N<IdExpr>("isinstance"),
|
||||
N<DotExpr>(N<IdExpr>(pyVar), "pytype"), clone(c.exc)),
|
||||
N<SuiteStmt>(c.suite, N<BreakStmt>()), nullptr);
|
||||
pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite);
|
||||
} else if (c.exc && c.exc->type->is("std.internal.types.error.PyError")) {
|
||||
// Transform PyExc exceptions
|
||||
if (!c.var.empty()) {
|
||||
c.suite =
|
||||
N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(c.var), N<IdExpr>(pyVar)), c.suite);
|
||||
}
|
||||
c.suite = N<SuiteStmt>(c.suite, N<BreakStmt>());
|
||||
pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite);
|
||||
} else {
|
||||
// Handle all other exceptions
|
||||
transformType(c.exc);
|
||||
if (!c.var.empty()) {
|
||||
// Handle dominated except bindings
|
||||
auto changed = in(ctx->cache->replacements, c.var);
|
||||
while (auto s = in(ctx->cache->replacements, c.var))
|
||||
c.var = s->first, changed = s;
|
||||
if (changed && changed->second) {
|
||||
auto update =
|
||||
N<AssignStmt>(N<IdExpr>(format("{}.__used__", c.var)), N<BoolExpr>(true));
|
||||
update->setUpdate();
|
||||
c.suite = N<SuiteStmt>(update, c.suite);
|
||||
}
|
||||
if (changed)
|
||||
c.exc->setAttr(ExprAttr::Dominated);
|
||||
auto val = ctx->find(c.var);
|
||||
if (!changed)
|
||||
val = ctx->add(TypecheckItem::Var, c.var, c.exc->getType());
|
||||
unify(val->type, c.exc->getType());
|
||||
}
|
||||
ctx->blockLevel++;
|
||||
transform(c.suite);
|
||||
ctx->blockLevel--;
|
||||
done &= (!c.exc || c.exc->isDone()) && c.suite->isDone();
|
||||
catches.push_back(c);
|
||||
}
|
||||
}
|
||||
if (!pyCatchStmt->suite->getSuite()->stmts.empty()) {
|
||||
// Process PyError catches
|
||||
auto exc = N<IdExpr>("std.internal.types.error.PyError");
|
||||
exc->markType();
|
||||
pyCatchStmt->suite->getSuite()->stmts.push_back(N<ThrowStmt>(nullptr));
|
||||
TryStmt::Catch c{pyVar, transformType(exc), pyCatchStmt};
|
||||
|
||||
auto val = ctx->add(TypecheckItem::Var, pyVar, c.exc->getType());
|
||||
unify(val->type, c.exc->getType());
|
||||
ctx->blockLevel++;
|
||||
transform(c.suite);
|
||||
ctx->blockLevel--;
|
||||
done &= (!c.exc || c.exc->isDone()) && c.suite->isDone();
|
||||
catches.push_back(c);
|
||||
}
|
||||
stmt->catches = catches;
|
||||
if (stmt->finally) {
|
||||
ctx->blockLevel++;
|
||||
transform(stmt->finally);
|
||||
|
@ -56,6 +115,11 @@ void TypecheckVisitor::visit(TryStmt *stmt) {
|
|||
/// @example
|
||||
/// `raise exc` -> ```raise __internal__.set_header(exc, "fn", "file", line, col)```
|
||||
void TypecheckVisitor::visit(ThrowStmt *stmt) {
|
||||
if (!stmt->expr) {
|
||||
stmt->setDone();
|
||||
return;
|
||||
}
|
||||
|
||||
transform(stmt->expr);
|
||||
|
||||
if (!(stmt->expr->getCall() &&
|
||||
|
|
|
@ -40,16 +40,16 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
|
|||
|
||||
unify(ctx->getRealizationBase()->returnType, stmt->expr->type);
|
||||
} else {
|
||||
// If we are not within conditional block, ignore later statements in this function.
|
||||
// Useful with static if statements.
|
||||
if (!ctx->blockLevel)
|
||||
ctx->returnEarly = true;
|
||||
|
||||
// Just set the expr for the translation stage. However, do not unify the return
|
||||
// type! This might be a `return` in a generator.
|
||||
stmt->expr = transform(N<CallExpr>(N<IdExpr>("NoneType")));
|
||||
}
|
||||
|
||||
// If we are not within conditional block, ignore later statements in this function.
|
||||
// Useful with static if statements.
|
||||
if (!ctx->blockLevel)
|
||||
ctx->returnEarly = true;
|
||||
|
||||
if (stmt->expr->isDone())
|
||||
stmt->setDone();
|
||||
}
|
||||
|
|
|
@ -52,8 +52,11 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
|
|||
ctx->typecheckLevel++;
|
||||
auto changedNodes = ctx->changedNodes;
|
||||
ctx->changedNodes = 0;
|
||||
auto returnEarly = ctx->returnEarly;
|
||||
ctx->returnEarly = false;
|
||||
TypecheckVisitor(ctx).transform(result);
|
||||
std::swap(ctx->changedNodes, changedNodes);
|
||||
std::swap(ctx->returnEarly, returnEarly);
|
||||
ctx->typecheckLevel--;
|
||||
|
||||
if (iteration == 1 && isToplevel) {
|
||||
|
@ -282,10 +285,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
|
|||
|
||||
if (hasAst) {
|
||||
auto oldBlockLevel = ctx->blockLevel;
|
||||
auto oldReturnEarly = ctx->returnEarly;
|
||||
ctx->blockLevel = 0;
|
||||
ctx->returnEarly = false;
|
||||
|
||||
if (startswith(type->ast->name, "Function.__call__")) {
|
||||
// Special case: Function.__call__
|
||||
/// TODO: move to IR
|
||||
|
@ -313,7 +313,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
|
|||
}
|
||||
inferTypes(ast->suite);
|
||||
ctx->blockLevel = oldBlockLevel;
|
||||
ctx->returnEarly = oldReturnEarly;
|
||||
|
||||
// Use NoneType as the return type when the return type is not specified and
|
||||
// function has no return statement
|
||||
|
@ -382,6 +381,7 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
|
|||
// Get the IR type
|
||||
auto *module = ctx->cache->module;
|
||||
ir::types::Type *handle = nullptr;
|
||||
|
||||
if (t->name == "bool") {
|
||||
handle = module->getBoolType();
|
||||
} else if (t->name == "byte") {
|
||||
|
@ -390,6 +390,8 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
|
|||
handle = module->getIntType();
|
||||
} else if (t->name == "float") {
|
||||
handle = module->getFloatType();
|
||||
} else if (t->name == "float32") {
|
||||
handle = module->getFloat32Type();
|
||||
} else if (t->name == "str") {
|
||||
handle = module->getStringType();
|
||||
} else if (t->name == "Int" || t->name == "UInt") {
|
||||
|
@ -415,6 +417,9 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
|
|||
types.push_back(forceFindIRType(m));
|
||||
auto ret = forceFindIRType(t->generics[1].type);
|
||||
handle = module->unsafeGetFuncType(realizedName, ret, types);
|
||||
} else if (t->name == "std.simd.Vec") {
|
||||
seqassert(types.size() == 1 && statics.size() == 1, "bad generics/statics");
|
||||
handle = module->unsafeGetVectorType(statics[0]->getInt(), types[0]);
|
||||
} else if (auto tr = t->getRecord()) {
|
||||
std::vector<ir::types::Type *> typeArgs;
|
||||
std::vector<std::string> names;
|
||||
|
|
|
@ -13,17 +13,32 @@ namespace codon::ast {
|
|||
using namespace types;
|
||||
|
||||
/// Nothing to typecheck; just call setDone
|
||||
void TypecheckVisitor::visit(BreakStmt *stmt) { stmt->setDone(); }
|
||||
void TypecheckVisitor::visit(BreakStmt *stmt) {
|
||||
stmt->setDone();
|
||||
if (!ctx->staticLoops.back().empty()) {
|
||||
auto a = N<AssignStmt>(N<IdExpr>(ctx->staticLoops.back()), N<BoolExpr>(false));
|
||||
a->setUpdate();
|
||||
resultStmt = transform(N<SuiteStmt>(a, stmt->clone()));
|
||||
}
|
||||
}
|
||||
|
||||
/// Nothing to typecheck; just call setDone
|
||||
void TypecheckVisitor::visit(ContinueStmt *stmt) { stmt->setDone(); }
|
||||
void TypecheckVisitor::visit(ContinueStmt *stmt) {
|
||||
stmt->setDone();
|
||||
if (!ctx->staticLoops.back().empty()) {
|
||||
resultStmt = N<BreakStmt>();
|
||||
resultStmt->setDone();
|
||||
}
|
||||
}
|
||||
|
||||
/// Typecheck while statements.
|
||||
void TypecheckVisitor::visit(WhileStmt *stmt) {
|
||||
ctx->staticLoops.push_back(stmt->gotoVar.empty() ? "" : stmt->gotoVar);
|
||||
transform(stmt->cond);
|
||||
ctx->blockLevel++;
|
||||
transform(stmt->suite);
|
||||
ctx->blockLevel--;
|
||||
ctx->staticLoops.pop_back();
|
||||
|
||||
if (stmt->cond->isDone() && stmt->suite->isDone())
|
||||
stmt->setDone();
|
||||
|
@ -40,6 +55,9 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
|
|||
if (!iterType)
|
||||
return; // wait until the iterator is known
|
||||
|
||||
if ((resultStmt = transformStaticForLoop(stmt)))
|
||||
return;
|
||||
|
||||
bool maybeHeterogenous = startswith(iterType->name, TYPE_TUPLE) ||
|
||||
startswith(iterType->name, TYPE_KWTUPLE);
|
||||
if (maybeHeterogenous && !iterType->canRealize()) {
|
||||
|
@ -83,9 +101,11 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
|
|||
unify(stmt->var->type,
|
||||
iterType ? unify(val->type, iterType->generics[0].type) : val->type);
|
||||
|
||||
ctx->staticLoops.push_back("");
|
||||
ctx->blockLevel++;
|
||||
transform(stmt->suite);
|
||||
ctx->blockLevel--;
|
||||
ctx->staticLoops.pop_back();
|
||||
|
||||
if (stmt->iter->isDone() && stmt->suite->isDone())
|
||||
stmt->setDone();
|
||||
|
@ -132,4 +152,81 @@ StmtPtr TypecheckVisitor::transformHeterogenousTupleFor(ForStmt *stmt) {
|
|||
return block;
|
||||
}
|
||||
|
||||
/// Handle static for constructs.
|
||||
/// @example
|
||||
/// `for i in statictuple(1, x): <suite>` ->
|
||||
/// ```loop = True
|
||||
/// while loop:
|
||||
/// while loop:
|
||||
/// i: Static[int] = 1; <suite>; break
|
||||
/// while loop:
|
||||
/// i = x; <suite>; break
|
||||
/// loop = False # also set to False on break
|
||||
/// A separate suite is generated for each static iteration.
|
||||
StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
|
||||
auto loopVar = ctx->cache->getTemporaryVar("loop");
|
||||
auto fn = [&](const std::string &var, const ExprPtr &expr) {
|
||||
bool staticInt = expr->isStatic();
|
||||
auto t = N<IndexExpr>(
|
||||
N<IdExpr>("Static"),
|
||||
N<IdExpr>(expr->staticValue.type == StaticValue::INT ? "int" : "str"));
|
||||
t->markType();
|
||||
auto brk = N<BreakStmt>();
|
||||
brk->setDone(); // Avoid transforming this one to continue
|
||||
// var [: Static] := expr; suite...
|
||||
auto loop = N<WhileStmt>(N<IdExpr>(loopVar),
|
||||
N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(var), expr->clone(),
|
||||
staticInt ? t : nullptr),
|
||||
clone(stmt->suite), brk));
|
||||
loop->gotoVar = loopVar;
|
||||
return loop;
|
||||
};
|
||||
|
||||
auto var = stmt->var->getId()->value;
|
||||
if (!stmt->iter->getCall() || !stmt->iter->getCall()->expr->getId())
|
||||
return nullptr;
|
||||
auto iter = stmt->iter->getCall()->expr->getId();
|
||||
auto block = N<SuiteStmt>();
|
||||
if (iter && startswith(iter->value, "statictuple:0")) {
|
||||
auto &args = stmt->iter->getCall()->args[0].value->getCall()->args;
|
||||
for (size_t i = 0; i < args.size(); i++)
|
||||
block->stmts.push_back(fn(var, args[i].value));
|
||||
} else if (iter &&
|
||||
startswith(iter->value, "std.internal.types.range.staticrange:0")) {
|
||||
int st =
|
||||
iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt();
|
||||
int ed =
|
||||
iter->type->getFunc()->funcGenerics[1].type->getStatic()->evaluate().getInt();
|
||||
int step =
|
||||
iter->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt();
|
||||
if (abs(st - ed) / abs(step) > MAX_STATIC_ITER)
|
||||
error("staticrange out of bounds ({} > {})", abs(st - ed) / abs(step),
|
||||
MAX_STATIC_ITER);
|
||||
for (int i = st; step > 0 ? i < ed : i > ed; i += step)
|
||||
block->stmts.push_back(fn(var, N<IntExpr>(i)));
|
||||
} else if (iter &&
|
||||
startswith(iter->value, "std.internal.types.range.staticrange:1")) {
|
||||
int ed =
|
||||
iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt();
|
||||
if (ed > MAX_STATIC_ITER)
|
||||
error("staticrange out of bounds ({} > {})", ed, MAX_STATIC_ITER);
|
||||
for (int i = 0; i < ed; i++)
|
||||
block->stmts.push_back(fn(var, N<IntExpr>(i)));
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
ctx->blockLevel++;
|
||||
|
||||
// Close the loop
|
||||
auto a = N<AssignStmt>(N<IdExpr>(loopVar), N<BoolExpr>(false));
|
||||
a->setUpdate();
|
||||
block->stmts.push_back(a);
|
||||
|
||||
auto loop =
|
||||
transform(N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(loopVar), N<BoolExpr>(true)),
|
||||
N<WhileStmt>(N<IdExpr>(loopVar), block)));
|
||||
ctx->blockLevel--;
|
||||
return loop;
|
||||
}
|
||||
|
||||
} // namespace codon::ast
|
||||
|
|
|
@ -262,15 +262,20 @@ void TypecheckVisitor::visit(IndexExpr *expr) {
|
|||
/// @example
|
||||
/// Instantiate(foo, [bar]) -> Id("foo[bar]")
|
||||
void TypecheckVisitor::visit(InstantiateExpr *expr) {
|
||||
// Infer the expression type
|
||||
transformType(expr->typeExpr);
|
||||
TypePtr typ =
|
||||
ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
|
||||
seqassert(typ->getClass(), "unknown type: {}", expr->typeExpr->toString());
|
||||
|
||||
auto &generics = typ->getClass()->generics;
|
||||
if (expr->typeParams.size() != generics.size())
|
||||
error("expected {} generics and/or statics", generics.size());
|
||||
|
||||
if (expr->typeExpr->isId(TYPE_CALLABLE)) {
|
||||
// Case: Callable[...] instantiation
|
||||
// Case: Callable[...] trait instantiation
|
||||
std::vector<TypePtr> types;
|
||||
|
||||
// Callable error checking.
|
||||
/// TODO: move to Codon?
|
||||
if (expr->typeParams.size() != 2)
|
||||
error("invalid Callable type declaration");
|
||||
for (auto &typeParam : expr->typeParams) {
|
||||
transformType(typeParam);
|
||||
if (typeParam->type->isStaticType())
|
||||
|
@ -281,16 +286,13 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
|
|||
// Set up the Callable trait
|
||||
typ->getLink()->trait = std::make_shared<CallableTrait>(types);
|
||||
unify(expr->type, typ);
|
||||
} else if (expr->typeExpr->isId(TYPE_TYPEVAR)) {
|
||||
// Case: TypeVar[...] trait instantiation
|
||||
transformType(expr->typeParams[0]);
|
||||
auto typ = ctx->getUnbound();
|
||||
typ->getLink()->trait = std::make_shared<TypeTrait>(expr->typeParams[0]->type);
|
||||
unify(expr->type, typ);
|
||||
} else {
|
||||
transformType(expr->typeExpr);
|
||||
TypePtr typ =
|
||||
ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
|
||||
seqassert(typ->getClass(), "unknown type");
|
||||
|
||||
auto &generics = typ->getClass()->generics;
|
||||
if (expr->typeParams.size() != generics.size())
|
||||
error("expected {} generics and/or statics", generics.size());
|
||||
|
||||
for (size_t i = 0; i < expr->typeParams.size(); i++) {
|
||||
transform(expr->typeParams[i]);
|
||||
TypePtr t = nullptr;
|
||||
|
@ -339,7 +341,9 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
|
|||
if (expr->expr->staticValue.type == StaticValue::STRING) {
|
||||
if (expr->op == "!") {
|
||||
if (expr->expr->staticValue.evaluated) {
|
||||
return transform(N<BoolExpr>(expr->expr->staticValue.getString().empty()));
|
||||
bool value = expr->expr->staticValue.getString().empty();
|
||||
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
|
||||
return transform(N<BoolExpr>(value));
|
||||
} else {
|
||||
// Cannot be evaluated yet: just set the type
|
||||
unify(expr->type, ctx->getType("bool"));
|
||||
|
@ -360,6 +364,7 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
|
|||
value = -value;
|
||||
else
|
||||
value = !bool(value);
|
||||
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
|
||||
if (expr->op == "!")
|
||||
return transform(N<BoolExpr>(bool(value)));
|
||||
else
|
||||
|
@ -375,6 +380,25 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
/// Division and modulus implementations.
|
||||
std::pair<int, int> divMod(const std::shared_ptr<TypeContext> &ctx, int a, int b) {
|
||||
if (!b)
|
||||
error(ctx->getSrcInfo(), "static division by zero");
|
||||
if (ctx->cache->pythonCompat) {
|
||||
// Use Python implementation.
|
||||
int d = a / b;
|
||||
int m = a - d * b;
|
||||
if (m && ((b ^ m) < 0)) {
|
||||
m += b;
|
||||
d -= 1;
|
||||
}
|
||||
return {d, m};
|
||||
} else {
|
||||
// Use C implementation.
|
||||
return {a / b, a % b};
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a static binary expression and return the resulting static expression.
|
||||
/// If the expression cannot be evaluated yet, return nullptr.
|
||||
/// Supported operators: (strings) +, ==, !=
|
||||
|
@ -385,8 +409,10 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
|
|||
if (expr->op == "+") {
|
||||
// `"a" + "b"` -> `"ab"`
|
||||
if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) {
|
||||
return transform(N<StringExpr>(expr->lexpr->staticValue.getString() +
|
||||
expr->rexpr->staticValue.getString()));
|
||||
auto value =
|
||||
expr->lexpr->staticValue.getString() + expr->rexpr->staticValue.getString();
|
||||
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value);
|
||||
return transform(N<StringExpr>(value));
|
||||
} else {
|
||||
// Cannot be evaluated yet: just set the type
|
||||
if (!expr->isStatic())
|
||||
|
@ -398,7 +424,9 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
|
|||
if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) {
|
||||
bool eq = expr->lexpr->staticValue.getString() ==
|
||||
expr->rexpr->staticValue.getString();
|
||||
return transform(N<BoolExpr>(expr->op == "==" ? eq : !eq));
|
||||
bool value = expr->op == "==" ? eq : !eq;
|
||||
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value);
|
||||
return transform(N<BoolExpr>(value));
|
||||
} else {
|
||||
// Cannot be evaluated yet: just set the type
|
||||
if (!expr->isStatic())
|
||||
|
@ -441,18 +469,13 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
|
|||
lvalue = lvalue & rvalue;
|
||||
else if (expr->op == "|")
|
||||
lvalue = lvalue | rvalue;
|
||||
else if (expr->op == "//") {
|
||||
if (!rvalue)
|
||||
error("static division by zero");
|
||||
lvalue = lvalue / rvalue;
|
||||
} else if (expr->op == "%") {
|
||||
if (!rvalue)
|
||||
error("static division by zero");
|
||||
lvalue = lvalue % rvalue;
|
||||
} else {
|
||||
else if (expr->op == "//")
|
||||
lvalue = divMod(ctx, lvalue, rvalue).first;
|
||||
else if (expr->op == "%")
|
||||
lvalue = divMod(ctx, lvalue, rvalue).second;
|
||||
else
|
||||
seqassert(false, "unknown static operator {}", expr->op);
|
||||
}
|
||||
|
||||
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), lvalue);
|
||||
if (in(std::set<std::string>{"==", "!=", "<", "<=", ">", ">=", "&&", "||"},
|
||||
expr->op))
|
||||
return transform(N<BoolExpr>(bool(lvalue)));
|
||||
|
|
|
@ -47,6 +47,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
|
|||
ctx->popSrcInfo();
|
||||
if (v.resultExpr) {
|
||||
v.resultExpr->attributes |= expr->attributes;
|
||||
v.resultExpr->origExpr = expr;
|
||||
expr = v.resultExpr;
|
||||
}
|
||||
seqassert(expr->type, "type not set for {}", expr->toString());
|
||||
|
|
|
@ -135,6 +135,8 @@ private: // Node typechecking rules
|
|||
ExprPtr transformCompileError(CallExpr *expr);
|
||||
ExprPtr transformTupleFn(CallExpr *expr);
|
||||
ExprPtr transformTypeFn(CallExpr *expr);
|
||||
ExprPtr transformRealizedFn(CallExpr *expr);
|
||||
ExprPtr transformStaticPrintFn(CallExpr *expr);
|
||||
std::vector<types::ClassTypePtr> getSuperTypes(const types::ClassTypePtr &cls);
|
||||
void addFunctionGenerics(const types::FuncType *t);
|
||||
std::string generatePartialStub(const std::vector<char> &mask, types::FuncType *fn);
|
||||
|
@ -151,6 +153,7 @@ private: // Node typechecking rules
|
|||
void visit(WhileStmt *) override;
|
||||
void visit(ForStmt *) override;
|
||||
StmtPtr transformHeterogenousTupleFor(ForStmt *);
|
||||
StmtPtr transformStaticForLoop(ForStmt *);
|
||||
|
||||
/* Errors and exceptions (error.cpp) */
|
||||
void visit(TryStmt *) override;
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "lib.h"
|
||||
|
||||
#ifdef CODON_GPU
|
||||
|
||||
#include "cuda.h"
|
||||
|
||||
#define fail(err) \
|
||||
do { \
|
||||
const char *msg; \
|
||||
cuGetErrorString((err), &msg); \
|
||||
fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, msg); \
|
||||
abort(); \
|
||||
} while (0)
|
||||
|
||||
#define check(call) \
|
||||
do { \
|
||||
auto err = (call); \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
fail(err); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static std::vector<CUmodule> modules;
|
||||
static CUcontext context;
|
||||
|
||||
void seq_nvptx_init() {
|
||||
CUdevice device;
|
||||
check(cuInit(0));
|
||||
check(cuDeviceGet(&device, 0));
|
||||
check(cuCtxCreate(&context, 0, device));
|
||||
}
|
||||
|
||||
SEQ_FUNC void seq_nvptx_load_module(const char *filename) {
|
||||
CUmodule module;
|
||||
check(cuModuleLoad(&module, filename));
|
||||
modules.push_back(module);
|
||||
}
|
||||
|
||||
SEQ_FUNC seq_int_t seq_nvptx_device_count() {
|
||||
int devCount;
|
||||
check(cuDeviceGetCount(&devCount));
|
||||
return devCount;
|
||||
}
|
||||
|
||||
SEQ_FUNC seq_str_t seq_nvptx_device_name(CUdevice device) {
|
||||
char name[128];
|
||||
check(cuDeviceGetName(name, sizeof(name) - 1, device));
|
||||
auto sz = static_cast<seq_int_t>(strlen(name));
|
||||
auto *p = (char *)seq_alloc_atomic(sz);
|
||||
memcpy(p, name, sz);
|
||||
return {sz, p};
|
||||
}
|
||||
|
||||
SEQ_FUNC seq_int_t seq_nvptx_device_capability(CUdevice device) {
|
||||
int devMajor, devMinor;
|
||||
check(cuDeviceComputeCapability(&devMajor, &devMinor, device));
|
||||
return ((seq_int_t)devMajor << 32) | (seq_int_t)devMinor;
|
||||
}
|
||||
|
||||
SEQ_FUNC CUdevice seq_nvptx_device(seq_int_t idx) {
|
||||
CUdevice device;
|
||||
check(cuDeviceGet(&device, idx));
|
||||
return device;
|
||||
}
|
||||
|
||||
static bool name_char_valid(char c, bool first) {
|
||||
bool ok = ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_');
|
||||
if (!first)
|
||||
ok = ok || ('0' <= c && c <= '9');
|
||||
return ok;
|
||||
}
|
||||
|
||||
SEQ_FUNC CUfunction seq_nvptx_function(seq_str_t name) {
|
||||
CUfunction function;
|
||||
CUresult result;
|
||||
|
||||
std::vector<char> clean(name.len + 1);
|
||||
for (unsigned i = 0; i < name.len; i++) {
|
||||
char c = name.str[i];
|
||||
clean[i] = (name_char_valid(c, i == 0) ? c : '_');
|
||||
}
|
||||
clean[name.len] = '\0';
|
||||
|
||||
for (auto it = modules.rbegin(); it != modules.rend(); ++it) {
|
||||
result = cuModuleGetFunction(&function, *it, clean.data());
|
||||
if (result == CUDA_SUCCESS) {
|
||||
return function;
|
||||
} else if (result == CUDA_ERROR_NOT_FOUND) {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
fail(result);
|
||||
return {};
|
||||
}
|
||||
|
||||
SEQ_FUNC void seq_nvptx_invoke(CUfunction f, unsigned int gridDimX,
|
||||
unsigned int gridDimY, unsigned int gridDimZ,
|
||||
unsigned int blockDimX, unsigned int blockDimY,
|
||||
unsigned int blockDimZ, unsigned int sharedMemBytes,
|
||||
void **kernelParams) {
|
||||
check(cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ,
|
||||
sharedMemBytes, nullptr, kernelParams, nullptr));
|
||||
}
|
||||
|
||||
SEQ_FUNC CUdeviceptr seq_nvptx_device_alloc(seq_int_t size) {
|
||||
if (size == 0)
|
||||
return {};
|
||||
|
||||
CUdeviceptr devp;
|
||||
check(cuMemAlloc(&devp, size));
|
||||
return devp;
|
||||
}
|
||||
|
||||
SEQ_FUNC void seq_nvptx_memcpy_h2d(CUdeviceptr devp, char *hostp, seq_int_t size) {
|
||||
if (size)
|
||||
check(cuMemcpyHtoD(devp, hostp, size));
|
||||
}
|
||||
|
||||
SEQ_FUNC void seq_nvptx_memcpy_d2h(char *hostp, CUdeviceptr devp, seq_int_t size) {
|
||||
if (size)
|
||||
check(cuMemcpyDtoH(hostp, devp, size));
|
||||
}
|
||||
|
||||
SEQ_FUNC void seq_nvptx_device_free(CUdeviceptr devp) {
|
||||
if (devp)
|
||||
check(cuMemFree(devp));
|
||||
}
|
||||
|
||||
#endif /* CODON_GPU */
|
|
@ -38,6 +38,10 @@ extern "C" void __kmpc_set_gc_callbacks(gc_setup_callback get_stack_base,
|
|||
|
||||
void seq_exc_init();
|
||||
|
||||
#ifdef CODON_GPU
|
||||
void seq_nvptx_init();
|
||||
#endif
|
||||
|
||||
int seq_flags;
|
||||
|
||||
SEQ_FUNC void seq_init(int flags) {
|
||||
|
@ -47,6 +51,9 @@ SEQ_FUNC void seq_init(int flags) {
|
|||
__kmpc_set_gc_callbacks(GC_get_stack_base, (gc_setup_callback)GC_register_my_thread,
|
||||
GC_add_roots, GC_remove_roots);
|
||||
seq_exc_init();
|
||||
#ifdef CODON_GPU
|
||||
seq_nvptx_init();
|
||||
#endif
|
||||
seq_flags = flags;
|
||||
}
|
||||
|
||||
|
@ -176,11 +183,11 @@ SEQ_FUNC void *seq_calloc_atomic(size_t m, size_t n) {
|
|||
#endif
|
||||
}
|
||||
|
||||
SEQ_FUNC void *seq_realloc(void *p, size_t n) {
|
||||
SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize) {
|
||||
#if USE_STANDARD_MALLOC
|
||||
return realloc(p, n);
|
||||
return realloc(p, newsize);
|
||||
#else
|
||||
return GC_REALLOC(p, n);
|
||||
return GC_REALLOC(p, newsize);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -231,7 +238,7 @@ static seq_str_t string_conv(const char *fmt, const size_t size, T t) {
|
|||
int n = snprintf(p, size, fmt, t);
|
||||
if (n >= size) {
|
||||
auto n2 = (size_t)n + 1;
|
||||
p = (char *)seq_realloc((void *)p, n2);
|
||||
p = (char *)seq_realloc((void *)p, n2, size);
|
||||
n = snprintf(p, n2, fmt, t);
|
||||
}
|
||||
return {(seq_int_t)n, p};
|
||||
|
|
|
@ -54,7 +54,7 @@ SEQ_FUNC void *seq_alloc(size_t n);
|
|||
SEQ_FUNC void *seq_alloc_atomic(size_t n);
|
||||
SEQ_FUNC void *seq_calloc(size_t m, size_t n);
|
||||
SEQ_FUNC void *seq_calloc_atomic(size_t m, size_t n);
|
||||
SEQ_FUNC void *seq_realloc(void *p, size_t n);
|
||||
SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize);
|
||||
SEQ_FUNC void seq_free(void *p);
|
||||
SEQ_FUNC void seq_register_finalizer(void *p, void (*f)(void *obj, void *data));
|
||||
|
||||
|
|
|
@ -0,0 +1,550 @@
|
|||
#include "gpu.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "codon/util/common.h"
|
||||
#include "llvm/CodeGen/CommandFlags.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
namespace {
|
||||
const std::string GPU_TRIPLE = "nvptx64-nvidia-cuda";
|
||||
const std::string GPU_DL =
|
||||
"e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-"
|
||||
"f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64";
|
||||
llvm::cl::opt<std::string>
|
||||
libdevice("libdevice", llvm::cl::desc("libdevice path for GPU kernels"),
|
||||
llvm::cl::init("/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"));
|
||||
|
||||
std::string cleanUpName(llvm::StringRef name) {
|
||||
std::string validName;
|
||||
llvm::raw_string_ostream validNameStream(validName);
|
||||
|
||||
auto valid = [](char c, bool first) {
|
||||
bool ok = ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_');
|
||||
if (!first)
|
||||
ok = ok || ('0' <= c && c <= '9');
|
||||
return ok;
|
||||
};
|
||||
|
||||
bool first = true;
|
||||
for (char c : name) {
|
||||
validNameStream << (valid(c, first) ? c : '_');
|
||||
first = false;
|
||||
}
|
||||
|
||||
return validNameStream.str();
|
||||
}
|
||||
|
||||
void linkLibdevice(llvm::Module *M, const std::string &path) {
|
||||
llvm::SMDiagnostic err;
|
||||
auto libdevice = llvm::parseIRFile(path, err, M->getContext());
|
||||
if (!libdevice)
|
||||
compilationError(err.getMessage().str(), err.getFilename().str(), err.getLineNo(),
|
||||
err.getColumnNo());
|
||||
libdevice->setDataLayout(M->getDataLayout());
|
||||
libdevice->setTargetTriple(M->getTargetTriple());
|
||||
|
||||
llvm::Linker L(*M);
|
||||
const bool fail = L.linkInModule(std::move(libdevice));
|
||||
seqassertn(!fail, "linking libdevice failed");
|
||||
}
|
||||
|
||||
llvm::Function *copyPrototype(llvm::Function *F, const std::string &name) {
|
||||
auto *M = F->getParent();
|
||||
return llvm::Function::Create(F->getFunctionType(), llvm::GlobalValue::PrivateLinkage,
|
||||
name.empty() ? F->getName() : name, *M);
|
||||
}
|
||||
|
||||
llvm::Function *makeNoOp(llvm::Function *F) {
|
||||
auto *M = F->getParent();
|
||||
auto &context = M->getContext();
|
||||
auto dummyName = (".codon.gpu.dummy." + F->getName()).str();
|
||||
auto *dummy = M->getFunction(dummyName);
|
||||
if (!dummy) {
|
||||
dummy = copyPrototype(F, dummyName);
|
||||
auto *entry = llvm::BasicBlock::Create(context, "entry", dummy);
|
||||
llvm::IRBuilder<> B(entry);
|
||||
|
||||
auto *retType = F->getReturnType();
|
||||
if (retType->isVoidTy()) {
|
||||
B.CreateRetVoid();
|
||||
} else {
|
||||
B.CreateRet(llvm::UndefValue::get(retType));
|
||||
}
|
||||
}
|
||||
return dummy;
|
||||
}
|
||||
|
||||
using Codegen =
|
||||
std::function<void(llvm::IRBuilder<> &, const std::vector<llvm::Value *> &)>;
|
||||
|
||||
llvm::Function *makeFillIn(llvm::Function *F, Codegen codegen) {
|
||||
auto *M = F->getParent();
|
||||
auto &context = M->getContext();
|
||||
auto fillInName = (".codon.gpu.fillin." + F->getName()).str();
|
||||
auto *fillIn = M->getFunction(fillInName);
|
||||
if (!fillIn) {
|
||||
fillIn = copyPrototype(F, fillInName);
|
||||
std::vector<llvm::Value *> args;
|
||||
for (auto it = fillIn->arg_begin(); it != fillIn->arg_end(); ++it) {
|
||||
args.push_back(it);
|
||||
}
|
||||
auto *entry = llvm::BasicBlock::Create(context, "entry", fillIn);
|
||||
llvm::IRBuilder<> B(entry);
|
||||
codegen(B, args);
|
||||
}
|
||||
return fillIn;
|
||||
}
|
||||
|
||||
llvm::Function *makeMalloc(llvm::Module *M) {
|
||||
auto &context = M->getContext();
|
||||
auto F = M->getOrInsertFunction("malloc", llvm::Type::getInt8PtrTy(context),
|
||||
llvm::Type::getInt64Ty(context));
|
||||
auto *G = llvm::cast<llvm::Function>(F.getCallee());
|
||||
G->setLinkage(llvm::GlobalValue::ExternalLinkage);
|
||||
G->setDoesNotThrow();
|
||||
G->setReturnDoesNotAlias();
|
||||
G->setOnlyAccessesInaccessibleMemory();
|
||||
G->setWillReturn();
|
||||
return G;
|
||||
}
|
||||
|
||||
void remapFunctions(llvm::Module *M) {
|
||||
// simple name-to-name remappings
|
||||
static const std::vector<std::pair<std::string, std::string>> remapping = {
|
||||
// 64-bit float intrinsics
|
||||
{"llvm.ceil.f64", "__nv_ceil"},
|
||||
{"llvm.floor.f64", "__nv_floor"},
|
||||
{"llvm.fabs.f64", "__nv_fabs"},
|
||||
{"llvm.exp.f64", "__nv_exp"},
|
||||
{"llvm.log.f64", "__nv_log"},
|
||||
{"llvm.log2.f64", "__nv_log2"},
|
||||
{"llvm.log10.f64", "__nv_log10"},
|
||||
{"llvm.sqrt.f64", "__nv_sqrt"},
|
||||
{"llvm.pow.f64", "__nv_pow"},
|
||||
{"llvm.sin.f64", "__nv_sin"},
|
||||
{"llvm.cos.f64", "__nv_cos"},
|
||||
{"llvm.copysign.f64", "__nv_copysign"},
|
||||
{"llvm.trunc.f64", "__nv_trunc"},
|
||||
{"llvm.rint.f64", "__nv_rint"},
|
||||
{"llvm.nearbyint.f64", "__nv_nearbyint"},
|
||||
{"llvm.round.f64", "__nv_round"},
|
||||
{"llvm.minnum.f64", "__nv_fmin"},
|
||||
{"llvm.maxnum.f64", "__nv_fmax"},
|
||||
{"llvm.copysign.f64", "__nv_copysign"},
|
||||
{"llvm.fma.f64", "__nv_fma"},
|
||||
|
||||
// 64-bit float math functions
|
||||
{"expm1", "__nv_expm1"},
|
||||
{"ldexp", "__nv_ldexp"},
|
||||
{"acos", "__nv_acos"},
|
||||
{"asin", "__nv_asin"},
|
||||
{"atan", "__nv_atan"},
|
||||
{"atan2", "__nv_atan2"},
|
||||
{"hypot", "__nv_hypot"},
|
||||
{"tan", "__nv_tan"},
|
||||
{"cosh", "__nv_cosh"},
|
||||
{"sinh", "__nv_sinh"},
|
||||
{"tanh", "__nv_tanh"},
|
||||
{"acosh", "__nv_acosh"},
|
||||
{"asinh", "__nv_asinh"},
|
||||
{"atanh", "__nv_atanh"},
|
||||
{"erf", "__nv_erf"},
|
||||
{"erfc", "__nv_erfc"},
|
||||
{"tgamma", "__nv_tgamma"},
|
||||
{"lgamma", "__nv_lgamma"},
|
||||
{"remainder", "__nv_remainder"},
|
||||
{"frexp", "__nv_frexp"},
|
||||
{"modf", "__nv_modf"},
|
||||
|
||||
// 32-bit float intrinsics
|
||||
{"llvm.ceil.f32", "__nv_ceilf"},
|
||||
{"llvm.floor.f32", "__nv_floorf"},
|
||||
{"llvm.fabs.f32", "__nv_fabsf"},
|
||||
{"llvm.exp.f32", "__nv_expf"},
|
||||
{"llvm.log.f32", "__nv_logf"},
|
||||
{"llvm.log2.f32", "__nv_log2f"},
|
||||
{"llvm.log10.f32", "__nv_log10f"},
|
||||
{"llvm.sqrt.f32", "__nv_sqrtf"},
|
||||
{"llvm.pow.f32", "__nv_powf"},
|
||||
{"llvm.sin.f32", "__nv_sinf"},
|
||||
{"llvm.cos.f32", "__nv_cosf"},
|
||||
{"llvm.copysign.f32", "__nv_copysignf"},
|
||||
{"llvm.trunc.f32", "__nv_truncf"},
|
||||
{"llvm.rint.f32", "__nv_rintf"},
|
||||
{"llvm.nearbyint.f32", "__nv_nearbyintf"},
|
||||
{"llvm.round.f32", "__nv_roundf"},
|
||||
{"llvm.minnum.f32", "__nv_fminf"},
|
||||
{"llvm.maxnum.f32", "__nv_fmaxf"},
|
||||
{"llvm.copysign.f32", "__nv_copysignf"},
|
||||
{"llvm.fma.f32", "__nv_fmaf"},
|
||||
|
||||
// 32-bit float math functions
|
||||
{"expm1f", "__nv_expm1f"},
|
||||
{"ldexpf", "__nv_ldexpf"},
|
||||
{"acosf", "__nv_acosf"},
|
||||
{"asinf", "__nv_asinf"},
|
||||
{"atanf", "__nv_atanf"},
|
||||
{"atan2f", "__nv_atan2f"},
|
||||
{"hypotf", "__nv_hypotf"},
|
||||
{"tanf", "__nv_tanf"},
|
||||
{"coshf", "__nv_coshf"},
|
||||
{"sinhf", "__nv_sinhf"},
|
||||
{"tanhf", "__nv_tanhf"},
|
||||
{"acoshf", "__nv_acoshf"},
|
||||
{"asinhf", "__nv_asinhf"},
|
||||
{"atanhf", "__nv_atanhf"},
|
||||
{"erff", "__nv_erff"},
|
||||
{"erfcf", "__nv_erfcf"},
|
||||
{"tgammaf", "__nv_tgammaf"},
|
||||
{"lgammaf", "__nv_lgammaf"},
|
||||
{"remainderf", "__nv_remainderf"},
|
||||
{"frexpf", "__nv_frexpf"},
|
||||
{"modff", "__nv_modff"},
|
||||
|
||||
// runtime library functions
|
||||
{"seq_free", "free"},
|
||||
{"seq_register_finalizer", ""},
|
||||
{"seq_gc_add_roots", ""},
|
||||
{"seq_gc_remove_roots", ""},
|
||||
{"seq_gc_clear_roots", ""},
|
||||
{"seq_gc_exclude_static_roots", ""},
|
||||
};
|
||||
|
||||
// functions that need to be generated as they're not available on GPU
|
||||
static const std::vector<std::pair<std::string, Codegen>> fillins = {
|
||||
{"seq_alloc",
|
||||
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
|
||||
auto *M = B.GetInsertBlock()->getModule();
|
||||
llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]);
|
||||
B.CreateRet(mem);
|
||||
}},
|
||||
|
||||
{"seq_alloc_atomic",
|
||||
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
|
||||
auto *M = B.GetInsertBlock()->getModule();
|
||||
llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]);
|
||||
B.CreateRet(mem);
|
||||
}},
|
||||
|
||||
{"seq_realloc",
|
||||
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
|
||||
auto *M = B.GetInsertBlock()->getModule();
|
||||
llvm::Value *mem = B.CreateCall(makeMalloc(M), args[1]);
|
||||
auto F = llvm::Intrinsic::getDeclaration(
|
||||
M, llvm::Intrinsic::memcpy,
|
||||
{B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt64Ty()});
|
||||
B.CreateCall(F, {mem, args[0], args[2], B.getFalse()});
|
||||
B.CreateRet(mem);
|
||||
}},
|
||||
|
||||
{"seq_calloc",
|
||||
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
|
||||
auto *M = B.GetInsertBlock()->getModule();
|
||||
llvm::Value *size = B.CreateMul(args[0], args[1]);
|
||||
llvm::Value *mem = B.CreateCall(makeMalloc(M), size);
|
||||
auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset,
|
||||
{B.getInt8PtrTy(), B.getInt64Ty()});
|
||||
B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()});
|
||||
B.CreateRet(mem);
|
||||
}},
|
||||
|
||||
{"seq_calloc_atomic",
|
||||
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
|
||||
auto *M = B.GetInsertBlock()->getModule();
|
||||
llvm::Value *size = B.CreateMul(args[0], args[1]);
|
||||
llvm::Value *mem = B.CreateCall(makeMalloc(M), size);
|
||||
auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset,
|
||||
{B.getInt8PtrTy(), B.getInt64Ty()});
|
||||
B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()});
|
||||
B.CreateRet(mem);
|
||||
}},
|
||||
|
||||
{"seq_alloc_exc",
|
||||
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
|
||||
// TODO: print error message and abort if in debug mode
|
||||
B.CreateUnreachable();
|
||||
}},
|
||||
|
||||
{"seq_throw",
|
||||
[](llvm::IRBuilder<> &B,
|
||||
const std::vector<llvm::Value *> &args) { B.CreateUnreachable(); }},
|
||||
};
|
||||
|
||||
for (auto &pair : remapping) {
|
||||
if (auto *F = M->getFunction(pair.first)) {
|
||||
llvm::Function *G = nullptr;
|
||||
if (pair.second.empty()) {
|
||||
G = makeNoOp(F);
|
||||
} else {
|
||||
G = M->getFunction(pair.second);
|
||||
if (!G)
|
||||
G = copyPrototype(F, pair.second);
|
||||
}
|
||||
|
||||
G->setWillReturn();
|
||||
F->replaceAllUsesWith(G);
|
||||
F->dropAllReferences();
|
||||
F->eraseFromParent();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &pair : fillins) {
|
||||
if (auto *F = M->getFunction(pair.first)) {
|
||||
llvm::Function *G = makeFillIn(F, pair.second);
|
||||
F->replaceAllUsesWith(G);
|
||||
F->dropAllReferences();
|
||||
F->eraseFromParent();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void exploreGV(llvm::GlobalValue *G, llvm::SmallPtrSetImpl<llvm::GlobalValue *> &keep) {
|
||||
if (keep.contains(G))
|
||||
return;
|
||||
|
||||
keep.insert(G);
|
||||
if (auto *F = llvm::dyn_cast<llvm::Function>(G)) {
|
||||
for (auto I = llvm::inst_begin(F), E = inst_end(F); I != E; ++I) {
|
||||
for (auto &U : I->operands()) {
|
||||
if (auto *G2 = llvm::dyn_cast<llvm::GlobalValue>(U.get()))
|
||||
exploreGV(G2, keep);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<llvm::GlobalValue *>
|
||||
getRequiredGVs(const std::vector<llvm::GlobalValue *> &kernels) {
|
||||
llvm::SmallPtrSet<llvm::GlobalValue *, 32> keep;
|
||||
for (auto *G : kernels) {
|
||||
exploreGV(G, keep);
|
||||
}
|
||||
return std::vector<llvm::GlobalValue *>(keep.begin(), keep.end());
|
||||
}
|
||||
|
||||
void moduleToPTX(llvm::Module *M, const std::string &filename,
|
||||
std::vector<llvm::GlobalValue *> &kernels,
|
||||
const std::string &cpuStr = "sm_30",
|
||||
const std::string &featuresStr = "+ptx42") {
|
||||
llvm::Triple triple(llvm::Triple::normalize(GPU_TRIPLE));
|
||||
llvm::TargetLibraryInfoImpl tlii(triple);
|
||||
|
||||
std::string err;
|
||||
const llvm::Target *target =
|
||||
llvm::TargetRegistry::lookupTarget("nvptx64", triple, err);
|
||||
seqassertn(target, "couldn't lookup target: {}", err);
|
||||
|
||||
const llvm::TargetOptions options =
|
||||
llvm::codegen::InitTargetOptionsFromCodeGenFlags(triple);
|
||||
|
||||
std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(
|
||||
triple.getTriple(), cpuStr, featuresStr, options,
|
||||
llvm::codegen::getExplicitRelocModel(), llvm::codegen::getExplicitCodeModel(),
|
||||
llvm::CodeGenOpt::Aggressive));
|
||||
|
||||
M->setDataLayout(machine->createDataLayout());
|
||||
auto keep = getRequiredGVs(kernels);
|
||||
|
||||
auto prune = [&](std::vector<llvm::GlobalValue *> keep) {
|
||||
auto pm = std::make_unique<llvm::legacy::PassManager>();
|
||||
pm->add(new llvm::TargetLibraryInfoWrapperPass(tlii));
|
||||
// Delete everything but kernel functions.
|
||||
pm->add(llvm::createGVExtractionPass(keep));
|
||||
// Delete unreachable globals.
|
||||
pm->add(llvm::createGlobalDCEPass());
|
||||
// Remove dead debug info.
|
||||
pm->add(llvm::createStripDeadDebugInfoPass());
|
||||
// Remove dead func decls.
|
||||
pm->add(llvm::createStripDeadPrototypesPass());
|
||||
pm->run(*M);
|
||||
};
|
||||
|
||||
// Remove non-kernel functions.
|
||||
prune(keep);
|
||||
|
||||
// Link libdevice and other cleanup.
|
||||
linkLibdevice(M, libdevice);
|
||||
remapFunctions(M);
|
||||
|
||||
// Run NVPTX passes and general opt pipeline.
|
||||
{
|
||||
auto pm = std::make_unique<llvm::legacy::PassManager>();
|
||||
auto fpm = std::make_unique<llvm::legacy::FunctionPassManager>(M);
|
||||
pm->add(new llvm::TargetLibraryInfoWrapperPass(tlii));
|
||||
|
||||
pm->add(llvm::createTargetTransformInfoWrapperPass(
|
||||
machine ? machine->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
|
||||
fpm->add(llvm::createTargetTransformInfoWrapperPass(
|
||||
machine ? machine->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
|
||||
|
||||
if (machine) {
|
||||
auto <m = dynamic_cast<llvm::LLVMTargetMachine &>(*machine);
|
||||
llvm::Pass *tpc = ltm.createPassConfig(*pm);
|
||||
pm->add(tpc);
|
||||
}
|
||||
|
||||
pm->add(llvm::createInternalizePass([&](const llvm::GlobalValue &gv) {
|
||||
return std::find(keep.begin(), keep.end(), &gv) != keep.end();
|
||||
}));
|
||||
|
||||
llvm::PassManagerBuilder pmb;
|
||||
unsigned optLevel = 3, sizeLevel = 0;
|
||||
pmb.OptLevel = optLevel;
|
||||
pmb.SizeLevel = sizeLevel;
|
||||
pmb.Inliner = llvm::createFunctionInliningPass(optLevel, sizeLevel, false);
|
||||
pmb.DisableUnrollLoops = false;
|
||||
pmb.LoopVectorize = true;
|
||||
pmb.SLPVectorize = true;
|
||||
|
||||
if (machine) {
|
||||
machine->adjustPassManager(pmb);
|
||||
}
|
||||
|
||||
pmb.populateModulePassManager(*pm);
|
||||
pmb.populateFunctionPassManager(*fpm);
|
||||
|
||||
fpm->doInitialization();
|
||||
for (llvm::Function &f : *M) {
|
||||
fpm->run(f);
|
||||
}
|
||||
fpm->doFinalization();
|
||||
pm->run(*M);
|
||||
}
|
||||
|
||||
// Prune again after optimizations.
|
||||
keep = getRequiredGVs(kernels);
|
||||
prune(keep);
|
||||
|
||||
// Clean up names.
|
||||
{
|
||||
for (auto &G : M->globals()) {
|
||||
G.setName(cleanUpName(G.getName()));
|
||||
}
|
||||
|
||||
for (auto &F : M->functions()) {
|
||||
if (F.getInstructionCount() > 0)
|
||||
F.setName(cleanUpName(F.getName()));
|
||||
}
|
||||
|
||||
for (auto *S : M->getIdentifiedStructTypes()) {
|
||||
S->setName(cleanUpName(S->getName()));
|
||||
}
|
||||
}
|
||||
|
||||
// Generate PTX file.
|
||||
{
|
||||
std::error_code errcode;
|
||||
auto out = std::make_unique<llvm::ToolOutputFile>(filename, errcode,
|
||||
llvm::sys::fs::OF_Text);
|
||||
if (errcode)
|
||||
compilationError(errcode.message());
|
||||
llvm::raw_pwrite_stream *os = &out->os();
|
||||
|
||||
auto &llvmtm = static_cast<llvm::LLVMTargetMachine &>(*machine);
|
||||
auto *mmiwp = new llvm::MachineModuleInfoWrapperPass(&llvmtm);
|
||||
llvm::legacy::PassManager pm;
|
||||
|
||||
pm.add(new llvm::TargetLibraryInfoWrapperPass(tlii));
|
||||
seqassertn(!machine->addPassesToEmitFile(pm, *os, nullptr, llvm::CGFT_AssemblyFile,
|
||||
/*DisableVerify=*/false, mmiwp),
|
||||
"could not add passes");
|
||||
const_cast<llvm::TargetLoweringObjectFile *>(llvmtm.getObjFileLowering())
|
||||
->Initialize(mmiwp->getMMI().getContext(), *machine);
|
||||
pm.run(*M);
|
||||
out->keep();
|
||||
}
|
||||
}
|
||||
|
||||
void addInitCall(llvm::Module *M, const std::string &filename) {
|
||||
llvm::LLVMContext &context = M->getContext();
|
||||
auto f =
|
||||
M->getOrInsertFunction("seq_nvptx_load_module", llvm::Type::getVoidTy(context),
|
||||
llvm::Type::getInt8PtrTy(context));
|
||||
auto *g = llvm::cast<llvm::Function>(f.getCallee());
|
||||
g->setDoesNotThrow();
|
||||
|
||||
auto *filenameVar = new llvm::GlobalVariable(
|
||||
*M, llvm::ArrayType::get(llvm::Type::getInt8Ty(context), filename.length() + 1),
|
||||
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
|
||||
llvm::ConstantDataArray::getString(context, filename), ".nvptx.filename");
|
||||
filenameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
|
||||
llvm::IRBuilder<> B(context);
|
||||
|
||||
if (auto *init = M->getFunction("seq_init")) {
|
||||
seqassertn(init->hasOneUse(), "seq_init used more than once");
|
||||
auto *use = llvm::dyn_cast<llvm::CallBase>(init->use_begin()->getUser());
|
||||
seqassertn(use, "seq_init use was not a call");
|
||||
B.SetInsertPoint(use->getNextNode());
|
||||
B.CreateCall(g, B.CreateBitCast(filenameVar, B.getInt8PtrTy()));
|
||||
}
|
||||
|
||||
for (auto &F : M->functions()) {
|
||||
if (F.hasFnAttribute("jit")) {
|
||||
B.SetInsertPoint(F.getEntryBlock().getFirstNonPHI());
|
||||
B.CreateCall(g, B.CreateBitCast(filenameVar, B.getInt8PtrTy()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cleanUpIntrinsics(llvm::Module *M) {
|
||||
llvm::LLVMContext &context = M->getContext();
|
||||
llvm::SmallVector<llvm::Function *, 16> remove;
|
||||
for (auto &F : *M) {
|
||||
if (F.getIntrinsicID() != llvm::Intrinsic::not_intrinsic &&
|
||||
F.getName().startswith("llvm.nvvm"))
|
||||
remove.push_back(&F);
|
||||
}
|
||||
|
||||
for (auto *F : remove) {
|
||||
F->replaceAllUsesWith(makeNoOp(F));
|
||||
F->dropAllReferences();
|
||||
F->eraseFromParent();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void applyGPUTransformations(llvm::Module *M, const std::string &ptxFilename) {
|
||||
llvm::LLVMContext &context = M->getContext();
|
||||
std::unique_ptr<llvm::Module> clone = llvm::CloneModule(*M);
|
||||
clone->setTargetTriple(llvm::Triple::normalize(GPU_TRIPLE));
|
||||
clone->setDataLayout(GPU_DL);
|
||||
|
||||
llvm::NamedMDNode *nvvmAnno = clone->getOrInsertNamedMetadata("nvvm.annotations");
|
||||
std::vector<llvm::GlobalValue *> kernels;
|
||||
|
||||
for (auto &F : *clone) {
|
||||
if (!F.hasFnAttribute("kernel"))
|
||||
continue;
|
||||
|
||||
llvm::Metadata *nvvmElem[] = {
|
||||
llvm::ConstantAsMetadata::get(&F),
|
||||
llvm::MDString::get(context, "kernel"),
|
||||
llvm::ConstantAsMetadata::get(
|
||||
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)),
|
||||
};
|
||||
|
||||
nvvmAnno->addOperand(llvm::MDNode::get(context, nvvmElem));
|
||||
kernels.push_back(&F);
|
||||
}
|
||||
|
||||
if (kernels.empty())
|
||||
return;
|
||||
|
||||
std::string filename = ptxFilename.empty() ? M->getSourceFileName() : ptxFilename;
|
||||
if (filename.empty() || filename[0] == '<')
|
||||
filename = "kernel";
|
||||
llvm::SmallString<128> path(filename);
|
||||
llvm::sys::path::replace_extension(path, "ptx");
|
||||
filename = path.str();
|
||||
|
||||
moduleToPTX(clone.get(), filename, kernels);
|
||||
cleanUpIntrinsics(M);
|
||||
addInitCall(M, filename);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace codon
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "codon/sir/llvm/llvm.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
|
||||
/// Applies GPU-specific transformations and generates PTX
|
||||
/// code from kernel functions in the given LLVM module.
|
||||
/// @param module LLVM module containing GPU kernel functions (marked with "kernel"
|
||||
/// annotation)
|
||||
/// @param ptxFilename Filename for output PTX code; empty to use filename based on
|
||||
/// module
|
||||
void applyGPUTransformations(llvm::Module *module, const std::string &ptxFilename = "");
|
||||
|
||||
} // namespace ir
|
||||
} // namespace codon
|
|
@ -22,6 +22,7 @@ namespace {
|
|||
const std::string EXPORT_ATTR = "std.internal.attributes.export";
|
||||
const std::string INLINE_ATTR = "std.internal.attributes.inline";
|
||||
const std::string NOINLINE_ATTR = "std.internal.attributes.noinline";
|
||||
const std::string GPU_KERNEL_ATTR = "std.gpu.kernel";
|
||||
} // namespace
|
||||
|
||||
llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
|
||||
|
@ -41,6 +42,8 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
|
|||
std::string LLVMVisitor::getNameForFunction(const Func *x) {
|
||||
if (isA<ExternalFunc>(x) || util::hasAttribute(x, EXPORT_ATTR)) {
|
||||
return x->getUnmangledName();
|
||||
} else if (util::hasAttribute(x, GPU_KERNEL_ATTR)) {
|
||||
return x->getName();
|
||||
} else {
|
||||
return x->referenceString();
|
||||
}
|
||||
|
@ -49,7 +52,7 @@ std::string LLVMVisitor::getNameForFunction(const Func *x) {
|
|||
LLVMVisitor::LLVMVisitor()
|
||||
: util::ConstVisitor(), context(std::make_unique<llvm::LLVMContext>()), M(),
|
||||
B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr), block(nullptr),
|
||||
value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), db(),
|
||||
value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), catches(), db(),
|
||||
plugins(nullptr) {
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
|
@ -287,6 +290,7 @@ LLVMVisitor::takeModule(Module *module, const SrcInfo *src) {
|
|||
coro.reset();
|
||||
loops.clear();
|
||||
trycatch.clear();
|
||||
catches.clear();
|
||||
db.reset();
|
||||
|
||||
context = std::make_unique<llvm::LLVMContext>();
|
||||
|
@ -697,6 +701,16 @@ void LLVMVisitor::exitTryCatch() {
|
|||
trycatch.pop_back();
|
||||
}
|
||||
|
||||
void LLVMVisitor::enterCatch(CatchData data) {
|
||||
catches.push_back(std::move(data));
|
||||
catches.back().sequenceNumber = nextSequenceNumber++;
|
||||
}
|
||||
|
||||
void LLVMVisitor::exitCatch() {
|
||||
seqassertn(!catches.empty(), "no catches present");
|
||||
catches.pop_back();
|
||||
}
|
||||
|
||||
LLVMVisitor::TryCatchData *LLVMVisitor::getInnermostTryCatch() {
|
||||
return trycatch.empty() ? nullptr : &trycatch.back();
|
||||
}
|
||||
|
@ -1119,7 +1133,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
|
|||
err.print("LLVM", buf);
|
||||
// LOG("-> ERR {}", x->referenceString());
|
||||
// LOG(" {}", code);
|
||||
compilationError(buf.str());
|
||||
compilationError(fmt::format("{} ({})", buf.str(), x->getName()));
|
||||
}
|
||||
sub->setDataLayout(M->getDataLayout());
|
||||
|
||||
|
@ -1152,6 +1166,9 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
|
|||
setDebugInfoForNode(x);
|
||||
|
||||
auto *fnAttributes = x->getAttribute<KeyValueAttribute>();
|
||||
if (x->isJIT()) {
|
||||
func->addFnAttr(llvm::Attribute::get(*context, "jit"));
|
||||
}
|
||||
if (x->isJIT() || (fnAttributes && fnAttributes->has(EXPORT_ATTR))) {
|
||||
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
|
||||
} else {
|
||||
|
@ -1163,6 +1180,11 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
|
|||
if (fnAttributes && fnAttributes->has(NOINLINE_ATTR)) {
|
||||
func->addFnAttr(llvm::Attribute::AttrKind::NoInline);
|
||||
}
|
||||
if (fnAttributes && fnAttributes->has(GPU_KERNEL_ATTR)) {
|
||||
func->addFnAttr(llvm::Attribute::AttrKind::NoInline);
|
||||
func->addFnAttr(llvm::Attribute::get(*context, "kernel"));
|
||||
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
|
||||
}
|
||||
func->setPersonalityFn(llvm::cast<llvm::Constant>(makePersonalityFunc().getCallee()));
|
||||
|
||||
auto *funcType = cast<types::FuncType>(x->getType());
|
||||
|
@ -1352,6 +1374,10 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
|
|||
return B->getDoubleTy();
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::Float32Type>(t)) {
|
||||
return B->getFloatTy();
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::BoolType>(t)) {
|
||||
return B->getInt8Ty();
|
||||
}
|
||||
|
@ -1406,6 +1432,11 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
|
|||
return B->getIntNTy(x->getLen());
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::VectorType>(t)) {
|
||||
return llvm::VectorType::get(getLLVMType(x->getBase()), x->getCount(),
|
||||
/*Scalable=*/false);
|
||||
}
|
||||
|
||||
if (auto *x = cast<dsl::types::CustomType>(t)) {
|
||||
return x->getBuilder()->buildType(this);
|
||||
}
|
||||
|
@ -1429,6 +1460,11 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
|
|||
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::Float32Type>(t)) {
|
||||
return db.builder->createBasicType(
|
||||
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::BoolType>(t)) {
|
||||
return db.builder->createBasicType(
|
||||
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean);
|
||||
|
@ -1554,6 +1590,12 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
|
|||
x->isSigned() ? llvm::dwarf::DW_ATE_signed : llvm::dwarf::DW_ATE_unsigned);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::VectorType>(t)) {
|
||||
return db.builder->createBasicType(x->getName(),
|
||||
layout.getTypeAllocSizeInBits(type),
|
||||
llvm::dwarf::DW_ATE_unsigned);
|
||||
}
|
||||
|
||||
if (auto *x = cast<dsl::types::CustomType>(t)) {
|
||||
return x->getBuilder()->buildDebugType(this);
|
||||
}
|
||||
|
@ -2094,7 +2136,12 @@ void LLVMVisitor::visit(const TryCatchFlow *x) {
|
|||
}
|
||||
|
||||
B->CreateStore(excStateCaught, tc.excFlag);
|
||||
CatchData cd;
|
||||
cd.exception = objPtr;
|
||||
cd.typeId = objType;
|
||||
enterCatch(cd);
|
||||
process(catches[i]->getHandler());
|
||||
exitCatch();
|
||||
B->SetInsertPoint(block);
|
||||
B->CreateBr(tc.finallyBlock);
|
||||
}
|
||||
|
@ -2455,10 +2502,21 @@ void LLVMVisitor::visit(const ThrowInstr *x) {
|
|||
// note: exception header should be set in the frontend
|
||||
auto excAllocFunc = makeExcAllocFunc();
|
||||
auto throwFunc = makeThrowFunc();
|
||||
process(x->getValue());
|
||||
llvm::Value *obj = nullptr;
|
||||
llvm::Value *typ = nullptr;
|
||||
|
||||
if (x->getValue()) {
|
||||
process(x->getValue());
|
||||
obj = value;
|
||||
typ = B->getInt32(getTypeIdx(x->getValue()->getType()));
|
||||
} else {
|
||||
seqassertn(!catches.empty(), "empty raise outside of except block");
|
||||
obj = catches.back().exception;
|
||||
typ = catches.back().typeId;
|
||||
}
|
||||
|
||||
B->SetInsertPoint(block);
|
||||
llvm::Value *exc = B->CreateCall(
|
||||
excAllocFunc, {B->getInt32(getTypeIdx(x->getValue()->getType())), value});
|
||||
llvm::Value *exc = B->CreateCall(excAllocFunc, {typ, obj});
|
||||
call(throwFunc, exc);
|
||||
}
|
||||
|
||||
|
|
|
@ -89,6 +89,11 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
struct CatchData : NestableData {
|
||||
llvm::Value *exception;
|
||||
llvm::Value *typeId;
|
||||
};
|
||||
|
||||
struct DebugInfo {
|
||||
/// LLVM debug info builder
|
||||
std::unique_ptr<llvm::DIBuilder> builder;
|
||||
|
@ -139,6 +144,8 @@ private:
|
|||
std::vector<LoopData> loops;
|
||||
/// Try-catch data stack
|
||||
std::vector<TryCatchData> trycatch;
|
||||
/// Catch-block data stack
|
||||
std::vector<CatchData> catches;
|
||||
/// Debug information
|
||||
DebugInfo db;
|
||||
/// Plugin manager
|
||||
|
@ -181,6 +188,8 @@ private:
|
|||
void exitLoop();
|
||||
void enterTryCatch(TryCatchData data);
|
||||
void exitTryCatch();
|
||||
void enterCatch(CatchData data);
|
||||
void exitCatch();
|
||||
TryCatchData *getInnermostTryCatch();
|
||||
TryCatchData *getInnermostTryCatchBeforeLoop();
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/InstIterator.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "codon/sir/llvm/coro/Coroutines.h"
|
||||
#include "codon/sir/llvm/gpu.h"
|
||||
#include "codon/util/common.h"
|
||||
#include "llvm/CodeGen/CommandFlags.h"
|
||||
|
||||
|
@ -642,13 +643,17 @@ void verify(llvm::Module *module) {
|
|||
void optimize(llvm::Module *module, bool debug, bool jit, PluginManager *plugins) {
|
||||
verify(module);
|
||||
{
|
||||
TIME("llvm/opt");
|
||||
TIME("llvm/opt1");
|
||||
runLLVMOptimizationPasses(module, debug, jit, plugins);
|
||||
}
|
||||
if (!debug) {
|
||||
TIME("llvm/opt2");
|
||||
runLLVMOptimizationPasses(module, debug, jit, plugins);
|
||||
}
|
||||
{
|
||||
TIME("llvm/gpu");
|
||||
applyGPUTransformations(module);
|
||||
}
|
||||
verify(module);
|
||||
}
|
||||
|
||||
|
|
|
@ -57,6 +57,7 @@ const std::string Module::BOOL_NAME = "bool";
|
|||
const std::string Module::BYTE_NAME = "byte";
|
||||
const std::string Module::INT_NAME = "int";
|
||||
const std::string Module::FLOAT_NAME = "float";
|
||||
const std::string Module::FLOAT32_NAME = "float32";
|
||||
const std::string Module::STRING_NAME = "str";
|
||||
|
||||
const std::string Module::EQ_MAGIC_NAME = "__eq__";
|
||||
|
@ -126,7 +127,9 @@ Func *Module::getOrRealizeMethod(types::Type *parent, const std::string &methodN
|
|||
return cache->realizeFunction(method, translateArgs(args),
|
||||
translateGenerics(generics), cls);
|
||||
} catch (const exc::ParserException &e) {
|
||||
LOG_IR("getOrRealizeMethod parser error: {}", e.what());
|
||||
for (int i = 0; i < e.messages.size(); i++)
|
||||
LOG_IR("getOrRealizeMethod parser error at {}: {}", e.locations[i],
|
||||
e.messages[i]);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -162,7 +165,8 @@ types::Type *Module::getOrRealizeType(const std::string &typeName,
|
|||
try {
|
||||
return cache->realizeType(type, translateGenerics(generics));
|
||||
} catch (const exc::ParserException &e) {
|
||||
LOG_IR("getOrRealizeType parser error: {}", e.what());
|
||||
for (int i = 0; i < e.messages.size(); i++)
|
||||
LOG_IR("getOrRealizeType parser error at {}: {}", e.locations[i], e.messages[i]);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -197,6 +201,12 @@ types::Type *Module::getFloatType() {
|
|||
return Nr<types::FloatType>();
|
||||
}
|
||||
|
||||
types::Type *Module::getFloat32Type() {
|
||||
if (auto *rVal = getType(FLOAT32_NAME))
|
||||
return rVal;
|
||||
return Nr<types::Float32Type>();
|
||||
}
|
||||
|
||||
types::Type *Module::getStringType() {
|
||||
if (auto *rVal = getType(STRING_NAME))
|
||||
return rVal;
|
||||
|
@ -243,6 +253,10 @@ types::Type *Module::getIntNType(unsigned int len, bool sign) {
|
|||
return getOrRealizeType(sign ? "Int" : "UInt", {len});
|
||||
}
|
||||
|
||||
types::Type *Module::getVectorType(unsigned count, types::Type *base) {
|
||||
return getOrRealizeType("Vec", {base, count});
|
||||
}
|
||||
|
||||
types::Type *Module::getTupleType(std::vector<types::Type *> args) {
|
||||
std::vector<ast::types::TypePtr> argTypes;
|
||||
for (auto *t : args) {
|
||||
|
@ -332,5 +346,14 @@ types::Type *Module::unsafeGetIntNType(unsigned int len, bool sign) {
|
|||
return Nr<types::IntNType>(len, sign);
|
||||
}
|
||||
|
||||
types::Type *Module::unsafeGetVectorType(unsigned int count, types::Type *base) {
|
||||
auto *primitive = cast<types::PrimitiveType>(base);
|
||||
auto name = types::VectorType::getInstanceName(count, primitive);
|
||||
if (auto *rVal = getType(name))
|
||||
return rVal;
|
||||
seqassertn(primitive, "base type must be a primitive type");
|
||||
return Nr<types::VectorType>(count, primitive);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace codon
|
||||
|
|
|
@ -31,6 +31,7 @@ public:
|
|||
static const std::string BYTE_NAME;
|
||||
static const std::string INT_NAME;
|
||||
static const std::string FLOAT_NAME;
|
||||
static const std::string FLOAT32_NAME;
|
||||
static const std::string STRING_NAME;
|
||||
|
||||
static const std::string EQ_MAGIC_NAME;
|
||||
|
@ -305,6 +306,8 @@ public:
|
|||
types::Type *getIntType();
|
||||
/// @return the float type
|
||||
types::Type *getFloatType();
|
||||
/// @return the float32 type
|
||||
types::Type *getFloat32Type();
|
||||
/// @return the string type
|
||||
types::Type *getStringType();
|
||||
/// Gets a pointer type.
|
||||
|
@ -335,6 +338,11 @@ public:
|
|||
/// @param sign true if signed
|
||||
/// @return a variable length integer type
|
||||
types::Type *getIntNType(unsigned len, bool sign);
|
||||
/// Gets a vector type.
|
||||
/// @param count the vector size
|
||||
/// @param base the vector base type (MUST be a primitive type)
|
||||
/// @return a vector type
|
||||
types::Type *getVectorType(unsigned count, types::Type *base);
|
||||
/// Gets a tuple type.
|
||||
/// @param args the arg types
|
||||
/// @return the tuple type
|
||||
|
@ -401,6 +409,12 @@ public:
|
|||
/// @param sign true if signed
|
||||
/// @return a variable length integer type
|
||||
types::Type *unsafeGetIntNType(unsigned len, bool sign);
|
||||
/// Gets a vector type. Should generally not be used as no
|
||||
/// type-checker information is generated.
|
||||
/// @param count the vector size
|
||||
/// @param base the vector base type (MUST be a primitive type)
|
||||
/// @return a vector type
|
||||
types::Type *unsafeGetVectorType(unsigned count, types::Type *base);
|
||||
|
||||
private:
|
||||
void store(types::Type *t) {
|
||||
|
|
|
@ -281,6 +281,9 @@ struct CanonConstSub : public RewriteRule {
|
|||
Value *lhs = v->front();
|
||||
Value *rhs = v->back();
|
||||
|
||||
if (!lhs->getType()->is(rhs->getType()))
|
||||
return;
|
||||
|
||||
Value *newCall = nullptr;
|
||||
if (util::isConst<int64_t>(rhs)) {
|
||||
auto c = util::getConst<int64_t>(rhs);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "const_fold.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <utility>
|
||||
|
||||
#include "codon/sir/util/cloning.h"
|
||||
#include "codon/sir/util/irtools.h"
|
||||
|
@ -15,15 +16,28 @@ namespace ir {
|
|||
namespace transform {
|
||||
namespace folding {
|
||||
namespace {
|
||||
auto pyDivmod(int64_t self, int64_t other) {
|
||||
auto d = self / other;
|
||||
auto m = self - d * other;
|
||||
if (m && ((other ^ m) < 0)) {
|
||||
m += other;
|
||||
d -= 1;
|
||||
}
|
||||
return std::make_pair(d, m);
|
||||
}
|
||||
|
||||
template <typename Func, typename Out> class IntFloatBinaryRule : public RewriteRule {
|
||||
private:
|
||||
Func f;
|
||||
std::string magic;
|
||||
types::Type *out;
|
||||
bool excludeRHSZero;
|
||||
|
||||
public:
|
||||
IntFloatBinaryRule(Func f, std::string magic, types::Type *out)
|
||||
: f(std::move(f)), magic(std::move(magic)), out(out) {}
|
||||
IntFloatBinaryRule(Func f, std::string magic, types::Type *out,
|
||||
bool excludeRHSZero = false)
|
||||
: f(std::move(f)), magic(std::move(magic)), out(out),
|
||||
excludeRHSZero(excludeRHSZero) {}
|
||||
|
||||
virtual ~IntFloatBinaryRule() noexcept = default;
|
||||
|
||||
|
@ -41,11 +55,15 @@ public:
|
|||
if (isA<FloatConst>(leftConst) && isA<IntConst>(rightConst)) {
|
||||
auto left = cast<FloatConst>(leftConst)->getVal();
|
||||
auto right = cast<IntConst>(rightConst)->getVal();
|
||||
if (excludeRHSZero && right == 0)
|
||||
return;
|
||||
return setResult(M->template N<TemplatedConst<Out>>(v->getSrcInfo(),
|
||||
f(left, (double)right), out));
|
||||
} else if (isA<IntConst>(leftConst) && isA<FloatConst>(rightConst)) {
|
||||
auto left = cast<IntConst>(leftConst)->getVal();
|
||||
auto right = cast<FloatConst>(rightConst)->getVal();
|
||||
if (excludeRHSZero && right == 0.0)
|
||||
return;
|
||||
return setResult(M->template N<TemplatedConst<Out>>(v->getSrcInfo(),
|
||||
f((double)left, right), out));
|
||||
}
|
||||
|
@ -140,15 +158,22 @@ template <typename Func> auto floatToFloatBinary(Module *m, Func f, std::string
|
|||
std::move(f), std::move(magic), m->getFloatType(), m->getFloatType());
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
auto floatToFloatBinaryNoZeroRHS(Module *m, Func f, std::string magic) {
|
||||
return std::make_unique<DoubleConstantBinaryRuleExcludeRHSZero<double, Func, double>>(
|
||||
std::move(f), std::move(magic), m->getFloatType(), m->getFloatType());
|
||||
}
|
||||
|
||||
template <typename Func> auto floatToBoolBinary(Module *m, Func f, std::string magic) {
|
||||
return std::make_unique<DoubleConstantBinaryRule<double, Func, bool>>(
|
||||
std::move(f), std::move(magic), m->getFloatType(), m->getBoolType());
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
auto intFloatToFloatBinary(Module *m, Func f, std::string magic) {
|
||||
auto intFloatToFloatBinary(Module *m, Func f, std::string magic,
|
||||
bool excludeRHSZero = false) {
|
||||
return std::make_unique<IntFloatBinaryRule<Func, double>>(
|
||||
std::move(f), std::move(magic), m->getFloatType());
|
||||
std::move(f), std::move(magic), m->getFloatType(), excludeRHSZero);
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
|
@ -222,8 +247,15 @@ void FoldingPass::registerStandardRules(Module *m) {
|
|||
intToIntBinary(m, BINOP(+), Module::ADD_MAGIC_NAME));
|
||||
registerRule("int-constant-subtraction",
|
||||
intToIntBinary(m, BINOP(-), Module::SUB_MAGIC_NAME));
|
||||
registerRule("int-constant-floor-div",
|
||||
intToIntBinaryNoZeroRHS(m, BINOP(/), Module::FLOOR_DIV_MAGIC_NAME));
|
||||
if (pyNumerics) {
|
||||
registerRule("int-constant-floor-div",
|
||||
intToIntBinaryNoZeroRHS(
|
||||
m, [](auto x, auto y) -> auto{ return pyDivmod(x, y).first; },
|
||||
Module::FLOOR_DIV_MAGIC_NAME));
|
||||
} else {
|
||||
registerRule("int-constant-floor-div",
|
||||
intToIntBinaryNoZeroRHS(m, BINOP(/), Module::FLOOR_DIV_MAGIC_NAME));
|
||||
}
|
||||
registerRule("int-constant-mul", intToIntBinary(m, BINOP(*), Module::MUL_MAGIC_NAME));
|
||||
registerRule("int-constant-lshift",
|
||||
intToIntBinary(m, BINOP(<<), Module::LSHIFT_MAGIC_NAME));
|
||||
|
@ -233,8 +265,15 @@ void FoldingPass::registerStandardRules(Module *m) {
|
|||
registerRule("int-constant-xor", intToIntBinary(m, BINOP(^), Module::XOR_MAGIC_NAME));
|
||||
registerRule("int-constant-or", intToIntBinary(m, BINOP(|), Module::OR_MAGIC_NAME));
|
||||
registerRule("int-constant-and", intToIntBinary(m, BINOP(&), Module::AND_MAGIC_NAME));
|
||||
registerRule("int-constant-mod",
|
||||
intToIntBinaryNoZeroRHS(m, BINOP(%), Module::MOD_MAGIC_NAME));
|
||||
if (pyNumerics) {
|
||||
registerRule("int-constant-mod",
|
||||
intToIntBinaryNoZeroRHS(
|
||||
m, [](auto x, auto y) -> auto{ return pyDivmod(x, y).second; },
|
||||
Module::MOD_MAGIC_NAME));
|
||||
} else {
|
||||
registerRule("int-constant-mod",
|
||||
intToIntBinaryNoZeroRHS(m, BINOP(%), Module::MOD_MAGIC_NAME));
|
||||
}
|
||||
|
||||
// binary, double constant, int->bool
|
||||
registerRule("int-constant-eq", intToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME));
|
||||
|
@ -273,8 +312,13 @@ void FoldingPass::registerStandardRules(Module *m) {
|
|||
floatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME));
|
||||
registerRule("float-constant-subtraction",
|
||||
floatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME));
|
||||
registerRule("float-constant-floor-div",
|
||||
floatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME));
|
||||
if (pyNumerics) {
|
||||
registerRule("float-constant-floor-div",
|
||||
floatToFloatBinaryNoZeroRHS(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME));
|
||||
} else {
|
||||
registerRule("float-constant-floor-div",
|
||||
floatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME));
|
||||
}
|
||||
registerRule("float-constant-mul",
|
||||
floatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME));
|
||||
registerRule(
|
||||
|
@ -301,8 +345,9 @@ void FoldingPass::registerStandardRules(Module *m) {
|
|||
intFloatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME));
|
||||
registerRule("int-float-constant-subtraction",
|
||||
intFloatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME));
|
||||
registerRule("int-float-constant-floor-div",
|
||||
intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME));
|
||||
registerRule(
|
||||
"int-float-constant-floor-div",
|
||||
intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME, pyNumerics));
|
||||
registerRule("int-float-constant-mul",
|
||||
intFloatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME));
|
||||
|
||||
|
|
|
@ -12,18 +12,21 @@ namespace transform {
|
|||
namespace folding {
|
||||
|
||||
class FoldingPass : public OperatorPass, public Rewriter {
|
||||
private:
|
||||
bool pyNumerics;
|
||||
|
||||
void registerStandardRules(Module *m);
|
||||
|
||||
public:
|
||||
/// Constructs a folding pass.
|
||||
FoldingPass() : OperatorPass(/*childrenFirst=*/true) {}
|
||||
FoldingPass(bool pyNumerics = false)
|
||||
: OperatorPass(/*childrenFirst=*/true), pyNumerics(pyNumerics) {}
|
||||
|
||||
static const std::string KEY;
|
||||
std::string getKey() const override { return KEY; }
|
||||
|
||||
void run(Module *m) override;
|
||||
void handle(CallInstr *v) override;
|
||||
|
||||
private:
|
||||
void registerStandardRules(Module *m);
|
||||
};
|
||||
|
||||
} // namespace folding
|
||||
|
|
|
@ -13,12 +13,12 @@ const std::string FoldingPassGroup::KEY = "core-folding-pass-group";
|
|||
FoldingPassGroup::FoldingPassGroup(const std::string &sideEffectsPass,
|
||||
const std::string &reachingDefPass,
|
||||
const std::string &globalVarPass, int repeat,
|
||||
bool runGlobalDemotion)
|
||||
bool runGlobalDemotion, bool pyNumerics)
|
||||
: PassGroup(repeat) {
|
||||
auto gdUnique = runGlobalDemotion ? std::make_unique<cleanup::GlobalDemotionPass>()
|
||||
: std::unique_ptr<cleanup::GlobalDemotionPass>();
|
||||
auto canonUnique = std::make_unique<cleanup::CanonicalizationPass>(sideEffectsPass);
|
||||
auto fpUnique = std::make_unique<FoldingPass>();
|
||||
auto fpUnique = std::make_unique<FoldingPass>(pyNumerics);
|
||||
auto dceUnique = std::make_unique<cleanup::DeadCodeCleanupPass>(sideEffectsPass);
|
||||
|
||||
gd = gdUnique.get();
|
||||
|
|
|
@ -29,9 +29,11 @@ public:
|
|||
/// @param globalVarPass the key of the global variables pass
|
||||
/// @param repeat default number of times to repeat the pass
|
||||
/// @param runGlobalDemotion whether to demote globals if possible
|
||||
/// @param pyNumerics whether to use Python (vs. C) semantics when folding
|
||||
FoldingPassGroup(const std::string &sideEffectsPass,
|
||||
const std::string &reachingDefPass, const std::string &globalVarPass,
|
||||
int repeat = 5, bool runGlobalDemotion = true);
|
||||
int repeat = 5, bool runGlobalDemotion = true,
|
||||
bool pyNumerics = false);
|
||||
|
||||
bool shouldRepeat(int num) const override;
|
||||
};
|
||||
|
|
|
@ -185,11 +185,11 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
|||
capKey,
|
||||
/*globalAssignmentHasSideEffects=*/false),
|
||||
{capKey});
|
||||
registerPass(
|
||||
std::make_unique<folding::FoldingPassGroup>(
|
||||
seKey1, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/false),
|
||||
/*insertBefore=*/"", {seKey1, rdKey, globalKey},
|
||||
{seKey1, rdKey, cfgKey, globalKey, capKey});
|
||||
registerPass(std::make_unique<folding::FoldingPassGroup>(
|
||||
seKey1, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/false,
|
||||
pyNumerics),
|
||||
/*insertBefore=*/"", {seKey1, rdKey, globalKey},
|
||||
{seKey1, rdKey, cfgKey, globalKey, capKey});
|
||||
|
||||
// parallel
|
||||
registerPass(std::make_unique<parallel::OpenMPPass>(), /*insertBefore=*/"", {},
|
||||
|
@ -198,12 +198,12 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
|||
if (init != Init::JIT) {
|
||||
// Don't demote globals in JIT mode, since they might be used later
|
||||
// by another user input.
|
||||
registerPass(
|
||||
std::make_unique<folding::FoldingPassGroup>(seKey2, rdKey, globalKey,
|
||||
/*repeat=*/5,
|
||||
/*runGlobalDemoton=*/true),
|
||||
/*insertBefore=*/"", {seKey2, rdKey, globalKey},
|
||||
{seKey2, rdKey, cfgKey, globalKey});
|
||||
registerPass(std::make_unique<folding::FoldingPassGroup>(
|
||||
seKey2, rdKey, globalKey,
|
||||
/*repeat=*/5,
|
||||
/*runGlobalDemoton=*/true, pyNumerics),
|
||||
/*insertBefore=*/"", {seKey2, rdKey, globalKey},
|
||||
{seKey2, rdKey, cfgKey, globalKey});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -89,6 +89,9 @@ private:
|
|||
/// passes to avoid registering
|
||||
std::vector<std::string> disabled;
|
||||
|
||||
/// whether to use Python (vs. C) numeric semantics in passes
|
||||
bool pyNumerics;
|
||||
|
||||
public:
|
||||
/// PassManager initialization mode.
|
||||
enum Init {
|
||||
|
@ -98,14 +101,17 @@ public:
|
|||
JIT,
|
||||
};
|
||||
|
||||
explicit PassManager(Init init, std::vector<std::string> disabled = {})
|
||||
explicit PassManager(Init init, std::vector<std::string> disabled = {},
|
||||
bool pyNumerics = false)
|
||||
: km(), passes(), analyses(), executionOrder(), results(),
|
||||
disabled(std::move(disabled)) {
|
||||
disabled(std::move(disabled)), pyNumerics(pyNumerics) {
|
||||
registerStandardPasses(init);
|
||||
}
|
||||
|
||||
explicit PassManager(bool debug = false, std::vector<std::string> disabled = {})
|
||||
: PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled)) {}
|
||||
explicit PassManager(bool debug = false, std::vector<std::string> disabled = {},
|
||||
bool pyNumerics = false)
|
||||
: PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled),
|
||||
pyNumerics) {}
|
||||
|
||||
/// Checks if the given pass is included in this manager.
|
||||
/// @param key the pass key
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "codon/sir/transform/parallel/schedule.h"
|
||||
#include "codon/sir/util/cloning.h"
|
||||
|
@ -15,14 +16,22 @@ namespace transform {
|
|||
namespace parallel {
|
||||
namespace {
|
||||
const std::string ompModule = "std.openmp";
|
||||
const std::string gpuModule = "std.gpu";
|
||||
const std::string builtinModule = "std.internal.builtin";
|
||||
|
||||
void warn(const std::string &msg, const Value *v) {
|
||||
auto src = v->getSrcInfo();
|
||||
compilationWarning(msg, src.file, src.line, src.col);
|
||||
}
|
||||
|
||||
struct OMPTypes {
|
||||
types::Type *i64 = nullptr;
|
||||
types::Type *i32 = nullptr;
|
||||
types::Type *i8ptr = nullptr;
|
||||
types::Type *i32ptr = nullptr;
|
||||
|
||||
explicit OMPTypes(Module *M) {
|
||||
i64 = M->getIntType();
|
||||
i32 = M->getIntNType(32, /*sign=*/true);
|
||||
i8ptr = M->getPointerType(M->getByteType());
|
||||
i32ptr = M->getPointerType(i32);
|
||||
|
@ -142,6 +151,28 @@ struct Reduction {
|
|||
default:
|
||||
return nullptr;
|
||||
}
|
||||
} else if (isA<types::Float32Type>(type)) {
|
||||
auto *f32 = M->getOrRealizeType("float32");
|
||||
float value = 0.0;
|
||||
|
||||
switch (kind) {
|
||||
case Kind::ADD:
|
||||
value = 0.0;
|
||||
break;
|
||||
case Kind::MUL:
|
||||
value = 1.0;
|
||||
break;
|
||||
case Kind::MIN:
|
||||
value = std::numeric_limits<float>::max();
|
||||
break;
|
||||
case Kind::MAX:
|
||||
value = std::numeric_limits<float>::min();
|
||||
break;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return (*f32)(*M->getFloat(value));
|
||||
}
|
||||
|
||||
auto *init = (*type)();
|
||||
|
@ -239,6 +270,23 @@ struct Reduction {
|
|||
default:
|
||||
break;
|
||||
}
|
||||
} else if (isA<types::Float32Type>(type)) {
|
||||
switch (kind) {
|
||||
case Kind::ADD:
|
||||
func = "_atomic_float32_add";
|
||||
break;
|
||||
case Kind::MUL:
|
||||
func = "_atomic_float32_mul";
|
||||
break;
|
||||
case Kind::MIN:
|
||||
func = "_atomic_float32_min";
|
||||
break;
|
||||
case Kind::MAX:
|
||||
func = "_atomic_float32_max";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!func.empty()) {
|
||||
|
@ -476,10 +524,16 @@ struct SharedInfo {
|
|||
Reduction reduction; // the reduction we're performing, or empty if none
|
||||
};
|
||||
|
||||
struct ParallelLoopTemplateReplacer : public util::Operator {
|
||||
struct LoopTemplateReplacer : public util::Operator {
|
||||
BodiedFunc *parent;
|
||||
CallInstr *replacement;
|
||||
Var *loopVar;
|
||||
|
||||
LoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar)
|
||||
: util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar) {}
|
||||
};
|
||||
|
||||
struct ParallelLoopTemplateReplacer : public LoopTemplateReplacer {
|
||||
ReductionIdentifier *reds;
|
||||
std::vector<SharedInfo> sharedInfo;
|
||||
ReductionLocks locks;
|
||||
|
@ -489,9 +543,8 @@ struct ParallelLoopTemplateReplacer : public util::Operator {
|
|||
|
||||
ParallelLoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar,
|
||||
ReductionIdentifier *reds)
|
||||
: util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar),
|
||||
reds(reds), sharedInfo(), locks(), locRef(nullptr), reductionLocRef(nullptr),
|
||||
gtid(nullptr) {}
|
||||
: LoopTemplateReplacer(parent, replacement, loopVar), reds(reds), sharedInfo(),
|
||||
locks(), locRef(nullptr), reductionLocRef(nullptr), gtid(nullptr) {}
|
||||
|
||||
unsigned numReductions() {
|
||||
unsigned num = 0;
|
||||
|
@ -1102,6 +1155,72 @@ struct TaskLoopRoutineStubReplacer : public ParallelLoopTemplateReplacer {
|
|||
}
|
||||
};
|
||||
|
||||
struct GPULoopBodyStubReplacer : public util::Operator {
|
||||
CallInstr *replacement;
|
||||
Var *loopVar;
|
||||
int64_t step;
|
||||
|
||||
GPULoopBodyStubReplacer(CallInstr *replacement, Var *loopVar, int64_t step)
|
||||
: util::Operator(), replacement(replacement), loopVar(loopVar), step(step) {}
|
||||
|
||||
void handle(CallInstr *v) override {
|
||||
auto *M = v->getModule();
|
||||
auto *func = util::getFunc(v->getCallee());
|
||||
if (!func)
|
||||
return;
|
||||
auto name = func->getUnmangledName();
|
||||
|
||||
if (name == "_gpu_loop_body_stub") {
|
||||
seqassertn(replacement, "unexpected double replacement");
|
||||
seqassertn(v->numArgs() == 2, "unexpected loop body stub");
|
||||
|
||||
// the template passes gtid, privs and shareds to the body stub for convenience
|
||||
auto *idx = v->front();
|
||||
auto *args = v->back();
|
||||
unsigned next = 0;
|
||||
|
||||
std::vector<Value *> newArgs;
|
||||
for (auto *arg : *replacement) {
|
||||
// std::cout << "A: " << *arg << std::endl;
|
||||
if (getVarFromOutlinedArg(arg)->getId() == loopVar->getId()) {
|
||||
// std::cout << "(loop var)" << std::endl;
|
||||
newArgs.push_back(idx);
|
||||
} else {
|
||||
newArgs.push_back(util::tupleGet(args, next++));
|
||||
}
|
||||
}
|
||||
|
||||
auto *outlinedFunc = cast<BodiedFunc>(util::getFunc(replacement->getCallee()));
|
||||
v->replaceAll(util::call(outlinedFunc, newArgs));
|
||||
replacement = nullptr;
|
||||
}
|
||||
|
||||
if (name == "_loop_step") {
|
||||
v->replaceAll(M->getInt(step));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct GPULoopTemplateReplacer : public LoopTemplateReplacer {
|
||||
int64_t step;
|
||||
|
||||
GPULoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar,
|
||||
int64_t step)
|
||||
: LoopTemplateReplacer(parent, replacement, loopVar), step(step) {}
|
||||
|
||||
void handle(CallInstr *v) override {
|
||||
auto *M = v->getModule();
|
||||
auto *func = util::getFunc(v->getCallee());
|
||||
if (!func)
|
||||
return;
|
||||
auto name = func->getUnmangledName();
|
||||
|
||||
if (name == "_loop_step") {
|
||||
v->replaceAll(M->getInt(step));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct OpenMPTransformData {
|
||||
util::OutlineResult outline;
|
||||
std::vector<Var *> sharedVars;
|
||||
|
@ -1164,15 +1283,118 @@ ForkCallData createForkCall(Module *M, OMPTypes &types, Value *rawTemplateFunc,
|
|||
seqassertn(forkFunc, "fork call function not found");
|
||||
result.fork = util::call(forkFunc, {rawTemplateFunc, forkExtra});
|
||||
|
||||
auto *intType = M->getIntType();
|
||||
if (sched->threads && sched->threads->getType()->is(intType)) {
|
||||
if (sched->threads && sched->threads->getType()->is(types.i64)) {
|
||||
auto *pushNumThreadsFunc =
|
||||
M->getOrRealizeFunc("_push_num_threads", {intType}, {}, ompModule);
|
||||
M->getOrRealizeFunc("_push_num_threads", {types.i64}, {}, ompModule);
|
||||
seqassertn(pushNumThreadsFunc, "push num threads func not found");
|
||||
result.pushNumThreads = util::call(pushNumThreadsFunc, {sched->threads});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
struct CollapseResult {
|
||||
ImperativeForFlow *collapsed = nullptr;
|
||||
SeriesFlow *setup = nullptr;
|
||||
std::string error;
|
||||
|
||||
operator bool() const { return collapsed != nullptr; }
|
||||
};
|
||||
|
||||
struct LoopRange {
|
||||
ImperativeForFlow *loop;
|
||||
Var *start;
|
||||
Var *stop;
|
||||
int64_t step;
|
||||
Var *len;
|
||||
};
|
||||
|
||||
CollapseResult collapseLoop(BodiedFunc *parent, ImperativeForFlow *v, int64_t levels) {
|
||||
auto fail = [](const std::string &error) {
|
||||
CollapseResult bad;
|
||||
bad.error = error;
|
||||
return bad;
|
||||
};
|
||||
|
||||
auto *M = v->getModule();
|
||||
CollapseResult res;
|
||||
if (levels < 1)
|
||||
return fail("'collapse' must be at least 1");
|
||||
|
||||
std::vector<ImperativeForFlow *> loopNests = {v};
|
||||
ImperativeForFlow *curr = v;
|
||||
|
||||
for (auto i = 0; i < levels - 1; i++) {
|
||||
auto *body = cast<SeriesFlow>(curr->getBody());
|
||||
seqassertn(body, "unexpected loop body");
|
||||
if (std::distance(body->begin(), body->end()) != 1 ||
|
||||
!isA<ImperativeForFlow>(body->front()))
|
||||
return fail("loop nest not collapsible");
|
||||
|
||||
curr = cast<ImperativeForFlow>(body->front());
|
||||
loopNests.push_back(curr);
|
||||
}
|
||||
|
||||
std::vector<LoopRange> ranges;
|
||||
auto *setup = M->Nr<SeriesFlow>();
|
||||
|
||||
auto *intType = M->getIntType();
|
||||
auto *lenCalc =
|
||||
M->getOrRealizeFunc("_range_len", {intType, intType, intType}, {}, ompModule);
|
||||
seqassertn(lenCalc, "range length calculation function not found");
|
||||
|
||||
for (auto *loop : loopNests) {
|
||||
LoopRange range;
|
||||
range.loop = loop;
|
||||
range.start = util::makeVar(loop->getStart(), setup, parent)->getVar();
|
||||
range.stop = util::makeVar(loop->getEnd(), setup, parent)->getVar();
|
||||
range.step = loop->getStep();
|
||||
range.len = util::makeVar(util::call(lenCalc, {M->Nr<VarValue>(range.start),
|
||||
M->Nr<VarValue>(range.stop),
|
||||
M->getInt(range.step)}),
|
||||
setup, parent)
|
||||
->getVar();
|
||||
ranges.push_back(range);
|
||||
}
|
||||
|
||||
auto *numIters = M->getInt(1);
|
||||
for (auto &range : ranges) {
|
||||
numIters = (*numIters) * (*M->Nr<VarValue>(range.len));
|
||||
}
|
||||
|
||||
auto *collapsedVar = M->Nr<Var>(M->getIntType(), /*global=*/false);
|
||||
parent->push_back(collapsedVar);
|
||||
auto *body = M->Nr<SeriesFlow>();
|
||||
auto sched = std::make_unique<OMPSched>(*v->getSchedule());
|
||||
sched->collapse = 0;
|
||||
auto *collapsed = M->Nr<ImperativeForFlow>(M->getInt(0), 1, numIters, body,
|
||||
collapsedVar, std::move(sched));
|
||||
|
||||
// reconstruct indices by successive divmods
|
||||
Var *lastDiv = nullptr;
|
||||
for (auto it = ranges.rbegin(); it != ranges.rend(); ++it) {
|
||||
auto *k = lastDiv ? lastDiv : collapsedVar;
|
||||
auto *div =
|
||||
util::makeVar(*M->Nr<VarValue>(k) / *M->Nr<VarValue>(it->len), body, parent)
|
||||
->getVar();
|
||||
auto *mod =
|
||||
util::makeVar(*M->Nr<VarValue>(k) % *M->Nr<VarValue>(it->len), body, parent)
|
||||
->getVar();
|
||||
auto *i =
|
||||
*M->Nr<VarValue>(it->start) + *(*M->Nr<VarValue>(mod) * *M->getInt(it->step));
|
||||
body->push_back(M->Nr<AssignInstr>(it->loop->getVar(), i));
|
||||
lastDiv = div;
|
||||
}
|
||||
|
||||
auto *oldBody = cast<SeriesFlow>(loopNests.back()->getBody());
|
||||
for (auto *x : *oldBody) {
|
||||
body->push_back(x);
|
||||
}
|
||||
|
||||
res.collapsed = collapsed;
|
||||
res.setup = setup;
|
||||
|
||||
return res;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const std::string OpenMPPass::KEY = "core-parallel-openmp";
|
||||
|
@ -1279,7 +1501,22 @@ void OpenMPPass::handle(ForFlow *v) {
|
|||
}
|
||||
|
||||
void OpenMPPass::handle(ImperativeForFlow *v) {
|
||||
auto data = setupOpenMPTransform(v, cast<BodiedFunc>(getParentFunc()));
|
||||
auto *parent = cast<BodiedFunc>(getParentFunc());
|
||||
|
||||
if (v->isParallel() && v->getSchedule()->collapse != 0) {
|
||||
auto levels = v->getSchedule()->collapse;
|
||||
auto collapse = collapseLoop(parent, v, levels);
|
||||
|
||||
if (collapse) {
|
||||
v->replaceAll(collapse.collapsed);
|
||||
v = collapse.collapsed;
|
||||
insertBefore(collapse.setup);
|
||||
} else if (!collapse.error.empty()) {
|
||||
warn("could not collapse loop: " + collapse.error, v);
|
||||
}
|
||||
}
|
||||
|
||||
auto data = setupOpenMPTransform(v, parent);
|
||||
if (!v->isParallel())
|
||||
return;
|
||||
|
||||
|
@ -1292,6 +1529,12 @@ void OpenMPPass::handle(ImperativeForFlow *v) {
|
|||
auto *sched = v->getSchedule();
|
||||
OMPTypes types(M);
|
||||
|
||||
if (sched->gpu && !sharedVars.empty()) {
|
||||
warn("GPU-parallel loop cannot modify external variables; ignoring", v);
|
||||
v->setParallel(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// gather extra arguments
|
||||
std::vector<Value *> extraArgs;
|
||||
std::vector<types::Type *> extraArgTypes;
|
||||
|
@ -1304,41 +1547,88 @@ void OpenMPPass::handle(ImperativeForFlow *v) {
|
|||
|
||||
// template call
|
||||
std::string templateFuncName;
|
||||
if (sched->dynamic) {
|
||||
if (sched->gpu) {
|
||||
templateFuncName = "_gpu_loop_outline_template";
|
||||
} else if (sched->dynamic) {
|
||||
templateFuncName = "_dynamic_loop_outline_template";
|
||||
} else if (sched->chunk) {
|
||||
templateFuncName = "_static_chunked_loop_outline_template";
|
||||
} else {
|
||||
templateFuncName = "_static_loop_outline_template";
|
||||
}
|
||||
auto *intType = M->getIntType();
|
||||
std::vector<types::Type *> templateFuncArgs = {
|
||||
types.i32ptr, types.i32ptr,
|
||||
M->getPointerType(M->getTupleType(
|
||||
{intType, intType, intType, M->getTupleType(extraArgTypes)}))};
|
||||
auto *templateFunc =
|
||||
M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {}, ompModule);
|
||||
seqassertn(templateFunc, "imperative loop outline template not found");
|
||||
|
||||
util::CloneVisitor cv(M);
|
||||
templateFunc = cast<Func>(cv.forceClone(templateFunc));
|
||||
ImperativeLoopTemplateReplacer rep(cast<BodiedFunc>(templateFunc), outline.call,
|
||||
loopVar, &reds, sched, v->getStep());
|
||||
templateFunc->accept(rep);
|
||||
auto *rawTemplateFunc = ptrFromFunc(templateFunc);
|
||||
if (sched->gpu) {
|
||||
std::unordered_set<id_t> kernels;
|
||||
const std::string gpuAttr = "std.gpu.kernel";
|
||||
for (auto *var : *M) {
|
||||
if (auto *func = cast<BodiedFunc>(var)) {
|
||||
if (util::hasAttribute(func, gpuAttr))
|
||||
kernels.insert(func->getId());
|
||||
}
|
||||
}
|
||||
|
||||
auto *chunk = (sched->chunk && sched->chunk->getType()->is(intType)) ? sched->chunk
|
||||
: M->getInt(1);
|
||||
std::vector<Value *> forkExtraArgs = {chunk, v->getStart(), v->getEnd()};
|
||||
for (auto *arg : extraArgs) {
|
||||
forkExtraArgs.push_back(arg);
|
||||
std::vector<types::Type *> templateFuncArgs = {types.i64, types.i64,
|
||||
M->getTupleType(extraArgTypes)};
|
||||
static int64_t instance = 0;
|
||||
auto *templateFunc = M->getOrRealizeFunc(templateFuncName, templateFuncArgs,
|
||||
{instance++}, gpuModule);
|
||||
|
||||
if (!templateFunc) {
|
||||
warn("loop not compilable for GPU; ignoring", v);
|
||||
v->setParallel(false);
|
||||
return;
|
||||
}
|
||||
|
||||
BodiedFunc *kernel = nullptr;
|
||||
for (auto *var : *M) {
|
||||
if (auto *func = cast<BodiedFunc>(var)) {
|
||||
if (util::hasAttribute(func, gpuAttr) && kernels.count(func->getId()) == 0) {
|
||||
seqassertn(!kernel, "multiple new kernels found after instantiation");
|
||||
kernel = func;
|
||||
}
|
||||
}
|
||||
}
|
||||
seqassertn(kernel, "no new kernel found");
|
||||
GPULoopBodyStubReplacer brep(outline.call, loopVar, v->getStep());
|
||||
kernel->accept(brep);
|
||||
|
||||
util::CloneVisitor cv(M);
|
||||
templateFunc = cast<Func>(cv.forceClone(templateFunc));
|
||||
GPULoopTemplateReplacer rep(cast<BodiedFunc>(templateFunc), outline.call, loopVar,
|
||||
v->getStep());
|
||||
templateFunc->accept(rep);
|
||||
v->replaceAll(util::call(
|
||||
templateFunc, {v->getStart(), v->getEnd(), util::makeTuple(extraArgs, M)}));
|
||||
} else {
|
||||
std::vector<types::Type *> templateFuncArgs = {
|
||||
types.i32ptr, types.i32ptr,
|
||||
M->getPointerType(M->getTupleType(
|
||||
{types.i64, types.i64, types.i64, M->getTupleType(extraArgTypes)}))};
|
||||
auto *templateFunc =
|
||||
M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {}, ompModule);
|
||||
seqassertn(templateFunc, "imperative loop outline template not found");
|
||||
|
||||
util::CloneVisitor cv(M);
|
||||
templateFunc = cast<Func>(cv.forceClone(templateFunc));
|
||||
ImperativeLoopTemplateReplacer rep(cast<BodiedFunc>(templateFunc), outline.call,
|
||||
loopVar, &reds, sched, v->getStep());
|
||||
templateFunc->accept(rep);
|
||||
auto *rawTemplateFunc = ptrFromFunc(templateFunc);
|
||||
|
||||
auto *chunk = (sched->chunk && sched->chunk->getType()->is(types.i64))
|
||||
? sched->chunk
|
||||
: M->getInt(1);
|
||||
std::vector<Value *> forkExtraArgs = {chunk, v->getStart(), v->getEnd()};
|
||||
for (auto *arg : extraArgs) {
|
||||
forkExtraArgs.push_back(arg);
|
||||
}
|
||||
|
||||
// fork call
|
||||
auto forkData = createForkCall(M, types, rawTemplateFunc, forkExtraArgs, sched);
|
||||
if (forkData.pushNumThreads)
|
||||
insertBefore(forkData.pushNumThreads);
|
||||
v->replaceAll(forkData.fork);
|
||||
}
|
||||
|
||||
// fork call
|
||||
auto forkData = createForkCall(M, types, rawTemplateFunc, forkExtraArgs, sched);
|
||||
if (forkData.pushNumThreads)
|
||||
insertBefore(forkData.pushNumThreads);
|
||||
v->replaceAll(forkData.fork);
|
||||
}
|
||||
|
||||
} // namespace parallel
|
||||
|
|
|
@ -47,17 +47,19 @@ Value *nullIfNeg(Value *v) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
OMPSched::OMPSched(int code, bool dynamic, Value *threads, Value *chunk, bool ordered)
|
||||
OMPSched::OMPSched(int code, bool dynamic, Value *threads, Value *chunk, bool ordered,
|
||||
int64_t collapse, bool gpu)
|
||||
: code(code), dynamic(dynamic), threads(nullIfNeg(threads)),
|
||||
chunk(nullIfNeg(chunk)), ordered(ordered) {
|
||||
chunk(nullIfNeg(chunk)), ordered(ordered), collapse(collapse), gpu(gpu) {
|
||||
if (code < 0)
|
||||
this->code = getScheduleCode();
|
||||
}
|
||||
|
||||
OMPSched::OMPSched(const std::string &schedule, Value *threads, Value *chunk,
|
||||
bool ordered)
|
||||
bool ordered, int64_t collapse, bool gpu)
|
||||
: OMPSched(getScheduleCode(schedule, nullIfNeg(chunk) != nullptr, ordered),
|
||||
(schedule != "static") || ordered, threads, chunk, ordered) {}
|
||||
(schedule != "static") || ordered, threads, chunk, ordered, collapse,
|
||||
gpu) {}
|
||||
|
||||
std::vector<Value *> OMPSched::getUsedValues() const {
|
||||
std::vector<Value *> ret;
|
||||
|
|
|
@ -16,14 +16,18 @@ struct OMPSched {
|
|||
Value *threads;
|
||||
Value *chunk;
|
||||
bool ordered;
|
||||
int64_t collapse;
|
||||
bool gpu;
|
||||
|
||||
explicit OMPSched(int code = -1, bool dynamic = false, Value *threads = nullptr,
|
||||
Value *chunk = nullptr, bool ordered = false);
|
||||
Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0,
|
||||
bool gpu = false);
|
||||
explicit OMPSched(const std::string &code, Value *threads = nullptr,
|
||||
Value *chunk = nullptr, bool ordered = false);
|
||||
Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0,
|
||||
bool gpu = false);
|
||||
OMPSched(const OMPSched &s)
|
||||
: code(s.code), dynamic(s.dynamic), threads(s.threads), chunk(s.chunk),
|
||||
ordered(s.ordered) {}
|
||||
ordered(s.ordered), collapse(s.collapse), gpu(s.gpu) {}
|
||||
|
||||
std::vector<Value *> getUsedValues() const;
|
||||
int replaceUsedValue(id_t id, Value *newValue);
|
||||
|
|
|
@ -65,6 +65,8 @@ const char IntType::NodeId = 0;
|
|||
|
||||
const char FloatType::NodeId = 0;
|
||||
|
||||
const char Float32Type::NodeId = 0;
|
||||
|
||||
const char BoolType::NodeId = 0;
|
||||
|
||||
const char ByteType::NodeId = 0;
|
||||
|
@ -188,6 +190,12 @@ std::string IntNType::getInstanceName(unsigned int len, bool sign) {
|
|||
return fmt::format(FMT_STRING("{}Int{}"), sign ? "" : "U", len);
|
||||
}
|
||||
|
||||
const char VectorType::NodeId = 0;
|
||||
|
||||
std::string VectorType::getInstanceName(unsigned int count, PrimitiveType *base) {
|
||||
return fmt::format(FMT_STRING("Vector[{}, {}]"), count, base->referenceString());
|
||||
}
|
||||
|
||||
} // namespace types
|
||||
} // namespace ir
|
||||
} // namespace codon
|
||||
|
|
|
@ -138,6 +138,15 @@ public:
|
|||
FloatType() : AcceptorExtend("float") {}
|
||||
};
|
||||
|
||||
/// Float32 type (32-bit float)
|
||||
class Float32Type : public AcceptorExtend<Float32Type, PrimitiveType> {
|
||||
public:
|
||||
static const char NodeId;
|
||||
|
||||
/// Constructs a float32 type.
|
||||
Float32Type() : AcceptorExtend("float32") {}
|
||||
};
|
||||
|
||||
/// Bool type (8-bit unsigned integer; either 0 or 1)
|
||||
class BoolType : public AcceptorExtend<BoolType, PrimitiveType> {
|
||||
public:
|
||||
|
@ -424,7 +433,7 @@ private:
|
|||
};
|
||||
|
||||
/// Type of a variably sized integer
|
||||
class IntNType : public AcceptorExtend<IntNType, Type> {
|
||||
class IntNType : public AcceptorExtend<IntNType, PrimitiveType> {
|
||||
private:
|
||||
/// length of the integer
|
||||
unsigned len;
|
||||
|
@ -451,9 +460,31 @@ public:
|
|||
std::string oppositeSignName() const { return getInstanceName(len, !sign); }
|
||||
|
||||
static std::string getInstanceName(unsigned len, bool sign);
|
||||
};
|
||||
|
||||
/// Type of a vector of primitives
|
||||
class VectorType : public AcceptorExtend<VectorType, PrimitiveType> {
|
||||
private:
|
||||
bool doIsAtomic() const override { return true; }
|
||||
/// number of elements
|
||||
unsigned count;
|
||||
/// base type
|
||||
PrimitiveType *base;
|
||||
|
||||
public:
|
||||
static const char NodeId;
|
||||
|
||||
/// Constructs a vector type.
|
||||
/// @param count the number of elements
|
||||
/// @param base the base type
|
||||
VectorType(unsigned count, PrimitiveType *base)
|
||||
: AcceptorExtend(getInstanceName(count, base)), count(count), base(base) {}
|
||||
|
||||
/// @return the count of the vector
|
||||
unsigned getCount() const { return count; }
|
||||
/// @return the base type of the vector
|
||||
PrimitiveType *getBase() const { return base; }
|
||||
|
||||
static std::string getInstanceName(unsigned count, PrimitiveType *base);
|
||||
};
|
||||
|
||||
} // namespace types
|
||||
|
|
|
@ -288,6 +288,9 @@ public:
|
|||
void visit(const types::FloatType *v) override {
|
||||
fmt::print(os, FMT_STRING("(float '\"{}\")"), v->referenceString());
|
||||
}
|
||||
void visit(const types::Float32Type *v) override {
|
||||
fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString());
|
||||
}
|
||||
void visit(const types::BoolType *v) override {
|
||||
fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString());
|
||||
}
|
||||
|
@ -334,6 +337,10 @@ public:
|
|||
fmt::print(os, FMT_STRING("(intn '\"{}\" {} (signed {}))"), v->referenceString(),
|
||||
v->getLen(), v->isSigned());
|
||||
}
|
||||
void visit(const types::VectorType *v) override {
|
||||
fmt::print(os, FMT_STRING("(vector '\"{}\" {} (count {}))"), v->referenceString(),
|
||||
makeFormatter(v->getBase()), v->getCount());
|
||||
}
|
||||
void visit(const dsl::types::CustomType *v) override { v->doFormat(os); }
|
||||
|
||||
void format(const Node *n) {
|
||||
|
|
|
@ -233,6 +233,8 @@ public:
|
|||
void handle(const types::IntType *, const types::IntType *) { result = true; }
|
||||
VISIT(types::FloatType);
|
||||
void handle(const types::FloatType *, const types::FloatType *) { result = true; }
|
||||
VISIT(types::Float32Type);
|
||||
void handle(const types::Float32Type *, const types::Float32Type *) { result = true; }
|
||||
VISIT(types::BoolType);
|
||||
void handle(const types::BoolType *, const types::BoolType *) { result = true; }
|
||||
VISIT(types::ByteType);
|
||||
|
@ -272,6 +274,10 @@ public:
|
|||
void handle(const types::IntNType *x, const types::IntNType *y) {
|
||||
result = x->getLen() == y->getLen() && x->isSigned() == y->isSigned();
|
||||
}
|
||||
VISIT(types::VectorType);
|
||||
void handle(const types::VectorType *x, const types::VectorType *y) {
|
||||
result = x->getCount() == y->getCount() && process(x->getBase(), y->getBase());
|
||||
}
|
||||
VISIT(dsl::types::CustomType);
|
||||
void handle(const dsl::types::CustomType *x, const dsl::types::CustomType *y) {
|
||||
result = x->match(y);
|
||||
|
|
|
@ -51,6 +51,7 @@ void Visitor::visit(types::Type *x) { defaultVisit(x); }
|
|||
void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::IntType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::FloatType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::Float32Type *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::BoolType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::ByteType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::VoidType *x) { defaultVisit(x); }
|
||||
|
@ -61,6 +62,7 @@ void Visitor::visit(types::OptionalType *x) { defaultVisit(x); }
|
|||
void Visitor::visit(types::PointerType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::GeneratorType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::IntNType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::VectorType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(dsl::types::CustomType *x) { defaultVisit(x); }
|
||||
|
||||
void ConstVisitor::visit(const Module *x) { defaultVisit(x); }
|
||||
|
@ -108,6 +110,7 @@ void ConstVisitor::visit(const types::Type *x) { defaultVisit(x); }
|
|||
void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); }
|
||||
|
@ -118,6 +121,7 @@ void ConstVisitor::visit(const types::OptionalType *x) { defaultVisit(x); }
|
|||
void ConstVisitor::visit(const types::PointerType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::GeneratorType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::IntNType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::VectorType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const dsl::types::CustomType *x) { defaultVisit(x); }
|
||||
|
||||
} // namespace util
|
||||
|
|
|
@ -16,6 +16,7 @@ class Type;
|
|||
class PrimitiveType;
|
||||
class IntType;
|
||||
class FloatType;
|
||||
class Float32Type;
|
||||
class BoolType;
|
||||
class ByteType;
|
||||
class VoidType;
|
||||
|
@ -26,6 +27,7 @@ class OptionalType;
|
|||
class PointerType;
|
||||
class GeneratorType;
|
||||
class IntNType;
|
||||
class VectorType;
|
||||
} // namespace types
|
||||
|
||||
namespace dsl {
|
||||
|
@ -146,6 +148,7 @@ public:
|
|||
VISIT(types::PrimitiveType);
|
||||
VISIT(types::IntType);
|
||||
VISIT(types::FloatType);
|
||||
VISIT(types::Float32Type);
|
||||
VISIT(types::BoolType);
|
||||
VISIT(types::ByteType);
|
||||
VISIT(types::VoidType);
|
||||
|
@ -156,6 +159,7 @@ public:
|
|||
VISIT(types::PointerType);
|
||||
VISIT(types::GeneratorType);
|
||||
VISIT(types::IntNType);
|
||||
VISIT(types::VectorType);
|
||||
VISIT(dsl::types::CustomType);
|
||||
};
|
||||
|
||||
|
@ -220,6 +224,7 @@ public:
|
|||
CONST_VISIT(types::PrimitiveType);
|
||||
CONST_VISIT(types::IntType);
|
||||
CONST_VISIT(types::FloatType);
|
||||
CONST_VISIT(types::Float32Type);
|
||||
CONST_VISIT(types::BoolType);
|
||||
CONST_VISIT(types::ByteType);
|
||||
CONST_VISIT(types::VoidType);
|
||||
|
@ -230,6 +235,7 @@ public:
|
|||
CONST_VISIT(types::PointerType);
|
||||
CONST_VISIT(types::GeneratorType);
|
||||
CONST_VISIT(types::IntNType);
|
||||
CONST_VISIT(types::VectorType);
|
||||
CONST_VISIT(dsl::types::CustomType);
|
||||
};
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
## Advanced
|
||||
|
||||
* [Parallelism and multithreading](advanced/parallel.md)
|
||||
* [GPU programming](advanced/gpu.md)
|
||||
* [Pipelines](advanced/pipelines.md)
|
||||
* [Intermediate representation](advanced/ir.md)
|
||||
* [Building from source](advanced/build.md)
|
||||
|
|
|
@ -0,0 +1,249 @@
|
|||
Codon supports GPU programming through a native GPU backend.
|
||||
Currently, only Nvidia devices are supported.
|
||||
Here is a simple example:
|
||||
|
||||
``` python
|
||||
import gpu
|
||||
|
||||
@gpu.kernel
|
||||
def hello(a, b, c):
|
||||
i = gpu.thread.x
|
||||
c[i] = a[i] + b[i]
|
||||
|
||||
a = [i for i in range(16)]
|
||||
b = [2*i for i in range(16)]
|
||||
c = [0 for _ in range(16)]
|
||||
|
||||
hello(a, b, c, grid=1, block=16)
|
||||
print(c)
|
||||
```
|
||||
|
||||
which outputs:
|
||||
|
||||
```
|
||||
[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45]
|
||||
```
|
||||
|
||||
The same code can be written using Codon's `@par` syntax:
|
||||
|
||||
``` python
|
||||
a = [i for i in range(16)]
|
||||
b = [2*i for i in range(16)]
|
||||
c = [0 for _ in range(16)]
|
||||
|
||||
@par(gpu=True)
|
||||
for i in range(16):
|
||||
c[i] = a[i] + b[i]
|
||||
|
||||
print(c)
|
||||
```
|
||||
|
||||
Below is a more comprehensive example for computing the [Mandelbrot
|
||||
set](https://en.wikipedia.org/wiki/Mandelbrot_set), and plotting it
|
||||
using NumPy/Matplotlib:
|
||||
|
||||
``` python
|
||||
from python import numpy as np
|
||||
from python import matplotlib.pyplot as plt
|
||||
import gpu
|
||||
|
||||
MAX = 1000 # maximum Mandelbrot iterations
|
||||
N = 4096 # width and height of image
|
||||
pixels = [0 for _ in range(N * N)]
|
||||
|
||||
def scale(x, a, b):
|
||||
return a + (x/N)*(b - a)
|
||||
|
||||
@gpu.kernel
|
||||
def mandelbrot(pixels):
|
||||
idx = (gpu.block.x * gpu.block.dim.x) + gpu.thread.x
|
||||
i, j = divmod(idx, N)
|
||||
c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12))
|
||||
z = 0j
|
||||
iteration = 0
|
||||
|
||||
while abs(z) <= 2 and iteration < MAX:
|
||||
z = z**2 + c
|
||||
iteration += 1
|
||||
|
||||
pixels[idx] = int(255 * iteration/MAX)
|
||||
|
||||
mandelbrot(pixels, grid=(N*N)//1024, block=1024)
|
||||
plt.imshow(np.array(pixels).reshape(N, N))
|
||||
plt.show()
|
||||
```
|
||||
|
||||
The GPU version of the Mandelbrot code is about 450 times faster
|
||||
than an equivalent CPU version.
|
||||
|
||||
GPU kernels are marked with the `@gpu.kernel` annotation, and
|
||||
compiled specially in Codon's backend. Kernel functions can
|
||||
use the vast majority of features supported in Codon, with a
|
||||
couple notable exceptions:
|
||||
|
||||
- Exception handling is not supported inside the kernel, meaning
|
||||
kernel code should not throw or catch exceptions. `raise`
|
||||
statements inside the kernel are marked as unreachable and
|
||||
optimized out.
|
||||
|
||||
- Functionality related to I/O is not supported (e.g. you can't
|
||||
open a file in the kernel).
|
||||
|
||||
- A few other modules and functions are not allowed, such as the
|
||||
`re` module (which uses an external regex library) or the `os`
|
||||
module.
|
||||
|
||||
{% hint style="warning" %}
|
||||
The GPU module is under active development. APIs and semantics
|
||||
might change between Codon releases.
|
||||
{% endhint %}
|
||||
|
||||
# Invoking the kernel
|
||||
|
||||
The kernel can be invoked via a simple call with added `grid` and
|
||||
`block` parameters. These parameters define the grid and block
|
||||
dimensions, respectively. Recall that GPU execution involves a *grid*
|
||||
of (`X` x `Y` x `Z`) *blocks* where each block contains (`x` x `y` x `z`)
|
||||
executing threads. Device-specific restrictions on grid and block sizes
|
||||
apply.
|
||||
|
||||
The `grid` and `block` parameters can be one of:
|
||||
|
||||
- Single integer `x`, giving dimensions `(x, 1, 1)`
|
||||
- Tuple of two integers `(x, y)`, giving dimensions `(x, y, 1)`
|
||||
- Tuple of three integers `(x, y, z)`, giving dimensions `(x, y, z)`
|
||||
- Instance of `gpu.Dim3` as in `Dim3(x, y, z)`, specifying the three dimensions
|
||||
|
||||
# GPU intrinsics
|
||||
|
||||
Codon's GPU module provides many of the same intrinsics that CUDA does:
|
||||
|
||||
| Codon | Description | CUDA equivalent |
|
||||
|-------------------|-----------------------------------------|-----------------|
|
||||
| `gpu.thread.x` | x-coordinate of current thread in block | `threadId.x` |
|
||||
| `gpu.block.x` | x-coordinate of current block in grid | `blockIdx.x` |
|
||||
| `gpu.block.dim.x` | x-dimension of block | `blockDim.x` |
|
||||
| `gpu.grid.dim.x` | x-dimension of grid | `gridDim.x` |
|
||||
|
||||
The same applies for the `y` and `z` coordinates. The `*.dim` objects are instances
|
||||
of `gpu.Dim3`.
|
||||
|
||||
# Math functions
|
||||
|
||||
All the functions in the `math` module are supported in kernel functions, and
|
||||
are automatically replaced with GPU-optimized versions:
|
||||
|
||||
``` python
|
||||
import math
|
||||
import gpu
|
||||
|
||||
@gpu.kernel
|
||||
def hello(x):
|
||||
i = gpu.thread.x
|
||||
x[i] = math.sqrt(x[i]) # uses __nv_sqrt from libdevice
|
||||
|
||||
x = [float(i) for i in range(10)]
|
||||
hello(x, grid=1, block=10)
|
||||
print(x)
|
||||
```
|
||||
|
||||
gives:
|
||||
|
||||
```
|
||||
[0, 1, 1.41421, 1.73205, 2, 2.23607, 2.44949, 2.64575, 2.82843, 3]
|
||||
```
|
||||
|
||||
# Libdevice
|
||||
|
||||
Codon uses [libdevice](https://docs.nvidia.com/cuda/libdevice-users-guide/index.html)
|
||||
for GPU-optimized math functions. The default libdevice path is
|
||||
`/usr/local/cuda/nvvm/libdevice/libdevice.10.bc`. An alternative path can be specified
|
||||
via the `-libdevice` compiler flag.
|
||||
|
||||
# Working with raw pointers
|
||||
|
||||
By default, objects are converted entirely to their GPU counterparts, which have
|
||||
the same data layout as the original objects (although the Codon compiler might perform
|
||||
optimizations by swapping a CPU implementation of a data type with a GPU-optimized
|
||||
implementation that exposes the same API). This preserves all of Codon/Python's
|
||||
standard semantics within the kernel.
|
||||
|
||||
It is possible to use a kernel with raw pointers via `gpu.raw`, which corresponds
|
||||
to how the kernel would be written in C++/CUDA:
|
||||
|
||||
``` python
|
||||
import gpu
|
||||
|
||||
@gpu.kernel
|
||||
def hello(a, b, c):
|
||||
i = gpu.thread.x
|
||||
c[i] = a[i] + b[i]
|
||||
|
||||
a = [i for i in range(16)]
|
||||
b = [2*i for i in range(16)]
|
||||
c = [0 for _ in range(16)]
|
||||
|
||||
# call the kernel with three int-pointer arguments:
|
||||
hello(gpu.raw(a), gpu.raw(b), gpu.raw(c), grid=1, block=16)
|
||||
print(c) # output same as first snippet's
|
||||
```
|
||||
|
||||
`gpu.raw` can avoid an extra pointer indirection, but outputs a Codon `Ptr` object,
|
||||
meaning the corresponding kernel parameters will not have the full list API, instead
|
||||
having the more limited `Ptr` API (which primarily just supports indexing/assignment).
|
||||
|
||||
# Object conversions
|
||||
|
||||
A hidden API is used to copy objects to and from the GPU device. This API consists of
|
||||
two new *magic methods*:
|
||||
|
||||
- `__to_gpu__(self)`: Allocates the necessary GPU memory and copies the object `self` to
|
||||
the device.
|
||||
|
||||
- `__from_gpu__(self, gpu_object)`: Copies the GPU memory of `gpu_object` (which is
|
||||
a value returned by `__to_gpu__`) back to the CPU object `self`.
|
||||
|
||||
For primitive types like `int` and `float`, `__to_gpu__` simply returns `self` and
|
||||
`__from_gpu__` does nothing. These methods are defined for all the built-in types *and*
|
||||
are automatically generated for user-defined classes, so most objects can be transferred
|
||||
back and forth from the GPU seamlessly. A user-defined class that makes use of raw pointers
|
||||
or other low-level constructs will have to define these methods for GPU use. Please refer
|
||||
to the `gpu` module for implementation examples.
|
||||
|
||||
# `@par(gpu=True)`
|
||||
|
||||
Codon's `@par` syntax can be used to seamlessly parallelize existing loops on the GPU,
|
||||
without needing to explicitly write them as kernels. For loop nests, the `collapse` argument
|
||||
can be used to cover the entire iteration space on the GPU. For example, here is the Mandelbrot
|
||||
code above written using `@par`:
|
||||
|
||||
``` python
|
||||
MAX = 1000 # maximum Mandelbrot iterations
|
||||
N = 4096 # width and height of image
|
||||
pixels = [0 for _ in range(N * N)]
|
||||
|
||||
def scale(x, a, b):
|
||||
return a + (x/N)*(b - a)
|
||||
|
||||
@par(gpu=True, collapse=2)
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12))
|
||||
z = 0j
|
||||
iteration = 0
|
||||
|
||||
while abs(z) <= 2 and iteration < MAX:
|
||||
z = z**2 + c
|
||||
iteration += 1
|
||||
|
||||
pixels[i*N + j] = int(255 * iteration/MAX)
|
||||
```
|
||||
|
||||
Note that the `gpu=True` option disallows shared variables (i.e. assigning out-of-loop
|
||||
variables in the loop body) as well as reductions. The other GPU-specific restrictions
|
||||
described here apply as well.
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
CUDA errors resulting in kernel abortion are printed, and typically arise from invalid
|
||||
code in the kernel, either via using exceptions or using unsupported modules/objects.
|
|
@ -29,6 +29,8 @@ for i in range(10):
|
|||
- `chunk_size` (int): chunk size when partitioning loop iterations
|
||||
- `ordered` (bool): whether the loop iterations should be executed in
|
||||
the same order
|
||||
- `collapse` (int): number of loop nests to collapse into a single
|
||||
iteration space
|
||||
|
||||
Other OpenMP parameters like `private`, `shared` or `reduction`, are
|
||||
inferred automatically by the compiler. For example, the following loop
|
||||
|
|
|
@ -30,6 +30,15 @@ compile to native code without any runtime performance overhead.
|
|||
Future versions of Codon will lift some of these restrictions
|
||||
by the introduction of e.g. implicit union types.
|
||||
|
||||
# Numerics
|
||||
|
||||
For performance reasons, some numeric operations use C semantics
|
||||
rather than Python semantics. This includes, for example, raising
|
||||
an exception when dividing by zero, or other checks done by `math`
|
||||
functions. Strict adherence to Python semantics can be achieved by
|
||||
using the `-numerics=py` flag of the Codon compiler. Note that this
|
||||
does *not* change `int`s from 64-bit.
|
||||
|
||||
# Modules
|
||||
|
||||
While most of the commonly used builtin modules have Codon-native
|
||||
|
|
|
@ -2,6 +2,50 @@ Below you can find release notes for each major Codon release,
|
|||
listing improvements, updates, optimizations and more for each
|
||||
new version.
|
||||
|
||||
# v0.14
|
||||
|
||||
## GPU support
|
||||
|
||||
GPU kernels can now be written and called in Codon. Existing
|
||||
loops can be parallelized on the GPU with the `@par(gpu=True)`
|
||||
annotation. Please see the [docs](../advanced/gpu.md) for
|
||||
more information and examples.
|
||||
|
||||
## Semantics
|
||||
|
||||
Added `-numerics` flag, which specifies semantics of various
|
||||
numeric operations:
|
||||
|
||||
- `-numerics=c` (default): C semantics; best performance
|
||||
- `-numerics=py`: Python semantics (checks for zero divisors
|
||||
and raises `ZeroDivisionError`, and adds domain checks to `math`
|
||||
functions); might slightly decrease performance.
|
||||
|
||||
## Types
|
||||
|
||||
Added `float32` type to represent 32-bit floats (equivalent to C's
|
||||
`float`). All `math` functions now have `float32` overloads.
|
||||
|
||||
## Parallelism
|
||||
|
||||
Added `collapse` option to `@par`:
|
||||
|
||||
``` python
|
||||
@par(collapse=2) # parallelize entire iteration space of 2 loops
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
do_work(i, j)
|
||||
```
|
||||
|
||||
## Standard library
|
||||
|
||||
Added `collections.defaultdict`.
|
||||
|
||||
## Python interoperability
|
||||
|
||||
Various Python interoperability improvements: can now use `isinstance`
|
||||
on Python objects/types and can now catch Python exceptions by name.
|
||||
|
||||
# v0.13
|
||||
|
||||
## Language
|
||||
|
|
|
@ -19,6 +19,12 @@ variants:
|
|||
- `i32`/`u32`: signed/unsigned 32-bit integer
|
||||
- `i64`/`u64`: signed/unsigned 64-bit integer
|
||||
|
||||
# 32-bit float
|
||||
|
||||
Codon's `float` type is a 64-bit floating point value. Codon
|
||||
also supports `float32` (or `f32` as a shorthand), representing
|
||||
a 32-bit floating point value (like C's `float`).
|
||||
|
||||
# Pointers
|
||||
|
||||
Codon has a `Ptr[T]` type that represents a pointer to an object
|
||||
|
|
|
@ -69,6 +69,6 @@ def heap_sort(collection: List[T], keyf: Callable[[T], S], T: type, S: type) ->
|
|||
Heap Sort
|
||||
Returns a sorted list.
|
||||
"""
|
||||
newlst = copy(collection)
|
||||
newlst = collection.__copy__()
|
||||
heap_sort_inplace(newlst, keyf)
|
||||
return newlst
|
||||
|
|
|
@ -42,6 +42,6 @@ def insertion_sort(
|
|||
Insertion Sort
|
||||
Returns the sorted list.
|
||||
"""
|
||||
newlst = copy(collection)
|
||||
newlst = collection.__copy__()
|
||||
insertion_sort_inplace(newlst, keyf)
|
||||
return newlst
|
||||
|
|
|
@ -321,6 +321,6 @@ def pdq_sort(collection: List[T], keyf: Callable[[T], S], T: type, S: type) -> L
|
|||
|
||||
Returns a sorted list.
|
||||
"""
|
||||
newlst = copy(collection)
|
||||
newlst = collection.__copy__()
|
||||
pdq_sort_inplace(newlst, keyf)
|
||||
return newlst
|
||||
|
|
|
@ -114,13 +114,13 @@ class deque:
|
|||
return deque(i.__deepcopy__() for i in self)
|
||||
|
||||
def __copy__(self) -> deque[T]:
|
||||
return deque[T](copy(self._arr), self._head, self._tail, self._maxlen)
|
||||
return deque[T](self._arr.__copy__(), self._head, self._tail, self._maxlen)
|
||||
|
||||
def copy(self) -> deque[T]:
|
||||
return self.__copy__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(List[T](iter(self)))
|
||||
return f"deque({repr(List[T](iter(self)))})"
|
||||
|
||||
def _idx_check(self, idx: int, msg: str):
|
||||
if self._head == self._tail or idx >= len(self) or idx < 0:
|
||||
|
@ -323,25 +323,28 @@ class Counter(Dict[T, int]):
|
|||
return result
|
||||
|
||||
def __add__(self, other: Counter[T]) -> Counter[T]:
|
||||
result = copy(self)
|
||||
result = self.__copy__()
|
||||
result += other
|
||||
return result
|
||||
|
||||
def __sub__(self, other: Counter[T]) -> Counter[T]:
|
||||
result = copy(self)
|
||||
result = self.__copy__()
|
||||
result -= other
|
||||
return result
|
||||
|
||||
def __and__(self, other: Counter[T]) -> Counter[T]:
|
||||
result = copy(self)
|
||||
result = self.__copy__()
|
||||
result &= other
|
||||
return result
|
||||
|
||||
def __or__(self, other: Counter[T]) -> Counter[T]:
|
||||
result = copy(self)
|
||||
result = self.__copy__()
|
||||
result |= other
|
||||
return result
|
||||
|
||||
def __repr__(self):
|
||||
return f"Counter({super().__repr__()})"
|
||||
|
||||
def __dict_do_op_throws__(self, key: T, other: Z, op: F, F: type, Z: type):
|
||||
self.__dict_do_op__(key, other, 0, op)
|
||||
|
||||
|
@ -357,5 +360,79 @@ class Dict:
|
|||
self._init_from(other)
|
||||
|
||||
|
||||
class defaultdict(Dict[K,V]):
|
||||
default_factory: S
|
||||
K: type
|
||||
V: type
|
||||
S: TypeVar[Callable[[], V]]
|
||||
|
||||
def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V]):
|
||||
super().__init__()
|
||||
self.default_factory = lambda: VV()
|
||||
|
||||
def __init__(self, f: S):
|
||||
super().__init__()
|
||||
self.default_factory = f
|
||||
|
||||
def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V], other: Dict[K, V]):
|
||||
super().__init__(other)
|
||||
self.default_factory = lambda: VV()
|
||||
|
||||
def __init__(self, f: S, other: Dict[K, V]):
|
||||
super().__init__(other)
|
||||
self.default_factory = f
|
||||
|
||||
def __missing__(self, key: K):
|
||||
default_value = self.default_factory()
|
||||
self.__setitem__(key, default_value)
|
||||
return default_value
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
if key not in self:
|
||||
return self.__missing__(key)
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __dict_do_op_throws__(self, key: K, other: Z, op: F, F: type, Z: type):
|
||||
x = self._kh_get(key)
|
||||
if x == self._kh_end():
|
||||
self.__missing__(key)
|
||||
x = self._kh_get(key)
|
||||
self._vals[x] = op(self._vals[x], other)
|
||||
|
||||
def copy(self):
|
||||
d = defaultdict[K,V,S](self.default_factory)
|
||||
d._init_from(self)
|
||||
return d
|
||||
|
||||
def __copy__(self):
|
||||
return self.copy()
|
||||
|
||||
def __deepcopy__(self):
|
||||
d = defaultdict[K,V,S](self.default_factory)
|
||||
for k,v in self.items():
|
||||
d[k.__deepcopy__()] = v.__deepcopy__()
|
||||
return d
|
||||
|
||||
def __eq__(self, other: defaultdict[K,V,S]) -> bool:
|
||||
if self.__len__() != other.__len__():
|
||||
return False
|
||||
for k, v in self.items():
|
||||
if k not in other or other[k] != v:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other: defaultdict[K,V,S]) -> bool:
|
||||
return not (self == other)
|
||||
|
||||
def __repr__(self):
|
||||
return f"defaultdict(<default factory of '{V.__name__}'>, {super().__repr__()})"
|
||||
|
||||
|
||||
@extend
|
||||
class Dict:
|
||||
def __init__(self: Dict[K, V], other: defaultdict[K, V, S], S: type):
|
||||
self._init_from(other)
|
||||
|
||||
|
||||
def namedtuple(name: Static[str], args): # internal
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
def __init__(self, message: str = ""):
|
||||
super().__init__("copy.Error", message)
|
||||
|
||||
|
||||
def copy(x):
|
||||
return x.__copy__()
|
||||
|
||||
|
||||
def deepcopy(x):
|
||||
return x.__deepcopy__()
|
|
@ -0,0 +1,732 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
from internal.gc import sizeof as _sizeof
|
||||
|
||||
@tuple
|
||||
class Device:
|
||||
_device: i32
|
||||
|
||||
def __new__(device: int):
|
||||
from C import seq_nvptx_device(int) -> i32
|
||||
return Device(seq_nvptx_device(device))
|
||||
|
||||
@staticmethod
|
||||
def count():
|
||||
from C import seq_nvptx_device_count() -> int
|
||||
return seq_nvptx_device_count()
|
||||
|
||||
def __str__(self):
|
||||
from C import seq_nvptx_device_name(i32) -> str
|
||||
return seq_nvptx_device_name(self._device)
|
||||
|
||||
def __index__(self):
|
||||
return int(self._device)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def compute_capability(self):
|
||||
from C import seq_nvptx_device_capability(i32) -> int
|
||||
c = seq_nvptx_device_capability(self._device)
|
||||
return (c >> 32, c & 0xffffffff)
|
||||
|
||||
@tuple
|
||||
class Memory[T]:
|
||||
_ptr: Ptr[byte]
|
||||
|
||||
def _alloc(n: int, T: type):
|
||||
from C import seq_nvptx_device_alloc(int) -> Ptr[byte]
|
||||
return Memory[T](seq_nvptx_device_alloc(n * _sizeof(T)))
|
||||
|
||||
def _read(self, p: Ptr[T], n: int):
|
||||
from C import seq_nvptx_memcpy_d2h(Ptr[byte], Ptr[byte], int)
|
||||
seq_nvptx_memcpy_d2h(p.as_byte(), self._ptr, n * _sizeof(T))
|
||||
|
||||
def _write(self, p: Ptr[T], n: int):
|
||||
from C import seq_nvptx_memcpy_h2d(Ptr[byte], Ptr[byte], int)
|
||||
seq_nvptx_memcpy_h2d(self._ptr, p.as_byte(), n * _sizeof(T))
|
||||
|
||||
def _free(self):
|
||||
from C import seq_nvptx_device_free(Ptr[byte])
|
||||
seq_nvptx_device_free(self._ptr)
|
||||
|
||||
@llvm
|
||||
def syncthreads() -> None:
|
||||
declare void @llvm.nvvm.barrier0()
|
||||
call void @llvm.nvvm.barrier0()
|
||||
ret {} {}
|
||||
|
||||
@tuple
|
||||
class Dim3:
|
||||
_x: u32
|
||||
_y: u32
|
||||
_z: u32
|
||||
|
||||
def __new__(x: int, y: int, z: int):
|
||||
return Dim3(u32(x), u32(y), u32(z))
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return int(self._x)
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return int(self._y)
|
||||
|
||||
@property
|
||||
def z(self):
|
||||
return int(self._z)
|
||||
|
||||
@tuple
|
||||
class Thread:
|
||||
@property
|
||||
def x(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_x() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
ret i32 %res
|
||||
|
||||
return int(get_x())
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_y() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.tid.y()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
|
||||
ret i32 %res
|
||||
|
||||
return int(get_y())
|
||||
|
||||
@property
|
||||
def z(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_z() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.tid.z()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
|
||||
ret i32 %res
|
||||
|
||||
return int(get_z())
|
||||
|
||||
@tuple
|
||||
class Block:
|
||||
@property
|
||||
def x(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_x() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
|
||||
ret i32 %res
|
||||
|
||||
return int(get_x())
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_y() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
|
||||
ret i32 %res
|
||||
|
||||
return int(get_y())
|
||||
|
||||
@property
|
||||
def z(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_z() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
|
||||
ret i32 %res
|
||||
|
||||
return int(get_z())
|
||||
|
||||
@property
|
||||
def dim(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_x() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
|
||||
ret i32 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def get_y() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
|
||||
ret i32 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def get_z() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
|
||||
ret i32 %res
|
||||
|
||||
return Dim3(get_x(), get_y(), get_z())
|
||||
|
||||
@tuple
|
||||
class Grid:
|
||||
@property
|
||||
def dim(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_x() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
|
||||
ret i32 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def get_y() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
|
||||
ret i32 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def get_z() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
|
||||
ret i32 %res
|
||||
|
||||
return Dim3(get_x(), get_y(), get_z())
|
||||
|
||||
@tuple
|
||||
class Warp:
|
||||
def __len__(self):
|
||||
@pure
|
||||
@llvm
|
||||
def get_warpsize() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
|
||||
ret i32 %res
|
||||
|
||||
return int(get_warpsize())
|
||||
|
||||
thread = Thread()
|
||||
block = Block()
|
||||
grid = Grid()
|
||||
warp = Warp()
|
||||
|
||||
def _catch():
|
||||
return (thread, block, grid, warp)
|
||||
|
||||
_catch()
|
||||
|
||||
@tuple
|
||||
class AllocCache:
|
||||
v: List[Ptr[byte]]
|
||||
|
||||
def add(self, p: Ptr[byte]):
|
||||
self.v.append(p)
|
||||
|
||||
def free(self):
|
||||
for p in self.v:
|
||||
Memory[byte](p)._free()
|
||||
|
||||
def _tuple_from_gpu(args, gpu_args):
|
||||
if staticlen(args) > 0:
|
||||
a = args[0]
|
||||
g = gpu_args[0]
|
||||
a.__from_gpu__(g)
|
||||
_tuple_from_gpu(args[1:], gpu_args[1:])
|
||||
|
||||
def kernel(fn):
|
||||
from C import seq_nvptx_function(str) -> cobj
|
||||
from C import seq_nvptx_invoke(cobj, u32, u32, u32, u32, u32, u32, u32, cobj)
|
||||
|
||||
def canonical_dim(dim):
|
||||
if isinstance(dim, NoneType):
|
||||
return (1, 1, 1)
|
||||
elif isinstance(dim, int):
|
||||
return (dim, 1, 1)
|
||||
elif isinstance(dim, Tuple[int,int]):
|
||||
return (dim[0], dim[1], 1)
|
||||
elif isinstance(dim, Tuple[int,int,int]):
|
||||
return dim
|
||||
elif isinstance(dim, Dim3):
|
||||
return (dim.x, dim.y, dim.z)
|
||||
else:
|
||||
compile_error("bad dimension argument")
|
||||
|
||||
def offsets(t):
|
||||
@pure
|
||||
@llvm
|
||||
def offsetof(t: T, i: Static[int], T: type, S: type) -> int:
|
||||
%p = getelementptr {=T}, {=T}* null, i64 0, i32 {=i}
|
||||
%s = ptrtoint {=S}* %p to i64
|
||||
ret i64 %s
|
||||
|
||||
if staticlen(t) == 0:
|
||||
return ()
|
||||
else:
|
||||
T = type(t)
|
||||
S = type(t[-1])
|
||||
return (*offsets(t[:-1]), offsetof(t, staticlen(t) - 1, T, S))
|
||||
|
||||
def wrapper(*args, grid, block):
|
||||
grid = canonical_dim(grid)
|
||||
block = canonical_dim(block)
|
||||
cache = AllocCache([])
|
||||
shared_mem = 0
|
||||
gpu_args = tuple(arg.__to_gpu__(cache) for arg in args)
|
||||
kernel_ptr = seq_nvptx_function(__realized__(fn, gpu_args).__llvm_name__)
|
||||
p = __ptr__(gpu_args).as_byte()
|
||||
arg_ptrs = tuple((p + offset) for offset in offsets(gpu_args))
|
||||
seq_nvptx_invoke(kernel_ptr, u32(grid[0]), u32(grid[1]), u32(grid[2]), u32(block[0]),
|
||||
u32(block[1]), u32(block[2]), u32(shared_mem), __ptr__(arg_ptrs).as_byte())
|
||||
_tuple_from_gpu(args, gpu_args)
|
||||
cache.free()
|
||||
|
||||
return wrapper
|
||||
|
||||
def _ptr_to_gpu(p: Ptr[T], n: int, cache: AllocCache, index_filter = lambda i: True, T: type):
|
||||
from internal.gc import atomic
|
||||
|
||||
if not atomic(T):
|
||||
tmp = Ptr[T](n)
|
||||
for i in range(n):
|
||||
if index_filter(i):
|
||||
tmp[i] = p[i].__to_gpu__(cache)
|
||||
p = tmp
|
||||
|
||||
mem = Memory._alloc(n, T)
|
||||
cache.add(mem._ptr)
|
||||
mem._write(p, n)
|
||||
return Ptr[T](mem._ptr)
|
||||
|
||||
def _ptr_from_gpu(p: Ptr[T], q: Ptr[T], n: int, index_filter = lambda i: True, T: type):
|
||||
from internal.gc import atomic
|
||||
|
||||
mem = Memory[T](q.as_byte())
|
||||
if not atomic(T):
|
||||
tmp = Ptr[T](n)
|
||||
mem._read(tmp, n)
|
||||
for i in range(n):
|
||||
if index_filter(i):
|
||||
p[i] = T.__from_gpu_new__(tmp[i])
|
||||
else:
|
||||
mem._read(p, n)
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _ptr_to_type(p: cobj, T: type) -> T:
|
||||
ret i8* %p
|
||||
|
||||
def _object_to_gpu(obj: T, cache: AllocCache, T: type):
|
||||
s = tuple(obj)
|
||||
gpu_mem = Memory._alloc(1, type(s))
|
||||
cache.add(gpu_mem._ptr)
|
||||
gpu_mem._write(__ptr__(s), 1)
|
||||
return _ptr_to_type(gpu_mem._ptr, T)
|
||||
|
||||
def _object_from_gpu(obj):
|
||||
T = type(obj)
|
||||
S = type(tuple(obj))
|
||||
|
||||
tmp = T.__new__()
|
||||
p = Ptr[S](tmp.__raw__())
|
||||
q = Ptr[S](obj.__raw__())
|
||||
|
||||
mem = Memory[S](q.as_byte())
|
||||
mem._read(p, 1)
|
||||
return tmp
|
||||
|
||||
@tuple
|
||||
class Pointer[T]:
|
||||
_ptr: Ptr[T]
|
||||
_len: int
|
||||
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return _ptr_to_gpu(self._ptr, self._len, cache)
|
||||
|
||||
def __from_gpu__(self, other: Ptr[T]):
|
||||
_ptr_from_gpu(self._ptr, other, self._len)
|
||||
|
||||
def __from_gpu_new__(other: Ptr[T]):
|
||||
return other
|
||||
|
||||
def raw(v: List[T], T: type):
|
||||
return Pointer(v.arr.ptr, len(v))
|
||||
|
||||
@extend
|
||||
class Ptr:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: Ptr[T]):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: Ptr[T]):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class NoneType:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: NoneType):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: NoneType):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class int:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: int):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: int):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class float:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: float):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: float):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class float32:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: float32):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: float32):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class bool:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: bool):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: bool):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class byte:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: byte):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: byte):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class Int:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: Int[N]):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: Int[N]):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class UInt:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
return self
|
||||
|
||||
def __from_gpu__(self, other: UInt[N]):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: UInt[N]):
|
||||
return other
|
||||
|
||||
@extend
|
||||
class str:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
n = self.len
|
||||
return str(_ptr_to_gpu(self.ptr, n, cache), n)
|
||||
|
||||
def __from_gpu__(self, other: str):
|
||||
pass
|
||||
|
||||
def __from_gpu_new__(other: str):
|
||||
n = other.len
|
||||
p = Ptr[byte](n)
|
||||
_ptr_from_gpu(p, other.ptr, n)
|
||||
return str(p, n)
|
||||
|
||||
@extend
|
||||
class List:
|
||||
@inline
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
mem = List[T].__new__()
|
||||
n = self.len
|
||||
gpu_ptr = _ptr_to_gpu(self.arr.ptr, n, cache)
|
||||
mem.arr = Array[T](gpu_ptr, n)
|
||||
mem.len = n
|
||||
return _object_to_gpu(mem, cache)
|
||||
|
||||
@inline
|
||||
def __from_gpu__(self, other: List[T]):
|
||||
mem = _object_from_gpu(other)
|
||||
my_cap = self.arr.len
|
||||
other_cap = mem.arr.len
|
||||
|
||||
if other_cap > my_cap:
|
||||
self._resize(other_cap)
|
||||
|
||||
_ptr_from_gpu(self.arr.ptr, mem.arr.ptr, mem.len)
|
||||
self.len = mem.len
|
||||
|
||||
@inline
|
||||
def __from_gpu_new__(other: List[T]):
|
||||
mem = _object_from_gpu(other)
|
||||
arr = Array[T](mem.arr.len)
|
||||
_ptr_from_gpu(arr.ptr, mem.arr.ptr, arr.len)
|
||||
mem.arr = arr
|
||||
return mem
|
||||
|
||||
@extend
|
||||
class Dict:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
from internal.khash import __ac_fsize
|
||||
mem = Dict[K,V].__new__()
|
||||
n = self._n_buckets
|
||||
f = __ac_fsize(n) if n else 0
|
||||
|
||||
mem._n_buckets = n
|
||||
mem._size = self._size
|
||||
mem._n_occupied = self._n_occupied
|
||||
mem._upper_bound = self._upper_bound
|
||||
mem._flags = _ptr_to_gpu(self._flags, f, cache)
|
||||
mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i))
|
||||
mem._vals = _ptr_to_gpu(self._vals, n, cache, lambda i: self._kh_exist(i))
|
||||
|
||||
return _object_to_gpu(mem, cache)
|
||||
|
||||
def __from_gpu__(self, other: Dict[K,V]):
|
||||
from internal.khash import __ac_fsize
|
||||
mem = _object_from_gpu(other)
|
||||
my_n = self._n_buckets
|
||||
n = mem._n_buckets
|
||||
f = __ac_fsize(n) if n else 0
|
||||
|
||||
if my_n != n:
|
||||
self._flags = Ptr[u32](f)
|
||||
self._keys = Ptr[K](n)
|
||||
self._vals = Ptr[V](n)
|
||||
|
||||
_ptr_from_gpu(self._flags, mem._flags, f)
|
||||
_ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i))
|
||||
_ptr_from_gpu(self._vals, mem._vals, n, lambda i: self._kh_exist(i))
|
||||
|
||||
self._n_buckets = n
|
||||
self._size = mem._size
|
||||
self._n_occupied = mem._n_occupied
|
||||
self._upper_bound = mem._upper_bound
|
||||
|
||||
def __from_gpu_new__(other: Dict[K,V]):
|
||||
from internal.khash import __ac_fsize
|
||||
mem = _object_from_gpu(other)
|
||||
|
||||
n = mem._n_buckets
|
||||
f = __ac_fsize(n) if n else 0
|
||||
flags = Ptr[u32](f)
|
||||
keys = Ptr[K](n)
|
||||
vals = Ptr[V](n)
|
||||
|
||||
_ptr_from_gpu(flags, mem._flags, f)
|
||||
mem._flags = flags
|
||||
_ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i))
|
||||
mem._keys = keys
|
||||
_ptr_from_gpu(vals, mem._vals, n, lambda i: mem._kh_exist(i))
|
||||
mem._vals = vals
|
||||
return mem
|
||||
|
||||
@extend
|
||||
class Set:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
from internal.khash import __ac_fsize
|
||||
mem = Set[K].__new__()
|
||||
n = self._n_buckets
|
||||
f = __ac_fsize(n) if n else 0
|
||||
|
||||
mem._n_buckets = n
|
||||
mem._size = self._size
|
||||
mem._n_occupied = self._n_occupied
|
||||
mem._upper_bound = self._upper_bound
|
||||
mem._flags = _ptr_to_gpu(self._flags, f, cache)
|
||||
mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i))
|
||||
|
||||
return _object_to_gpu(mem, cache)
|
||||
|
||||
def __from_gpu__(self, other: Set[K]):
|
||||
from internal.khash import __ac_fsize
|
||||
mem = _object_from_gpu(other)
|
||||
|
||||
my_n = self._n_buckets
|
||||
n = mem._n_buckets
|
||||
f = __ac_fsize(n) if n else 0
|
||||
|
||||
if my_n != n:
|
||||
self._flags = Ptr[u32](f)
|
||||
self._keys = Ptr[K](n)
|
||||
|
||||
_ptr_from_gpu(self._flags, mem._flags, f)
|
||||
_ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i))
|
||||
|
||||
self._n_buckets = n
|
||||
self._size = mem._size
|
||||
self._n_occupied = mem._n_occupied
|
||||
self._upper_bound = mem._upper_bound
|
||||
|
||||
def __from_gpu_new__(other: Set[K]):
|
||||
from internal.khash import __ac_fsize
|
||||
mem = _object_from_gpu(other)
|
||||
|
||||
n = mem._n_buckets
|
||||
f = __ac_fsize(n) if n else 0
|
||||
flags = Ptr[u32](f)
|
||||
keys = Ptr[K](n)
|
||||
|
||||
_ptr_from_gpu(flags, mem._flags, f)
|
||||
mem._flags = flags
|
||||
_ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i))
|
||||
mem._keys = keys
|
||||
return mem
|
||||
|
||||
@extend
|
||||
class Optional:
|
||||
def __to_gpu__(self, cache: AllocCache):
|
||||
if self is None:
|
||||
return self
|
||||
else:
|
||||
return Optional[T](self.__val__().__to_gpu__(cache))
|
||||
|
||||
def __from_gpu__(self, other: Optional[T]):
|
||||
if self is not None and other is not None:
|
||||
self.__val__().__from_gpu__(other.__val__())
|
||||
|
||||
def __from_gpu_new__(other: Optional[T]):
|
||||
if other is None:
|
||||
return Optional[T]()
|
||||
else:
|
||||
return Optional[T](T.__from_gpu_new__(other.__val__()))
|
||||
|
||||
@extend
|
||||
class __internal__:
|
||||
def class_to_gpu(obj, cache: AllocCache):
|
||||
if isinstance(obj, Tuple):
|
||||
return tuple(a.__to_gpu__(cache) for a in obj)
|
||||
elif isinstance(obj, ByVal):
|
||||
T = type(obj)
|
||||
return T(*tuple(a.__to_gpu__(cache) for a in tuple(obj)))
|
||||
else:
|
||||
T = type(obj)
|
||||
S = type(tuple(obj))
|
||||
mem = T.__new__()
|
||||
Ptr[S](mem.__raw__())[0] = tuple(obj).__to_gpu__(cache)
|
||||
return _object_to_gpu(mem, cache)
|
||||
|
||||
def class_from_gpu(obj, other):
|
||||
if isinstance(obj, Tuple):
|
||||
_tuple_from_gpu(obj, other)
|
||||
elif isinstance(obj, ByVal):
|
||||
_tuple_from_gpu(tuple(obj), tuple(other))
|
||||
else:
|
||||
S = type(tuple(obj))
|
||||
Ptr[S](obj.__raw__())[0] = S.__from_gpu_new__(tuple(_object_from_gpu(other)))
|
||||
|
||||
def class_from_gpu_new(other):
|
||||
if isinstance(other, Tuple):
|
||||
return tuple(type(a).__from_gpu_new__(a) for a in other)
|
||||
elif isinstance(other, ByVal):
|
||||
T = type(other)
|
||||
return T(*tuple(type(a).__from_gpu_new__(a) for a in tuple(other)))
|
||||
else:
|
||||
S = type(tuple(other))
|
||||
mem = _object_from_gpu(other)
|
||||
Ptr[S](mem.__raw__())[0] = S.__from_gpu_new__(tuple(mem))
|
||||
return mem
|
||||
|
||||
# @par(gpu=True) support
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _gpu_thread_x() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
ret i32 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _gpu_block_x() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
|
||||
ret i32 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _gpu_block_dim_x() -> u32:
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
|
||||
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
|
||||
ret i32 %res
|
||||
|
||||
def _gpu_loop_outline_template(start, stop, args, instance: Static[int]):
|
||||
@nonpure
|
||||
def _loop_step():
|
||||
return 1
|
||||
|
||||
@kernel
|
||||
def _kernel_stub(start: int, count: int, args):
|
||||
@nonpure
|
||||
def _gpu_loop_body_stub(idx, args):
|
||||
pass
|
||||
|
||||
@nonpure
|
||||
def _dummy_use(n):
|
||||
pass
|
||||
|
||||
_dummy_use(instance)
|
||||
idx = (int(_gpu_block_dim_x()) * int(_gpu_block_x())) + int(_gpu_thread_x())
|
||||
step = _loop_step()
|
||||
if idx < count:
|
||||
_gpu_loop_body_stub(start + (idx * step), args)
|
||||
|
||||
step = _loop_step()
|
||||
loop = range(start, stop, step)
|
||||
|
||||
MAX_BLOCK = 1024
|
||||
MAX_GRID = 2147483647
|
||||
G = MAX_BLOCK * MAX_GRID
|
||||
n = len(loop)
|
||||
|
||||
if n == 0:
|
||||
return
|
||||
elif n > G:
|
||||
raise ValueError(f'loop exceeds GPU iteration limit of {G}')
|
||||
|
||||
block = n
|
||||
grid = 1
|
||||
if n > MAX_BLOCK:
|
||||
block = MAX_BLOCK
|
||||
grid = (n // MAX_BLOCK) + (0 if n % MAX_BLOCK == 0 else 1)
|
||||
|
||||
_kernel_stub(start, n, args, grid=grid, block=block)
|
|
@ -1,6 +1,7 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
# Core library
|
||||
|
||||
from internal.core import *
|
||||
from internal.attributes import *
|
||||
from internal.types.ptr import *
|
||||
from internal.types.str import *
|
||||
|
@ -34,7 +35,11 @@ from internal.str import *
|
|||
from internal.sort import sorted
|
||||
|
||||
from openmp import Ident as __OMPIdent, for_par
|
||||
from gpu import _gpu_loop_outline_template
|
||||
from internal.file import File, gzFile, open, gzopen
|
||||
from pickle import pickle, unpickle
|
||||
from internal.dlopen import dlsym as _dlsym
|
||||
import internal.python
|
||||
|
||||
if __py_numerics__:
|
||||
import internal.pynumerics
|
||||
|
|
|
@ -108,13 +108,6 @@ def iter(x):
|
|||
return x.__iter__()
|
||||
|
||||
|
||||
def copy(x):
|
||||
"""
|
||||
Return a copy of x
|
||||
"""
|
||||
return x.__copy__()
|
||||
|
||||
|
||||
def abs(x):
|
||||
"""
|
||||
Return the absolute value of x
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
# Seq runtime functions
|
||||
# runtime functions
|
||||
from C import seq_print(str)
|
||||
from C import seq_print_full(str, cobj)
|
||||
|
||||
|
@ -216,7 +216,7 @@ def expm1(a: float) -> float:
|
|||
|
||||
@pure
|
||||
@C
|
||||
def ldexp(a: float, b: int) -> float:
|
||||
def ldexp(a: float, b: i32) -> float:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -406,6 +406,234 @@ def modf(a: float, b: Ptr[float]) -> float:
|
|||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def ceilf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def floorf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def fabsf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def fmodf(a: float32, b: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def expf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def expm1f(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def ldexpf(a: float32, b: i32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def logf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def log2f(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def log10f(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def sqrtf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def powf(a: float32, b: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def roundf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def acosf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def asinf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def atanf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def atan2f(a: float32, b: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def cosf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def sinf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def tanf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def coshf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def sinhf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def tanhf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def acoshf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def asinhf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def atanhf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def copysignf(a: float32, b: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def log1pf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def truncf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def log2f(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def erff(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def erfcf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def tgammaf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def lgammaf(a: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def remainderf(a: float32, b: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@pure
|
||||
@C
|
||||
def hypotf(a: float32, b: float32) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@nocapture
|
||||
@C
|
||||
def frexpf(a: float32, b: Ptr[Int[32]]) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
@nocapture
|
||||
@C
|
||||
def modff(a: float32, b: Ptr[float32]) -> float32:
|
||||
pass
|
||||
|
||||
|
||||
# <stdio.h>
|
||||
@pure
|
||||
@C
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
@__internal__
|
||||
class __internal__:
|
||||
pass
|
||||
|
@ -15,13 +17,23 @@ class byte:
|
|||
@tuple
|
||||
@__internal__
|
||||
class int:
|
||||
MAX = 9223372036854775807
|
||||
pass
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
class float:
|
||||
MIN_10_EXP = -307
|
||||
pass
|
||||
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
class float32:
|
||||
MIN_10_EXP = -37
|
||||
pass
|
||||
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
class NoneType:
|
||||
|
@ -93,7 +105,7 @@ function = Function
|
|||
|
||||
# dummy
|
||||
@__internal__
|
||||
class TypeVar: pass
|
||||
class TypeVar[T]: pass
|
||||
@__internal__
|
||||
class ByVal: pass
|
||||
@__internal__
|
||||
|
@ -154,3 +166,13 @@ def super():
|
|||
def superf(*args):
|
||||
"""Special handling"""
|
||||
pass
|
||||
|
||||
def __realized__(fn, args):
|
||||
pass
|
||||
|
||||
def statictuple(*args):
|
||||
return args
|
||||
|
||||
def __static_print__(*args):
|
||||
pass
|
||||
|
||||
|
|
|
@ -145,8 +145,9 @@ class gzFile:
|
|||
if self.buf[offset - 1] == byte(10): # '\n'
|
||||
break
|
||||
|
||||
oldsz = self.sz
|
||||
self.sz *= 2
|
||||
self.buf = realloc(self.buf, self.sz)
|
||||
self.buf = realloc(self.buf, self.sz, oldsz)
|
||||
|
||||
return offset
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ def seq_alloc_atomic(a: int) -> cobj:
|
|||
@nocapture
|
||||
@derives
|
||||
@C
|
||||
def seq_realloc(p: cobj, a: int) -> cobj:
|
||||
def seq_realloc(p: cobj, newsize: int, oldsize: int) -> cobj:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -76,8 +76,8 @@ def alloc_atomic(sz: int):
|
|||
return seq_alloc_atomic(sz)
|
||||
|
||||
|
||||
def realloc(p: cobj, sz: int):
|
||||
return seq_realloc(p, sz)
|
||||
def realloc(p: cobj, newsz: int, oldsz: int):
|
||||
return seq_realloc(p, newsz, oldsz)
|
||||
|
||||
|
||||
def free(p: cobj):
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _floordiv_int_float(self: int, other: float) -> float:
|
||||
declare double @llvm.floor.f64(double)
|
||||
%0 = sitofp i64 %self to double
|
||||
%1 = fdiv double %0, %other
|
||||
%2 = call double @llvm.floor.f64(double %1)
|
||||
ret double %2
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _floordiv_int_int(self: int, other: int) -> int:
|
||||
%0 = sdiv i64 %self, %other
|
||||
ret i64 %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _truediv_int_float(self: int, other: float) -> float:
|
||||
%0 = sitofp i64 %self to double
|
||||
%1 = fdiv double %0, %other
|
||||
ret double %1
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _truediv_int_int(self: int, other: int) -> float:
|
||||
%0 = sitofp i64 %self to double
|
||||
%1 = sitofp i64 %other to double
|
||||
%2 = fdiv double %0, %1
|
||||
ret double %2
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _mod_int_float(self: int, other: float) -> float:
|
||||
%0 = sitofp i64 %self to double
|
||||
%1 = frem double %0, %other
|
||||
ret double %1
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _mod_int_int(self: int, other: int) -> int:
|
||||
%0 = srem i64 %self, %other
|
||||
ret i64 %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _truediv_float_float(self: float, other: float) -> float:
|
||||
%0 = fdiv double %self, %other
|
||||
ret double %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _mod_float_float(self: float, other: float) -> float:
|
||||
%0 = frem double %self, %other
|
||||
ret double %0
|
||||
|
||||
def _divmod_int_int(self: int, other: int):
|
||||
d = _floordiv_int_int(self, other)
|
||||
m = self - d * other
|
||||
if m and ((other ^ m) < 0):
|
||||
m += other
|
||||
d -= 1
|
||||
return (d, m)
|
||||
|
||||
def _divmod_float_float(self: float, other: float):
|
||||
mod = _mod_float_float(self, other)
|
||||
div = _truediv_float_float(self - mod, other)
|
||||
if mod:
|
||||
if (other < 0) != (mod < 0):
|
||||
mod += other
|
||||
div -= 1.0
|
||||
else:
|
||||
mod = (0.0).copysign(other)
|
||||
|
||||
floordiv = 0.0
|
||||
if div:
|
||||
floordiv = div.__floor__()
|
||||
if div - floordiv > 0.5:
|
||||
floordiv += 1.0
|
||||
else:
|
||||
floordiv = (0.0).copysign(self / other)
|
||||
|
||||
return (floordiv, mod)
|
||||
|
||||
@extend
|
||||
class int:
|
||||
def __floordiv__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float floor division by zero")
|
||||
return _divmod_float_float(float(self), other)[0]
|
||||
|
||||
def __floordiv__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("integer division or modulo by zero")
|
||||
return _divmod_int_int(self, other)[0]
|
||||
|
||||
def __truediv__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float division by zero")
|
||||
return _truediv_int_float(self, other)
|
||||
|
||||
def __truediv__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("division by zero")
|
||||
return _truediv_int_int(self, other)
|
||||
|
||||
def __mod__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float modulo")
|
||||
return _divmod_float_float(self, other)[1]
|
||||
|
||||
def __mod__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("integer division or modulo by zero")
|
||||
return _divmod_int_int(self, other)[1]
|
||||
|
||||
def __divmod__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float divmod()")
|
||||
return _divmod_float_float(float(self), other)
|
||||
|
||||
def __divmod__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("integer division or modulo by zero")
|
||||
return _divmod_int_int(self, other)
|
||||
|
||||
@extend
|
||||
class float:
|
||||
def __floordiv__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float floor division by zero")
|
||||
return _divmod_float_float(self, other)[0]
|
||||
|
||||
def __floordiv__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("float floor division by zero")
|
||||
return _divmod_float_float(self, float(other))[0]
|
||||
|
||||
def __truediv__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float division by zero")
|
||||
return _truediv_float_float(self, other)
|
||||
|
||||
def __truediv__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("float division by zero")
|
||||
return _truediv_float_float(self, float(other))
|
||||
|
||||
def __mod__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float modulo")
|
||||
return _divmod_float_float(self, other)[1]
|
||||
|
||||
def __mod__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("float modulo")
|
||||
return _divmod_float_float(self, float(other))[1]
|
||||
|
||||
def __divmod__(self, other: float):
|
||||
if other == 0.0:
|
||||
raise ZeroDivisionError("float divmod()")
|
||||
return _divmod_float_float(self, other)
|
||||
|
||||
def __divmod__(self, other: int):
|
||||
if other == 0:
|
||||
raise ZeroDivisionError("float divmod()")
|
||||
return _divmod_float_float(self, float(other))
|
|
@ -12,8 +12,10 @@ PyImport_AddModule = Function[[cobj], cobj](cobj())
|
|||
PyImport_AddModuleObject = Function[[cobj], cobj](cobj())
|
||||
PyImport_ImportModule = Function[[cobj], cobj](cobj())
|
||||
PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj())
|
||||
PyErr_NormalizeException = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj())
|
||||
PyRun_SimpleString = Function[[cobj], NoneType](cobj())
|
||||
PyEval_GetGlobals = Function[[], cobj](cobj())
|
||||
PyEval_GetBuiltins = Function[[], cobj](cobj())
|
||||
|
||||
# conversions
|
||||
PyLong_AsLong = Function[[cobj], int](cobj())
|
||||
|
@ -97,6 +99,7 @@ PyObject_GetItem = Function[[cobj, cobj], cobj](cobj())
|
|||
PyObject_SetItem = Function[[cobj, cobj, cobj], int](cobj())
|
||||
PyObject_DelItem = Function[[cobj, cobj], int](cobj())
|
||||
PyObject_RichCompare = Function[[cobj, cobj, i32], cobj](cobj())
|
||||
PyObject_IsInstance = Function[[cobj, cobj], i32](cobj())
|
||||
|
||||
# constants
|
||||
Py_None = cobj()
|
||||
|
@ -154,8 +157,10 @@ def init_dl_handles(py_handle: cobj):
|
|||
global PyImport_AddModuleObject
|
||||
global PyImport_ImportModule
|
||||
global PyErr_Fetch
|
||||
global PyErr_NormalizeException
|
||||
global PyRun_SimpleString
|
||||
global PyEval_GetGlobals
|
||||
global PyEval_GetBuiltins
|
||||
global PyLong_AsLong
|
||||
global PyLong_FromLong
|
||||
global PyFloat_AsDouble
|
||||
|
@ -233,6 +238,7 @@ def init_dl_handles(py_handle: cobj):
|
|||
global PyObject_SetItem
|
||||
global PyObject_DelItem
|
||||
global PyObject_RichCompare
|
||||
global PyObject_IsInstance
|
||||
global Py_None
|
||||
global Py_True
|
||||
global Py_False
|
||||
|
@ -244,8 +250,10 @@ def init_dl_handles(py_handle: cobj):
|
|||
PyImport_AddModuleObject = dlsym(py_handle, "PyImport_AddModuleObject")
|
||||
PyImport_ImportModule = dlsym(py_handle, "PyImport_ImportModule")
|
||||
PyErr_Fetch = dlsym(py_handle, "PyErr_Fetch")
|
||||
PyErr_NormalizeException = dlsym(py_handle, "PyErr_NormalizeException")
|
||||
PyRun_SimpleString = dlsym(py_handle, "PyRun_SimpleString")
|
||||
PyEval_GetGlobals = dlsym(py_handle, "PyEval_GetGlobals")
|
||||
PyEval_GetBuiltins = dlsym(py_handle, "PyEval_GetBuiltins")
|
||||
PyLong_AsLong = dlsym(py_handle, "PyLong_AsLong")
|
||||
PyLong_FromLong = dlsym(py_handle, "PyLong_FromLong")
|
||||
PyFloat_AsDouble = dlsym(py_handle, "PyFloat_AsDouble")
|
||||
|
@ -323,6 +331,7 @@ def init_dl_handles(py_handle: cobj):
|
|||
PyObject_SetItem = dlsym(py_handle, "PyObject_SetItem")
|
||||
PyObject_DelItem = dlsym(py_handle, "PyObject_DelItem")
|
||||
PyObject_RichCompare = dlsym(py_handle, "PyObject_RichCompare")
|
||||
PyObject_IsInstance = dlsym(py_handle, "PyObject_IsInstance")
|
||||
Py_None = dlsym(py_handle, "_Py_NoneStruct")
|
||||
Py_True = dlsym(py_handle, "_Py_TrueStruct")
|
||||
Py_False = dlsym(py_handle, "_Py_FalseStruct")
|
||||
|
@ -603,19 +612,17 @@ class pyobj:
|
|||
def exc_check():
|
||||
ptype, pvalue, ptraceback = cobj(), cobj(), cobj()
|
||||
PyErr_Fetch(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback))
|
||||
PyErr_NormalizeException(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback))
|
||||
if ptype != cobj():
|
||||
py_msg = PyObject_Str(pvalue) if pvalue != cobj() else pvalue
|
||||
msg = pyobj.to_str(py_msg, "ignore", "<empty Python message>")
|
||||
py_typ = PyObject_GetAttrString(ptype, "__name__".c_str())
|
||||
typ = pyobj.to_str(py_typ, "ignore")
|
||||
|
||||
pyobj.decref(ptype)
|
||||
pyobj.decref(pvalue)
|
||||
pyobj.decref(ptraceback)
|
||||
pyobj.decref(py_msg)
|
||||
pyobj.decref(py_typ)
|
||||
|
||||
raise PyError(msg, typ)
|
||||
# pyobj.decref(pvalue)
|
||||
raise PyError(msg, pyobj(pvalue))
|
||||
|
||||
def exc_wrap(_retval: T, T: type) -> T:
|
||||
pyobj.exc_check()
|
||||
|
@ -667,7 +674,14 @@ class pyobj:
|
|||
PyRun_SimpleString(code.c_str())
|
||||
|
||||
def _globals() -> pyobj:
|
||||
return pyobj(PyEval_GetGlobals())
|
||||
p = PyEval_GetGlobals()
|
||||
if p == cobj():
|
||||
Py_IncRef(Py_None)
|
||||
return pyobj(Py_None)
|
||||
return pyobj(p)
|
||||
|
||||
def _builtins() -> pyobj:
|
||||
return pyobj(PyEval_GetBuiltins())
|
||||
|
||||
def _get_module(name: str) -> pyobj:
|
||||
p = pyobj(pyobj.exc_wrap(PyImport_AddModule(name.c_str())))
|
||||
|
@ -686,6 +700,17 @@ class pyobj:
|
|||
return bool(pyobj.exc_wrap(PyObject_IsTrue(self.p) == 1))
|
||||
|
||||
|
||||
def _get_identifier(typ: str) -> pyobj:
|
||||
t = pyobj._builtins()[typ]
|
||||
if t.p == cobj():
|
||||
t = pyobj._main_module()[typ]
|
||||
return t
|
||||
|
||||
|
||||
def _isinstance(what: pyobj, typ: pyobj) -> bool:
|
||||
return bool(pyobj.exc_wrap(PyObject_IsInstance(what.p, typ.p)))
|
||||
|
||||
|
||||
# Type conversions
|
||||
|
||||
@extend
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
from internal.gc import sizeof
|
||||
|
||||
|
||||
|
|
|
@ -106,6 +106,15 @@ class Dict:
|
|||
def __len__(self) -> int:
|
||||
return self._size
|
||||
|
||||
def __or__(self, other):
|
||||
new = self.__copy__()
|
||||
new.update(other)
|
||||
return new
|
||||
|
||||
def __ior__(self, other):
|
||||
self.update(other)
|
||||
return self
|
||||
|
||||
def __copy__(self):
|
||||
if self.__len__() == 0:
|
||||
return Dict[K, V]()
|
||||
|
@ -183,9 +192,13 @@ class Dict:
|
|||
ret, x = self._kh_put(key)
|
||||
self._vals[x] = op(dflt if ret != 0 else self._vals[x], other)
|
||||
|
||||
def update(self, other: Dict[K, V]):
|
||||
for k, v in other.items():
|
||||
self[k] = v
|
||||
def update(self, other):
|
||||
if isinstance(other, Dict[K, V]):
|
||||
for k, v in other.items():
|
||||
self[k] = v
|
||||
else:
|
||||
for k, v in other:
|
||||
self[k] = v
|
||||
|
||||
def pop(self, key: K) -> V:
|
||||
x = self._kh_get(key)
|
||||
|
@ -284,10 +297,14 @@ class Dict:
|
|||
|
||||
if self._n_buckets < new_n_buckets:
|
||||
self._keys = Ptr[K](
|
||||
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
|
||||
gc.realloc(self._keys.as_byte(),
|
||||
new_n_buckets * gc.sizeof(K),
|
||||
self._n_buckets * gc.sizeof(K))
|
||||
)
|
||||
self._vals = Ptr[V](
|
||||
gc.realloc(self._vals.as_byte(), new_n_buckets * gc.sizeof(V))
|
||||
gc.realloc(self._vals.as_byte(),
|
||||
new_n_buckets * gc.sizeof(V),
|
||||
self._n_buckets * gc.sizeof(V))
|
||||
)
|
||||
|
||||
if j:
|
||||
|
@ -324,10 +341,14 @@ class Dict:
|
|||
|
||||
if self._n_buckets > new_n_buckets:
|
||||
self._keys = Ptr[K](
|
||||
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
|
||||
gc.realloc(self._keys.as_byte(),
|
||||
new_n_buckets * gc.sizeof(K),
|
||||
self._n_buckets * gc.sizeof(K))
|
||||
)
|
||||
self._vals = Ptr[V](
|
||||
gc.realloc(self._vals.as_byte(), new_n_buckets * gc.sizeof(V))
|
||||
gc.realloc(self._vals.as_byte(),
|
||||
new_n_buckets * gc.sizeof(V),
|
||||
self._n_buckets * gc.sizeof(V))
|
||||
)
|
||||
|
||||
self._flags = new_flags
|
||||
|
|
|
@ -238,7 +238,9 @@ class List:
|
|||
len0 = self.__len__()
|
||||
new_cap = n * len0
|
||||
if self.arr.len < new_cap:
|
||||
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), new_cap * gc.sizeof(T)))
|
||||
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(),
|
||||
new_cap * gc.sizeof(T),
|
||||
self.arr.len * gc.sizeof(T)))
|
||||
self.arr = Array[T](p, new_cap)
|
||||
|
||||
idx = len0
|
||||
|
@ -349,7 +351,9 @@ class List:
|
|||
raise IndexError(msg)
|
||||
|
||||
def _resize(self, new_cap: int):
|
||||
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), new_cap * gc.sizeof(T)))
|
||||
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(),
|
||||
new_cap * gc.sizeof(T),
|
||||
self.arr.len * gc.sizeof(T)))
|
||||
self.arr = Array[T](p, new_cap)
|
||||
|
||||
def _resize_if_full(self):
|
||||
|
@ -414,5 +418,37 @@ class List:
|
|||
return Array[T](Ptr[T](), 0)
|
||||
return self.arr.slice(start, stop).__copy__()
|
||||
|
||||
def _cmp(self, other: List[T]):
|
||||
n1 = self.__len__()
|
||||
n2 = other.__len__()
|
||||
nmin = n1 if n1 < n2 else n2
|
||||
for i in range(nmin):
|
||||
a = self.arr[i]
|
||||
b = other.arr[i]
|
||||
|
||||
if a < b:
|
||||
return -1
|
||||
elif a == b:
|
||||
continue
|
||||
else:
|
||||
return 1
|
||||
if n1 < n2:
|
||||
return -1
|
||||
elif n1 == n2:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def __lt__(self, other: List[T]):
|
||||
return self._cmp(other) < 0
|
||||
|
||||
def __gt__(self, other: List[T]):
|
||||
return self._cmp(other) > 0
|
||||
|
||||
def __le__(self, other: List[T]):
|
||||
return self._cmp(other) <= 0
|
||||
|
||||
def __ge__(self, other: List[T]):
|
||||
return self._cmp(other) >= 0
|
||||
|
||||
list = List
|
||||
|
|
|
@ -316,7 +316,9 @@ class Set:
|
|||
|
||||
if self._n_buckets < new_n_buckets:
|
||||
self._keys = Ptr[K](
|
||||
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
|
||||
gc.realloc(self._keys.as_byte(),
|
||||
new_n_buckets * gc.sizeof(K),
|
||||
self._n_buckets * gc.sizeof(K))
|
||||
)
|
||||
|
||||
if j:
|
||||
|
@ -350,7 +352,9 @@ class Set:
|
|||
|
||||
if self._n_buckets > new_n_buckets:
|
||||
self._keys = Ptr[K](
|
||||
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
|
||||
gc.realloc(self._keys.as_byte(),
|
||||
new_n_buckets * gc.sizeof(K),
|
||||
self._n_buckets * gc.sizeof(K))
|
||||
)
|
||||
|
||||
self._flags = new_flags
|
||||
|
|
|
@ -72,9 +72,9 @@ class CError(Exception):
|
|||
|
||||
|
||||
class PyError(Exception):
|
||||
pytype: str
|
||||
pytype: pyobj
|
||||
|
||||
def __init__(self, message: str, pytype: str):
|
||||
def __init__(self, message: str, pytype: pyobj):
|
||||
super().__init__("PyError", message)
|
||||
self.pytype = pytype
|
||||
|
||||
|
|
|
@ -395,3 +395,326 @@ class float:
|
|||
@property
|
||||
def imag(self) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
@extend
|
||||
class float32:
|
||||
@pure
|
||||
@llvm
|
||||
def __new__(self: float) -> float32:
|
||||
%0 = fptrunc double %self to float
|
||||
ret float %0
|
||||
|
||||
def __new__(what: float32) -> float32:
|
||||
return what
|
||||
|
||||
def __new__() -> float32:
|
||||
return float32.__new__(0.0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = seq_str_float(self.__float__())
|
||||
return s if s != "-nan" else "nan"
|
||||
|
||||
def __copy__(self) -> float32:
|
||||
return self
|
||||
|
||||
def __deepcopy__(self) -> float32:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __int__(self) -> int:
|
||||
%0 = fptosi float %self to i64
|
||||
ret i64 %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __float__(self) -> float:
|
||||
%0 = fpext float %self to double
|
||||
ret double %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __bool__(self) -> bool:
|
||||
%0 = fcmp one float %self, 0.000000e+00
|
||||
%1 = zext i1 %0 to i8
|
||||
ret i8 %1
|
||||
|
||||
def __pos__(self) -> float32:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __neg__(self) -> float32:
|
||||
%0 = fneg float %self
|
||||
ret float %0
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __add__(a: float32, b: float32) -> float32:
|
||||
%tmp = fadd float %a, %b
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __sub__(a: float32, b: float32) -> float32:
|
||||
%tmp = fsub float %a, %b
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __mul__(a: float32, b: float32) -> float32:
|
||||
%tmp = fmul float %a, %b
|
||||
ret float %tmp
|
||||
|
||||
def __floordiv__(self, other: float32) -> float:
|
||||
return self.__truediv__(other).__floor__()
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __truediv__(a: float32, b: float32) -> float32:
|
||||
%tmp = fdiv float %a, %b
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __mod__(a: float32, b: float32) -> float32:
|
||||
%tmp = frem float %a, %b
|
||||
ret float %tmp
|
||||
|
||||
def __divmod__(self, other: float32) -> Tuple[float32, float32]:
|
||||
mod = self % other
|
||||
div = (self - mod) / other
|
||||
if mod:
|
||||
if (other < 0) != (mod < 0):
|
||||
mod += other
|
||||
div -= 1.0
|
||||
else:
|
||||
mod = (0.0).copysign(other)
|
||||
|
||||
floordiv = 0.0
|
||||
if div:
|
||||
floordiv = div.__floor__()
|
||||
if div - floordiv > 0.5:
|
||||
floordiv += 1.0
|
||||
else:
|
||||
floordiv = (0.0).copysign(self / other)
|
||||
|
||||
return (floordiv, mod)
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __eq__(a: float32, b: float32) -> bool:
|
||||
%tmp = fcmp oeq float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ne__(a: float32, b: float32) -> bool:
|
||||
entry:
|
||||
%tmp = fcmp one float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __lt__(a: float32, b: float32) -> bool:
|
||||
%tmp = fcmp olt float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __gt__(a: float32, b: float32) -> bool:
|
||||
%tmp = fcmp ogt float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __le__(a: float32, b: float32) -> bool:
|
||||
%tmp = fcmp ole float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ge__(a: float32, b: float32) -> bool:
|
||||
%tmp = fcmp oge float %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sqrt(a: float32) -> float32:
|
||||
declare float @llvm.sqrt.f32(float %a)
|
||||
%tmp = call float @llvm.sqrt.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sin(a: float32) -> float32:
|
||||
declare float @llvm.sin.f32(float %a)
|
||||
%tmp = call float @llvm.sin.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def cos(a: float32) -> float32:
|
||||
declare float @llvm.cos.f32(float %a)
|
||||
%tmp = call float @llvm.cos.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp(a: float32) -> float32:
|
||||
declare float @llvm.exp.f32(float %a)
|
||||
%tmp = call float @llvm.exp.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp2(a: float32) -> float32:
|
||||
declare float @llvm.exp2.f32(float %a)
|
||||
%tmp = call float @llvm.exp2.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log(a: float32) -> float32:
|
||||
declare float @llvm.log.f32(float %a)
|
||||
%tmp = call float @llvm.log.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log10(a: float32) -> float32:
|
||||
declare float @llvm.log10.f32(float %a)
|
||||
%tmp = call float @llvm.log10.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log2(a: float32) -> float32:
|
||||
declare float @llvm.log2.f32(float %a)
|
||||
%tmp = call float @llvm.log2.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __abs__(a: float32) -> float32:
|
||||
declare float @llvm.fabs.f32(float %a)
|
||||
%tmp = call float @llvm.fabs.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __floor__(a: float32) -> float32:
|
||||
declare float @llvm.floor.f32(float %a)
|
||||
%tmp = call float @llvm.floor.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ceil__(a: float32) -> float32:
|
||||
declare float @llvm.ceil.f32(float %a)
|
||||
%tmp = call float @llvm.ceil.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __trunc__(a: float32) -> float32:
|
||||
declare float @llvm.trunc.f32(float %a)
|
||||
%tmp = call float @llvm.trunc.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def rint(a: float32) -> float32:
|
||||
declare float @llvm.rint.f32(float %a)
|
||||
%tmp = call float @llvm.rint.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def nearbyint(a: float32) -> float32:
|
||||
declare float @llvm.nearbyint.f32(float %a)
|
||||
%tmp = call float @llvm.nearbyint.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __round__(a: float32) -> float32:
|
||||
declare float @llvm.round.f32(float %a)
|
||||
%tmp = call float @llvm.round.f32(float %a)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __pow__(a: float32, b: float32) -> float32:
|
||||
declare float @llvm.pow.f32(float %a, float %b)
|
||||
%tmp = call float @llvm.pow.f32(float %a, float %b)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def min(a: float32, b: float32) -> float32:
|
||||
declare float @llvm.minnum.f32(float %a, float %b)
|
||||
%tmp = call float @llvm.minnum.f32(float %a, float %b)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def max(a: float32, b: float32) -> float32:
|
||||
declare float @llvm.maxnum.f32(float %a, float %b)
|
||||
%tmp = call float @llvm.maxnum.f32(float %a, float %b)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def copysign(a: float32, b: float32) -> float32:
|
||||
declare float @llvm.copysign.f32(float %a, float %b)
|
||||
%tmp = call float @llvm.copysign.f32(float %a, float %b)
|
||||
ret float %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def fma(a: float32, b: float32, c: float32) -> float32:
|
||||
declare float @llvm.fma.f32(float %a, float %b, float %c)
|
||||
%tmp = call float @llvm.fma.f32(float %a, float %b, float %c)
|
||||
ret float %tmp
|
||||
|
||||
@nocapture
|
||||
@llvm
|
||||
def __atomic_xchg__(d: Ptr[float32], b: float32) -> None:
|
||||
%tmp = atomicrmw xchg float* %d, float %b seq_cst
|
||||
ret {} {}
|
||||
|
||||
@nocapture
|
||||
@llvm
|
||||
def __atomic_add__(d: Ptr[float32], b: float32) -> float32:
|
||||
%tmp = atomicrmw fadd float* %d, float %b seq_cst
|
||||
ret float %tmp
|
||||
|
||||
@nocapture
|
||||
@llvm
|
||||
def __atomic_sub__(d: Ptr[float32], b: float32) -> float32:
|
||||
%tmp = atomicrmw fsub float* %d, float %b seq_cst
|
||||
ret float %tmp
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.__float__().__hash__()
|
||||
|
||||
def __match__(self, i: float32) -> bool:
|
||||
return self == i
|
||||
|
||||
|
||||
@extend
|
||||
class float:
|
||||
def __suffix_f32__(double) -> float32:
|
||||
return float32.__new__(double)
|
||||
|
||||
f32 = float32
|
||||
|
|
|
@ -182,6 +182,9 @@ class int:
|
|||
%tmp = srem i64 %a, %b
|
||||
ret i64 %tmp
|
||||
|
||||
def __divmod__(self, other: float) -> Tuple[float, float]:
|
||||
return float(self).__divmod__(other)
|
||||
|
||||
def __divmod__(self, other: int) -> Tuple[int, int]:
|
||||
d = self // other
|
||||
m = self - d * other
|
||||
|
|
|
@ -95,3 +95,11 @@ class range:
|
|||
return f"range({self.start}, {self.stop})"
|
||||
else:
|
||||
return f"range({self.start}, {self.stop}, {self.step})"
|
||||
|
||||
@overload
|
||||
def staticrange(start: Static[int], stop: Static[int], step: Static[int] = 1):
|
||||
return range(start, stop, step)
|
||||
|
||||
@overload
|
||||
def staticrange(stop: Static[int]):
|
||||
return range(0, stop, 1)
|
||||
|
|
|
@ -1,11 +1,23 @@
|
|||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
e = 2.718281828459045
|
||||
pi = 3.141592653589793
|
||||
tau = 6.283185307179586
|
||||
|
||||
inf = 1.0 / 0.0
|
||||
nan = 0.0 / 0.0
|
||||
@pure
|
||||
@llvm
|
||||
def _inf() -> float:
|
||||
ret double 0x7FF0000000000000
|
||||
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def _nan() -> float:
|
||||
ret double 0x7FF8000000000000
|
||||
|
||||
|
||||
e = 2.7182818284590452354
|
||||
pi = 3.14159265358979323846
|
||||
tau = 6.28318530717958647693
|
||||
inf = _inf()
|
||||
nan = _nan()
|
||||
|
||||
|
||||
def factorial(x: int) -> int:
|
||||
|
@ -43,11 +55,14 @@ def isnan(x: float) -> bool:
|
|||
|
||||
Return True if float arg is a NaN, else False.
|
||||
"""
|
||||
test = x == x
|
||||
# if it is true then it is a number
|
||||
if test:
|
||||
return False
|
||||
return True
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float) -> bool:
|
||||
%y = fcmp uno double %x, 0.000000e+00
|
||||
%z = zext i1 %y to i8
|
||||
ret i8 %z
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
def isinf(x: float) -> bool:
|
||||
|
@ -56,7 +71,16 @@ def isinf(x: float) -> bool:
|
|||
|
||||
Return True if float arg is an INF, else False.
|
||||
"""
|
||||
return x == inf or x == -inf
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float) -> bool:
|
||||
declare double @llvm.fabs.f64(double)
|
||||
%a = call double @llvm.fabs.f64(double %x)
|
||||
%b = fcmp oeq double %a, 0x7FF0000000000000
|
||||
%c = zext i1 %b to i8
|
||||
ret i8 %c
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
def isfinite(x: float) -> bool:
|
||||
|
@ -66,9 +90,35 @@ def isfinite(x: float) -> bool:
|
|||
Return True if x is neither an infinity nor a NaN,
|
||||
and False otherwise.
|
||||
"""
|
||||
if isnan(x) or isinf(x):
|
||||
return False
|
||||
return True
|
||||
return not (isnan(x) or isinf(x))
|
||||
|
||||
|
||||
def _check1(arg: float, r: float, can_overflow: bool = False):
|
||||
if __py_numerics__:
|
||||
if isnan(r) and not isnan(arg):
|
||||
raise ValueError("math domain error")
|
||||
|
||||
if isinf(r) and isfinite(arg):
|
||||
if can_overflow:
|
||||
raise OverflowError("math range error")
|
||||
else:
|
||||
raise ValueError("math domain error")
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def _check2(x: float, y: float, r: float, can_overflow: bool = False):
|
||||
if __py_numerics__:
|
||||
if isnan(r) and not isnan(x) and not isnan(y):
|
||||
raise ValueError("math domain error")
|
||||
|
||||
if isinf(r) and isfinite(x) and isfinite(y):
|
||||
if can_overflow:
|
||||
raise OverflowError("math range error")
|
||||
else:
|
||||
raise ValueError("math domain error")
|
||||
|
||||
return r
|
||||
|
||||
|
||||
def ceil(x: float) -> float:
|
||||
|
@ -149,7 +199,7 @@ def exp(x: float) -> float:
|
|||
%y = call double @llvm.exp.f64(double %x)
|
||||
ret double %y
|
||||
|
||||
return f(x)
|
||||
return _check1(x, f(x), True)
|
||||
|
||||
|
||||
def expm1(x: float) -> float:
|
||||
|
@ -159,7 +209,7 @@ def expm1(x: float) -> float:
|
|||
Return e raised to the power x, minus 1. expm1 provides
|
||||
a way to compute this quantity to full precision.
|
||||
"""
|
||||
return _C.expm1(x)
|
||||
return _check1(x, _C.expm1(x), True)
|
||||
|
||||
|
||||
def ldexp(x: float, i: int) -> float:
|
||||
|
@ -168,7 +218,7 @@ def ldexp(x: float, i: int) -> float:
|
|||
|
||||
Returns x multiplied by 2 raised to the power of exponent.
|
||||
"""
|
||||
return _C.ldexp(x, i)
|
||||
return _check1(x, _C.ldexp(x, i32(i)), True)
|
||||
|
||||
|
||||
def log(x: float, base: float = e) -> float:
|
||||
|
@ -185,9 +235,9 @@ def log(x: float, base: float = e) -> float:
|
|||
ret double %y
|
||||
|
||||
if base == e:
|
||||
return f(x)
|
||||
return _check1(x, f(x))
|
||||
else:
|
||||
return f(x) / f(base)
|
||||
return _check1(x, f(x)) / _check1(base, f(base))
|
||||
|
||||
|
||||
def log2(x: float) -> float:
|
||||
|
@ -203,7 +253,7 @@ def log2(x: float) -> float:
|
|||
%y = call double @llvm.log2.f64(double %x)
|
||||
ret double %y
|
||||
|
||||
return f(x)
|
||||
return _check1(x, f(x))
|
||||
|
||||
|
||||
def log10(x: float) -> float:
|
||||
|
@ -219,7 +269,7 @@ def log10(x: float) -> float:
|
|||
%y = call double @llvm.log10.f64(double %x)
|
||||
ret double %y
|
||||
|
||||
return f(x)
|
||||
return _check1(x, f(x))
|
||||
|
||||
|
||||
def degrees(x: float) -> float:
|
||||
|
@ -255,7 +305,7 @@ def sqrt(x: float) -> float:
|
|||
%y = call double @llvm.sqrt.f64(double %x)
|
||||
ret double %y
|
||||
|
||||
return f(x)
|
||||
return _check1(x, f(x))
|
||||
|
||||
|
||||
def pow(x: float, y: float) -> float:
|
||||
|
@ -271,7 +321,7 @@ def pow(x: float, y: float) -> float:
|
|||
%z = call double @llvm.pow.f64(double %x, double %y)
|
||||
ret double %z
|
||||
|
||||
return f(x, y)
|
||||
return _check2(x, y, f(x, y), True)
|
||||
|
||||
|
||||
def acos(x: float) -> float:
|
||||
|
@ -280,7 +330,7 @@ def acos(x: float) -> float:
|
|||
|
||||
Returns the arc cosine of x in radians.
|
||||
"""
|
||||
return _C.acos(x)
|
||||
return _check1(x, _C.acos(x))
|
||||
|
||||
|
||||
def asin(x: float) -> float:
|
||||
|
@ -289,7 +339,7 @@ def asin(x: float) -> float:
|
|||
|
||||
Returns the arc sine of x in radians.
|
||||
"""
|
||||
return _C.asin(x)
|
||||
return _check1(x, _C.asin(x))
|
||||
|
||||
|
||||
def atan(x: float) -> float:
|
||||
|
@ -298,7 +348,7 @@ def atan(x: float) -> float:
|
|||
|
||||
Returns the arc tangent of x in radians.
|
||||
"""
|
||||
return _C.atan(x)
|
||||
return _check1(x, _C.atan(x))
|
||||
|
||||
|
||||
def atan2(y: float, x: float) -> float:
|
||||
|
@ -309,7 +359,7 @@ def atan2(y: float, x: float) -> float:
|
|||
on the signs of both values to determine the
|
||||
correct quadrant.
|
||||
"""
|
||||
return _C.atan2(y, x)
|
||||
return _check2(x, y, _C.atan2(y, x))
|
||||
|
||||
|
||||
def cos(x: float) -> float:
|
||||
|
@ -325,7 +375,7 @@ def cos(x: float) -> float:
|
|||
%y = call double @llvm.cos.f64(double %x)
|
||||
ret double %y
|
||||
|
||||
return f(x)
|
||||
return _check1(x, f(x))
|
||||
|
||||
|
||||
def sin(x: float) -> float:
|
||||
|
@ -341,7 +391,7 @@ def sin(x: float) -> float:
|
|||
%y = call double @llvm.sin.f64(double %x)
|
||||
ret double %y
|
||||
|
||||
return f(x)
|
||||
return _check1(x, f(x))
|
||||
|
||||
|
||||
def hypot(x: float, y: float) -> float:
|
||||
|
@ -352,7 +402,7 @@ def hypot(x: float, y: float) -> float:
|
|||
This is the length of the vector from the
|
||||
origin to point (x, y).
|
||||
"""
|
||||
return _C.hypot(x, y)
|
||||
return _check2(x, y, _C.hypot(x, y), True)
|
||||
|
||||
|
||||
def tan(x: float) -> float:
|
||||
|
@ -361,7 +411,7 @@ def tan(x: float) -> float:
|
|||
|
||||
Return the tangent of a radian angle x.
|
||||
"""
|
||||
return _C.tan(x)
|
||||
return _check1(x, _C.tan(x))
|
||||
|
||||
|
||||
def cosh(x: float) -> float:
|
||||
|
@ -370,7 +420,7 @@ def cosh(x: float) -> float:
|
|||
|
||||
Returns the hyperbolic cosine of x.
|
||||
"""
|
||||
return _C.cosh(x)
|
||||
return _check1(x, _C.cosh(x), True)
|
||||
|
||||
|
||||
def sinh(x: float) -> float:
|
||||
|
@ -379,7 +429,7 @@ def sinh(x: float) -> float:
|
|||
|
||||
Returns the hyperbolic sine of x.
|
||||
"""
|
||||
return _C.sinh(x)
|
||||
return _check1(x, _C.sinh(x), True)
|
||||
|
||||
|
||||
def tanh(x: float) -> float:
|
||||
|
@ -388,7 +438,7 @@ def tanh(x: float) -> float:
|
|||
|
||||
Returns the hyperbolic tangent of x.
|
||||
"""
|
||||
return _C.tanh(x)
|
||||
return _check1(x, _C.tanh(x))
|
||||
|
||||
|
||||
def acosh(x: float) -> float:
|
||||
|
@ -397,7 +447,7 @@ def acosh(x: float) -> float:
|
|||
|
||||
Return the inverse hyperbolic cosine of x.
|
||||
"""
|
||||
return _C.acosh(x)
|
||||
return _check1(x, _C.acosh(x))
|
||||
|
||||
|
||||
def asinh(x: float) -> float:
|
||||
|
@ -406,7 +456,7 @@ def asinh(x: float) -> float:
|
|||
|
||||
Return the inverse hyperbolic sine of x.
|
||||
"""
|
||||
return _C.asinh(x)
|
||||
return _check1(x, _C.asinh(x))
|
||||
|
||||
|
||||
def atanh(x: float) -> float:
|
||||
|
@ -415,7 +465,7 @@ def atanh(x: float) -> float:
|
|||
|
||||
Return the inverse hyperbolic tangent of x.
|
||||
"""
|
||||
return _C.atanh(x)
|
||||
return _check1(x, _C.atanh(x))
|
||||
|
||||
|
||||
def copysign(x: float, y: float) -> float:
|
||||
|
@ -432,7 +482,7 @@ def copysign(x: float, y: float) -> float:
|
|||
%z = call double @llvm.copysign.f64(double %x, double %y)
|
||||
ret double %z
|
||||
|
||||
return f(x, y)
|
||||
return _check2(x, y, f(x, y))
|
||||
|
||||
|
||||
def log1p(x: float) -> float:
|
||||
|
@ -441,7 +491,7 @@ def log1p(x: float) -> float:
|
|||
|
||||
Return the natural logarithm of 1+x (base e).
|
||||
"""
|
||||
return _C.log1p(x)
|
||||
return _check1(x, _C.log1p(x))
|
||||
|
||||
|
||||
def trunc(x: float) -> float:
|
||||
|
@ -458,7 +508,7 @@ def trunc(x: float) -> float:
|
|||
%y = call double @llvm.trunc.f64(double %x)
|
||||
ret double %y
|
||||
|
||||
return f(x)
|
||||
return _check1(x, f(x))
|
||||
|
||||
|
||||
def erf(x: float) -> float:
|
||||
|
@ -467,7 +517,7 @@ def erf(x: float) -> float:
|
|||
|
||||
Return the error function at x.
|
||||
"""
|
||||
return _C.erf(x)
|
||||
return _check1(x, _C.erf(x))
|
||||
|
||||
|
||||
def erfc(x: float) -> float:
|
||||
|
@ -476,7 +526,7 @@ def erfc(x: float) -> float:
|
|||
|
||||
Return the complementary error function at x.
|
||||
"""
|
||||
return _C.erfc(x)
|
||||
return _check1(x, _C.erfc(x))
|
||||
|
||||
|
||||
def gamma(x: float) -> float:
|
||||
|
@ -485,7 +535,7 @@ def gamma(x: float) -> float:
|
|||
|
||||
Return the Gamma function at x.
|
||||
"""
|
||||
return _C.tgamma(x)
|
||||
return _check1(x, _C.tgamma(x), True)
|
||||
|
||||
|
||||
def lgamma(x: float) -> float:
|
||||
|
@ -495,7 +545,7 @@ def lgamma(x: float) -> float:
|
|||
Return the natural logarithm of
|
||||
the absolute value of the Gamma function at x.
|
||||
"""
|
||||
return _C.lgamma(x)
|
||||
return _check1(x, _C.lgamma(x), True)
|
||||
|
||||
|
||||
def remainder(x: float, y: float) -> float:
|
||||
|
@ -509,7 +559,7 @@ def remainder(x: float, y: float) -> float:
|
|||
two consecutive integers, the nearest even integer is used
|
||||
for n.
|
||||
"""
|
||||
return _C.remainder(x, y)
|
||||
return _check2(x, y, _C.remainder(x, y))
|
||||
|
||||
|
||||
def gcd(a: float, b: float) -> float:
|
||||
|
@ -585,3 +635,615 @@ def isclose(a: float, b: float, rel_tol: float = 1e-09, abs_tol: float = 0.0) ->
|
|||
return ((diff <= fabs(rel_tol * b)) or (diff <= fabs(rel_tol * a))) or (
|
||||
diff <= abs_tol
|
||||
)
|
||||
|
||||
|
||||
# 32-bit float ops
|
||||
|
||||
e32 = float32(e)
|
||||
pi32 = float32(pi)
|
||||
tau32 = float32(tau)
|
||||
|
||||
inf32 = float32(inf)
|
||||
nan32 = float32(nan)
|
||||
|
||||
|
||||
@overload
|
||||
def isnan(x: float32) -> bool:
|
||||
"""
|
||||
isnan(float32) -> bool
|
||||
|
||||
Return True if float arg is a NaN, else False.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> bool:
|
||||
%y = fcmp uno float %x, 0.000000e+00
|
||||
%z = zext i1 %y to i8
|
||||
ret i8 %z
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def isinf(x: float32) -> bool:
|
||||
"""
|
||||
isinf(float32) -> bool:
|
||||
|
||||
Return True if float arg is an INF, else False.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> bool:
|
||||
declare float @llvm.fabs.f32(float)
|
||||
%a = call float @llvm.fabs.f32(float %x)
|
||||
%b = fcmp oeq float %a, 0x7FF0000000000000
|
||||
%c = zext i1 %b to i8
|
||||
ret i8 %c
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def isfinite(x: float32) -> bool:
|
||||
"""
|
||||
isfinite(float32) -> bool
|
||||
|
||||
Return True if x is neither an infinity nor a NaN,
|
||||
and False otherwise.
|
||||
"""
|
||||
return not (isnan(x) or isinf(x))
|
||||
|
||||
|
||||
@overload
|
||||
def ceil(x: float32) -> float32:
|
||||
"""
|
||||
ceil(float32) -> float32
|
||||
|
||||
Return the ceiling of x as an Integral.
|
||||
This is the smallest integer >= x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.ceil.f32(float)
|
||||
%y = call float @llvm.ceil.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def floor(x: float32) -> float32:
|
||||
"""
|
||||
floor(float32) -> float32
|
||||
|
||||
Return the floor of x as an Integral.
|
||||
This is the largest integer <= x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.floor.f32(float)
|
||||
%y = call float @llvm.floor.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def fabs(x: float32) -> float32:
|
||||
"""
|
||||
fabs(float32) -> float32
|
||||
|
||||
Returns the absolute value of a float32ing point number.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.fabs.f32(float)
|
||||
%y = call float @llvm.fabs.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def fmod(x: float32, y: float32) -> float32:
|
||||
"""
|
||||
fmod(float32, float32) -> float32
|
||||
|
||||
Returns the remainder of x divided by y.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32, y: float32) -> float32:
|
||||
%z = frem float %x, %y
|
||||
ret float %z
|
||||
|
||||
return f(x, y)
|
||||
|
||||
|
||||
@overload
|
||||
def exp(x: float32) -> float32:
|
||||
"""
|
||||
exp(float32) -> float32
|
||||
|
||||
Returns the value of e raised to the xth power.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.exp.f32(float)
|
||||
%y = call float @llvm.exp.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def expm1(x: float32) -> float32:
|
||||
"""
|
||||
expm1(float32) -> float32
|
||||
|
||||
Return e raised to the power x, minus 1. expm1 provides
|
||||
a way to compute this quantity to full precision.
|
||||
"""
|
||||
return _C.expm1f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def ldexp(x: float32, i: int) -> float32:
|
||||
"""
|
||||
ldexp(float32, int) -> float32
|
||||
|
||||
Returns x multiplied by 2 raised to the power of exponent.
|
||||
"""
|
||||
return _C.ldexpf(x, i32(i))
|
||||
|
||||
|
||||
@overload
|
||||
def log(x: float32, base: float32 = e32) -> float32:
|
||||
"""
|
||||
log(float32) -> float32
|
||||
|
||||
Returns the natural logarithm (base-e logarithm) of x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.log.f32(float)
|
||||
%y = call float @llvm.log.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
if base == e32:
|
||||
return f(x)
|
||||
else:
|
||||
return f(x) / f(base)
|
||||
|
||||
|
||||
@overload
|
||||
def log2(x: float32) -> float32:
|
||||
"""
|
||||
log2(float32) -> float32
|
||||
|
||||
Return the base-2 logarithm of x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.log2.f32(float)
|
||||
%y = call float @llvm.log2.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def log10(x: float32) -> float32:
|
||||
"""
|
||||
log10(float32) -> float32
|
||||
|
||||
Returns the common logarithm (base-10 logarithm) of x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.log10.f32(float)
|
||||
%y = call float @llvm.log10.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def degrees(x: float32) -> float32:
|
||||
"""
|
||||
degrees(float32) -> float32
|
||||
|
||||
Convert angle x from radians to degrees.
|
||||
"""
|
||||
radToDeg = float32(180.0) / pi32
|
||||
return x * radToDeg
|
||||
|
||||
|
||||
@overload
|
||||
def radians(x: float32) -> float32:
|
||||
"""
|
||||
radians(float32) -> float32
|
||||
|
||||
Convert angle x from degrees to radians.
|
||||
"""
|
||||
degToRad = pi32 / float32(180.0)
|
||||
return x * degToRad
|
||||
|
||||
@overload
|
||||
def sqrt(x: float32) -> float32:
|
||||
"""
|
||||
sqrt(float32) -> float32
|
||||
|
||||
Returns the square root of x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.sqrt.f32(float)
|
||||
%y = call float @llvm.sqrt.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def pow(x: float32, y: float32) -> float32:
|
||||
"""
|
||||
pow(float32, float32) -> float32
|
||||
|
||||
Returns x raised to the power of y.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32, y: float32) -> float32:
|
||||
declare float @llvm.pow.f32(float, float)
|
||||
%z = call float @llvm.pow.f32(float %x, float %y)
|
||||
ret float %z
|
||||
|
||||
return f(x, y)
|
||||
|
||||
|
||||
@overload
|
||||
def acos(x: float32) -> float32:
|
||||
"""
|
||||
acos(float32) -> float32
|
||||
|
||||
Returns the arc cosine of x in radians.
|
||||
"""
|
||||
return _C.acosf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def asin(x: float32) -> float32:
|
||||
"""
|
||||
asin(float32) -> float32
|
||||
|
||||
Returns the arc sine of x in radians.
|
||||
"""
|
||||
return _C.asinf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def atan(x: float32) -> float32:
|
||||
"""
|
||||
atan(float32) -> float32
|
||||
|
||||
Returns the arc tangent of x in radians.
|
||||
"""
|
||||
return _C.atanf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def atan2(y: float32, x: float32) -> float32:
|
||||
"""
|
||||
atan2(float32, float32) -> float32
|
||||
|
||||
Returns the arc tangent in radians of y/x based
|
||||
on the signs of both values to determine the
|
||||
correct quadrant.
|
||||
"""
|
||||
return _C.atan2f(y, x)
|
||||
|
||||
|
||||
@overload
|
||||
def cos(x: float32) -> float32:
|
||||
"""
|
||||
cos(float32) -> float32
|
||||
|
||||
Returns the cosine of a radian angle x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.cos.f32(float)
|
||||
%y = call float @llvm.cos.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def sin(x: float32) -> float32:
|
||||
"""
|
||||
sin(float32) -> float32
|
||||
|
||||
Returns the sine of a radian angle x.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.sin.f32(float)
|
||||
%y = call float @llvm.sin.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def hypot(x: float32, y: float32) -> float32:
|
||||
"""
|
||||
hypot(float32, float32) -> float32
|
||||
|
||||
Return the Euclidean norm.
|
||||
This is the length of the vector from the
|
||||
origin to point (x, y).
|
||||
"""
|
||||
return _C.hypotf(x, y)
|
||||
|
||||
|
||||
@overload
|
||||
def tan(x: float32) -> float32:
|
||||
"""
|
||||
tan(float32) -> float32
|
||||
|
||||
Return the tangent of a radian angle x.
|
||||
"""
|
||||
return _C.tanf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def cosh(x: float32) -> float32:
|
||||
"""
|
||||
cosh(float32) -> float32
|
||||
|
||||
Returns the hyperbolic cosine of x.
|
||||
"""
|
||||
return _C.coshf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def sinh(x: float32) -> float32:
|
||||
"""
|
||||
sinh(float32) -> float32
|
||||
|
||||
Returns the hyperbolic sine of x.
|
||||
"""
|
||||
return _C.sinhf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def tanh(x: float32) -> float32:
|
||||
"""
|
||||
tanh(float32) -> float32
|
||||
|
||||
Returns the hyperbolic tangent of x.
|
||||
"""
|
||||
return _C.tanhf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def acosh(x: float32) -> float32:
|
||||
"""
|
||||
acosh(float32) -> float32
|
||||
|
||||
Return the inverse hyperbolic cosine of x.
|
||||
"""
|
||||
return _C.acoshf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def asinh(x: float32) -> float32:
|
||||
"""
|
||||
asinh(float32) -> float32
|
||||
|
||||
Return the inverse hyperbolic sine of x.
|
||||
"""
|
||||
return _C.asinhf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def atanh(x: float32) -> float32:
|
||||
"""
|
||||
atanh(float32) -> float32
|
||||
|
||||
Return the inverse hyperbolic tangent of x.
|
||||
"""
|
||||
return _C.atanhf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def copysign(x: float32, y: float32) -> float32:
|
||||
"""
|
||||
copysign(float32, float32) -> float32
|
||||
|
||||
Return a float32 with the magnitude (absolute value) of
|
||||
x but the sign of y.
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32, y: float32) -> float32:
|
||||
declare float @llvm.copysign.f32(float, float)
|
||||
%z = call float @llvm.copysign.f32(float %x, float %y)
|
||||
ret float %z
|
||||
|
||||
return f(x, y)
|
||||
|
||||
|
||||
@overload
|
||||
def log1p(x: float32) -> float32:
|
||||
"""
|
||||
log1p(float32) -> float32
|
||||
|
||||
Return the natural logarithm of 1+x (base e).
|
||||
"""
|
||||
return _C.log1pf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def trunc(x: float32) -> float32:
|
||||
"""
|
||||
trunc(float32) -> float32
|
||||
|
||||
Return the Real value x truncated to an Integral
|
||||
(usually an integer).
|
||||
"""
|
||||
@pure
|
||||
@llvm
|
||||
def f(x: float32) -> float32:
|
||||
declare float @llvm.trunc.f32(float)
|
||||
%y = call float @llvm.trunc.f32(float %x)
|
||||
ret float %y
|
||||
|
||||
return f(x)
|
||||
|
||||
|
||||
@overload
|
||||
def erf(x: float32) -> float32:
|
||||
"""
|
||||
erf(float32) -> float32
|
||||
|
||||
Return the error function at x.
|
||||
"""
|
||||
return _C.erff(x)
|
||||
|
||||
|
||||
@overload
|
||||
def erfc(x: float32) -> float32:
|
||||
"""
|
||||
erfc(float32) -> float32
|
||||
|
||||
Return the complementary error function at x.
|
||||
"""
|
||||
return _C.erfcf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def gamma(x: float32) -> float32:
|
||||
"""
|
||||
gamma(float32) -> float32
|
||||
|
||||
Return the Gamma function at x.
|
||||
"""
|
||||
return _C.tgammaf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def lgamma(x: float32) -> float32:
|
||||
"""
|
||||
lgamma(float32) -> float32
|
||||
|
||||
Return the natural logarithm of
|
||||
the absolute value of the Gamma function at x.
|
||||
"""
|
||||
return _C.lgammaf(x)
|
||||
|
||||
|
||||
@overload
|
||||
def remainder(x: float32, y: float32) -> float32:
|
||||
"""
|
||||
remainder(float32, float32) -> float32
|
||||
|
||||
Return the IEEE 754-style remainder of x with respect to y.
|
||||
For finite x and finite nonzero y, this is the difference
|
||||
x - n*y, where n is the closest integer to the exact value
|
||||
of the quotient x / y. If x / y is exactly halfway between
|
||||
two consecutive integers, the nearest even integer is used
|
||||
for n.
|
||||
"""
|
||||
return _C.remainderf(x, y)
|
||||
|
||||
|
||||
@overload
|
||||
def gcd(a: float32, b: float32) -> float32:
|
||||
"""
|
||||
gcd(float32, float32) -> float32
|
||||
|
||||
returns greatest common divisor of x and y.
|
||||
"""
|
||||
a = abs(a)
|
||||
b = abs(b)
|
||||
while a:
|
||||
a, b = b % a, a
|
||||
return b
|
||||
|
||||
|
||||
@overload
|
||||
@pure
|
||||
def frexp(x: float32) -> Tuple[float32, int]:
|
||||
"""
|
||||
frexp(float32) -> Tuple[float32, int]
|
||||
|
||||
The returned value is the mantissa and the integer pointed
|
||||
to by exponent is the exponent. The resultant value is
|
||||
x = mantissa * 2 ^ exponent.
|
||||
"""
|
||||
tmp = i32(0)
|
||||
res = _C.frexpf(float32(x), __ptr__(tmp))
|
||||
return (res, int(tmp))
|
||||
|
||||
|
||||
@overload
|
||||
@pure
|
||||
def modf(x: float32) -> Tuple[float32, float32]:
|
||||
"""
|
||||
modf(float32) -> Tuple[float32, float32]
|
||||
|
||||
The returned value is the fraction component (part after
|
||||
the decimal), and sets integer to the integer component.
|
||||
"""
|
||||
tmp = float32(0.0)
|
||||
res = _C.modff(float32(x), __ptr__(tmp))
|
||||
return (res, tmp)
|
||||
|
||||
|
||||
@overload
|
||||
def isclose(a: float32, b: float32, rel_tol: float32 = float32(1e-09), abs_tol: float32 = float32(0.0)) -> bool:
|
||||
"""
|
||||
isclose(float32, float32) -> bool
|
||||
|
||||
Return True if a is close in value to b, and False otherwise.
|
||||
For the values to be considered close, the difference between them
|
||||
must be smaller than at least one of the tolerances.
|
||||
"""
|
||||
|
||||
# short circuit exact equality -- needed to catch two
|
||||
# infinities of the same sign. And perhaps speeds things
|
||||
# up a bit sometimes.
|
||||
if a == b:
|
||||
return True
|
||||
|
||||
# This catches the case of two infinities of opposite sign, or
|
||||
# one infinity and one finite number. Two infinities of opposite
|
||||
# sign would otherwise have an infinite relative tolerance.
|
||||
# Two infinities of the same sign are caught by the equality check
|
||||
# above.
|
||||
if a == inf32 or b == inf32:
|
||||
return False
|
||||
|
||||
# NAN is not close to anything, not even itself
|
||||
if a == nan32 or b == nan32:
|
||||
return False
|
||||
|
||||
# regular computation
|
||||
diff = fabs(b - a)
|
||||
|
||||
return ((diff <= fabs(rel_tol * b)) or (diff <= fabs(rel_tol * a))) or (
|
||||
diff <= abs_tol
|
||||
)
|
||||
|
|
|
@ -643,6 +643,7 @@ def _task_loop_outline_template(gtid_ptr: Ptr[i32], btid_ptr: Ptr[i32], args):
|
|||
_loop_reductions(shared)
|
||||
_barrier(loc_ref, gtid)
|
||||
|
||||
|
||||
@pure
|
||||
def get_num_threads():
|
||||
from C import omp_get_num_threads() -> i32
|
||||
|
@ -952,10 +953,41 @@ def _atomic_float_max(a: Ptr[float], b: float):
|
|||
__kmpc_atomic_float8_max(_default_loc(), i32(0), a, b)
|
||||
|
||||
|
||||
def _atomic_float32_add(a: Ptr[float32], b: float32) -> None:
|
||||
from C import __kmpc_atomic_float4_add(Ptr[Ident], i32, Ptr[float32], float32)
|
||||
__kmpc_atomic_float4_add(_default_loc(), i32(0), a, b)
|
||||
|
||||
|
||||
def _atomic_float32_mul(a: Ptr[float32], b: float32):
|
||||
from C import __kmpc_atomic_float4_mul(Ptr[Ident], i32, Ptr[float32], float32)
|
||||
__kmpc_atomic_float4_mul(_default_loc(), i32(0), a, b)
|
||||
|
||||
|
||||
def _atomic_float32_min(a: Ptr[float32], b: float32) -> None:
|
||||
from C import __kmpc_atomic_float4_min(Ptr[Ident], i32, Ptr[float32], float32)
|
||||
__kmpc_atomic_float4_min(_default_loc(), i32(0), a, b)
|
||||
|
||||
|
||||
def _atomic_float32_max(a: Ptr[float32], b: float32) -> None:
|
||||
from C import __kmpc_atomic_float4_max(Ptr[Ident], i32, Ptr[float32], float32)
|
||||
__kmpc_atomic_float4_max(_default_loc(), i32(0), a, b)
|
||||
|
||||
|
||||
def _range_len(start: int, stop: int, step: int):
|
||||
if step > 0 and start < stop:
|
||||
return 1 + (stop - 1 - start) // step
|
||||
elif step < 0 and start > stop:
|
||||
return 1 + (start - 1 - stop) // (-step)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def for_par(
|
||||
num_threads: int = -1,
|
||||
chunk_size: int = -1,
|
||||
schedule: Static[str] = "static",
|
||||
ordered: Static[int] = False,
|
||||
collapse: Static[int] = 0,
|
||||
gpu: Static[int] = False,
|
||||
):
|
||||
pass
|
||||
|
|
|
@ -77,6 +77,15 @@ class float:
|
|||
return _read(jar, float)
|
||||
|
||||
|
||||
@extend
|
||||
class float32:
|
||||
def __pickle__(self, jar: Jar):
|
||||
_write(jar, self)
|
||||
|
||||
def __unpickle__(jar: Jar) -> float32:
|
||||
return _read(jar, float32)
|
||||
|
||||
|
||||
@extend
|
||||
class bool:
|
||||
def __pickle__(self, jar: Jar):
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue