Generator argument optimization (and more) (#175)

* Fix ABI incompatibilities

* Fix codon-jit on macOS

* Fix scoping bugs

* Fix .codon detection

* Handle static arguments in magic methods; Update simd; Fix misc. bugs

* Avoid partial calls with generators

* clang-format

* Add generator-argument optimization

* Fix typo

* Fix omp test

* Make sure sum() does not call __iadd__

* Clarify difference in docs

* Fix any/all generator pass

* Fix  InstantiateExpr simplification; Support .py as module extension

* clang-format

* Bump version

Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
pull/182/head v0.15.4
A. R. Shajii 2023-01-17 10:21:59 -05:00 committed by GitHub
parent fc70c830d0
commit bac6ae58dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 515 additions and 100 deletions

View File

@ -1,10 +1,10 @@
cmake_minimum_required(VERSION 3.14)
project(
Codon
VERSION "0.15.3"
VERSION "0.15.4"
HOMEPAGE_URL "https://github.com/exaloop/codon"
DESCRIPTION "high-performance, extensible Python compiler")
set(CODON_JIT_PYTHON_VERSION "0.1.1")
set(CODON_JIT_PYTHON_VERSION "0.1.2")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
"${PROJECT_SOURCE_DIR}/codon/config/config.h")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
@ -197,6 +197,7 @@ set(CODON_HPPFILES
codon/cir/transform/parallel/schedule.h
codon/cir/transform/pass.h
codon/cir/transform/pythonic/dict.h
codon/cir/transform/pythonic/generator.h
codon/cir/transform/pythonic/io.h
codon/cir/transform/pythonic/list.h
codon/cir/transform/pythonic/str.h
@ -304,6 +305,7 @@ set(CODON_CPPFILES
codon/cir/transform/parallel/schedule.cpp
codon/cir/transform/pass.cpp
codon/cir/transform/pythonic/dict.cpp
codon/cir/transform/pythonic/generator.cpp
codon/cir/transform/pythonic/io.cpp
codon/cir/transform/pythonic/list.cpp
codon/cir/transform/pythonic/str.cpp

View File

@ -4,12 +4,12 @@
#include <algorithm>
#include "codon/parser/common.h"
#include "codon/cir/module.h"
#include "codon/cir/util/iterators.h"
#include "codon/cir/util/operator.h"
#include "codon/cir/util/visitor.h"
#include "codon/cir/var.h"
#include "codon/parser/common.h"
namespace codon {
namespace ir {

View File

@ -10,13 +10,13 @@
#include <unistd.h>
#include <utility>
#include "codon/cir/dsl/codegen.h"
#include "codon/cir/llvm/optimize.h"
#include "codon/cir/util/irtools.h"
#include "codon/compiler/debug_listener.h"
#include "codon/compiler/memory_manager.h"
#include "codon/parser/common.h"
#include "codon/runtime/lib.h"
#include "codon/cir/dsl/codegen.h"
#include "codon/cir/llvm/optimize.h"
#include "codon/cir/util/irtools.h"
#include "codon/util/common.h"
namespace codon {

View File

@ -2,9 +2,9 @@
#pragma once
#include "codon/dsl/plugins.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/cir/cir.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/dsl/plugins.h"
#include "codon/util/common.h"
#include <string>

View File

@ -4,8 +4,8 @@
#include <memory>
#include "codon/dsl/plugins.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/dsl/plugins.h"
namespace codon {
namespace ir {

View File

@ -5,8 +5,8 @@
#include <algorithm>
#include <memory>
#include "codon/parser/cache.h"
#include "codon/cir/func.h"
#include "codon/parser/cache.h"
namespace codon {
namespace ir {

View File

@ -18,6 +18,7 @@
#include "codon/cir/transform/parallel/openmp.h"
#include "codon/cir/transform/pass.h"
#include "codon/cir/transform/pythonic/dict.h"
#include "codon/cir/transform/pythonic/generator.h"
#include "codon/cir/transform/pythonic/io.h"
#include "codon/cir/transform/pythonic/list.h"
#include "codon/cir/transform/pythonic/str.h"
@ -162,6 +163,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
registerPass(std::make_unique<pythonic::DictArithmeticOptimization>());
registerPass(std::make_unique<pythonic::ListAdditionOptimization>());
registerPass(std::make_unique<pythonic::StrAdditionOptimization>());
registerPass(std::make_unique<pythonic::GeneratorArgumentOptimization>());
registerPass(std::make_unique<pythonic::IOCatOptimization>());
// lowering

View File

@ -0,0 +1,235 @@
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
#include "generator.h"
#include <algorithm>
#include "codon/cir/util/cloning.h"
#include "codon/cir/util/irtools.h"
#include "codon/cir/util/matching.h"
namespace codon {
namespace ir {
namespace transform {
namespace pythonic {
namespace {
bool isSum(Func *f) {
return f && f->getName().rfind("std.internal.builtin.sum:", 0) == 0;
}
bool isAny(Func *f) {
return f && f->getName().rfind("std.internal.builtin.any:", 0) == 0;
}
bool isAll(Func *f) {
return f && f->getName().rfind("std.internal.builtin.all:", 0) == 0;
}
// Replaces yields with updates to the accumulator variable.
struct GeneratorSumTransformer : public util::Operator {
Var *accumulator;
bool valid;
explicit GeneratorSumTransformer(Var *accumulator)
: util::Operator(), accumulator(accumulator), valid(true) {}
void handle(YieldInstr *v) override {
auto *M = v->getModule();
auto *val = v->getValue();
if (!val) {
valid = false;
return;
}
Value *rhs = val;
if (val->getType()->is(M->getBoolType())) {
rhs = M->Nr<TernaryInstr>(rhs, M->getInt(1), M->getInt(0));
}
Value *add = *M->Nr<VarValue>(accumulator) + *rhs;
if (!add || !add->getType()->is(accumulator->getType())) {
valid = false;
return;
}
auto *assign = M->Nr<AssignInstr>(accumulator, add);
v->replaceAll(assign);
}
void handle(ReturnInstr *v) override {
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
}
void handle(YieldInInstr *v) override { valid = false; }
};
// Replaces yields with conditional returns of the any/all answer.
struct GeneratorAnyAllTransformer : public util::Operator {
bool any; // true=any, false=all
bool valid;
explicit GeneratorAnyAllTransformer(bool any)
: util::Operator(), any(any), valid(true) {}
void handle(YieldInstr *v) override {
auto *M = v->getModule();
auto *val = v->getValue();
auto *valBool = val ? (*M->getBoolType())(*val) : nullptr;
if (!valBool) {
valid = false;
return;
} else if (!any) {
valBool = M->Nr<TernaryInstr>(valBool, M->getBool(false), M->getBool(true));
}
auto *newReturn = M->Nr<ReturnInstr>(M->getBool(any));
see(newReturn);
auto *rep = M->Nr<IfFlow>(valBool, util::series(newReturn));
v->replaceAll(rep);
}
void handle(ReturnInstr *v) override {
if (saw(v))
return;
auto *M = v->getModule();
auto *newReturn = M->Nr<ReturnInstr>(M->getBool(!any));
see(newReturn);
v->replaceAll(util::series(v->getValue(), newReturn));
}
void handle(YieldInInstr *v) override { valid = false; }
};
Func *genToSum(BodiedFunc *gen, types::Type *startType, types::Type *outType) {
if (!gen || !gen->isGenerator())
return nullptr;
auto *M = gen->getModule();
auto *fn = M->Nr<BodiedFunc>("__sum_wrapper");
auto *genType = cast<types::FuncType>(gen->getType());
if (!genType)
return nullptr;
std::vector<types::Type *> argTypes(genType->begin(), genType->end());
argTypes.push_back(startType);
std::vector<std::string> names;
for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) {
names.push_back((*it)->getName());
}
names.push_back("start");
auto *fnType = M->getFuncType(outType, argTypes);
fn->realize(fnType, names);
std::unordered_map<id_t, Var *> argRemap;
for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin();
it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) {
argRemap.emplace((*it1)->getId(), *it2);
}
util::CloneVisitor cv(M);
auto *body = cast<SeriesFlow>(cv.clone(gen->getBody(), fn, argRemap));
fn->setBody(body);
Value *init = M->Nr<VarValue>(fn->arg_back());
if (startType->is(M->getIntType()) && outType->is(M->getFloatType()))
init = (*M->getFloatType())(*init);
if (!init || !init->getType()->is(outType))
return nullptr;
auto *accumulator = util::makeVar(init, body, fn, /*prepend=*/true)->getVar();
GeneratorSumTransformer xgen(accumulator);
fn->accept(xgen);
body->push_back(M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator)));
if (!xgen.valid)
return nullptr;
return fn;
}
Func *genToAnyAll(BodiedFunc *gen, bool any) {
if (!gen || !gen->isGenerator())
return nullptr;
auto *M = gen->getModule();
auto *fn = M->Nr<BodiedFunc>(any ? "__any_wrapper" : "__all_wrapper");
auto *genType = cast<types::FuncType>(gen->getType());
std::vector<types::Type *> argTypes(genType->begin(), genType->end());
std::vector<std::string> names;
for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) {
names.push_back((*it)->getName());
}
auto *fnType = M->getFuncType(M->getBoolType(), argTypes);
fn->realize(fnType, names);
std::unordered_map<id_t, Var *> argRemap;
for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin();
it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) {
argRemap.emplace((*it1)->getId(), *it2);
}
util::CloneVisitor cv(M);
auto *body = cast<SeriesFlow>(cv.clone(gen->getBody(), fn, argRemap));
fn->setBody(body);
GeneratorAnyAllTransformer xgen(any);
fn->accept(xgen);
body->push_back(M->Nr<ReturnInstr>(M->getBool(!any)));
if (!xgen.valid)
return nullptr;
return fn;
}
} // namespace
const std::string GeneratorArgumentOptimization::KEY =
"core-pythonic-generator-argument-opt";
void GeneratorArgumentOptimization::handle(CallInstr *v) {
auto *M = v->getModule();
auto *func = util::getFunc(v->getCallee());
if (isSum(func) && v->numArgs() == 2) {
auto *call = cast<CallInstr>(v->front());
if (!call)
return;
auto *gen = util::getFunc(call->getCallee());
auto *start = v->back();
if (auto *fn = genToSum(cast<BodiedFunc>(gen), start->getType(), v->getType())) {
std::vector<Value *> args(call->begin(), call->end());
args.push_back(start);
v->replaceAll(util::call(fn, args));
}
} else {
bool any = isAny(func), all = isAll(func);
if (!(any || all) || v->numArgs() != 1 || !v->getType()->is(M->getBoolType()))
return;
auto *call = cast<CallInstr>(v->front());
if (!call)
return;
auto *gen = util::getFunc(call->getCallee());
if (auto *fn = genToAnyAll(cast<BodiedFunc>(gen), any)) {
std::vector<Value *> args(call->begin(), call->end());
v->replaceAll(util::call(fn, args));
}
}
}
} // namespace pythonic
} // namespace transform
} // namespace ir
} // namespace codon

View File

@ -0,0 +1,25 @@
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
#pragma once
#include "codon/cir/transform/pass.h"
namespace codon {
namespace ir {
namespace transform {
namespace pythonic {
/// Pass to optimize passing a generator to some built-in functions
/// like sum(), any() or all(), which will be converted to regular
/// for-loops.
class GeneratorArgumentOptimization : public OperatorPass {
public:
static const std::string KEY;
std::string getKey() const override { return KEY; }
void handle(CallInstr *v) override;
};
} // namespace pythonic
} // namespace transform
} // namespace ir
} // namespace codon

View File

@ -6,12 +6,12 @@
#include <memory>
#include <utility>
#include "codon/parser/cache.h"
#include "codon/cir/module.h"
#include "codon/cir/util/irtools.h"
#include "codon/cir/util/iterators.h"
#include "codon/cir/util/visitor.h"
#include "codon/cir/value.h"
#include "codon/parser/cache.h"
#include <fmt/format.h>
namespace codon {

View File

@ -8,10 +8,10 @@
#include <utility>
#include <vector>
#include "codon/parser/ast.h"
#include "codon/cir/base.h"
#include "codon/cir/util/packs.h"
#include "codon/cir/util/visitor.h"
#include "codon/parser/ast.h"
#include <fmt/format.h>
#include <fmt/ostream.h>

View File

@ -7,12 +7,12 @@
#include <unordered_map>
#include <vector>
#include "codon/compiler/error.h"
#include "codon/dsl/plugins.h"
#include "codon/parser/cache.h"
#include "codon/cir/llvm/llvisitor.h"
#include "codon/cir/module.h"
#include "codon/cir/transform/manager.h"
#include "codon/compiler/error.h"
#include "codon/dsl/plugins.h"
#include "codon/parser/cache.h"
namespace codon {

View File

@ -2,8 +2,8 @@
#include "engine.h"
#include "codon/compiler/memory_manager.h"
#include "codon/cir/llvm/optimize.h"
#include "codon/compiler/memory_manager.h"
namespace codon {
namespace jit {

View File

@ -5,8 +5,8 @@
#include <memory>
#include <vector>
#include "codon/compiler/debug_listener.h"
#include "codon/cir/llvm/llvm.h"
#include "codon/compiler/debug_listener.h"
namespace codon {
namespace jit {

View File

@ -7,14 +7,14 @@
#include <unordered_map>
#include <vector>
#include "codon/cir/llvm/llvisitor.h"
#include "codon/cir/transform/manager.h"
#include "codon/cir/var.h"
#include "codon/compiler/compiler.h"
#include "codon/compiler/engine.h"
#include "codon/compiler/error.h"
#include "codon/parser/cache.h"
#include "codon/runtime/lib.h"
#include "codon/cir/llvm/llvisitor.h"
#include "codon/cir/transform/manager.h"
#include "codon/cir/var.h"
#include "codon/compiler/jit_extern.h"

View File

@ -2,10 +2,10 @@
#pragma once
#include "codon/parser/cache.h"
#include "codon/cir/cir.h"
#include "codon/cir/transform/manager.h"
#include "codon/cir/transform/pass.h"
#include "codon/parser/cache.h"
#include "llvm/Passes/PassBuilder.h"
#include <functional>
#include <string>

View File

@ -7,9 +7,9 @@
#include <string>
#include <vector>
#include "codon/cir/util/iterators.h"
#include "codon/compiler/error.h"
#include "codon/dsl/dsl.h"
#include "codon/cir/util/iterators.h"
#include "llvm/Support/DynamicLibrary.h"
namespace codon {

View File

@ -8,10 +8,10 @@
#include <string>
#include <vector>
#include "codon/cir/cir.h"
#include "codon/parser/ast.h"
#include "codon/parser/common.h"
#include "codon/parser/ctx.h"
#include "codon/cir/cir.h"
#define FILE_GENERATED "<generated>"
#define MODULE_MAIN "__main__"

View File

@ -207,9 +207,12 @@ std::string library_path() {
namespace {
void addPath(std::vector<std::string> &paths, const std::string &path) {
if (llvm::sys::fs::exists(path))
bool addPath(std::vector<std::string> &paths, const std::string &path) {
if (llvm::sys::fs::exists(path)) {
paths.push_back(getAbsolutePath(path));
return true;
}
return false;
}
std::vector<std::string> getStdLibPaths(const std::string &argv0,
@ -244,7 +247,9 @@ ImportFile getRoot(const std::string argv0, const std::vector<std::string> &plug
}
if (!isStdLib && startswith(s, module0Root))
root = module0Root;
const std::string ext = ".codon";
std::string ext = ".codon";
if (!((root.empty() || startswith(s, root)) && endswith(s, ext)))
ext = ".py";
seqassertn((root.empty() || startswith(s, root)) && endswith(s, ext),
"bad path substitution: {}, {}", s, root);
auto module = s.substr(root.size() + 1, s.size() - root.size() - ext.size() - 1);
@ -280,6 +285,14 @@ std::shared_ptr<ImportFile> getImportFile(const std::string &argv0,
path = llvm::SmallString<128>(parentRelativeTo);
llvm::sys::path::append(path, what, "__init__.codon");
addPath(paths, std::string(path));
path = llvm::SmallString<128>(parentRelativeTo);
llvm::sys::path::append(path, what);
llvm::sys::path::replace_extension(path, "py");
addPath(paths, std::string(path));
path = llvm::SmallString<128>(parentRelativeTo);
llvm::sys::path::append(path, what, "__init__.py");
addPath(paths, std::string(path));
}
}
for (auto &p : getStdLibPaths(argv0, plugins)) {

View File

@ -14,6 +14,10 @@ using namespace codon::error;
namespace codon::ast {
void SimplifyVisitor::visit(IdExpr *expr) {
if (startswith(expr->value, TYPE_TUPLE)) {
expr->markType();
return;
}
auto val = ctx->findDominatingBinding(expr->value);
if (!val)
E(Error::ID_NOT_FOUND, expr, expr->value);

View File

@ -53,15 +53,19 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
auto loops = clone_nop(expr->loops); // Clone as loops will be modified
std::string optimizeVar;
if (expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 &&
loops[0].conds.empty()) {
// List comprehension optimization:
// Use `iter.__len__()` when creating list if there is a single for loop
// without any if conditions in the comprehension
optimizeVar = ctx->cache->getTemporaryVar("i");
stmts.push_back(transform(N<AssignStmt>(N<IdExpr>(optimizeVar), loops[0].gen)));
loops[0].gen = N<IdExpr>(optimizeVar);
// List comprehension optimization:
// Use `iter.__len__()` when creating list if there is a single for loop
// without any if conditions in the comprehension
bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 &&
loops[0].conds.empty();
if (canOptimize) {
auto iter = transform(loops[0].gen);
IdExpr *id;
if (iter->getCall() && (id = iter->getCall()->expr->getId())) {
// Turn off this optimization for static items
canOptimize &= !startswith(id->value, "std.internal.types.range.staticrange");
canOptimize &= !startswith(id->value, "statictuple");
}
}
SuiteStmt *prev = nullptr;
@ -72,16 +76,32 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
if (expr->kind == GeneratorExpr::ListGenerator) {
// List comprehensions
std::vector<ExprPtr> args;
if (!optimizeVar.empty()) {
// Use special List.__init__(bool, [optimizeVar]) constructor
args = {N<BoolExpr>(true), N<IdExpr>(optimizeVar)};
}
stmts.push_back(
transform(N<AssignStmt>(clone(var), N<CallExpr>(N<IdExpr>("List"), args))));
prev->stmts.push_back(
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(expr->expr))));
stmts.push_back(transform(suite));
resultExpr = N<StmtExpr>(stmts, transform(var));
auto noOptStmt =
N<SuiteStmt>(N<AssignStmt>(clone(var), N<CallExpr>(N<IdExpr>("List"))), suite);
if (canOptimize) {
seqassert(suite->getSuite() && !suite->getSuite()->stmts.empty() &&
CAST(suite->getSuite()->stmts[0], ForStmt),
"bad comprehension transformation");
auto optimizeVar = ctx->cache->getTemporaryVar("i");
auto optSuite = clone(suite);
CAST(optSuite->getSuite()->stmts[0], ForStmt)->iter = N<IdExpr>(optimizeVar);
auto optStmt = N<SuiteStmt>(
N<AssignStmt>(N<IdExpr>(optimizeVar), clone(expr->loops[0].gen)),
N<AssignStmt>(
clone(var),
N<CallExpr>(N<IdExpr>("List"),
N<CallExpr>(N<DotExpr>(N<IdExpr>(optimizeVar), "__len__")))),
optSuite);
resultExpr = transform(
N<IfExpr>(N<CallExpr>(N<IdExpr>("hasattr"), clone(expr->loops[0].gen),
N<StringExpr>("__len__")),
N<StmtExpr>(optStmt, clone(var)), N<StmtExpr>(noOptStmt, var)));
} else {
resultExpr = transform(N<StmtExpr>(noOptStmt, var));
}
} else if (expr->kind == GeneratorExpr::SetGenerator) {
// Set comprehensions
stmts.push_back(
@ -94,7 +114,16 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
// Generators: converted to lambda functions that yield the target expression
prev->stmts.push_back(N<YieldStmt>(clone(expr->expr)));
stmts.push_back(suite);
resultExpr = N<CallExpr>(N<DotExpr>(N<CallExpr>(makeAnonFn(stmts)), "__iter__"));
auto anon = makeAnonFn(stmts);
if (auto call = anon->getCall()) {
seqassert(!call->args.empty() && call->args.back().value->getEllipsis(),
"bad lambda: {}", *call);
call->args.pop_back();
} else {
anon = N<CallExpr>(anon);
}
resultExpr = anon;
}
std::swap(avoidDomination, ctx->avoidDomination);
}

View File

@ -322,8 +322,8 @@ ExprPtr SimplifyVisitor::makeAnonFn(std::vector<StmtPtr> suite,
prependStmts->push_back(fs->stmts[0]);
for (StmtPtr s = fs->stmts[1]; s;) {
if (auto suite = s->getSuite()) {
// Suites can only occur when __internal__.undef is inserted for a partial call
// argument. Extract __internal__.undef checks and prepend them
// Suites can only occur when captures are inserted for a partial call
// argument.
seqassert(suite->stmts.size() == 2, "invalid function transform");
prependStmts->push_back(suite->stmts[0]);
s = suite->stmts[1];

View File

@ -95,8 +95,12 @@ void SimplifyVisitor::visit(IndexExpr *expr) {
}
}
/// Ignore it. Already transformed. Sometimes called again
/// during class extension.
void SimplifyVisitor::visit(InstantiateExpr *expr) {}
/// Already transformed. Sometimes needed again
/// for identifier analysis.
void SimplifyVisitor::visit(InstantiateExpr *expr) {
transformType(expr->typeExpr);
for (auto &tp : expr->typeParams)
transform(tp, true);
}
} // namespace codon::ast

View File

@ -7,11 +7,11 @@
#include <string>
#include <vector>
#include "codon/cir/transform/parallel/schedule.h"
#include "codon/cir/util/cloning.h"
#include "codon/parser/ast.h"
#include "codon/parser/common.h"
#include "codon/parser/visitors/translate/translate_ctx.h"
#include "codon/cir/transform/parallel/schedule.h"
#include "codon/cir/util/cloning.h"
using codon::ir::cast;
using codon::ir::transform::parallel::OMPSched;

View File

@ -8,12 +8,12 @@
#include <unordered_set>
#include <vector>
#include "codon/cir/cir.h"
#include "codon/parser/ast.h"
#include "codon/parser/cache.h"
#include "codon/parser/common.h"
#include "codon/parser/visitors/translate/translate_ctx.h"
#include "codon/parser/visitors/visitor.h"
#include "codon/cir/cir.h"
namespace codon::ast {

View File

@ -7,11 +7,11 @@
#include <unordered_set>
#include <vector>
#include "codon/cir/cir.h"
#include "codon/cir/types/types.h"
#include "codon/parser/cache.h"
#include "codon/parser/common.h"
#include "codon/parser/ctx.h"
#include "codon/cir/cir.h"
#include "codon/cir/types/types.h"
namespace codon::ast {

View File

@ -909,22 +909,30 @@ void TypecheckVisitor::addFunctionGenerics(const FuncType *t) {
for (auto parent = t->funcParent; parent;) {
if (auto f = parent->getFunc()) {
// Add parent function generics
for (auto &g : f->funcGenerics)
for (auto &g : f->funcGenerics) {
// LOG(" -> {} := {}", g.name, g.type->debugString(true));
ctx->add(TypecheckItem::Type, g.name, g.type);
}
parent = f->funcParent;
} else {
// Add parent class generics
seqassert(parent->getClass(), "not a class: {}", parent);
for (auto &g : parent->getClass()->generics)
for (auto &g : parent->getClass()->generics) {
// LOG(" => {} := {}", g.name, g.type->debugString(true));
ctx->add(TypecheckItem::Type, g.name, g.type);
for (auto &g : parent->getClass()->hiddenGenerics)
}
for (auto &g : parent->getClass()->hiddenGenerics) {
// LOG(" :> {} := {}", g.name, g.type->debugString(true));
ctx->add(TypecheckItem::Type, g.name, g.type);
}
break;
}
}
// Add function generics
for (auto &g : t->funcGenerics)
for (auto &g : t->funcGenerics) {
// LOG(" >> {} := {}", g.name, g.type->debugString(true));
ctx->add(TypecheckItem::Type, g.name, g.type);
}
}
/// Generate a partial type `Partial.N<mask>` for a given function.

View File

@ -155,7 +155,8 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
// Generalize generics and remove them from the context
for (const auto &g : generics) {
for (auto &u : g->getUnbounds())
u->getUnbound()->kind = LinkType::Generic;
if (u->getUnbound())
u->getUnbound()->kind = LinkType::Generic;
}
// Construct the type
@ -163,8 +164,9 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
baseType, ctx->cache->functions[stmt->name].ast.get(), explicits);
funcTyp->setSrcInfo(getSrcInfo());
if (isClassMember && stmt->attributes.has(Attr::Method))
if (isClassMember && stmt->attributes.has(Attr::Method)) {
funcTyp->funcParent = ctx->find(stmt->attributes.parentClass)->type;
}
funcTyp =
std::static_pointer_cast<FuncType>(funcTyp->generalize(ctx->typecheckLevel));

View File

@ -7,11 +7,11 @@
#include <tuple>
#include <vector>
#include "codon/cir/types/types.h"
#include "codon/parser/ast.h"
#include "codon/parser/common.h"
#include "codon/parser/visitors/simplify/simplify.h"
#include "codon/parser/visitors/typecheck/typecheck.h"
#include "codon/cir/types/types.h"
using fmt::format;
using namespace codon::error;

View File

@ -636,7 +636,7 @@ ExprPtr TypecheckVisitor::transformBinaryInplaceMagic(BinaryExpr *expr, bool isA
// In-place operations: check if `lhs.__iop__(lhs, rhs)` exists
if (!method && expr->inPlace) {
method = findBestMethod(lt, format("__i{}__", magic), {lt, rt});
method = findBestMethod(lt, format("__i{}__", magic), {expr->lexpr, expr->rexpr});
}
if (method)
@ -667,11 +667,11 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
}
// Normal operations: check if `lhs.__magic__(lhs, rhs)` exists
auto method = findBestMethod(lt, format("__{}__", magic), {lt, rt});
auto method = findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr});
// Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists
if (!method &&
(method = findBestMethod(rt, format("__{}__", rightMagic), {rt, lt}))) {
if (!method && (method = findBestMethod(rt, format("__{}__", rightMagic),
{expr->rexpr, expr->lexpr}))) {
swap(expr->lexpr, expr->rexpr);
}

View File

@ -187,6 +187,19 @@ TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, const std::string &mem
return m.empty() ? nullptr : m[0];
}
/// Select the best method indicated of an object that matches the given argument
/// types. See @c findMatchingMethods for details.
types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ,
const std::string &member,
const std::vector<ExprPtr> &args) {
std::vector<CallExpr::Arg> callArgs;
for (auto &a : args)
callArgs.push_back({"", a});
auto methods = ctx->findMethod(typ->name, member, false);
auto m = findMatchingMethods(typ, methods, callArgs);
return m.empty() ? nullptr : m[0];
}
/// Select the best method among the provided methods given the list of arguments.
/// See @c reorderNamedArgs for details.
std::vector<types::FuncTypePtr>

View File

@ -207,6 +207,9 @@ private:
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
const std::string &member,
const std::vector<types::TypePtr> &args);
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
const std::string &member,
const std::vector<ExprPtr> &args);
std::vector<types::FuncTypePtr>
findMatchingMethods(const types::ClassTypePtr &typ,
const std::vector<types::FuncTypePtr> &methods,

View File

@ -14,8 +14,8 @@ in mind.
- **Strings:** Codon currently uses ASCII strings unlike
Python's unicode strings.
- **Dictionaries:** Codon's dictionary type is not sorted
internally, unlike Python's.
- **Dictionaries:** Codon's dictionary type does not preserve
insertion order, unlike Python's as of 3.6.
# Type checking

View File

@ -17,13 +17,21 @@ sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
from .codon_jit import JITWrapper, JITError, codon_library
if "CODON_PATH" not in os.environ:
codon_path = []
codon_lib_path = codon_library()
if not codon_lib_path:
if codon_lib_path:
codon_path.append(Path(codon_lib_path).parent / "stdlib")
codon_path.append(
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib"
)
for path in codon_path:
if path.exists():
os.environ["CODON_PATH"] = str(path.resolve())
break
else:
raise RuntimeError(
"Cannot locate Codon. Please install Codon or set CODON_PATH."
)
codon_path = (Path(codon_lib_path).parent / "stdlib").resolve()
os.environ["CODON_PATH"] = str(codon_path)
pod_conversions = {
type(None): "pyobj",

View File

@ -1,6 +1,6 @@
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
@tuple
@tuple(container=False) # disallow default __getitem__
class Vec[T, N: Static[int]]:
ZERO_16x8i = Vec[u8,16](u8(0))
FF_16x8i = Vec[u8,16](u8(0xff))
@ -307,6 +307,10 @@ class Vec[T, N: Static[int]]:
else:
return "?"
def scatter(self: Vec[T, N]) -> List[T]:
return [self[i] for i in staticrange(N)]
u8x16 = Vec[u8, 16]
u8x32 = Vec[u8, 32]
f32x8 = Vec[f32, 8]

View File

@ -248,19 +248,26 @@ def round(x, n=0):
nx = float.__pow__(10.0, n)
return float.__round__(x * nx) / nx
def sum(xi):
"""
Return the sum of the items added together from xi
"""
x = iter(xi)
if not x.done():
s = x.next()
while not x.done():
s += x.next()
x.destroy()
return s
def _sum_start(x, start):
if isinstance(x.__iter__(), Generator[float]) and isinstance(start, int):
return float(start)
else:
x.destroy()
return start
def sum(x, start=0):
"""
Return the sum of the items added together from x
"""
s = _sum_start(x, start)
for a in x:
# don't use += to avoid calling iadd
if isinstance(a, bool):
s = s + (1 if a else 0)
else:
s = s + a
return s
def repr(x):
"""Return the string representation of x"""

View File

@ -24,13 +24,6 @@ class List:
self.arr = Array[T](capacity)
self.len = 0
def __init__(self, dummy: bool, other):
"""Dummy __init__ used for list comprehension optimization"""
if hasattr(other, "__len__"):
self.__init__(other.__len__())
else:
self.__init__()
def __init__(self, arr: Array[T], len: int):
self.arr = arr
self.len = len

View File

@ -85,6 +85,70 @@ def test_map_filter():
assert list(filter(lambda i: i%2 == 0, map(lambda i: i*i, range(10)))) == [0, 4, 16, 36, 64]
@test
def test_gen_builtins():
assert sum([1, 2, 3]) == 6
assert sum([1, 2, 3], 0.5) == 6.5
assert sum([True, False, True, False, True], 0.5) == 3.5
assert sum(List[float]()) == 0
assert sum(i/2 for i in range(10)) == 22.5
def g1():
yield 1.5
yield 2.5
return
yield 3.5
assert sum(g1(), 10) == 14.0
def g2():
yield True
yield False
yield True
assert sum(g2()) == 2
class A:
iadd_count = 0
n: int
def __init__(self, n):
self.n = n
def __add__(self, other):
return A(self.n + other.n)
def __iadd__(self, other):
A.iadd_count += 1
self.n += other.n
return self
assert sum((A(i) for i in range(5)), A(100)).n == 110
assert A.iadd_count == 0
def g3(a, b):
for i in range(10):
yield a
yield b
assert all([True, True])
assert all(i for i in range(0))
assert not all([True, False])
assert all(List[str]())
assert all(g3(True, True))
assert not all(g3(True, False))
assert not all(g3(False, True))
assert not all(g3(False, False))
assert any([True, True])
assert not any(i for i in range(0))
assert not any([False, False])
assert not any(List[bool]())
assert any(g3(True, True))
assert any(g3(True, False))
assert any(g3(False, True))
assert not any(g3(False, False))
@test
def test_int_format():
n = 0
@ -269,6 +333,7 @@ def test_files(open_fn):
test_min_max()
test_map_filter()
test_gen_builtins()
test_int_format()
test_reversed()
test_divmod()

View File

@ -15,15 +15,15 @@
#include <unistd.h>
#include <vector>
#include "codon/compiler/compiler.h"
#include "codon/compiler/error.h"
#include "codon/parser/common.h"
#include "codon/cir/analyze/dataflow/capture.h"
#include "codon/cir/analyze/dataflow/reaching.h"
#include "codon/cir/util/inlining.h"
#include "codon/cir/util/irtools.h"
#include "codon/cir/util/operator.h"
#include "codon/cir/util/outlining.h"
#include "codon/compiler/compiler.h"
#include "codon/compiler/error.h"
#include "codon/parser/common.h"
#include "codon/util/common.h"
#include "gtest/gtest.h"

View File

@ -157,12 +157,10 @@ print d #: {0: 9}
#%% comprehension_opt,barebones
@extend
class List:
def __init__(self, dummy: bool, other):
if hasattr(other, '__len__'):
print 'optimize', other.__len__()
self.__init__(other.__len__())
else:
self.__init__()
def __init__(self, cap: int):
print 'optimize', cap
self.arr = Array[T](cap)
self.len = 0
def foo():
yield 0
yield 1

View File

@ -457,7 +457,7 @@ def test_omp_reductions():
@par
for i in L[1:1001]:
c += f32(i)
assert c == sum(f32(i) for i in range(1001))
assert c == sum((f32(i) for i in range(1001)), f32(0))
c = f32(1.)
@par