mirror of https://github.com/exaloop/codon.git
Generator argument optimization (and more) (#175)
* Fix ABI incompatibilities * Fix codon-jit on macOS * Fix scoping bugs * Fix .codon detection * Handle static arguments in magic methods; Update simd; Fix misc. bugs * Avoid partial calls with generators * clang-format * Add generator-argument optimization * Fix typo * Fix omp test * Make sure sum() does not call __iadd__ * Clarify difference in docs * Fix any/all generator pass * Fix InstantiateExpr simplification; Support .py as module extension * clang-format * Bump version Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>pull/182/head v0.15.4
parent
fc70c830d0
commit
bac6ae58dd
|
@ -1,10 +1,10 @@
|
|||
cmake_minimum_required(VERSION 3.14)
|
||||
project(
|
||||
Codon
|
||||
VERSION "0.15.3"
|
||||
VERSION "0.15.4"
|
||||
HOMEPAGE_URL "https://github.com/exaloop/codon"
|
||||
DESCRIPTION "high-performance, extensible Python compiler")
|
||||
set(CODON_JIT_PYTHON_VERSION "0.1.1")
|
||||
set(CODON_JIT_PYTHON_VERSION "0.1.2")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
|
||||
"${PROJECT_SOURCE_DIR}/codon/config/config.h")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
|
||||
|
@ -197,6 +197,7 @@ set(CODON_HPPFILES
|
|||
codon/cir/transform/parallel/schedule.h
|
||||
codon/cir/transform/pass.h
|
||||
codon/cir/transform/pythonic/dict.h
|
||||
codon/cir/transform/pythonic/generator.h
|
||||
codon/cir/transform/pythonic/io.h
|
||||
codon/cir/transform/pythonic/list.h
|
||||
codon/cir/transform/pythonic/str.h
|
||||
|
@ -304,6 +305,7 @@ set(CODON_CPPFILES
|
|||
codon/cir/transform/parallel/schedule.cpp
|
||||
codon/cir/transform/pass.cpp
|
||||
codon/cir/transform/pythonic/dict.cpp
|
||||
codon/cir/transform/pythonic/generator.cpp
|
||||
codon/cir/transform/pythonic/io.cpp
|
||||
codon/cir/transform/pythonic/list.cpp
|
||||
codon/cir/transform/pythonic/str.cpp
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/cir/module.h"
|
||||
#include "codon/cir/util/iterators.h"
|
||||
#include "codon/cir/util/operator.h"
|
||||
#include "codon/cir/util/visitor.h"
|
||||
#include "codon/cir/var.h"
|
||||
#include "codon/parser/common.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
|
|
|
@ -10,13 +10,13 @@
|
|||
#include <unistd.h>
|
||||
#include <utility>
|
||||
|
||||
#include "codon/cir/dsl/codegen.h"
|
||||
#include "codon/cir/llvm/optimize.h"
|
||||
#include "codon/cir/util/irtools.h"
|
||||
#include "codon/compiler/debug_listener.h"
|
||||
#include "codon/compiler/memory_manager.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/runtime/lib.h"
|
||||
#include "codon/cir/dsl/codegen.h"
|
||||
#include "codon/cir/llvm/optimize.h"
|
||||
#include "codon/cir/util/irtools.h"
|
||||
#include "codon/util/common.h"
|
||||
|
||||
namespace codon {
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "codon/dsl/plugins.h"
|
||||
#include "codon/cir/llvm/llvm.h"
|
||||
#include "codon/cir/cir.h"
|
||||
#include "codon/cir/llvm/llvm.h"
|
||||
#include "codon/dsl/plugins.h"
|
||||
#include "codon/util/common.h"
|
||||
|
||||
#include <string>
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "codon/dsl/plugins.h"
|
||||
#include "codon/cir/llvm/llvm.h"
|
||||
#include "codon/dsl/plugins.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/cir/func.h"
|
||||
#include "codon/parser/cache.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "codon/cir/transform/parallel/openmp.h"
|
||||
#include "codon/cir/transform/pass.h"
|
||||
#include "codon/cir/transform/pythonic/dict.h"
|
||||
#include "codon/cir/transform/pythonic/generator.h"
|
||||
#include "codon/cir/transform/pythonic/io.h"
|
||||
#include "codon/cir/transform/pythonic/list.h"
|
||||
#include "codon/cir/transform/pythonic/str.h"
|
||||
|
@ -162,6 +163,7 @@ void PassManager::registerStandardPasses(PassManager::Init init) {
|
|||
registerPass(std::make_unique<pythonic::DictArithmeticOptimization>());
|
||||
registerPass(std::make_unique<pythonic::ListAdditionOptimization>());
|
||||
registerPass(std::make_unique<pythonic::StrAdditionOptimization>());
|
||||
registerPass(std::make_unique<pythonic::GeneratorArgumentOptimization>());
|
||||
registerPass(std::make_unique<pythonic::IOCatOptimization>());
|
||||
|
||||
// lowering
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
#include "generator.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "codon/cir/util/cloning.h"
|
||||
#include "codon/cir/util/irtools.h"
|
||||
#include "codon/cir/util/matching.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
namespace transform {
|
||||
namespace pythonic {
|
||||
namespace {
|
||||
bool isSum(Func *f) {
|
||||
return f && f->getName().rfind("std.internal.builtin.sum:", 0) == 0;
|
||||
}
|
||||
|
||||
bool isAny(Func *f) {
|
||||
return f && f->getName().rfind("std.internal.builtin.any:", 0) == 0;
|
||||
}
|
||||
|
||||
bool isAll(Func *f) {
|
||||
return f && f->getName().rfind("std.internal.builtin.all:", 0) == 0;
|
||||
}
|
||||
|
||||
// Replaces yields with updates to the accumulator variable.
|
||||
struct GeneratorSumTransformer : public util::Operator {
|
||||
Var *accumulator;
|
||||
bool valid;
|
||||
|
||||
explicit GeneratorSumTransformer(Var *accumulator)
|
||||
: util::Operator(), accumulator(accumulator), valid(true) {}
|
||||
|
||||
void handle(YieldInstr *v) override {
|
||||
auto *M = v->getModule();
|
||||
auto *val = v->getValue();
|
||||
if (!val) {
|
||||
valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
Value *rhs = val;
|
||||
if (val->getType()->is(M->getBoolType())) {
|
||||
rhs = M->Nr<TernaryInstr>(rhs, M->getInt(1), M->getInt(0));
|
||||
}
|
||||
|
||||
Value *add = *M->Nr<VarValue>(accumulator) + *rhs;
|
||||
if (!add || !add->getType()->is(accumulator->getType())) {
|
||||
valid = false;
|
||||
return;
|
||||
}
|
||||
|
||||
auto *assign = M->Nr<AssignInstr>(accumulator, add);
|
||||
v->replaceAll(assign);
|
||||
}
|
||||
|
||||
void handle(ReturnInstr *v) override {
|
||||
auto *M = v->getModule();
|
||||
auto *newReturn = M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator));
|
||||
see(newReturn);
|
||||
v->replaceAll(util::series(v->getValue(), newReturn));
|
||||
}
|
||||
|
||||
void handle(YieldInInstr *v) override { valid = false; }
|
||||
};
|
||||
|
||||
// Replaces yields with conditional returns of the any/all answer.
|
||||
struct GeneratorAnyAllTransformer : public util::Operator {
|
||||
bool any; // true=any, false=all
|
||||
bool valid;
|
||||
|
||||
explicit GeneratorAnyAllTransformer(bool any)
|
||||
: util::Operator(), any(any), valid(true) {}
|
||||
|
||||
void handle(YieldInstr *v) override {
|
||||
auto *M = v->getModule();
|
||||
auto *val = v->getValue();
|
||||
auto *valBool = val ? (*M->getBoolType())(*val) : nullptr;
|
||||
if (!valBool) {
|
||||
valid = false;
|
||||
return;
|
||||
} else if (!any) {
|
||||
valBool = M->Nr<TernaryInstr>(valBool, M->getBool(false), M->getBool(true));
|
||||
}
|
||||
|
||||
auto *newReturn = M->Nr<ReturnInstr>(M->getBool(any));
|
||||
see(newReturn);
|
||||
auto *rep = M->Nr<IfFlow>(valBool, util::series(newReturn));
|
||||
v->replaceAll(rep);
|
||||
}
|
||||
|
||||
void handle(ReturnInstr *v) override {
|
||||
if (saw(v))
|
||||
return;
|
||||
auto *M = v->getModule();
|
||||
auto *newReturn = M->Nr<ReturnInstr>(M->getBool(!any));
|
||||
see(newReturn);
|
||||
v->replaceAll(util::series(v->getValue(), newReturn));
|
||||
}
|
||||
|
||||
void handle(YieldInInstr *v) override { valid = false; }
|
||||
};
|
||||
|
||||
Func *genToSum(BodiedFunc *gen, types::Type *startType, types::Type *outType) {
|
||||
if (!gen || !gen->isGenerator())
|
||||
return nullptr;
|
||||
|
||||
auto *M = gen->getModule();
|
||||
auto *fn = M->Nr<BodiedFunc>("__sum_wrapper");
|
||||
auto *genType = cast<types::FuncType>(gen->getType());
|
||||
if (!genType)
|
||||
return nullptr;
|
||||
|
||||
std::vector<types::Type *> argTypes(genType->begin(), genType->end());
|
||||
argTypes.push_back(startType);
|
||||
|
||||
std::vector<std::string> names;
|
||||
for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) {
|
||||
names.push_back((*it)->getName());
|
||||
}
|
||||
names.push_back("start");
|
||||
|
||||
auto *fnType = M->getFuncType(outType, argTypes);
|
||||
fn->realize(fnType, names);
|
||||
|
||||
std::unordered_map<id_t, Var *> argRemap;
|
||||
for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin();
|
||||
it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) {
|
||||
argRemap.emplace((*it1)->getId(), *it2);
|
||||
}
|
||||
|
||||
util::CloneVisitor cv(M);
|
||||
auto *body = cast<SeriesFlow>(cv.clone(gen->getBody(), fn, argRemap));
|
||||
fn->setBody(body);
|
||||
|
||||
Value *init = M->Nr<VarValue>(fn->arg_back());
|
||||
if (startType->is(M->getIntType()) && outType->is(M->getFloatType()))
|
||||
init = (*M->getFloatType())(*init);
|
||||
|
||||
if (!init || !init->getType()->is(outType))
|
||||
return nullptr;
|
||||
|
||||
auto *accumulator = util::makeVar(init, body, fn, /*prepend=*/true)->getVar();
|
||||
GeneratorSumTransformer xgen(accumulator);
|
||||
fn->accept(xgen);
|
||||
body->push_back(M->Nr<ReturnInstr>(M->Nr<VarValue>(accumulator)));
|
||||
|
||||
if (!xgen.valid)
|
||||
return nullptr;
|
||||
|
||||
return fn;
|
||||
}
|
||||
|
||||
Func *genToAnyAll(BodiedFunc *gen, bool any) {
|
||||
if (!gen || !gen->isGenerator())
|
||||
return nullptr;
|
||||
|
||||
auto *M = gen->getModule();
|
||||
auto *fn = M->Nr<BodiedFunc>(any ? "__any_wrapper" : "__all_wrapper");
|
||||
auto *genType = cast<types::FuncType>(gen->getType());
|
||||
|
||||
std::vector<types::Type *> argTypes(genType->begin(), genType->end());
|
||||
std::vector<std::string> names;
|
||||
for (auto it = gen->arg_begin(); it != gen->arg_end(); ++it) {
|
||||
names.push_back((*it)->getName());
|
||||
}
|
||||
|
||||
auto *fnType = M->getFuncType(M->getBoolType(), argTypes);
|
||||
fn->realize(fnType, names);
|
||||
|
||||
std::unordered_map<id_t, Var *> argRemap;
|
||||
for (auto it1 = gen->arg_begin(), it2 = fn->arg_begin();
|
||||
it1 != gen->arg_end() && it2 != fn->arg_end(); ++it1, ++it2) {
|
||||
argRemap.emplace((*it1)->getId(), *it2);
|
||||
}
|
||||
|
||||
util::CloneVisitor cv(M);
|
||||
auto *body = cast<SeriesFlow>(cv.clone(gen->getBody(), fn, argRemap));
|
||||
fn->setBody(body);
|
||||
|
||||
GeneratorAnyAllTransformer xgen(any);
|
||||
fn->accept(xgen);
|
||||
body->push_back(M->Nr<ReturnInstr>(M->getBool(!any)));
|
||||
|
||||
if (!xgen.valid)
|
||||
return nullptr;
|
||||
|
||||
return fn;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const std::string GeneratorArgumentOptimization::KEY =
|
||||
"core-pythonic-generator-argument-opt";
|
||||
|
||||
void GeneratorArgumentOptimization::handle(CallInstr *v) {
|
||||
auto *M = v->getModule();
|
||||
auto *func = util::getFunc(v->getCallee());
|
||||
|
||||
if (isSum(func) && v->numArgs() == 2) {
|
||||
auto *call = cast<CallInstr>(v->front());
|
||||
if (!call)
|
||||
return;
|
||||
|
||||
auto *gen = util::getFunc(call->getCallee());
|
||||
auto *start = v->back();
|
||||
|
||||
if (auto *fn = genToSum(cast<BodiedFunc>(gen), start->getType(), v->getType())) {
|
||||
std::vector<Value *> args(call->begin(), call->end());
|
||||
args.push_back(start);
|
||||
v->replaceAll(util::call(fn, args));
|
||||
}
|
||||
} else {
|
||||
bool any = isAny(func), all = isAll(func);
|
||||
if (!(any || all) || v->numArgs() != 1 || !v->getType()->is(M->getBoolType()))
|
||||
return;
|
||||
|
||||
auto *call = cast<CallInstr>(v->front());
|
||||
if (!call)
|
||||
return;
|
||||
|
||||
auto *gen = util::getFunc(call->getCallee());
|
||||
|
||||
if (auto *fn = genToAnyAll(cast<BodiedFunc>(gen), any)) {
|
||||
std::vector<Value *> args(call->begin(), call->end());
|
||||
v->replaceAll(util::call(fn, args));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pythonic
|
||||
} // namespace transform
|
||||
} // namespace ir
|
||||
} // namespace codon
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "codon/cir/transform/pass.h"
|
||||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
namespace transform {
|
||||
namespace pythonic {
|
||||
|
||||
/// Pass to optimize passing a generator to some built-in functions
|
||||
/// like sum(), any() or all(), which will be converted to regular
|
||||
/// for-loops.
|
||||
class GeneratorArgumentOptimization : public OperatorPass {
|
||||
public:
|
||||
static const std::string KEY;
|
||||
std::string getKey() const override { return KEY; }
|
||||
void handle(CallInstr *v) override;
|
||||
};
|
||||
|
||||
} // namespace pythonic
|
||||
} // namespace transform
|
||||
} // namespace ir
|
||||
} // namespace codon
|
|
@ -6,12 +6,12 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/cir/module.h"
|
||||
#include "codon/cir/util/irtools.h"
|
||||
#include "codon/cir/util/iterators.h"
|
||||
#include "codon/cir/util/visitor.h"
|
||||
#include "codon/cir/value.h"
|
||||
#include "codon/parser/cache.h"
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace codon {
|
||||
|
|
|
@ -8,10 +8,10 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/parser/ast.h"
|
||||
#include "codon/cir/base.h"
|
||||
#include "codon/cir/util/packs.h"
|
||||
#include "codon/cir/util/visitor.h"
|
||||
#include "codon/parser/ast.h"
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/ostream.h>
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/dsl/plugins.h"
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/cir/llvm/llvisitor.h"
|
||||
#include "codon/cir/module.h"
|
||||
#include "codon/cir/transform/manager.h"
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/dsl/plugins.h"
|
||||
#include "codon/parser/cache.h"
|
||||
|
||||
namespace codon {
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
#include "engine.h"
|
||||
|
||||
#include "codon/compiler/memory_manager.h"
|
||||
#include "codon/cir/llvm/optimize.h"
|
||||
#include "codon/compiler/memory_manager.h"
|
||||
|
||||
namespace codon {
|
||||
namespace jit {
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/compiler/debug_listener.h"
|
||||
#include "codon/cir/llvm/llvm.h"
|
||||
#include "codon/compiler/debug_listener.h"
|
||||
|
||||
namespace codon {
|
||||
namespace jit {
|
||||
|
|
|
@ -7,14 +7,14 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/llvm/llvisitor.h"
|
||||
#include "codon/cir/transform/manager.h"
|
||||
#include "codon/cir/var.h"
|
||||
#include "codon/compiler/compiler.h"
|
||||
#include "codon/compiler/engine.h"
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/runtime/lib.h"
|
||||
#include "codon/cir/llvm/llvisitor.h"
|
||||
#include "codon/cir/transform/manager.h"
|
||||
#include "codon/cir/var.h"
|
||||
|
||||
#include "codon/compiler/jit_extern.h"
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/cir/cir.h"
|
||||
#include "codon/cir/transform/manager.h"
|
||||
#include "codon/cir/transform/pass.h"
|
||||
#include "codon/parser/cache.h"
|
||||
#include "llvm/Passes/PassBuilder.h"
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/util/iterators.h"
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/dsl/dsl.h"
|
||||
#include "codon/cir/util/iterators.h"
|
||||
#include "llvm/Support/DynamicLibrary.h"
|
||||
|
||||
namespace codon {
|
||||
|
|
|
@ -8,10 +8,10 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/cir.h"
|
||||
#include "codon/parser/ast.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/ctx.h"
|
||||
#include "codon/cir/cir.h"
|
||||
|
||||
#define FILE_GENERATED "<generated>"
|
||||
#define MODULE_MAIN "__main__"
|
||||
|
|
|
@ -207,9 +207,12 @@ std::string library_path() {
|
|||
|
||||
namespace {
|
||||
|
||||
void addPath(std::vector<std::string> &paths, const std::string &path) {
|
||||
if (llvm::sys::fs::exists(path))
|
||||
bool addPath(std::vector<std::string> &paths, const std::string &path) {
|
||||
if (llvm::sys::fs::exists(path)) {
|
||||
paths.push_back(getAbsolutePath(path));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> getStdLibPaths(const std::string &argv0,
|
||||
|
@ -244,7 +247,9 @@ ImportFile getRoot(const std::string argv0, const std::vector<std::string> &plug
|
|||
}
|
||||
if (!isStdLib && startswith(s, module0Root))
|
||||
root = module0Root;
|
||||
const std::string ext = ".codon";
|
||||
std::string ext = ".codon";
|
||||
if (!((root.empty() || startswith(s, root)) && endswith(s, ext)))
|
||||
ext = ".py";
|
||||
seqassertn((root.empty() || startswith(s, root)) && endswith(s, ext),
|
||||
"bad path substitution: {}, {}", s, root);
|
||||
auto module = s.substr(root.size() + 1, s.size() - root.size() - ext.size() - 1);
|
||||
|
@ -280,6 +285,14 @@ std::shared_ptr<ImportFile> getImportFile(const std::string &argv0,
|
|||
path = llvm::SmallString<128>(parentRelativeTo);
|
||||
llvm::sys::path::append(path, what, "__init__.codon");
|
||||
addPath(paths, std::string(path));
|
||||
|
||||
path = llvm::SmallString<128>(parentRelativeTo);
|
||||
llvm::sys::path::append(path, what);
|
||||
llvm::sys::path::replace_extension(path, "py");
|
||||
addPath(paths, std::string(path));
|
||||
path = llvm::SmallString<128>(parentRelativeTo);
|
||||
llvm::sys::path::append(path, what, "__init__.py");
|
||||
addPath(paths, std::string(path));
|
||||
}
|
||||
}
|
||||
for (auto &p : getStdLibPaths(argv0, plugins)) {
|
||||
|
|
|
@ -14,6 +14,10 @@ using namespace codon::error;
|
|||
namespace codon::ast {
|
||||
|
||||
void SimplifyVisitor::visit(IdExpr *expr) {
|
||||
if (startswith(expr->value, TYPE_TUPLE)) {
|
||||
expr->markType();
|
||||
return;
|
||||
}
|
||||
auto val = ctx->findDominatingBinding(expr->value);
|
||||
if (!val)
|
||||
E(Error::ID_NOT_FOUND, expr, expr->value);
|
||||
|
|
|
@ -53,15 +53,19 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
|
|||
|
||||
auto loops = clone_nop(expr->loops); // Clone as loops will be modified
|
||||
|
||||
std::string optimizeVar;
|
||||
if (expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 &&
|
||||
loops[0].conds.empty()) {
|
||||
// List comprehension optimization:
|
||||
// Use `iter.__len__()` when creating list if there is a single for loop
|
||||
// without any if conditions in the comprehension
|
||||
optimizeVar = ctx->cache->getTemporaryVar("i");
|
||||
stmts.push_back(transform(N<AssignStmt>(N<IdExpr>(optimizeVar), loops[0].gen)));
|
||||
loops[0].gen = N<IdExpr>(optimizeVar);
|
||||
// List comprehension optimization:
|
||||
// Use `iter.__len__()` when creating list if there is a single for loop
|
||||
// without any if conditions in the comprehension
|
||||
bool canOptimize = expr->kind == GeneratorExpr::ListGenerator && loops.size() == 1 &&
|
||||
loops[0].conds.empty();
|
||||
if (canOptimize) {
|
||||
auto iter = transform(loops[0].gen);
|
||||
IdExpr *id;
|
||||
if (iter->getCall() && (id = iter->getCall()->expr->getId())) {
|
||||
// Turn off this optimization for static items
|
||||
canOptimize &= !startswith(id->value, "std.internal.types.range.staticrange");
|
||||
canOptimize &= !startswith(id->value, "statictuple");
|
||||
}
|
||||
}
|
||||
|
||||
SuiteStmt *prev = nullptr;
|
||||
|
@ -72,16 +76,32 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
|
|||
if (expr->kind == GeneratorExpr::ListGenerator) {
|
||||
// List comprehensions
|
||||
std::vector<ExprPtr> args;
|
||||
if (!optimizeVar.empty()) {
|
||||
// Use special List.__init__(bool, [optimizeVar]) constructor
|
||||
args = {N<BoolExpr>(true), N<IdExpr>(optimizeVar)};
|
||||
}
|
||||
stmts.push_back(
|
||||
transform(N<AssignStmt>(clone(var), N<CallExpr>(N<IdExpr>("List"), args))));
|
||||
prev->stmts.push_back(
|
||||
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "append"), clone(expr->expr))));
|
||||
stmts.push_back(transform(suite));
|
||||
resultExpr = N<StmtExpr>(stmts, transform(var));
|
||||
auto noOptStmt =
|
||||
N<SuiteStmt>(N<AssignStmt>(clone(var), N<CallExpr>(N<IdExpr>("List"))), suite);
|
||||
if (canOptimize) {
|
||||
seqassert(suite->getSuite() && !suite->getSuite()->stmts.empty() &&
|
||||
CAST(suite->getSuite()->stmts[0], ForStmt),
|
||||
"bad comprehension transformation");
|
||||
auto optimizeVar = ctx->cache->getTemporaryVar("i");
|
||||
auto optSuite = clone(suite);
|
||||
CAST(optSuite->getSuite()->stmts[0], ForStmt)->iter = N<IdExpr>(optimizeVar);
|
||||
|
||||
auto optStmt = N<SuiteStmt>(
|
||||
N<AssignStmt>(N<IdExpr>(optimizeVar), clone(expr->loops[0].gen)),
|
||||
N<AssignStmt>(
|
||||
clone(var),
|
||||
N<CallExpr>(N<IdExpr>("List"),
|
||||
N<CallExpr>(N<DotExpr>(N<IdExpr>(optimizeVar), "__len__")))),
|
||||
optSuite);
|
||||
resultExpr = transform(
|
||||
N<IfExpr>(N<CallExpr>(N<IdExpr>("hasattr"), clone(expr->loops[0].gen),
|
||||
N<StringExpr>("__len__")),
|
||||
N<StmtExpr>(optStmt, clone(var)), N<StmtExpr>(noOptStmt, var)));
|
||||
} else {
|
||||
resultExpr = transform(N<StmtExpr>(noOptStmt, var));
|
||||
}
|
||||
} else if (expr->kind == GeneratorExpr::SetGenerator) {
|
||||
// Set comprehensions
|
||||
stmts.push_back(
|
||||
|
@ -94,7 +114,16 @@ void SimplifyVisitor::visit(GeneratorExpr *expr) {
|
|||
// Generators: converted to lambda functions that yield the target expression
|
||||
prev->stmts.push_back(N<YieldStmt>(clone(expr->expr)));
|
||||
stmts.push_back(suite);
|
||||
resultExpr = N<CallExpr>(N<DotExpr>(N<CallExpr>(makeAnonFn(stmts)), "__iter__"));
|
||||
|
||||
auto anon = makeAnonFn(stmts);
|
||||
if (auto call = anon->getCall()) {
|
||||
seqassert(!call->args.empty() && call->args.back().value->getEllipsis(),
|
||||
"bad lambda: {}", *call);
|
||||
call->args.pop_back();
|
||||
} else {
|
||||
anon = N<CallExpr>(anon);
|
||||
}
|
||||
resultExpr = anon;
|
||||
}
|
||||
std::swap(avoidDomination, ctx->avoidDomination);
|
||||
}
|
||||
|
|
|
@ -322,8 +322,8 @@ ExprPtr SimplifyVisitor::makeAnonFn(std::vector<StmtPtr> suite,
|
|||
prependStmts->push_back(fs->stmts[0]);
|
||||
for (StmtPtr s = fs->stmts[1]; s;) {
|
||||
if (auto suite = s->getSuite()) {
|
||||
// Suites can only occur when __internal__.undef is inserted for a partial call
|
||||
// argument. Extract __internal__.undef checks and prepend them
|
||||
// Suites can only occur when captures are inserted for a partial call
|
||||
// argument.
|
||||
seqassert(suite->stmts.size() == 2, "invalid function transform");
|
||||
prependStmts->push_back(suite->stmts[0]);
|
||||
s = suite->stmts[1];
|
||||
|
|
|
@ -95,8 +95,12 @@ void SimplifyVisitor::visit(IndexExpr *expr) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Ignore it. Already transformed. Sometimes called again
|
||||
/// during class extension.
|
||||
void SimplifyVisitor::visit(InstantiateExpr *expr) {}
|
||||
/// Already transformed. Sometimes needed again
|
||||
/// for identifier analysis.
|
||||
void SimplifyVisitor::visit(InstantiateExpr *expr) {
|
||||
transformType(expr->typeExpr);
|
||||
for (auto &tp : expr->typeParams)
|
||||
transform(tp, true);
|
||||
}
|
||||
|
||||
} // namespace codon::ast
|
||||
|
|
|
@ -7,11 +7,11 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/transform/parallel/schedule.h"
|
||||
#include "codon/cir/util/cloning.h"
|
||||
#include "codon/parser/ast.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/visitors/translate/translate_ctx.h"
|
||||
#include "codon/cir/transform/parallel/schedule.h"
|
||||
#include "codon/cir/util/cloning.h"
|
||||
|
||||
using codon::ir::cast;
|
||||
using codon::ir::transform::parallel::OMPSched;
|
||||
|
|
|
@ -8,12 +8,12 @@
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/cir.h"
|
||||
#include "codon/parser/ast.h"
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/visitors/translate/translate_ctx.h"
|
||||
#include "codon/parser/visitors/visitor.h"
|
||||
#include "codon/cir/cir.h"
|
||||
|
||||
namespace codon::ast {
|
||||
|
||||
|
|
|
@ -7,11 +7,11 @@
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/cir.h"
|
||||
#include "codon/cir/types/types.h"
|
||||
#include "codon/parser/cache.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/ctx.h"
|
||||
#include "codon/cir/cir.h"
|
||||
#include "codon/cir/types/types.h"
|
||||
|
||||
namespace codon::ast {
|
||||
|
||||
|
|
|
@ -909,22 +909,30 @@ void TypecheckVisitor::addFunctionGenerics(const FuncType *t) {
|
|||
for (auto parent = t->funcParent; parent;) {
|
||||
if (auto f = parent->getFunc()) {
|
||||
// Add parent function generics
|
||||
for (auto &g : f->funcGenerics)
|
||||
for (auto &g : f->funcGenerics) {
|
||||
// LOG(" -> {} := {}", g.name, g.type->debugString(true));
|
||||
ctx->add(TypecheckItem::Type, g.name, g.type);
|
||||
}
|
||||
parent = f->funcParent;
|
||||
} else {
|
||||
// Add parent class generics
|
||||
seqassert(parent->getClass(), "not a class: {}", parent);
|
||||
for (auto &g : parent->getClass()->generics)
|
||||
for (auto &g : parent->getClass()->generics) {
|
||||
// LOG(" => {} := {}", g.name, g.type->debugString(true));
|
||||
ctx->add(TypecheckItem::Type, g.name, g.type);
|
||||
for (auto &g : parent->getClass()->hiddenGenerics)
|
||||
}
|
||||
for (auto &g : parent->getClass()->hiddenGenerics) {
|
||||
// LOG(" :> {} := {}", g.name, g.type->debugString(true));
|
||||
ctx->add(TypecheckItem::Type, g.name, g.type);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Add function generics
|
||||
for (auto &g : t->funcGenerics)
|
||||
for (auto &g : t->funcGenerics) {
|
||||
// LOG(" >> {} := {}", g.name, g.type->debugString(true));
|
||||
ctx->add(TypecheckItem::Type, g.name, g.type);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a partial type `Partial.N<mask>` for a given function.
|
||||
|
|
|
@ -155,7 +155,8 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
|
|||
// Generalize generics and remove them from the context
|
||||
for (const auto &g : generics) {
|
||||
for (auto &u : g->getUnbounds())
|
||||
u->getUnbound()->kind = LinkType::Generic;
|
||||
if (u->getUnbound())
|
||||
u->getUnbound()->kind = LinkType::Generic;
|
||||
}
|
||||
|
||||
// Construct the type
|
||||
|
@ -163,8 +164,9 @@ void TypecheckVisitor::visit(FunctionStmt *stmt) {
|
|||
baseType, ctx->cache->functions[stmt->name].ast.get(), explicits);
|
||||
|
||||
funcTyp->setSrcInfo(getSrcInfo());
|
||||
if (isClassMember && stmt->attributes.has(Attr::Method))
|
||||
if (isClassMember && stmt->attributes.has(Attr::Method)) {
|
||||
funcTyp->funcParent = ctx->find(stmt->attributes.parentClass)->type;
|
||||
}
|
||||
funcTyp =
|
||||
std::static_pointer_cast<FuncType>(funcTyp->generalize(ctx->typecheckLevel));
|
||||
|
||||
|
|
|
@ -7,11 +7,11 @@
|
|||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/types/types.h"
|
||||
#include "codon/parser/ast.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/parser/visitors/simplify/simplify.h"
|
||||
#include "codon/parser/visitors/typecheck/typecheck.h"
|
||||
#include "codon/cir/types/types.h"
|
||||
|
||||
using fmt::format;
|
||||
using namespace codon::error;
|
||||
|
|
|
@ -636,7 +636,7 @@ ExprPtr TypecheckVisitor::transformBinaryInplaceMagic(BinaryExpr *expr, bool isA
|
|||
|
||||
// In-place operations: check if `lhs.__iop__(lhs, rhs)` exists
|
||||
if (!method && expr->inPlace) {
|
||||
method = findBestMethod(lt, format("__i{}__", magic), {lt, rt});
|
||||
method = findBestMethod(lt, format("__i{}__", magic), {expr->lexpr, expr->rexpr});
|
||||
}
|
||||
|
||||
if (method)
|
||||
|
@ -667,11 +667,11 @@ ExprPtr TypecheckVisitor::transformBinaryMagic(BinaryExpr *expr) {
|
|||
}
|
||||
|
||||
// Normal operations: check if `lhs.__magic__(lhs, rhs)` exists
|
||||
auto method = findBestMethod(lt, format("__{}__", magic), {lt, rt});
|
||||
auto method = findBestMethod(lt, format("__{}__", magic), {expr->lexpr, expr->rexpr});
|
||||
|
||||
// Right-side magics: check if `rhs.__rmagic__(rhs, lhs)` exists
|
||||
if (!method &&
|
||||
(method = findBestMethod(rt, format("__{}__", rightMagic), {rt, lt}))) {
|
||||
if (!method && (method = findBestMethod(rt, format("__{}__", rightMagic),
|
||||
{expr->rexpr, expr->lexpr}))) {
|
||||
swap(expr->lexpr, expr->rexpr);
|
||||
}
|
||||
|
||||
|
|
|
@ -187,6 +187,19 @@ TypecheckVisitor::findBestMethod(const ClassTypePtr &typ, const std::string &mem
|
|||
return m.empty() ? nullptr : m[0];
|
||||
}
|
||||
|
||||
/// Select the best method indicated of an object that matches the given argument
|
||||
/// types. See @c findMatchingMethods for details.
|
||||
types::FuncTypePtr TypecheckVisitor::findBestMethod(const ClassTypePtr &typ,
|
||||
const std::string &member,
|
||||
const std::vector<ExprPtr> &args) {
|
||||
std::vector<CallExpr::Arg> callArgs;
|
||||
for (auto &a : args)
|
||||
callArgs.push_back({"", a});
|
||||
auto methods = ctx->findMethod(typ->name, member, false);
|
||||
auto m = findMatchingMethods(typ, methods, callArgs);
|
||||
return m.empty() ? nullptr : m[0];
|
||||
}
|
||||
|
||||
/// Select the best method among the provided methods given the list of arguments.
|
||||
/// See @c reorderNamedArgs for details.
|
||||
std::vector<types::FuncTypePtr>
|
||||
|
|
|
@ -207,6 +207,9 @@ private:
|
|||
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
|
||||
const std::string &member,
|
||||
const std::vector<types::TypePtr> &args);
|
||||
types::FuncTypePtr findBestMethod(const types::ClassTypePtr &typ,
|
||||
const std::string &member,
|
||||
const std::vector<ExprPtr> &args);
|
||||
std::vector<types::FuncTypePtr>
|
||||
findMatchingMethods(const types::ClassTypePtr &typ,
|
||||
const std::vector<types::FuncTypePtr> &methods,
|
||||
|
|
|
@ -14,8 +14,8 @@ in mind.
|
|||
- **Strings:** Codon currently uses ASCII strings unlike
|
||||
Python's unicode strings.
|
||||
|
||||
- **Dictionaries:** Codon's dictionary type is not sorted
|
||||
internally, unlike Python's.
|
||||
- **Dictionaries:** Codon's dictionary type does not preserve
|
||||
insertion order, unlike Python's as of 3.6.
|
||||
|
||||
# Type checking
|
||||
|
||||
|
|
|
@ -17,13 +17,21 @@ sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
|
|||
from .codon_jit import JITWrapper, JITError, codon_library
|
||||
|
||||
if "CODON_PATH" not in os.environ:
|
||||
codon_path = []
|
||||
codon_lib_path = codon_library()
|
||||
if not codon_lib_path:
|
||||
if codon_lib_path:
|
||||
codon_path.append(Path(codon_lib_path).parent / "stdlib")
|
||||
codon_path.append(
|
||||
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib"
|
||||
)
|
||||
for path in codon_path:
|
||||
if path.exists():
|
||||
os.environ["CODON_PATH"] = str(path.resolve())
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot locate Codon. Please install Codon or set CODON_PATH."
|
||||
)
|
||||
codon_path = (Path(codon_lib_path).parent / "stdlib").resolve()
|
||||
os.environ["CODON_PATH"] = str(codon_path)
|
||||
|
||||
pod_conversions = {
|
||||
type(None): "pyobj",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
@tuple
|
||||
@tuple(container=False) # disallow default __getitem__
|
||||
class Vec[T, N: Static[int]]:
|
||||
ZERO_16x8i = Vec[u8,16](u8(0))
|
||||
FF_16x8i = Vec[u8,16](u8(0xff))
|
||||
|
@ -307,6 +307,10 @@ class Vec[T, N: Static[int]]:
|
|||
else:
|
||||
return "?"
|
||||
|
||||
def scatter(self: Vec[T, N]) -> List[T]:
|
||||
return [self[i] for i in staticrange(N)]
|
||||
|
||||
|
||||
u8x16 = Vec[u8, 16]
|
||||
u8x32 = Vec[u8, 32]
|
||||
f32x8 = Vec[f32, 8]
|
||||
|
|
|
@ -248,19 +248,26 @@ def round(x, n=0):
|
|||
nx = float.__pow__(10.0, n)
|
||||
return float.__round__(x * nx) / nx
|
||||
|
||||
def sum(xi):
|
||||
"""
|
||||
Return the sum of the items added together from xi
|
||||
"""
|
||||
x = iter(xi)
|
||||
if not x.done():
|
||||
s = x.next()
|
||||
while not x.done():
|
||||
s += x.next()
|
||||
x.destroy()
|
||||
return s
|
||||
def _sum_start(x, start):
|
||||
if isinstance(x.__iter__(), Generator[float]) and isinstance(start, int):
|
||||
return float(start)
|
||||
else:
|
||||
x.destroy()
|
||||
return start
|
||||
|
||||
def sum(x, start=0):
|
||||
"""
|
||||
Return the sum of the items added together from x
|
||||
"""
|
||||
s = _sum_start(x, start)
|
||||
|
||||
for a in x:
|
||||
# don't use += to avoid calling iadd
|
||||
if isinstance(a, bool):
|
||||
s = s + (1 if a else 0)
|
||||
else:
|
||||
s = s + a
|
||||
|
||||
return s
|
||||
|
||||
def repr(x):
|
||||
"""Return the string representation of x"""
|
||||
|
|
|
@ -24,13 +24,6 @@ class List:
|
|||
self.arr = Array[T](capacity)
|
||||
self.len = 0
|
||||
|
||||
def __init__(self, dummy: bool, other):
|
||||
"""Dummy __init__ used for list comprehension optimization"""
|
||||
if hasattr(other, "__len__"):
|
||||
self.__init__(other.__len__())
|
||||
else:
|
||||
self.__init__()
|
||||
|
||||
def __init__(self, arr: Array[T], len: int):
|
||||
self.arr = arr
|
||||
self.len = len
|
||||
|
|
|
@ -85,6 +85,70 @@ def test_map_filter():
|
|||
|
||||
assert list(filter(lambda i: i%2 == 0, map(lambda i: i*i, range(10)))) == [0, 4, 16, 36, 64]
|
||||
|
||||
@test
|
||||
def test_gen_builtins():
|
||||
assert sum([1, 2, 3]) == 6
|
||||
assert sum([1, 2, 3], 0.5) == 6.5
|
||||
assert sum([True, False, True, False, True], 0.5) == 3.5
|
||||
assert sum(List[float]()) == 0
|
||||
assert sum(i/2 for i in range(10)) == 22.5
|
||||
|
||||
def g1():
|
||||
yield 1.5
|
||||
yield 2.5
|
||||
return
|
||||
yield 3.5
|
||||
|
||||
assert sum(g1(), 10) == 14.0
|
||||
|
||||
def g2():
|
||||
yield True
|
||||
yield False
|
||||
yield True
|
||||
|
||||
assert sum(g2()) == 2
|
||||
|
||||
class A:
|
||||
iadd_count = 0
|
||||
n: int
|
||||
|
||||
def __init__(self, n):
|
||||
self.n = n
|
||||
|
||||
def __add__(self, other):
|
||||
return A(self.n + other.n)
|
||||
|
||||
def __iadd__(self, other):
|
||||
A.iadd_count += 1
|
||||
self.n += other.n
|
||||
return self
|
||||
|
||||
assert sum((A(i) for i in range(5)), A(100)).n == 110
|
||||
assert A.iadd_count == 0
|
||||
|
||||
def g3(a, b):
|
||||
for i in range(10):
|
||||
yield a
|
||||
yield b
|
||||
|
||||
assert all([True, True])
|
||||
assert all(i for i in range(0))
|
||||
assert not all([True, False])
|
||||
assert all(List[str]())
|
||||
assert all(g3(True, True))
|
||||
assert not all(g3(True, False))
|
||||
assert not all(g3(False, True))
|
||||
assert not all(g3(False, False))
|
||||
|
||||
assert any([True, True])
|
||||
assert not any(i for i in range(0))
|
||||
assert not any([False, False])
|
||||
assert not any(List[bool]())
|
||||
assert any(g3(True, True))
|
||||
assert any(g3(True, False))
|
||||
assert any(g3(False, True))
|
||||
assert not any(g3(False, False))
|
||||
|
||||
@test
|
||||
def test_int_format():
|
||||
n = 0
|
||||
|
@ -269,6 +333,7 @@ def test_files(open_fn):
|
|||
|
||||
test_min_max()
|
||||
test_map_filter()
|
||||
test_gen_builtins()
|
||||
test_int_format()
|
||||
test_reversed()
|
||||
test_divmod()
|
||||
|
|
|
@ -15,15 +15,15 @@
|
|||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/compiler/compiler.h"
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/cir/analyze/dataflow/capture.h"
|
||||
#include "codon/cir/analyze/dataflow/reaching.h"
|
||||
#include "codon/cir/util/inlining.h"
|
||||
#include "codon/cir/util/irtools.h"
|
||||
#include "codon/cir/util/operator.h"
|
||||
#include "codon/cir/util/outlining.h"
|
||||
#include "codon/compiler/compiler.h"
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/parser/common.h"
|
||||
#include "codon/util/common.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
|
|
@ -157,12 +157,10 @@ print d #: {0: 9}
|
|||
#%% comprehension_opt,barebones
|
||||
@extend
|
||||
class List:
|
||||
def __init__(self, dummy: bool, other):
|
||||
if hasattr(other, '__len__'):
|
||||
print 'optimize', other.__len__()
|
||||
self.__init__(other.__len__())
|
||||
else:
|
||||
self.__init__()
|
||||
def __init__(self, cap: int):
|
||||
print 'optimize', cap
|
||||
self.arr = Array[T](cap)
|
||||
self.len = 0
|
||||
def foo():
|
||||
yield 0
|
||||
yield 1
|
||||
|
|
|
@ -457,7 +457,7 @@ def test_omp_reductions():
|
|||
@par
|
||||
for i in L[1:1001]:
|
||||
c += f32(i)
|
||||
assert c == sum(f32(i) for i in range(1001))
|
||||
assert c == sum((f32(i) for i in range(1001)), f32(0))
|
||||
|
||||
c = f32(1.)
|
||||
@par
|
||||
|
|
Loading…
Reference in New Issue