1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00

Fix POD type unification with tuples

This commit is contained in:
Ibrahim Numanagić 2023-02-22 12:24:06 -08:00
parent 259aec30be
commit f3f3e7ee91
6 changed files with 29 additions and 15 deletions

View File

@ -564,6 +564,8 @@ void ClassStmt::parseDecorators() {
for (auto &d : decorators) { for (auto &d : decorators) {
if (d->isId("deduce")) { if (d->isId("deduce")) {
attributes.customAttr.insert("deduce"); attributes.customAttr.insert("deduce");
} else if (d->isId("__notuple__")) {
attributes.customAttr.insert("__notuple__");
} else if (auto c = d->getCall()) { } else if (auto c = d->getCall()) {
if (c->expr->isId(Attr::Tuple)) { if (c->expr->isId(Attr::Tuple)) {
attributes.set(Attr::Tuple); attributes.set(Attr::Tuple);

View File

@ -130,13 +130,13 @@ std::string ClassType::realizedTypeName() const {
RecordType::RecordType(Cache *cache, std::string name, std::string niceName, RecordType::RecordType(Cache *cache, std::string name, std::string niceName,
std::vector<Generic> generics, std::vector<TypePtr> args, std::vector<Generic> generics, std::vector<TypePtr> args,
bool isInternal) bool noTuple)
: ClassType(cache, std::move(name), std::move(niceName), std::move(generics)), : ClassType(cache, std::move(name), std::move(niceName), std::move(generics)),
args(std::move(args)), isInternal(false) {} args(std::move(args)), noTuple(false) {}
RecordType::RecordType(const ClassTypePtr &base, std::vector<TypePtr> args, RecordType::RecordType(const ClassTypePtr &base, std::vector<TypePtr> args,
bool isInternal) bool noTuple)
: ClassType(base), args(std::move(args)), isInternal(isInternal) {} : ClassType(base), args(std::move(args)), noTuple(noTuple) {}
int RecordType::unify(Type *typ, Unification *us) { int RecordType::unify(Type *typ, Unification *us) {
if (auto tr = typ->getRecord()) { if (auto tr = typ->getRecord()) {
@ -159,8 +159,7 @@ int RecordType::unify(Type *typ, Unification *us) {
} }
// Handle Tuple<->@tuple: when unifying tuples, only record members matter. // Handle Tuple<->@tuple: when unifying tuples, only record members matter.
if (startswith(name, TYPE_TUPLE) || startswith(tr->name, TYPE_TUPLE)) { if (startswith(name, TYPE_TUPLE) || startswith(tr->name, TYPE_TUPLE)) {
if (!args.empty() || if (!args.empty() || (!noTuple && !tr->noTuple)) // prevent POD<->() unification
(!isInternal && !tr->isInternal)) // prevent int<->() unification
return s1 + int(name == tr->name); return s1 + int(name == tr->name);
else else
return -1; return -1;
@ -178,7 +177,7 @@ TypePtr RecordType::generalize(int atLevel) {
auto a = args; auto a = args;
for (auto &t : a) for (auto &t : a)
t = t->generalize(atLevel); t = t->generalize(atLevel);
return std::make_shared<RecordType>(c, a, isInternal); return std::make_shared<RecordType>(c, a, noTuple);
} }
TypePtr RecordType::instantiate(int atLevel, int *unboundCount, TypePtr RecordType::instantiate(int atLevel, int *unboundCount,
@ -188,7 +187,7 @@ TypePtr RecordType::instantiate(int atLevel, int *unboundCount,
auto a = args; auto a = args;
for (auto &t : a) for (auto &t : a)
t = t->instantiate(atLevel, unboundCount, cache); t = t->instantiate(atLevel, unboundCount, cache);
return std::make_shared<RecordType>(c, a, isInternal); return std::make_shared<RecordType>(c, a, noTuple);
} }
std::vector<TypePtr> RecordType::getUnbounds() const { std::vector<TypePtr> RecordType::getUnbounds() const {

View File

@ -77,14 +77,14 @@ using ClassTypePtr = std::shared_ptr<ClassType>;
struct RecordType : public ClassType { struct RecordType : public ClassType {
/// List of tuple arguments. /// List of tuple arguments.
std::vector<TypePtr> args; std::vector<TypePtr> args;
bool isInternal; bool noTuple;
explicit RecordType( explicit RecordType(
Cache *cache, std::string name, std::string niceName, Cache *cache, std::string name, std::string niceName,
std::vector<ClassType::Generic> generics = std::vector<ClassType::Generic>(), std::vector<ClassType::Generic> generics = std::vector<ClassType::Generic>(),
std::vector<TypePtr> args = std::vector<TypePtr>(), bool isInternal = false); std::vector<TypePtr> args = std::vector<TypePtr>(), bool noTuple = false);
RecordType(const ClassTypePtr &base, std::vector<TypePtr> args, RecordType(const ClassTypePtr &base, std::vector<TypePtr> args,
bool isInternal = false); bool noTuple = false);
public: public:
int unify(Type *typ, Unification *undo) override; int unify(Type *typ, Unification *undo) override;

View File

@ -26,8 +26,8 @@ void TypecheckVisitor::visit(ClassStmt *stmt) {
auto typ = Type::makeType(ctx->cache, stmt->name, ctx->cache->rev(stmt->name), auto typ = Type::makeType(ctx->cache, stmt->name, ctx->cache->rev(stmt->name),
stmt->isRecord()) stmt->isRecord())
->getClass(); ->getClass();
if (stmt->isRecord() && stmt->hasAttr(Attr::Internal)) if (stmt->isRecord() && stmt->hasAttr("__notuple__"))
typ->getRecord()->isInternal = true; typ->getRecord()->noTuple = true;
if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) { if (stmt->isRecord() && startswith(stmt->name, TYPE_PARTIAL)) {
// Special handling of partial types (e.g., `Partial.0001.foo`) // Special handling of partial types (e.g., `Partial.0001.foo`)
if (auto p = in(ctx->cache->partials, stmt->name)) if (auto p = in(ctx->cache->partials, stmt->name))

View File

@ -6,28 +6,33 @@ class __internal__:
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class bool: class bool:
pass pass
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class byte: class byte:
pass pass
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class int: class int:
MAX = 9223372036854775807 MAX = 9223372036854775807
pass pass
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class float: class float:
MIN_10_EXP = -307 MIN_10_EXP = -307
pass pass
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class float32: class float32:
MIN_10_EXP = -37 MIN_10_EXP = -37
pass pass
@ -44,6 +49,7 @@ class type:
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class Function[T, TR]: class Function[T, TR]:
pass pass
@ -54,27 +60,32 @@ class Callable[T, TR]:
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class Ptr[T]: class Ptr[T]:
pass pass
cobj = Ptr[byte] cobj = Ptr[byte]
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class Generator[T]: class Generator[T]:
pass pass
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class Optional: class Optional:
T: type = NoneType T: type = NoneType
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class Int[N: Static[int]]: class Int[N: Static[int]]:
pass pass
@tuple @tuple
@__internal__ @__internal__
@__notuple__
class UInt[N: Static[int]]: class UInt[N: Static[int]]:
pass pass
@ -105,8 +116,9 @@ function = Function
class Ref[T]: class Ref[T]:
pass pass
@__internal__
@tuple @tuple
@__internal__
@__notuple__
class Union[TU]: class Union[TU]:
# compiler-generated # compiler-generated
def __new__(val): def __new__(val):

View File

@ -1019,7 +1019,8 @@ print [a for a in ()] #: []
def foo(*args): def foo(*args):
return [a for a in args] return [a for a in args]
args, result = ((), [()]) args, result = ((), [()])
print list(foo(*args)) == result #: False print list(foo(*args)) #: []
print result #: [()]
#%% type_error_reporting #%% type_error_reporting