GPU compilation fixes (#496)

* Fix __from_gpu_new__

* Fix GPU tests

* Update GPU debug codegen

* Add will-return attribute for GPU compilation

* Fix isinstance on unresolved types

* Fix union type instantiation and pendingRealizations placement

* Add float16, bfloat16 and float128 IR types

* Add float16, bfloat16 and float128 types

* Mark complex64 as no-python

* Fix float methods

* Add float tests

* Disable some float tests

* Fix bitset in reaching definitions analysis

* Fix static bool unification

---------

Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
pull/499/head
A. R. Shajii 2023-11-18 15:14:05 -05:00 committed by GitHub
parent 4eb641e3cb
commit 2c7440768d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1137 additions and 25 deletions

View File

@ -71,10 +71,10 @@ struct BitSet {
return res;
}
void set(unsigned bit) { words.data()[bit / B] |= (1 << (bit % B)); }
void set(unsigned bit) { words.data()[bit / B] |= (1UL << (bit % B)); }
bool get(unsigned bit) const {
return (words.data()[bit / B] & (1 << (bit % B))) != 0;
return (words.data()[bit / B] & (1UL << (bit % B))) != 0;
}
bool equals(const BitSet &other, unsigned size) {

View File

@ -500,6 +500,14 @@ void moduleToPTX(llvm::Module *M, const std::string &filename,
linkLibdevice(M, libdevice);
remapFunctions(M);
// Strip debug info and remove noinline from functions (added in debug mode).
// Also, tell LLVM that all functions will return.
for (auto &F : *M) {
F.removeFnAttr(llvm::Attribute::AttrKind::NoInline);
F.setWillReturn();
}
llvm::StripDebugInfo(*M);
// Run NVPTX passes and general opt pipeline.
{
llvm::LoopAnalysisManager lam;

View File

@ -2086,6 +2086,18 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
return B->getFloatTy();
}
if (auto *x = cast<types::Float16Type>(t)) {
return B->getHalfTy();
}
if (auto *x = cast<types::BFloat16Type>(t)) {
return B->getBFloatTy();
}
if (auto *x = cast<types::Float128Type>(t)) {
return llvm::Type::getFP128Ty(*context);
}
if (auto *x = cast<types::BoolType>(t)) {
return B->getInt8Ty();
}
@ -2203,6 +2215,22 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}
if (auto *x = cast<types::Float16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}
if (auto *x = cast<types::BFloat16Type>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
}
if (auto *x = cast<types::Float128Type>(t)) {
return db.builder->createBasicType(x->getName(),
layout.getTypeAllocSizeInBits(type),
llvm::dwarf::DW_ATE_HP_float128);
}
if (auto *x = cast<types::BoolType>(t)) {
return db.builder->createBasicType(
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean);

View File

@ -62,6 +62,9 @@ const std::string Module::BYTE_NAME = "byte";
const std::string Module::INT_NAME = "int";
const std::string Module::FLOAT_NAME = "float";
const std::string Module::FLOAT32_NAME = "float32";
const std::string Module::FLOAT16_NAME = "float16";
const std::string Module::BFLOAT16_NAME = "bfloat16";
const std::string Module::FLOAT128_NAME = "float128";
const std::string Module::STRING_NAME = "str";
const std::string Module::EQ_MAGIC_NAME = "__eq__";
@ -239,6 +242,24 @@ types::Type *Module::getFloat32Type() {
return Nr<types::Float32Type>();
}
types::Type *Module::getFloat16Type() {
if (auto *rVal = getType(FLOAT16_NAME))
return rVal;
return Nr<types::Float16Type>();
}
types::Type *Module::getBFloat16Type() {
if (auto *rVal = getType(BFLOAT16_NAME))
return rVal;
return Nr<types::BFloat16Type>();
}
types::Type *Module::getFloat128Type() {
if (auto *rVal = getType(FLOAT128_NAME))
return rVal;
return Nr<types::Float128Type>();
}
types::Type *Module::getStringType() {
if (auto *rVal = getType(STRING_NAME))
return rVal;

View File

@ -34,6 +34,9 @@ public:
static const std::string INT_NAME;
static const std::string FLOAT_NAME;
static const std::string FLOAT32_NAME;
static const std::string FLOAT16_NAME;
static const std::string BFLOAT16_NAME;
static const std::string FLOAT128_NAME;
static const std::string STRING_NAME;
static const std::string EQ_MAGIC_NAME;
@ -338,6 +341,12 @@ public:
types::Type *getFloatType();
/// @return the float32 type
types::Type *getFloat32Type();
/// @return the float16 type
types::Type *getFloat16Type();
/// @return the bfloat16 type
types::Type *getBFloat16Type();
/// @return the float128 type
types::Type *getFloat128Type();
/// @return the string type
types::Type *getStringType();
/// Gets a pointer type.

View File

@ -69,6 +69,12 @@ const char FloatType::NodeId = 0;
const char Float32Type::NodeId = 0;
const char Float16Type::NodeId = 0;
const char BFloat16Type::NodeId = 0;
const char Float128Type::NodeId = 0;
const char BoolType::NodeId = 0;
const char ByteType::NodeId = 0;

View File

@ -169,6 +169,33 @@ public:
Float32Type() : AcceptorExtend("float32") {}
};
/// Float16 type (16-bit float)
class Float16Type : public AcceptorExtend<Float16Type, PrimitiveType> {
public:
static const char NodeId;
/// Constructs a float16 type.
Float16Type() : AcceptorExtend("float16") {}
};
/// BFloat16 type (16-bit brain float)
class BFloat16Type : public AcceptorExtend<BFloat16Type, PrimitiveType> {
public:
static const char NodeId;
/// Constructs a bfloat16 type.
BFloat16Type() : AcceptorExtend("bfloat16") {}
};
/// Float128 type (128-bit float)
class Float128Type : public AcceptorExtend<Float128Type, PrimitiveType> {
public:
static const char NodeId;
/// Constructs a float128 type.
Float128Type() : AcceptorExtend("float128") {}
};
/// Bool type (8-bit unsigned integer; either 0 or 1)
class BoolType : public AcceptorExtend<BoolType, PrimitiveType> {
public:

View File

@ -295,6 +295,15 @@ public:
void visit(const types::Float32Type *v) override {
fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString());
}
void visit(const types::Float16Type *v) override {
fmt::print(os, FMT_STRING("(float16 '\"{}\")"), v->referenceString());
}
void visit(const types::BFloat16Type *v) override {
fmt::print(os, FMT_STRING("(bfloat16 '\"{}\")"), v->referenceString());
}
void visit(const types::Float128Type *v) override {
fmt::print(os, FMT_STRING("(float128 '\"{}\")"), v->referenceString());
}
void visit(const types::BoolType *v) override {
fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString());
}

View File

@ -54,6 +54,9 @@ void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); }
void Visitor::visit(types::IntType *x) { defaultVisit(x); }
void Visitor::visit(types::FloatType *x) { defaultVisit(x); }
void Visitor::visit(types::Float32Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float16Type *x) { defaultVisit(x); }
void Visitor::visit(types::BFloat16Type *x) { defaultVisit(x); }
void Visitor::visit(types::Float128Type *x) { defaultVisit(x); }
void Visitor::visit(types::BoolType *x) { defaultVisit(x); }
void Visitor::visit(types::ByteType *x) { defaultVisit(x); }
void Visitor::visit(types::VoidType *x) { defaultVisit(x); }
@ -114,6 +117,9 @@ void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BFloat16Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::Float128Type *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); }
void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); }

View File

@ -19,6 +19,9 @@ class PrimitiveType;
class IntType;
class FloatType;
class Float32Type;
class Float16Type;
class BFloat16Type;
class Float128Type;
class BoolType;
class ByteType;
class VoidType;
@ -152,6 +155,9 @@ public:
VISIT(types::IntType);
VISIT(types::FloatType);
VISIT(types::Float32Type);
VISIT(types::Float16Type);
VISIT(types::BFloat16Type);
VISIT(types::Float128Type);
VISIT(types::BoolType);
VISIT(types::ByteType);
VISIT(types::VoidType);
@ -229,6 +235,9 @@ public:
CONST_VISIT(types::IntType);
CONST_VISIT(types::FloatType);
CONST_VISIT(types::Float32Type);
CONST_VISIT(types::Float16Type);
CONST_VISIT(types::BFloat16Type);
CONST_VISIT(types::Float128Type);
CONST_VISIT(types::BoolType);
CONST_VISIT(types::ByteType);
CONST_VISIT(types::VoidType);

View File

@ -709,6 +709,7 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) {
/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type
ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
expr->setType(unify(expr->type, ctx->getType("bool")));
expr->staticValue.type = StaticValue::INT; // prevent branching until this is resolved
transform(expr->args[0].value);
auto typ = expr->args[0].value->type->getClass();
if (!typ || !typ->canRealize())
@ -947,10 +948,11 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
auto &args = expr->args[0].value->getCall()->args;
for (size_t i = 0; i < args.size(); i++) {
realize(args[i].value->type);
fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(),
fmt::print(stderr, "[static_print] {}: {} := {}{} (iter: {})\n", getSrcInfo(),
FormatVisitor::apply(args[i].value),
args[i].value->type ? args[i].value->type->debugString(1) : "-",
args[i].value->isStatic() ? " [static]" : "");
args[i].value->isStatic() ? " [static]" : "",
ctx->getRealizationBase()->iteration);
}
return nullptr;
}

View File

@ -100,13 +100,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
if (auto l = i.second->getLink()) {
i.second->setSrcInfo(srcInfo);
if (l->defaultType) {
pendingDefaults.insert(i.second);
getRealizationBase()->pendingDefaults.insert(i.second);
}
}
}
if (t->getUnion() && !t->getUnion()->isSealed()) {
t->setSrcInfo(srcInfo);
pendingDefaults.insert(t);
getRealizationBase()->pendingDefaults.insert(t);
}
if (auto r = t->getRecord())
if (r->repeats && r->repeats->canRealize())

View File

@ -50,12 +50,12 @@ struct TypeContext : public Context<TypecheckItem> {
types::TypePtr returnType = nullptr;
/// Typechecking iteration
int iteration = 0;
std::set<types::TypePtr> pendingDefaults;
};
std::vector<RealizationBase> realizationBases;
/// The current type-checking level (for type instantiation and generalization).
int typecheckLevel;
std::set<types::TypePtr> pendingDefaults;
int changedNodes;
/// The age of the currently parsed statement.

View File

@ -99,8 +99,9 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
bool anotherRound = false;
// Special case: return type might have default as well (e.g., Union)
if (ctx->getRealizationBase()->returnType)
ctx->pendingDefaults.insert(ctx->getRealizationBase()->returnType);
for (auto &unbound : ctx->pendingDefaults) {
ctx->getRealizationBase()->pendingDefaults.insert(
ctx->getRealizationBase()->returnType);
for (auto &unbound : ctx->getRealizationBase()->pendingDefaults) {
if (auto tu = unbound->getUnion()) {
// Seal all dynamic unions after the iteration is over
if (!tu->isSealed()) {
@ -113,7 +114,7 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
anotherRound = true;
}
}
ctx->pendingDefaults.clear();
ctx->getRealizationBase()->pendingDefaults.clear();
if (anotherRound)
continue;
@ -653,6 +654,12 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
handle = module->getFloatType();
} else if (t->name == "float32") {
handle = module->getFloat32Type();
} else if (t->name == "float16") {
handle = module->getFloat16Type();
} else if (t->name == "bfloat16") {
handle = module->getBFloat16Type();
} else if (t->name == "float128") {
handle = module->getFloat128Type();
} else if (t->name == "str") {
handle = module->getStringType();
} else if (t->name == "Int" || t->name == "UInt") {
@ -936,7 +943,9 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
N<StringExpr>("invalid union call"))));
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
unify(type->getRetType(), ctx->instantiate(ctx->getType("Union")));
auto ret = ctx->instantiate(ctx->getType("Union"));
unify(type->getRetType(), ret);
ast->suite = suite;
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {
// def __internal__.get_union_first(union: Union[T0]):

View File

@ -49,6 +49,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
auto typ = expr->type;
if (!expr->done) {
bool isIntStatic = expr->staticValue.type == StaticValue::INT;
TypecheckVisitor v(ctx, prependStmts);
v.setSrcInfo(expr->getSrcInfo());
ctx->pushSrcInfo(expr->getSrcInfo());
@ -60,7 +61,8 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
expr = v.resultExpr;
}
seqassert(expr->type, "type not set for {}", expr);
unify(typ, expr->type);
if (!(isIntStatic && expr->type->is("bool")))
unify(typ, expr->type);
if (expr->done) {
ctx->changedNodes++;
}

View File

@ -45,6 +45,27 @@ class float32:
MIN_10_EXP = -37
pass
@tuple
@__internal__
@__notuple__
class float16:
MIN_10_EXP = -4
pass
@tuple
@__internal__
@__notuple__
class bfloat16:
MIN_10_EXP = -37
pass
@tuple
@__internal__
@__notuple__
class float128:
MIN_10_EXP = -4931
pass
@tuple
@__internal__
class type:

View File

@ -618,7 +618,7 @@ class __magic__:
# @dataclass parameter: gpu=True
def from_gpu_new(other: T, T: type) -> T:
__internal__.class_from_gpu_new(other)
return __internal__.class_from_gpu_new(other)
# @dataclass parameter: repr=True
def repr(slf) -> str:

View File

@ -1,6 +1,6 @@
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
@tuple
@tuple(python=False)
class complex64:
real: float32
imag: float32

View File

@ -92,7 +92,7 @@ class float:
mod = self % other
div = (self - mod) / other
if mod:
if (other < 0) != (mod < 0):
if (other < 0.0) != (mod < 0.0):
mod += other
div -= 1.0
else:
@ -475,7 +475,7 @@ class float32:
%tmp = fmul float %a, %b
ret float %tmp
def __floordiv__(self, other: float32) -> float:
def __floordiv__(self, other: float32) -> float32:
return self.__truediv__(other).__floor__()
@pure
@ -494,19 +494,19 @@ class float32:
mod = self % other
div = (self - mod) / other
if mod:
if (other < 0) != (mod < 0):
if (other < float32(0.0)) != (mod < float32(0.0)):
mod += other
div -= 1.0
div -= float32(1.0)
else:
mod = (0.0).copysign(other)
mod = float32(0.0).copysign(other)
floordiv = 0.0
floordiv = float32(0.0)
if div:
floordiv = div.__floor__()
if div - floordiv > 0.5:
floordiv += 1.0
if div - floordiv > float32(0.5):
floordiv += float32(1.0)
else:
floordiv = (0.0).copysign(self / other)
floordiv = float32(0.0).copysign(self / other)
return (floordiv, mod)
@ -752,10 +752,913 @@ class float32:
def __match__(self, obj: float32) -> bool:
return self == obj
@extend
class float16:
@pure
@llvm
def __new__(self: float) -> float16:
%0 = fptrunc double %self to half
ret half %0
def __new__(what: float16) -> float16:
return what
def __new__() -> float16:
return float16.__new__(0.0)
def __repr__(self) -> str:
return self.__float__().__repr__()
def __format__(self, format_spec: str) -> str:
return self.__float__().__format(format_spec)
def __copy__(self) -> float16:
return self
def __deepcopy__(self) -> float16:
return self
@pure
@llvm
def __int__(self) -> int:
%0 = fptosi half %self to i64
ret i64 %0
@pure
@llvm
def __float__(self) -> float:
%0 = fpext half %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp une half %self, 0.000000e+00
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> float16:
return self
@pure
@llvm
def __neg__(self) -> float16:
%0 = fneg half %self
ret half %0
@pure
@commutative
@llvm
def __add__(a: float16, b: float16) -> float16:
%tmp = fadd half %a, %b
ret half %tmp
@pure
@llvm
def __sub__(a: float16, b: float16) -> float16:
%tmp = fsub half %a, %b
ret half %tmp
@pure
@commutative
@llvm
def __mul__(a: float16, b: float16) -> float16:
%tmp = fmul half %a, %b
ret half %tmp
def __floordiv__(self, other: float16) -> float16:
return self.__truediv__(other).__floor__()
@pure
@llvm
def __truediv__(a: float16, b: float16) -> float16:
%tmp = fdiv half %a, %b
ret half %tmp
@pure
@llvm
def __mod__(a: float16, b: float16) -> float16:
%tmp = frem half %a, %b
ret half %tmp
def __divmod__(self, other: float16) -> Tuple[float16, float16]:
mod = self % other
div = (self - mod) / other
if mod:
if (other < float16(0.0)) != (mod < float16(0.0)):
mod += other
div -= float16(1.0)
else:
mod = float16(0.0).copysign(other)
floordiv = float16(0.0)
if div:
floordiv = div.__floor__()
if div - floordiv > float16(0.5):
floordiv += float16(1.0)
else:
floordiv = float16(0.0).copysign(self / other)
return (floordiv, mod)
@pure
@llvm
def __eq__(a: float16, b: float16) -> bool:
%tmp = fcmp oeq half %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ne__(a: float16, b: float16) -> bool:
%tmp = fcmp une half %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __lt__(a: float16, b: float16) -> bool:
%tmp = fcmp olt half %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __gt__(a: float16, b: float16) -> bool:
%tmp = fcmp ogt half %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __le__(a: float16, b: float16) -> bool:
%tmp = fcmp ole half %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ge__(a: float16, b: float16) -> bool:
%tmp = fcmp oge half %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def sqrt(a: float16) -> float16:
declare half @llvm.sqrt.f16(half %a)
%tmp = call half @llvm.sqrt.f16(half %a)
ret half %tmp
@pure
@llvm
def sin(a: float16) -> float16:
declare half @llvm.sin.f16(half %a)
%tmp = call half @llvm.sin.f16(half %a)
ret half %tmp
@pure
@llvm
def cos(a: float16) -> float16:
declare half @llvm.cos.f16(half %a)
%tmp = call half @llvm.cos.f16(half %a)
ret half %tmp
@pure
@llvm
def exp(a: float16) -> float16:
declare half @llvm.exp.f16(half %a)
%tmp = call half @llvm.exp.f16(half %a)
ret half %tmp
@pure
@llvm
def exp2(a: float16) -> float16:
declare half @llvm.exp2.f16(half %a)
%tmp = call half @llvm.exp2.f16(half %a)
ret half %tmp
@pure
@llvm
def log(a: float16) -> float16:
declare half @llvm.log.f16(half %a)
%tmp = call half @llvm.log.f16(half %a)
ret half %tmp
@pure
@llvm
def log10(a: float16) -> float16:
declare half @llvm.log10.f16(half %a)
%tmp = call half @llvm.log10.f16(half %a)
ret half %tmp
@pure
@llvm
def log2(a: float16) -> float16:
declare half @llvm.log2.f16(half %a)
%tmp = call half @llvm.log2.f16(half %a)
ret half %tmp
@pure
@llvm
def __abs__(a: float16) -> float16:
declare half @llvm.fabs.f16(half %a)
%tmp = call half @llvm.fabs.f16(half %a)
ret half %tmp
@pure
@llvm
def __floor__(a: float16) -> float16:
declare half @llvm.floor.f16(half %a)
%tmp = call half @llvm.floor.f16(half %a)
ret half %tmp
@pure
@llvm
def __ceil__(a: float16) -> float16:
declare half @llvm.ceil.f16(half %a)
%tmp = call half @llvm.ceil.f16(half %a)
ret half %tmp
@pure
@llvm
def __trunc__(a: float16) -> float16:
declare half @llvm.trunc.f16(half %a)
%tmp = call half @llvm.trunc.f16(half %a)
ret half %tmp
@pure
@llvm
def rint(a: float16) -> float16:
declare half @llvm.rint.f16(half %a)
%tmp = call half @llvm.rint.f16(half %a)
ret half %tmp
@pure
@llvm
def nearbyint(a: float16) -> float16:
declare half @llvm.nearbyint.f16(half %a)
%tmp = call half @llvm.nearbyint.f16(half %a)
ret half %tmp
@pure
@llvm
def __round__(a: float16) -> float16:
declare half @llvm.round.f16(half %a)
%tmp = call half @llvm.round.f16(half %a)
ret half %tmp
@pure
@llvm
def __pow__(a: float16, b: float16) -> float16:
declare half @llvm.pow.f16(half %a, half %b)
%tmp = call half @llvm.pow.f16(half %a, half %b)
ret half %tmp
@pure
@llvm
def min(a: float16, b: float16) -> float16:
declare half @llvm.minnum.f16(half %a, half %b)
%tmp = call half @llvm.minnum.f16(half %a, half %b)
ret half %tmp
@pure
@llvm
def max(a: float16, b: float16) -> float16:
declare half @llvm.maxnum.f16(half %a, half %b)
%tmp = call half @llvm.maxnum.f16(half %a, half %b)
ret half %tmp
@pure
@llvm
def copysign(a: float16, b: float16) -> float16:
declare half @llvm.copysign.f16(half %a, half %b)
%tmp = call half @llvm.copysign.f16(half %a, half %b)
ret half %tmp
@pure
@llvm
def fma(a: float16, b: float16, c: float16) -> float16:
declare half @llvm.fma.f16(half %a, half %b, half %c)
%tmp = call half @llvm.fma.f16(half %a, half %b, half %c)
ret half %tmp
def __hash__(self) -> int:
return self.__float__().__hash__()
def __match__(self, obj: float16) -> bool:
return self == obj
@extend
class bfloat16:
@pure
@llvm
def __new__(self: float) -> bfloat16:
%0 = fptrunc double %self to bfloat
ret bfloat %0
def __new__(what: bfloat16) -> bfloat16:
return what
def __new__() -> bfloat16:
return bfloat16.__new__(0.0)
def __repr__(self) -> str:
return self.__float__().__repr__()
def __format__(self, format_spec: str) -> str:
return self.__float__().__format(format_spec)
def __copy__(self) -> bfloat16:
return self
def __deepcopy__(self) -> bfloat16:
return self
@pure
@llvm
def __int__(self) -> int:
%0 = fptosi bfloat %self to i64
ret i64 %0
@pure
@llvm
def __float__(self) -> float:
%0 = fpext bfloat %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp une bfloat %self, 0.000000e+00
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> bfloat16:
return self
@pure
@llvm
def __neg__(self) -> bfloat16:
%0 = fneg bfloat %self
ret bfloat %0
@pure
@commutative
@llvm
def __add__(a: bfloat16, b: bfloat16) -> bfloat16:
%tmp = fadd bfloat %a, %b
ret bfloat %tmp
@pure
@llvm
def __sub__(a: bfloat16, b: bfloat16) -> bfloat16:
%tmp = fsub bfloat %a, %b
ret bfloat %tmp
@pure
@commutative
@llvm
def __mul__(a: bfloat16, b: bfloat16) -> bfloat16:
%tmp = fmul bfloat %a, %b
ret bfloat %tmp
def __floordiv__(self, other: bfloat16) -> bfloat16:
return self.__truediv__(other).__floor__()
@pure
@llvm
def __truediv__(a: bfloat16, b: bfloat16) -> bfloat16:
%tmp = fdiv bfloat %a, %b
ret bfloat %tmp
@pure
@llvm
def __mod__(a: bfloat16, b: bfloat16) -> bfloat16:
%tmp = frem bfloat %a, %b
ret bfloat %tmp
def __divmod__(self, other: bfloat16) -> Tuple[bfloat16, bfloat16]:
mod = self % other
div = (self - mod) / other
if mod:
if (other < bfloat16(0.0)) != (mod < bfloat16(0.0)):
mod += other
div -= bfloat16(1.0)
else:
mod = bfloat16(0.0).copysign(other)
floordiv = bfloat16(0.0)
if div:
floordiv = div.__floor__()
if div - floordiv > bfloat16(0.5):
floordiv += bfloat16(1.0)
else:
floordiv = bfloat16(0.0).copysign(self / other)
return (floordiv, mod)
@pure
@llvm
def __eq__(a: bfloat16, b: bfloat16) -> bool:
%tmp = fcmp oeq bfloat %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ne__(a: bfloat16, b: bfloat16) -> bool:
%tmp = fcmp une bfloat %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __lt__(a: bfloat16, b: bfloat16) -> bool:
%tmp = fcmp olt bfloat %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __gt__(a: bfloat16, b: bfloat16) -> bool:
%tmp = fcmp ogt bfloat %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __le__(a: bfloat16, b: bfloat16) -> bool:
%tmp = fcmp ole bfloat %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ge__(a: bfloat16, b: bfloat16) -> bool:
%tmp = fcmp oge bfloat %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def sqrt(a: bfloat16) -> bfloat16:
declare bfloat @llvm.sqrt.bf16(bfloat %a)
%tmp = call bfloat @llvm.sqrt.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def sin(a: bfloat16) -> bfloat16:
declare bfloat @llvm.sin.bf16(bfloat %a)
%tmp = call bfloat @llvm.sin.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def cos(a: bfloat16) -> bfloat16:
declare bfloat @llvm.cos.bf16(bfloat %a)
%tmp = call bfloat @llvm.cos.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def exp(a: bfloat16) -> bfloat16:
declare bfloat @llvm.exp.bf16(bfloat %a)
%tmp = call bfloat @llvm.exp.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def exp2(a: bfloat16) -> bfloat16:
declare bfloat @llvm.exp2.bf16(bfloat %a)
%tmp = call bfloat @llvm.exp2.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def log(a: bfloat16) -> bfloat16:
declare bfloat @llvm.log.bf16(bfloat %a)
%tmp = call bfloat @llvm.log.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def log10(a: bfloat16) -> bfloat16:
declare bfloat @llvm.log10.bf16(bfloat %a)
%tmp = call bfloat @llvm.log10.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def log2(a: bfloat16) -> bfloat16:
declare bfloat @llvm.log2.bf16(bfloat %a)
%tmp = call bfloat @llvm.log2.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def __abs__(a: bfloat16) -> bfloat16:
declare bfloat @llvm.fabs.bf16(bfloat %a)
%tmp = call bfloat @llvm.fabs.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def __floor__(a: bfloat16) -> bfloat16:
declare bfloat @llvm.floor.bf16(bfloat %a)
%tmp = call bfloat @llvm.floor.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def __ceil__(a: bfloat16) -> bfloat16:
declare bfloat @llvm.ceil.bf16(bfloat %a)
%tmp = call bfloat @llvm.ceil.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def __trunc__(a: bfloat16) -> bfloat16:
declare bfloat @llvm.trunc.bf16(bfloat %a)
%tmp = call bfloat @llvm.trunc.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def rint(a: bfloat16) -> bfloat16:
declare bfloat @llvm.rint.bf16(bfloat %a)
%tmp = call bfloat @llvm.rint.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def nearbyint(a: bfloat16) -> bfloat16:
declare bfloat @llvm.nearbyint.bf16(bfloat %a)
%tmp = call bfloat @llvm.nearbyint.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def __round__(a: bfloat16) -> bfloat16:
declare bfloat @llvm.round.bf16(bfloat %a)
%tmp = call bfloat @llvm.round.bf16(bfloat %a)
ret bfloat %tmp
@pure
@llvm
def __pow__(a: bfloat16, b: bfloat16) -> bfloat16:
declare bfloat @llvm.pow.bf16(bfloat %a, bfloat %b)
%tmp = call bfloat @llvm.pow.bf16(bfloat %a, bfloat %b)
ret bfloat %tmp
@pure
@llvm
def min(a: bfloat16, b: bfloat16) -> bfloat16:
declare bfloat @llvm.minnum.bf16(bfloat %a, bfloat %b)
%tmp = call bfloat @llvm.minnum.bf16(bfloat %a, bfloat %b)
ret bfloat %tmp
@pure
@llvm
def max(a: bfloat16, b: bfloat16) -> bfloat16:
declare bfloat @llvm.maxnum.bf16(bfloat %a, bfloat %b)
%tmp = call bfloat @llvm.maxnum.bf16(bfloat %a, bfloat %b)
ret bfloat %tmp
@pure
@llvm
def copysign(a: bfloat16, b: bfloat16) -> bfloat16:
declare bfloat @llvm.copysign.bf16(bfloat %a, bfloat %b)
%tmp = call bfloat @llvm.copysign.bf16(bfloat %a, bfloat %b)
ret bfloat %tmp
@pure
@llvm
def fma(a: bfloat16, b: bfloat16, c: bfloat16) -> bfloat16:
declare bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
%tmp = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
ret bfloat %tmp
def __hash__(self) -> int:
return self.__float__().__hash__()
def __match__(self, obj: bfloat16) -> bool:
return self == obj
@extend
class float128:
@pure
@llvm
def __new__(self: float) -> float128:
%0 = fpext double %self to fp128
ret fp128 %0
def __new__(what: float128) -> float128:
return what
def __new__() -> float128:
return float128.__new__(0.0)
def __repr__(self) -> str:
return self.__float__().__repr__()
def __format__(self, format_spec: str) -> str:
return self.__float__().__format(format_spec)
def __copy__(self) -> float128:
return self
def __deepcopy__(self) -> float128:
return self
@pure
@llvm
def __int__(self) -> int:
%0 = fptosi fp128 %self to i64
ret i64 %0
@pure
@llvm
def __float__(self) -> float:
%0 = fptrunc fp128 %self to double
ret double %0
@pure
@llvm
def __bool__(self) -> bool:
%0 = fcmp une fp128 %self, 0xL00000000000000000000000000000000
%1 = zext i1 %0 to i8
ret i8 %1
def __pos__(self) -> float128:
return self
@pure
@llvm
def __neg__(self) -> float128:
%0 = fneg fp128 %self
ret fp128 %0
@pure
@commutative
@llvm
def __add__(a: float128, b: float128) -> float128:
%tmp = fadd fp128 %a, %b
ret fp128 %tmp
@pure
@llvm
def __sub__(a: float128, b: float128) -> float128:
%tmp = fsub fp128 %a, %b
ret fp128 %tmp
@pure
@commutative
@llvm
def __mul__(a: float128, b: float128) -> float128:
%tmp = fmul fp128 %a, %b
ret fp128 %tmp
def __floordiv__(self, other: float128) -> float128:
return self.__truediv__(other).__floor__()
@pure
@llvm
def __truediv__(a: float128, b: float128) -> float128:
%tmp = fdiv fp128 %a, %b
ret fp128 %tmp
@pure
@llvm
def __mod__(a: float128, b: float128) -> float128:
%tmp = frem fp128 %a, %b
ret fp128 %tmp
def __divmod__(self, other: float128) -> Tuple[float128, float128]:
mod = self % other
div = (self - mod) / other
if mod:
if (other < float128(0.0)) != (mod < float128(0)):
mod += other
div -= float128(1.0)
else:
mod = float128(0.0).copysign(other)
floordiv = float128(0.0)
if div:
floordiv = div.__floor__()
if div - floordiv > float128(0.5):
floordiv += float128(1.0)
else:
floordiv = float128(0.0).copysign(self / other)
return (floordiv, mod)
@pure
@llvm
def __eq__(a: float128, b: float128) -> bool:
%tmp = fcmp oeq fp128 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ne__(a: float128, b: float128) -> bool:
%tmp = fcmp une fp128 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __lt__(a: float128, b: float128) -> bool:
%tmp = fcmp olt fp128 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __gt__(a: float128, b: float128) -> bool:
%tmp = fcmp ogt fp128 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __le__(a: float128, b: float128) -> bool:
%tmp = fcmp ole fp128 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def __ge__(a: float128, b: float128) -> bool:
%tmp = fcmp oge fp128 %a, %b
%res = zext i1 %tmp to i8
ret i8 %res
@pure
@llvm
def sqrt(a: float128) -> float128:
declare fp128 @llvm.sqrt.f128(fp128 %a)
%tmp = call fp128 @llvm.sqrt.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def sin(a: float128) -> float128:
declare fp128 @llvm.sin.f128(fp128 %a)
%tmp = call fp128 @llvm.sin.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def cos(a: float128) -> float128:
declare fp128 @llvm.cos.f128(fp128 %a)
%tmp = call fp128 @llvm.cos.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def exp(a: float128) -> float128:
declare fp128 @llvm.exp.f128(fp128 %a)
%tmp = call fp128 @llvm.exp.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def exp2(a: float128) -> float128:
declare fp128 @llvm.exp2.f128(fp128 %a)
%tmp = call fp128 @llvm.exp2.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def log(a: float128) -> float128:
declare fp128 @llvm.log.f128(fp128 %a)
%tmp = call fp128 @llvm.log.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def log10(a: float128) -> float128:
declare fp128 @llvm.log10.f128(fp128 %a)
%tmp = call fp128 @llvm.log10.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def log2(a: float128) -> float128:
declare fp128 @llvm.log2.f128(fp128 %a)
%tmp = call fp128 @llvm.log2.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def __abs__(a: float128) -> float128:
declare fp128 @llvm.fabs.f128(fp128 %a)
%tmp = call fp128 @llvm.fabs.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def __floor__(a: float128) -> float128:
declare fp128 @llvm.floor.f128(fp128 %a)
%tmp = call fp128 @llvm.floor.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def __ceil__(a: float128) -> float128:
declare fp128 @llvm.ceil.f128(fp128 %a)
%tmp = call fp128 @llvm.ceil.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def __trunc__(a: float128) -> float128:
declare fp128 @llvm.trunc.f128(fp128 %a)
%tmp = call fp128 @llvm.trunc.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def rint(a: float128) -> float128:
declare fp128 @llvm.rint.f128(fp128 %a)
%tmp = call fp128 @llvm.rint.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def nearbyint(a: float128) -> float128:
declare fp128 @llvm.nearbyint.f128(fp128 %a)
%tmp = call fp128 @llvm.nearbyint.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def __round__(a: float128) -> float128:
declare fp128 @llvm.round.f128(fp128 %a)
%tmp = call fp128 @llvm.round.f128(fp128 %a)
ret fp128 %tmp
@pure
@llvm
def __pow__(a: float128, b: float128) -> float128:
declare fp128 @llvm.pow.f128(fp128 %a, fp128 %b)
%tmp = call fp128 @llvm.pow.f128(fp128 %a, fp128 %b)
ret fp128 %tmp
@pure
@llvm
def min(a: float128, b: float128) -> float128:
declare fp128 @llvm.minnum.f128(fp128 %a, fp128 %b)
%tmp = call fp128 @llvm.minnum.f128(fp128 %a, fp128 %b)
ret fp128 %tmp
@pure
@llvm
def max(a: float128, b: float128) -> float128:
declare fp128 @llvm.maxnum.f128(fp128 %a, fp128 %b)
%tmp = call fp128 @llvm.maxnum.f128(fp128 %a, fp128 %b)
ret fp128 %tmp
@pure
@llvm
def copysign(a: float128, b: float128) -> float128:
declare fp128 @llvm.copysign.f128(fp128 %a, fp128 %b)
%tmp = call fp128 @llvm.copysign.f128(fp128 %a, fp128 %b)
ret fp128 %tmp
@pure
@llvm
def fma(a: float128, b: float128, c: float128) -> float128:
declare fp128 @llvm.fma.f128(fp128 %a, fp128 %b, fp128 %c)
%tmp = call fp128 @llvm.fma.f128(fp128 %a, fp128 %b, fp128 %c)
ret fp128 %tmp
def __hash__(self) -> int:
return self.__float__().__hash__()
def __match__(self, obj: float128) -> bool:
return self == obj
@extend
class float:
def __suffix_f32__(double) -> float32:
return float32.__new__(double)
def __suffix_f16__(double) -> float16:
return float16.__new__(double)
def __suffix_bf16__(double) -> bfloat16:
return bfloat16.__new__(double)
def __suffix_f128__(double) -> float128:
return float128.__new__(double)
f16 = float16
bf16 = bfloat16
f32 = float32
f64 = float
f128 = float128

View File

@ -142,3 +142,45 @@ def test_int_pow():
assert f(T2(0)) ** f(T2(0)) == T2(1)
assert str(f(T2(31)) ** f(T2(31))) == '17069174130723235958610643029059314756044734431'
test_int_pow()
@test
def test_float(F: type):
x = F(5.5)
assert str(x) == '5.5'
assert F(x) == x
assert F() == F(0.0)
assert x.__copy__() == x
assert x.__deepcopy__() == x
assert int(x) == 5
assert float(x) == 5.5
assert bool(x)
assert not bool(F())
assert +x == x
assert -x == F(-5.5)
assert x + x == F(11.0)
assert x - F(1.0) == F(4.5)
assert x * F(3.0) == F(16.5)
assert x / F(2.0) == F(2.75)
if F is not float128: # LLVM ops give wrong results for fp128
assert x // F(2.0) == F(2.0)
assert x % F(0.75) == F(0.25)
assert divmod(x, F(0.75)) == (F(7.0), F(0.25))
assert x == x
assert x != F()
assert x < F(6.5)
assert x > F(4.5)
assert x <= F(6.5)
assert x >= F(4.5)
assert x >= x
assert x <= x
assert abs(x) == x
assert abs(-x) == x
assert x.__match__(x)
assert not x.__match__(F())
assert hash(x) == hash(5.5)
test_float(float)
test_float(float32)
#test_float(float16)
#test_float(bfloat16)
#test_float(float128)

View File

@ -34,9 +34,19 @@ def test_conversions():
def kernel(x, v):
v[0] = x
def empty_tuple(x):
if staticlen(x) == 0:
return ()
else:
T = type(x[0])
return (T(),) + empty_tuple(x[1:])
def check(x):
T = type(x)
v = [T()]
if isinstance(x, Tuple):
e = empty_tuple(x)
else:
e = type(x)()
v = [e]
kernel(x, v, grid=1, block=1)
return v == [x]