New Tuple machinery (#462)

* Refactor Tuple class

* Add Tuple[N,...] support; Fix inline operators with list brackets

* Fix Tuple[N,...] support

* Merge Sequre SIMD changes

* Fix repeat-tuple unification

* Fix staticlen in OpenMP

* Fix isinstance unification

* Fix delayed unification with static realization

* Fix CI

* Cleanup

* Use "fcmp une" in float.__ne__()

---------

Co-authored-by: A. R. Shajii <ars@ars.me>
pull/484/head
Ibrahim Numanagić 2023-09-26 07:49:14 -07:00 committed by GitHub
parent 9933954e30
commit ce459c5667
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1218 additions and 254 deletions

View File

@ -1745,7 +1745,7 @@ void LLVMVisitor::visit(const InternalFunc *x) {
}
else if (internalFuncMatchesIgnoreArgs<RecordType>("__new__", x)) {
auto *recordType = cast<RecordType>(parentType);
auto *recordType = cast<RecordType>(cast<FuncType>(x->getType())->getReturnType());
seqassertn(args.size() == std::distance(recordType->begin(), recordType->end()),
"args size does not match");
result = llvm::UndefValue::get(getLLVMType(recordType));

View File

@ -672,8 +672,6 @@ void ClassStmt::parseDecorators() {
E(Error::CLASS_BAD_DECORATOR, d);
}
}
if (startswith(name, TYPE_TUPLE))
tupleMagics["contains"] = true;
if (attributes.has("deduce"))
tupleMagics["new"] = false;
if (!attributes.has(Attr::Tuple)) {
@ -681,12 +679,7 @@ void ClassStmt::parseDecorators() {
tupleMagics["new"] = tupleMagics["raw"] = true;
tupleMagics["len"] = false;
}
if (startswith(name, TYPE_TUPLE)) {
tupleMagics["add"] = true;
tupleMagics["mul"] = true;
} else {
tupleMagics["dict"] = true;
}
tupleMagics["dict"] = true;
// Internal classes do not get any auto-generated members.
attributes.magics.clear();
if (!attributes.has(Attr::Internal)) {

View File

@ -104,8 +104,6 @@ std::string ClassType::debugString(char mode) const {
}
// Special formatting for Functions and Tuples
auto n = mode == 0 ? niceName : name;
if (startswith(n, TYPE_TUPLE))
n = "Tuple";
return fmt::format("{}{}", n, gs.empty() ? "" : fmt::format("[{}]", join(gs, ",")));
}
@ -130,13 +128,13 @@ std::string ClassType::realizedTypeName() const {
RecordType::RecordType(Cache *cache, std::string name, std::string niceName,
std::vector<Generic> generics, std::vector<TypePtr> args,
bool noTuple)
bool noTuple, const std::shared_ptr<StaticType> &repeats)
: ClassType(cache, std::move(name), std::move(niceName), std::move(generics)),
args(std::move(args)), noTuple(false) {}
args(std::move(args)), noTuple(false), repeats(repeats) {}
RecordType::RecordType(const ClassTypePtr &base, std::vector<TypePtr> args,
bool noTuple)
: ClassType(base), args(std::move(args)), noTuple(noTuple) {}
bool noTuple, const std::shared_ptr<StaticType> &repeats)
: ClassType(base), args(std::move(args)), noTuple(noTuple), repeats(repeats) {}
int RecordType::unify(Type *typ, Unification *us) {
if (auto tr = typ->getRecord()) {
@ -148,6 +146,27 @@ int RecordType::unify(Type *typ, Unification *us) {
return generics[0].type->unify(t64.get(), us);
}
// TODO: we now support very limited unification strategy where repetitions must
// match. We should expand this later on...
if (repeats || tr->repeats) {
if (!repeats && tr->repeats) {
auto n = std::make_shared<StaticType>(cache, args.size());
if (tr->repeats->unify(n.get(), us) == -1)
return -1;
} else if (!tr->repeats) {
auto n = std::make_shared<StaticType>(cache, tr->args.size());
if (repeats->unify(n.get(), us) == -1)
return -1;
} else {
if (repeats->unify(tr->repeats.get(), us) == -1)
return -1;
}
}
if (getRepeats() != -1)
flatten();
if (tr->getRepeats() != -1)
tr->flatten();
int s1 = 2, s = 0;
if (args.size() != tr->args.size())
return -1;
@ -158,7 +177,7 @@ int RecordType::unify(Type *typ, Unification *us) {
return -1;
}
// Handle Tuple<->@tuple: when unifying tuples, only record members matter.
if (startswith(name, TYPE_TUPLE) || startswith(tr->name, TYPE_TUPLE)) {
if (name == TYPE_TUPLE || tr->name == TYPE_TUPLE) {
if (!args.empty() || (!noTuple && !tr->noTuple)) // prevent POD<->() unification
return s1 + int(name == tr->name);
else
@ -177,7 +196,8 @@ TypePtr RecordType::generalize(int atLevel) {
auto a = args;
for (auto &t : a)
t = t->generalize(atLevel);
return std::make_shared<RecordType>(c, a, noTuple);
auto r = repeats ? repeats->generalize(atLevel)->getStatic() : nullptr;
return std::make_shared<RecordType>(c, a, noTuple, r);
}
TypePtr RecordType::instantiate(int atLevel, int *unboundCount,
@ -187,11 +207,17 @@ TypePtr RecordType::instantiate(int atLevel, int *unboundCount,
auto a = args;
for (auto &t : a)
t = t->instantiate(atLevel, unboundCount, cache);
return std::make_shared<RecordType>(c, a, noTuple);
auto r = repeats ? repeats->instantiate(atLevel, unboundCount, cache)->getStatic()
: nullptr;
return std::make_shared<RecordType>(c, a, noTuple, r);
}
std::vector<TypePtr> RecordType::getUnbounds() const {
std::vector<TypePtr> u;
if (repeats) {
auto tu = repeats->getUnbounds();
u.insert(u.begin(), tu.begin(), tu.end());
}
for (auto &a : args) {
auto tu = a->getUnbounds();
u.insert(u.begin(), tu.begin(), tu.end());
@ -202,21 +228,58 @@ std::vector<TypePtr> RecordType::getUnbounds() const {
}
bool RecordType::canRealize() const {
return std::all_of(args.begin(), args.end(),
return getRepeats() >= 0 &&
std::all_of(args.begin(), args.end(),
[](auto &a) { return a->canRealize(); }) &&
this->ClassType::canRealize();
}
bool RecordType::isInstantiated() const {
return std::all_of(args.begin(), args.end(),
return (!repeats || repeats->isInstantiated()) &&
std::all_of(args.begin(), args.end(),
[](auto &a) { return a->isInstantiated(); }) &&
this->ClassType::isInstantiated();
}
std::string RecordType::debugString(char mode) const {
return fmt::format("{}", this->ClassType::debugString(mode));
std::string RecordType::realizedName() const {
if (!_rn.empty())
return _rn;
if (name == TYPE_TUPLE) {
std::vector<std::string> gs;
auto n = getRepeats();
if (n == -1)
gs.push_back(repeats->realizedName());
for (int i = 0; i < std::max(n, int64_t(0)); i++)
for (auto &a : args)
gs.push_back(a->realizedName());
std::string s = join(gs, ",");
if (canRealize())
const_cast<RecordType *>(this)->_rn =
fmt::format("{}{}", name, s.empty() ? "" : fmt::format("[{}]", s));
return _rn;
}
return ClassType::realizedName();
}
std::string RecordType::debugString(char mode) const {
if (name == TYPE_TUPLE) {
std::vector<std::string> gs;
auto n = getRepeats();
if (n == -1)
gs.push_back(repeats->debugString(mode));
for (int i = 0; i < std::max(n, int64_t(0)); i++)
for (auto &a : args)
gs.push_back(a->debugString(mode));
return fmt::format("{}{}", name,
gs.empty() ? "" : fmt::format("[{}]", join(gs, ",")));
} else {
return fmt::format("{}{}", repeats ? repeats->debugString(mode) + "," : "",
this->ClassType::debugString(mode));
}
}
std::string RecordType::realizedTypeName() const { return realizedName(); }
std::shared_ptr<RecordType> RecordType::getHeterogenousTuple() {
seqassert(canRealize(), "{} not realizable", toString());
if (args.size() > 1) {
@ -228,4 +291,25 @@ std::shared_ptr<RecordType> RecordType::getHeterogenousTuple() {
return nullptr;
}
/// Returns -1 if the type cannot be realized yet
int64_t RecordType::getRepeats() const {
if (!repeats)
return 1;
if (repeats->canRealize())
return std::max(repeats->evaluate().getInt(), int64_t(0));
return -1;
}
void RecordType::flatten() {
auto n = getRepeats();
seqassert(n >= 0, "bad call to flatten");
auto a = args;
args.clear();
for (int64_t i = 0; i < n; i++)
args.insert(args.end(), a.begin(), a.end());
repeats = nullptr;
}
} // namespace codon::ast::types

View File

@ -78,12 +78,15 @@ struct RecordType : public ClassType {
/// List of tuple arguments.
std::vector<TypePtr> args;
bool noTuple;
std::shared_ptr<StaticType> repeats = nullptr;
explicit RecordType(
Cache *cache, std::string name, std::string niceName,
std::vector<ClassType::Generic> generics = std::vector<ClassType::Generic>(),
std::vector<TypePtr> args = std::vector<TypePtr>(), bool noTuple = false);
RecordType(const ClassTypePtr &base, std::vector<TypePtr> args, bool noTuple = false);
std::vector<TypePtr> args = std::vector<TypePtr>(), bool noTuple = false,
const std::shared_ptr<StaticType> &repeats = nullptr);
RecordType(const ClassTypePtr &base, std::vector<TypePtr> args, bool noTuple = false,
const std::shared_ptr<StaticType> &repeats = nullptr);
public:
int unify(Type *typ, Unification *undo) override;
@ -96,11 +99,16 @@ public:
bool canRealize() const override;
bool isInstantiated() const override;
std::string debugString(char mode) const override;
std::string realizedName() const override;
std::string realizedTypeName() const override;
std::shared_ptr<RecordType> getRecord() override {
return std::static_pointer_cast<RecordType>(shared_from_this());
}
std::shared_ptr<RecordType> getHeterogenousTuple() override;
int64_t getRepeats() const;
void flatten();
};
} // namespace codon::ast::types

View File

@ -99,6 +99,10 @@ bool FuncType::canRealize() const {
return generics;
}
std::string FuncType::realizedTypeName() const {
return this->ClassType::realizedName();
}
bool FuncType::isInstantiated() const {
TypePtr removed = nullptr;
auto retType = getRetType();

View File

@ -48,6 +48,7 @@ public:
bool isInstantiated() const override;
std::string debugString(char mode) const override;
std::string realizedName() const override;
std::string realizedTypeName() const override;
std::shared_ptr<FuncType> getFunc() override {
return std::static_pointer_cast<FuncType>(shared_from_this());

View File

@ -140,7 +140,9 @@ StaticValue StaticType::evaluate() const {
cache->typeCtx->addBlock();
for (auto &g : generics)
cache->typeCtx->add(TypecheckItem::Type, g.name, g.type);
auto oldChangedNodes = cache->typeCtx->changedNodes;
auto en = TypecheckVisitor(cache->typeCtx).transform(expr->clone());
cache->typeCtx->changedNodes = oldChangedNodes;
seqassert(en->isStatic() && en->staticValue.evaluated, "{} cannot be evaluated", en);
cache->typeCtx->popBlock();
return en->staticValue;

View File

@ -94,25 +94,23 @@ int CallableTrait::unify(Type *typ, Unification *us) {
starArgTypes.insert(starArgTypes.end(), inArgs.begin() + i, inArgs.end());
auto tv = TypecheckVisitor(cache->typeCtx);
auto name = tv.generateTuple(starArgTypes.size());
auto t = cache->typeCtx->forceFind(name)->type;
t = cache->typeCtx->instantiateGeneric(t, starArgTypes)->getClass();
auto t = cache->typeCtx->instantiateTuple(starArgTypes)->getClass();
if (t->unify(trInArgs[star].get(), us) == -1)
return -1;
}
if (kwStar < trInArgs.size()) {
auto tv = TypecheckVisitor(cache->typeCtx);
std::vector<std::string> names;
std::vector<TypePtr> starArgTypes;
if (auto tp = tr->getPartial()) {
auto ts = tp->args.back()->getRecord();
seqassert(ts, "bad partial *args/**kwargs");
auto &ff = cache->classes[ts->name].fields;
auto ff = tv.getClassFields(ts.get());
for (size_t i = 0; i < ts->args.size(); i++) {
names.emplace_back(ff[i].name);
starArgTypes.emplace_back(ts->args[i]);
}
}
auto tv = TypecheckVisitor(cache->typeCtx);
auto name = tv.generateTuple(starArgTypes.size(), TYPE_KWTUPLE, names);
auto t = cache->typeCtx->forceFind(name)->type;
t = cache->typeCtx->instantiateGeneric(t, starArgTypes)->getClass();

View File

@ -46,4 +46,16 @@ public:
std::string debugString(char mode) const override;
};
struct VariableTupleTrait : public Trait {
TypePtr size;
public:
explicit VariableTupleTrait(TypePtr size);
int unify(Type *typ, Unification *undo) override;
TypePtr generalize(int atLevel) override;
TypePtr instantiate(int atLevel, int *unboundCount,
std::unordered_map<int, TypePtr> *cache) override;
std::string debugString(char mode) const override;
};
} // namespace codon::ast::types

View File

@ -150,9 +150,7 @@ void UnionType::seal() {
pendingTypes[i]->getLink()->kind == LinkType::Unbound)
break;
std::vector<TypePtr> typeSet(pendingTypes.begin(), pendingTypes.begin() + i);
auto name = tv.generateTuple(typeSet.size());
auto t = cache->typeCtx->instantiateGeneric(
cache->typeCtx->forceFind(name)->type->getClass(), typeSet);
auto t = cache->typeCtx->instantiateTuple(typeSet);
Unification us;
generics[0].type->unify(t.get(), &us);
}

View File

@ -152,20 +152,16 @@ ir::Func *Cache::realizeFunction(types::FuncTypePtr type,
ir::types::Type *Cache::makeTuple(const std::vector<types::TypePtr> &types) {
auto tv = TypecheckVisitor(typeCtx);
auto name = tv.generateTuple(types.size());
auto t = typeCtx->find(name);
seqassertn(t && t->type, "cannot find {}", name);
return realizeType(t->type->getClass(), types);
auto t = typeCtx->instantiateTuple(types);
return realizeType(t, types);
}
ir::types::Type *Cache::makeFunction(const std::vector<types::TypePtr> &types) {
auto tv = TypecheckVisitor(typeCtx);
seqassertn(!types.empty(), "types must have at least one argument");
auto tup = tv.generateTuple(types.size() - 1);
const auto &ret = types[0];
auto argType = typeCtx->instantiateGeneric(
typeCtx->find(tup)->type,
auto argType = typeCtx->instantiateTuple(
std::vector<types::TypePtr>(types.begin() + 1, types.end()));
auto t = typeCtx->find("Function");
seqassertn(t && t->type, "cannot find 'Function'");
@ -175,8 +171,7 @@ ir::types::Type *Cache::makeFunction(const std::vector<types::TypePtr> &types) {
ir::types::Type *Cache::makeUnion(const std::vector<types::TypePtr> &types) {
auto tv = TypecheckVisitor(typeCtx);
auto tup = tv.generateTuple(types.size());
auto argType = typeCtx->instantiateGeneric(typeCtx->find(tup)->type, types);
auto argType = typeCtx->instantiateTuple(types);
auto t = typeCtx->find("Union");
seqassertn(t && t->type, "cannot find 'Union'");
return realizeType(t->type->getClass(), {argType});

View File

@ -20,7 +20,7 @@
#define STDLIB_IMPORT ":stdlib:"
#define STDLIB_INTERNAL_MODULE "internal"
#define TYPE_TUPLE "Tuple.N"
#define TYPE_TUPLE "Tuple"
#define TYPE_KWTUPLE "KwTuple.N"
#define TYPE_TYPEVAR "TypeVar"
#define TYPE_CALLABLE "Callable"

View File

@ -14,10 +14,6 @@ using namespace codon::error;
namespace codon::ast {
void SimplifyVisitor::visit(IdExpr *expr) {
if (startswith(expr->value, TYPE_TUPLE)) {
expr->markType();
return;
}
auto val = ctx->findDominatingBinding(expr->value);
if (!val && ctx->getBase()->pyCaptures) {

View File

@ -95,6 +95,18 @@ StmtPtr SimplifyVisitor::transformAssignment(ExprPtr lhs, ExprPtr rhs, ExprPtr t
if (auto idx = lhs->getIndex()) {
// Case: a[x] = b
seqassert(!type, "unexpected type annotation");
if (auto b = rhs->getBinary()) {
if (mustExist && b->inPlace && !b->rexpr->getId()) {
auto var = ctx->cache->getTemporaryVar("assign");
seqassert(rhs->getBinary(), "not a bin");
return transform(N<SuiteStmt>(
N<AssignStmt>(N<IdExpr>(var), idx->index),
N<ExprStmt>(N<CallExpr>(
N<DotExpr>(idx->expr, "__setitem__"), N<IdExpr>(var),
N<BinaryExpr>(N<IndexExpr>(idx->expr->clone(), N<IdExpr>(var)), b->op,
b->rexpr, true)))));
}
}
return transform(N<ExprStmt>(
N<CallExpr>(N<DotExpr>(idx->expr, "__setitem__"), idx->index, rhs)));
}

View File

@ -528,7 +528,23 @@ StmtPtr SimplifyVisitor::codegenMagic(const std::string &op, const ExprPtr &typE
// Classes: def __new__() -> T
stmts.emplace_back(N<ReturnStmt>(N<CallExpr>(NS(op), typExpr->clone())));
}
} else if (op == "init") {
}
// else if (startswith(op, "new.")) {
// // special handle for tuple[t1, t2, ...]
// int sz = atoi(op.substr(4).c_str());
// std::vector<ExprPtr> ts;
// for (int i = 0; i < sz; i++) {
// fargs.emplace_back(format("a{}", i + 1), I(format("T{}", i + 1)));
// ts.emplace_back(I(format("T{}", i + 1)));
// }
// for (int i = 0; i < sz; i++) {
// fargs.emplace_back(format("T{}", i + 1), I("type"));
// }
// ret = N<InstantiateExpr>(I(TYPE_TUPLE), ts);
// ret->markType();
// attr.set(Attr::Internal);
// }
else if (op == "init") {
// Classes: def __init__(self: T, a1: T1, ..., aN: TN) -> None:
// self.aI = aI ...
ret = I("NoneType");

View File

@ -54,13 +54,12 @@ void SimplifyVisitor::visit(ChainBinaryExpr *expr) {
}
/// Transform index into an instantiation @c InstantiateExpr if possible.
/// Generate tuple class `Tuple.N` for `Tuple[T1, ... TN]` (and `tuple[...]`).
/// Generate tuple class `Tuple` for `Tuple[T1, ... TN]` (and `tuple[...]`).
/// The rest is handled during the type checking.
void SimplifyVisitor::visit(IndexExpr *expr) {
if (expr->expr->isId("tuple") || expr->expr->isId("Tuple")) {
// Special case: tuples. Change to Tuple.N
if (expr->expr->isId("tuple") || expr->expr->isId(TYPE_TUPLE)) {
auto t = expr->index->getTuple();
expr->expr = NT<IdExpr>(format(TYPE_TUPLE "{}", t ? t->items.size() : 1));
expr->expr = NT<IdExpr>(TYPE_TUPLE);
} else if (expr->expr->isId("Static")) {
// Special case: static types. Ensure that static is supported
if (!expr->index->isId("int") && !expr->index->isId("str"))
@ -84,7 +83,7 @@ void SimplifyVisitor::visit(IndexExpr *expr) {
if (i->getList() && expr->expr->isType()) {
// Special case: `A[[A, B], C]` -> `A[Tuple[A, B], C]` (e.g., in
// `Function[...]`)
i = N<IndexExpr>(N<IdExpr>("Tuple"), N<TupleExpr>(i->getList()->items));
i = N<IndexExpr>(N<IdExpr>(TYPE_TUPLE), N<TupleExpr>(i->getList()->items));
}
transform(i, true);
}

View File

@ -20,7 +20,7 @@ namespace codon::ast {
* - All imports are flattened resulting in a single self-containing
* (and fairly large) AST
* - All identifiers are normalized (no two distinct objects share the same name)
* - Variadic classes (e.g., Tuple.N) are generated
* - Variadic classes (e.g., Tuple) are generated
* - Any AST node that can be trivially expressed as a set of "simpler" nodes
* type is simplified. If a transformation requires a type information,
* it is done during the type checking.

View File

@ -55,7 +55,12 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
cache->codegenCtx->add(TranslateItem::Var, g.first, g.second);
}
TranslateVisitor(cache->codegenCtx).transform(stmts);
auto tv = TranslateVisitor(cache->codegenCtx);
tv.transform(stmts);
for (auto &[fn, f] : cache->functions)
if (startswith(fn, TYPE_TUPLE)) {
tv.transformFunctionRealizations(fn, f.ast->attributes.has(Attr::LLVM));
}
cache->populatePythonModule();
return main;
}
@ -223,9 +228,10 @@ void TranslateVisitor::visit(CallExpr *expr) {
seqassert(!expr->args[i].value->getEllipsis(), "ellipsis not elided");
if (i + 1 == expr->args.size() && isVariadic) {
auto call = expr->args[i].value->getCall();
seqassert(call && call->expr->getId() &&
startswith(call->expr->getId()->value, TYPE_TUPLE),
"expected *args tuple");
seqassert(
call && call->expr->getId() &&
startswith(call->expr->getId()->value, std::string(TYPE_TUPLE) + "["),
"expected *args tuple: '{}'", call->toString());
for (auto &arg : call->args)
items.emplace_back(transform(arg.value));
} else {
@ -524,20 +530,7 @@ void TranslateVisitor::visit(ThrowStmt *stmt) {
void TranslateVisitor::visit(FunctionStmt *stmt) {
// Process all realizations.
for (auto &real : ctx->cache->functions[stmt->name].realizations) {
if (!in(ctx->cache->pendingRealizations, make_pair(stmt->name, real.first)))
continue;
ctx->cache->pendingRealizations.erase(make_pair(stmt->name, real.first));
LOG_TYPECHECK("[translate] generating fn {}", real.first);
real.second->ir->setSrcInfo(getSrcInfo());
const auto &ast = real.second->ast;
seqassert(ast, "AST not set for {}", real.first);
if (!stmt->attributes.has(Attr::LLVM))
transformFunction(real.second->type.get(), ast.get(), real.second->ir);
else
transformLLVMFunction(real.second->type.get(), ast.get(), real.second->ir);
}
transformFunctionRealizations(stmt->name, stmt->attributes.has(Attr::LLVM));
}
void TranslateVisitor::visit(ClassStmt *stmt) {
@ -555,6 +548,24 @@ codon::ir::types::Type *TranslateVisitor::getType(const types::TypePtr &t) {
return i->getType();
}
void TranslateVisitor::transformFunctionRealizations(const std::string &name,
bool isLLVM) {
for (auto &real : ctx->cache->functions[name].realizations) {
if (!in(ctx->cache->pendingRealizations, make_pair(name, real.first)))
continue;
ctx->cache->pendingRealizations.erase(make_pair(name, real.first));
LOG_TYPECHECK("[translate] generating fn {}", real.first);
real.second->ir->setSrcInfo(getSrcInfo());
const auto &ast = real.second->ast;
seqassert(ast, "AST not set for {}", real.first);
if (!isLLVM)
transformFunction(real.second->type.get(), ast.get(), real.second->ir);
else
transformLLVMFunction(real.second->type.get(), ast.get(), real.second->ir);
}
}
void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *ast,
ir::Func *func) {
std::vector<std::string> names;

View File

@ -66,6 +66,7 @@ public:
private:
ir::types::Type *getType(const types::TypePtr &t);
void transformFunctionRealizations(const std::string &name, bool isLLVM);
void transformFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func);
void transformLLVMFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func);

View File

@ -20,13 +20,7 @@ using namespace types;
/// replace it with its value (e.g., a @c IntExpr ). Also ensure that the identifier of
/// a generic function or a type is fully qualified (e.g., replace `Ptr` with
/// `Ptr[byte]`).
/// For tuple identifiers, generate appropriate class. See @c generateTuple for
/// details.
void TypecheckVisitor::visit(IdExpr *expr) {
// Generate tuple stubs if needed
if (isTuple(expr->value))
generateTuple(std::stoi(expr->value.substr(sizeof(TYPE_TUPLE) - 1)));
// Replace identifiers that have been superseded by domination analysis during the
// simplification
while (auto s = in(ctx->cache->replacements, expr->value))
@ -187,7 +181,7 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
return nullptr;
// Check if this is a method or member access
if (ctx->findMethod(typ->name, expr->member).empty())
if (ctx->findMethod(typ.get(), expr->member).empty())
return getClassMember(expr, args);
auto bestMethod = getBestOverload(expr, args);
@ -210,10 +204,9 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
std::vector<ExprPtr> ids;
for (auto &t : fn->getArgTypes())
ids.push_back(NT<IdExpr>(t->realizedName()));
auto name = generateTuple(ids.size());
auto fnType = NT<InstantiateExpr>(
NT<IdExpr>("Function"),
std::vector<ExprPtr>{NT<InstantiateExpr>(NT<IdExpr>(name), ids),
std::vector<ExprPtr>{NT<InstantiateExpr>(NT<IdExpr>(TYPE_TUPLE), ids),
NT<IdExpr>(fn->getRetType()->realizedName())});
// Function[Tuple[TArg1, TArg2, ...],TRet](
// __internal__.class_get_rtti_vtable(expr)[T[VIRTUAL_ID]]
@ -264,7 +257,7 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,
// Case: object member access (`obj.member`)
if (!expr->expr->isType()) {
if (auto member = ctx->findMember(typ->name, expr->member)) {
if (auto member = ctx->findMember(typ, expr->member)) {
unify(expr->type, ctx->instantiate(member, typ));
if (expr->expr->isDone() && realize(expr->type))
expr->setDone();
@ -340,7 +333,8 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr,
{"", N<EllipsisExpr>(EllipsisExpr::PARTIAL)}}));
}
// For debugging purposes: ctx->findMethod(typ->name, expr->member);
// For debugging purposes:
// ctx->findMethod(typ.get(), expr->member);
E(Error::DOT_NO_ATTR, expr, typ->prettyString(), expr->member);
return nullptr;
}
@ -372,7 +366,7 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
bool addSelf = true;
if (auto dot = expr->getDot()) {
auto methods =
ctx->findMethod(dot->expr->type->getClass()->name, dot->member, false);
ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false);
if (!methods.empty() && methods.front()->ast->attributes.has(Attr::StaticMethod))
addSelf = false;
}
@ -410,7 +404,7 @@ FuncTypePtr TypecheckVisitor::getBestOverload(Expr *expr,
if (auto dot = expr->getDot()) {
// Case: method overloads (DotExpr)
auto methods =
ctx->findMethod(dot->expr->type->getClass()->name, dot->member, false);
ctx->findMethod(dot->expr->type->getClass().get(), dot->member, false);
auto m = findMatchingMethods(dot->expr->type->getClass(), methods, *methodArgs);
bestMethod = m.empty() ? nullptr : m[0];
} else if (auto id = expr->getId()) {

View File

@ -161,7 +161,7 @@ void TypecheckVisitor::visit(AssignMemberStmt *stmt) {
transform(stmt->lhs);
if (auto lhsClass = stmt->lhs->getType()->getClass()) {
auto member = ctx->findMember(lhsClass->name, stmt->member);
auto member = ctx->findMember(lhsClass, stmt->member);
if (!member && stmt->lhs->isType()) {
// Case: class variables

View File

@ -160,7 +160,7 @@ bool TypecheckVisitor::transformCallArgs(std::vector<CallExpr::Arg> &args) {
return false;
if (!typ->getRecord())
E(Error::CALL_BAD_UNPACK, args[ai], typ->prettyString());
auto &fields = ctx->cache->classes[typ->name].fields;
auto fields = getClassFields(typ.get());
for (size_t i = 0; i < typ->getRecord()->args.size(); i++, ai++) {
args.insert(args.begin() + ai,
{"", transform(N<DotExpr>(clone(star->what), fields[i].name))});
@ -176,9 +176,9 @@ bool TypecheckVisitor::transformCallArgs(std::vector<CallExpr::Arg> &args) {
}
if (!typ)
return false;
if (!typ->getRecord() || startswith(typ->name, TYPE_TUPLE))
if (!typ->getRecord() || typ->name == TYPE_TUPLE)
E(Error::CALL_BAD_KWUNPACK, args[ai], typ->prettyString());
auto &fields = ctx->cache->classes[typ->name].fields;
auto fields = getClassFields(typ.get());
for (size_t i = 0; i < typ->getRecord()->args.size(); i++, ai++) {
args.insert(args.begin() + ai,
{fields[i].name,
@ -338,7 +338,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e
auto e = getPartialArg(-1);
auto t = e->getType()->getRecord();
seqassert(t && startswith(t->name, TYPE_KWTUPLE), "{} not a kwtuple", e);
auto &ff = ctx->cache->classes[t->name].fields;
auto ff = getClassFields(t.get());
for (int i = 0; i < t->getRecord()->args.size(); i++) {
names.emplace_back(ff[i].name);
values.emplace_back(
@ -663,10 +663,9 @@ ExprPtr TypecheckVisitor::transformSuper() {
if (typ->getRecord()) {
// Case: tuple types. Return `tuple(obj.args...)`
std::vector<ExprPtr> members;
for (auto &field : ctx->cache->classes[name].fields)
for (auto &field : getClassFields(superTyp.get()))
members.push_back(N<DotExpr>(N<IdExpr>(funcTyp->ast->args[0].name), field.name));
ExprPtr e = transform(
N<CallExpr>(N<IdExpr>(format(TYPE_TUPLE "{}", members.size())), members));
ExprPtr e = transform(N<TupleExpr>(members));
e->type = unify(superTyp, e->type); // see super_tuple test for this line
return e;
} else {
@ -732,8 +731,8 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
}
expr->staticValue.type = StaticValue::INT;
if (typExpr->isId("Tuple") || typExpr->isId("tuple")) {
return transform(N<BoolExpr>(startswith(typ->name, TYPE_TUPLE)));
if (typExpr->isId(TYPE_TUPLE) || typExpr->isId("tuple")) {
return transform(N<BoolExpr>(typ->name == TYPE_TUPLE));
} else if (typExpr->isId("ByVal")) {
return transform(N<BoolExpr>(typ->getRecord() != nullptr));
} else if (typExpr->isId("ByRef")) {
@ -766,7 +765,10 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
// Check super types (i.e., statically inherited) as well
for (auto &tx : getSuperTypes(typ->getClass())) {
if (tx->unify(typExpr->type.get(), nullptr) >= 0)
types::Type::Unification us;
auto s = tx->unify(typExpr->type.get(), &us);
us.undo();
if (s >= 0)
return transform(N<BoolExpr>(true));
}
return transform(N<BoolExpr>(false));
@ -841,8 +843,8 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) {
}
}
bool exists = !ctx->findMethod(typ->getClass()->name, member).empty() ||
ctx->findMember(typ->getClass()->name, member);
bool exists = !ctx->findMethod(typ->getClass().get(), member).empty() ||
ctx->findMember(typ->getClass(), member);
if (exists && args.size() > 1)
exists &= findBestMethod(typ, member, args) != nullptr;
return transform(N<BoolExpr>(exists));
@ -890,26 +892,23 @@ ExprPtr TypecheckVisitor::transformTupleFn(CallExpr *expr) {
return expr->clone();
std::vector<ExprPtr> items;
auto tn = generateTuple(ctx->cache->classes[cls->name].fields.size());
for (auto &ft : ctx->cache->classes[cls->name].fields) {
for (auto &ft : getClassFields(cls.get())) {
auto t = ctx->instantiate(ft.type, cls);
auto rt = realize(t);
seqassert(rt, "cannot realize '{}' in {}", t, ft.name);
items.push_back(NT<IdExpr>(t->realizedName()));
}
auto e = transform(NT<InstantiateExpr>(N<IdExpr>(tn), items));
auto e = transform(NT<InstantiateExpr>(N<IdExpr>(TYPE_TUPLE), items));
return e;
}
std::vector<ExprPtr> args;
args.reserve(ctx->cache->classes[cls->name].fields.size());
std::string var = ctx->cache->getTemporaryVar("tup");
for (auto &field : ctx->cache->classes[cls->name].fields)
for (auto &field : getClassFields(cls.get()))
args.emplace_back(N<DotExpr>(N<IdExpr>(var), field.name));
return transform(N<StmtExpr>(
N<AssignStmt>(N<IdExpr>(var), expr->args.front().value),
N<CallExpr>(N<IdExpr>(format("{}{}", TYPE_TUPLE, args.size())), args)));
return transform(N<StmtExpr>(N<AssignStmt>(N<IdExpr>(var), expr->args.front().value),
N<TupleExpr>(args)));
}
/// Transform type function to a type IdExpr identifier.
@ -1092,7 +1091,7 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
return {true, nullptr};
size_t idx = 0;
for (auto &f : ctx->cache->classes[typ->name].fields) {
for (auto &f : getClassFields(typ.get())) {
auto k = N<StringExpr>(f.name);
auto v = N<DotExpr>(expr->args[0].value, f.name);
if (withIdx) {
@ -1119,10 +1118,10 @@ std::pair<bool, ExprPtr> TypecheckVisitor::transformInternalStaticFn(CallExpr *e
error("invalid index");
typ = t->getRecord()->args[n];
} else {
if (n < 0 || n >= ctx->cache->classes[t->getClass()->name].fields.size())
auto f = getClassFields(t->getClass().get());
if (n < 0 || n >= f.size())
error("invalid index");
typ = ctx->instantiate(ctx->cache->classes[t->getClass()->name].fields[n].type,
t->getClass());
typ = ctx->instantiate(f[n].type, t->getClass());
}
typ = realize(typ);
return {true, transform(NT<IdExpr>(typ->realizedName()))};
@ -1141,8 +1140,8 @@ std::vector<ClassTypePtr> TypecheckVisitor::getSuperTypes(const ClassTypePtr &cl
result.push_back(cls);
for (auto &name : ctx->cache->classes[cls->name].staticParentClasses) {
auto parentTyp = ctx->instantiate(ctx->forceFind(name)->type)->getClass();
for (auto &field : ctx->cache->classes[cls->name].fields) {
for (auto &parentField : ctx->cache->classes[name].fields)
for (auto &field : getClassFields(cls.get())) {
for (auto &parentField : getClassFields(parentTyp.get()))
if (field.name == parentField.name) {
unify(ctx->instantiate(field.type, cls),
ctx->instantiate(parentField.type, parentTyp));

View File

@ -117,7 +117,7 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
LOG_REALIZE(" - member: {}: {}", m.name, m.type);
}
/// Generate a tuple class `Tuple.N[T1,...,TN]`.
/// Generate a tuple class `Tuple[T1,...,TN]`.
/// @param len Tuple length (`N`)
/// @param name Tuple name. `Tuple` by default.
/// Can be something else (e.g., `KwTuple`)

View File

@ -126,8 +126,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
seqassert(collectionTyp->getRecord() &&
collectionTyp->getRecord()->args.size() == 2,
"bad dict");
auto tname = generateTuple(2);
auto tt = unify(typ, ctx->instantiate(ctx->getType(tname)))->getRecord();
auto tt = unify(typ, ctx->instantiateTuple(2))->getRecord();
auto nt = collectionTyp->getRecord()->args;
for (int di = 0; di < 2; di++) {
if (!nt[di]->getClass())
@ -135,7 +134,7 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
else if (auto dt = superTyp(nt[di]->getClass(), tt->args[di]->getClass()))
nt[di] = dt;
}
collectionTyp = ctx->instantiateGeneric(ctx->getType(tname), nt);
collectionTyp = ctx->instantiateTuple(nt);
}
}
if (!done)
@ -195,9 +194,9 @@ ExprPtr TypecheckVisitor::transformComprehension(const std::string &type,
}
/// Transform tuples.
/// Generate tuple classes (e.g., `Tuple.N`) if not available.
/// Generate tuple classes (e.g., `Tuple`) if not available.
/// @example
/// `(a1, ..., aN)` -> `Tuple.N.__new__(a1, ..., aN)`
/// `(a1, ..., aN)` -> `Tuple.__new__(a1, ..., aN)`
void TypecheckVisitor::visit(TupleExpr *expr) {
expr->setType(ctx->getUnbound());
for (int ai = 0; ai < expr->items.size(); ai++)
@ -213,7 +212,7 @@ void TypecheckVisitor::visit(TupleExpr *expr) {
return; // continue later when the type becomes known
if (!typ->getRecord())
E(Error::CALL_BAD_UNPACK, star, typ->prettyString());
auto &ff = ctx->cache->classes[typ->name].fields;
auto ff = getClassFields(typ.get());
for (int i = 0; i < typ->getRecord()->args.size(); i++, ai++) {
expr->items.insert(expr->items.begin() + ai,
transform(N<DotExpr>(clone(star->what), ff[i].name)));
@ -224,15 +223,14 @@ void TypecheckVisitor::visit(TupleExpr *expr) {
} else {
expr->items[ai] = transform(expr->items[ai]);
}
auto tupleName = generateTuple(expr->items.size());
resultExpr =
transform(N<CallExpr>(N<DotExpr>(tupleName, "__new__"), clone(expr->items)));
auto s = ctx->generateTuple(expr->items.size());
resultExpr = transform(N<CallExpr>(N<IdExpr>(s), clone(expr->items)));
unify(expr->type, resultExpr->type);
}
/// Transform a tuple generator expression.
/// @example
/// `tuple(expr for i in tuple_generator)` -> `Tuple.N.__new__(expr...)`
/// `tuple(expr for i in tuple_generator)` -> `Tuple.__new__(expr...)`
void TypecheckVisitor::visit(GeneratorExpr *expr) {
seqassert(expr->kind == GeneratorExpr::Generator && expr->loops.size() == 1 &&
expr->loops[0].conds.empty(),
@ -282,8 +280,7 @@ void TypecheckVisitor::visit(GeneratorExpr *expr) {
}
auto tuple = gen->type->getRecord();
if (!tuple ||
!(startswith(tuple->name, TYPE_TUPLE) || startswith(tuple->name, TYPE_KWTUPLE)))
if (!tuple || !(tuple->name == TYPE_TUPLE || startswith(tuple->name, TYPE_KWTUPLE)))
E(Error::CALL_BAD_ITER, gen, gen->type->prettyString());
// `a := tuple[i]; expr...` for each i

View File

@ -10,6 +10,8 @@
#include "codon/parser/ast.h"
#include "codon/parser/common.h"
#include "codon/parser/visitors/format/format.h"
#include "codon/parser/visitors/simplify/ctx.h"
#include "codon/parser/visitors/typecheck/typecheck.h"
using fmt::format;
using namespace codon::error;
@ -97,14 +99,18 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
for (auto &i : genericCache) {
if (auto l = i.second->getLink()) {
i.second->setSrcInfo(srcInfo);
if (l->defaultType)
if (l->defaultType) {
pendingDefaults.insert(i.second);
}
}
}
if (t->getUnion() && !t->getUnion()->isSealed()) {
t->setSrcInfo(srcInfo);
pendingDefaults.insert(t);
}
if (auto r = t->getRecord())
if (r->repeats && r->repeats->canRealize())
r->flatten();
return t;
}
@ -126,9 +132,112 @@ TypeContext::instantiateGeneric(const SrcInfo &srcInfo, const types::TypePtr &ro
return instantiate(srcInfo, root, g);
}
std::vector<types::FuncTypePtr> TypeContext::findMethod(const std::string &typeName,
std::shared_ptr<types::RecordType>
TypeContext::instantiateTuple(const SrcInfo &srcInfo,
const std::vector<types::TypePtr> &generics) {
auto key = generateTuple(generics.size());
auto root = forceFind(key)->type->getRecord();
return instantiateGeneric(srcInfo, root, generics)->getRecord();
}
std::string TypeContext::generateTuple(size_t n) {
auto key = format("_{}:{}", TYPE_TUPLE, n);
if (!in(cache->classes, key)) {
cache->classes[key].fields.clear();
cache->classes[key].ast =
std::static_pointer_cast<ClassStmt>(clone(cache->classes[TYPE_TUPLE].ast));
auto root = std::make_shared<types::RecordType>(cache, TYPE_TUPLE, TYPE_TUPLE);
for (size_t i = 0; i < n; i++) { // generate unique ID
auto g = getUnbound()->getLink();
g->kind = types::LinkType::Generic;
g->genericName = format("T{}", i + 1);
auto gn = cache->imports[MAIN_IMPORT].ctx->generateCanonicalName(g->genericName);
root->generics.emplace_back(gn, g->genericName, g, g->id);
root->args.emplace_back(g);
cache->classes[key].ast->args.emplace_back(
g->genericName, std::make_shared<IdExpr>("type"), nullptr, Param::Generic);
cache->classes[key].fields.push_back(
Cache::Class::ClassField{format("item{}", i + 1), g, ""});
}
std::vector<ExprPtr> eTypeArgs;
for (size_t i = 0; i < n; i++)
eTypeArgs.push_back(std::make_shared<IdExpr>(format("T{}", i + 1)));
auto eType = std::make_shared<InstantiateExpr>(std::make_shared<IdExpr>(TYPE_TUPLE),
eTypeArgs);
eType->type = root;
cache->classes[key].mro = {eType};
addToplevel(key, std::make_shared<TypecheckItem>(TypecheckItem::Type, root));
}
return key;
}
std::shared_ptr<types::RecordType> TypeContext::instantiateTuple(size_t n) {
std::vector<types::TypePtr> t(n);
for (size_t i = 0; i < n; i++) {
auto g = getUnbound()->getLink();
g->genericName = format("T{}", i + 1);
t[i] = g;
}
return instantiateTuple(getSrcInfo(), t);
}
std::vector<types::FuncTypePtr> TypeContext::findMethod(types::ClassType *type,
const std::string &method,
bool hideShadowed) const {
bool hideShadowed) {
auto typeName = type->name;
if (type->is(TYPE_TUPLE)) {
auto sz = type->getRecord()->getRepeats();
if (sz != -1)
type->getRecord()->flatten();
sz = int64_t(type->getRecord()->args.size());
typeName = format("_{}:{}", TYPE_TUPLE, sz);
if (in(cache->classes[TYPE_TUPLE].methods, method) &&
!in(cache->classes[typeName].methods, method)) {
auto type = forceFind(typeName)->type;
cache->classes[typeName].methods[method] =
cache->classes[TYPE_TUPLE].methods[method];
auto &o = cache->overloads[cache->classes[typeName].methods[method]];
auto f = cache->functions[o[0].name];
f.realizations.clear();
seqassert(f.type, "tuple fn type not yet set");
f.ast->attributes.parentClass = typeName;
f.ast = std::static_pointer_cast<FunctionStmt>(clone(f.ast));
f.ast->name = format("{}{}", f.ast->name.substr(0, f.ast->name.size() - 1), sz);
f.ast->attributes.set(Attr::Method);
auto eType = clone(cache->classes[typeName].mro[0]);
eType->type = nullptr;
for (auto &a : f.ast->args)
if (a.type && a.type->isId(TYPE_TUPLE)) {
a.type = eType;
}
if (f.ast->ret && f.ast->ret->isId(TYPE_TUPLE))
f.ast->ret = eType;
// TODO: resurrect Tuple[N].__new__(defaults...)
if (method == "__new__") {
for (size_t i = 0; i < sz; i++) {
auto n = format("item{}", i + 1);
f.ast->args.emplace_back(
cache->imports[MAIN_IMPORT].ctx->generateCanonicalName(n),
std::make_shared<IdExpr>(format("T{}", i + 1))
// std::make_shared<CallExpr>(
// std::make_shared<IdExpr>(format("T{}", i + 1)))
);
}
}
cache->reverseIdentifierLookup[f.ast->name] = method;
cache->functions[f.ast->name] = f;
cache->functions[f.ast->name].type =
TypecheckVisitor(cache->typeCtx).makeFunctionType(f.ast.get());
addToplevel(f.ast->name,
std::make_shared<TypecheckItem>(TypecheckItem::Func,
cache->functions[f.ast->name].type));
o.push_back(Cache::Overload{f.ast->name, 0});
}
}
std::vector<types::FuncTypePtr> vv;
std::unordered_set<std::string> signatureLoci;
@ -166,9 +275,26 @@ std::vector<types::FuncTypePtr> TypeContext::findMethod(const std::string &typeN
return vv;
}
types::TypePtr TypeContext::findMember(const std::string &typeName,
types::TypePtr TypeContext::findMember(const types::ClassTypePtr &type,
const std::string &member) const {
if (auto cls = in(cache->classes, typeName)) {
if (type->is(TYPE_TUPLE)) {
if (!startswith(member, "item") || member.size() < 5)
return nullptr;
int id = 0;
for (int i = 4; i < member.size(); i++) {
if (member[i] >= '0' + (i == 4) && member[i] <= '9')
id = id * 10 + member[i] - '0';
else
return nullptr;
}
auto sz = type->getRecord()->getRepeats();
if (sz != -1)
type->getRecord()->flatten();
if (id < 1 || id > type->getRecord()->args.size())
return nullptr;
return type->getRecord()->args[id - 1];
}
if (auto cls = in(cache->classes, type->name)) {
for (auto &pt : cls->mro) {
if (auto pc = pt->type->getClass()) {
auto mc = in(cache->classes, pc->name);

View File

@ -132,14 +132,22 @@ public:
return instantiateGeneric(getSrcInfo(), std::move(root), generics);
}
std::shared_ptr<types::RecordType>
instantiateTuple(const SrcInfo &info, const std::vector<types::TypePtr> &generics);
std::shared_ptr<types::RecordType>
instantiateTuple(const std::vector<types::TypePtr> &generics) {
return instantiateTuple(getSrcInfo(), generics);
}
std::shared_ptr<types::RecordType> instantiateTuple(size_t n);
std::string generateTuple(size_t n);
/// Returns the list of generic methods that correspond to typeName.method.
std::vector<types::FuncTypePtr> findMethod(const std::string &typeName,
std::vector<types::FuncTypePtr> findMethod(types::ClassType *type,
const std::string &method,
bool hideShadowed = true) const;
bool hideShadowed = true);
/// Returns the generic type of typeName.member, if it exists (nullptr otherwise).
/// Special cases: __elemsize__ and __atomic__.
types::TypePtr findMember(const std::string &typeName,
const std::string &member) const;
types::TypePtr findMember(const types::ClassTypePtr &, const std::string &) const;
using ReorderDoneFn =
std::function<int(int, int, const std::vector<std::vector<int>> &, bool)>;

View File

@ -75,6 +75,42 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Function should be constructed only once
stmt->setDone();
auto funcTyp = makeFunctionType(stmt);
// If this is a class method, update the method lookup table
bool isClassMember = !stmt->attributes.parentClass.empty();
if (isClassMember) {
auto m =
ctx->cache->getMethod(ctx->find(stmt->attributes.parentClass)->type->getClass(),
ctx->cache->rev(stmt->name));
bool found = false;
for (auto &i : ctx->cache->overloads[m])
if (i.name == stmt->name) {
ctx->cache->functions[i.name].type = funcTyp;
found = true;
break;
}
seqassert(found, "cannot find matching class method for {}", stmt->name);
}
// Update the visited table
// Functions should always be visible, so add them to the toplevel
ctx->addToplevel(stmt->name,
std::make_shared<TypecheckItem>(TypecheckItem::Func, funcTyp));
ctx->cache->functions[stmt->name].type = funcTyp;
// Ensure that functions with @C, @force_realize, and @export attributes can be
// realized
if (stmt->attributes.has(Attr::ForceRealize) || stmt->attributes.has(Attr::Export) ||
(stmt->attributes.has(Attr::C) && !stmt->attributes.has(Attr::CVarArg))) {
if (!funcTyp->canRealize())
E(Error::FN_REALIZE_BUILTIN, stmt);
}
// Debug information
LOG_REALIZE("[stmt] added func {}: {}", stmt->name, funcTyp);
}
types::FuncTypePtr TypecheckVisitor::makeFunctionType(FunctionStmt *stmt) {
// Handle generics
bool isClassMember = !stmt->attributes.parentClass.empty();
auto explicits = std::vector<ClassType::Generic>();
@ -169,38 +205,7 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
}
funcTyp =
std::static_pointer_cast<FuncType>(funcTyp->generalize(ctx->typecheckLevel));
// If this is a class method, update the method lookup table
if (isClassMember) {
auto m =
ctx->cache->getMethod(ctx->find(stmt->attributes.parentClass)->type->getClass(),
ctx->cache->rev(stmt->name));
bool found = false;
for (auto &i : ctx->cache->overloads[m])
if (i.name == stmt->name) {
ctx->cache->functions[i.name].type = funcTyp;
found = true;
break;
}
seqassert(found, "cannot find matching class method for {}", stmt->name);
}
// Update the visited table
// Functions should always be visible, so add them to the toplevel
ctx->addToplevel(stmt->name,
std::make_shared<TypecheckItem>(TypecheckItem::Func, funcTyp));
ctx->cache->functions[stmt->name].type = funcTyp;
// Ensure that functions with @C, @force_realize, and @export attributes can be
// realized
if (stmt->attributes.has(Attr::ForceRealize) || stmt->attributes.has(Attr::Export) ||
(stmt->attributes.has(Attr::C) && !stmt->attributes.has(Attr::CVarArg))) {
if (!funcTyp->canRealize())
E(Error::FN_REALIZE_BUILTIN, stmt);
}
// Debug information
LOG_REALIZE("[stmt] added func {}: {}", stmt->name, funcTyp);
return funcTyp;
}
/// Make an empty partial call `fn(...)` for a given function.
@ -232,11 +237,10 @@ ExprPtr TypecheckVisitor::partializeFunction(const types::FuncTypePtr &fn) {
return call;
}
/// Generate and return `Function[Tuple.N[args...], ret]` type
/// Generate and return `Function[Tuple[args...], ret]` type
std::shared_ptr<RecordType> TypecheckVisitor::getFuncTypeBase(size_t nargs) {
auto baseType = ctx->instantiate(ctx->forceFind("Function")->type)->getRecord();
unify(baseType->generics[0].type,
ctx->instantiate(ctx->forceFind(generateTuple(nargs))->type)->getRecord());
unify(baseType->generics[0].type, ctx->instantiateTuple(nargs)->getRecord());
return baseType;
}

View File

@ -53,9 +53,12 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
ctx->getRealizationBase()->iteration++) {
LOG_TYPECHECK("[iter] {} :: {}", ctx->getRealizationBase()->name,
ctx->getRealizationBase()->iteration);
if (ctx->getRealizationBase()->iteration >= MAX_TYPECHECK_ITER)
if (ctx->getRealizationBase()->iteration >= MAX_TYPECHECK_ITER) {
error(result, "cannot typecheck '{}' in reasonable time",
ctx->cache->rev(ctx->getRealizationBase()->name));
ctx->getRealizationBase()->name.empty()
? "toplevel"
: ctx->cache->rev(ctx->getRealizationBase()->name));
}
// Keep iterating until:
// (1) success: the statement is marked as done; or
@ -183,7 +186,6 @@ types::TypePtr TypecheckVisitor::realize(types::TypePtr typ) {
}
e.trackRealize(fmt::format("{}{}", name, name_args), getSrcInfo());
}
} else {
e.trackRealize(typ->prettyString(), getSrcInfo());
}
@ -198,9 +200,12 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
if (!type || !type->canRealize())
return nullptr;
if (auto tr = type->getRecord())
tr->flatten();
// Check if the type fields are all initialized
// (sometimes that's not the case: e.g., `class X: x: List[X]`)
for (auto &field : ctx->cache->classes[type->name].fields) {
for (auto field : getClassFields(type)) {
if (!field.type)
return nullptr;
}
@ -247,16 +252,24 @@ types::TypePtr TypecheckVisitor::realizeType(types::ClassType *type) {
std::vector<ir::types::Type *> typeArgs; // needed for IR
std::vector<std::string> names; // needed for IR
std::map<std::string, SrcInfo> memberInfo; // needed for IR
for (auto &field : ctx->cache->classes[realized->name].fields) {
if (realized->is(TYPE_TUPLE))
realized->getRecord()->flatten();
int i = 0;
for (auto &field : getClassFields(realized.get())) {
auto ftyp = ctx->instantiate(field.type, realized);
// HACK: repeated tuples have no generics so this is needed to fix the instantiation
// above
if (realized->is(TYPE_TUPLE))
unify(ftyp, realized->getRecord()->args[i]);
if (!realize(ftyp))
E(Error::TYPE_CANNOT_REALIZE_ATTR, getSrcInfo(), field.name,
ftyp->prettyString());
LOG_REALIZE("- member: {} -> {}: {}", field.name, field.type, ftyp);
realization->fields.emplace_back(field.name, ftyp);
names.emplace_back(field.name);
typeArgs.emplace_back(makeIRType(ftyp->getClass().get()));
memberInfo[field.name] = field.type->getSrcInfo();
i++;
}
// Set IR attributes
@ -431,9 +444,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
NT<InstantiateExpr>(
NT<IdExpr>("Function"),
std::vector<ExprPtr>{
NT<InstantiateExpr>(
NT<IdExpr>(format("{}{}", TYPE_TUPLE, ids.size())),
ids),
NT<InstantiateExpr>(NT<IdExpr>(TYPE_TUPLE), ids),
NT<IdExpr>(fn->getRetType()->realizedName())}),
N<IdExpr>(fn->realizedName())),
"__raw__")),
@ -459,7 +470,7 @@ StmtPtr TypecheckVisitor::prepareVTables() {
auto baseTyp = t->funcGenerics[0].type->getClass();
auto derivedTyp = t->funcGenerics[1].type->getClass();
const auto &fields = ctx->cache->classes[derivedTyp->name].fields;
const auto &fields = getClassFields(derivedTyp.get());
auto types = std::vector<ExprPtr>{};
auto found = false;
for (auto &f : fields) {
@ -471,13 +482,11 @@ StmtPtr TypecheckVisitor::prepareVTables() {
types.push_back(NT<IdExpr>(ft->realizedName()));
}
}
seqassert(found || ctx->cache->classes[baseTyp->name].fields.empty(),
seqassert(found || getClassFields(baseTyp.get()).empty(),
"cannot find distance between {} and {}", derivedTyp->name,
baseTyp->name);
StmtPtr suite = N<ReturnStmt>(
N<DotExpr>(NT<InstantiateExpr>(
NT<IdExpr>(format("{}{}", TYPE_TUPLE, types.size())), types),
"__elemsize__"));
N<DotExpr>(NT<InstantiateExpr>(NT<IdExpr>(TYPE_TUPLE), types), "__elemsize__"));
LOG_REALIZE("[poly] {} : {}", t, *suite);
initDist.ast->suite = suite;
t->ast = initDist.ast.get();
@ -518,7 +527,7 @@ size_t TypecheckVisitor::getRealizationID(types::ClassType *cp, types::FuncType
->vtables[cp->realizedName()];
// Add or extract thunk ID
size_t vid;
size_t vid = 0;
if (auto i = in(vt.table, key)) {
vid = i->second;
} else {
@ -681,14 +690,19 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
seqassert(types.size() == 1 && statics.size() == 1, "bad generics/statics");
handle = module->unsafeGetVectorType(statics[0]->getInt(), types[0]);
} else if (auto tr = t->getRecord()) {
seqassert(tr->getRepeats() >= 0, "repeats not resolved: '{}'", tr->debugString(2));
tr->flatten();
std::vector<ir::types::Type *> typeArgs;
std::vector<std::string> names;
std::map<std::string, SrcInfo> memberInfo;
for (int ai = 0; ai < tr->args.size(); ai++) {
names.emplace_back(ctx->cache->classes[t->name].fields[ai].name);
auto n = t->name == TYPE_TUPLE ? format("item{}", ai + 1)
: ctx->cache->classes[t->name].fields[ai].name;
names.emplace_back(n);
typeArgs.emplace_back(forceFindIRType(tr->args[ai]));
memberInfo[ctx->cache->classes[t->name].fields[ai].name] =
ctx->cache->classes[t->name].fields[ai].type->getSrcInfo();
memberInfo[n] = t->name == TYPE_TUPLE
? tr->getSrcInfo()
: ctx->cache->classes[t->name].fields[ai].type->getSrcInfo();
}
auto record =
ir::cast<ir::types::RecordType>(module->unsafeGetMemberedType(realizedName));

View File

@ -60,8 +60,8 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
if ((resultStmt = transformStaticForLoop(stmt)))
return;
bool maybeHeterogenous = startswith(iterType->name, TYPE_TUPLE) ||
startswith(iterType->name, TYPE_KWTUPLE);
bool maybeHeterogenous =
iterType->name == TYPE_TUPLE || startswith(iterType->name, TYPE_KWTUPLE);
if (maybeHeterogenous && !iterType->canRealize()) {
return; // wait until the tuple is fully realizable
} else if (maybeHeterogenous && iterType->getHeterogenousTuple()) {
@ -336,7 +336,7 @@ TypecheckVisitor::transformStaticLoopCall(
error("expected three items");
auto typ = args[0]->getClass();
size_t idx = 0;
for (auto &f : ctx->cache->classes[typ->name].fields) {
for (auto &f : getClassFields(typ.get())) {
std::vector<StmtPtr> stmts;
if (withIdx) {
stmts.push_back(
@ -370,7 +370,7 @@ TypecheckVisitor::transformStaticLoopCall(
seqassert(typ, "vars_types expects a realizable type, got '{}' instead",
generics[0]);
size_t idx = 0;
for (auto &f : ctx->cache->classes[typ->getClass()->name].fields) {
for (auto &f : getClassFields(typ->getClass().get())) {
auto ta = realize(ctx->instantiate(f.type, typ->getClass()));
seqassert(ta, "cannot realize '{}'", f.type->debugString(1));
std::vector<StmtPtr> stmts;

View File

@ -271,15 +271,29 @@ void TypecheckVisitor::visit(IndexExpr *expr) {
/// Instantiate(foo, [bar]) -> Id("foo[bar]")
void TypecheckVisitor::visit(InstantiateExpr *expr) {
transformType(expr->typeExpr);
TypePtr typ =
ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
std::shared_ptr<types::StaticType> repeats = nullptr;
if (expr->typeExpr->isId(TYPE_TUPLE) && !expr->typeParams.empty()) {
transform(expr->typeParams[0]);
if (expr->typeParams[0]->staticValue.type == StaticValue::INT) {
repeats = Type::makeStatic(ctx->cache, expr->typeParams[0]);
}
}
TypePtr typ = nullptr;
size_t typeParamsSize = expr->typeParams.size() - (repeats != nullptr);
if (expr->typeExpr->isId(TYPE_TUPLE)) {
typ = ctx->instantiateTuple(typeParamsSize);
} else {
typ = ctx->instantiate(expr->typeExpr->getSrcInfo(), expr->typeExpr->getType());
}
seqassert(typ->getClass(), "unknown type: {}", expr->typeExpr);
auto &generics = typ->getClass()->generics;
bool isUnion = typ->getUnion() != nullptr;
if (!isUnion && expr->typeParams.size() != generics.size())
if (!isUnion && typeParamsSize != generics.size())
E(Error::GENERICS_MISMATCH, expr, ctx->cache->rev(typ->getClass()->name),
generics.size(), expr->typeParams.size());
generics.size(), typeParamsSize);
if (expr->typeExpr->isId(TYPE_CALLABLE)) {
// Case: Callable[...] trait instantiation
@ -303,7 +317,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
typ->getLink()->trait = std::make_shared<TypeTrait>(expr->typeParams[0]->type);
unify(expr->type, typ);
} else {
for (size_t i = 0; i < expr->typeParams.size(); i++) {
for (size_t i = (repeats != nullptr); i < expr->typeParams.size(); i++) {
transform(expr->typeParams[i]);
TypePtr t = nullptr;
if (expr->typeParams[i]->isStatic()) {
@ -319,7 +333,10 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
if (isUnion)
typ->getUnion()->addType(t);
else
unify(t, generics[i].type);
unify(t, generics[i - (repeats != nullptr)].type);
}
if (repeats) {
typ->getRecord()->repeats = repeats;
}
if (isUnion) {
typ->getUnion()->seal();
@ -714,7 +731,7 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
const ExprPtr &expr, const ExprPtr &index) {
if (!tuple->getRecord())
return {false, nullptr};
if (!startswith(tuple->name, TYPE_TUPLE) && !startswith(tuple->name, TYPE_KWTUPLE) &&
if (tuple->name != TYPE_TUPLE && !startswith(tuple->name, TYPE_KWTUPLE) &&
!startswith(tuple->name, TYPE_PARTIAL)) {
if (tuple->is(TYPE_OPTIONAL)) {
if (auto newTuple = tuple->generics[0].type->getClass()) {
@ -743,16 +760,15 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
return false;
};
auto classItem = in(ctx->cache->classes, tuple->name);
seqassert(classItem, "cannot find class '{}'", tuple->name);
auto sz = classItem->fields.size();
auto classFields = getClassFields(tuple.get());
auto sz = int64_t(tuple->getRecord()->args.size());
int64_t start = 0, stop = sz, step = 1;
if (getInt(&start, index)) {
// Case: `tuple[int]`
auto i = translateIndex(start, stop);
if (i < 0 || i >= stop)
E(Error::TUPLE_RANGE_BOUNDS, index, stop - 1, i);
return {true, transform(N<DotExpr>(expr, classItem->fields[i].name))};
return {true, transform(N<DotExpr>(expr, classFields[i].name))};
} else if (auto slice = CAST(index, SliceExpr)) {
// Case: `tuple[int:int:int]`
if (!getInt(&start, slice->start) || !getInt(&stop, slice->stop) ||
@ -773,11 +789,11 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple,
for (auto i = start; (step > 0) ? (i < stop) : (i > stop); i += step) {
if (i < 0 || i >= sz)
E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i);
te.push_back(N<DotExpr>(clone(var), classItem->fields[i].name));
te.push_back(N<DotExpr>(clone(var), classFields[i].name));
}
ExprPtr e = transform(N<StmtExpr>(
std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(format(TYPE_TUPLE "{}", te.size()), "__new__"), te)));
ExprPtr e = transform(
N<StmtExpr>(std::vector<StmtPtr>{ass},
N<CallExpr>(N<DotExpr>(N<IdExpr>(TYPE_TUPLE), "__new__"), te)));
return {true, e};
}

View File

@ -61,8 +61,9 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
}
seqassert(expr->type, "type not set for {}", expr);
unify(typ, expr->type);
if (expr->done)
if (expr->done) {
ctx->changedNodes++;
}
}
realize(typ);
LOG_TYPECHECK("[expr] {}: {}{}", getSrcInfo(), expr, expr->isDone() ? "[done]" : "");
@ -182,7 +183,7 @@ TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, const std::string &mem
callArgs.push_back({"", std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a);
}
auto methods = ctx->findMethod(typ->name, member, false);
auto methods = ctx->findMethod(typ.get(), member, false);
auto m = findMatchingMethods(typ, methods, callArgs);
return m.empty() ? nullptr : m[0];
}
@ -195,7 +196,7 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ,
std::vector<CallExpr::Arg> callArgs;
for (auto &a : args)
callArgs.push_back({"", a});
auto methods = ctx->findMethod(typ->name, member, false);
auto methods = ctx->findMethod(typ.get(), member, false);
auto m = findMatchingMethods(typ, methods, callArgs);
return m.empty() ? nullptr : m[0];
}
@ -210,7 +211,7 @@ types::FuncTypePtr TypecheckVisitor::findBestMethod(
callArgs.push_back({n, std::make_shared<NoneExpr>()}); // dummy expression
callArgs.back().value->setType(a);
}
auto methods = ctx->findMethod(typ->name, member, false);
auto methods = ctx->findMethod(typ.get(), member, false);
auto m = findMatchingMethods(typ, methods, callArgs);
return m.empty() ? nullptr : m[0];
}
@ -451,8 +452,8 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType,
ExprPtr TypecheckVisitor::castToSuperClass(ExprPtr expr, ClassTypePtr superTyp,
bool isVirtual) {
ClassTypePtr typ = expr->type->getClass();
for (auto &field : ctx->cache->classes[typ->name].fields) {
for (auto &parentField : ctx->cache->classes[superTyp->name].fields)
for (auto &field : getClassFields(typ.get())) {
for (auto &parentField : getClassFields(superTyp.get()))
if (field.name == parentField.name) {
unify(ctx->instantiate(field.type, typ),
ctx->instantiate(parentField.type, superTyp));
@ -493,4 +494,16 @@ TypecheckVisitor::unpackTupleTypes(ExprPtr expr) {
return ret;
}
std::vector<Cache::Class::ClassField> &
TypecheckVisitor::getClassFields(types::ClassType *t) {
seqassert(t && in(ctx->cache->classes, t->name), "cannot find '{}'",
t ? t->name : "<null>");
if (t->is(TYPE_TUPLE) && !t->getRecord()->args.empty()) {
auto key = ctx->generateTuple(t->getRecord()->args.size());
return ctx->cache->classes[key].fields;
} else {
return ctx->cache->classes[t->name].fields;
}
}
} // namespace codon::ast

View File

@ -177,6 +177,10 @@ private: // Node typechecking rules
ExprPtr partializeFunction(const types::FuncTypePtr &);
std::shared_ptr<types::RecordType> getFuncTypeBase(size_t);
public:
types::FuncTypePtr makeFunctionType(FunctionStmt *);
private:
/* Classes (class.cpp) */
void visit(ClassStmt *) override;
void parseBaseClasses(ClassStmt *);
@ -227,7 +231,8 @@ private:
StmtPtr prepareVTables();
public:
bool isTuple(const std::string &s) const { return startswith(s, TYPE_TUPLE); }
bool isTuple(const std::string &s) const { return s == TYPE_TUPLE; }
std::vector<Cache::Class::ClassField> &getClassFields(types::ClassType *);
friend class Cache;
friend class types::CallableTrait;

View File

@ -19,7 +19,8 @@
#define DBG(c, ...) \
fmt::print(codon::getLogger().log, "{}" c "\n", \
std::string(2 * codon::getLogger().level, ' '), ##__VA_ARGS__)
std::string(size_t(2) * size_t(codon::getLogger().level), ' '), \
##__VA_ARGS__)
#define LOG(c, ...) DBG(c, ##__VA_ARGS__)
#define LOG_TIME(c, ...) \
{ \

View File

@ -1,5 +1,6 @@
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
@tuple(container=False) # disallow default __getitem__
class Vec[T, N: Static[int]]:
ZERO_16x8i = Vec[u8,16](u8(0))
@ -7,18 +8,50 @@ class Vec[T, N: Static[int]]:
ZERO_32x8i = Vec[u8,32](u8(0))
FF_32x8i = Vec[u8,32](u8(0xff))
ZERO_4x64f = Vec[f64,4](0.0)
@llvm
def _mm_set1_epi8(val: u8) -> Vec[u8, 16]:
%0 = insertelement <16 x i8> undef, i8 %val, i32 0
%1 = shufflevector <16 x i8> %0, <16 x i8> undef, <16 x i32> zeroinitializer
ret <16 x i8> %1
@llvm
def _mm_set1_epi32(val: u32) -> Vec[u32, 4]:
%0 = insertelement <4 x i32> undef, i32 %val, i32 0
%1 = shufflevector <4 x i32> %0, <4 x i32> undef, <4 x i32> zeroinitializer
ret <4 x i32> %1
@llvm
def _mm256_set1_epi8(val: u8) -> Vec[u8, 32]:
%0 = insertelement <32 x i8> undef, i8 %val, i32 0
%1 = shufflevector <32 x i8> %0, <32 x i8> undef, <32 x i32> zeroinitializer
ret <32 x i8> %1
@llvm
def _mm256_set1_epi32(val: u32) -> Vec[u32, 8]:
%0 = insertelement <8 x i32> undef, i32 %val, i32 0
%1 = shufflevector <8 x i32> %0, <8 x i32> undef, <8 x i32> zeroinitializer
ret <8 x i32> %1
@llvm
def _mm256_set1_epi64x(val: u64) -> Vec[u64, 4]:
%0 = insertelement <4 x i64> undef, i64 %val, i32 0
%1 = shufflevector <4 x i64> %0, <4 x i64> undef, <4 x i32> zeroinitializer
ret <4 x i64> %1
@llvm
def _mm512_set1_epi64(val: u64) -> Vec[u64, 8]:
%0 = insertelement <8 x i64> undef, i64 %val, i32 0
%1 = shufflevector <8 x i64> %0, <8 x i64> undef, <8 x i32> zeroinitializer
ret <8 x i64> %1
@llvm
def _mm_load_epi32(data) -> Vec[u32, 4]:
%0 = bitcast i32* %data to <4 x i32>*
%1 = load <4 x i32>, <4 x i32>* %0, align 1
ret <4 x i32> %1
@llvm
def _mm_loadu_si128(data) -> Vec[u8, 16]:
%0 = bitcast i8* %data to <16 x i8>*
@ -31,12 +64,42 @@ class Vec[T, N: Static[int]]:
%1 = load <32 x i8>, <32 x i8>* %0, align 1
ret <32 x i8> %1
@llvm
def _mm256_load_epi32(data) -> Vec[u32, 8]:
%0 = bitcast i32* %data to <8 x i32>*
%1 = load <8 x i32>, <8 x i32>* %0
ret <8 x i32> %1
@llvm
def _mm256_load_epi64(data) -> Vec[u64, 4]:
%0 = bitcast i64* %data to <4 x i64>*
%1 = load <4 x i64>, <4 x i64>* %0
ret <4 x i64> %1
@llvm
def _mm512_load_epi64(data) -> Vec[u64, 8]:
%0 = bitcast i64* %data to <8 x i64>*
%1 = load <8 x i64>, <8 x i64>* %0
ret <8 x i64> %1
@llvm
def _mm256_set1_ps(val: f32) -> Vec[f32, 8]:
%0 = insertelement <8 x float> undef, float %val, i32 0
%1 = shufflevector <8 x float> %0, <8 x float> undef, <8 x i32> zeroinitializer
ret <8 x float> %1
@llvm
def _mm256_set1_pd(val: f64) -> Vec[f64, 4]:
%0 = insertelement <4 x double> undef, double %val, i64 0
%1 = shufflevector <4 x double> %0, <4 x double> undef, <4 x i32> zeroinitializer
ret <4 x double> %1
@llvm
def _mm512_set1_pd(val: f64) -> Vec[f64, 8]:
%0 = insertelement <8 x double> undef, double %val, i64 0
%1 = shufflevector <8 x double> %0, <8 x double> undef, <8 x i32> zeroinitializer
ret <8 x double> %1
@llvm
def _mm512_set1_ps(val: f32) -> Vec[f32, 16]:
%0 = insertelement <16 x float> undef, float %val, i32 0
@ -49,6 +112,18 @@ class Vec[T, N: Static[int]]:
%1 = load <8 x float>, <8 x float>* %0
ret <8 x float> %1
@llvm
def _mm256_loadu_pd(data: Ptr[f64]) -> Vec[f64, 4]:
%0 = bitcast double* %data to <4 x double>*
%1 = load <4 x double>, <4 x double>* %0
ret <4 x double> %1
@llvm
def _mm512_loadu_pd(data: Ptr[f64]) -> Vec[f64, 8]:
%0 = bitcast double* %data to <8 x double>*
%1 = load <8 x double>, <8 x double>* %0
ret <8 x double> %1
@llvm
def _mm512_loadu_ps(data: Ptr[f32]) -> Vec[f32, 16]:
%0 = bitcast float* %data to <16 x float>*
@ -72,6 +147,11 @@ class Vec[T, N: Static[int]]:
%0 = bitcast <8 x i32> %vec to <8 x float>
ret <8 x float> %0
@llvm
def _mm256_castsi256_pd(vec: Vec[u64, 4]) -> Vec[f64, 4]:
%0 = bitcast <4 x i64> %vec to <4 x double>
ret <4 x double> %0
@llvm
def _mm512_castsi512_ps(vec: Vec[u32, 16]) -> Vec[f32, 16]:
%0 = bitcast <16 x i32> %vec to <16 x float>
@ -92,10 +172,58 @@ class Vec[T, N: Static[int]]:
return Vec._mm256_loadu_si256(x)
if isinstance(x, str):
return Vec._mm256_loadu_si256(x.ptr)
if isinstance(T, u32) and N == 4:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm_set1_epi32(u32(x))
if isinstance(x, u32):
return Vec._mm_set1_epi32(x)
if isinstance(x, Ptr[u32]):
return Vec._mm_load_epi32(x)
if isinstance(x, List[u32]):
return Vec._mm_load_epi32(x.arr.ptr)
if isinstance(x, str):
return Vec._mm_load_epi32(x.ptr)
if isinstance(T, u32) and N == 8:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm256_set1_epi32(u32(x))
if isinstance(x, u64):
return Vec._mm256_set1_epi32(x)
if isinstance(x, Ptr[u64]):
return Vec._mm256_load_epi32(x)
if isinstance(x, List[u64]):
return Vec._mm256_load_epi32(x.arr.ptr)
if isinstance(x, str):
return Vec._mm256_load_epi32(x.ptr)
if isinstance(T, u64) and N == 4:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm256_set1_epi64x(u64(x))
if isinstance(x, u64):
return Vec._mm256_set1_epi64x(x)
if isinstance(x, Ptr[u64]):
return Vec._mm256_load_epi64(x)
if isinstance(x, List[u64]):
return Vec._mm256_load_epi64(x.arr.ptr)
if isinstance(x, str):
return Vec._mm256_load_epi64(x.ptr)
if isinstance(T, u64) and N == 8:
if isinstance(x, int):
assert x >= 0, "SIMD: No support for negative int vectors added yet."
return Vec._mm512_set1_epi64(u64(x))
if isinstance(x, u64):
return Vec._mm512_set1_epi64(x)
if isinstance(x, Ptr[u64]):
return Vec._mm512_load_epi64(x)
if isinstance(x, List[u64]):
return Vec._mm512_load_epi64(x.arr.ptr)
if isinstance(x, str):
return Vec._mm512_load_epi64(x.ptr)
if isinstance(T, f32) and N == 8:
if isinstance(x, f32):
return Vec._mm256_set1_ps(x)
if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!]
if isinstance(x, Ptr[f32]):
return Vec._mm256_loadu_ps(x)
if isinstance(x, List[f32]):
return Vec._mm256_loadu_ps(x.arr.ptr)
@ -104,12 +232,26 @@ class Vec[T, N: Static[int]]:
if isinstance(T, f32) and N == 16:
if isinstance(x, f32):
return Vec._mm512_set1_ps(x)
if isinstance(x, Ptr[f32]): # TODO: multi-elif does NOT work with statics [why?!]
if isinstance(x, Ptr[f32]):
return Vec._mm512_loadu_ps(x)
if isinstance(x, List[f32]):
return Vec._mm512_loadu_ps(x.arr.ptr)
if isinstance(x, Vec[u8, 32]):
return Vec._mm512_castsi512_ps(Vec._mm512_cvtepi8_epi64(x))
if isinstance(T, f64) and N == 4:
if isinstance(x, f64):
return Vec._mm256_set1_pd(x)
if isinstance(x, Ptr[f64]):
return Vec._mm256_loadu_pd(x)
if isinstance(x, List[f64]):
return Vec._mm256_loadu_pd(x.arr.ptr)
if isinstance(T, f64) and N == 8:
if isinstance(x, f64):
return Vec._mm512_set1_pd(x)
if isinstance(x, Ptr[f64]):
return Vec._mm512_loadu_pd(x)
if isinstance(x, List[f64]):
return Vec._mm512_loadu_pd(x.arr.ptr)
compile_error("invalid SIMD vector constructor")
def __new__(x: str, offset: int = 0) -> Vec[u8, N]:
@ -185,6 +327,14 @@ class Vec[T, N: Static[int]]:
def __and__(self: Vec[u8, 32], other: Vec[u8, 32]) -> Vec[u8, 32]:
return Vec._mm_and_si256(self, other)
@llvm
def _mm256_and_si256(x: Vec[u64, 4], y: Vec[u64, 4]) -> Vec[u64, 4]:
%0 = and <4 x i64> %x, %y
ret <4 x i64> %0
def __and__(self: Vec[u64, 4], other: Vec[u64, 4]) -> Vec[u64, 4]:
return Vec._mm256_and_si256(self, other)
@llvm
def _mm256_and_ps(x: Vec[f32, 8], y: Vec[f32, 8]) -> Vec[f32, 8]:
%0 = bitcast <8 x float> %x to <8 x i32>
@ -255,9 +405,6 @@ class Vec[T, N: Static[int]]:
%0 = fadd <8 x float> %x, %y
ret <8 x float> %0
def __add__(self: Vec[f32, 8], other: Vec[f32, 8]) -> Vec[f32, 8]:
return Vec._mm256_add_ps(self, other)
def __rshift__(self: Vec[u8, 16], shift: Static[int]) -> Vec[u8, 16]:
if shift == 0:
return self
@ -294,18 +441,428 @@ class Vec[T, N: Static[int]]:
def sum(self: Vec[f32, 8], x: f32 = f32(0.0)) -> f32:
return x + self[0] + self[1] + self[2] + self[3] + self[4] + self[5] + self[6] + self[7]
def __repr__(self):
return f"<{','.join(self.scatter())}>"
# Methods below added ad-hoc. TODO: Add Intel intrinsics equivalents for them and integrate them neatly above.
# Constructors
@llvm
def __getitem__(self, n: Static[int]) -> T:
%0 = extractelement <{=N} x {=T}> %self, i32 {=n}
def generic_init[SLS: Static[int]](val: T) -> Vec[T, SLS]:
%0 = insertelement <{=SLS} x {=T}> undef, {=T} %val, i32 0
%1 = shufflevector <{=SLS} x {=T}> %0, <{=SLS} x {=T}> undef, <{=SLS} x i32> zeroinitializer
ret <{=SLS} x {=T}> %1
@llvm
def generic_load[SLS: Static[int]](data: Ptr[T]) -> Vec[T, SLS]:
%0 = bitcast T* %data to <{=SLS} x {=T}>*
%1 = load <{=SLS} x {=T}>, <{=SLS} x {=T}>* %0
ret <{=SLS} x {=T}> %1
# Bitwise intrinsics
@llvm
def __and__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = and <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __and__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = and <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __or__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = or <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __or__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = or <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __xor__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = xor <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __xor__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = xor <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __lshift__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = shl <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __lshift__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = shl <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def __rshift__(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = lshr <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def __rshift__(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = lshr <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def bit_flip(self: Vec[u1, N]) -> Vec[u1, N]:
%0 = insertelement <{=N} x i1> undef, i1 1, i32 0
%1 = shufflevector <{=N} x i1> %0, <{=N} x i1> undef, <{=N} x i32> zeroinitializer
%2 = xor <{=N} x i1> %self, %1
ret <{=N} x i1> %2
@llvm
def shift_half(self: Vec[u128, N]) -> Vec[u128, N]:
%0 = insertelement <{=N} x i128> undef, i128 64, i32 0
%1 = shufflevector <{=N} x i128> %0, <{=N} x i128> undef, <{=N} x i32> zeroinitializer
%2 = lshr <{=N} x i128> %self, %1
ret <{=N} x i128> %2
# Comparisons
@llvm
def __ge__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u1, N]:
%0 = icmp uge <{=N} x i64> %self, %other
ret <{=N} x i1> %0
@llvm
def __ge__(self: Vec[u64, N], other: u64) -> Vec[u1, N]:
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = icmp uge <{=N} x i64> %self, %1
ret <{=N} x i1> %2
@llvm
def __ge__(self: Vec[f64, N], other: Vec[f64, N]) -> Vec[u1, N]:
%0 = fcmp oge <{=N} x double> %self, %other
ret <{=N} x i1> %0
@llvm
def __ge__(self: Vec[f64, N], other: f64) -> Vec[u1, N]:
%0 = insertelement <{=N} x double> undef, double %other, i32 0
%1 = shufflevector <{=N} x double> %0, <{=N} x double> undef, <{=N} x i32> zeroinitializer
%2 = fcmp oge <{=N} x double> %self, %1
ret <{=N} x i1> %2
@llvm
def __le__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u1, N]:
%0 = icmp ule <{=N} x i64> %self, %other
ret <{=N} x i1> %0
@llvm
def __le__(self: Vec[u64, N], other: u64) -> Vec[u1, N]:
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = icmp ule <{=N} x i64> %self, %1
ret <{=N} x i1> %2
# Arithmetic intrinsics
@llvm
def __neg__(self: Vec[T, N]) -> Vec[T, N]:
%0 = sub <{=N} x {=T}> zeroinitializer, %self
ret <{=N} x {=T}> %0
@llvm
def __mod__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u64, N]:
%0 = urem <{=N} x i64> %self, %other
ret <{=N} x i64> %0
@llvm
def __mod__(self: Vec[u64, N], other: u64) -> Vec[u64, N]:
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = urem <{=N} x i64> %self, %1
ret <{=N} x i64> %2
@llvm
def add(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = add <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def add(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = add <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def fadd(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = fadd <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def fadd(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = fadd <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
def __add__(self: Vec[T, N], other) -> Vec[T, N]:
if isinstance(T, u8) or isinstance(T, u64) or isinstance(T, u128):
return self.add(other)
if isinstance(T, f32) or isinstance(T, f64):
return self.fadd(other)
compile_error("invalid SIMD vector addition")
@llvm
def sub(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = sub <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def sub(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = sub <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def fsub(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = fsub <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def fsub(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = fsub <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
def __sub__(self: Vec[T, N], other) -> Vec[T, N]:
if isinstance(T, u8) or isinstance(T, u64):
return self.sub(other)
if isinstance(T, f32) or isinstance(T, f64):
return self.fsub(other)
compile_error("invalid SIMD vector subtraction")
@llvm
def mul(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = mul <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def fmul(self: Vec[T, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = fmul <{=N} x {=T}> %self, %other
ret <{=N} x {=T}> %0
@llvm
def mul(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = mul <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
@llvm
def fmul(self: Vec[T, N], other: T) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = fmul <{=N} x {=T}> %self, %1
ret <{=N} x {=T}> %2
def __mul__(self: Vec[T, N], other) -> Vec[T, N]:
if isinstance(T, u8) or isinstance(T, u64) or isinstance(T, u128):
return self.mul(other)
if isinstance(T, f32) or isinstance(T, f64):
return self.fmul(other)
compile_error("invalid SIMD vector multiplication")
@llvm
def __truediv__(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[f64, N]:
%0 = uitofp <{=N} x i64> %self to <{=N} x double>
%1 = uitofp <{=N} x i64> %other to <{=N} x double>
%2 = fdiv <{=N} x double> %0, %1
ret <{=N} x double> %2
@llvm
def __truediv__(self: Vec[u64, 4], other: Vec[u64, 4]) -> Vec[f64, 4]:
%0 = uitofp <4 x i64> %self to <4 x double>
%1 = uitofp <4 x i64> %other to <4 x double>
%2 = fdiv <4 x double> %0, %1
ret <4 x double> %2
@llvm
def __truediv__(self: Vec[u64, 8], other: Vec[u64, 8]) -> Vec[f64, 8]:
%0 = uitofp <8 x i64> %self to <8 x double>
%1 = uitofp <8 x i64> %other to <8 x double>
%2 = fdiv <8 x double> %0, %1
ret <8 x double> %2
@llvm
def __truediv__(self: Vec[u64, 8], other: u64) -> Vec[f64, 8]:
%0 = uitofp <8 x i64> %self to <8 x double>
%1 = uitofp i64 %other to double
%2 = insertelement <8 x double> undef, double %1, i32 0
%3 = shufflevector <8 x double> %0, <8 x double> undef, <8 x i32> zeroinitializer
%4 = fdiv <8 x double> %0, %3
ret <8 x double> %4
@llvm
def add_overflow(self: Vec[u64, N], other: Vec[u64, N]) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = call {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %other)
ret {<{=N} x i64>, <{=N} x i1>} %0
@llvm
def add_overflow(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = call {<{=N} x i64>, <{=N} x i1>} @llvm.uadd.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %1)
ret {<{=N} x i64>, <{=N} x i1>} %2
@llvm
def sub_overflow(self: Vec[u64, N], other: Vec[u64, N]) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = call {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %other)
ret {<{=N} x i64>, <{=N} x i1>} %0
def sub_overflow_commutative(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
return Vec[u64, N](other).sub_overflow(self)
@llvm
def sub_overflow(self: Vec[u64, N], other: u64) -> Tuple[Vec[u64, N], Vec[u1, N]]:
declare {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64>, <{=N} x i64>)
%0 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%1 = shufflevector <{=N} x i64> %0, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%2 = call {<{=N} x i64>, <{=N} x i1>} @llvm.usub.with.overflow.v{=N}i64(<{=N} x i64> %self, <{=N} x i64> %1)
ret {<{=N} x i64>, <{=N} x i1>} %2
@llvm
def zext_mul(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u128, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
%1 = zext <{=N} x i64> %other to <{=N} x i128>
%2 = mul nuw <{=N} x i128> %0, %1
ret <{=N} x i128> %2
@llvm
def zext_mul(self: Vec[u64, N], other: u64) -> Vec[u128, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
%1 = insertelement <{=N} x i64> undef, i64 %other, i32 0
%2 = shufflevector <{=N} x i64> %1, <{=N} x i64> undef, <{=N} x i32> zeroinitializer
%3 = zext <{=N} x i64> %2 to <{=N} x i128>
%4 = mul nuw <{=N} x i128> %0, %3
ret <{=N} x i128> %4
@nocapture
@llvm
def mulx(self: Vec[u64, N], other: Vec[u64, N], hi: Ptr[Vec[u64, N]]) -> Vec[u64, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
%1 = zext <{=N} x i64> %other to <{=N} x i128>
%2 = mul nuw <{=N} x i128> %0, %1
%3 = lshr <{=N} x i128> %2, <i128 64, i128 64, i128 64, i128 64, i128 64, i128 64, i128 64, i128 64>
%4 = trunc <{=N} x i128> %3 to <{=N} x i64>
store <{=N} x i64> %4, <{=N} x i64>* %hi, align 8
%5 = trunc <{=N} x i128> %2 to <{=N} x i64>
ret <{=N} x i64> %5
def mulhi(self: Vec[u64, N], other: Vec[u64, N]) -> Vec[u64, N]:
hi = Ptr[Vec[u64, N]](1)
self.mulx(other, hi)
return hi[0]
@llvm
def sqrt(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.sqrt.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.sqrt.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
@llvm
def log(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.log.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.log.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
@llvm
def cos(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.cos.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.cos.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
@llvm
def fabs(self: Vec[f64, N]) -> Vec[f64, N]:
declare <{=N} x double> @llvm.fabs.v{=N}f64(<{=N} x double>)
%0 = call <{=N} x double> @llvm.fabs.v{=N}f64(<{=N} x double> %self)
ret <{=N} x double> %0
# Conversion intrinsics
@llvm
def zext_double(self: Vec[u64, N]) -> Vec[u128, N]:
%0 = zext <{=N} x i64> %self to <{=N} x i128>
ret <{=N} x i128> %0
@llvm
def trunc_half(self: Vec[u128, N]) -> Vec[u64, N]:
%0 = trunc <{=N} x i128> %self to <{=N} x i64>
ret <{=N} x i64> %0
@llvm
def shift_trunc_half(self: Vec[u128, N]) -> Vec[u64, N]:
%0 = insertelement <{=N} x i128> undef, i128 64, i32 0
%1 = shufflevector <{=N} x i128> %0, <{=N} x i128> undef, <{=N} x i32> zeroinitializer
%2 = lshr <{=N} x i128> %self, %1
%3 = trunc <{=N} x i128> %2 to <{=N} x i64>
ret <{=N} x i64> %3
@llvm
def to_u64(self: Vec[f64, N]) -> Vec[u64, N]:
%0 = fptoui <{=N} x double> %self to <{=N} x i64>
ret <{=N} x i64> %0
@llvm
def to_u64(self: Vec[u1, N]) -> Vec[u64, N]:
%0 = zext <{=N} x i1> %self to <{=N} x i64>
ret <{=N} x i64> %0
@llvm
def to_float(self: Vec[T, N]) -> Vec[f64, N]:
%0 = uitofp <{=N} x {=T}> %self to <{=N} x double>
ret <{=N} x double> %0
# Predication intrinsics
@llvm
def sub_if(self: Vec[T, N], other: Vec[T, N], mask: Vec[u1, N]) -> Vec[T, N]:
%0 = sub <{=N} x {=T}> %self, %other
%1 = select <{=N} x i1> %mask, <{=N} x {=T}> %0, <{=N} x {=T}> %self
ret <{=N} x {=T}> %1
@llvm
def sub_if(self: Vec[T, N], other: T, mask: Vec[u1, N]) -> Vec[T, N]:
%0 = insertelement <{=N} x {=T}> undef, {=T} %other, i32 0
%1 = shufflevector <{=N} x {=T}> %0, <{=N} x {=T}> undef, <{=N} x i32> zeroinitializer
%2 = sub <{=N} x {=T}> %self, %1
%3 = select <{=N} x i1> %mask, <{=N} x {=T}> %2, <{=N} x {=T}> %self
ret <{=N} x {=T}> %3
# Gather-scatter
@llvm
def __getitem__(self: Vec[T, N], idx) -> T:
%0 = extractelement <{=N} x {=T}> %self, i64 %idx
ret {=T} %0
def __repr__(self):
if N == 8:
return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}>"
elif N == 16:
return f"<{self[0]}, {self[1]}, {self[2]}, {self[3]}, {self[4]}, {self[5]}, {self[6]}, {self[7]}, {self[8]}, {self[9]}, {self[10]}, {self[11]}, {self[12]}, {self[13]}, {self[14]}, {self[15]}>"
else:
return "?"
# Misc
def copy(self: Vec[T, N]) -> Vec[T, N]:
return self
@llvm
def mask(self: Vec[T, N], mask: Vec[u1, N], other: Vec[T, N]) -> Vec[T, N]:
%0 = select <{=N} x i1> %mask, <{=N} x {=T}> %self, <{=N} x {=T}> %other
ret <{=N} x {=T}> %0
def scatter(self: Vec[T, N]) -> List[T]:
return [self[i] for i in staticrange(N)]
@ -314,3 +871,15 @@ class Vec[T, N: Static[int]]:
u8x16 = Vec[u8, 16]
u8x32 = Vec[u8, 32]
f32x8 = Vec[f32, 8]
@llvm
def bitcast_scatter[N: Static[int]](ptr_in: Ptr[Vec[u64, N]]) -> Ptr[u64]:
%0 = bitcast <{=N} x i64>* %ptr_in to i64*
ret i64* %0
@llvm
def bitcast_vectorize[N: Static[int]](ptr_in: Ptr[u64]) -> Ptr[Vec[u64, N]]:
%0 = bitcast i64* %ptr_in to <{=N} x i64>*
ret <{=N} x i64>* %0

View File

@ -102,6 +102,59 @@ class str:
ptr: Ptr[byte]
len: int
@tuple
@__internal__
class Tuple:
@__internal__
def __new__() -> Tuple:
pass
def __add__(self, obj):
return __magic__.add(self, obj)
def __mul__(self, n: Static[int]):
return __magic__.mul(self, n)
def __contains__(self, obj) -> bool:
return __magic__.contains(self, obj)
def __getitem__(self, idx: int):
return __magic__.getitem(self, idx)
def __iter__(self):
yield from __magic__.iter(self)
def __hash__(self) -> int:
return __magic__.hash(self)
def __repr__(self) -> str:
return __magic__.repr(self)
def __len__(self) -> int:
return __magic__.len(self)
def __eq__(self, obj: Tuple) -> bool:
return __magic__.eq(self, obj)
def __ne__(self, obj: Tuple) -> bool:
return __magic__.ne(self, obj)
def __gt__(self, obj: Tuple) -> bool:
return __magic__.gt(self, obj)
def __ge__(self, obj: Tuple) -> bool:
return __magic__.ge(self, obj)
def __lt__(self, obj: Tuple) -> bool:
return __magic__.lt(self, obj)
def __le__(self, obj: Tuple) -> bool:
return __magic__.le(self, obj)
def __pickle__(self, dest: Ptr[byte]):
return __magic__.pickle(self, dest)
def __unpickle__(src: Ptr[byte]) -> Tuple:
return __magic__.unpickle(src)
def __to_py__(self) -> Ptr[byte]:
return __magic__.to_py(self)
def __from_py__(src: Ptr[byte]) -> Tuple:
return __magic__.from_py(src)
def __to_gpu__(self, cache) -> Tuple:
return __magic__.to_gpu(self, cache)
def __from_gpu__(self, other: Tuple):
return __magic__.from_gpu(self, other)
def __from_gpu_new__(other: Tuple) -> Tuple:
return __magic__.from_gpu_new(other)
def __tuplesize__(self) -> int:
return __magic__.tuplesize(self)
@tuple
@__internal__
class Array:
@ -134,8 +187,6 @@ class TypeVar[T]: pass
class ByVal: pass
@__internal__
class ByRef: pass
@__internal__
class Tuple: pass
@__internal__
class ClassVar[T]:
@ -147,10 +198,7 @@ class RTTI:
@__internal__
@tuple
class ellipsis:
def __new__() -> ellipsis:
return ()
Ellipsis = ellipsis()
class ellipsis: pass
@tuple
@__internal__

View File

@ -6,14 +6,18 @@ from internal.gc import (
)
from internal.static import vars_types, tuple_type, vars as _vars
def vars(obj, with_index: Static[int] = 0):
return _vars(obj, with_index)
__vtables__ = Ptr[Ptr[cobj]]()
__vtable_size__ = 0
@extend
class ellipsis:
def __new__() -> ellipsis:
return ()
Ellipsis = ellipsis()
@extend
class __internal__:
@ -441,7 +445,6 @@ class __internal__:
else:
return default
@extend
class __magic__:
# always present
@ -635,7 +638,6 @@ class __magic__:
def str(slf) -> str:
return slf.__repr__()
@dataclass(init=True)
@tuple
class Import:
@ -670,19 +672,16 @@ class Function:
def __call__(self, *args) -> TR:
return Function.__call_internal__(self, args)
@tuple
class PyObject:
refcnt: int
pytype: Ptr[byte]
@tuple
class PyWrapper[T]:
head: PyObject
data: T
@extend
class RTTI:
def __new__() -> RTTI:

View File

@ -37,7 +37,7 @@ class float:
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp one double %self, 0.000000e+00
%0 = fcmp une double %self, 0.000000e+00
%1 = zext i1 %0 to i8
ret i8 %1
@ -118,10 +118,9 @@ class float:
@pure
@llvm
def __ne__(a: float, b: float) -> bool:
entry:
%tmp = fcmp one double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
%tmp = fcmp une double %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
@ -443,7 +442,7 @@ class float32:
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp one float %self, 0.000000e+00
%0 = fcmp une float %self, 0.000000e+00
%1 = zext i1 %0 to i8
ret i8 %1
@ -521,10 +520,9 @@ class float32:
@pure
@llvm
def __ne__(a: float32, b: float32) -> bool:
entry:
%tmp = fcmp one float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
%tmp = fcmp une float %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
@ -760,3 +758,4 @@ class float:
return float32.__new__(double)
f32 = float32
f64 = float

View File

@ -507,11 +507,16 @@ class UInt:
def len() -> int:
return N
i1 = Int[1]
i8 = Int[8]
i16 = Int[16]
i32 = Int[32]
i64 = Int[64]
i128 = Int[128]
u1 = UInt[1]
u8 = UInt[8]
u16 = UInt[16]
u32 = UInt[32]
u64 = UInt[64]
u128 = UInt[128]

View File

@ -538,7 +538,7 @@ def _task_loop_outline_template(gtid_ptr: Ptr[i32], btid_ptr: Ptr[i32], args):
task = Ptr[TaskWithPrivates[P]](data)[0]
priv = task.data
gtid64 = int(gtid)
if staticlen(S()) != 0:
if staticlen(S) != 0:
shared = Ptr[S](task.task.shareds)[0]
_task_loop_body_stub(gtid64, priv, shared)
else:

View File

@ -67,6 +67,16 @@ print f.a #: 20
f[1] = -8
print f.a #: 12
def foo():
print('foo')
return 0
v = [0]
v[foo()] += 1
#: foo
print(v)
#: [1]
#%% assign_err_1,barebones
a, *b, c, *d = 1,2,3,4,5 #! multiple starred expressions in assignment
@ -163,12 +173,12 @@ assert True, "blah"
try:
assert False
except AssertionError as e:
print e.message[:15], e.message[-24:] #: Assert failed ( simplify_stmt.codon:164)
print e.message[:15], e.message[-24:] #: Assert failed ( simplify_stmt.codon:174)
try:
assert False, f"hehe {1}"
except AssertionError as e:
print e.message[:23], e.message[-24:] #: Assert failed: hehe 1 ( simplify_stmt.codon:169)
print e.message[:23], e.message[-24:] #: Assert failed: hehe 1 ( simplify_stmt.codon:179)
#%% print,barebones
print 1,

View File

@ -773,7 +773,7 @@ z = tuple(c)
print z, z.__class__.__name__ #: (3, 'heh') Tuple[int,str]
#%% static_unify,barebones
def foo(x: Callable[[1,2], 3]): pass #! '1' does not match expected type 'T1'
def foo(x: Callable[[1,2], 3]): pass #! '2' does not match expected type 'T1'
#%% static_unify_2,barebones
def foo(x: List[1]): pass #! '1' does not match expected type 'T'

View File

@ -1971,3 +1971,30 @@ foo(5) #! generic 'b' not provided
def f(a, b, T: type):
print(a, b)
f(1, 2) #! generic 'T' not provided
#%% variardic_tuples,barebones
class Foo[N: Static[int]]:
x: Tuple[N, str]
def __init__(self):
self.x = ("hi", ) * N
f = Foo[5]()
print(f.__class__.__name__)
#: Foo[5]
print(f.x.__class__.__name__)
#: Tuple[str,str,str,str,str]
print(f.x)
#: ('hi', 'hi', 'hi', 'hi', 'hi')
print(Tuple[int, int].__class__.__name__)
#: Tuple[int,int]
print(Tuple[3, int].__class__.__name__)
#: Tuple[int,int,int]
print(Tuple[0].__class__.__name__)
#: Tuple
print(Tuple[-5, int].__class__.__name__)
#: Tuple
print(Tuple[5, int, str].__class__.__name__)
#: Tuple[int,str,int,str,int,str,int,str,int,str]