mirror of https://github.com/exaloop/codon.git
GPU compilation fixes (#496)
* Fix __from_gpu_new__ * Fix GPU tests * Update GPU debug codegen * Add will-return attribute for GPU compilation * Fix isinstance on unresolved types * Fix union type instantiation and pendingRealizations placement * Add float16, bfloat16 and float128 IR types * Add float16, bfloat16 and float128 types * Mark complex64 as no-python * Fix float methods * Add float tests * Disable some float tests * Fix bitset in reaching definitions analysis * Fix static bool unification --------- Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>pull/499/head
parent
4eb641e3cb
commit
2c7440768d
codon
cir
analyze/dataflow
parser/visitors/typecheck
stdlib/internal
test
core
transform
|
@ -71,10 +71,10 @@ struct BitSet {
|
|||
return res;
|
||||
}
|
||||
|
||||
void set(unsigned bit) { words.data()[bit / B] |= (1 << (bit % B)); }
|
||||
void set(unsigned bit) { words.data()[bit / B] |= (1UL << (bit % B)); }
|
||||
|
||||
bool get(unsigned bit) const {
|
||||
return (words.data()[bit / B] & (1 << (bit % B))) != 0;
|
||||
return (words.data()[bit / B] & (1UL << (bit % B))) != 0;
|
||||
}
|
||||
|
||||
bool equals(const BitSet &other, unsigned size) {
|
||||
|
|
|
@ -500,6 +500,14 @@ void moduleToPTX(llvm::Module *M, const std::string &filename,
|
|||
linkLibdevice(M, libdevice);
|
||||
remapFunctions(M);
|
||||
|
||||
// Strip debug info and remove noinline from functions (added in debug mode).
|
||||
// Also, tell LLVM that all functions will return.
|
||||
for (auto &F : *M) {
|
||||
F.removeFnAttr(llvm::Attribute::AttrKind::NoInline);
|
||||
F.setWillReturn();
|
||||
}
|
||||
llvm::StripDebugInfo(*M);
|
||||
|
||||
// Run NVPTX passes and general opt pipeline.
|
||||
{
|
||||
llvm::LoopAnalysisManager lam;
|
||||
|
|
|
@ -2086,6 +2086,18 @@ llvm::Type *LLVMVisitor::getLLVMType(types::Type *t) {
|
|||
return B->getFloatTy();
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::Float16Type>(t)) {
|
||||
return B->getHalfTy();
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::BFloat16Type>(t)) {
|
||||
return B->getBFloatTy();
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::Float128Type>(t)) {
|
||||
return llvm::Type::getFP128Ty(*context);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::BoolType>(t)) {
|
||||
return B->getInt8Ty();
|
||||
}
|
||||
|
@ -2203,6 +2215,22 @@ llvm::DIType *LLVMVisitor::getDITypeHelper(
|
|||
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::Float16Type>(t)) {
|
||||
return db.builder->createBasicType(
|
||||
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::BFloat16Type>(t)) {
|
||||
return db.builder->createBasicType(
|
||||
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_float);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::Float128Type>(t)) {
|
||||
return db.builder->createBasicType(x->getName(),
|
||||
layout.getTypeAllocSizeInBits(type),
|
||||
llvm::dwarf::DW_ATE_HP_float128);
|
||||
}
|
||||
|
||||
if (auto *x = cast<types::BoolType>(t)) {
|
||||
return db.builder->createBasicType(
|
||||
x->getName(), layout.getTypeAllocSizeInBits(type), llvm::dwarf::DW_ATE_boolean);
|
||||
|
|
|
@ -62,6 +62,9 @@ const std::string Module::BYTE_NAME = "byte";
|
|||
const std::string Module::INT_NAME = "int";
|
||||
const std::string Module::FLOAT_NAME = "float";
|
||||
const std::string Module::FLOAT32_NAME = "float32";
|
||||
const std::string Module::FLOAT16_NAME = "float16";
|
||||
const std::string Module::BFLOAT16_NAME = "bfloat16";
|
||||
const std::string Module::FLOAT128_NAME = "float128";
|
||||
const std::string Module::STRING_NAME = "str";
|
||||
|
||||
const std::string Module::EQ_MAGIC_NAME = "__eq__";
|
||||
|
@ -239,6 +242,24 @@ types::Type *Module::getFloat32Type() {
|
|||
return Nr<types::Float32Type>();
|
||||
}
|
||||
|
||||
types::Type *Module::getFloat16Type() {
|
||||
if (auto *rVal = getType(FLOAT16_NAME))
|
||||
return rVal;
|
||||
return Nr<types::Float16Type>();
|
||||
}
|
||||
|
||||
types::Type *Module::getBFloat16Type() {
|
||||
if (auto *rVal = getType(BFLOAT16_NAME))
|
||||
return rVal;
|
||||
return Nr<types::BFloat16Type>();
|
||||
}
|
||||
|
||||
types::Type *Module::getFloat128Type() {
|
||||
if (auto *rVal = getType(FLOAT128_NAME))
|
||||
return rVal;
|
||||
return Nr<types::Float128Type>();
|
||||
}
|
||||
|
||||
types::Type *Module::getStringType() {
|
||||
if (auto *rVal = getType(STRING_NAME))
|
||||
return rVal;
|
||||
|
|
|
@ -34,6 +34,9 @@ public:
|
|||
static const std::string INT_NAME;
|
||||
static const std::string FLOAT_NAME;
|
||||
static const std::string FLOAT32_NAME;
|
||||
static const std::string FLOAT16_NAME;
|
||||
static const std::string BFLOAT16_NAME;
|
||||
static const std::string FLOAT128_NAME;
|
||||
static const std::string STRING_NAME;
|
||||
|
||||
static const std::string EQ_MAGIC_NAME;
|
||||
|
@ -338,6 +341,12 @@ public:
|
|||
types::Type *getFloatType();
|
||||
/// @return the float32 type
|
||||
types::Type *getFloat32Type();
|
||||
/// @return the float16 type
|
||||
types::Type *getFloat16Type();
|
||||
/// @return the bfloat16 type
|
||||
types::Type *getBFloat16Type();
|
||||
/// @return the float128 type
|
||||
types::Type *getFloat128Type();
|
||||
/// @return the string type
|
||||
types::Type *getStringType();
|
||||
/// Gets a pointer type.
|
||||
|
|
|
@ -69,6 +69,12 @@ const char FloatType::NodeId = 0;
|
|||
|
||||
const char Float32Type::NodeId = 0;
|
||||
|
||||
const char Float16Type::NodeId = 0;
|
||||
|
||||
const char BFloat16Type::NodeId = 0;
|
||||
|
||||
const char Float128Type::NodeId = 0;
|
||||
|
||||
const char BoolType::NodeId = 0;
|
||||
|
||||
const char ByteType::NodeId = 0;
|
||||
|
|
|
@ -169,6 +169,33 @@ public:
|
|||
Float32Type() : AcceptorExtend("float32") {}
|
||||
};
|
||||
|
||||
/// Float16 type (16-bit float)
|
||||
class Float16Type : public AcceptorExtend<Float16Type, PrimitiveType> {
|
||||
public:
|
||||
static const char NodeId;
|
||||
|
||||
/// Constructs a float16 type.
|
||||
Float16Type() : AcceptorExtend("float16") {}
|
||||
};
|
||||
|
||||
/// BFloat16 type (16-bit brain float)
|
||||
class BFloat16Type : public AcceptorExtend<BFloat16Type, PrimitiveType> {
|
||||
public:
|
||||
static const char NodeId;
|
||||
|
||||
/// Constructs a bfloat16 type.
|
||||
BFloat16Type() : AcceptorExtend("bfloat16") {}
|
||||
};
|
||||
|
||||
/// Float128 type (128-bit float)
|
||||
class Float128Type : public AcceptorExtend<Float128Type, PrimitiveType> {
|
||||
public:
|
||||
static const char NodeId;
|
||||
|
||||
/// Constructs a float128 type.
|
||||
Float128Type() : AcceptorExtend("float128") {}
|
||||
};
|
||||
|
||||
/// Bool type (8-bit unsigned integer; either 0 or 1)
|
||||
class BoolType : public AcceptorExtend<BoolType, PrimitiveType> {
|
||||
public:
|
||||
|
|
|
@ -295,6 +295,15 @@ public:
|
|||
void visit(const types::Float32Type *v) override {
|
||||
fmt::print(os, FMT_STRING("(float32 '\"{}\")"), v->referenceString());
|
||||
}
|
||||
void visit(const types::Float16Type *v) override {
|
||||
fmt::print(os, FMT_STRING("(float16 '\"{}\")"), v->referenceString());
|
||||
}
|
||||
void visit(const types::BFloat16Type *v) override {
|
||||
fmt::print(os, FMT_STRING("(bfloat16 '\"{}\")"), v->referenceString());
|
||||
}
|
||||
void visit(const types::Float128Type *v) override {
|
||||
fmt::print(os, FMT_STRING("(float128 '\"{}\")"), v->referenceString());
|
||||
}
|
||||
void visit(const types::BoolType *v) override {
|
||||
fmt::print(os, FMT_STRING("(bool '\"{}\")"), v->referenceString());
|
||||
}
|
||||
|
|
|
@ -54,6 +54,9 @@ void Visitor::visit(types::PrimitiveType *x) { defaultVisit(x); }
|
|||
void Visitor::visit(types::IntType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::FloatType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::Float32Type *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::Float16Type *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::BFloat16Type *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::Float128Type *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::BoolType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::ByteType *x) { defaultVisit(x); }
|
||||
void Visitor::visit(types::VoidType *x) { defaultVisit(x); }
|
||||
|
@ -114,6 +117,9 @@ void ConstVisitor::visit(const types::PrimitiveType *x) { defaultVisit(x); }
|
|||
void ConstVisitor::visit(const types::IntType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::FloatType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::Float32Type *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::Float16Type *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::BFloat16Type *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::Float128Type *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::BoolType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::ByteType *x) { defaultVisit(x); }
|
||||
void ConstVisitor::visit(const types::VoidType *x) { defaultVisit(x); }
|
||||
|
|
|
@ -19,6 +19,9 @@ class PrimitiveType;
|
|||
class IntType;
|
||||
class FloatType;
|
||||
class Float32Type;
|
||||
class Float16Type;
|
||||
class BFloat16Type;
|
||||
class Float128Type;
|
||||
class BoolType;
|
||||
class ByteType;
|
||||
class VoidType;
|
||||
|
@ -152,6 +155,9 @@ public:
|
|||
VISIT(types::IntType);
|
||||
VISIT(types::FloatType);
|
||||
VISIT(types::Float32Type);
|
||||
VISIT(types::Float16Type);
|
||||
VISIT(types::BFloat16Type);
|
||||
VISIT(types::Float128Type);
|
||||
VISIT(types::BoolType);
|
||||
VISIT(types::ByteType);
|
||||
VISIT(types::VoidType);
|
||||
|
@ -229,6 +235,9 @@ public:
|
|||
CONST_VISIT(types::IntType);
|
||||
CONST_VISIT(types::FloatType);
|
||||
CONST_VISIT(types::Float32Type);
|
||||
CONST_VISIT(types::Float16Type);
|
||||
CONST_VISIT(types::BFloat16Type);
|
||||
CONST_VISIT(types::Float128Type);
|
||||
CONST_VISIT(types::BoolType);
|
||||
CONST_VISIT(types::ByteType);
|
||||
CONST_VISIT(types::VoidType);
|
||||
|
|
|
@ -709,6 +709,7 @@ ExprPtr TypecheckVisitor::transformArray(CallExpr *expr) {
|
|||
/// `isinstance(obj, ByRef)` is True if `type(obj)` is a reference type
|
||||
ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) {
|
||||
expr->setType(unify(expr->type, ctx->getType("bool")));
|
||||
expr->staticValue.type = StaticValue::INT; // prevent branching until this is resolved
|
||||
transform(expr->args[0].value);
|
||||
auto typ = expr->args[0].value->type->getClass();
|
||||
if (!typ || !typ->canRealize())
|
||||
|
@ -947,10 +948,11 @@ ExprPtr TypecheckVisitor::transformStaticPrintFn(CallExpr *expr) {
|
|||
auto &args = expr->args[0].value->getCall()->args;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
realize(args[i].value->type);
|
||||
fmt::print(stderr, "[static_print] {}: {} := {}{}\n", getSrcInfo(),
|
||||
fmt::print(stderr, "[static_print] {}: {} := {}{} (iter: {})\n", getSrcInfo(),
|
||||
FormatVisitor::apply(args[i].value),
|
||||
args[i].value->type ? args[i].value->type->debugString(1) : "-",
|
||||
args[i].value->isStatic() ? " [static]" : "");
|
||||
args[i].value->isStatic() ? " [static]" : "",
|
||||
ctx->getRealizationBase()->iteration);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -100,13 +100,13 @@ types::TypePtr TypeContext::instantiate(const SrcInfo &srcInfo,
|
|||
if (auto l = i.second->getLink()) {
|
||||
i.second->setSrcInfo(srcInfo);
|
||||
if (l->defaultType) {
|
||||
pendingDefaults.insert(i.second);
|
||||
getRealizationBase()->pendingDefaults.insert(i.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (t->getUnion() && !t->getUnion()->isSealed()) {
|
||||
t->setSrcInfo(srcInfo);
|
||||
pendingDefaults.insert(t);
|
||||
getRealizationBase()->pendingDefaults.insert(t);
|
||||
}
|
||||
if (auto r = t->getRecord())
|
||||
if (r->repeats && r->repeats->canRealize())
|
||||
|
|
|
@ -50,12 +50,12 @@ struct TypeContext : public Context<TypecheckItem> {
|
|||
types::TypePtr returnType = nullptr;
|
||||
/// Typechecking iteration
|
||||
int iteration = 0;
|
||||
std::set<types::TypePtr> pendingDefaults;
|
||||
};
|
||||
std::vector<RealizationBase> realizationBases;
|
||||
|
||||
/// The current type-checking level (for type instantiation and generalization).
|
||||
int typecheckLevel;
|
||||
std::set<types::TypePtr> pendingDefaults;
|
||||
int changedNodes;
|
||||
|
||||
/// The age of the currently parsed statement.
|
||||
|
|
|
@ -99,8 +99,9 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
|
|||
bool anotherRound = false;
|
||||
// Special case: return type might have default as well (e.g., Union)
|
||||
if (ctx->getRealizationBase()->returnType)
|
||||
ctx->pendingDefaults.insert(ctx->getRealizationBase()->returnType);
|
||||
for (auto &unbound : ctx->pendingDefaults) {
|
||||
ctx->getRealizationBase()->pendingDefaults.insert(
|
||||
ctx->getRealizationBase()->returnType);
|
||||
for (auto &unbound : ctx->getRealizationBase()->pendingDefaults) {
|
||||
if (auto tu = unbound->getUnion()) {
|
||||
// Seal all dynamic unions after the iteration is over
|
||||
if (!tu->isSealed()) {
|
||||
|
@ -113,7 +114,7 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) {
|
|||
anotherRound = true;
|
||||
}
|
||||
}
|
||||
ctx->pendingDefaults.clear();
|
||||
ctx->getRealizationBase()->pendingDefaults.clear();
|
||||
if (anotherRound)
|
||||
continue;
|
||||
|
||||
|
@ -653,6 +654,12 @@ ir::types::Type *TypecheckVisitor::makeIRType(types::ClassType *t) {
|
|||
handle = module->getFloatType();
|
||||
} else if (t->name == "float32") {
|
||||
handle = module->getFloat32Type();
|
||||
} else if (t->name == "float16") {
|
||||
handle = module->getFloat16Type();
|
||||
} else if (t->name == "bfloat16") {
|
||||
handle = module->getBFloat16Type();
|
||||
} else if (t->name == "float128") {
|
||||
handle = module->getFloat128Type();
|
||||
} else if (t->name == "str") {
|
||||
handle = module->getStringType();
|
||||
} else if (t->name == "Int" || t->name == "UInt") {
|
||||
|
@ -936,7 +943,9 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) {
|
|||
N<ThrowStmt>(N<CallExpr>(N<IdExpr>("std.internal.types.error.TypeError"),
|
||||
N<StringExpr>("invalid union call"))));
|
||||
// suite->stmts.push_back(N<ReturnStmt>(N<NoneExpr>()));
|
||||
unify(type->getRetType(), ctx->instantiate(ctx->getType("Union")));
|
||||
|
||||
auto ret = ctx->instantiate(ctx->getType("Union"));
|
||||
unify(type->getRetType(), ret);
|
||||
ast->suite = suite;
|
||||
} else if (startswith(ast->name, "__internal__.get_union_first:0")) {
|
||||
// def __internal__.get_union_first(union: Union[T0]):
|
||||
|
|
|
@ -49,6 +49,7 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
|
|||
|
||||
auto typ = expr->type;
|
||||
if (!expr->done) {
|
||||
bool isIntStatic = expr->staticValue.type == StaticValue::INT;
|
||||
TypecheckVisitor v(ctx, prependStmts);
|
||||
v.setSrcInfo(expr->getSrcInfo());
|
||||
ctx->pushSrcInfo(expr->getSrcInfo());
|
||||
|
@ -60,7 +61,8 @@ ExprPtr TypecheckVisitor::transform(ExprPtr &expr) {
|
|||
expr = v.resultExpr;
|
||||
}
|
||||
seqassert(expr->type, "type not set for {}", expr);
|
||||
unify(typ, expr->type);
|
||||
if (!(isIntStatic && expr->type->is("bool")))
|
||||
unify(typ, expr->type);
|
||||
if (expr->done) {
|
||||
ctx->changedNodes++;
|
||||
}
|
||||
|
|
|
@ -45,6 +45,27 @@ class float32:
|
|||
MIN_10_EXP = -37
|
||||
pass
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
@__notuple__
|
||||
class float16:
|
||||
MIN_10_EXP = -4
|
||||
pass
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
@__notuple__
|
||||
class bfloat16:
|
||||
MIN_10_EXP = -37
|
||||
pass
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
@__notuple__
|
||||
class float128:
|
||||
MIN_10_EXP = -4931
|
||||
pass
|
||||
|
||||
@tuple
|
||||
@__internal__
|
||||
class type:
|
||||
|
|
|
@ -618,7 +618,7 @@ class __magic__:
|
|||
|
||||
# @dataclass parameter: gpu=True
|
||||
def from_gpu_new(other: T, T: type) -> T:
|
||||
__internal__.class_from_gpu_new(other)
|
||||
return __internal__.class_from_gpu_new(other)
|
||||
|
||||
# @dataclass parameter: repr=True
|
||||
def repr(slf) -> str:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
@tuple
|
||||
@tuple(python=False)
|
||||
class complex64:
|
||||
real: float32
|
||||
imag: float32
|
||||
|
|
|
@ -92,7 +92,7 @@ class float:
|
|||
mod = self % other
|
||||
div = (self - mod) / other
|
||||
if mod:
|
||||
if (other < 0) != (mod < 0):
|
||||
if (other < 0.0) != (mod < 0.0):
|
||||
mod += other
|
||||
div -= 1.0
|
||||
else:
|
||||
|
@ -475,7 +475,7 @@ class float32:
|
|||
%tmp = fmul float %a, %b
|
||||
ret float %tmp
|
||||
|
||||
def __floordiv__(self, other: float32) -> float:
|
||||
def __floordiv__(self, other: float32) -> float32:
|
||||
return self.__truediv__(other).__floor__()
|
||||
|
||||
@pure
|
||||
|
@ -494,19 +494,19 @@ class float32:
|
|||
mod = self % other
|
||||
div = (self - mod) / other
|
||||
if mod:
|
||||
if (other < 0) != (mod < 0):
|
||||
if (other < float32(0.0)) != (mod < float32(0.0)):
|
||||
mod += other
|
||||
div -= 1.0
|
||||
div -= float32(1.0)
|
||||
else:
|
||||
mod = (0.0).copysign(other)
|
||||
mod = float32(0.0).copysign(other)
|
||||
|
||||
floordiv = 0.0
|
||||
floordiv = float32(0.0)
|
||||
if div:
|
||||
floordiv = div.__floor__()
|
||||
if div - floordiv > 0.5:
|
||||
floordiv += 1.0
|
||||
if div - floordiv > float32(0.5):
|
||||
floordiv += float32(1.0)
|
||||
else:
|
||||
floordiv = (0.0).copysign(self / other)
|
||||
floordiv = float32(0.0).copysign(self / other)
|
||||
|
||||
return (floordiv, mod)
|
||||
|
||||
|
@ -752,10 +752,913 @@ class float32:
|
|||
def __match__(self, obj: float32) -> bool:
|
||||
return self == obj
|
||||
|
||||
@extend
|
||||
class float16:
|
||||
@pure
|
||||
@llvm
|
||||
def __new__(self: float) -> float16:
|
||||
%0 = fptrunc double %self to half
|
||||
ret half %0
|
||||
|
||||
def __new__(what: float16) -> float16:
|
||||
return what
|
||||
|
||||
def __new__() -> float16:
|
||||
return float16.__new__(0.0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__float__().__repr__()
|
||||
|
||||
def __format__(self, format_spec: str) -> str:
|
||||
return self.__float__().__format(format_spec)
|
||||
|
||||
def __copy__(self) -> float16:
|
||||
return self
|
||||
|
||||
def __deepcopy__(self) -> float16:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __int__(self) -> int:
|
||||
%0 = fptosi half %self to i64
|
||||
ret i64 %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __float__(self) -> float:
|
||||
%0 = fpext half %self to double
|
||||
ret double %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __bool__(self) -> bool:
|
||||
%0 = fcmp une half %self, 0.000000e+00
|
||||
%1 = zext i1 %0 to i8
|
||||
ret i8 %1
|
||||
|
||||
def __pos__(self) -> float16:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __neg__(self) -> float16:
|
||||
%0 = fneg half %self
|
||||
ret half %0
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __add__(a: float16, b: float16) -> float16:
|
||||
%tmp = fadd half %a, %b
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __sub__(a: float16, b: float16) -> float16:
|
||||
%tmp = fsub half %a, %b
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __mul__(a: float16, b: float16) -> float16:
|
||||
%tmp = fmul half %a, %b
|
||||
ret half %tmp
|
||||
|
||||
def __floordiv__(self, other: float16) -> float16:
|
||||
return self.__truediv__(other).__floor__()
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __truediv__(a: float16, b: float16) -> float16:
|
||||
%tmp = fdiv half %a, %b
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __mod__(a: float16, b: float16) -> float16:
|
||||
%tmp = frem half %a, %b
|
||||
ret half %tmp
|
||||
|
||||
def __divmod__(self, other: float16) -> Tuple[float16, float16]:
|
||||
mod = self % other
|
||||
div = (self - mod) / other
|
||||
if mod:
|
||||
if (other < float16(0.0)) != (mod < float16(0.0)):
|
||||
mod += other
|
||||
div -= float16(1.0)
|
||||
else:
|
||||
mod = float16(0.0).copysign(other)
|
||||
|
||||
floordiv = float16(0.0)
|
||||
if div:
|
||||
floordiv = div.__floor__()
|
||||
if div - floordiv > float16(0.5):
|
||||
floordiv += float16(1.0)
|
||||
else:
|
||||
floordiv = float16(0.0).copysign(self / other)
|
||||
|
||||
return (floordiv, mod)
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __eq__(a: float16, b: float16) -> bool:
|
||||
%tmp = fcmp oeq half %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ne__(a: float16, b: float16) -> bool:
|
||||
%tmp = fcmp une half %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __lt__(a: float16, b: float16) -> bool:
|
||||
%tmp = fcmp olt half %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __gt__(a: float16, b: float16) -> bool:
|
||||
%tmp = fcmp ogt half %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __le__(a: float16, b: float16) -> bool:
|
||||
%tmp = fcmp ole half %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ge__(a: float16, b: float16) -> bool:
|
||||
%tmp = fcmp oge half %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sqrt(a: float16) -> float16:
|
||||
declare half @llvm.sqrt.f16(half %a)
|
||||
%tmp = call half @llvm.sqrt.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sin(a: float16) -> float16:
|
||||
declare half @llvm.sin.f16(half %a)
|
||||
%tmp = call half @llvm.sin.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def cos(a: float16) -> float16:
|
||||
declare half @llvm.cos.f16(half %a)
|
||||
%tmp = call half @llvm.cos.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp(a: float16) -> float16:
|
||||
declare half @llvm.exp.f16(half %a)
|
||||
%tmp = call half @llvm.exp.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp2(a: float16) -> float16:
|
||||
declare half @llvm.exp2.f16(half %a)
|
||||
%tmp = call half @llvm.exp2.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log(a: float16) -> float16:
|
||||
declare half @llvm.log.f16(half %a)
|
||||
%tmp = call half @llvm.log.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log10(a: float16) -> float16:
|
||||
declare half @llvm.log10.f16(half %a)
|
||||
%tmp = call half @llvm.log10.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log2(a: float16) -> float16:
|
||||
declare half @llvm.log2.f16(half %a)
|
||||
%tmp = call half @llvm.log2.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __abs__(a: float16) -> float16:
|
||||
declare half @llvm.fabs.f16(half %a)
|
||||
%tmp = call half @llvm.fabs.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __floor__(a: float16) -> float16:
|
||||
declare half @llvm.floor.f16(half %a)
|
||||
%tmp = call half @llvm.floor.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ceil__(a: float16) -> float16:
|
||||
declare half @llvm.ceil.f16(half %a)
|
||||
%tmp = call half @llvm.ceil.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __trunc__(a: float16) -> float16:
|
||||
declare half @llvm.trunc.f16(half %a)
|
||||
%tmp = call half @llvm.trunc.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def rint(a: float16) -> float16:
|
||||
declare half @llvm.rint.f16(half %a)
|
||||
%tmp = call half @llvm.rint.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def nearbyint(a: float16) -> float16:
|
||||
declare half @llvm.nearbyint.f16(half %a)
|
||||
%tmp = call half @llvm.nearbyint.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __round__(a: float16) -> float16:
|
||||
declare half @llvm.round.f16(half %a)
|
||||
%tmp = call half @llvm.round.f16(half %a)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __pow__(a: float16, b: float16) -> float16:
|
||||
declare half @llvm.pow.f16(half %a, half %b)
|
||||
%tmp = call half @llvm.pow.f16(half %a, half %b)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def min(a: float16, b: float16) -> float16:
|
||||
declare half @llvm.minnum.f16(half %a, half %b)
|
||||
%tmp = call half @llvm.minnum.f16(half %a, half %b)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def max(a: float16, b: float16) -> float16:
|
||||
declare half @llvm.maxnum.f16(half %a, half %b)
|
||||
%tmp = call half @llvm.maxnum.f16(half %a, half %b)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def copysign(a: float16, b: float16) -> float16:
|
||||
declare half @llvm.copysign.f16(half %a, half %b)
|
||||
%tmp = call half @llvm.copysign.f16(half %a, half %b)
|
||||
ret half %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def fma(a: float16, b: float16, c: float16) -> float16:
|
||||
declare half @llvm.fma.f16(half %a, half %b, half %c)
|
||||
%tmp = call half @llvm.fma.f16(half %a, half %b, half %c)
|
||||
ret half %tmp
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.__float__().__hash__()
|
||||
|
||||
def __match__(self, obj: float16) -> bool:
|
||||
return self == obj
|
||||
|
||||
@extend
|
||||
class bfloat16:
|
||||
@pure
|
||||
@llvm
|
||||
def __new__(self: float) -> bfloat16:
|
||||
%0 = fptrunc double %self to bfloat
|
||||
ret bfloat %0
|
||||
|
||||
def __new__(what: bfloat16) -> bfloat16:
|
||||
return what
|
||||
|
||||
def __new__() -> bfloat16:
|
||||
return bfloat16.__new__(0.0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__float__().__repr__()
|
||||
|
||||
def __format__(self, format_spec: str) -> str:
|
||||
return self.__float__().__format(format_spec)
|
||||
|
||||
def __copy__(self) -> bfloat16:
|
||||
return self
|
||||
|
||||
def __deepcopy__(self) -> bfloat16:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __int__(self) -> int:
|
||||
%0 = fptosi bfloat %self to i64
|
||||
ret i64 %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __float__(self) -> float:
|
||||
%0 = fpext bfloat %self to double
|
||||
ret double %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __bool__(self) -> bool:
|
||||
%0 = fcmp une bfloat %self, 0.000000e+00
|
||||
%1 = zext i1 %0 to i8
|
||||
ret i8 %1
|
||||
|
||||
def __pos__(self) -> bfloat16:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __neg__(self) -> bfloat16:
|
||||
%0 = fneg bfloat %self
|
||||
ret bfloat %0
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __add__(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
%tmp = fadd bfloat %a, %b
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __sub__(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
%tmp = fsub bfloat %a, %b
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __mul__(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
%tmp = fmul bfloat %a, %b
|
||||
ret bfloat %tmp
|
||||
|
||||
def __floordiv__(self, other: bfloat16) -> bfloat16:
|
||||
return self.__truediv__(other).__floor__()
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __truediv__(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
%tmp = fdiv bfloat %a, %b
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __mod__(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
%tmp = frem bfloat %a, %b
|
||||
ret bfloat %tmp
|
||||
|
||||
def __divmod__(self, other: bfloat16) -> Tuple[bfloat16, bfloat16]:
|
||||
mod = self % other
|
||||
div = (self - mod) / other
|
||||
if mod:
|
||||
if (other < bfloat16(0.0)) != (mod < bfloat16(0.0)):
|
||||
mod += other
|
||||
div -= bfloat16(1.0)
|
||||
else:
|
||||
mod = bfloat16(0.0).copysign(other)
|
||||
|
||||
floordiv = bfloat16(0.0)
|
||||
if div:
|
||||
floordiv = div.__floor__()
|
||||
if div - floordiv > bfloat16(0.5):
|
||||
floordiv += bfloat16(1.0)
|
||||
else:
|
||||
floordiv = bfloat16(0.0).copysign(self / other)
|
||||
|
||||
return (floordiv, mod)
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __eq__(a: bfloat16, b: bfloat16) -> bool:
|
||||
%tmp = fcmp oeq bfloat %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ne__(a: bfloat16, b: bfloat16) -> bool:
|
||||
%tmp = fcmp une bfloat %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __lt__(a: bfloat16, b: bfloat16) -> bool:
|
||||
%tmp = fcmp olt bfloat %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __gt__(a: bfloat16, b: bfloat16) -> bool:
|
||||
%tmp = fcmp ogt bfloat %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __le__(a: bfloat16, b: bfloat16) -> bool:
|
||||
%tmp = fcmp ole bfloat %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ge__(a: bfloat16, b: bfloat16) -> bool:
|
||||
%tmp = fcmp oge bfloat %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sqrt(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.sqrt.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.sqrt.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sin(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.sin.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.sin.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def cos(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.cos.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.cos.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.exp.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.exp.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp2(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.exp2.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.exp2.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.log.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.log.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log10(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.log10.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.log10.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log2(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.log2.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.log2.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __abs__(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.fabs.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.fabs.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __floor__(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.floor.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.floor.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ceil__(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.ceil.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.ceil.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __trunc__(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.trunc.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.trunc.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def rint(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.rint.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.rint.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def nearbyint(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.nearbyint.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.nearbyint.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __round__(a: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.round.bf16(bfloat %a)
|
||||
%tmp = call bfloat @llvm.round.bf16(bfloat %a)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __pow__(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.pow.bf16(bfloat %a, bfloat %b)
|
||||
%tmp = call bfloat @llvm.pow.bf16(bfloat %a, bfloat %b)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def min(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.minnum.bf16(bfloat %a, bfloat %b)
|
||||
%tmp = call bfloat @llvm.minnum.bf16(bfloat %a, bfloat %b)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def max(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.maxnum.bf16(bfloat %a, bfloat %b)
|
||||
%tmp = call bfloat @llvm.maxnum.bf16(bfloat %a, bfloat %b)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def copysign(a: bfloat16, b: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.copysign.bf16(bfloat %a, bfloat %b)
|
||||
%tmp = call bfloat @llvm.copysign.bf16(bfloat %a, bfloat %b)
|
||||
ret bfloat %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def fma(a: bfloat16, b: bfloat16, c: bfloat16) -> bfloat16:
|
||||
declare bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
|
||||
%tmp = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
|
||||
ret bfloat %tmp
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.__float__().__hash__()
|
||||
|
||||
def __match__(self, obj: bfloat16) -> bool:
|
||||
return self == obj
|
||||
|
||||
@extend
|
||||
class float128:
|
||||
@pure
|
||||
@llvm
|
||||
def __new__(self: float) -> float128:
|
||||
%0 = fpext double %self to fp128
|
||||
ret fp128 %0
|
||||
|
||||
def __new__(what: float128) -> float128:
|
||||
return what
|
||||
|
||||
def __new__() -> float128:
|
||||
return float128.__new__(0.0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__float__().__repr__()
|
||||
|
||||
def __format__(self, format_spec: str) -> str:
|
||||
return self.__float__().__format(format_spec)
|
||||
|
||||
def __copy__(self) -> float128:
|
||||
return self
|
||||
|
||||
def __deepcopy__(self) -> float128:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __int__(self) -> int:
|
||||
%0 = fptosi fp128 %self to i64
|
||||
ret i64 %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __float__(self) -> float:
|
||||
%0 = fptrunc fp128 %self to double
|
||||
ret double %0
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __bool__(self) -> bool:
|
||||
%0 = fcmp une fp128 %self, 0xL00000000000000000000000000000000
|
||||
%1 = zext i1 %0 to i8
|
||||
ret i8 %1
|
||||
|
||||
def __pos__(self) -> float128:
|
||||
return self
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __neg__(self) -> float128:
|
||||
%0 = fneg fp128 %self
|
||||
ret fp128 %0
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __add__(a: float128, b: float128) -> float128:
|
||||
%tmp = fadd fp128 %a, %b
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __sub__(a: float128, b: float128) -> float128:
|
||||
%tmp = fsub fp128 %a, %b
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@commutative
|
||||
@llvm
|
||||
def __mul__(a: float128, b: float128) -> float128:
|
||||
%tmp = fmul fp128 %a, %b
|
||||
ret fp128 %tmp
|
||||
|
||||
def __floordiv__(self, other: float128) -> float128:
|
||||
return self.__truediv__(other).__floor__()
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __truediv__(a: float128, b: float128) -> float128:
|
||||
%tmp = fdiv fp128 %a, %b
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __mod__(a: float128, b: float128) -> float128:
|
||||
%tmp = frem fp128 %a, %b
|
||||
ret fp128 %tmp
|
||||
|
||||
def __divmod__(self, other: float128) -> Tuple[float128, float128]:
|
||||
mod = self % other
|
||||
div = (self - mod) / other
|
||||
if mod:
|
||||
if (other < float128(0.0)) != (mod < float128(0)):
|
||||
mod += other
|
||||
div -= float128(1.0)
|
||||
else:
|
||||
mod = float128(0.0).copysign(other)
|
||||
|
||||
floordiv = float128(0.0)
|
||||
if div:
|
||||
floordiv = div.__floor__()
|
||||
if div - floordiv > float128(0.5):
|
||||
floordiv += float128(1.0)
|
||||
else:
|
||||
floordiv = float128(0.0).copysign(self / other)
|
||||
|
||||
return (floordiv, mod)
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __eq__(a: float128, b: float128) -> bool:
|
||||
%tmp = fcmp oeq fp128 %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ne__(a: float128, b: float128) -> bool:
|
||||
%tmp = fcmp une fp128 %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __lt__(a: float128, b: float128) -> bool:
|
||||
%tmp = fcmp olt fp128 %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __gt__(a: float128, b: float128) -> bool:
|
||||
%tmp = fcmp ogt fp128 %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __le__(a: float128, b: float128) -> bool:
|
||||
%tmp = fcmp ole fp128 %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ge__(a: float128, b: float128) -> bool:
|
||||
%tmp = fcmp oge fp128 %a, %b
|
||||
%res = zext i1 %tmp to i8
|
||||
ret i8 %res
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sqrt(a: float128) -> float128:
|
||||
declare fp128 @llvm.sqrt.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.sqrt.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def sin(a: float128) -> float128:
|
||||
declare fp128 @llvm.sin.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.sin.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def cos(a: float128) -> float128:
|
||||
declare fp128 @llvm.cos.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.cos.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp(a: float128) -> float128:
|
||||
declare fp128 @llvm.exp.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.exp.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def exp2(a: float128) -> float128:
|
||||
declare fp128 @llvm.exp2.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.exp2.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log(a: float128) -> float128:
|
||||
declare fp128 @llvm.log.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.log.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log10(a: float128) -> float128:
|
||||
declare fp128 @llvm.log10.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.log10.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def log2(a: float128) -> float128:
|
||||
declare fp128 @llvm.log2.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.log2.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __abs__(a: float128) -> float128:
|
||||
declare fp128 @llvm.fabs.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.fabs.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __floor__(a: float128) -> float128:
|
||||
declare fp128 @llvm.floor.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.floor.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __ceil__(a: float128) -> float128:
|
||||
declare fp128 @llvm.ceil.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.ceil.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __trunc__(a: float128) -> float128:
|
||||
declare fp128 @llvm.trunc.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.trunc.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def rint(a: float128) -> float128:
|
||||
declare fp128 @llvm.rint.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.rint.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def nearbyint(a: float128) -> float128:
|
||||
declare fp128 @llvm.nearbyint.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.nearbyint.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __round__(a: float128) -> float128:
|
||||
declare fp128 @llvm.round.f128(fp128 %a)
|
||||
%tmp = call fp128 @llvm.round.f128(fp128 %a)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def __pow__(a: float128, b: float128) -> float128:
|
||||
declare fp128 @llvm.pow.f128(fp128 %a, fp128 %b)
|
||||
%tmp = call fp128 @llvm.pow.f128(fp128 %a, fp128 %b)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def min(a: float128, b: float128) -> float128:
|
||||
declare fp128 @llvm.minnum.f128(fp128 %a, fp128 %b)
|
||||
%tmp = call fp128 @llvm.minnum.f128(fp128 %a, fp128 %b)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def max(a: float128, b: float128) -> float128:
|
||||
declare fp128 @llvm.maxnum.f128(fp128 %a, fp128 %b)
|
||||
%tmp = call fp128 @llvm.maxnum.f128(fp128 %a, fp128 %b)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def copysign(a: float128, b: float128) -> float128:
|
||||
declare fp128 @llvm.copysign.f128(fp128 %a, fp128 %b)
|
||||
%tmp = call fp128 @llvm.copysign.f128(fp128 %a, fp128 %b)
|
||||
ret fp128 %tmp
|
||||
|
||||
@pure
|
||||
@llvm
|
||||
def fma(a: float128, b: float128, c: float128) -> float128:
|
||||
declare fp128 @llvm.fma.f128(fp128 %a, fp128 %b, fp128 %c)
|
||||
%tmp = call fp128 @llvm.fma.f128(fp128 %a, fp128 %b, fp128 %c)
|
||||
ret fp128 %tmp
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.__float__().__hash__()
|
||||
|
||||
def __match__(self, obj: float128) -> bool:
|
||||
return self == obj
|
||||
|
||||
@extend
|
||||
class float:
|
||||
def __suffix_f32__(double) -> float32:
|
||||
return float32.__new__(double)
|
||||
|
||||
def __suffix_f16__(double) -> float16:
|
||||
return float16.__new__(double)
|
||||
|
||||
def __suffix_bf16__(double) -> bfloat16:
|
||||
return bfloat16.__new__(double)
|
||||
|
||||
def __suffix_f128__(double) -> float128:
|
||||
return float128.__new__(double)
|
||||
|
||||
f16 = float16
|
||||
bf16 = bfloat16
|
||||
f32 = float32
|
||||
f64 = float
|
||||
f128 = float128
|
||||
|
|
|
@ -142,3 +142,45 @@ def test_int_pow():
|
|||
assert f(T2(0)) ** f(T2(0)) == T2(1)
|
||||
assert str(f(T2(31)) ** f(T2(31))) == '17069174130723235958610643029059314756044734431'
|
||||
test_int_pow()
|
||||
|
||||
@test
|
||||
def test_float(F: type):
|
||||
x = F(5.5)
|
||||
assert str(x) == '5.5'
|
||||
assert F(x) == x
|
||||
assert F() == F(0.0)
|
||||
assert x.__copy__() == x
|
||||
assert x.__deepcopy__() == x
|
||||
assert int(x) == 5
|
||||
assert float(x) == 5.5
|
||||
assert bool(x)
|
||||
assert not bool(F())
|
||||
assert +x == x
|
||||
assert -x == F(-5.5)
|
||||
assert x + x == F(11.0)
|
||||
assert x - F(1.0) == F(4.5)
|
||||
assert x * F(3.0) == F(16.5)
|
||||
assert x / F(2.0) == F(2.75)
|
||||
if F is not float128: # LLVM ops give wrong results for fp128
|
||||
assert x // F(2.0) == F(2.0)
|
||||
assert x % F(0.75) == F(0.25)
|
||||
assert divmod(x, F(0.75)) == (F(7.0), F(0.25))
|
||||
assert x == x
|
||||
assert x != F()
|
||||
assert x < F(6.5)
|
||||
assert x > F(4.5)
|
||||
assert x <= F(6.5)
|
||||
assert x >= F(4.5)
|
||||
assert x >= x
|
||||
assert x <= x
|
||||
assert abs(x) == x
|
||||
assert abs(-x) == x
|
||||
assert x.__match__(x)
|
||||
assert not x.__match__(F())
|
||||
assert hash(x) == hash(5.5)
|
||||
|
||||
test_float(float)
|
||||
test_float(float32)
|
||||
#test_float(float16)
|
||||
#test_float(bfloat16)
|
||||
#test_float(float128)
|
||||
|
|
|
@ -34,9 +34,19 @@ def test_conversions():
|
|||
def kernel(x, v):
|
||||
v[0] = x
|
||||
|
||||
def empty_tuple(x):
|
||||
if staticlen(x) == 0:
|
||||
return ()
|
||||
else:
|
||||
T = type(x[0])
|
||||
return (T(),) + empty_tuple(x[1:])
|
||||
|
||||
def check(x):
|
||||
T = type(x)
|
||||
v = [T()]
|
||||
if isinstance(x, Tuple):
|
||||
e = empty_tuple(x)
|
||||
else:
|
||||
e = type(x)()
|
||||
v = [e]
|
||||
kernel(x, v, grid=1, block=1)
|
||||
return v == [x]
|
||||
|
||||
|
|
Loading…
Reference in New Issue