GPU and other updates (#52)

* Add nvptx pass

* Fix spaces

* Don't change name

* Add runtime support

* Add init call

* Add more runtime functions

* Add launch function

* Add intrinsics

* Fix codegen

* Run GPU pass between general opt passes

* Set data layout

* Create context

* Link libdevice

* Add function remapping

* Fix linkage

* Fix libdevice link

* Fix linking

* Fix personality

* Fix linking

* Fix linking

* Fix linking

* Add internalize pass

* Add more math conversions

* Add more re-mappings

* Fix conversions

* Fix __str__

* Add decorator attribute for any decorator

* Update kernel decorator

* Fix kernel decorator

* Fix kernel decorator

* Fix kernel decorator

* Fix kernel decorator

* Remove old decorator

* Fix pointer calc

* Fix fill-in codegen

* Fix linkage

* Add comment

* Update list conversion

* Add more conversions

* Add dict and set conversions

* Add float32 type to IR/LLVM

* Add float32

* Add float32 stdlib

* Keep required global values in PTX module

* Fix PTX module pruning

* Fix malloc

* Set will-return

* Fix name cleanup

* Fix access

* Fix name cleanup

* Fix function renaming

* Update dimension API

* Fix args

* Clean up API

* Move GPU transformations to end of opt pipeline

* Fix alloc replacements

* Fix naming

* Target PTX 4.2

* Fix global renaming

* Fix early return in static blocks; Add __realized__ function

* Format

* Add __llvm_name__ for functions

* Add vector type to IR

* SIMD support [wip]

* Update kernel naming

* Fix early returns; Fix SIMD calls

* Fix kernel naming

* Fix IR matcher

* Remove module print

* Update realloc

* Add overloads for 32-bit float math ops

* Add gpu.Pointer type for working with raw pointers

* Add float32 conversion

* Add to_gpu and from_gpu

* clang-format

* Add f32 reduction support to OpenMP

* Fix automatic GPU class conversions

* Fix conversion functions

* Fix conversions

* Rename self

* Fix tuple conversion

* Fix conversions

* Fix conversions

* Update PTX filename

* Fix filename

* Add raw function

* Add GPU docs

* Allow nested object conversions

* Add tests (WIP)

* Update SIMD

* Add staticrange and statictuple loop support

* SIMD updates

* Add new Vec constructors

* Fix UInt conversion

* Fix size-0 allocs

* Add more tests

* Add matmul test

* Rename gpu test file

* Add more tests

* Add alloc cache

* Fix object_to_gpu

* Fix frees

* Fix str conversion

* Fix set conversion

* Fix conversions

* Fix class conversion

* Fix str conversion

* Fix byte conversion

* Fix list conversion

* Fix pointer conversions

* Fix conversions

* Fix conversions

* Update tests

* Fix conversions

* Fix tuple conversion

* Fix tuple conversion

* Fix auto conversions

* Fix conversion

* Fix magics

* Update tests

* Support GPU in JIT mode

* Fix GPU+JIT

* Fix kernel filename in JIT mode

* Add __static_print__; Add earlyDefines; Various domination bugfixes; SimplifyContext RAII base handling

* Fix global static handling

* Fix float32 tests

* FIx gpu module

* Support OpenMP "collapse" option

* Add more collapse tests

* Capture generics and statics

* TraitVar handling

* Python exceptions / isinstance [wip; no_ci]

* clang-format

* Add list comparison operators

* Support empty raise in IR

* Add dict 'or' operator

* Fix repr

* Add copy module

* Fix spacing

* Use sm_30

* Python exceptions

* TypeTrait support; Fix defaultDict

* Fix earlyDefines

* Add defaultdict

* clang-format

* Fix invalid canonicalizations

* Fix empty raise

* Fix copyright

* Add Python numerics option

* Support py-numerics in math module

* Update docs

* Add static Python division / modulus

* Add static py numerics tests

* Fix staticrange/tuple; Add KwTuple.__getitem__

* clang-format

* Add gpu parameter to par

* Fix globals

* Don't init loop vars on loop collapse

* Add par-gpu tests

* Update gpu docs

* Fix isinstance check

* Remove invalid test

* Add -libdevice to set custom path [skip ci]

* Add release notes; bump version [skip ci]

* Add libdevice docs [skip ci]

Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
pull/48/head^2
A. R. Shajii 2022-09-15 15:40:00 -04:00 committed by GitHub
parent 3379c064eb
commit ebd344f894
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
115 changed files with 7505 additions and 617 deletions

View File

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.14)
project(
Codon
VERSION "0.13.0"
VERSION "0.14.0"
HOMEPAGE_URL "https://github.com/exaloop/codon"
DESCRIPTION "high-performance, extensible Python compiler")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
@ -10,6 +10,7 @@ configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
"${PROJECT_SOURCE_DIR}/extra/python/config/config.py")
option(CODON_JUPYTER "build Codon Jupyter server" OFF)
option(CODON_GPU "build Codon GPU backend" OFF)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
@ -59,7 +60,8 @@ add_custom_command(
# Codon runtime library
set(CODONRT_FILES codon/runtime/lib.h codon/runtime/lib.cpp
codon/runtime/re.cpp codon/runtime/exc.cpp)
codon/runtime/re.cpp codon/runtime/exc.cpp
codon/runtime/gpu.cpp)
add_library(codonrt SHARED ${CODONRT_FILES})
add_dependencies(codonrt zlibstatic gc backtrace bz2 liblzma re2)
target_include_directories(codonrt PRIVATE ${backtrace_SOURCE_DIR}
@ -90,6 +92,11 @@ if(ASAN)
codonrt PRIVATE "-fno-omit-frame-pointer" "-fsanitize=address"
"-fsanitize-recover=address")
endif()
if(CODON_GPU)
add_compile_definitions(CODON_GPU)
find_package(CUDAToolkit REQUIRED)
target_link_libraries(codonrt PRIVATE CUDA::cudart_static CUDA::cuda_driver)
endif()
add_custom_command(
TARGET codonrt
POST_BUILD
@ -148,6 +155,7 @@ set(CODON_HPPFILES
codon/sir/llvm/coro/CoroInternal.h
codon/sir/llvm/coro/CoroSplit.h
codon/sir/llvm/coro/Coroutines.h
codon/sir/llvm/gpu.h
codon/sir/llvm/llvisitor.h
codon/sir/llvm/llvm.h
codon/sir/llvm/optimize.h
@ -295,6 +303,7 @@ set(CODON_CPPFILES
codon/sir/llvm/coro/CoroFrame.cpp
codon/sir/llvm/coro/CoroSplit.cpp
codon/sir/llvm/coro/Coroutines.cpp
codon/sir/llvm/gpu.cpp
codon/sir/llvm/llvisitor.cpp
codon/sir/llvm/optimize.cpp
codon/sir/module.cpp

View File

@ -72,6 +72,7 @@ void initLogFlags(const llvm::cl::opt<std::string> &log) {
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, Detect };
enum OptMode { Debug, Release };
enum Numerics { C, Python };
} // namespace
int docMode(const std::vector<const char *> &args, const std::string &argv0) {
@ -116,6 +117,14 @@ std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &
llvm::cl::list<std::string> plugins("plugin",
llvm::cl::desc("Load specified plugin"));
llvm::cl::opt<std::string> log("log", llvm::cl::desc("Enable given log streams"));
llvm::cl::opt<Numerics> numerics(
"numerics", llvm::cl::desc("numerical semantics"),
llvm::cl::values(
clEnumValN(C, "c", "C semantics: best performance but deviates from Python"),
clEnumValN(Python, "py",
"Python semantics: mirrors Python but might disable optimizations "
"like vectorization")),
llvm::cl::init(C));
llvm::cl::ParseCommandLineOptions(args.size(), args.data());
initLogFlags(log);
@ -148,7 +157,9 @@ std::unique_ptr<codon::Compiler> processSource(const std::vector<const char *> &
const bool isDebug = (optMode == OptMode::Debug);
std::vector<std::string> disabledOptsVec(disabledOpts);
auto compiler = std::make_unique<codon::Compiler>(args[0], isDebug, disabledOptsVec);
auto compiler = std::make_unique<codon::Compiler>(args[0], isDebug, disabledOptsVec,
/*isTest=*/false,
(numerics == Numerics::Python));
compiler->getLLVMVisitor()->setStandalone(standalone);
// load plugins

View File

@ -29,15 +29,17 @@ ir::transform::PassManager::Init getPassManagerInit(Compiler::Mode mode, bool is
} // namespace
Compiler::Compiler(const std::string &argv0, Compiler::Mode mode,
const std::vector<std::string> &disabledPasses, bool isTest)
: argv0(argv0), debug(mode == Mode::DEBUG), input(),
const std::vector<std::string> &disabledPasses, bool isTest,
bool pyNumerics)
: argv0(argv0), debug(mode == Mode::DEBUG), pyNumerics(pyNumerics), input(),
plm(std::make_unique<PluginManager>()),
cache(std::make_unique<ast::Cache>(argv0)),
module(std::make_unique<ir::Module>()),
pm(std::make_unique<ir::transform::PassManager>(getPassManagerInit(mode, isTest),
disabledPasses)),
disabledPasses, pyNumerics)),
llvisitor(std::make_unique<ir::LLVMVisitor>()) {
cache->module = module.get();
cache->pythonCompat = pyNumerics;
module->setCache(cache.get());
llvisitor->setDebug(debug);
llvisitor->setPluginManager(plm.get());
@ -77,8 +79,9 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code,
Timer t2("simplify");
t2.logged = true;
auto transformed = ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt),
abspath, defines, (testFlags > 1));
auto transformed =
ast::SimplifyVisitor::apply(cache.get(), std::move(codeStmt), abspath, defines,
getEarlyDefines(), (testFlags > 1));
LOG_TIME("[T] parse = {:.1f}", totalPeg);
LOG_TIME("[T] simplify = {:.1f}", t2.elapsed() - totalPeg);
@ -171,4 +174,11 @@ llvm::Expected<std::string> Compiler::docgen(const std::vector<std::string> &fil
}
}
std::unordered_map<std::string, std::string> Compiler::getEarlyDefines() {
std::unordered_map<std::string, std::string> earlyDefines;
earlyDefines.emplace("__debug__", debug ? "1" : "0");
earlyDefines.emplace("__py_numerics__", pyNumerics ? "1" : "0");
return earlyDefines;
}
} // namespace codon

View File

@ -25,6 +25,7 @@ public:
private:
std::string argv0;
bool debug;
bool pyNumerics;
std::string input;
std::unique_ptr<PluginManager> plm;
std::unique_ptr<ast::Cache> cache;
@ -38,12 +39,14 @@ private:
public:
Compiler(const std::string &argv0, Mode mode,
const std::vector<std::string> &disabledPasses = {}, bool isTest = false);
const std::vector<std::string> &disabledPasses = {}, bool isTest = false,
bool pyNumerics = false);
explicit Compiler(const std::string &argv0, bool debug = false,
const std::vector<std::string> &disabledPasses = {},
bool isTest = false)
: Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest) {}
bool isTest = false, bool pyNumerics = false)
: Compiler(argv0, debug ? Mode::DEBUG : Mode::RELEASE, disabledPasses, isTest,
pyNumerics) {}
std::string getInput() const { return input; }
PluginManager *getPluginManager() const { return plm.get(); }
@ -62,6 +65,8 @@ public:
const std::unordered_map<std::string, std::string> &defines = {});
llvm::Error compile();
llvm::Expected<std::string> docgen(const std::vector<std::string> &files);
std::unordered_map<std::string, std::string> getEarlyDefines();
};
} // namespace codon

View File

@ -37,8 +37,9 @@ llvm::Error JIT::init() {
auto *pm = compiler->getPassManager();
auto *llvisitor = compiler->getLLVMVisitor();
auto transformed = ast::SimplifyVisitor::apply(
cache, std::make_shared<ast::SuiteStmt>(), JIT_FILENAME, {});
auto transformed =
ast::SimplifyVisitor::apply(cache, std::make_shared<ast::SuiteStmt>(),
JIT_FILENAME, {}, compiler->getEarlyDefines());
auto typechecked = ast::TypecheckVisitor::apply(cache, std::move(transformed));
ast::TranslateVisitor::apply(cache, std::move(typechecked));

View File

@ -5,6 +5,7 @@
#include <vector>
#include "codon/parser/ast.h"
#include "codon/parser/cache.h"
#include "codon/parser/visitors/visitor.h"
#define ACCEPT_IMPL(T, X) \
@ -17,7 +18,7 @@ namespace codon::ast {
Expr::Expr()
: type(nullptr), isTypeExpr(false), staticValue(StaticValue::NOT_STATIC),
done(false), attributes(0) {}
done(false), attributes(0), origExpr(nullptr) {}
void Expr::validate() const {}
types::TypePtr Expr::getType() const { return type; }
void Expr::setType(types::TypePtr t) { this->type = std::move(t); }
@ -65,7 +66,8 @@ Param::Param(std::string name, ExprPtr type, ExprPtr defaultValue, int status)
: name(std::move(name)), type(std::move(type)),
defaultValue(std::move(defaultValue)) {
if (status == 0 && this->type &&
(this->type->isId("type") || this->type->isId("TypeVar") ||
(this->type->isId("type") || this->type->isId(TYPE_TYPEVAR) ||
(this->type->getIndex() && this->type->getIndex()->expr->isId(TYPE_TYPEVAR)) ||
getStaticGeneric(this->type.get())))
this->status = Generic;
else

View File

@ -80,6 +80,9 @@ struct Expr : public codon::SrcObject {
/// Set of attributes.
int attributes;
/// Original (pre-transformation) expression
std::shared_ptr<Expr> origExpr;
public:
Expr();
Expr(const Expr &expr) = default;

View File

@ -339,11 +339,15 @@ std::string FunctionStmt::toString(int indent) const {
std::vector<std::string> as;
for (auto &a : args)
as.push_back(a.toString());
std::vector<std::string> attr;
std::vector<std::string> dec, attr;
for (auto &a : decorators)
attr.push_back(format("(dec {})", a->toString()));
return format("(fn '{} ({}){}{}{}{})", name, join(as, " "),
if (a)
dec.push_back(format("(dec {})", a->toString()));
for (auto &a : attributes.customAttr)
attr.push_back(format("'{}'", a));
return format("(fn '{} ({}){}{}{}{}{})", name, join(as, " "),
ret ? " #:ret " + ret->toString() : "",
dec.empty() ? "" : format(" (dec {})", join(dec, " ")),
attr.empty() ? "" : format(" (attr {})", join(attr, " ")), pad,
suite ? suite->toString(indent >= 0 ? indent + INDENT_SIZE : -1)
: "(suite)");
@ -496,7 +500,8 @@ void ClassStmt::parseDecorators() {
{"new", true}, {"repr", false}, {"hash", false}, {"eq", false},
{"ne", false}, {"lt", false}, {"le", false}, {"gt", false},
{"ge", false}, {"pickle", true}, {"unpickle", true}, {"to_py", false},
{"from_py", false}, {"iter", false}, {"getitem", false}, {"len", false}};
{"from_py", false}, {"iter", false}, {"getitem", false}, {"len", false},
{"to_gpu", false}, {"from_gpu", false}, {"from_gpu_new", false}};
for (auto &d : decorators) {
if (d->isId("deduce")) {
@ -531,6 +536,9 @@ void ClassStmt::parseDecorators() {
tupleMagics["pickle"] = tupleMagics["unpickle"] = val;
} else if (a.name == "python") {
tupleMagics["to_py"] = tupleMagics["from_py"] = val;
} else if (a.name == "gpu") {
tupleMagics["to_gpu"] = tupleMagics["from_gpu"] =
tupleMagics["from_gpu_new"] = val;
} else if (a.name == "container") {
tupleMagics["iter"] = tupleMagics["getitem"] = val;
} else {

View File

@ -238,6 +238,9 @@ struct WhileStmt : public Stmt {
StmtPtr suite;
/// nullptr if there is no else suite.
StmtPtr elseSuite;
/// Set if a while loop is used to emulate goto statement
/// (as `while gotoVar: ...`).
std::string gotoVar = "";
WhileStmt(ExprPtr cond, StmtPtr suite, StmtPtr elseSuite = nullptr);
WhileStmt(const WhileStmt &stmt);

View File

@ -312,10 +312,9 @@ std::string ClassType::debugString(bool debug) const {
if (!a.name.empty())
gs.push_back(a.type->debugString(debug));
if (debug && !hiddenGenerics.empty()) {
gs.emplace_back("//");
for (auto &a : hiddenGenerics)
if (!a.name.empty())
gs.push_back(a.type->debugString(debug));
gs.push_back("-" + a.type->debugString(debug));
}
// Special formatting for Functions and Tuples
auto n = niceName;
@ -829,4 +828,21 @@ std::string CallableTrait::debugString(bool debug) const {
return fmt::format("Callable[{}]", join(gs, ","));
}
TypeTrait::TypeTrait(TypePtr typ) : type(std::move(typ)) {}
int TypeTrait::unify(Type *typ, Unification *us) { return typ->unify(type.get(), us); }
TypePtr TypeTrait::generalize(int atLevel) {
auto c = std::make_shared<TypeTrait>(type->generalize(atLevel));
c->setSrcInfo(getSrcInfo());
return c;
}
TypePtr TypeTrait::instantiate(int atLevel, int *unboundCount,
std::unordered_map<int, TypePtr> *cache) {
auto c = std::make_shared<TypeTrait>(type->instantiate(atLevel, unboundCount, cache));
c->setSrcInfo(getSrcInfo());
return c;
}
std::string TypeTrait::debugString(bool debug) const {
return fmt::format("Trait[{}]", type->debugString(debug));
}
} // namespace codon::ast::types

View File

@ -422,5 +422,17 @@ public:
std::string debugString(bool debug) const override;
};
struct TypeTrait : public Trait {
TypePtr type;
public:
explicit TypeTrait(TypePtr type);
int unify(Type *typ, Unification *undo) override;
TypePtr generalize(int atLevel) override;
TypePtr instantiate(int atLevel, int *unboundCount,
std::unordered_map<int, TypePtr> *cache) override;
std::string debugString(bool debug) const override;
};
} // namespace types
} // namespace codon::ast

View File

@ -31,6 +31,12 @@ std::string Cache::rev(const std::string &s) {
return "";
}
void Cache::addGlobal(const std::string &name, ir::Var *var) {
if (!in(globals, name)) {
globals[name] = var;
}
}
SrcInfo Cache::generateSrcInfo() {
return {FILE_GENERATED, generatedSrcInfoCount, generatedSrcInfoCount++, 0};
}

View File

@ -19,6 +19,7 @@
#define TYPE_TUPLE "Tuple.N"
#define TYPE_KWTUPLE "KwTuple.N"
#define TYPE_TYPEVAR "TypeVar"
#define TYPE_CALLABLE "Callable"
#define TYPE_PARTIAL "Partial.N"
#define TYPE_OPTIONAL "Optional"
@ -28,6 +29,7 @@
#define MAX_INT_WIDTH 10000
#define MAX_REALIZATION_DEPTH 200
#define MAX_STATIC_ITER 1024
namespace codon::ast {
@ -207,6 +209,9 @@ struct Cache : public std::enable_shared_from_this<Cache> {
std::unordered_map<std::string, int> generatedTuples;
std::vector<exc::ParserException> errors;
/// Set if Codon operates in Python compatibility mode (e.g., with Python numerics)
bool pythonCompat = false;
public:
explicit Cache(std::string argv0 = "");
@ -220,6 +225,8 @@ public:
SrcInfo generateSrcInfo();
/// Get file contents at the given location.
std::string getContent(const SrcInfo &info);
/// Register a global identifier.
void addGlobal(const std::string &name, ir::Var *var = nullptr);
/// Realization API.

View File

@ -32,6 +32,12 @@ clause <-
/ "ordered" {
return vector<CallExpr::Arg>{{"ordered", make_shared<BoolExpr>(true)}};
}
/ "collapse" {
return vector<CallExpr::Arg>{{"collapse", make_shared<IntExpr>(ac<int>(V0))}};
}
/ "gpu" {
return vector<CallExpr::Arg>{{"gpu", make_shared<BoolExpr>(true)}};
}
schedule_kind <- ("static" / "dynamic" / "guided" / "auto" / "runtime") {
return VS.token_to_string();
}

View File

@ -34,6 +34,7 @@ std::shared_ptr<peg::Grammar> initParser() {
x.second.accept(v);
}
(*g)["program"].enablePackratParsing = true;
(*g)["fstring"].enablePackratParsing = true;
for (auto &rule : std::vector<std::string>{
"arguments", "slices", "genexp", "parentheses", "star_parens", "generics",
"with_parens_item", "params", "from_as_parens", "from_params"}) {

View File

@ -84,6 +84,7 @@ void SimplifyVisitor::visit(IdExpr *expr) {
/// @example
/// `a.b.c` -> canonical name of `c` in `a.b` if `a.b` is an import
/// `a.B.c` -> canonical name of `c` in class `a.B`
/// `python.foo` -> internal.python._get_identifier("foo")
/// Other cases are handled during the type checking.
void SimplifyVisitor::visit(DotExpr *expr) {
// First flatten the imports:
@ -99,7 +100,11 @@ void SimplifyVisitor::visit(DotExpr *expr) {
std::reverse(chain.begin(), chain.end());
auto p = getImport(chain);
if (p.second->getModule() == ctx->getModule() && p.first == 1) {
if (p.second->getModule() == "std.python") {
resultExpr = transform(N<CallExpr>(
N<DotExpr>(N<DotExpr>(N<IdExpr>("internal"), "python"), "_get_identifier"),
N<StringExpr>(chain[p.first++])));
} else if (p.second->getModule() == ctx->getModule() && p.first == 1) {
resultExpr = transform(N<IdExpr>(chain[0]), true);
} else {
resultExpr = N<IdExpr>(p.second->canonicalName);
@ -120,57 +125,74 @@ void SimplifyVisitor::visit(DotExpr *expr) {
bool SimplifyVisitor::checkCapture(const SimplifyContext::Item &val) {
if (!ctx->isOuter(val))
return false;
if ((val->isType() && !val->isGeneric()) || val->isFunc())
return false;
// Ensure that outer variables can be captured (i.e., do not cross no-capture
// boundary). Example:
// def foo():
// x = 1
// class T: # <- boundary (class methods cannot capture locals)
// def bar():
// class T: # <- boundary (classes cannot capture locals)
// t: int = x # x cannot be accessed
// def bar(): # <- another boundary
// # (class methods cannot capture locals except class generics)
// print(x) # x cannot be accessed
bool crossCaptureBoundary = false;
bool localGeneric = val->isGeneric() && val->getBaseName() == ctx->getBaseName();
bool parentClassGeneric =
val->isGeneric() && !ctx->getBase()->isType() &&
(ctx->bases.size() > 1 && ctx->bases[ctx->bases.size() - 2].isType() &&
ctx->bases[ctx->bases.size() - 2].name == val->getBaseName());
auto i = ctx->bases.size();
for (; i-- > 0;) {
if (ctx->bases[i].name == val->getBaseName())
break;
if (!ctx->bases[i].captures)
if (!localGeneric && !parentClassGeneric && !ctx->bases[i].captures)
crossCaptureBoundary = true;
}
seqassert(i < ctx->bases.size(), "invalid base for '{}'", val->canonicalName);
// Disallow outer generics except for class generics in methods
if (val->isGeneric() && !(ctx->bases[i].isType() && i + 2 == ctx->bases.size()))
error("cannot access nonlocal variable '{}'", ctx->cache->rev(val->canonicalName));
// Mark methods (class functions that access class generics)
if (val->isGeneric() && ctx->bases[i].isType() && i + 2 == ctx->bases.size() &&
ctx->getBase()->attributes)
if (parentClassGeneric)
ctx->getBase()->attributes->set(Attr::Method);
// Check if a real variable (not a static) is defined outside the current scope
if (!val->isVar() || val->isGeneric())
// Ignore generics
if (parentClassGeneric || localGeneric)
return false;
// Case: a global variable that has not been marked with `global` statement
if (val->getBaseName().empty()) { /// TODO: use isGlobal instead?
if (val->isVar() && val->getBaseName().empty()) {
val->noShadow = true;
if (val->scope.size() == 1 && !in(ctx->cache->globals, val->canonicalName))
ctx->cache->globals[val->canonicalName] = nullptr;
if (val->scope.size() == 1)
ctx->cache->addGlobal(val->canonicalName);
return false;
}
// Check if a real variable (not a static) is defined outside the current scope
if (crossCaptureBoundary)
error("cannot access nonlocal variable '{}'", ctx->cache->rev(val->canonicalName));
// Case: a nonlocal variable that has not been marked with `nonlocal` statement
// and capturing is enabled
auto captures = ctx->getBase()->captures;
if (!crossCaptureBoundary && captures && !in(*captures, val->canonicalName)) {
if (captures && !in(*captures, val->canonicalName)) {
// Captures are transformed to function arguments; generate new name for that
// argument
auto newName = (*captures)[val->canonicalName] =
ctx->generateCanonicalName(val->canonicalName);
ExprPtr typ = nullptr;
if (val->isType())
typ = N<IdExpr>("type");
if (auto st = val->isStatic())
typ = N<IndexExpr>(N<IdExpr>("Static"),
N<IdExpr>(st == StaticValue::INT ? "int" : "str"));
auto [newName, _] = (*captures)[val->canonicalName] = {
ctx->generateCanonicalName(val->canonicalName), typ};
ctx->cache->reverseIdentifierLookup[newName] = newName;
// Add newly generated argument to the context
auto newVal =
ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo());
std::shared_ptr<SimplifyItem> newVal = nullptr;
if (val->isType())
newVal = ctx->addType(ctx->cache->rev(val->canonicalName), newName, getSrcInfo());
else
newVal = ctx->addVar(ctx->cache->rev(val->canonicalName), newName, getSrcInfo());
newVal->baseName = ctx->getBaseName();
newVal->noShadow = true;
return true;
@ -206,12 +228,20 @@ SimplifyVisitor::getImport(const std::vector<std::string> &chain) {
size_t itemEnd = 0;
auto fctx = importName.empty() ? ctx : ctx->cache->imports[importName].ctx;
for (auto i = chain.size(); i-- > importEnd;) {
if (fctx->getModule() == "std.python" && importEnd < chain.size()) {
// Special case: importing from Python.
// Fake SimplifyItem that inidcates std.python access
val = std::make_shared<SimplifyItem>(SimplifyItem::Var, "", "",
fctx->getModule(), std::vector<int>{});
return {importEnd, val};
} else {
val = fctx->find(join(chain, ".", importEnd, i + 1));
if (val && (importName.empty() || val->isType() || !val->isConditional())) {
itemName = val->canonicalName, itemEnd = i + 1;
break;
}
}
}
if (itemName.empty() && importName.empty())
error("identifier '{}' not found", chain[importEnd]);
if (itemName.empty())

View File

@ -147,14 +147,15 @@ StmtPtr SimplifyVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr t
if (rhs && rhs->isType()) {
ctx->addType(e->value, canonical, lhs->getSrcInfo());
} else {
ctx->addVar(e->value, canonical, lhs->getSrcInfo());
auto val = ctx->addVar(e->value, canonical, lhs->getSrcInfo());
if (auto st = getStaticGeneric(type.get()))
val->staticType = st;
}
// Register all toplevel variables as global in JIT mode
bool isGlobal = (ctx->cache->isJit && ctx->isGlobal()) || (canonical == VAR_ARGV);
if (isGlobal && !in(ctx->cache->globals, canonical)) {
ctx->cache->globals[canonical] = nullptr;
}
if (isGlobal)
ctx->cache->addGlobal(canonical);
return assign;
}

View File

@ -50,9 +50,13 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
argsToParse = astIter->second.ast->args;
}
std::vector<StmtPtr> clsStmts; // Will be filled later!
std::vector<StmtPtr> varStmts; // Will be filled later!
std::vector<StmtPtr> fnStmts; // Will be filled later!
std::vector<SimplifyContext::Item> addLater;
{
// Add the class base
ctx->bases.emplace_back(SimplifyContext::Base(canonicalName));
ctx->addBlock();
SimplifyContext::BaseGuard br(ctx.get(), canonicalName);
// Parse and add class generics
std::vector<Param> args;
@ -70,10 +74,13 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
varName = a.name, genName = ctx->cache->rev(a.name);
else
varName = ctx->generateCanonicalName(a.name), genName = a.name;
if (getStaticGeneric(a.type.get()))
ctx->addVar(genName, varName, a.type->getSrcInfo())->generic = true;
else
if (auto st = getStaticGeneric(a.type.get())) {
auto val = ctx->addVar(genName, varName, a.type->getSrcInfo());
val->generic = true;
val->staticType = st;
} else {
ctx->addType(genName, varName, a.type->getSrcInfo())->generic = true;
}
args.emplace_back(Param{varName, transformType(clone(a.type), false),
transformType(clone(a.defaultValue), false), a.status});
}
@ -94,9 +101,6 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
// A ClassStmt will be separated into class variable assignments, method-free
// ClassStmts (that include nested classes) and method FunctionStmts
std::vector<StmtPtr> clsStmts; // Will be filled later!
std::vector<StmtPtr> varStmts; // Will be filled later!
std::vector<StmtPtr> fnStmts; // Will be filled later!
transformNestedClasses(stmt, clsStmts, varStmts, fnStmts);
// Collect class fields
@ -109,8 +113,7 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
// Handle class variables. Transform them later to allow self-references
auto name = format("{}.{}", canonicalName, a.name);
preamble->push_back(N<AssignStmt>(N<IdExpr>(name), nullptr, nullptr));
if (!in(ctx->cache->globals, name))
ctx->cache->globals[name] = nullptr;
ctx->cache->addGlobal(name);
auto assign = N<AssignStmt>(N<IdExpr>(name), a.defaultValue,
a.type ? a.type->getIndex()->index : nullptr);
assign->setUpdate();
@ -205,14 +208,12 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
// After popping context block, record types and nested classes will disappear.
// Store their references and re-add them to the context after popping
std::vector<SimplifyContext::Item> addLater;
addLater.reserve(clsStmts.size() + 1);
for (auto &c : clsStmts)
addLater.push_back(ctx->find(c->getClass()->name));
if (stmt->attributes.has(Attr::Tuple))
addLater.push_back(ctx->forceFind(name));
ctx->bases.pop_back();
ctx->popBlock();
}
for (auto &i : addLater)
ctx->add(ctx->cache->rev(i->canonicalName), i);
@ -279,12 +280,15 @@ SimplifyVisitor::parseBaseClasses(const std::vector<ExprPtr> &baseClasses,
args.emplace_back(a);
}
if (a.status != Param::Normal) {
if (getStaticGeneric(a.type.get()))
ctx->addVar(a.name, a.name, a.type->getSrcInfo())->generic = true;
else
if (auto st = getStaticGeneric(a.type.get())) {
auto val = ctx->addVar(a.name, a.name, a.type->getSrcInfo());
val->generic = true;
val->staticType = st;
} else {
ctx->addType(a.name, a.name, a.type->getSrcInfo())->generic = true;
}
}
}
if (si != subs.size())
error(cls.get(), "wrong number of generics");
}
@ -632,6 +636,29 @@ StmtPtr SimplifyVisitor::codegenMagic(const std::string &op, const ExprPtr &typE
N<DotExpr>(clone(args[i].type), "__from_py__"),
N<CallExpr>(N<DotExpr>(I("pyobj"), "_tuple_get"), I("src"), N<IntExpr>(i))));
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(typExpr->clone(), ar)));
} else if (op == "to_gpu") {
// def __to_gpu__(self: T, cache) -> T:
// return __internal__.class_to_gpu(self, cache)
fargs.emplace_back(Param{"self", typExpr->clone()});
fargs.emplace_back(Param{"cache"});
ret = typExpr->clone();
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(
N<DotExpr>(I("__internal__"), "class_to_gpu"), I("self"), I("cache"))));
} else if (op == "from_gpu") {
// def __from_gpu__(self: T, other: T) -> None:
// __internal__.class_from_gpu(self, other)
fargs.emplace_back(Param{"self", typExpr->clone()});
fargs.emplace_back(Param{"other", typExpr->clone()});
ret = I("NoneType");
stmts.emplace_back(N<ExprStmt>(N<CallExpr>(
N<DotExpr>(I("__internal__"), "class_from_gpu"), I("self"), I("other"))));
} else if (op == "from_gpu_new") {
// def __from_gpu_new__(other: T) -> T:
// return __internal__.class_from_gpu_new(other)
fargs.emplace_back(Param{"other", typExpr->clone()});
ret = typExpr->clone();
stmts.emplace_back(N<ReturnStmt>(
N<CallExpr>(N<DotExpr>(I("__internal__"), "class_from_gpu_new"), I("other"))));
} else if (op == "repr") {
// def __repr__(self: T) -> str:
// a = __array__[str](N) (number of args)

View File

@ -13,7 +13,7 @@ namespace codon::ast {
/// The rest will be handled during the type-checking stage.
void SimplifyVisitor::visit(TupleExpr *expr) {
for (auto &i : expr->items)
transform(i);
transform(i, true); // types needed for some constructs (e.g., isinstance)
}
/// Transform a list `[a1, ..., aN]` to the corresponding statement expression.

View File

@ -160,8 +160,8 @@ SimplifyContext::Item SimplifyContext::findDominatingBinding(const std::string &
std::make_unique<BoolExpr>(false), nullptr));
// Reached the toplevel? Register the binding as global.
if (prefix == 1) {
cache->globals[canonicalName] = nullptr;
cache->globals[fmt::format("{}.__used__", canonicalName)] = nullptr;
cache->addGlobal(canonicalName);
cache->addGlobal(fmt::format("{}.__used__", canonicalName));
}
hasUsed = true;
}

View File

@ -40,6 +40,8 @@ struct SimplifyItem : public SrcObject {
bool noShadow = false;
/// Set if an identifier is a class or a function generic
bool generic = false;
/// Set if an identifier is a static variable.
char staticType = 0;
public:
SimplifyItem(Kind kind, std::string baseName, std::string canonicalName,
@ -61,6 +63,7 @@ public:
/// (i.e., a block that might not be executed during the runtime)
bool isConditional() const { return scope.size() > 1; }
bool isGeneric() const { return generic; }
char isStatic() const { return staticType; }
};
/** Context class that tracks identifiers during the simplification. **/
@ -99,8 +102,9 @@ struct SimplifyContext : public Context<SimplifyItem> {
/// Map of captured identifiers (i.e., identifiers not defined in a function).
/// Captured (canonical) identifiers are mapped to the new canonical names
/// (representing the canonical function argument names that are appended to the
/// function after processing).
std::unordered_map<std::string, std::string> *captures;
/// function after processing) and their types (indicating if they are a type, a
/// static or a variable).
std::unordered_map<std::string, std::pair<std::string, ExprPtr>> *captures;
/// A stack of nested loops enclosing the current statement used for transforming
/// "break" statement in loop-else constructs. Each loop is defined by a "break"
@ -123,6 +127,18 @@ struct SimplifyContext : public Context<SimplifyItem> {
/// Current base stack (the last enclosing base is the last base in the stack).
std::vector<Base> bases;
struct BaseGuard {
SimplifyContext *holder;
BaseGuard(SimplifyContext *holder, const std::string &name) : holder(holder) {
holder->bases.emplace_back(Base(name));
holder->addBlock();
}
~BaseGuard() {
holder->bases.pop_back();
holder->popBlock();
}
};
/// Set of seen global identifiers used to prevent later creation of local variables
/// with the same name.
std::unordered_map<std::string, std::unordered_map<std::string, ExprPtr>>

View File

@ -42,7 +42,7 @@ void SimplifyVisitor::visit(TryStmt *stmt) {
c.var = ctx->generateCanonicalName(c.var);
ctx->addVar(ctx->cache->rev(c.var), c.var, c.suite->getSrcInfo());
}
transformType(c.exc);
transform(c.exc, true);
transformConditionalScope(c.suite);
ctx->leaveConditionalBlock();
}

View File

@ -74,8 +74,7 @@ void SimplifyVisitor::visit(GlobalStmt *stmt) {
stmt->var);
// Register as global if needed
if (!in(ctx->cache->globals, val->canonicalName))
ctx->cache->globals[val->canonicalName] = nullptr;
ctx->cache->addGlobal(val->canonicalName);
val = ctx->addVar(stmt->var, val->canonicalName, stmt->getSrcInfo());
val->baseName = ctx->getBaseName();
@ -111,8 +110,10 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// Parse attributes
for (auto i = stmt->decorators.size(); i-- > 0;) {
if (auto n = isAttribute(stmt->decorators[i])) {
stmt->attributes.set(*n);
auto [isAttr, attrName] = getDecorator(stmt->decorators[i]);
if (!attrName.empty()) {
stmt->attributes.set(attrName);
if (isAttr)
stmt->decorators[i] = nullptr; // remove it from further consideration
}
}
@ -153,13 +154,16 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
ctx->addAlwaysVisible(funcVal);
}
std::vector<Param> args;
StmtPtr suite = nullptr;
ExprPtr ret = nullptr;
std::unordered_map<std::string, std::pair<std::string, ExprPtr>> captures;
{
// Set up the base
ctx->bases.emplace_back(SimplifyContext::Base{canonicalName});
ctx->addBlock();
SimplifyContext::BaseGuard br(ctx.get(), canonicalName);
ctx->getBase()->attributes = &(stmt->attributes);
// Parse arguments and add them to the context
std::vector<Param> args;
for (auto &a : stmt->args) {
std::string varName = a.name;
int stars = trimStars(varName);
@ -178,7 +182,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
if (a.type->getIndex() && a.type->getIndex()->expr->isId(TYPE_CALLABLE))
defaultValue = N<CallExpr>(N<IdExpr>("NoneType"));
// Special case: `arg: type = None` -> `arg: type = NoneType`
if (a.type->isId("type") || a.type->isId("TypeVar"))
if (a.type->isId("type") || a.type->isId(TYPE_TYPEVAR))
defaultValue = N<IdExpr>("NoneType");
}
/// TODO: Uncomment for Python-style defaults
@ -194,12 +198,16 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
// Add generics to the context
if (a.status != Param::Normal) {
if (getStaticGeneric(a.type.get()))
ctx->addVar(varName, name, stmt->getSrcInfo())->generic = true;
else
if (auto st = getStaticGeneric(a.type.get())) {
auto val = ctx->addVar(varName, name, stmt->getSrcInfo());
val->generic = true;
val->staticType = st;
} else {
ctx->addType(varName, name, stmt->getSrcInfo())->generic = true;
}
}
}
// Parse arguments to the context. Needs to be done after adding generics
// to support cases like `foo(a: T, T: type)`
for (auto &a : args) {
@ -217,11 +225,9 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
}
// Parse the return type
auto ret = transformType(stmt->ret, false);
ret = transformType(stmt->ret, false);
// Parse function body
StmtPtr suite = nullptr;
std::unordered_map<std::string, std::string> captures;
if (!stmt->attributes.has(Attr::Internal) && !stmt->attributes.has(Attr::C)) {
if (stmt->attributes.has(Attr::LLVM)) {
suite = transformLLVMDefinition(stmt->suite->firstInBlock());
@ -233,9 +239,7 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
suite = SimplifyVisitor(ctx, preamble).transformConditionalScope(stmt->suite);
}
}
ctx->bases.pop_back();
ctx->popBlock();
}
stmt->attributes.module =
format("{}{}", ctx->moduleName.status == ImportFile::STDLIB ? "std::" : "::",
ctx->moduleName.module);
@ -259,8 +263,8 @@ void SimplifyVisitor::visit(FunctionStmt *stmt) {
args.pop_back();
}
for (auto &c : captures) {
args.emplace_back(Param{c.second, nullptr, nullptr});
partialArgs.push_back({c.second, N<IdExpr>(ctx->cache->rev(c.first))});
args.emplace_back(Param{c.second.first, c.second.second, nullptr});
partialArgs.push_back({c.second.first, N<IdExpr>(ctx->cache->rev(c.first))});
}
if (!kw.name.empty())
args.push_back(kw);
@ -417,20 +421,22 @@ StmtPtr SimplifyVisitor::transformLLVMDefinition(Stmt *codeStmt) {
return N<SuiteStmt>(items);
}
/// Check if a decorator is actually an attribute (a function with `@__attribute__`)
std::string *SimplifyVisitor::isAttribute(const ExprPtr &e) {
/// Fetch a decorator canonical name. The first pair member indicates if a decorator is
/// actually an attribute (a function with `@__attribute__`).
std::pair<bool, std::string> SimplifyVisitor::getDecorator(const ExprPtr &e) {
auto dt = transform(clone(e));
if (dt && dt->getId()) {
auto ci = ctx->find(dt->getId()->value);
auto id = dt->getCall() ? dt->getCall()->expr : dt;
if (id && id->getId()) {
auto ci = ctx->find(id->getId()->value);
if (ci && ci->isFunc()) {
if (ctx->cache->overloads[ci->canonicalName].size() == 1)
if (ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name]
.ast->attributes.isAttribute) {
return &(ci->canonicalName);
if (ctx->cache->overloads[ci->canonicalName].size() == 1) {
return {ctx->cache->functions[ctx->cache->overloads[ci->canonicalName][0].name]
.ast->attributes.isAttribute,
ci->canonicalName};
}
}
}
return nullptr;
return {false, ""};
}
} // namespace codon::ast

View File

@ -319,8 +319,7 @@ StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) {
// `import_[I]_done = False` (set to True upon successful import)
preamble->push_back(N<AssignStmt>(N<IdExpr>(importDoneVar = importVar + "_done"),
N<BoolExpr>(false)));
if (!in(ctx->cache->globals, importDoneVar))
ctx->cache->globals[importDoneVar] = nullptr;
ctx->cache->addGlobal(importDoneVar);
// Wrap all imported top-level statements into a function.
// Make sure to register the global variables and set their assignments as updates.
@ -332,11 +331,8 @@ StmtPtr SimplifyVisitor::transformNewImport(const ImportFile &file) {
if (!a->isUpdate() && a->lhs->getId()) {
// Global `a = ...`
auto val = ictx->forceFind(a->lhs->getId()->value);
if (val->isVar() && val->isGlobal() && !getStaticGeneric(a->type.get())) {
// Register global
if (!in(ctx->cache->globals, val->canonicalName))
ctx->cache->globals[val->canonicalName] = nullptr;
}
if (val->isVar() && val->isGlobal() && !getStaticGeneric(a->type.get()))
ctx->cache->addGlobal(val->canonicalName);
}
}
stmts.push_back(s);

View File

@ -97,7 +97,7 @@ void SimplifyVisitor::visit(ForStmt *stmt) {
ctx->addVar(i->value, varName = ctx->generateCanonicalName(i->value),
stmt->var->getSrcInfo());
transform(stmt->var);
transform(stmt->suite);
stmt->suite = transform(N<SuiteStmt>(stmt->suite));
} else {
varName = ctx->cache->getTemporaryVar("for");
ctx->addVar(varName, varName, stmt->var->getSrcInfo());
@ -109,7 +109,7 @@ void SimplifyVisitor::visit(ForStmt *stmt) {
stmts.push_back(stmt->suite);
stmt->suite = transform(N<SuiteStmt>(stmts));
}
ctx->leaveConditionalBlock();
ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts));
// Dominate loop variables
for (auto &var : ctx->getBase()->getLoop()->seenVars)
ctx->findDominatingBinding(var);

View File

@ -20,12 +20,12 @@ using namespace types;
/// @param cache Pointer to the shared cache ( @c Cache )
/// @param file Filename to be used for error reporting
/// @param barebones Use the bare-bones standard library for faster testing
/// @param defines User-defined static values (typically passed as `codon run -DX=Y
/// ...`).
/// @param defines User-defined static values (typically passed as `codon run -DX=Y`).
/// Each value is passed as a string.
StmtPtr
SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &file,
const std::unordered_map<std::string, std::string> &defines,
const std::unordered_map<std::string, std::string> &earlyDefines,
bool barebones) {
auto preamble = std::make_shared<std::vector<StmtPtr>>();
seqassertn(cache->module, "cache's module is not set");
@ -53,6 +53,20 @@ SimplifyVisitor::apply(Cache *cache, const StmtPtr &node, const std::string &fil
stdlib->moduleName = {ImportFile::STDLIB, stdlibPath->path, "__init__"};
// Load the standard library
stdlib->setFilename(stdlibPath->path);
// Core definitions
preamble->push_back(SimplifyVisitor(stdlib, preamble)
.transform(parseCode(stdlib->cache, stdlibPath->path,
"from internal.core import *")));
for (auto &d : earlyDefines) {
// Load early compile-time defines (for standard library)
preamble->push_back(
SimplifyVisitor(stdlib, preamble)
.transform(std::make_shared<AssignStmt>(
std::make_shared<IdExpr>(d.first),
std::make_shared<IntExpr>(d.second),
std::make_shared<IndexExpr>(std::make_shared<IdExpr>("Static"),
std::make_shared<IdExpr>("int")))));
}
preamble->push_back(SimplifyVisitor(stdlib, preamble)
.transform(parseFile(stdlib->cache, stdlibPath->path)));
stdlib->isStdlibLoading = false;
@ -181,6 +195,7 @@ StmtPtr SimplifyVisitor::transform(StmtPtr &stmt) {
stmt->accept(v);
} catch (const exc::ParserException &e) {
ctx->cache->errors.push_back(e);
// throw;
}
ctx->popSrcInfo();
if (v.resultStmt)

View File

@ -43,8 +43,10 @@ class SimplifyVisitor : public CallbackASTVisitor<ExprPtr, StmtPtr> {
StmtPtr resultStmt;
public:
static StmtPtr apply(Cache *cache, const StmtPtr &node, const std::string &file,
const std::unordered_map<std::string, std::string> &defines,
static StmtPtr
apply(Cache *cache, const StmtPtr &node, const std::string &file,
const std::unordered_map<std::string, std::string> &defines = {},
const std::unordered_map<std::string, std::string> &earlyDefines = {},
bool barebones = false);
static StmtPtr apply(const std::shared_ptr<SimplifyContext> &cache,
const StmtPtr &node, const std::string &file, int atAge = -1);
@ -169,7 +171,7 @@ private: // Node simplification rules
StmtPtr transformPythonDefinition(const std::string &, const std::vector<Param> &,
const Expr *, Stmt *);
StmtPtr transformLLVMDefinition(Stmt *);
std::string *isAttribute(const ExprPtr &);
std::pair<bool, std::string> getDecorator(const ExprPtr &);
/* Classes (class.cpp) */
void visit(ClassStmt *) override;

View File

@ -409,7 +409,10 @@ void TranslateVisitor::visit(ForStmt *stmt) {
bool ordered = fc->funcGenerics[1].type->getStatic()->expr->staticValue.getInt();
auto threads = transform(c->args[0].value);
auto chunk = transform(c->args[1].value);
os = std::make_unique<OMPSched>(schedule, threads, chunk, ordered);
int64_t collapse =
fc->funcGenerics[2].type->getStatic()->expr->staticValue.getInt();
bool gpu = fc->funcGenerics[3].type->getStatic()->expr->staticValue.getInt();
os = std::make_unique<OMPSched>(schedule, threads, chunk, ordered, collapse, gpu);
LOG_TYPECHECK("parsed {}", stmt->decorator->toString());
}
@ -487,7 +490,7 @@ void TranslateVisitor::visit(TryStmt *stmt) {
}
void TranslateVisitor::visit(ThrowStmt *stmt) {
result = make<ir::ThrowInstr>(stmt, transform(stmt->expr));
result = make<ir::ThrowInstr>(stmt, stmt->expr ? transform(stmt->expr) : nullptr);
}
void TranslateVisitor::visit(FunctionStmt *stmt) {

View File

@ -24,15 +24,6 @@ void TypecheckVisitor::visit(IdExpr *expr) {
if (isTuple(expr->value))
generateTuple(std::stoi(expr->value.substr(sizeof(TYPE_TUPLE) - 1)));
// Handle empty callable references
if (expr->value == TYPE_CALLABLE) {
auto typ = ctx->getUnbound();
typ->getLink()->trait = std::make_shared<CallableTrait>(std::vector<TypePtr>{});
unify(expr->type, typ);
expr->markType();
return;
}
// Replace identifiers that have been superseded by domination analysis during the
// simplification
while (auto s = in(ctx->cache->replacements, expr->value))
@ -160,6 +151,12 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
if (expr->expr->type->getFunc() && expr->member == "__name__") {
return transform(N<StringExpr>(expr->expr->type->toString()));
}
// Special case: fn.__llvm_name__
if (expr->expr->type->getFunc() && expr->member == "__llvm_name__") {
if (realize(expr->expr->type))
return transform(N<StringExpr>(expr->expr->type->realizedName()));
return nullptr;
}
// Special case: cls.__name__
if (expr->expr->isType() && expr->member == "__name__") {
if (realize(expr->expr->type))

View File

@ -69,7 +69,11 @@ void TypecheckVisitor::visit(AssignStmt *stmt) {
seqassert(stmt->rhs->staticValue.evaluated, "static not evaluated");
unify(stmt->lhs->type,
unify(stmt->type->type, std::make_shared<StaticType>(stmt->rhs, ctx)));
ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type);
auto val = ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type);
if (in(ctx->cache->globals, lhs)) {
// Make globals always visible!
ctx->addToplevel(lhs, val);
}
if (realize(stmt->lhs->type))
stmt->setDone();
} else {

View File

@ -320,7 +320,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
if (!part.known.empty()) {
auto e = getPartialArg(-1);
auto t = e->getType()->getRecord();
seqassert(t && startswith(t->name, "KwTuple"), "{} not a kwtuple",
seqassert(t && startswith(t->name, TYPE_KWTUPLE), "{} not a kwtuple",
e->toString());
auto &ff = ctx->cache->classes[t->name].fields;
for (int i = 0; i < t->getRecord()->args.size(); i++) {
@ -410,8 +410,9 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
if (typeArgs[si]) {
auto typ = typeArgs[si]->type;
if (calleeFn->funcGenerics[si].type->isStaticType()) {
if (!typeArgs[si]->isStatic())
if (!typeArgs[si]->isStatic()) {
error("expected static expression");
}
typ = std::make_shared<StaticType>(typeArgs[si], ctx);
}
unify(typ, calleeFn->funcGenerics[si].type);
@ -543,6 +544,10 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformSpecialCall(CallExpr *expr)
return {true, transformCompileError(expr)};
} else if (val == "tuple") {
return {true, transformTupleFn(expr)};
} else if (val == "__realized__") {
return {true, transformRealizedFn(expr)};
} else if (val == "__static_print__") {
return {false, transformStaticPrintFn(expr)};
} else {
return {false, nullptr};
}
@ -668,7 +673,6 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) {
/// `isinstance(obj, ByVal)` is True if `type(obj)` is a tuple type
/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type
ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
expr->staticValue.type = StaticValue::INT;
expr->setType(unify(expr->type, ctx->getType("bool")));
transform(expr->args[0].value);
auto typ = expr->args[0].value->type->getClass();
@ -678,21 +682,44 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
transform(expr->args[0].value); // transform again to realize it
auto &typExpr = expr->args[1].value;
if (auto c = typExpr->getCall()) {
// Handle `isinstance(obj, (type1, type2, ...))`
if (typExpr->origExpr && typExpr->origExpr->getTuple()) {
ExprPtr result = transform(N<BoolExpr>(false));
for (auto &i : typExpr->origExpr->getTuple()->items) {
result = transform(N<BinaryExpr>(
result, "||",
N<CallExpr>(N<IdExpr>("isinstance"), expr->args[0].value, i)));
}
return result;
}
}
expr->staticValue.type = StaticValue::INT;
if (typExpr->isId("Tuple") || typExpr->isId("tuple")) {
return transform(N<BoolExpr>(startswith(typ->name, TYPE_TUPLE)));
} else if (typExpr->isId("ByVal")) {
return transform(N<BoolExpr>(typ->getRecord() != nullptr));
} else if (typExpr->isId("ByRef")) {
return transform(N<BoolExpr>(typ->getRecord() == nullptr));
} else if (typExpr->type->is("pyobj") && !typExpr->isType()) {
if (typ->is("pyobj")) {
expr->staticValue.type = StaticValue::NOT_STATIC;
return transform(N<CallExpr>(N<IdExpr>("std.internal.python._isinstance:0"),
expr->args[0].value, expr->args[1].value));
} else {
return transform(N<BoolExpr>(false));
}
}
transformType(typExpr);
// Check super types (i.e., statically inherited) as well
for (auto &tx : getSuperTypes(typ->getClass())) {
if (tx->unify(typExpr->type.get(), nullptr) >= 0)
return transform(N<BoolExpr>(true));
}
return transform(N<BoolExpr>(false));
}
}
/// Transform staticlen method to a static integer expression. This method supports only
@ -801,6 +828,33 @@ ExprPtr TypecheckVisitor::transformTypeFn(CallExpr *expr) {
return e;
}
/// Transform __realized__ function to a fully realized type identifier.
ExprPtr TypecheckVisitor::transformRealizedFn(CallExpr *expr) {
auto call =
transform(N<CallExpr>(expr->args[0].value, N<StarExpr>(expr->args[1].value)));
if (!call->getCall()->expr->type->getFunc())
error("the first argument must be a function");
if (auto f = realize(call->getCall()->expr->type)) {
auto e = N<IdExpr>(f->getFunc()->realizedName());
e->setType(f);
e->setDone();
return e;
}
return nullptr;
}
/// Transform __static_print__ function to a fully realized type identifier.
ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
auto &args = expr->args[0].value->getCall()->args;
for (size_t i = 0; i < args.size(); i++) {
realize(args[i].value->type);
fmt::print(stderr, "[static_print] {}: {} := {}\n", getSrcInfo(),
FormatVisitor::apply(args[i].value),
args[i].value->type ? args[i].value->type->debugString(1) : "-");
}
return nullptr;
}
/// Get the list that describes the inheritance hierarchy of a given type.
/// The first type in the list is the most recently inherited type.
std::vector<ClassTypePtr> TypecheckVisitor::getSuperTypes(const ClassTypePtr &cls) {

View File

@ -55,6 +55,15 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
unify(defType->type, generic);
}
}
if (auto ti = CAST(a.type, InstantiateExpr)) {
// Parse TraitVar
seqassert(ti->typeExpr->isId(TYPE_TYPEVAR), "not a TypeVar instantiation");
auto l = transformType(ti->typeParams[0])->type;
if (l->getLink() && l->getLink()->trait)
generic->getLink()->trait = l->getLink()->trait;
else
generic->getLink()->trait = std::make_shared<types::TypeTrait>(l);
}
ctx->add(TypecheckItem::Type, a.name, generic);
ClassType::Generic g{a.name, ctx->cache->rev(a.name),
generic->generalize(ctx->typecheckLevel), typId};
@ -135,6 +144,18 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name,
args.emplace_back(Param(format("T{}", i + 1), N<IdExpr>("type"), nullptr, true));
StmtPtr stmt = N<ClassStmt>(ctx->cache->generateSrcInfo(), typeName, args, nullptr,
std::vector<ExprPtr>{N<IdExpr>("tuple")});
// Add getItem for KwArgs:
// `def __getitem__(self, key: Static[str]): return getattr(self, key)`
auto getItem = N<FunctionStmt>(
"__getitem__", nullptr,
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
N<IdExpr>("str"))}},
N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("getattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
if (startswith(typeName, TYPE_KWTUPLE))
stmt->getClass()->suite = getItem;
// Simplify in the standard library context and type check
stmt = SimplifyVisitor::apply(ctx->cache->imports[STDLIB_IMPORT].ctx, stmt,
FILE_GENERATED, 0);

View File

@ -36,7 +36,10 @@ void TypecheckVisitor::visit(IfExpr *expr) {
if (expr->cond->isStatic()) {
resultExpr = evaluateStaticCondition(
expr->cond,
[&](bool isTrue) { return transform(isTrue ? expr->ifexpr : expr->elsexpr); },
[&](bool isTrue) {
LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue);
return transform(isTrue ? expr->ifexpr : expr->elsexpr);
},
[&]() -> ExprPtr {
// Check if both subexpressions are static; if so, this if expression is also
// static and should be marked as such
@ -83,6 +86,7 @@ void TypecheckVisitor::visit(IfStmt *stmt) {
resultStmt = evaluateStaticCondition(
stmt->cond,
[&](bool isTrue) {
LOG_TYPECHECK("[static::cond] {}: {}", getSrcInfo(), isTrue);
auto t = transform(isTrue ? stmt->ifSuite : stmt->elseSuite);
return t ? t : transform(N<SuiteStmt>());
},

View File

@ -83,12 +83,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
const types::ClassTypePtr &generics) {
seqassert(type, "type is null");
std::unordered_map<int, types::TypePtr> genericCache;
if (generics)
if (generics) {
for (auto &g : generics->generics)
if (g.type &&
!(g.type->getLink() && g.type->getLink()->kind == types::LinkType::Generic)) {
genericCache[g.id] = g.type;
}
}
auto t = type->instantiate(typecheckLevel, &(cache->unboundCount), &genericCache);
for (auto &i : genericCache) {
if (auto l = i.second->getLink()) {

View File

@ -66,6 +66,8 @@ struct TypeContext : public Context<TypecheckItem> {
int blockLevel;
/// True if an early return is found (anything afterwards won't be typechecked)
bool returnEarly;
/// Stack of static loop control variables (used to emulate goto statements).
std::vector<std::string> staticLoops;
public:
explicit TypeContext(Cache *cache);

View File

@ -9,14 +9,55 @@ namespace codon::ast {
using namespace types;
/// Typecheck try-except statements.
/// Typecheck try-except statements. Handle Python exceptions separately.
/// @example
/// ```try: ...
/// except python.Error as e: ...
/// except PyExc as f: ...
/// except ValueError as g: ...
/// ``` -> ```
/// try: ...
/// except ValueError as g: ... # ValueError
/// except PyExc as exc:
/// while True:
/// if isinstance(exc.pytype, python.Error): # python.Error
/// e = exc.pytype; ...; break
/// f = exc; ...; break # PyExc
/// raise```
void TypecheckVisitor::visit(TryStmt *stmt) {
ctx->blockLevel++;
transform(stmt->suite);
ctx->blockLevel--;
std::vector<TryStmt::Catch> catches;
auto pyVar = ctx->cache->getTemporaryVar("pyexc");
auto pyCatchStmt = N<WhileStmt>(N<BoolExpr>(true), N<SuiteStmt>());
auto done = stmt->suite->isDone();
for (auto &c : stmt->catches) {
transform(c.exc);
if (c.exc && c.exc->type->is("pyobj")) {
// Transform python.Error exceptions
if (!c.var.empty()) {
c.suite = N<SuiteStmt>(
N<AssignStmt>(N<IdExpr>(c.var), N<DotExpr>(N<IdExpr>(pyVar), "pytype")),
c.suite);
}
c.suite =
N<IfStmt>(N<CallExpr>(N<IdExpr>("isinstance"),
N<DotExpr>(N<IdExpr>(pyVar), "pytype"), clone(c.exc)),
N<SuiteStmt>(c.suite, N<BreakStmt>()), nullptr);
pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite);
} else if (c.exc && c.exc->type->is("std.internal.types.error.PyError")) {
// Transform PyExc exceptions
if (!c.var.empty()) {
c.suite =
N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(c.var), N<IdExpr>(pyVar)), c.suite);
}
c.suite = N<SuiteStmt>(c.suite, N<BreakStmt>());
pyCatchStmt->suite->getSuite()->stmts.push_back(c.suite);
} else {
// Handle all other exceptions
transformType(c.exc);
if (!c.var.empty()) {
// Handle dominated except bindings
@ -40,7 +81,25 @@ void TypecheckVisitor::visit(TryStmt *stmt) {
transform(c.suite);
ctx->blockLevel--;
done &= (!c.exc || c.exc->isDone()) && c.suite->isDone();
catches.push_back(c);
}
}
if (!pyCatchStmt->suite->getSuite()->stmts.empty()) {
// Process PyError catches
auto exc = N<IdExpr>("std.internal.types.error.PyError");
exc->markType();
pyCatchStmt->suite->getSuite()->stmts.push_back(N<ThrowStmt>(nullptr));
TryStmt::Catch c{pyVar, transformType(exc), pyCatchStmt};
auto val = ctx->add(TypecheckItem::Var, pyVar, c.exc->getType());
unify(val->type, c.exc->getType());
ctx->blockLevel++;
transform(c.suite);
ctx->blockLevel--;
done &= (!c.exc || c.exc->isDone()) && c.suite->isDone();
catches.push_back(c);
}
stmt->catches = catches;
if (stmt->finally) {
ctx->blockLevel++;
transform(stmt->finally);
@ -56,6 +115,11 @@ void TypecheckVisitor::visit(TryStmt *stmt) {
/// @example
/// `raise exc` -> ```raise __internal__.set_header(exc, "fn", "file", line, col)```
void TypecheckVisitor::visit(ThrowStmt *stmt) {
if (!stmt->expr) {
stmt->setDone();
return;
}
transform(stmt->expr);
if (!(stmt->expr->getCall() &&

View File

@ -40,16 +40,16 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
unify(ctx->getRealizationBase()->returnType, stmt->expr->type);
} else {
// If we are not within conditional block, ignore later statements in this function.
// Useful with static if statements.
if (!ctx->blockLevel)
ctx->returnEarly = true;
// Just set the expr for the translation stage. However, do not unify the return
// type! This might be a `return` in a generator.
stmt->expr = transform(N<CallExpr>(N<IdExpr>("NoneType")));
}
// If we are not within conditional block, ignore later statements in this function.
// Useful with static if statements.
if (!ctx->blockLevel)
ctx->returnEarly = true;
if (stmt->expr->isDone())
stmt->setDone();
}

View File

@ -52,8 +52,11 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
ctx->typecheckLevel++;
auto changedNodes = ctx->changedNodes;
ctx->changedNodes = 0;
auto returnEarly = ctx->returnEarly;
ctx->returnEarly = false;
TypecheckVisitor(ctx).transform(result);
std::swap(ctx->changedNodes, changedNodes);
std::swap(ctx->returnEarly, returnEarly);
ctx->typecheckLevel--;
if (iteration == 1 && isToplevel) {
@ -282,10 +285,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
if (hasAst) {
auto oldBlockLevel = ctx->blockLevel;
auto oldReturnEarly = ctx->returnEarly;
ctx->blockLevel = 0;
ctx->returnEarly = false;
if (startswith(type->ast->name, "Function.__call__")) {
// Special case: Function.__call__
/// TODO: move to IR
@ -313,7 +313,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
}
inferTypes(ast->suite);
ctx->blockLevel = oldBlockLevel;
ctx->returnEarly = oldReturnEarly;
// Use NoneType as the return type when the return type is not specified and
// function has no return statement
@ -382,6 +381,7 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
// Get the IR type
auto *module = ctx->cache->module;
ir::types::Type *handle = nullptr;
if (t->name == "bool") {
handle = module->getBoolType();
} else if (t->name == "byte") {
@ -390,6 +390,8 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
handle = module->getIntType();
} else if (t->name == "float") {
handle = module->getFloatType();
} else if (t->name == "float32") {
handle = module->getFloat32Type();
} else if (t->name == "str") {
handle = module->getStringType();
} else if (t->name == "Int" || t->name == "UInt") {
@ -415,6 +417,9 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
types.push_back(forceFindIRType(m));
auto ret = forceFindIRType(t->generics[1].type);
handle = module->unsafeGetFuncType(realizedName, ret, types);
} else if (t->name == "std.simd.Vec") {
seqassert(types.size() == 1 && statics.size() == 1, "bad generics/statics");
handle = module->unsafeGetVectorType(statics[0]->getInt(), types[0]);
} else if (auto tr = t->getRecord()) {
std::vector<ir::types::Type *> typeArgs;
std::vector<std::string> names;

View File

@ -13,17 +13,32 @@ namespace codon::ast {
using namespace types;
/// Nothing to typecheck; just call setDone
void TypecheckVisitor::visit(BreakStmt *stmt) { stmt->setDone(); }
void TypecheckVisitor::visit(BreakStmt *stmt) {
stmt->setDone();
if (!ctx->staticLoops.back().empty()) {
auto a = N<AssignStmt>(N<IdExpr>(ctx->staticLoops.back()), N<BoolExpr>(false));
a->setUpdate();
resultStmt = transform(N<SuiteStmt>(a, stmt->clone()));
}
}
/// Nothing to typecheck; just call setDone
void TypecheckVisitor::visit(ContinueStmt *stmt) { stmt->setDone(); }
void TypecheckVisitor::visit(ContinueStmt *stmt) {
stmt->setDone();
if (!ctx->staticLoops.back().empty()) {
resultStmt = N<BreakStmt>();
resultStmt->setDone();
}
}
/// Typecheck while statements.
void TypecheckVisitor::visit(WhileStmt *stmt) {
ctx->staticLoops.push_back(stmt->gotoVar.empty() ? "" : stmt->gotoVar);
transform(stmt->cond);
ctx->blockLevel++;
transform(stmt->suite);
ctx->blockLevel--;
ctx->staticLoops.pop_back();
if (stmt->cond->isDone() && stmt->suite->isDone())
stmt->setDone();
@ -40,6 +55,9 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
if (!iterType)
return; // wait until the iterator is known
if ((resultStmt = transformStaticForLoop(stmt)))
return;
bool maybeHeterogenous = startswith(iterType->name, TYPE_TUPLE) ||
startswith(iterType->name, TYPE_KWTUPLE);
if (maybeHeterogenous && !iterType->canRealize()) {
@ -83,9 +101,11 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
unify(stmt->var->type,
iterType ? unify(val->type, iterType->generics[0].type) : val->type);
ctx->staticLoops.push_back("");
ctx->blockLevel++;
transform(stmt->suite);
ctx->blockLevel--;
ctx->staticLoops.pop_back();
if (stmt->iter->isDone() && stmt->suite->isDone())
stmt->setDone();
@ -132,4 +152,81 @@ StmtPtr TypecheckVisitor::transformHeterogenousTupleFor(ForStmt *stmt) {
return block;
}
/// Handle static for constructs.
/// @example
/// `for i in statictuple(1, x): <suite>` ->
/// ```loop = True
/// while loop:
/// while loop:
/// i: Static[int] = 1; <suite>; break
/// while loop:
/// i = x; <suite>; break
/// loop = False # also set to False on break
/// A separate suite is generated for each static iteration.
StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) {
auto loopVar = ctx->cache->getTemporaryVar("loop");
auto fn = [&](const std::string &var, const ExprPtr &expr) {
bool staticInt = expr->isStatic();
auto t = N<IndexExpr>(
N<IdExpr>("Static"),
N<IdExpr>(expr->staticValue.type == StaticValue::INT ? "int" : "str"));
t->markType();
auto brk = N<BreakStmt>();
brk->setDone(); // Avoid transforming this one to continue
// var [: Static] := expr; suite...
auto loop = N<WhileStmt>(N<IdExpr>(loopVar),
N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(var), expr->clone(),
staticInt ? t : nullptr),
clone(stmt->suite), brk));
loop->gotoVar = loopVar;
return loop;
};
auto var = stmt->var->getId()->value;
if (!stmt->iter->getCall() || !stmt->iter->getCall()->expr->getId())
return nullptr;
auto iter = stmt->iter->getCall()->expr->getId();
auto block = N<SuiteStmt>();
if (iter && startswith(iter->value, "statictuple:0")) {
auto &args = stmt->iter->getCall()->args[0].value->getCall()->args;
for (size_t i = 0; i < args.size(); i++)
block->stmts.push_back(fn(var, args[i].value));
} else if (iter &&
startswith(iter->value, "std.internal.types.range.staticrange:0")) {
int st =
iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt();
int ed =
iter->type->getFunc()->funcGenerics[1].type->getStatic()->evaluate().getInt();
int step =
iter->type->getFunc()->funcGenerics[2].type->getStatic()->evaluate().getInt();
if (abs(st - ed) / abs(step) > MAX_STATIC_ITER)
error("staticrange out of bounds ({} > {})", abs(st - ed) / abs(step),
MAX_STATIC_ITER);
for (int i = st; step > 0 ? i < ed : i > ed; i += step)
block->stmts.push_back(fn(var, N<IntExpr>(i)));
} else if (iter &&
startswith(iter->value, "std.internal.types.range.staticrange:1")) {
int ed =
iter->type->getFunc()->funcGenerics[0].type->getStatic()->evaluate().getInt();
if (ed > MAX_STATIC_ITER)
error("staticrange out of bounds ({} > {})", ed, MAX_STATIC_ITER);
for (int i = 0; i < ed; i++)
block->stmts.push_back(fn(var, N<IntExpr>(i)));
} else {
return nullptr;
}
ctx->blockLevel++;
// Close the loop
auto a = N<AssignStmt>(N<IdExpr>(loopVar), N<BoolExpr>(false));
a->setUpdate();
block->stmts.push_back(a);
auto loop =
transform(N<SuiteStmt>(N<AssignStmt>(N<IdExpr>(loopVar), N<BoolExpr>(true)),
N<WhileStmt>(N<IdExpr>(loopVar), block)));
ctx->blockLevel--;
return loop;
}
} // namespace codon::ast

View File

@ -262,15 +262,20 @@ void TypecheckVisitor::visit(IndexExpr *expr) {
/// @example
/// Instantiate(foo, [bar]) -> Id("foo[bar]")
void TypecheckVisitor::visit(InstantiateExpr *expr) {
// Infer the expression type
transformType(expr->typeExpr);
TypePtr typ =
ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
seqassert(typ->getClass(), "unknown type: {}", expr->typeExpr->toString());
auto &generics = typ->getClass()->generics;
if (expr->typeParams.size() != generics.size())
error("expected {} generics and/or statics", generics.size());
if (expr->typeExpr->isId(TYPE_CALLABLE)) {
// Case: Callable[...] instantiation
// Case: Callable[...] trait instantiation
std::vector<TypePtr> types;
// Callable error checking.
/// TODO: move to Codon?
if (expr->typeParams.size() != 2)
error("invalid Callable type declaration");
for (auto &typeParam : expr->typeParams) {
transformType(typeParam);
if (typeParam->type->isStaticType())
@ -281,16 +286,13 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
// Set up the Callable trait
typ->getLink()->trait = std::make_shared<CallableTrait>(types);
unify(expr->type, typ);
} else if (expr->typeExpr->isId(TYPE_TYPEVAR)) {
// Case: TypeVar[...] trait instantiation
transformType(expr->typeParams[0]);
auto typ = ctx->getUnbound();
typ->getLink()->trait = std::make_shared<TypeTrait>(expr->typeParams[0]->type);
unify(expr->type, typ);
} else {
transformType(expr->typeExpr);
TypePtr typ =
ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
seqassert(typ->getClass(), "unknown type");
auto &generics = typ->getClass()->generics;
if (expr->typeParams.size() != generics.size())
error("expected {} generics and/or statics", generics.size());
for (size_t i = 0; i < expr->typeParams.size(); i++) {
transform(expr->typeParams[i]);
TypePtr t = nullptr;
@ -339,7 +341,9 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
if (expr->expr->staticValue.type == StaticValue::STRING) {
if (expr->op == "!") {
if (expr->expr->staticValue.evaluated) {
return transform(N<BoolExpr>(expr->expr->staticValue.getString().empty()));
bool value = expr->expr->staticValue.getString().empty();
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
return transform(N<BoolExpr>(value));
} else {
// Cannot be evaluated yet: just set the type
unify(expr->type, ctx->getType("bool"));
@ -360,6 +364,7 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
value = -value;
else
value = !bool(value);
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
if (expr->op == "!")
return transform(N<BoolExpr>(bool(value)));
else
@ -375,6 +380,25 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
return nullptr;
}
/// Division and modulus implementations.
std::pair<int, int> divMod(const std::shared_ptr<TypeContext> &ctx, int a, int b) {
if (!b)
error(ctx->getSrcInfo(), "static division by zero");
if (ctx->cache->pythonCompat) {
// Use Python implementation.
int d = a / b;
int m = a - d * b;
if (m && ((b ^ m) < 0)) {
m += b;
d -= 1;
}
return {d, m};
} else {
// Use C implementation.
return {a / b, a % b};
}
}
/// Evaluate a static binary expression and return the resulting static expression.
/// If the expression cannot be evaluated yet, return nullptr.
/// Supported operators: (strings) +, ==, !=
@ -385,8 +409,10 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
if (expr->op == "+") {
// `"a" + "b"` -> `"ab"`
if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) {
return transform(N<StringExpr>(expr->lexpr->staticValue.getString() +
expr->rexpr->staticValue.getString()));
auto value =
expr->lexpr->staticValue.getString() + expr->rexpr->staticValue.getString();
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value);
return transform(N<StringExpr>(value));
} else {
// Cannot be evaluated yet: just set the type
if (!expr->isStatic())
@ -398,7 +424,9 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
if (expr->lexpr->staticValue.evaluated && expr->rexpr->staticValue.evaluated) {
bool eq = expr->lexpr->staticValue.getString() ==
expr->rexpr->staticValue.getString();
return transform(N<BoolExpr>(expr->op == "==" ? eq : !eq));
bool value = expr->op == "==" ? eq : !eq;
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), value);
return transform(N<BoolExpr>(value));
} else {
// Cannot be evaluated yet: just set the type
if (!expr->isStatic())
@ -441,18 +469,13 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
lvalue = lvalue & rvalue;
else if (expr->op == "|")
lvalue = lvalue | rvalue;
else if (expr->op == "//") {
if (!rvalue)
error("static division by zero");
lvalue = lvalue / rvalue;
} else if (expr->op == "%") {
if (!rvalue)
error("static division by zero");
lvalue = lvalue % rvalue;
} else {
else if (expr->op == "//")
lvalue = divMod(ctx, lvalue, rvalue).first;
else if (expr->op == "%")
lvalue = divMod(ctx, lvalue, rvalue).second;
else
seqassert(false, "unknown static operator {}", expr->op);
}
LOG_TYPECHECK("[cond::bin] {}: {}", getSrcInfo(), lvalue);
if (in(std::set<std::string>{"==", "!=", "<", "<=", ">", ">=", "&&", "||"},
expr->op))
return transform(N<BoolExpr>(bool(lvalue)));

View File

@ -47,6 +47,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
ctx->popSrcInfo();
if (v.resultExpr) {
v.resultExpr->attributes |= expr->attributes;
v.resultExpr->origExpr = expr;
expr = v.resultExpr;
}
seqassert(expr->type, "type not set for {}", expr->toString());

View File

@ -135,6 +135,8 @@ private: // Node typechecking rules
ExprPtr transformCompileError(CallExpr *expr);
ExprPtr transformTupleFn(CallExpr *expr);
ExprPtr transformTypeFn(CallExpr *expr);
ExprPtr transformRealizedFn(CallExpr *expr);
ExprPtr transformStaticPrintFn(CallExpr *expr);
std::vector<types::ClassTypePtr> getSuperTypes(const types::ClassTypePtr &cls);
void addFunctionGenerics(const types::FuncType *t);
std::string generatePartialStub(const std::vector<char> &mask, types::FuncType *fn);
@ -151,6 +153,7 @@ private: // Node typechecking rules
void visit(WhileStmt *) override;
void visit(ForStmt *) override;
StmtPtr transformHeterogenousTupleFor(ForStmt *);
StmtPtr transformStaticForLoop(ForStmt *);
/* Errors and exceptions (error.cpp) */
void visit(TryStmt *) override;

View File

@ -0,0 +1,137 @@
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include "lib.h"
#ifdef CODON_GPU
#include "cuda.h"
#define fail(err) \
do { \
const char *msg; \
cuGetErrorString((err), &msg); \
fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, msg); \
abort(); \
} while (0)
#define check(call) \
do { \
auto err = (call); \
if (err != CUDA_SUCCESS) { \
fail(err); \
} \
} while (0)
static std::vector<CUmodule> modules;
static CUcontext context;
void seq_nvptx_init() {
CUdevice device;
check(cuInit(0));
check(cuDeviceGet(&device, 0));
check(cuCtxCreate(&context, 0, device));
}
SEQ_FUNC void seq_nvptx_load_module(const char *filename) {
CUmodule module;
check(cuModuleLoad(&module, filename));
modules.push_back(module);
}
SEQ_FUNC seq_int_t seq_nvptx_device_count() {
int devCount;
check(cuDeviceGetCount(&devCount));
return devCount;
}
SEQ_FUNC seq_str_t seq_nvptx_device_name(CUdevice device) {
char name[128];
check(cuDeviceGetName(name, sizeof(name) - 1, device));
auto sz = static_cast<seq_int_t>(strlen(name));
auto *p = (char *)seq_alloc_atomic(sz);
memcpy(p, name, sz);
return {sz, p};
}
SEQ_FUNC seq_int_t seq_nvptx_device_capability(CUdevice device) {
int devMajor, devMinor;
check(cuDeviceComputeCapability(&devMajor, &devMinor, device));
return ((seq_int_t)devMajor << 32) | (seq_int_t)devMinor;
}
SEQ_FUNC CUdevice seq_nvptx_device(seq_int_t idx) {
CUdevice device;
check(cuDeviceGet(&device, idx));
return device;
}
static bool name_char_valid(char c, bool first) {
bool ok = ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_');
if (!first)
ok = ok || ('0' <= c && c <= '9');
return ok;
}
SEQ_FUNC CUfunction seq_nvptx_function(seq_str_t name) {
CUfunction function;
CUresult result;
std::vector<char> clean(name.len + 1);
for (unsigned i = 0; i < name.len; i++) {
char c = name.str[i];
clean[i] = (name_char_valid(c, i == 0) ? c : '_');
}
clean[name.len] = '\0';
for (auto it = modules.rbegin(); it != modules.rend(); ++it) {
result = cuModuleGetFunction(&function, *it, clean.data());
if (result == CUDA_SUCCESS) {
return function;
} else if (result == CUDA_ERROR_NOT_FOUND) {
continue;
} else {
break;
}
}
fail(result);
return {};
}
SEQ_FUNC void seq_nvptx_invoke(CUfunction f, unsigned int gridDimX,
unsigned int gridDimY, unsigned int gridDimZ,
unsigned int blockDimX, unsigned int blockDimY,
unsigned int blockDimZ, unsigned int sharedMemBytes,
void **kernelParams) {
check(cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ,
sharedMemBytes, nullptr, kernelParams, nullptr));
}
SEQ_FUNC CUdeviceptr seq_nvptx_device_alloc(seq_int_t size) {
if (size == 0)
return {};
CUdeviceptr devp;
check(cuMemAlloc(&devp, size));
return devp;
}
SEQ_FUNC void seq_nvptx_memcpy_h2d(CUdeviceptr devp, char *hostp, seq_int_t size) {
if (size)
check(cuMemcpyHtoD(devp, hostp, size));
}
SEQ_FUNC void seq_nvptx_memcpy_d2h(char *hostp, CUdeviceptr devp, seq_int_t size) {
if (size)
check(cuMemcpyDtoH(hostp, devp, size));
}
SEQ_FUNC void seq_nvptx_device_free(CUdeviceptr devp) {
if (devp)
check(cuMemFree(devp));
}
#endif /* CODON_GPU */

View File

@ -38,6 +38,10 @@ extern "C" void __kmpc_set_gc_callbacks(gc_setup_callback get_stack_base,
void seq_exc_init();
#ifdef CODON_GPU
void seq_nvptx_init();
#endif
int seq_flags;
SEQ_FUNC void seq_init(int flags) {
@ -47,6 +51,9 @@ SEQ_FUNC void seq_init(int flags) {
__kmpc_set_gc_callbacks(GC_get_stack_base, (gc_setup_callback)GC_register_my_thread,
GC_add_roots, GC_remove_roots);
seq_exc_init();
#ifdef CODON_GPU
seq_nvptx_init();
#endif
seq_flags = flags;
}
@ -176,11 +183,11 @@ SEQ_FUNC void *seq_calloc_atomic(size_t m, size_t n) {
#endif
}
SEQ_FUNC void *seq_realloc(void *p, size_t n) {
SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize) {
#if USE_STANDARD_MALLOC
return realloc(p, n);
return realloc(p, newsize);
#else
return GC_REALLOC(p, n);
return GC_REALLOC(p, newsize);
#endif
}
@ -231,7 +238,7 @@ static seq_str_t string_conv(const char *fmt, const size_t size, T t) {
int n = snprintf(p, size, fmt, t);
if (n >= size) {
auto n2 = (size_t)n + 1;
p = (char *)seq_realloc((void *)p, n2);
p = (char *)seq_realloc((void *)p, n2, size);
n = snprintf(p, n2, fmt, t);
}
return {(seq_int_t)n, p};

View File

@ -54,7 +54,7 @@ SEQ_FUNC void *seq_alloc(size_t n);
SEQ_FUNC void *seq_alloc_atomic(size_t n);
SEQ_FUNC void *seq_calloc(size_t m, size_t n);
SEQ_FUNC void *seq_calloc_atomic(size_t m, size_t n);
SEQ_FUNC void *seq_realloc(void *p, size_t n);
SEQ_FUNC void *seq_realloc(void *p, size_t newsize, size_t oldsize);
SEQ_FUNC void seq_free(void *p);
SEQ_FUNC void seq_register_finalizer(void *p, void (*f)(void *obj, void *data));

View File

@ -0,0 +1,550 @@
#include "gpu.h"
#include <algorithm>
#include <memory>
#include <string>
#include "codon/util/common.h"
#include "llvm/CodeGen/CommandFlags.h"
namespace codon {
namespace ir {
namespace {
const std::string GPU_TRIPLE = "nvptx64-nvidia-cuda";
const std::string GPU_DL =
"e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-"
"f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64";
llvm::cl::opt<std::string>
libdevice("libdevice", llvm::cl::desc("libdevice path for GPU kernels"),
llvm::cl::init("/usr/local/cuda/nvvm/libdevice/libdevice.10.bc"));
std::string cleanUpName(llvm::StringRef name) {
std::string validName;
llvm::raw_string_ostream validNameStream(validName);
auto valid = [](char c, bool first) {
bool ok = ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || (c == '_');
if (!first)
ok = ok || ('0' <= c && c <= '9');
return ok;
};
bool first = true;
for (char c : name) {
validNameStream << (valid(c, first) ? c : '_');
first = false;
}
return validNameStream.str();
}
void linkLibdevice(llvm::Module *M, const std::string &path) {
llvm::SMDiagnostic err;
auto libdevice = llvm::parseIRFile(path, err, M->getContext());
if (!libdevice)
compilationError(err.getMessage().str(), err.getFilename().str(), err.getLineNo(),
err.getColumnNo());
libdevice->setDataLayout(M->getDataLayout());
libdevice->setTargetTriple(M->getTargetTriple());
llvm::Linker L(*M);
const bool fail = L.linkInModule(std::move(libdevice));
seqassertn(!fail, "linking libdevice failed");
}
llvm::Function *copyPrototype(llvm::Function *F, const std::string &name) {
auto *M = F->getParent();
return llvm::Function::Create(F->getFunctionType(), llvm::GlobalValue::PrivateLinkage,
name.empty() ? F->getName() : name, *M);
}
llvm::Function *makeNoOp(llvm::Function *F) {
auto *M = F->getParent();
auto &context = M->getContext();
auto dummyName = (".codon.gpu.dummy." + F->getName()).str();
auto *dummy = M->getFunction(dummyName);
if (!dummy) {
dummy = copyPrototype(F, dummyName);
auto *entry = llvm::BasicBlock::Create(context, "entry", dummy);
llvm::IRBuilder<> B(entry);
auto *retType = F->getReturnType();
if (retType->isVoidTy()) {
B.CreateRetVoid();
} else {
B.CreateRet(llvm::UndefValue::get(retType));
}
}
return dummy;
}
using Codegen =
std::function<void(llvm::IRBuilder<> &, const std::vector<llvm::Value *> &)>;
llvm::Function *makeFillIn(llvm::Function *F, Codegen codegen) {
auto *M = F->getParent();
auto &context = M->getContext();
auto fillInName = (".codon.gpu.fillin." + F->getName()).str();
auto *fillIn = M->getFunction(fillInName);
if (!fillIn) {
fillIn = copyPrototype(F, fillInName);
std::vector<llvm::Value *> args;
for (auto it = fillIn->arg_begin(); it != fillIn->arg_end(); ++it) {
args.push_back(it);
}
auto *entry = llvm::BasicBlock::Create(context, "entry", fillIn);
llvm::IRBuilder<> B(entry);
codegen(B, args);
}
return fillIn;
}
llvm::Function *makeMalloc(llvm::Module *M) {
auto &context = M->getContext();
auto F = M->getOrInsertFunction("malloc", llvm::Type::getInt8PtrTy(context),
llvm::Type::getInt64Ty(context));
auto *G = llvm::cast<llvm::Function>(F.getCallee());
G->setLinkage(llvm::GlobalValue::ExternalLinkage);
G->setDoesNotThrow();
G->setReturnDoesNotAlias();
G->setOnlyAccessesInaccessibleMemory();
G->setWillReturn();
return G;
}
void remapFunctions(llvm::Module *M) {
// simple name-to-name remappings
static const std::vector<std::pair<std::string, std::string>> remapping = {
// 64-bit float intrinsics
{"llvm.ceil.f64", "__nv_ceil"},
{"llvm.floor.f64", "__nv_floor"},
{"llvm.fabs.f64", "__nv_fabs"},
{"llvm.exp.f64", "__nv_exp"},
{"llvm.log.f64", "__nv_log"},
{"llvm.log2.f64", "__nv_log2"},
{"llvm.log10.f64", "__nv_log10"},
{"llvm.sqrt.f64", "__nv_sqrt"},
{"llvm.pow.f64", "__nv_pow"},
{"llvm.sin.f64", "__nv_sin"},
{"llvm.cos.f64", "__nv_cos"},
{"llvm.copysign.f64", "__nv_copysign"},
{"llvm.trunc.f64", "__nv_trunc"},
{"llvm.rint.f64", "__nv_rint"},
{"llvm.nearbyint.f64", "__nv_nearbyint"},
{"llvm.round.f64", "__nv_round"},
{"llvm.minnum.f64", "__nv_fmin"},
{"llvm.maxnum.f64", "__nv_fmax"},
{"llvm.copysign.f64", "__nv_copysign"},
{"llvm.fma.f64", "__nv_fma"},
// 64-bit float math functions
{"expm1", "__nv_expm1"},
{"ldexp", "__nv_ldexp"},
{"acos", "__nv_acos"},
{"asin", "__nv_asin"},
{"atan", "__nv_atan"},
{"atan2", "__nv_atan2"},
{"hypot", "__nv_hypot"},
{"tan", "__nv_tan"},
{"cosh", "__nv_cosh"},
{"sinh", "__nv_sinh"},
{"tanh", "__nv_tanh"},
{"acosh", "__nv_acosh"},
{"asinh", "__nv_asinh"},
{"atanh", "__nv_atanh"},
{"erf", "__nv_erf"},
{"erfc", "__nv_erfc"},
{"tgamma", "__nv_tgamma"},
{"lgamma", "__nv_lgamma"},
{"remainder", "__nv_remainder"},
{"frexp", "__nv_frexp"},
{"modf", "__nv_modf"},
// 32-bit float intrinsics
{"llvm.ceil.f32", "__nv_ceilf"},
{"llvm.floor.f32", "__nv_floorf"},
{"llvm.fabs.f32", "__nv_fabsf"},
{"llvm.exp.f32", "__nv_expf"},
{"llvm.log.f32", "__nv_logf"},
{"llvm.log2.f32", "__nv_log2f"},
{"llvm.log10.f32", "__nv_log10f"},
{"llvm.sqrt.f32", "__nv_sqrtf"},
{"llvm.pow.f32", "__nv_powf"},
{"llvm.sin.f32", "__nv_sinf"},
{"llvm.cos.f32", "__nv_cosf"},
{"llvm.copysign.f32", "__nv_copysignf"},
{"llvm.trunc.f32", "__nv_truncf"},
{"llvm.rint.f32", "__nv_rintf"},
{"llvm.nearbyint.f32", "__nv_nearbyintf"},
{"llvm.round.f32", "__nv_roundf"},
{"llvm.minnum.f32", "__nv_fminf"},
{"llvm.maxnum.f32", "__nv_fmaxf"},
{"llvm.copysign.f32", "__nv_copysignf"},
{"llvm.fma.f32", "__nv_fmaf"},
// 32-bit float math functions
{"expm1f", "__nv_expm1f"},
{"ldexpf", "__nv_ldexpf"},
{"acosf", "__nv_acosf"},
{"asinf", "__nv_asinf"},
{"atanf", "__nv_atanf"},
{"atan2f", "__nv_atan2f"},
{"hypotf", "__nv_hypotf"},
{"tanf", "__nv_tanf"},
{"coshf", "__nv_coshf"},
{"sinhf", "__nv_sinhf"},
{"tanhf", "__nv_tanhf"},
{"acoshf", "__nv_acoshf"},
{"asinhf", "__nv_asinhf"},
{"atanhf", "__nv_atanhf"},
{"erff", "__nv_erff"},
{"erfcf", "__nv_erfcf"},
{"tgammaf", "__nv_tgammaf"},
{"lgammaf", "__nv_lgammaf"},
{"remainderf", "__nv_remainderf"},
{"frexpf", "__nv_frexpf"},
{"modff", "__nv_modff"},
// runtime library functions
{"seq_free", "free"},
{"seq_register_finalizer", ""},
{"seq_gc_add_roots", ""},
{"seq_gc_remove_roots", ""},
{"seq_gc_clear_roots", ""},
{"seq_gc_exclude_static_roots", ""},
};
// functions that need to be generated as they're not available on GPU
static const std::vector<std::pair<std::string, Codegen>> fillins = {
{"seq_alloc",
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
auto *M = B.GetInsertBlock()->getModule();
llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]);
B.CreateRet(mem);
}},
{"seq_alloc_atomic",
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
auto *M = B.GetInsertBlock()->getModule();
llvm::Value *mem = B.CreateCall(makeMalloc(M), args[0]);
B.CreateRet(mem);
}},
{"seq_realloc",
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
auto *M = B.GetInsertBlock()->getModule();
llvm::Value *mem = B.CreateCall(makeMalloc(M), args[1]);
auto F = llvm::Intrinsic::getDeclaration(
M, llvm::Intrinsic::memcpy,
{B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt64Ty()});
B.CreateCall(F, {mem, args[0], args[2], B.getFalse()});
B.CreateRet(mem);
}},
{"seq_calloc",
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
auto *M = B.GetInsertBlock()->getModule();
llvm::Value *size = B.CreateMul(args[0], args[1]);
llvm::Value *mem = B.CreateCall(makeMalloc(M), size);
auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset,
{B.getInt8PtrTy(), B.getInt64Ty()});
B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()});
B.CreateRet(mem);
}},
{"seq_calloc_atomic",
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
auto *M = B.GetInsertBlock()->getModule();
llvm::Value *size = B.CreateMul(args[0], args[1]);
llvm::Value *mem = B.CreateCall(makeMalloc(M), size);
auto F = llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::memset,
{B.getInt8PtrTy(), B.getInt64Ty()});
B.CreateCall(F, {mem, B.getInt8(0), size, B.getFalse()});
B.CreateRet(mem);
}},
{"seq_alloc_exc",
[](llvm::IRBuilder<> &B, const std::vector<llvm::Value *> &args) {
// TODO: print error message and abort if in debug mode
B.CreateUnreachable();
}},
{"seq_throw",
[](llvm::IRBuilder<> &B,
const std::vector<llvm::Value *> &args) { B.CreateUnreachable(); }},
};
for (auto &pair : remapping) {
if (auto *F = M->getFunction(pair.first)) {
llvm::Function *G = nullptr;
if (pair.second.empty()) {
G = makeNoOp(F);
} else {
G = M->getFunction(pair.second);
if (!G)
G = copyPrototype(F, pair.second);
}
G->setWillReturn();
F->replaceAllUsesWith(G);
F->dropAllReferences();
F->eraseFromParent();
}
}
for (auto &pair : fillins) {
if (auto *F = M->getFunction(pair.first)) {
llvm::Function *G = makeFillIn(F, pair.second);
F->replaceAllUsesWith(G);
F->dropAllReferences();
F->eraseFromParent();
}
}
}
void exploreGV(llvm::GlobalValue *G, llvm::SmallPtrSetImpl<llvm::GlobalValue *> &keep) {
if (keep.contains(G))
return;
keep.insert(G);
if (auto *F = llvm::dyn_cast<llvm::Function>(G)) {
for (auto I = llvm::inst_begin(F), E = inst_end(F); I != E; ++I) {
for (auto &U : I->operands()) {
if (auto *G2 = llvm::dyn_cast<llvm::GlobalValue>(U.get()))
exploreGV(G2, keep);
}
}
}
}
std::vector<llvm::GlobalValue *>
getRequiredGVs(const std::vector<llvm::GlobalValue *> &kernels) {
llvm::SmallPtrSet<llvm::GlobalValue *, 32> keep;
for (auto *G : kernels) {
exploreGV(G, keep);
}
return std::vector<llvm::GlobalValue *>(keep.begin(), keep.end());
}
void moduleToPTX(llvm::Module *M, const std::string &filename,
std::vector<llvm::GlobalValue *> &kernels,
const std::string &cpuStr = "sm_30",
const std::string &featuresStr = "+ptx42") {
llvm::Triple triple(llvm::Triple::normalize(GPU_TRIPLE));
llvm::TargetLibraryInfoImpl tlii(triple);
std::string err;
const llvm::Target *target =
llvm::TargetRegistry::lookupTarget("nvptx64", triple, err);
seqassertn(target, "couldn't lookup target: {}", err);
const llvm::TargetOptions options =
llvm::codegen::InitTargetOptionsFromCodeGenFlags(triple);
std::unique_ptr<llvm::TargetMachine> machine(target->createTargetMachine(
triple.getTriple(), cpuStr, featuresStr, options,
llvm::codegen::getExplicitRelocModel(), llvm::codegen::getExplicitCodeModel(),
llvm::CodeGenOpt::Aggressive));
M->setDataLayout(machine->createDataLayout());
auto keep = getRequiredGVs(kernels);
auto prune = [&](std::vector<llvm::GlobalValue *> keep) {
auto pm = std::make_unique<llvm::legacy::PassManager>();
pm->add(new llvm::TargetLibraryInfoWrapperPass(tlii));
// Delete everything but kernel functions.
pm->add(llvm::createGVExtractionPass(keep));
// Delete unreachable globals.
pm->add(llvm::createGlobalDCEPass());
// Remove dead debug info.
pm->add(llvm::createStripDeadDebugInfoPass());
// Remove dead func decls.
pm->add(llvm::createStripDeadPrototypesPass());
pm->run(*M);
};
// Remove non-kernel functions.
prune(keep);
// Link libdevice and other cleanup.
linkLibdevice(M, libdevice);
remapFunctions(M);
// Run NVPTX passes and general opt pipeline.
{
auto pm = std::make_unique<llvm::legacy::PassManager>();
auto fpm = std::make_unique<llvm::legacy::FunctionPassManager>(M);
pm->add(new llvm::TargetLibraryInfoWrapperPass(tlii));
pm->add(llvm::createTargetTransformInfoWrapperPass(
machine ? machine->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
fpm->add(llvm::createTargetTransformInfoWrapperPass(
machine ? machine->getTargetIRAnalysis() : llvm::TargetIRAnalysis()));
if (machine) {
auto &ltm = dynamic_cast<llvm::LLVMTargetMachine &>(*machine);
llvm::Pass *tpc = ltm.createPassConfig(*pm);
pm->add(tpc);
}
pm->add(llvm::createInternalizePass([&](const llvm::GlobalValue &gv) {
return std::find(keep.begin(), keep.end(), &gv) != keep.end();
}));
llvm::PassManagerBuilder pmb;
unsigned optLevel = 3, sizeLevel = 0;
pmb.OptLevel = optLevel;
pmb.SizeLevel = sizeLevel;
pmb.Inliner = llvm::createFunctionInliningPass(optLevel, sizeLevel, false);
pmb.DisableUnrollLoops = false;
pmb.LoopVectorize = true;
pmb.SLPVectorize = true;
if (machine) {
machine->adjustPassManager(pmb);
}
pmb.populateModulePassManager(*pm);
pmb.populateFunctionPassManager(*fpm);
fpm->doInitialization();
for (llvm::Function &f : *M) {
fpm->run(f);
}
fpm->doFinalization();
pm->run(*M);
}
// Prune again after optimizations.
keep = getRequiredGVs(kernels);
prune(keep);
// Clean up names.
{
for (auto &G : M->globals()) {
G.setName(cleanUpName(G.getName()));
}
for (auto &F : M->functions()) {
if (F.getInstructionCount() > 0)
F.setName(cleanUpName(F.getName()));
}
for (auto *S : M->getIdentifiedStructTypes()) {
S->setName(cleanUpName(S->getName()));
}
}
// Generate PTX file.
{
std::error_code errcode;
auto out = std::make_unique<llvm::ToolOutputFile>(filename, errcode,
llvm::sys::fs::OF_Text);
if (errcode)
compilationError(errcode.message());
llvm::raw_pwrite_stream *os = &out->os();
auto &llvmtm = static_cast<llvm::LLVMTargetMachine &>(*machine);
auto *mmiwp = new llvm::MachineModuleInfoWrapperPass(&llvmtm);
llvm::legacy::PassManager pm;
pm.add(new llvm::TargetLibraryInfoWrapperPass(tlii));
seqassertn(!machine->addPassesToEmitFile(pm, *os, nullptr, llvm::CGFT_AssemblyFile,
/*DisableVerify=*/false, mmiwp),
"could not add passes");
const_cast<llvm::TargetLoweringObjectFile *>(llvmtm.getObjFileLowering())
->Initialize(mmiwp->getMMI().getContext(), *machine);
pm.run(*M);
out->keep();
}
}
void addInitCall(llvm::Module *M, const std::string &filename) {
llvm::LLVMContext &context = M->getContext();
auto f =
M->getOrInsertFunction("seq_nvptx_load_module", llvm::Type::getVoidTy(context),
llvm::Type::getInt8PtrTy(context));
auto *g = llvm::cast<llvm::Function>(f.getCallee());
g->setDoesNotThrow();
auto *filenameVar = new llvm::GlobalVariable(
*M, llvm::ArrayType::get(llvm::Type::getInt8Ty(context), filename.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(context, filename), ".nvptx.filename");
filenameVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
llvm::IRBuilder<> B(context);
if (auto *init = M->getFunction("seq_init")) {
seqassertn(init->hasOneUse(), "seq_init used more than once");
auto *use = llvm::dyn_cast<llvm::CallBase>(init->use_begin()->getUser());
seqassertn(use, "seq_init use was not a call");
B.SetInsertPoint(use->getNextNode());
B.CreateCall(g, B.CreateBitCast(filenameVar, B.getInt8PtrTy()));
}
for (auto &F : M->functions()) {
if (F.hasFnAttribute("jit")) {
B.SetInsertPoint(F.getEntryBlock().getFirstNonPHI());
B.CreateCall(g, B.CreateBitCast(filenameVar, B.getInt8PtrTy()));
}
}
}
void cleanUpIntrinsics(llvm::Module *M) {
llvm::LLVMContext &context = M->getContext();
llvm::SmallVector<llvm::Function *, 16> remove;
for (auto &F : *M) {
if (F.getIntrinsicID() != llvm::Intrinsic::not_intrinsic &&
F.getName().startswith("llvm.nvvm"))
remove.push_back(&F);
}
for (auto *F : remove) {
F->replaceAllUsesWith(makeNoOp(F));
F->dropAllReferences();
F->eraseFromParent();
}
}
} // namespace
void applyGPUTransformations(llvm::Module *M, const std::string &ptxFilename) {
llvm::LLVMContext &context = M->getContext();
std::unique_ptr<llvm::Module> clone = llvm::CloneModule(*M);
clone->setTargetTriple(llvm::Triple::normalize(GPU_TRIPLE));
clone->setDataLayout(GPU_DL);
llvm::NamedMDNode *nvvmAnno = clone->getOrInsertNamedMetadata("nvvm.annotations");
std::vector<llvm::GlobalValue *> kernels;
for (auto &F : *clone) {
if (!F.hasFnAttribute("kernel"))
continue;
llvm::Metadata *nvvmElem[] = {
llvm::ConstantAsMetadata::get(&F),
llvm::MDString::get(context, "kernel"),
llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)),
};
nvvmAnno->addOperand(llvm::MDNode::get(context, nvvmElem));
kernels.push_back(&F);
}
if (kernels.empty())
return;
std::string filename = ptxFilename.empty() ? M->getSourceFileName() : ptxFilename;
if (filename.empty() || filename[0] == '<')
filename = "kernel";
llvm::SmallString<128> path(filename);
llvm::sys::path::replace_extension(path, "ptx");
filename = path.str();
moduleToPTX(clone.get(), filename, kernels);
cleanUpIntrinsics(M);
addInitCall(M, filename);
}
} // namespace ir
} // namespace codon

View File

@ -0,0 +1,19 @@
#pragma once
#include <string>
#include "codon/sir/llvm/llvm.h"
namespace codon {
namespace ir {
/// Applies GPU-specific transformations and generates PTX
/// code from kernel functions in the given LLVM module.
/// @param module LLVM module containing GPU kernel functions (marked with "kernel"
/// annotation)
/// @param ptxFilename Filename for output PTX code; empty to use filename based on
/// module
void applyGPUTransformations(llvm::Module *module, const std::string &ptxFilename = "");
} // namespace ir
} // namespace codon

View File

@ -22,6 +22,7 @@ namespace {
const std::string EXPORT_ATTR = "std.internal.attributes.export";
const std::string INLINE_ATTR = "std.internal.attributes.inline";
const std::string NOINLINE_ATTR = "std.internal.attributes.noinline";
const std::string GPU_KERNEL_ATTR = "std.gpu.kernel";
} // namespace
llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
@ -41,6 +42,8 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
std::string LLVMVisitor::getNameForFunction(const Func *x) {
if (isA<ExternalFunc>(x) || util::hasAttribute(x, EXPORT_ATTR)) {
return x->getUnmangledName();
} else if (util::hasAttribute(x, GPU_KERNEL_ATTR)) {
return x->getName();
} else {
return x->referenceString();
}
@ -49,7 +52,7 @@ std::string LLVMVisitor::getNameForFunction(const Func *x) {
LLVMVisitor::LLVMVisitor()
: util::ConstVisitor(), context(std::make_unique<llvm::LLVMContext>()), M(),
B(std::make_unique<llvm::IRBuilder<>>(*context)), func(nullptr), block(nullptr),
value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), db(),
value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), catches(), db(),
plugins(nullptr) {
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
@ -287,6 +290,7 @@ LLVMVisitor::takeModule(Module *module, const SrcInfo *src) {
coro.reset();
loops.clear();
trycatch.clear();
catches.clear();
db.reset();
context = std::make_unique<llvm::LLVMContext>();
@ -697,6 +701,16 @@ void LLVMVisitor::exitTryCatch() {
trycatch.pop_back();
}
void LLVMVisitor::enterCatch(CatchData data) {
catches.push_back(std::move(data));
catches.back().sequenceNumber = nextSequenceNumber++;
}
void LLVMVisitor::exitCatch() {
seqassertn(!catches.empty(), "no catches present");
catches.pop_back();
}
LLVMVisitor::TryCatchData *LLVMVisitor::getInnermostTryCatch() {
return trycatch.empty() ? nullptr : &trycatch.back();
}
@ -1119,7 +1133,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
err.print("LLVM", buf);
// LOG("-> ERR {}", x->referenceString());
// LOG(" {}", code);
compilationError(buf.str());
compilationError(fmt::format("{} ({})", buf.str(), x->getName()));
}
sub->setDataLayout(M->getDataLayout());
@ -1152,6 +1166,9 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
setDebugInfoForNode(x);
auto *fnAttributes = x->getAttribute<KeyValueAttribute>();
if (x->isJIT()) {
func->addFnAttr(llvm::Attribute::get(*context, "jit"));
}
if (x->isJIT() || (fnAttributes && fnAttributes->has(EXPORT_ATTR))) {
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
} else {
@ -1163,6 +1180,11 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
if (fnAttributes && fnAttributes->has(NOINLINE_ATTR)) {
func->addFnAttr(llvm::Attribute::AttrKind::NoInline);
}
if (fnAttributes && fnAttributes->has(GPU_KERNEL_ATTR)) {
func->addFnAttr(llvm::Attribute::AttrKind::NoInline);
func->addFnAttr(llvm::Attribute::get(*context, "kernel"));
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
}
func->setPersonalityFn(llvm::cast<llvm::Constant>(makePersonalityFunc().getCallee()));
auto *funcType = cast<types::FuncType>(x->getType());
@ -1352,6 +1374,10 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
return B->getDoubleTy();
}
if (auto *x = cast<types::Float32Type>(t)) {
return B->getFloatTy();
}
if (auto *x = cast<types::BoolType>(t)) {
return B->getInt8Ty();
}
@ -1406,6 +1432,11 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
return B->getIntNTy(x->getLen());
}
if (auto *x = cast<types::VectorType>(t)) {
return llvm::VectorType::get(getLLVMType(x->getBase()), x->getCount(),
/*Scalable=*/false);
}
if (auto *x = cast<dsl::types::CustomType>(t)) {
return x->getBuilder()->buildType(this);
}
@ -1429,6 +1460,11 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}
if (auto *x = cast<types::Float32Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}
if (auto *x = cast<types::BoolType>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean);
@ -1554,6 +1590,12 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
x->isSigned() ? llvm::dwarf::DW_ATE_signed : llvm::dwarf::DW_ATE_unsigned);
}
if (auto *x = cast<types::VectorType>(t)) {
return db.builder->createBasicType(x->getName(),
layout.getTypeAllocSizeInBits(type),
llvm::dwarf::DW_ATE_unsigned);
}
if (auto *x = cast<dsl::types::CustomType>(t)) {
return x->getBuilder()->buildDebugType(this);
}
@ -2094,7 +2136,12 @@ void LLVMVisitor::visit(const TryCatchFlow *x) {
}
B->CreateStore(excStateCaught, tc.excFlag);
CatchData cd;
cd.exception = objPtr;
cd.typeId = objType;
enterCatch(cd);
process(catches[i]->getHandler());
exitCatch();
B->SetInsertPoint(block);
B->CreateBr(tc.finallyBlock);
}
@ -2455,10 +2502,21 @@ void LLVMVisitor::visit(const ThrowInstr *x) {
// note: exception header should be set in the frontend
auto excAllocFunc = makeExcAllocFunc();
auto throwFunc = makeThrowFunc();
llvm::Value *obj = nullptr;
llvm::Value *typ = nullptr;
if (x->getValue()) {
process(x->getValue());
obj = value;
typ = B->getInt32(getTypeIdx(x->getValue()->getType()));
} else {
seqassertn(!catches.empty(), "empty raise outside of except block");
obj = catches.back().exception;
typ = catches.back().typeId;
}
B->SetInsertPoint(block);
llvm::Value *exc = B->CreateCall(
excAllocFunc, {B->getInt32(getTypeIdx(x->getValue()->getType())), value});
llvm::Value *exc = B->CreateCall(excAllocFunc, {typ, obj});
call(throwFunc, exc);
}

View File

@ -89,6 +89,11 @@ private:
}
};
struct CatchData : NestableData {
llvm::Value *exception;
llvm::Value *typeId;
};
struct DebugInfo {
/// LLVM debug info builder
std::unique_ptr<llvm::DIBuilder> builder;
@ -139,6 +144,8 @@ private:
std::vector<LoopData> loops;
/// Try-catch data stack
std::vector<TryCatchData> trycatch;
/// Catch-block data stack
std::vector<CatchData> catches;
/// Debug information
DebugInfo db;
/// Plugin manager
@ -181,6 +188,8 @@ private:
void exitLoop();
void enterTryCatch(TryCatchData data);
void exitTryCatch();
void enterCatch(CatchData data);
void exitCatch();
TryCatchData *getInnermostTryCatch();
TryCatchData *getInnermostTryCatchBeforeLoop();

View File

@ -32,6 +32,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"

View File

@ -3,6 +3,7 @@
#include <algorithm>
#include "codon/sir/llvm/coro/Coroutines.h"
#include "codon/sir/llvm/gpu.h"
#include "codon/util/common.h"
#include "llvm/CodeGen/CommandFlags.h"
@ -642,13 +643,17 @@ void verify(llvm::Module *module) {
void optimize(llvm::Module *module, bool debug, bool jit, PluginManager *plugins) {
verify(module);
{
TIME("llvm/opt");
TIME("llvm/opt1");
runLLVMOptimizationPasses(module, debug, jit, plugins);
}
if (!debug) {
TIME("llvm/opt2");
runLLVMOptimizationPasses(module, debug, jit, plugins);
}
{
TIME("llvm/gpu");
applyGPUTransformations(module);
}
verify(module);
}

View File

@ -57,6 +57,7 @@ const std::string Module::BOOL_NAME = "bool";
const std::string Module::BYTE_NAME = "byte";
const std::string Module::INT_NAME = "int";
const std::string Module::FLOAT_NAME = "float";
const std::string Module::FLOAT32_NAME = "float32";
const std::string Module::STRING_NAME = "str";
const std::string Module::EQ_MAGIC_NAME = "__eq__";
@ -126,7 +127,9 @@ Func *Module::getOrRealizeMethod(types::Type *parent, const std::string &methodN
return cache->realizeFunction(method, translateArgs(args),
translateGenerics(generics), cls);
} catch (const exc::ParserException &e) {
LOG_IR("getOrRealizeMethod parser error: {}", e.what());
for (int i = 0; i < e.messages.size(); i++)
LOG_IR("getOrRealizeMethod parser error at {}: {}", e.locations[i],
e.messages[i]);
return nullptr;
}
}
@ -162,7 +165,8 @@ types::Type *Module::getOrRealizeType(const std::string &typeName,
try {
return cache->realizeType(type, translateGenerics(generics));
} catch (const exc::ParserException &e) {
LOG_IR("getOrRealizeType parser error: {}", e.what());
for (int i = 0; i < e.messages.size(); i++)
LOG_IR("getOrRealizeType parser error at {}: {}", e.locations[i], e.messages[i]);
return nullptr;
}
}
@ -197,6 +201,12 @@ types::Type *Module::getFloatType() {
return Nr<types::FloatType>();
}
types::Type *Module::getFloat32Type() {
if (auto *rVal = getType(FLOAT32_NAME))
return rVal;
return Nr<types::Float32Type>();
}
types::Type *Module::getStringType() {
if (auto *rVal = getType(STRING_NAME))
return rVal;
@ -243,6 +253,10 @@ types::Type *Module::getIntNType(unsigned int len, bool sign) {
return getOrRealizeType(sign ? "Int" : "UInt", {len});
}
types::Type *Module::getVectorType(unsigned count, types::Type *base) {
return getOrRealizeType("Vec", {base, count});
}
types::Type *Module::getTupleType(std::vector<types::Type *> args) {
std::vector<ast::types::TypePtr> argTypes;
for (auto *t : args) {
@ -332,5 +346,14 @@ types::Type *Module::unsafeGetIntNType(unsigned int len, bool sign) {
return Nr<types::IntNType>(len, sign);
}
types::Type *Module::unsafeGetVectorType(unsigned int count, types::Type *base) {
auto *primitive = cast<types::PrimitiveType>(base);
auto name = types::VectorType::getInstanceName(count, primitive);
if (auto *rVal = getType(name))
return rVal;
seqassertn(primitive, "base type must be a primitive type");
return Nr<types::VectorType>(count, primitive);
}
} // namespace ir
} // namespace codon

View File

@ -31,6 +31,7 @@ public:
static const std::string BYTE_NAME;
static const std::string INT_NAME;
static const std::string FLOAT_NAME;
static const std::string FLOAT32_NAME;
static const std::string STRING_NAME;
static const std::string EQ_MAGIC_NAME;
@ -305,6 +306,8 @@ public:
types::Type *getIntType();
/// @return the float type
types::Type *getFloatType();
/// @return the float32 type
types::Type *getFloat32Type();
/// @return the string type
types::Type *getStringType();
/// Gets a pointer type.
@ -335,6 +338,11 @@ public:
/// @param sign true if signed
/// @return a variable length integer type
types::Type *getIntNType(unsigned len, bool sign);
/// Gets a vector type.
/// @param count the vector size
/// @param base the vector base type (MUST be a primitive type)
/// @return a vector type
types::Type *getVectorType(unsigned count, types::Type *base);
/// Gets a tuple type.
/// @param args the arg types
/// @return the tuple type
@ -401,6 +409,12 @@ public:
/// @param sign true if signed
/// @return a variable length integer type
types::Type *unsafeGetIntNType(unsigned len, bool sign);
/// Gets a vector type. Should generally not be used as no
/// type-checker information is generated.
/// @param count the vector size
/// @param base the vector base type (MUST be a primitive type)
/// @return a vector type
types::Type *unsafeGetVectorType(unsigned count, types::Type *base);
private:
void store(types::Type *t) {

View File

@ -281,6 +281,9 @@ struct CanonConstSub : public RewriteRule {
Value *lhs = v->front();
Value *rhs = v->back();
if (!lhs->getType()->is(rhs->getType()))
return;
Value *newCall = nullptr;
if (util::isConst<int64_t>(rhs)) {
auto c = util::getConst<int64_t>(rhs);

View File

@ -1,6 +1,7 @@
#include "const_fold.h"
#include <cmath>
#include <utility>
#include "codon/sir/util/cloning.h"
#include "codon/sir/util/irtools.h"
@ -15,15 +16,28 @@ namespace ir {
namespace transform {
namespace folding {
namespace {
auto pyDivmod(int64_t self, int64_t other) {
auto d = self / other;
auto m = self - d * other;
if (m && ((other ^ m) < 0)) {
m += other;
d -= 1;
}
return std::make_pair(d, m);
}
template <typename Func, typename Out> class IntFloatBinaryRule : public RewriteRule {
private:
Func f;
std::string magic;
types::Type *out;
bool excludeRHSZero;
public:
IntFloatBinaryRule(Func f, std::string magic, types::Type *out)
: f(std::move(f)), magic(std::move(magic)), out(out) {}
IntFloatBinaryRule(Func f, std::string magic, types::Type *out,
bool excludeRHSZero = false)
: f(std::move(f)), magic(std::move(magic)), out(out),
excludeRHSZero(excludeRHSZero) {}
virtual ~IntFloatBinaryRule() noexcept = default;
@ -41,11 +55,15 @@ public:
if (isA<FloatConst>(leftConst) && isA<IntConst>(rightConst)) {
auto left = cast<FloatConst>(leftConst)->getVal();
auto right = cast<IntConst>(rightConst)->getVal();
if (excludeRHSZero && right == 0)
return;
return setResult(M->template N<TemplatedConst<Out>>(v->getSrcInfo(),
f(left, (double)right), out));
} else if (isA<IntConst>(leftConst) && isA<FloatConst>(rightConst)) {
auto left = cast<IntConst>(leftConst)->getVal();
auto right = cast<FloatConst>(rightConst)->getVal();
if (excludeRHSZero && right == 0.0)
return;
return setResult(M->template N<TemplatedConst<Out>>(v->getSrcInfo(),
f((double)left, right), out));
}
@ -140,15 +158,22 @@ template <typename Func> auto floatToFloatBinary(Module *m, Func f, std::string
std::move(f), std::move(magic), m->getFloatType(), m->getFloatType());
}
template <typename Func>
auto floatToFloatBinaryNoZeroRHS(Module *m, Func f, std::string magic) {
return std::make_unique<DoubleConstantBinaryRuleExcludeRHSZero<double, Func, double>>(
std::move(f), std::move(magic), m->getFloatType(), m->getFloatType());
}
template <typename Func> auto floatToBoolBinary(Module *m, Func f, std::string magic) {
return std::make_unique<DoubleConstantBinaryRule<double, Func, bool>>(
std::move(f), std::move(magic), m->getFloatType(), m->getBoolType());
}
template <typename Func>
auto intFloatToFloatBinary(Module *m, Func f, std::string magic) {
auto intFloatToFloatBinary(Module *m, Func f, std::string magic,
bool excludeRHSZero = false) {
return std::make_unique<IntFloatBinaryRule<Func, double>>(
std::move(f), std::move(magic), m->getFloatType());
std::move(f), std::move(magic), m->getFloatType(), excludeRHSZero);
}
template <typename Func>
@ -222,8 +247,15 @@ void FoldingPass::registerStandardRules(Module *m) {
intToIntBinary(m, BINOP(+), Module::ADD_MAGIC_NAME));
registerRule("int-constant-subtraction",
intToIntBinary(m, BINOP(-), Module::SUB_MAGIC_NAME));
if (pyNumerics) {
registerRule("int-constant-floor-div",
intToIntBinaryNoZeroRHS(
m, [](auto x, auto y) -> auto{ return pyDivmod(x, y).first; },
Module::FLOOR_DIV_MAGIC_NAME));
} else {
registerRule("int-constant-floor-div",
intToIntBinaryNoZeroRHS(m, BINOP(/), Module::FLOOR_DIV_MAGIC_NAME));
}
registerRule("int-constant-mul", intToIntBinary(m, BINOP(*), Module::MUL_MAGIC_NAME));
registerRule("int-constant-lshift",
intToIntBinary(m, BINOP(<<), Module::LSHIFT_MAGIC_NAME));
@ -233,8 +265,15 @@ void FoldingPass::registerStandardRules(Module *m) {
registerRule("int-constant-xor", intToIntBinary(m, BINOP(^), Module::XOR_MAGIC_NAME));
registerRule("int-constant-or", intToIntBinary(m, BINOP(|), Module::OR_MAGIC_NAME));
registerRule("int-constant-and", intToIntBinary(m, BINOP(&), Module::AND_MAGIC_NAME));
if (pyNumerics) {
registerRule("int-constant-mod",
intToIntBinaryNoZeroRHS(
m, [](auto x, auto y) -> auto{ return pyDivmod(x, y).second; },
Module::MOD_MAGIC_NAME));
} else {
registerRule("int-constant-mod",
intToIntBinaryNoZeroRHS(m, BINOP(%), Module::MOD_MAGIC_NAME));
}
// binary, double constant, int->bool
registerRule("int-constant-eq", intToBoolBinary(m, BINOP(==), Module::EQ_MAGIC_NAME));
@ -273,8 +312,13 @@ void FoldingPass::registerStandardRules(Module *m) {
floatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME));
registerRule("float-constant-subtraction",
floatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME));
if (pyNumerics) {
registerRule("float-constant-floor-div",
floatToFloatBinaryNoZeroRHS(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME));
} else {
registerRule("float-constant-floor-div",
floatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME));
}
registerRule("float-constant-mul",
floatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME));
registerRule(
@ -301,8 +345,9 @@ void FoldingPass::registerStandardRules(Module *m) {
intFloatToFloatBinary(m, BINOP(+), Module::ADD_MAGIC_NAME));
registerRule("int-float-constant-subtraction",
intFloatToFloatBinary(m, BINOP(-), Module::SUB_MAGIC_NAME));
registerRule("int-float-constant-floor-div",
intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME));
registerRule(
"int-float-constant-floor-div",
intFloatToFloatBinary(m, BINOP(/), Module::TRUE_DIV_MAGIC_NAME, pyNumerics));
registerRule("int-float-constant-mul",
intFloatToFloatBinary(m, BINOP(*), Module::MUL_MAGIC_NAME));

View File

@ -12,18 +12,21 @@ namespace transform {
namespace folding {
class FoldingPass : public OperatorPass, public Rewriter {
private:
bool pyNumerics;
void registerStandardRules(Module *m);
public:
/// Constructs a folding pass.
FoldingPass() : OperatorPass(/*childrenFirst=*/true) {}
FoldingPass(bool pyNumerics = false)
: OperatorPass(/*childrenFirst=*/true), pyNumerics(pyNumerics) {}
static const std::string KEY;
std::string getKey() const override { return KEY; }
void run(Module *m) override;
void handle(CallInstr *v) override;
private:
void registerStandardRules(Module *m);
};
} // namespace folding

View File

@ -13,12 +13,12 @@ const std::string FoldingPassGroup::KEY = "core-folding-pass-group";
FoldingPassGroup::FoldingPassGroup(const std::string &sideEffectsPass,
const std::string &reachingDefPass,
const std::string &globalVarPass, int repeat,
bool runGlobalDemotion)
bool runGlobalDemotion, bool pyNumerics)
: PassGroup(repeat) {
auto gdUnique = runGlobalDemotion ? std::make_unique<cleanup::GlobalDemotionPass>()
: std::unique_ptr<cleanup::GlobalDemotionPass>();
auto canonUnique = std::make_unique<cleanup::CanonicalizationPass>(sideEffectsPass);
auto fpUnique = std::make_unique<FoldingPass>();
auto fpUnique = std::make_unique<FoldingPass>(pyNumerics);
auto dceUnique = std::make_unique<cleanup::DeadCodeCleanupPass>(sideEffectsPass);
gd = gdUnique.get();

View File

@ -29,9 +29,11 @@ public:
/// @param globalVarPass the key of the global variables pass
/// @param repeat default number of times to repeat the pass
/// @param runGlobalDemotion whether to demote globals if possible
/// @param pyNumerics whether to use Python (vs. C) semantics when folding
FoldingPassGroup(const std::string &sideEffectsPass,
const std::string &reachingDefPass, const std::string &globalVarPass,
int repeat = 5, bool runGlobalDemotion = true);
int repeat = 5, bool runGlobalDemotion = true,
bool pyNumerics = false);
bool shouldRepeat(int num) const override;
};

View File

@ -185,9 +185,9 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
capKey,
/*globalAssignmentHasSideEffects=*/false),
{capKey});
registerPass(
std::make_unique<folding::FoldingPassGroup>(
seKey1, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/false),
registerPass(std::make_unique<folding::FoldingPassGroup>(
seKey1, rdKey, globalKey, /*repeat=*/5, /*runGlobalDemoton=*/false,
pyNumerics),
/*insertBefore=*/"", {seKey1, rdKey, globalKey},
{seKey1, rdKey, cfgKey, globalKey, capKey});
@ -198,10 +198,10 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
if (init != Init::JIT) {
// Don't demote globals in JIT mode, since they might be used later
// by another user input.
registerPass(
std::make_unique<folding::FoldingPassGroup>(seKey2, rdKey, globalKey,
registerPass(std::make_unique<folding::FoldingPassGroup>(
seKey2, rdKey, globalKey,
/*repeat=*/5,
/*runGlobalDemoton=*/true),
/*runGlobalDemoton=*/true, pyNumerics),
/*insertBefore=*/"", {seKey2, rdKey, globalKey},
{seKey2, rdKey, cfgKey, globalKey});
}

View File

@ -89,6 +89,9 @@ private:
/// passes to avoid registering
std::vector<std::string> disabled;
/// whether to use Python (vs. C) numeric semantics in passes
bool pyNumerics;
public:
/// PassManager initialization mode.
enum Init {
@ -98,14 +101,17 @@ public:
JIT,
};
explicit PassManager(Init init, std::vector<std::string> disabled = {})
explicit PassManager(Init init, std::vector<std::string> disabled = {},
bool pyNumerics = false)
: km(), passes(), analyses(), executionOrder(), results(),
disabled(std::move(disabled)) {
disabled(std::move(disabled)), pyNumerics(pyNumerics) {
registerStandardPasses(init);
}
explicit PassManager(bool debug = false, std::vector<std::string> disabled = {})
: PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled)) {}
explicit PassManager(bool debug = false, std::vector<std::string> disabled = {},
bool pyNumerics = false)
: PassManager(debug ? Init::DEBUG : Init::RELEASE, std::move(disabled),
pyNumerics) {}
/// Checks if the given pass is included in this manager.
/// @param key the pass key

View File

@ -3,6 +3,7 @@
#include <algorithm>
#include <iterator>
#include <limits>
#include <unordered_set>
#include "codon/sir/transform/parallel/schedule.h"
#include "codon/sir/util/cloning.h"
@ -15,14 +16,22 @@ namespace transform {
namespace parallel {
namespace {
const std::string ompModule = "std.openmp";
const std::string gpuModule = "std.gpu";
const std::string builtinModule = "std.internal.builtin";
void warn(const std::string &msg, const Value *v) {
auto src = v->getSrcInfo();
compilationWarning(msg, src.file, src.line, src.col);
}
struct OMPTypes {
types::Type *i64 = nullptr;
types::Type *i32 = nullptr;
types::Type *i8ptr = nullptr;
types::Type *i32ptr = nullptr;
explicit OMPTypes(Module *M) {
i64 = M->getIntType();
i32 = M->getIntNType(32, /*sign=*/true);
i8ptr = M->getPointerType(M->getByteType());
i32ptr = M->getPointerType(i32);
@ -142,6 +151,28 @@ struct Reduction {
default:
return nullptr;
}
} else if (isA<types::Float32Type>(type)) {
auto *f32 = M->getOrRealizeType("float32");
float value = 0.0;
switch (kind) {
case Kind::ADD:
value = 0.0;
break;
case Kind::MUL:
value = 1.0;
break;
case Kind::MIN:
value = std::numeric_limits<float>::max();
break;
case Kind::MAX:
value = std::numeric_limits<float>::min();
break;
default:
return nullptr;
}
return (*f32)(*M->getFloat(value));
}
auto *init = (*type)();
@ -239,6 +270,23 @@ struct Reduction {
default:
break;
}
} else if (isA<types::Float32Type>(type)) {
switch (kind) {
case Kind::ADD:
func = "_atomic_float32_add";
break;
case Kind::MUL:
func = "_atomic_float32_mul";
break;
case Kind::MIN:
func = "_atomic_float32_min";
break;
case Kind::MAX:
func = "_atomic_float32_max";
break;
default:
break;
}
}
if (!func.empty()) {
@ -476,10 +524,16 @@ struct SharedInfo {
Reduction reduction; // the reduction we're performing, or empty if none
};
struct ParallelLoopTemplateReplacer : public util::Operator {
struct LoopTemplateReplacer : public util::Operator {
BodiedFunc *parent;
CallInstr *replacement;
Var *loopVar;
LoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar)
: util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar) {}
};
struct ParallelLoopTemplateReplacer : public LoopTemplateReplacer {
ReductionIdentifier *reds;
std::vector<SharedInfo> sharedInfo;
ReductionLocks locks;
@ -489,9 +543,8 @@ struct ParallelLoopTemplateReplacer : public util::Operator {
ParallelLoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar,
ReductionIdentifier *reds)
: util::Operator(), parent(parent), replacement(replacement), loopVar(loopVar),
reds(reds), sharedInfo(), locks(), locRef(nullptr), reductionLocRef(nullptr),
gtid(nullptr) {}
: LoopTemplateReplacer(parent, replacement, loopVar), reds(reds), sharedInfo(),
locks(), locRef(nullptr), reductionLocRef(nullptr), gtid(nullptr) {}
unsigned numReductions() {
unsigned num = 0;
@ -1102,6 +1155,72 @@ struct TaskLoopRoutineStubReplacer : public ParallelLoopTemplateReplacer {
}
};
struct GPULoopBodyStubReplacer : public util::Operator {
CallInstr *replacement;
Var *loopVar;
int64_t step;
GPULoopBodyStubReplacer(CallInstr *replacement, Var *loopVar, int64_t step)
: util::Operator(), replacement(replacement), loopVar(loopVar), step(step) {}
void handle(CallInstr *v) override {
auto *M = v->getModule();
auto *func = util::getFunc(v->getCallee());
if (!func)
return;
auto name = func->getUnmangledName();
if (name == "_gpu_loop_body_stub") {
seqassertn(replacement, "unexpected double replacement");
seqassertn(v->numArgs() == 2, "unexpected loop body stub");
// the template passes gtid, privs and shareds to the body stub for convenience
auto *idx = v->front();
auto *args = v->back();
unsigned next = 0;
std::vector<Value *> newArgs;
for (auto *arg : *replacement) {
// std::cout << "A: " << *arg << std::endl;
if (getVarFromOutlinedArg(arg)->getId() == loopVar->getId()) {
// std::cout << "(loop var)" << std::endl;
newArgs.push_back(idx);
} else {
newArgs.push_back(util::tupleGet(args, next++));
}
}
auto *outlinedFunc = cast<BodiedFunc>(util::getFunc(replacement->getCallee()));
v->replaceAll(util::call(outlinedFunc, newArgs));
replacement = nullptr;
}
if (name == "_loop_step") {
v->replaceAll(M->getInt(step));
}
}
};
struct GPULoopTemplateReplacer : public LoopTemplateReplacer {
int64_t step;
GPULoopTemplateReplacer(BodiedFunc *parent, CallInstr *replacement, Var *loopVar,
int64_t step)
: LoopTemplateReplacer(parent, replacement, loopVar), step(step) {}
void handle(CallInstr *v) override {
auto *M = v->getModule();
auto *func = util::getFunc(v->getCallee());
if (!func)
return;
auto name = func->getUnmangledName();
if (name == "_loop_step") {
v->replaceAll(M->getInt(step));
}
}
};
struct OpenMPTransformData {
util::OutlineResult outline;
std::vector<Var *> sharedVars;
@ -1164,15 +1283,118 @@ ForkCallData createForkCall(Module *M, OMPTypes &types, Value *rawTemplateFunc,
seqassertn(forkFunc, "fork call function not found");
result.fork = util::call(forkFunc, {rawTemplateFunc, forkExtra});
auto *intType = M->getIntType();
if (sched->threads && sched->threads->getType()->is(intType)) {
if (sched->threads && sched->threads->getType()->is(types.i64)) {
auto *pushNumThreadsFunc =
M->getOrRealizeFunc("_push_num_threads", {intType}, {}, ompModule);
M->getOrRealizeFunc("_push_num_threads", {types.i64}, {}, ompModule);
seqassertn(pushNumThreadsFunc, "push num threads func not found");
result.pushNumThreads = util::call(pushNumThreadsFunc, {sched->threads});
}
return result;
}
struct CollapseResult {
ImperativeForFlow *collapsed = nullptr;
SeriesFlow *setup = nullptr;
std::string error;
operator bool() const { return collapsed != nullptr; }
};
struct LoopRange {
ImperativeForFlow *loop;
Var *start;
Var *stop;
int64_t step;
Var *len;
};
CollapseResult collapseLoop(BodiedFunc *parent, ImperativeForFlow *v, int64_t levels) {
auto fail = [](const std::string &error) {
CollapseResult bad;
bad.error = error;
return bad;
};
auto *M = v->getModule();
CollapseResult res;
if (levels < 1)
return fail("'collapse' must be at least 1");
std::vector<ImperativeForFlow *> loopNests = {v};
ImperativeForFlow *curr = v;
for (auto i = 0; i < levels - 1; i++) {
auto *body = cast<SeriesFlow>(curr->getBody());
seqassertn(body, "unexpected loop body");
if (std::distance(body->begin(), body->end()) != 1 ||
!isA<ImperativeForFlow>(body->front()))
return fail("loop nest not collapsible");
curr = cast<ImperativeForFlow>(body->front());
loopNests.push_back(curr);
}
std::vector<LoopRange> ranges;
auto *setup = M->Nr<SeriesFlow>();
auto *intType = M->getIntType();
auto *lenCalc =
M->getOrRealizeFunc("_range_len", {intType, intType, intType}, {}, ompModule);
seqassertn(lenCalc, "range length calculation function not found");
for (auto *loop : loopNests) {
LoopRange range;
range.loop = loop;
range.start = util::makeVar(loop->getStart(), setup, parent)->getVar();
range.stop = util::makeVar(loop->getEnd(), setup, parent)->getVar();
range.step = loop->getStep();
range.len = util::makeVar(util::call(lenCalc, {M->Nr<VarValue>(range.start),
M->Nr<VarValue>(range.stop),
M->getInt(range.step)}),
setup, parent)
->getVar();
ranges.push_back(range);
}
auto *numIters = M->getInt(1);
for (auto &range : ranges) {
numIters = (*numIters) * (*M->Nr<VarValue>(range.len));
}
auto *collapsedVar = M->Nr<Var>(M->getIntType(), /*global=*/false);
parent->push_back(collapsedVar);
auto *body = M->Nr<SeriesFlow>();
auto sched = std::make_unique<OMPSched>(*v->getSchedule());
sched->collapse = 0;
auto *collapsed = M->Nr<ImperativeForFlow>(M->getInt(0), 1, numIters, body,
collapsedVar, std::move(sched));
// reconstruct indices by successive divmods
Var *lastDiv = nullptr;
for (auto it = ranges.rbegin(); it != ranges.rend(); ++it) {
auto *k = lastDiv ? lastDiv : collapsedVar;
auto *div =
util::makeVar(*M->Nr<VarValue>(k) / *M->Nr<VarValue>(it->len), body, parent)
->getVar();
auto *mod =
util::makeVar(*M->Nr<VarValue>(k) % *M->Nr<VarValue>(it->len), body, parent)
->getVar();
auto *i =
*M->Nr<VarValue>(it->start) + *(*M->Nr<VarValue>(mod) * *M->getInt(it->step));
body->push_back(M->Nr<AssignInstr>(it->loop->getVar(), i));
lastDiv = div;
}
auto *oldBody = cast<SeriesFlow>(loopNests.back()->getBody());
for (auto *x : *oldBody) {
body->push_back(x);
}
res.collapsed = collapsed;
res.setup = setup;
return res;
}
} // namespace
const std::string OpenMPPass::KEY = "core-parallel-openmp";
@ -1279,7 +1501,22 @@ void OpenMPPass::handle(ForFlow *v) {
}
void OpenMPPass::handle(ImperativeForFlow *v) {
auto data = setupOpenMPTransform(v, cast<BodiedFunc>(getParentFunc()));
auto *parent = cast<BodiedFunc>(getParentFunc());
if (v->isParallel() && v->getSchedule()->collapse != 0) {
auto levels = v->getSchedule()->collapse;
auto collapse = collapseLoop(parent, v, levels);
if (collapse) {
v->replaceAll(collapse.collapsed);
v = collapse.collapsed;
insertBefore(collapse.setup);
} else if (!collapse.error.empty()) {
warn("could not collapse loop: " + collapse.error, v);
}
}
auto data = setupOpenMPTransform(v, parent);
if (!v->isParallel())
return;
@ -1292,6 +1529,12 @@ void OpenMPPass::handle(ImperativeForFlow *v) {
auto *sched = v->getSchedule();
OMPTypes types(M);
if (sched->gpu && !sharedVars.empty()) {
warn("GPU-parallel loop cannot modify external variables; ignoring", v);
v->setParallel(false);
return;
}
// gather extra arguments
std::vector<Value *> extraArgs;
std::vector<types::Type *> extraArgTypes;
@ -1304,18 +1547,63 @@ void OpenMPPass::handle(ImperativeForFlow *v) {
// template call
std::string templateFuncName;
if (sched->dynamic) {
if (sched->gpu) {
templateFuncName = "_gpu_loop_outline_template";
} else if (sched->dynamic) {
templateFuncName = "_dynamic_loop_outline_template";
} else if (sched->chunk) {
templateFuncName = "_static_chunked_loop_outline_template";
} else {
templateFuncName = "_static_loop_outline_template";
}
auto *intType = M->getIntType();
if (sched->gpu) {
std::unordered_set<id_t> kernels;
const std::string gpuAttr = "std.gpu.kernel";
for (auto *var : *M) {
if (auto *func = cast<BodiedFunc>(var)) {
if (util::hasAttribute(func, gpuAttr))
kernels.insert(func->getId());
}
}
std::vector<types::Type *> templateFuncArgs = {types.i64, types.i64,
M->getTupleType(extraArgTypes)};
static int64_t instance = 0;
auto *templateFunc = M->getOrRealizeFunc(templateFuncName, templateFuncArgs,
{instance++}, gpuModule);
if (!templateFunc) {
warn("loop not compilable for GPU; ignoring", v);
v->setParallel(false);
return;
}
BodiedFunc *kernel = nullptr;
for (auto *var : *M) {
if (auto *func = cast<BodiedFunc>(var)) {
if (util::hasAttribute(func, gpuAttr) && kernels.count(func->getId()) == 0) {
seqassertn(!kernel, "multiple new kernels found after instantiation");
kernel = func;
}
}
}
seqassertn(kernel, "no new kernel found");
GPULoopBodyStubReplacer brep(outline.call, loopVar, v->getStep());
kernel->accept(brep);
util::CloneVisitor cv(M);
templateFunc = cast<Func>(cv.forceClone(templateFunc));
GPULoopTemplateReplacer rep(cast<BodiedFunc>(templateFunc), outline.call, loopVar,
v->getStep());
templateFunc->accept(rep);
v->replaceAll(util::call(
templateFunc, {v->getStart(), v->getEnd(), util::makeTuple(extraArgs, M)}));
} else {
std::vector<types::Type *> templateFuncArgs = {
types.i32ptr, types.i32ptr,
M->getPointerType(M->getTupleType(
{intType, intType, intType, M->getTupleType(extraArgTypes)}))};
{types.i64, types.i64, types.i64, M->getTupleType(extraArgTypes)}))};
auto *templateFunc =
M->getOrRealizeFunc(templateFuncName, templateFuncArgs, {}, ompModule);
seqassertn(templateFunc, "imperative loop outline template not found");
@ -1327,7 +1615,8 @@ void OpenMPPass::handle(ImperativeForFlow *v) {
templateFunc->accept(rep);
auto *rawTemplateFunc = ptrFromFunc(templateFunc);
auto *chunk = (sched->chunk && sched->chunk->getType()->is(intType)) ? sched->chunk
auto *chunk = (sched->chunk && sched->chunk->getType()->is(types.i64))
? sched->chunk
: M->getInt(1);
std::vector<Value *> forkExtraArgs = {chunk, v->getStart(), v->getEnd()};
for (auto *arg : extraArgs) {
@ -1339,6 +1628,7 @@ void OpenMPPass::handle(ImperativeForFlow *v) {
if (forkData.pushNumThreads)
insertBefore(forkData.pushNumThreads);
v->replaceAll(forkData.fork);
}
}
} // namespace parallel

View File

@ -47,17 +47,19 @@ Value *nullIfNeg(Value *v) {
}
} // namespace
OMPSched::OMPSched(int code, bool dynamic, Value *threads, Value *chunk, bool ordered)
OMPSched::OMPSched(int code, bool dynamic, Value *threads, Value *chunk, bool ordered,
int64_t collapse, bool gpu)
: code(code), dynamic(dynamic), threads(nullIfNeg(threads)),
chunk(nullIfNeg(chunk)), ordered(ordered) {
chunk(nullIfNeg(chunk)), ordered(ordered), collapse(collapse), gpu(gpu) {
if (code < 0)
this->code = getScheduleCode();
}
OMPSched::OMPSched(const std::string &schedule, Value *threads, Value *chunk,
bool ordered)
bool ordered, int64_t collapse, bool gpu)
: OMPSched(getScheduleCode(schedule, nullIfNeg(chunk) != nullptr, ordered),
(schedule != "static") || ordered, threads, chunk, ordered) {}
(schedule != "static") || ordered, threads, chunk, ordered, collapse,
gpu) {}
std::vector<Value *> OMPSched::getUsedValues() const {
std::vector<Value *> ret;

View File

@ -16,14 +16,18 @@ struct OMPSched {
Value *threads;
Value *chunk;
bool ordered;
int64_t collapse;
bool gpu;
explicit OMPSched(int code = -1, bool dynamic = false, Value *threads = nullptr,
Value *chunk = nullptr, bool ordered = false);
Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0,
bool gpu = false);
explicit OMPSched(const std::string &code, Value *threads = nullptr,
Value *chunk = nullptr, bool ordered = false);
Value *chunk = nullptr, bool ordered = false, int64_t collapse = 0,
bool gpu = false);
OMPSched(const OMPSched &s)
: code(s.code), dynamic(s.dynamic), threads(s.threads), chunk(s.chunk),
ordered(s.ordered) {}
ordered(s.ordered), collapse(s.collapse), gpu(s.gpu) {}
std::vector<Value *> getUsedValues() const;
int replaceUsedValue(id_t id, Value *newValue);

View File

@ -65,6 +65,8 @@ const char IntType::NodeId = 0;
const char FloatType::NodeId = 0;
const char Float32Type::NodeId = 0;
const char BoolType::NodeId = 0;
const char ByteType::NodeId = 0;
@ -188,6 +190,12 @@ std::string IntNType::getInstanceName(unsigned int len, bool sign) {
return fmt::format(FMT_STRING("{}Int{}"), sign ? "" : "U", len);
}
const char VectorType::NodeId = 0;
std::string VectorType::getInstanceName(unsigned int count, PrimitiveType *base) {
return fmt::format(FMT_STRING("Vector[{}, {}]"), count, base->referenceString());
}
} // namespace types
} // namespace ir
} // namespace codon

View File

@ -138,6 +138,15 @@ public:
FloatType() : AcceptorExtend("float") {}
};
/// Float32 type (32-bit float)
class Float32Type : public AcceptorExtend<Float32Type, PrimitiveType> {
public:
static const char NodeId;
/// Constructs a float32 type.
Float32Type() : AcceptorExtend("float32") {}
};
/// Bool type (8-bit unsigned integer; either 0 or 1)
class BoolType : public AcceptorExtend<BoolType, PrimitiveType> {
public:
@ -424,7 +433,7 @@ private:
};
/// Type of a variably sized integer
class IntNType : public AcceptorExtend<IntNType, Type> {
class IntNType : public AcceptorExtend<IntNType, PrimitiveType> {
private:
/// length of the integer
unsigned len;
@ -451,9 +460,31 @@ public:
std::string oppositeSignName() const { return getInstanceName(len, !sign); }
static std::string getInstanceName(unsigned len, bool sign);
};
/// Type of a vector of primitives
class VectorType : public AcceptorExtend<VectorType, PrimitiveType> {
private:
bool doIsAtomic() const override { return true; }
/// number of elements
unsigned count;
/// base type
PrimitiveType *base;
public:
static const char NodeId;
/// Constructs a vector type.
/// @param count the number of elements
/// @param base the base type
VectorType(unsigned count, PrimitiveType *base)
: AcceptorExtend(getInstanceName(count, base)), count(count), base(base) {}
/// @return the count of the vector
unsigned getCount() const { return count; }
/// @return the base type of the vector
PrimitiveType *getBase() const { return base; }
static std::string getInstanceName(unsigned count, PrimitiveType *base);
};
} // namespace types

View File

@ -288,6 +288,9 @@ public:
void visit(const types::FloatType *v) override {
fmt::print(os, FMT_STRING("(float '\"{}\")"), v->referenceString());
}
void visit(const types::Float32Type *v) override {
fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString());
}
void visit(const types::BoolType *v) override {
fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString());
}
@ -334,6 +337,10 @@ public:
fmt::print(os, FMT_STRING("(intn '\"{}\" {} (signed {}))"), v->referenceString(),
v->getLen(), v->isSigned());
}
void visit(const types::VectorType *v) override {
fmt::print(os, FMT_STRING("(vector '\"{}\" {} (count {}))"), v->referenceString(),
makeFormatter(v->getBase()), v->getCount());
}
void visit(const dsl::types::CustomType *v) override { v->doFormat(os); }
void format(const Node *n) {

View File

@ -233,6 +233,8 @@ public:
void handle(const types::IntType *, const types::IntType *) { result = true; }
VISIT(types::FloatType);
void handle(const types::FloatType *, const types::FloatType *) { result = true; }
VISIT(types::Float32Type);
void handle(const types::Float32Type *, const types::Float32Type *) { result = true; }
VISIT(types::BoolType);
void handle(const types::BoolType *, const types::BoolType *) { result = true; }
VISIT(types::ByteType);
@ -272,6 +274,10 @@ public:
void handle(const types::IntNType *x, const types::IntNType *y) {
result = x->getLen() == y->getLen() && x->isSigned() == y->isSigned();
}
VISIT(types::VectorType);
void handle(const types::VectorType *x, const types::VectorType *y) {
result = x->getCount() == y->getCount() && process(x->getBase(), y->getBase());
}
VISIT(dsl::types::CustomType);
void handle(const dsl::types::CustomType *x, const dsl::types::CustomType *y) {
result = x->match(y);

View File

@ -51,6 +51,7 @@ void Visitor::visit(types::Type *x) { defaultVisit(x); }
void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); }
void Visitor::visit(types::IntType *x) { defaultVisit(x); }
void Visitor::visit(types::FloatType *x) { defaultVisit(x); }
void Visitor::visit(types::Float32Type *x) { defaultVisit(x); }
void Visitor::visit(types::BoolType *x) { defaultVisit(x); }
void Visitor::visit(types::ByteType *x) { defaultVisit(x); }
void Visitor::visit(types::VoidType *x) { defaultVisit(x); }
@ -61,6 +62,7 @@ void Visitor::visit(types::OptionalType *x) { defaultVisit(x); }
void Visitor::visit(types::PointerType *x) { defaultVisit(x); }
void Visitor::visit(types::GeneratorType *x) { defaultVisit(x); }
void Visitor::visit(types::IntNType *x) { defaultVisit(x); }
void Visitor::visit(types::VectorType *x) { defaultVisit(x); }
void Visitor::visit(dsl::types::CustomType *x) { defaultVisit(x); }
void ConstVisitor::visit(const Module *x) { defaultVisit(x); }
@ -108,6 +110,7 @@ void ConstVisitor::visit(const types::Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); }
@ -118,6 +121,7 @@ void ConstVisitor::visit(const types::OptionalType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::PointerType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::GeneratorType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::IntNType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::VectorType *x) { defaultVisit(x); }
void ConstVisitor::visit(const dsl::types::CustomType *x) { defaultVisit(x); }
} // namespace util

View File

@ -16,6 +16,7 @@ class Type;
class PrimitiveType;
class IntType;
class FloatType;
class Float32Type;
class BoolType;
class ByteType;
class VoidType;
@ -26,6 +27,7 @@ class OptionalType;
class PointerType;
class GeneratorType;
class IntNType;
class VectorType;
} // namespace types
namespace dsl {
@ -146,6 +148,7 @@ public:
VISIT(types::PrimitiveType);
VISIT(types::IntType);
VISIT(types::FloatType);
VISIT(types::Float32Type);
VISIT(types::BoolType);
VISIT(types::ByteType);
VISIT(types::VoidType);
@ -156,6 +159,7 @@ public:
VISIT(types::PointerType);
VISIT(types::GeneratorType);
VISIT(types::IntNType);
VISIT(types::VectorType);
VISIT(dsl::types::CustomType);
};
@ -220,6 +224,7 @@ public:
CONST_VISIT(types::PrimitiveType);
CONST_VISIT(types::IntType);
CONST_VISIT(types::FloatType);
CONST_VISIT(types::Float32Type);
CONST_VISIT(types::BoolType);
CONST_VISIT(types::ByteType);
CONST_VISIT(types::VoidType);
@ -230,6 +235,7 @@ public:
CONST_VISIT(types::PointerType);
CONST_VISIT(types::GeneratorType);
CONST_VISIT(types::IntNType);
CONST_VISIT(types::VectorType);
CONST_VISIT(dsl::types::CustomType);
};

View File

@ -31,6 +31,7 @@
## Advanced
* [Parallelism and multithreading](advanced/parallel.md)
* [GPU programming](advanced/gpu.md)
* [Pipelines](advanced/pipelines.md)
* [Intermediate representation](advanced/ir.md)
* [Building from source](advanced/build.md)

View File

@ -0,0 +1,249 @@
Codon supports GPU programming through a native GPU backend.
Currently, only Nvidia devices are supported.
Here is a simple example:
``` python
import gpu
@gpu.kernel
def hello(a, b, c):
i = gpu.thread.x
c[i] = a[i] + b[i]
a = [i for i in range(16)]
b = [2*i for i in range(16)]
c = [0 for _ in range(16)]
hello(a, b, c, grid=1, block=16)
print(c)
```
which outputs:
```
[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45]
```
The same code can be written using Codon's `@par` syntax:
``` python
a = [i for i in range(16)]
b = [2*i for i in range(16)]
c = [0 for _ in range(16)]
@par(gpu=True)
for i in range(16):
c[i] = a[i] + b[i]
print(c)
```
Below is a more comprehensive example for computing the [Mandelbrot
set](https://en.wikipedia.org/wiki/Mandelbrot_set), and plotting it
using NumPy/Matplotlib:
``` python
from python import numpy as np
from python import matplotlib.pyplot as plt
import gpu
MAX = 1000 # maximum Mandelbrot iterations
N = 4096 # width and height of image
pixels = [0 for _ in range(N * N)]
def scale(x, a, b):
return a + (x/N)*(b - a)
@gpu.kernel
def mandelbrot(pixels):
idx = (gpu.block.x * gpu.block.dim.x) + gpu.thread.x
i, j = divmod(idx, N)
c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12))
z = 0j
iteration = 0
while abs(z) <= 2 and iteration < MAX:
z = z**2 + c
iteration += 1
pixels[idx] = int(255 * iteration/MAX)
mandelbrot(pixels, grid=(N*N)//1024, block=1024)
plt.imshow(np.array(pixels).reshape(N, N))
plt.show()
```
The GPU version of the Mandelbrot code is about 450 times faster
than an equivalent CPU version.
GPU kernels are marked with the `@gpu.kernel` annotation, and
compiled specially in Codon's backend. Kernel functions can
use the vast majority of features supported in Codon, with a
couple notable exceptions:
- Exception handling is not supported inside the kernel, meaning
kernel code should not throw or catch exceptions. `raise`
statements inside the kernel are marked as unreachable and
optimized out.
- Functionality related to I/O is not supported (e.g. you can't
open a file in the kernel).
- A few other modules and functions are not allowed, such as the
`re` module (which uses an external regex library) or the `os`
module.
{% hint style="warning" %}
The GPU module is under active development. APIs and semantics
might change between Codon releases.
{% endhint %}
# Invoking the kernel
The kernel can be invoked via a simple call with added `grid` and
`block` parameters. These parameters define the grid and block
dimensions, respectively. Recall that GPU execution involves a *grid*
of (`X` x `Y` x `Z`) *blocks* where each block contains (`x` x `y` x `z`)
executing threads. Device-specific restrictions on grid and block sizes
apply.
The `grid` and `block` parameters can be one of:
- Single integer `x`, giving dimensions `(x, 1, 1)`
- Tuple of two integers `(x, y)`, giving dimensions `(x, y, 1)`
- Tuple of three integers `(x, y, z)`, giving dimensions `(x, y, z)`
- Instance of `gpu.Dim3` as in `Dim3(x, y, z)`, specifying the three dimensions
# GPU intrinsics
Codon's GPU module provides many of the same intrinsics that CUDA does:
| Codon | Description | CUDA equivalent |
|-------------------|-----------------------------------------|-----------------|
| `gpu.thread.x` | x-coordinate of current thread in block | `threadId.x` |
| `gpu.block.x` | x-coordinate of current block in grid | `blockIdx.x` |
| `gpu.block.dim.x` | x-dimension of block | `blockDim.x` |
| `gpu.grid.dim.x` | x-dimension of grid | `gridDim.x` |
The same applies for the `y` and `z` coordinates. The `*.dim` objects are instances
of `gpu.Dim3`.
# Math functions
All the functions in the `math` module are supported in kernel functions, and
are automatically replaced with GPU-optimized versions:
``` python
import math
import gpu
@gpu.kernel
def hello(x):
i = gpu.thread.x
x[i] = math.sqrt(x[i]) # uses __nv_sqrt from libdevice
x = [float(i) for i in range(10)]
hello(x, grid=1, block=10)
print(x)
```
gives:
```
[0, 1, 1.41421, 1.73205, 2, 2.23607, 2.44949, 2.64575, 2.82843, 3]
```
# Libdevice
Codon uses [libdevice](https://docs.nvidia.com/cuda/libdevice-users-guide/index.html)
for GPU-optimized math functions. The default libdevice path is
`/usr/local/cuda/nvvm/libdevice/libdevice.10.bc`. An alternative path can be specified
via the `-libdevice` compiler flag.
# Working with raw pointers
By default, objects are converted entirely to their GPU counterparts, which have
the same data layout as the original objects (although the Codon compiler might perform
optimizations by swapping a CPU implementation of a data type with a GPU-optimized
implementation that exposes the same API). This preserves all of Codon/Python's
standard semantics within the kernel.
It is possible to use a kernel with raw pointers via `gpu.raw`, which corresponds
to how the kernel would be written in C++/CUDA:
``` python
import gpu
@gpu.kernel
def hello(a, b, c):
i = gpu.thread.x
c[i] = a[i] + b[i]
a = [i for i in range(16)]
b = [2*i for i in range(16)]
c = [0 for _ in range(16)]
# call the kernel with three int-pointer arguments:
hello(gpu.raw(a), gpu.raw(b), gpu.raw(c), grid=1, block=16)
print(c) # output same as first snippet's
```
`gpu.raw` can avoid an extra pointer indirection, but outputs a Codon `Ptr` object,
meaning the corresponding kernel parameters will not have the full list API, instead
having the more limited `Ptr` API (which primarily just supports indexing/assignment).
# Object conversions
A hidden API is used to copy objects to and from the GPU device. This API consists of
two new *magic methods*:
- `__to_gpu__(self)`: Allocates the necessary GPU memory and copies the object `self` to
the device.
- `__from_gpu__(self, gpu_object)`: Copies the GPU memory of `gpu_object` (which is
a value returned by `__to_gpu__`) back to the CPU object `self`.
For primitive types like `int` and `float`, `__to_gpu__` simply returns `self` and
`__from_gpu__` does nothing. These methods are defined for all the built-in types *and*
are automatically generated for user-defined classes, so most objects can be transferred
back and forth from the GPU seamlessly. A user-defined class that makes use of raw pointers
or other low-level constructs will have to define these methods for GPU use. Please refer
to the `gpu` module for implementation examples.
# `@par(gpu=True)`
Codon's `@par` syntax can be used to seamlessly parallelize existing loops on the GPU,
without needing to explicitly write them as kernels. For loop nests, the `collapse` argument
can be used to cover the entire iteration space on the GPU. For example, here is the Mandelbrot
code above written using `@par`:
``` python
MAX = 1000 # maximum Mandelbrot iterations
N = 4096 # width and height of image
pixels = [0 for _ in range(N * N)]
def scale(x, a, b):
return a + (x/N)*(b - a)
@par(gpu=True, collapse=2)
for i in range(N):
for j in range(N):
c = complex(scale(j, -2.00, 0.47), scale(i, -1.12, 1.12))
z = 0j
iteration = 0
while abs(z) <= 2 and iteration < MAX:
z = z**2 + c
iteration += 1
pixels[i*N + j] = int(255 * iteration/MAX)
```
Note that the `gpu=True` option disallows shared variables (i.e. assigning out-of-loop
variables in the loop body) as well as reductions. The other GPU-specific restrictions
described here apply as well.
# Troubleshooting
CUDA errors resulting in kernel abortion are printed, and typically arise from invalid
code in the kernel, either via using exceptions or using unsupported modules/objects.

View File

@ -29,6 +29,8 @@ for i in range(10):
- `chunk_size` (int): chunk size when partitioning loop iterations
- `ordered` (bool): whether the loop iterations should be executed in
the same order
- `collapse` (int): number of loop nests to collapse into a single
iteration space
Other OpenMP parameters like `private`, `shared` or `reduction`, are
inferred automatically by the compiler. For example, the following loop

View File

@ -30,6 +30,15 @@ compile to native code without any runtime performance overhead.
Future versions of Codon will lift some of these restrictions
by the introduction of e.g. implicit union types.
# Numerics
For performance reasons, some numeric operations use C semantics
rather than Python semantics. This includes, for example, raising
an exception when dividing by zero, or other checks done by `math`
functions. Strict adherence to Python semantics can be achieved by
using the `-numerics=py` flag of the Codon compiler. Note that this
does *not* change `int`s from 64-bit.
# Modules
While most of the commonly used builtin modules have Codon-native

View File

@ -2,6 +2,50 @@ Below you can find release notes for each major Codon release,
listing improvements, updates, optimizations and more for each
new version.
# v0.14
## GPU support
GPU kernels can now be written and called in Codon. Existing
loops can be parallelized on the GPU with the `@par(gpu=True)`
annotation. Please see the [docs](../advanced/gpu.md) for
more information and examples.
## Semantics
Added `-numerics` flag, which specifies semantics of various
numeric operations:
- `-numerics=c` (default): C semantics; best performance
- `-numerics=py`: Python semantics (checks for zero divisors
and raises `ZeroDivisionError`, and adds domain checks to `math`
functions); might slightly decrease performance.
## Types
Added `float32` type to represent 32-bit floats (equivalent to C's
`float`). All `math` functions now have `float32` overloads.
## Parallelism
Added `collapse` option to `@par`:
``` python
@par(collapse=2) # parallelize entire iteration space of 2 loops
for i in range(N):
for j in range(N):
do_work(i, j)
```
## Standard library
Added `collections.defaultdict`.
## Python interoperability
Various Python interoperability improvements: can now use `isinstance`
on Python objects/types and can now catch Python exceptions by name.
# v0.13
## Language

View File

@ -19,6 +19,12 @@ variants:
- `i32`/`u32`: signed/unsigned 32-bit integer
- `i64`/`u64`: signed/unsigned 64-bit integer
# 32-bit float
Codon's `float` type is a 64-bit floating point value. Codon
also supports `float32` (or `f32` as a shorthand), representing
a 32-bit floating point value (like C's `float`).
# Pointers
Codon has a `Ptr[T]` type that represents a pointer to an object

View File

@ -69,6 +69,6 @@ def heap_sort(collection: List[T], keyf: Callable[[T], S], T: type, S: type) ->
Heap Sort
Returns a sorted list.
"""
newlst = copy(collection)
newlst = collection.__copy__()
heap_sort_inplace(newlst, keyf)
return newlst

View File

@ -42,6 +42,6 @@ def insertion_sort(
Insertion Sort
Returns the sorted list.
"""
newlst = copy(collection)
newlst = collection.__copy__()
insertion_sort_inplace(newlst, keyf)
return newlst

View File

@ -321,6 +321,6 @@ def pdq_sort(collection: List[T], keyf: Callable[[T], S], T: type, S: type) -> L
Returns a sorted list.
"""
newlst = copy(collection)
newlst = collection.__copy__()
pdq_sort_inplace(newlst, keyf)
return newlst

View File

@ -114,13 +114,13 @@ class deque:
return deque(i.__deepcopy__() for i in self)
def __copy__(self) -> deque[T]:
return deque[T](copy(self._arr), self._head, self._tail, self._maxlen)
return deque[T](self._arr.__copy__(), self._head, self._tail, self._maxlen)
def copy(self) -> deque[T]:
return self.__copy__()
def __repr__(self) -> str:
return repr(List[T](iter(self)))
return f"deque({repr(List[T](iter(self)))})"
def _idx_check(self, idx: int, msg: str):
if self._head == self._tail or idx >= len(self) or idx < 0:
@ -323,25 +323,28 @@ class Counter(Dict[T, int]):
return result
def __add__(self, other: Counter[T]) -> Counter[T]:
result = copy(self)
result = self.__copy__()
result += other
return result
def __sub__(self, other: Counter[T]) -> Counter[T]:
result = copy(self)
result = self.__copy__()
result -= other
return result
def __and__(self, other: Counter[T]) -> Counter[T]:
result = copy(self)
result = self.__copy__()
result &= other
return result
def __or__(self, other: Counter[T]) -> Counter[T]:
result = copy(self)
result = self.__copy__()
result |= other
return result
def __repr__(self):
return f"Counter({super().__repr__()})"
def __dict_do_op_throws__(self, key: T, other: Z, op: F, F: type, Z: type):
self.__dict_do_op__(key, other, 0, op)
@ -357,5 +360,79 @@ class Dict:
self._init_from(other)
class defaultdict(Dict[K,V]):
default_factory: S
K: type
V: type
S: TypeVar[Callable[[], V]]
def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V]):
super().__init__()
self.default_factory = lambda: VV()
def __init__(self, f: S):
super().__init__()
self.default_factory = f
def __init__(self: defaultdict[K, VV, Function[[], V]], VV: TypeVar[V], other: Dict[K, V]):
super().__init__(other)
self.default_factory = lambda: VV()
def __init__(self, f: S, other: Dict[K, V]):
super().__init__(other)
self.default_factory = f
def __missing__(self, key: K):
default_value = self.default_factory()
self.__setitem__(key, default_value)
return default_value
def __getitem__(self, key: K) -> V:
if key not in self:
return self.__missing__(key)
return super().__getitem__(key)
def __dict_do_op_throws__(self, key: K, other: Z, op: F, F: type, Z: type):
x = self._kh_get(key)
if x == self._kh_end():
self.__missing__(key)
x = self._kh_get(key)
self._vals[x] = op(self._vals[x], other)
def copy(self):
d = defaultdict[K,V,S](self.default_factory)
d._init_from(self)
return d
def __copy__(self):
return self.copy()
def __deepcopy__(self):
d = defaultdict[K,V,S](self.default_factory)
for k,v in self.items():
d[k.__deepcopy__()] = v.__deepcopy__()
return d
def __eq__(self, other: defaultdict[K,V,S]) -> bool:
if self.__len__() != other.__len__():
return False
for k, v in self.items():
if k not in other or other[k] != v:
return False
return True
def __ne__(self, other: defaultdict[K,V,S]) -> bool:
return not (self == other)
def __repr__(self):
return f"defaultdict(<default factory of '{V.__name__}'>, {super().__repr__()})"
@extend
class Dict:
def __init__(self: Dict[K, V], other: defaultdict[K, V, S], S: type):
self._init_from(other)
def namedtuple(name: Static[str], args): # internal
pass

14
stdlib/copy.codon 100644
View File

@ -0,0 +1,14 @@
# (c) 2022 Exaloop Inc. All rights reserved.
class Error(Exception):
def __init__(self, message: str = ""):
super().__init__("copy.Error", message)
def copy(x):
return x.__copy__()
def deepcopy(x):
return x.__deepcopy__()

732
stdlib/gpu.codon 100644
View File

@ -0,0 +1,732 @@
# (c) 2022 Exaloop Inc. All rights reserved.
from internal.gc import sizeof as _sizeof
@tuple
class Device:
_device: i32
def __new__(device: int):
from C import seq_nvptx_device(int) -> i32
return Device(seq_nvptx_device(device))
@staticmethod
def count():
from C import seq_nvptx_device_count() -> int
return seq_nvptx_device_count()
def __str__(self):
from C import seq_nvptx_device_name(i32) -> str
return seq_nvptx_device_name(self._device)
def __index__(self):
return int(self._device)
def __bool__(self):
return True
@property
def compute_capability(self):
from C import seq_nvptx_device_capability(i32) -> int
c = seq_nvptx_device_capability(self._device)
return (c >> 32, c & 0xffffffff)
@tuple
class Memory[T]:
_ptr: Ptr[byte]
def _alloc(n: int, T: type):
from C import seq_nvptx_device_alloc(int) -> Ptr[byte]
return Memory[T](seq_nvptx_device_alloc(n * _sizeof(T)))
def _read(self, p: Ptr[T], n: int):
from C import seq_nvptx_memcpy_d2h(Ptr[byte], Ptr[byte], int)
seq_nvptx_memcpy_d2h(p.as_byte(), self._ptr, n * _sizeof(T))
def _write(self, p: Ptr[T], n: int):
from C import seq_nvptx_memcpy_h2d(Ptr[byte], Ptr[byte], int)
seq_nvptx_memcpy_h2d(self._ptr, p.as_byte(), n * _sizeof(T))
def _free(self):
from C import seq_nvptx_device_free(Ptr[byte])
seq_nvptx_device_free(self._ptr)
@llvm
def syncthreads() -> None:
declare void @llvm.nvvm.barrier0()
call void @llvm.nvvm.barrier0()
ret {} {}
@tuple
class Dim3:
_x: u32
_y: u32
_z: u32
def __new__(x: int, y: int, z: int):
return Dim3(u32(x), u32(y), u32(z))
@property
def x(self):
return int(self._x)
@property
def y(self):
return int(self._y)
@property
def z(self):
return int(self._z)
@tuple
class Thread:
@property
def x(self):
@pure
@llvm
def get_x() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
ret i32 %res
return int(get_x())
@property
def y(self):
@pure
@llvm
def get_y() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.tid.y()
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
ret i32 %res
return int(get_y())
@property
def z(self):
@pure
@llvm
def get_z() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.tid.z()
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
ret i32 %res
return int(get_z())
@tuple
class Block:
@property
def x(self):
@pure
@llvm
def get_x() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
ret i32 %res
return int(get_x())
@property
def y(self):
@pure
@llvm
def get_y() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
ret i32 %res
return int(get_y())
@property
def z(self):
@pure
@llvm
def get_z() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
ret i32 %res
return int(get_z())
@property
def dim(self):
@pure
@llvm
def get_x() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
ret i32 %res
@pure
@llvm
def get_y() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
ret i32 %res
@pure
@llvm
def get_z() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
ret i32 %res
return Dim3(get_x(), get_y(), get_z())
@tuple
class Grid:
@property
def dim(self):
@pure
@llvm
def get_x() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
%res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
ret i32 %res
@pure
@llvm
def get_y() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
%res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
ret i32 %res
@pure
@llvm
def get_z() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
%res = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
ret i32 %res
return Dim3(get_x(), get_y(), get_z())
@tuple
class Warp:
def __len__(self):
@pure
@llvm
def get_warpsize() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()
%res = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
ret i32 %res
return int(get_warpsize())
thread = Thread()
block = Block()
grid = Grid()
warp = Warp()
def _catch():
return (thread, block, grid, warp)
_catch()
@tuple
class AllocCache:
v: List[Ptr[byte]]
def add(self, p: Ptr[byte]):
self.v.append(p)
def free(self):
for p in self.v:
Memory[byte](p)._free()
def _tuple_from_gpu(args, gpu_args):
if staticlen(args) > 0:
a = args[0]
g = gpu_args[0]
a.__from_gpu__(g)
_tuple_from_gpu(args[1:], gpu_args[1:])
def kernel(fn):
from C import seq_nvptx_function(str) -> cobj
from C import seq_nvptx_invoke(cobj, u32, u32, u32, u32, u32, u32, u32, cobj)
def canonical_dim(dim):
if isinstance(dim, NoneType):
return (1, 1, 1)
elif isinstance(dim, int):
return (dim, 1, 1)
elif isinstance(dim, Tuple[int,int]):
return (dim[0], dim[1], 1)
elif isinstance(dim, Tuple[int,int,int]):
return dim
elif isinstance(dim, Dim3):
return (dim.x, dim.y, dim.z)
else:
compile_error("bad dimension argument")
def offsets(t):
@pure
@llvm
def offsetof(t: T, i: Static[int], T: type, S: type) -> int:
%p = getelementptr {=T}, {=T}* null, i64 0, i32 {=i}
%s = ptrtoint {=S}* %p to i64
ret i64 %s
if staticlen(t) == 0:
return ()
else:
T = type(t)
S = type(t[-1])
return (*offsets(t[:-1]), offsetof(t, staticlen(t) - 1, T, S))
def wrapper(*args, grid, block):
grid = canonical_dim(grid)
block = canonical_dim(block)
cache = AllocCache([])
shared_mem = 0
gpu_args = tuple(arg.__to_gpu__(cache) for arg in args)
kernel_ptr = seq_nvptx_function(__realized__(fn, gpu_args).__llvm_name__)
p = __ptr__(gpu_args).as_byte()
arg_ptrs = tuple((p + offset) for offset in offsets(gpu_args))
seq_nvptx_invoke(kernel_ptr, u32(grid[0]), u32(grid[1]), u32(grid[2]), u32(block[0]),
u32(block[1]), u32(block[2]), u32(shared_mem), __ptr__(arg_ptrs).as_byte())
_tuple_from_gpu(args, gpu_args)
cache.free()
return wrapper
def _ptr_to_gpu(p: Ptr[T], n: int, cache: AllocCache, index_filter = lambda i: True, T: type):
from internal.gc import atomic
if not atomic(T):
tmp = Ptr[T](n)
for i in range(n):
if index_filter(i):
tmp[i] = p[i].__to_gpu__(cache)
p = tmp
mem = Memory._alloc(n, T)
cache.add(mem._ptr)
mem._write(p, n)
return Ptr[T](mem._ptr)
def _ptr_from_gpu(p: Ptr[T], q: Ptr[T], n: int, index_filter = lambda i: True, T: type):
from internal.gc import atomic
mem = Memory[T](q.as_byte())
if not atomic(T):
tmp = Ptr[T](n)
mem._read(tmp, n)
for i in range(n):
if index_filter(i):
p[i] = T.__from_gpu_new__(tmp[i])
else:
mem._read(p, n)
@pure
@llvm
def _ptr_to_type(p: cobj, T: type) -> T:
ret i8* %p
def _object_to_gpu(obj: T, cache: AllocCache, T: type):
s = tuple(obj)
gpu_mem = Memory._alloc(1, type(s))
cache.add(gpu_mem._ptr)
gpu_mem._write(__ptr__(s), 1)
return _ptr_to_type(gpu_mem._ptr, T)
def _object_from_gpu(obj):
T = type(obj)
S = type(tuple(obj))
tmp = T.__new__()
p = Ptr[S](tmp.__raw__())
q = Ptr[S](obj.__raw__())
mem = Memory[S](q.as_byte())
mem._read(p, 1)
return tmp
@tuple
class Pointer[T]:
_ptr: Ptr[T]
_len: int
def __to_gpu__(self, cache: AllocCache):
return _ptr_to_gpu(self._ptr, self._len, cache)
def __from_gpu__(self, other: Ptr[T]):
_ptr_from_gpu(self._ptr, other, self._len)
def __from_gpu_new__(other: Ptr[T]):
return other
def raw(v: List[T], T: type):
return Pointer(v.arr.ptr, len(v))
@extend
class Ptr:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: Ptr[T]):
pass
def __from_gpu_new__(other: Ptr[T]):
return other
@extend
class NoneType:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: NoneType):
pass
def __from_gpu_new__(other: NoneType):
return other
@extend
class int:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: int):
pass
def __from_gpu_new__(other: int):
return other
@extend
class float:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: float):
pass
def __from_gpu_new__(other: float):
return other
@extend
class float32:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: float32):
pass
def __from_gpu_new__(other: float32):
return other
@extend
class bool:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: bool):
pass
def __from_gpu_new__(other: bool):
return other
@extend
class byte:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: byte):
pass
def __from_gpu_new__(other: byte):
return other
@extend
class Int:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: Int[N]):
pass
def __from_gpu_new__(other: Int[N]):
return other
@extend
class UInt:
def __to_gpu__(self, cache: AllocCache):
return self
def __from_gpu__(self, other: UInt[N]):
pass
def __from_gpu_new__(other: UInt[N]):
return other
@extend
class str:
def __to_gpu__(self, cache: AllocCache):
n = self.len
return str(_ptr_to_gpu(self.ptr, n, cache), n)
def __from_gpu__(self, other: str):
pass
def __from_gpu_new__(other: str):
n = other.len
p = Ptr[byte](n)
_ptr_from_gpu(p, other.ptr, n)
return str(p, n)
@extend
class List:
@inline
def __to_gpu__(self, cache: AllocCache):
mem = List[T].__new__()
n = self.len
gpu_ptr = _ptr_to_gpu(self.arr.ptr, n, cache)
mem.arr = Array[T](gpu_ptr, n)
mem.len = n
return _object_to_gpu(mem, cache)
@inline
def __from_gpu__(self, other: List[T]):
mem = _object_from_gpu(other)
my_cap = self.arr.len
other_cap = mem.arr.len
if other_cap > my_cap:
self._resize(other_cap)
_ptr_from_gpu(self.arr.ptr, mem.arr.ptr, mem.len)
self.len = mem.len
@inline
def __from_gpu_new__(other: List[T]):
mem = _object_from_gpu(other)
arr = Array[T](mem.arr.len)
_ptr_from_gpu(arr.ptr, mem.arr.ptr, arr.len)
mem.arr = arr
return mem
@extend
class Dict:
def __to_gpu__(self, cache: AllocCache):
from internal.khash import __ac_fsize
mem = Dict[K,V].__new__()
n = self._n_buckets
f = __ac_fsize(n) if n else 0
mem._n_buckets = n
mem._size = self._size
mem._n_occupied = self._n_occupied
mem._upper_bound = self._upper_bound
mem._flags = _ptr_to_gpu(self._flags, f, cache)
mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i))
mem._vals = _ptr_to_gpu(self._vals, n, cache, lambda i: self._kh_exist(i))
return _object_to_gpu(mem, cache)
def __from_gpu__(self, other: Dict[K,V]):
from internal.khash import __ac_fsize
mem = _object_from_gpu(other)
my_n = self._n_buckets
n = mem._n_buckets
f = __ac_fsize(n) if n else 0
if my_n != n:
self._flags = Ptr[u32](f)
self._keys = Ptr[K](n)
self._vals = Ptr[V](n)
_ptr_from_gpu(self._flags, mem._flags, f)
_ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i))
_ptr_from_gpu(self._vals, mem._vals, n, lambda i: self._kh_exist(i))
self._n_buckets = n
self._size = mem._size
self._n_occupied = mem._n_occupied
self._upper_bound = mem._upper_bound
def __from_gpu_new__(other: Dict[K,V]):
from internal.khash import __ac_fsize
mem = _object_from_gpu(other)
n = mem._n_buckets
f = __ac_fsize(n) if n else 0
flags = Ptr[u32](f)
keys = Ptr[K](n)
vals = Ptr[V](n)
_ptr_from_gpu(flags, mem._flags, f)
mem._flags = flags
_ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i))
mem._keys = keys
_ptr_from_gpu(vals, mem._vals, n, lambda i: mem._kh_exist(i))
mem._vals = vals
return mem
@extend
class Set:
def __to_gpu__(self, cache: AllocCache):
from internal.khash import __ac_fsize
mem = Set[K].__new__()
n = self._n_buckets
f = __ac_fsize(n) if n else 0
mem._n_buckets = n
mem._size = self._size
mem._n_occupied = self._n_occupied
mem._upper_bound = self._upper_bound
mem._flags = _ptr_to_gpu(self._flags, f, cache)
mem._keys = _ptr_to_gpu(self._keys, n, cache, lambda i: self._kh_exist(i))
return _object_to_gpu(mem, cache)
def __from_gpu__(self, other: Set[K]):
from internal.khash import __ac_fsize
mem = _object_from_gpu(other)
my_n = self._n_buckets
n = mem._n_buckets
f = __ac_fsize(n) if n else 0
if my_n != n:
self._flags = Ptr[u32](f)
self._keys = Ptr[K](n)
_ptr_from_gpu(self._flags, mem._flags, f)
_ptr_from_gpu(self._keys, mem._keys, n, lambda i: self._kh_exist(i))
self._n_buckets = n
self._size = mem._size
self._n_occupied = mem._n_occupied
self._upper_bound = mem._upper_bound
def __from_gpu_new__(other: Set[K]):
from internal.khash import __ac_fsize
mem = _object_from_gpu(other)
n = mem._n_buckets
f = __ac_fsize(n) if n else 0
flags = Ptr[u32](f)
keys = Ptr[K](n)
_ptr_from_gpu(flags, mem._flags, f)
mem._flags = flags
_ptr_from_gpu(keys, mem._keys, n, lambda i: mem._kh_exist(i))
mem._keys = keys
return mem
@extend
class Optional:
def __to_gpu__(self, cache: AllocCache):
if self is None:
return self
else:
return Optional[T](self.__val__().__to_gpu__(cache))
def __from_gpu__(self, other: Optional[T]):
if self is not None and other is not None:
self.__val__().__from_gpu__(other.__val__())
def __from_gpu_new__(other: Optional[T]):
if other is None:
return Optional[T]()
else:
return Optional[T](T.__from_gpu_new__(other.__val__()))
@extend
class __internal__:
def class_to_gpu(obj, cache: AllocCache):
if isinstance(obj, Tuple):
return tuple(a.__to_gpu__(cache) for a in obj)
elif isinstance(obj, ByVal):
T = type(obj)
return T(*tuple(a.__to_gpu__(cache) for a in tuple(obj)))
else:
T = type(obj)
S = type(tuple(obj))
mem = T.__new__()
Ptr[S](mem.__raw__())[0] = tuple(obj).__to_gpu__(cache)
return _object_to_gpu(mem, cache)
def class_from_gpu(obj, other):
if isinstance(obj, Tuple):
_tuple_from_gpu(obj, other)
elif isinstance(obj, ByVal):
_tuple_from_gpu(tuple(obj), tuple(other))
else:
S = type(tuple(obj))
Ptr[S](obj.__raw__())[0] = S.__from_gpu_new__(tuple(_object_from_gpu(other)))
def class_from_gpu_new(other):
if isinstance(other, Tuple):
return tuple(type(a).__from_gpu_new__(a) for a in other)
elif isinstance(other, ByVal):
T = type(other)
return T(*tuple(type(a).__from_gpu_new__(a) for a in tuple(other)))
else:
S = type(tuple(other))
mem = _object_from_gpu(other)
Ptr[S](mem.__raw__())[0] = S.__from_gpu_new__(tuple(mem))
return mem
# @par(gpu=True) support
@pure
@llvm
def _gpu_thread_x() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%res = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
ret i32 %res
@pure
@llvm
def _gpu_block_x() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
ret i32 %res
@pure
@llvm
def _gpu_block_dim_x() -> u32:
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%res = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
ret i32 %res
def _gpu_loop_outline_template(start, stop, args, instance: Static[int]):
@nonpure
def _loop_step():
return 1
@kernel
def _kernel_stub(start: int, count: int, args):
@nonpure
def _gpu_loop_body_stub(idx, args):
pass
@nonpure
def _dummy_use(n):
pass
_dummy_use(instance)
idx = (int(_gpu_block_dim_x()) * int(_gpu_block_x())) + int(_gpu_thread_x())
step = _loop_step()
if idx < count:
_gpu_loop_body_stub(start + (idx * step), args)
step = _loop_step()
loop = range(start, stop, step)
MAX_BLOCK = 1024
MAX_GRID = 2147483647
G = MAX_BLOCK * MAX_GRID
n = len(loop)
if n == 0:
return
elif n > G:
raise ValueError(f'loop exceeds GPU iteration limit of {G}')
block = n
grid = 1
if n > MAX_BLOCK:
block = MAX_BLOCK
grid = (n // MAX_BLOCK) + (0 if n % MAX_BLOCK == 0 else 1)
_kernel_stub(start, n, args, grid=grid, block=block)

View File

@ -1,6 +1,7 @@
# (c) 2022 Exaloop Inc. All rights reserved.
# Core library
from internal.core import *
from internal.attributes import *
from internal.types.ptr import *
from internal.types.str import *
@ -34,7 +35,11 @@ from internal.str import *
from internal.sort import sorted
from openmp import Ident as __OMPIdent, for_par
from gpu import _gpu_loop_outline_template
from internal.file import File, gzFile, open, gzopen
from pickle import pickle, unpickle
from internal.dlopen import dlsym as _dlsym
import internal.python
if __py_numerics__:
import internal.pynumerics

View File

@ -108,13 +108,6 @@ def iter(x):
return x.__iter__()
def copy(x):
"""
Return a copy of x
"""
return x.__copy__()
def abs(x):
"""
Return the absolute value of x

View File

@ -1,6 +1,6 @@
# (c) 2022 Exaloop Inc. All rights reserved.
# Seq runtime functions
# runtime functions
from C import seq_print(str)
from C import seq_print_full(str, cobj)
@ -216,7 +216,7 @@ def expm1(a: float) -> float:
@pure
@C
def ldexp(a: float, b: int) -> float:
def ldexp(a: float, b: i32) -> float:
pass
@ -406,6 +406,234 @@ def modf(a: float, b: Ptr[float]) -> float:
pass
@pure
@C
def ceilf(a: float32) -> float32:
pass
@pure
@C
def floorf(a: float32) -> float32:
pass
@pure
@C
def fabsf(a: float32) -> float32:
pass
@pure
@C
def fmodf(a: float32, b: float32) -> float32:
pass
@pure
@C
def expf(a: float32) -> float32:
pass
@pure
@C
def expm1f(a: float32) -> float32:
pass
@pure
@C
def ldexpf(a: float32, b: i32) -> float32:
pass
@pure
@C
def logf(a: float32) -> float32:
pass
@pure
@C
def log2f(a: float32) -> float32:
pass
@pure
@C
def log10f(a: float32) -> float32:
pass
@pure
@C
def sqrtf(a: float32) -> float32:
pass
@pure
@C
def powf(a: float32, b: float32) -> float32:
pass
@pure
@C
def roundf(a: float32) -> float32:
pass
@pure
@C
def acosf(a: float32) -> float32:
pass
@pure
@C
def asinf(a: float32) -> float32:
pass
@pure
@C
def atanf(a: float32) -> float32:
pass
@pure
@C
def atan2f(a: float32, b: float32) -> float32:
pass
@pure
@C
def cosf(a: float32) -> float32:
pass
@pure
@C
def sinf(a: float32) -> float32:
pass
@pure
@C
def tanf(a: float32) -> float32:
pass
@pure
@C
def coshf(a: float32) -> float32:
pass
@pure
@C
def sinhf(a: float32) -> float32:
pass
@pure
@C
def tanhf(a: float32) -> float32:
pass
@pure
@C
def acoshf(a: float32) -> float32:
pass
@pure
@C
def asinhf(a: float32) -> float32:
pass
@pure
@C
def atanhf(a: float32) -> float32:
pass
@pure
@C
def copysignf(a: float32, b: float32) -> float32:
pass
@pure
@C
def log1pf(a: float32) -> float32:
pass
@pure
@C
def truncf(a: float32) -> float32:
pass
@pure
@C
def log2f(a: float32) -> float32:
pass
@pure
@C
def erff(a: float32) -> float32:
pass
@pure
@C
def erfcf(a: float32) -> float32:
pass
@pure
@C
def tgammaf(a: float32) -> float32:
pass
@pure
@C
def lgammaf(a: float32) -> float32:
pass
@pure
@C
def remainderf(a: float32, b: float32) -> float32:
pass
@pure
@C
def hypotf(a: float32, b: float32) -> float32:
pass
@nocapture
@C
def frexpf(a: float32, b: Ptr[Int[32]]) -> float32:
pass
@nocapture
@C
def modff(a: float32, b: Ptr[float32]) -> float32:
pass
# <stdio.h>
@pure
@C

View File

@ -1,3 +1,5 @@
# (c) 2022 Exaloop Inc. All rights reserved.
@__internal__
class __internal__:
pass
@ -15,13 +17,23 @@ class byte:
@tuple
@__internal__
class int:
MAX = 9223372036854775807
pass
@tuple
@__internal__
class float:
MIN_10_EXP = -307
pass
@tuple
@__internal__
class float32:
MIN_10_EXP = -37
pass
@tuple
@__internal__
class NoneType:
@ -93,7 +105,7 @@ function = Function
# dummy
@__internal__
class TypeVar: pass
class TypeVar[T]: pass
@__internal__
class ByVal: pass
@__internal__
@ -154,3 +166,13 @@ def super():
def superf(*args):
"""Special handling"""
pass
def __realized__(fn, args):
pass
def statictuple(*args):
return args
def __static_print__(*args):
pass

View File

@ -145,8 +145,9 @@ class gzFile:
if self.buf[offset - 1] == byte(10): # '\n'
break
oldsz = self.sz
self.sz *= 2
self.buf = realloc(self.buf, self.sz)
self.buf = realloc(self.buf, self.sz, oldsz)
return offset

View File

@ -18,7 +18,7 @@ def seq_alloc_atomic(a: int) -> cobj:
@nocapture
@derives
@C
def seq_realloc(p: cobj, a: int) -> cobj:
def seq_realloc(p: cobj, newsize: int, oldsize: int) -> cobj:
pass
@ -76,8 +76,8 @@ def alloc_atomic(sz: int):
return seq_alloc_atomic(sz)
def realloc(p: cobj, sz: int):
return seq_realloc(p, sz)
def realloc(p: cobj, newsz: int, oldsz: int):
return seq_realloc(p, newsz, oldsz)
def free(p: cobj):

View File

@ -0,0 +1,168 @@
# (c) 2022 Exaloop Inc. All rights reserved.
@pure
@llvm
def _floordiv_int_float(self: int, other: float) -> float:
declare double @llvm.floor.f64(double)
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
%2 = call double @llvm.floor.f64(double %1)
ret double %2
@pure
@llvm
def _floordiv_int_int(self: int, other: int) -> int:
%0 = sdiv i64 %self, %other
ret i64 %0
@pure
@llvm
def _truediv_int_float(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = fdiv double %0, %other
ret double %1
@pure
@llvm
def _truediv_int_int(self: int, other: int) -> float:
%0 = sitofp i64 %self to double
%1 = sitofp i64 %other to double
%2 = fdiv double %0, %1
ret double %2
@pure
@llvm
def _mod_int_float(self: int, other: float) -> float:
%0 = sitofp i64 %self to double
%1 = frem double %0, %other
ret double %1
@pure
@llvm
def _mod_int_int(self: int, other: int) -> int:
%0 = srem i64 %self, %other
ret i64 %0
@pure
@llvm
def _truediv_float_float(self: float, other: float) -> float:
%0 = fdiv double %self, %other
ret double %0
@pure
@llvm
def _mod_float_float(self: float, other: float) -> float:
%0 = frem double %self, %other
ret double %0
def _divmod_int_int(self: int, other: int):
d = _floordiv_int_int(self, other)
m = self - d * other
if m and ((other ^ m) < 0):
m += other
d -= 1
return (d, m)
def _divmod_float_float(self: float, other: float):
mod = _mod_float_float(self, other)
div = _truediv_float_float(self - mod, other)
if mod:
if (other < 0) != (mod < 0):
mod += other
div -= 1.0
else:
mod = (0.0).copysign(other)
floordiv = 0.0
if div:
floordiv = div.__floor__()
if div - floordiv > 0.5:
floordiv += 1.0
else:
floordiv = (0.0).copysign(self / other)
return (floordiv, mod)
@extend
class int:
def __floordiv__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float floor division by zero")
return _divmod_float_float(float(self), other)[0]
def __floordiv__(self, other: int):
if other == 0:
raise ZeroDivisionError("integer division or modulo by zero")
return _divmod_int_int(self, other)[0]
def __truediv__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float division by zero")
return _truediv_int_float(self, other)
def __truediv__(self, other: int):
if other == 0:
raise ZeroDivisionError("division by zero")
return _truediv_int_int(self, other)
def __mod__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float modulo")
return _divmod_float_float(self, other)[1]
def __mod__(self, other: int):
if other == 0:
raise ZeroDivisionError("integer division or modulo by zero")
return _divmod_int_int(self, other)[1]
def __divmod__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float divmod()")
return _divmod_float_float(float(self), other)
def __divmod__(self, other: int):
if other == 0:
raise ZeroDivisionError("integer division or modulo by zero")
return _divmod_int_int(self, other)
@extend
class float:
def __floordiv__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float floor division by zero")
return _divmod_float_float(self, other)[0]
def __floordiv__(self, other: int):
if other == 0:
raise ZeroDivisionError("float floor division by zero")
return _divmod_float_float(self, float(other))[0]
def __truediv__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float division by zero")
return _truediv_float_float(self, other)
def __truediv__(self, other: int):
if other == 0:
raise ZeroDivisionError("float division by zero")
return _truediv_float_float(self, float(other))
def __mod__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float modulo")
return _divmod_float_float(self, other)[1]
def __mod__(self, other: int):
if other == 0:
raise ZeroDivisionError("float modulo")
return _divmod_float_float(self, float(other))[1]
def __divmod__(self, other: float):
if other == 0.0:
raise ZeroDivisionError("float divmod()")
return _divmod_float_float(self, other)
def __divmod__(self, other: int):
if other == 0:
raise ZeroDivisionError("float divmod()")
return _divmod_float_float(self, float(other))

View File

@ -12,8 +12,10 @@ PyImport_AddModule = Function[[cobj], cobj](cobj())
PyImport_AddModuleObject = Function[[cobj], cobj](cobj())
PyImport_ImportModule = Function[[cobj], cobj](cobj())
PyErr_Fetch = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj())
PyErr_NormalizeException = Function[[Ptr[cobj], Ptr[cobj], Ptr[cobj]], NoneType](cobj())
PyRun_SimpleString = Function[[cobj], NoneType](cobj())
PyEval_GetGlobals = Function[[], cobj](cobj())
PyEval_GetBuiltins = Function[[], cobj](cobj())
# conversions
PyLong_AsLong = Function[[cobj], int](cobj())
@ -97,6 +99,7 @@ PyObject_GetItem = Function[[cobj, cobj], cobj](cobj())
PyObject_SetItem = Function[[cobj, cobj, cobj], int](cobj())
PyObject_DelItem = Function[[cobj, cobj], int](cobj())
PyObject_RichCompare = Function[[cobj, cobj, i32], cobj](cobj())
PyObject_IsInstance = Function[[cobj, cobj], i32](cobj())
# constants
Py_None = cobj()
@ -154,8 +157,10 @@ def init_dl_handles(py_handle: cobj):
global PyImport_AddModuleObject
global PyImport_ImportModule
global PyErr_Fetch
global PyErr_NormalizeException
global PyRun_SimpleString
global PyEval_GetGlobals
global PyEval_GetBuiltins
global PyLong_AsLong
global PyLong_FromLong
global PyFloat_AsDouble
@ -233,6 +238,7 @@ def init_dl_handles(py_handle: cobj):
global PyObject_SetItem
global PyObject_DelItem
global PyObject_RichCompare
global PyObject_IsInstance
global Py_None
global Py_True
global Py_False
@ -244,8 +250,10 @@ def init_dl_handles(py_handle: cobj):
PyImport_AddModuleObject = dlsym(py_handle, "PyImport_AddModuleObject")
PyImport_ImportModule = dlsym(py_handle, "PyImport_ImportModule")
PyErr_Fetch = dlsym(py_handle, "PyErr_Fetch")
PyErr_NormalizeException = dlsym(py_handle, "PyErr_NormalizeException")
PyRun_SimpleString = dlsym(py_handle, "PyRun_SimpleString")
PyEval_GetGlobals = dlsym(py_handle, "PyEval_GetGlobals")
PyEval_GetBuiltins = dlsym(py_handle, "PyEval_GetBuiltins")
PyLong_AsLong = dlsym(py_handle, "PyLong_AsLong")
PyLong_FromLong = dlsym(py_handle, "PyLong_FromLong")
PyFloat_AsDouble = dlsym(py_handle, "PyFloat_AsDouble")
@ -323,6 +331,7 @@ def init_dl_handles(py_handle: cobj):
PyObject_SetItem = dlsym(py_handle, "PyObject_SetItem")
PyObject_DelItem = dlsym(py_handle, "PyObject_DelItem")
PyObject_RichCompare = dlsym(py_handle, "PyObject_RichCompare")
PyObject_IsInstance = dlsym(py_handle, "PyObject_IsInstance")
Py_None = dlsym(py_handle, "_Py_NoneStruct")
Py_True = dlsym(py_handle, "_Py_TrueStruct")
Py_False = dlsym(py_handle, "_Py_FalseStruct")
@ -603,19 +612,17 @@ class pyobj:
def exc_check():
ptype, pvalue, ptraceback = cobj(), cobj(), cobj()
PyErr_Fetch(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback))
PyErr_NormalizeException(__ptr__(ptype), __ptr__(pvalue), __ptr__(ptraceback))
if ptype != cobj():
py_msg = PyObject_Str(pvalue) if pvalue != cobj() else pvalue
msg = pyobj.to_str(py_msg, "ignore", "<empty Python message>")
py_typ = PyObject_GetAttrString(ptype, "__name__".c_str())
typ = pyobj.to_str(py_typ, "ignore")
pyobj.decref(ptype)
pyobj.decref(pvalue)
pyobj.decref(ptraceback)
pyobj.decref(py_msg)
pyobj.decref(py_typ)
raise PyError(msg, typ)
# pyobj.decref(pvalue)
raise PyError(msg, pyobj(pvalue))
def exc_wrap(_retval: T, T: type) -> T:
pyobj.exc_check()
@ -667,7 +674,14 @@ class pyobj:
PyRun_SimpleString(code.c_str())
def _globals() -> pyobj:
return pyobj(PyEval_GetGlobals())
p = PyEval_GetGlobals()
if p == cobj():
Py_IncRef(Py_None)
return pyobj(Py_None)
return pyobj(p)
def _builtins() -> pyobj:
return pyobj(PyEval_GetBuiltins())
def _get_module(name: str) -> pyobj:
p = pyobj(pyobj.exc_wrap(PyImport_AddModule(name.c_str())))
@ -686,6 +700,17 @@ class pyobj:
return bool(pyobj.exc_wrap(PyObject_IsTrue(self.p) == 1))
def _get_identifier(typ: str) -> pyobj:
t = pyobj._builtins()[typ]
if t.p == cobj():
t = pyobj._main_module()[typ]
return t
def _isinstance(what: pyobj, typ: pyobj) -> bool:
return bool(pyobj.exc_wrap(PyObject_IsInstance(what.p, typ.p)))
# Type conversions
@extend

View File

@ -1,3 +1,4 @@
# (c) 2022 Exaloop Inc. All rights reserved.
from internal.gc import sizeof

View File

@ -106,6 +106,15 @@ class Dict:
def __len__(self) -> int:
return self._size
def __or__(self, other):
new = self.__copy__()
new.update(other)
return new
def __ior__(self, other):
self.update(other)
return self
def __copy__(self):
if self.__len__() == 0:
return Dict[K, V]()
@ -183,9 +192,13 @@ class Dict:
ret, x = self._kh_put(key)
self._vals[x] = op(dflt if ret != 0 else self._vals[x], other)
def update(self, other: Dict[K, V]):
def update(self, other):
if isinstance(other, Dict[K, V]):
for k, v in other.items():
self[k] = v
else:
for k, v in other:
self[k] = v
def pop(self, key: K) -> V:
x = self._kh_get(key)
@ -284,10 +297,14 @@ class Dict:
if self._n_buckets < new_n_buckets:
self._keys = Ptr[K](
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
gc.realloc(self._keys.as_byte(),
new_n_buckets * gc.sizeof(K),
self._n_buckets * gc.sizeof(K))
)
self._vals = Ptr[V](
gc.realloc(self._vals.as_byte(), new_n_buckets * gc.sizeof(V))
gc.realloc(self._vals.as_byte(),
new_n_buckets * gc.sizeof(V),
self._n_buckets * gc.sizeof(V))
)
if j:
@ -324,10 +341,14 @@ class Dict:
if self._n_buckets > new_n_buckets:
self._keys = Ptr[K](
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
gc.realloc(self._keys.as_byte(),
new_n_buckets * gc.sizeof(K),
self._n_buckets * gc.sizeof(K))
)
self._vals = Ptr[V](
gc.realloc(self._vals.as_byte(), new_n_buckets * gc.sizeof(V))
gc.realloc(self._vals.as_byte(),
new_n_buckets * gc.sizeof(V),
self._n_buckets * gc.sizeof(V))
)
self._flags = new_flags

View File

@ -238,7 +238,9 @@ class List:
len0 = self.__len__()
new_cap = n * len0
if self.arr.len < new_cap:
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), new_cap * gc.sizeof(T)))
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(),
new_cap * gc.sizeof(T),
self.arr.len * gc.sizeof(T)))
self.arr = Array[T](p, new_cap)
idx = len0
@ -349,7 +351,9 @@ class List:
raise IndexError(msg)
def _resize(self, new_cap: int):
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(), new_cap * gc.sizeof(T)))
p = Ptr[T](gc.realloc(self.arr.ptr.as_byte(),
new_cap * gc.sizeof(T),
self.arr.len * gc.sizeof(T)))
self.arr = Array[T](p, new_cap)
def _resize_if_full(self):
@ -414,5 +418,37 @@ class List:
return Array[T](Ptr[T](), 0)
return self.arr.slice(start, stop).__copy__()
def _cmp(self, other: List[T]):
n1 = self.__len__()
n2 = other.__len__()
nmin = n1 if n1 < n2 else n2
for i in range(nmin):
a = self.arr[i]
b = other.arr[i]
if a < b:
return -1
elif a == b:
continue
else:
return 1
if n1 < n2:
return -1
elif n1 == n2:
return 0
else:
return 1
def __lt__(self, other: List[T]):
return self._cmp(other) < 0
def __gt__(self, other: List[T]):
return self._cmp(other) > 0
def __le__(self, other: List[T]):
return self._cmp(other) <= 0
def __ge__(self, other: List[T]):
return self._cmp(other) >= 0
list = List

View File

@ -316,7 +316,9 @@ class Set:
if self._n_buckets < new_n_buckets:
self._keys = Ptr[K](
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
gc.realloc(self._keys.as_byte(),
new_n_buckets * gc.sizeof(K),
self._n_buckets * gc.sizeof(K))
)
if j:
@ -350,7 +352,9 @@ class Set:
if self._n_buckets > new_n_buckets:
self._keys = Ptr[K](
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
gc.realloc(self._keys.as_byte(),
new_n_buckets * gc.sizeof(K),
self._n_buckets * gc.sizeof(K))
)
self._flags = new_flags

View File

@ -72,9 +72,9 @@ class CError(Exception):
class PyError(Exception):
pytype: str
pytype: pyobj
def __init__(self, message: str, pytype: str):
def __init__(self, message: str, pytype: pyobj):
super().__init__("PyError", message)
self.pytype = pytype

View File

@ -395,3 +395,326 @@ class float:
@property
def imag(self) -> float:
return 0.0
@extend
class float32:
@pure
@llvm
def __new__(self: float) -> float32:
%0 = fptrunc double %self to float
ret float %0
def __new__(what: float32) -> float32:
return what
def __new__() -> float32:
return float32.__new__(0.0)
def __repr__(self) -> str:
s = seq_str_float(self.__float__())
return s if s != "-nan" else "nan"
def __copy__(self) -> float32:
return self
def __deepcopy__(self) -> float32:
return self
@pure
@llvm
def __int__(self) -> int:
%0 = fptosi float %self to i64
ret i64 %0
@pure
@llvm
def __float__(self) -> float:
%0 = fpext float %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp one float %self, 0.000000e+00
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> float32:
return self
@pure
@llvm
def __neg__(self) -> float32:
%0 = fneg float %self
ret float %0
@pure
@commutative
@llvm
def __add__(a: float32, b: float32) -> float32:
%tmp = fadd float %a, %b
ret float %tmp
@pure
@llvm
def __sub__(a: float32, b: float32) -> float32:
%tmp = fsub float %a, %b
ret float %tmp
@pure
@commutative
@llvm
def __mul__(a: float32, b: float32) -> float32:
%tmp = fmul float %a, %b
ret float %tmp
def __floordiv__(self, other: float32) -> float:
return self.__truediv__(other).__floor__()
@pure
@llvm
def __truediv__(a: float32, b: float32) -> float32:
%tmp = fdiv float %a, %b
ret float %tmp
@pure
@llvm
def __mod__(a: float32, b: float32) -> float32:
%tmp = frem float %a, %b
ret float %tmp
def __divmod__(self, other: float32) -> Tuple[float32, float32]:
mod = self % other
div = (self - mod) / other
if mod:
if (other < 0) != (mod < 0):
mod += other
div -= 1.0
else:
mod = (0.0).copysign(other)
floordiv = 0.0
if div:
floordiv = div.__floor__()
if div - floordiv > 0.5:
floordiv += 1.0
else:
floordiv = (0.0).copysign(self / other)
return (floordiv, mod)
@pure
@llvm
def __eq__(a: float32, b: float32) -> bool:
%tmp = fcmp oeq float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ne__(a: float32, b: float32) -> bool:
entry:
%tmp = fcmp one float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __lt__(a: float32, b: float32) -> bool:
%tmp = fcmp olt float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __gt__(a: float32, b: float32) -> bool:
%tmp = fcmp ogt float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __le__(a: float32, b: float32) -> bool:
%tmp = fcmp ole float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ge__(a: float32, b: float32) -> bool:
%tmp = fcmp oge float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def sqrt(a: float32) -> float32:
declare float @llvm.sqrt.f32(float %a)
%tmp = call float @llvm.sqrt.f32(float %a)
ret float %tmp
@pure
@llvm
def sin(a: float32) -> float32:
declare float @llvm.sin.f32(float %a)
%tmp = call float @llvm.sin.f32(float %a)
ret float %tmp
@pure
@llvm
def cos(a: float32) -> float32:
declare float @llvm.cos.f32(float %a)
%tmp = call float @llvm.cos.f32(float %a)
ret float %tmp
@pure
@llvm
def exp(a: float32) -> float32:
declare float @llvm.exp.f32(float %a)
%tmp = call float @llvm.exp.f32(float %a)
ret float %tmp
@pure
@llvm
def exp2(a: float32) -> float32:
declare float @llvm.exp2.f32(float %a)
%tmp = call float @llvm.exp2.f32(float %a)
ret float %tmp
@pure
@llvm
def log(a: float32) -> float32:
declare float @llvm.log.f32(float %a)
%tmp = call float @llvm.log.f32(float %a)
ret float %tmp
@pure
@llvm
def log10(a: float32) -> float32:
declare float @llvm.log10.f32(float %a)
%tmp = call float @llvm.log10.f32(float %a)
ret float %tmp
@pure
@llvm
def log2(a: float32) -> float32:
declare float @llvm.log2.f32(float %a)
%tmp = call float @llvm.log2.f32(float %a)
ret float %tmp
@pure
@llvm
def __abs__(a: float32) -> float32:
declare float @llvm.fabs.f32(float %a)
%tmp = call float @llvm.fabs.f32(float %a)
ret float %tmp
@pure
@llvm
def __floor__(a: float32) -> float32:
declare float @llvm.floor.f32(float %a)
%tmp = call float @llvm.floor.f32(float %a)
ret float %tmp
@pure
@llvm
def __ceil__(a: float32) -> float32:
declare float @llvm.ceil.f32(float %a)
%tmp = call float @llvm.ceil.f32(float %a)
ret float %tmp
@pure
@llvm
def __trunc__(a: float32) -> float32:
declare float @llvm.trunc.f32(float %a)
%tmp = call float @llvm.trunc.f32(float %a)
ret float %tmp
@pure
@llvm
def rint(a: float32) -> float32:
declare float @llvm.rint.f32(float %a)
%tmp = call float @llvm.rint.f32(float %a)
ret float %tmp
@pure
@llvm
def nearbyint(a: float32) -> float32:
declare float @llvm.nearbyint.f32(float %a)
%tmp = call float @llvm.nearbyint.f32(float %a)
ret float %tmp
@pure
@llvm
def __round__(a: float32) -> float32:
declare float @llvm.round.f32(float %a)
%tmp = call float @llvm.round.f32(float %a)
ret float %tmp
@pure
@llvm
def __pow__(a: float32, b: float32) -> float32:
declare float @llvm.pow.f32(float %a, float %b)
%tmp = call float @llvm.pow.f32(float %a, float %b)
ret float %tmp
@pure
@llvm
def min(a: float32, b: float32) -> float32:
declare float @llvm.minnum.f32(float %a, float %b)
%tmp = call float @llvm.minnum.f32(float %a, float %b)
ret float %tmp
@pure
@llvm
def max(a: float32, b: float32) -> float32:
declare float @llvm.maxnum.f32(float %a, float %b)
%tmp = call float @llvm.maxnum.f32(float %a, float %b)
ret float %tmp
@pure
@llvm
def copysign(a: float32, b: float32) -> float32:
declare float @llvm.copysign.f32(float %a, float %b)
%tmp = call float @llvm.copysign.f32(float %a, float %b)
ret float %tmp
@pure
@llvm
def fma(a: float32, b: float32, c: float32) -> float32:
declare float @llvm.fma.f32(float %a, float %b, float %c)
%tmp = call float @llvm.fma.f32(float %a, float %b, float %c)
ret float %tmp
@nocapture
@llvm
def __atomic_xchg__(d: Ptr[float32], b: float32) -> None:
%tmp = atomicrmw xchg float* %d, float %b seq_cst
ret {} {}
@nocapture
@llvm
def __atomic_add__(d: Ptr[float32], b: float32) -> float32:
%tmp = atomicrmw fadd float* %d, float %b seq_cst
ret float %tmp
@nocapture
@llvm
def __atomic_sub__(d: Ptr[float32], b: float32) -> float32:
%tmp = atomicrmw fsub float* %d, float %b seq_cst
ret float %tmp
def __hash__(self) -> int:
return self.__float__().__hash__()
def __match__(self, i: float32) -> bool:
return self == i
@extend
class float:
def __suffix_f32__(double) -> float32:
return float32.__new__(double)
f32 = float32

View File

@ -182,6 +182,9 @@ class int:
%tmp = srem i64 %a, %b
ret i64 %tmp
def __divmod__(self, other: float) -> Tuple[float, float]:
return float(self).__divmod__(other)
def __divmod__(self, other: int) -> Tuple[int, int]:
d = self // other
m = self - d * other

View File

@ -95,3 +95,11 @@ class range:
return f"range({self.start}, {self.stop})"
else:
return f"range({self.start}, {self.stop}, {self.step})"
@overload
def staticrange(start: Static[int], stop: Static[int], step: Static[int] = 1):
return range(start, stop, step)
@overload
def staticrange(stop: Static[int]):
return range(0, stop, 1)

View File

@ -1,11 +1,23 @@
# (c) 2022 Exaloop Inc. All rights reserved.
e = 2.718281828459045
pi = 3.141592653589793
tau = 6.283185307179586
inf = 1.0 / 0.0
nan = 0.0 / 0.0
@pure
@llvm
def _inf() -> float:
ret double 0x7FF0000000000000
@pure
@llvm
def _nan() -> float:
ret double 0x7FF8000000000000
e = 2.7182818284590452354
pi = 3.14159265358979323846
tau = 6.28318530717958647693
inf = _inf()
nan = _nan()
def factorial(x: int) -> int:
@ -43,11 +55,14 @@ def isnan(x: float) -> bool:
Return True if float arg is a NaN, else False.
"""
test = x == x
# if it is true then it is a number
if test:
return False
return True
@pure
@llvm
def f(x: float) -> bool:
%y = fcmp uno double %x, 0.000000e+00
%z = zext i1 %y to i8
ret i8 %z
return f(x)
def isinf(x: float) -> bool:
@ -56,7 +71,16 @@ def isinf(x: float) -> bool:
Return True if float arg is an INF, else False.
"""
return x == inf or x == -inf
@pure
@llvm
def f(x: float) -> bool:
declare double @llvm.fabs.f64(double)
%a = call double @llvm.fabs.f64(double %x)
%b = fcmp oeq double %a, 0x7FF0000000000000
%c = zext i1 %b to i8
ret i8 %c
return f(x)
def isfinite(x: float) -> bool:
@ -66,9 +90,35 @@ def isfinite(x: float) -> bool:
Return True if x is neither an infinity nor a NaN,
and False otherwise.
"""
if isnan(x) or isinf(x):
return False
return True
return not (isnan(x) or isinf(x))
def _check1(arg: float, r: float, can_overflow: bool = False):
if __py_numerics__:
if isnan(r) and not isnan(arg):
raise ValueError("math domain error")
if isinf(r) and isfinite(arg):
if can_overflow:
raise OverflowError("math range error")
else:
raise ValueError("math domain error")
return r
def _check2(x: float, y: float, r: float, can_overflow: bool = False):
if __py_numerics__:
if isnan(r) and not isnan(x) and not isnan(y):
raise ValueError("math domain error")
if isinf(r) and isfinite(x) and isfinite(y):
if can_overflow:
raise OverflowError("math range error")
else:
raise ValueError("math domain error")
return r
def ceil(x: float) -> float:
@ -149,7 +199,7 @@ def exp(x: float) -> float:
%y = call double @llvm.exp.f64(double %x)
ret double %y
return f(x)
return _check1(x, f(x), True)
def expm1(x: float) -> float:
@ -159,7 +209,7 @@ def expm1(x: float) -> float:
Return e raised to the power x, minus 1. expm1 provides
a way to compute this quantity to full precision.
"""
return _C.expm1(x)
return _check1(x, _C.expm1(x), True)
def ldexp(x: float, i: int) -> float:
@ -168,7 +218,7 @@ def ldexp(x: float, i: int) -> float:
Returns x multiplied by 2 raised to the power of exponent.
"""
return _C.ldexp(x, i)
return _check1(x, _C.ldexp(x, i32(i)), True)
def log(x: float, base: float = e) -> float:
@ -185,9 +235,9 @@ def log(x: float, base: float = e) -> float:
ret double %y
if base == e:
return f(x)
return _check1(x, f(x))
else:
return f(x) / f(base)
return _check1(x, f(x)) / _check1(base, f(base))
def log2(x: float) -> float:
@ -203,7 +253,7 @@ def log2(x: float) -> float:
%y = call double @llvm.log2.f64(double %x)
ret double %y
return f(x)
return _check1(x, f(x))
def log10(x: float) -> float:
@ -219,7 +269,7 @@ def log10(x: float) -> float:
%y = call double @llvm.log10.f64(double %x)
ret double %y
return f(x)
return _check1(x, f(x))
def degrees(x: float) -> float:
@ -255,7 +305,7 @@ def sqrt(x: float) -> float:
%y = call double @llvm.sqrt.f64(double %x)
ret double %y
return f(x)
return _check1(x, f(x))
def pow(x: float, y: float) -> float:
@ -271,7 +321,7 @@ def pow(x: float, y: float) -> float:
%z = call double @llvm.pow.f64(double %x, double %y)
ret double %z
return f(x, y)
return _check2(x, y, f(x, y), True)
def acos(x: float) -> float:
@ -280,7 +330,7 @@ def acos(x: float) -> float:
Returns the arc cosine of x in radians.
"""
return _C.acos(x)
return _check1(x, _C.acos(x))
def asin(x: float) -> float:
@ -289,7 +339,7 @@ def asin(x: float) -> float:
Returns the arc sine of x in radians.
"""
return _C.asin(x)
return _check1(x, _C.asin(x))
def atan(x: float) -> float:
@ -298,7 +348,7 @@ def atan(x: float) -> float:
Returns the arc tangent of x in radians.
"""
return _C.atan(x)
return _check1(x, _C.atan(x))
def atan2(y: float, x: float) -> float:
@ -309,7 +359,7 @@ def atan2(y: float, x: float) -> float:
on the signs of both values to determine the
correct quadrant.
"""
return _C.atan2(y, x)
return _check2(x, y, _C.atan2(y, x))
def cos(x: float) -> float:
@ -325,7 +375,7 @@ def cos(x: float) -> float:
%y = call double @llvm.cos.f64(double %x)
ret double %y
return f(x)
return _check1(x, f(x))
def sin(x: float) -> float:
@ -341,7 +391,7 @@ def sin(x: float) -> float:
%y = call double @llvm.sin.f64(double %x)
ret double %y
return f(x)
return _check1(x, f(x))
def hypot(x: float, y: float) -> float:
@ -352,7 +402,7 @@ def hypot(x: float, y: float) -> float:
This is the length of the vector from the
origin to point (x, y).
"""
return _C.hypot(x, y)
return _check2(x, y, _C.hypot(x, y), True)
def tan(x: float) -> float:
@ -361,7 +411,7 @@ def tan(x: float) -> float:
Return the tangent of a radian angle x.
"""
return _C.tan(x)
return _check1(x, _C.tan(x))
def cosh(x: float) -> float:
@ -370,7 +420,7 @@ def cosh(x: float) -> float:
Returns the hyperbolic cosine of x.
"""
return _C.cosh(x)
return _check1(x, _C.cosh(x), True)
def sinh(x: float) -> float:
@ -379,7 +429,7 @@ def sinh(x: float) -> float:
Returns the hyperbolic sine of x.
"""
return _C.sinh(x)
return _check1(x, _C.sinh(x), True)
def tanh(x: float) -> float:
@ -388,7 +438,7 @@ def tanh(x: float) -> float:
Returns the hyperbolic tangent of x.
"""
return _C.tanh(x)
return _check1(x, _C.tanh(x))
def acosh(x: float) -> float:
@ -397,7 +447,7 @@ def acosh(x: float) -> float:
Return the inverse hyperbolic cosine of x.
"""
return _C.acosh(x)
return _check1(x, _C.acosh(x))
def asinh(x: float) -> float:
@ -406,7 +456,7 @@ def asinh(x: float) -> float:
Return the inverse hyperbolic sine of x.
"""
return _C.asinh(x)
return _check1(x, _C.asinh(x))
def atanh(x: float) -> float:
@ -415,7 +465,7 @@ def atanh(x: float) -> float:
Return the inverse hyperbolic tangent of x.
"""
return _C.atanh(x)
return _check1(x, _C.atanh(x))
def copysign(x: float, y: float) -> float:
@ -432,7 +482,7 @@ def copysign(x: float, y: float) -> float:
%z = call double @llvm.copysign.f64(double %x, double %y)
ret double %z
return f(x, y)
return _check2(x, y, f(x, y))
def log1p(x: float) -> float:
@ -441,7 +491,7 @@ def log1p(x: float) -> float:
Return the natural logarithm of 1+x (base e).
"""
return _C.log1p(x)
return _check1(x, _C.log1p(x))
def trunc(x: float) -> float:
@ -458,7 +508,7 @@ def trunc(x: float) -> float:
%y = call double @llvm.trunc.f64(double %x)
ret double %y
return f(x)
return _check1(x, f(x))
def erf(x: float) -> float:
@ -467,7 +517,7 @@ def erf(x: float) -> float:
Return the error function at x.
"""
return _C.erf(x)
return _check1(x, _C.erf(x))
def erfc(x: float) -> float:
@ -476,7 +526,7 @@ def erfc(x: float) -> float:
Return the complementary error function at x.
"""
return _C.erfc(x)
return _check1(x, _C.erfc(x))
def gamma(x: float) -> float:
@ -485,7 +535,7 @@ def gamma(x: float) -> float:
Return the Gamma function at x.
"""
return _C.tgamma(x)
return _check1(x, _C.tgamma(x), True)
def lgamma(x: float) -> float:
@ -495,7 +545,7 @@ def lgamma(x: float) -> float:
Return the natural logarithm of
the absolute value of the Gamma function at x.
"""
return _C.lgamma(x)
return _check1(x, _C.lgamma(x), True)
def remainder(x: float, y: float) -> float:
@ -509,7 +559,7 @@ def remainder(x: float, y: float) -> float:
two consecutive integers, the nearest even integer is used
for n.
"""
return _C.remainder(x, y)
return _check2(x, y, _C.remainder(x, y))
def gcd(a: float, b: float) -> float:
@ -585,3 +635,615 @@ def isclose(a: float, b: float, rel_tol: float = 1e-09, abs_tol: float = 0.0) ->
return ((diff <= fabs(rel_tol * b)) or (diff <= fabs(rel_tol * a))) or (
diff <= abs_tol
)
# 32-bit float ops
e32 = float32(e)
pi32 = float32(pi)
tau32 = float32(tau)
inf32 = float32(inf)
nan32 = float32(nan)
@overload
def isnan(x: float32) -> bool:
"""
isnan(float32) -> bool
Return True if float arg is a NaN, else False.
"""
@pure
@llvm
def f(x: float32) -> bool:
%y = fcmp uno float %x, 0.000000e+00
%z = zext i1 %y to i8
ret i8 %z
return f(x)
@overload
def isinf(x: float32) -> bool:
"""
isinf(float32) -> bool:
Return True if float arg is an INF, else False.
"""
@pure
@llvm
def f(x: float32) -> bool:
declare float @llvm.fabs.f32(float)
%a = call float @llvm.fabs.f32(float %x)
%b = fcmp oeq float %a, 0x7FF0000000000000
%c = zext i1 %b to i8
ret i8 %c
return f(x)
@overload
def isfinite(x: float32) -> bool:
"""
isfinite(float32) -> bool
Return True if x is neither an infinity nor a NaN,
and False otherwise.
"""
return not (isnan(x) or isinf(x))
@overload
def ceil(x: float32) -> float32:
"""
ceil(float32) -> float32
Return the ceiling of x as an Integral.
This is the smallest integer >= x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.ceil.f32(float)
%y = call float @llvm.ceil.f32(float %x)
ret float %y
return f(x)
@overload
def floor(x: float32) -> float32:
"""
floor(float32) -> float32
Return the floor of x as an Integral.
This is the largest integer <= x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.floor.f32(float)
%y = call float @llvm.floor.f32(float %x)
ret float %y
return f(x)
@overload
def fabs(x: float32) -> float32:
"""
fabs(float32) -> float32
Returns the absolute value of a float32ing point number.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.fabs.f32(float)
%y = call float @llvm.fabs.f32(float %x)
ret float %y
return f(x)
@overload
def fmod(x: float32, y: float32) -> float32:
"""
fmod(float32, float32) -> float32
Returns the remainder of x divided by y.
"""
@pure
@llvm
def f(x: float32, y: float32) -> float32:
%z = frem float %x, %y
ret float %z
return f(x, y)
@overload
def exp(x: float32) -> float32:
"""
exp(float32) -> float32
Returns the value of e raised to the xth power.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.exp.f32(float)
%y = call float @llvm.exp.f32(float %x)
ret float %y
return f(x)
@overload
def expm1(x: float32) -> float32:
"""
expm1(float32) -> float32
Return e raised to the power x, minus 1. expm1 provides
a way to compute this quantity to full precision.
"""
return _C.expm1f(x)
@overload
def ldexp(x: float32, i: int) -> float32:
"""
ldexp(float32, int) -> float32
Returns x multiplied by 2 raised to the power of exponent.
"""
return _C.ldexpf(x, i32(i))
@overload
def log(x: float32, base: float32 = e32) -> float32:
"""
log(float32) -> float32
Returns the natural logarithm (base-e logarithm) of x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.log.f32(float)
%y = call float @llvm.log.f32(float %x)
ret float %y
if base == e32:
return f(x)
else:
return f(x) / f(base)
@overload
def log2(x: float32) -> float32:
"""
log2(float32) -> float32
Return the base-2 logarithm of x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.log2.f32(float)
%y = call float @llvm.log2.f32(float %x)
ret float %y
return f(x)
@overload
def log10(x: float32) -> float32:
"""
log10(float32) -> float32
Returns the common logarithm (base-10 logarithm) of x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.log10.f32(float)
%y = call float @llvm.log10.f32(float %x)
ret float %y
return f(x)
@overload
def degrees(x: float32) -> float32:
"""
degrees(float32) -> float32
Convert angle x from radians to degrees.
"""
radToDeg = float32(180.0) / pi32
return x * radToDeg
@overload
def radians(x: float32) -> float32:
"""
radians(float32) -> float32
Convert angle x from degrees to radians.
"""
degToRad = pi32 / float32(180.0)
return x * degToRad
@overload
def sqrt(x: float32) -> float32:
"""
sqrt(float32) -> float32
Returns the square root of x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.sqrt.f32(float)
%y = call float @llvm.sqrt.f32(float %x)
ret float %y
return f(x)
@overload
def pow(x: float32, y: float32) -> float32:
"""
pow(float32, float32) -> float32
Returns x raised to the power of y.
"""
@pure
@llvm
def f(x: float32, y: float32) -> float32:
declare float @llvm.pow.f32(float, float)
%z = call float @llvm.pow.f32(float %x, float %y)
ret float %z
return f(x, y)
@overload
def acos(x: float32) -> float32:
"""
acos(float32) -> float32
Returns the arc cosine of x in radians.
"""
return _C.acosf(x)
@overload
def asin(x: float32) -> float32:
"""
asin(float32) -> float32
Returns the arc sine of x in radians.
"""
return _C.asinf(x)
@overload
def atan(x: float32) -> float32:
"""
atan(float32) -> float32
Returns the arc tangent of x in radians.
"""
return _C.atanf(x)
@overload
def atan2(y: float32, x: float32) -> float32:
"""
atan2(float32, float32) -> float32
Returns the arc tangent in radians of y/x based
on the signs of both values to determine the
correct quadrant.
"""
return _C.atan2f(y, x)
@overload
def cos(x: float32) -> float32:
"""
cos(float32) -> float32
Returns the cosine of a radian angle x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.cos.f32(float)
%y = call float @llvm.cos.f32(float %x)
ret float %y
return f(x)
@overload
def sin(x: float32) -> float32:
"""
sin(float32) -> float32
Returns the sine of a radian angle x.
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.sin.f32(float)
%y = call float @llvm.sin.f32(float %x)
ret float %y
return f(x)
@overload
def hypot(x: float32, y: float32) -> float32:
"""
hypot(float32, float32) -> float32
Return the Euclidean norm.
This is the length of the vector from the
origin to point (x, y).
"""
return _C.hypotf(x, y)
@overload
def tan(x: float32) -> float32:
"""
tan(float32) -> float32
Return the tangent of a radian angle x.
"""
return _C.tanf(x)
@overload
def cosh(x: float32) -> float32:
"""
cosh(float32) -> float32
Returns the hyperbolic cosine of x.
"""
return _C.coshf(x)
@overload
def sinh(x: float32) -> float32:
"""
sinh(float32) -> float32
Returns the hyperbolic sine of x.
"""
return _C.sinhf(x)
@overload
def tanh(x: float32) -> float32:
"""
tanh(float32) -> float32
Returns the hyperbolic tangent of x.
"""
return _C.tanhf(x)
@overload
def acosh(x: float32) -> float32:
"""
acosh(float32) -> float32
Return the inverse hyperbolic cosine of x.
"""
return _C.acoshf(x)
@overload
def asinh(x: float32) -> float32:
"""
asinh(float32) -> float32
Return the inverse hyperbolic sine of x.
"""
return _C.asinhf(x)
@overload
def atanh(x: float32) -> float32:
"""
atanh(float32) -> float32
Return the inverse hyperbolic tangent of x.
"""
return _C.atanhf(x)
@overload
def copysign(x: float32, y: float32) -> float32:
"""
copysign(float32, float32) -> float32
Return a float32 with the magnitude (absolute value) of
x but the sign of y.
"""
@pure
@llvm
def f(x: float32, y: float32) -> float32:
declare float @llvm.copysign.f32(float, float)
%z = call float @llvm.copysign.f32(float %x, float %y)
ret float %z
return f(x, y)
@overload
def log1p(x: float32) -> float32:
"""
log1p(float32) -> float32
Return the natural logarithm of 1+x (base e).
"""
return _C.log1pf(x)
@overload
def trunc(x: float32) -> float32:
"""
trunc(float32) -> float32
Return the Real value x truncated to an Integral
(usually an integer).
"""
@pure
@llvm
def f(x: float32) -> float32:
declare float @llvm.trunc.f32(float)
%y = call float @llvm.trunc.f32(float %x)
ret float %y
return f(x)
@overload
def erf(x: float32) -> float32:
"""
erf(float32) -> float32
Return the error function at x.
"""
return _C.erff(x)
@overload
def erfc(x: float32) -> float32:
"""
erfc(float32) -> float32
Return the complementary error function at x.
"""
return _C.erfcf(x)
@overload
def gamma(x: float32) -> float32:
"""
gamma(float32) -> float32
Return the Gamma function at x.
"""
return _C.tgammaf(x)
@overload
def lgamma(x: float32) -> float32:
"""
lgamma(float32) -> float32
Return the natural logarithm of
the absolute value of the Gamma function at x.
"""
return _C.lgammaf(x)
@overload
def remainder(x: float32, y: float32) -> float32:
"""
remainder(float32, float32) -> float32
Return the IEEE 754-style remainder of x with respect to y.
For finite x and finite nonzero y, this is the difference
x - n*y, where n is the closest integer to the exact value
of the quotient x / y. If x / y is exactly halfway between
two consecutive integers, the nearest even integer is used
for n.
"""
return _C.remainderf(x, y)
@overload
def gcd(a: float32, b: float32) -> float32:
"""
gcd(float32, float32) -> float32
returns greatest common divisor of x and y.
"""
a = abs(a)
b = abs(b)
while a:
a, b = b % a, a
return b
@overload
@pure
def frexp(x: float32) -> Tuple[float32, int]:
"""
frexp(float32) -> Tuple[float32, int]
The returned value is the mantissa and the integer pointed
to by exponent is the exponent. The resultant value is
x = mantissa * 2 ^ exponent.
"""
tmp = i32(0)
res = _C.frexpf(float32(x), __ptr__(tmp))
return (res, int(tmp))
@overload
@pure
def modf(x: float32) -> Tuple[float32, float32]:
"""
modf(float32) -> Tuple[float32, float32]
The returned value is the fraction component (part after
the decimal), and sets integer to the integer component.
"""
tmp = float32(0.0)
res = _C.modff(float32(x), __ptr__(tmp))
return (res, tmp)
@overload
def isclose(a: float32, b: float32, rel_tol: float32 = float32(1e-09), abs_tol: float32 = float32(0.0)) -> bool:
"""
isclose(float32, float32) -> bool
Return True if a is close in value to b, and False otherwise.
For the values to be considered close, the difference between them
must be smaller than at least one of the tolerances.
"""
# short circuit exact equality -- needed to catch two
# infinities of the same sign. And perhaps speeds things
# up a bit sometimes.
if a == b:
return True
# This catches the case of two infinities of opposite sign, or
# one infinity and one finite number. Two infinities of opposite
# sign would otherwise have an infinite relative tolerance.
# Two infinities of the same sign are caught by the equality check
# above.
if a == inf32 or b == inf32:
return False
# NAN is not close to anything, not even itself
if a == nan32 or b == nan32:
return False
# regular computation
diff = fabs(b - a)
return ((diff <= fabs(rel_tol * b)) or (diff <= fabs(rel_tol * a))) or (
diff <= abs_tol
)

View File

@ -643,6 +643,7 @@ def _task_loop_outline_template(gtid_ptr: Ptr[i32], btid_ptr: Ptr[i32], args):
_loop_reductions(shared)
_barrier(loc_ref, gtid)
@pure
def get_num_threads():
from C import omp_get_num_threads() -> i32
@ -952,10 +953,41 @@ def _atomic_float_max(a: Ptr[float], b: float):
__kmpc_atomic_float8_max(_default_loc(), i32(0), a, b)
def _atomic_float32_add(a: Ptr[float32], b: float32) -> None:
from C import __kmpc_atomic_float4_add(Ptr[Ident], i32, Ptr[float32], float32)
__kmpc_atomic_float4_add(_default_loc(), i32(0), a, b)
def _atomic_float32_mul(a: Ptr[float32], b: float32):
from C import __kmpc_atomic_float4_mul(Ptr[Ident], i32, Ptr[float32], float32)
__kmpc_atomic_float4_mul(_default_loc(), i32(0), a, b)
def _atomic_float32_min(a: Ptr[float32], b: float32) -> None:
from C import __kmpc_atomic_float4_min(Ptr[Ident], i32, Ptr[float32], float32)
__kmpc_atomic_float4_min(_default_loc(), i32(0), a, b)
def _atomic_float32_max(a: Ptr[float32], b: float32) -> None:
from C import __kmpc_atomic_float4_max(Ptr[Ident], i32, Ptr[float32], float32)
__kmpc_atomic_float4_max(_default_loc(), i32(0), a, b)
def _range_len(start: int, stop: int, step: int):
if step > 0 and start < stop:
return 1 + (stop - 1 - start) // step
elif step < 0 and start > stop:
return 1 + (start - 1 - stop) // (-step)
else:
return 0
def for_par(
num_threads: int = -1,
chunk_size: int = -1,
schedule: Static[str] = "static",
ordered: Static[int] = False,
collapse: Static[int] = 0,
gpu: Static[int] = False,
):
pass

View File

@ -77,6 +77,15 @@ class float:
return _read(jar, float)
@extend
class float32:
def __pickle__(self, jar: Jar):
_write(jar, self)
def __unpickle__(jar: Jar) -> float32:
return _read(jar, float32)
@extend
class bool:
def __pickle__(self, jar: Jar):

Some files were not shown because too many files have changed in this diff Show More