mirror of https://github.com/exaloop/codon.git
Compare commits
16 Commits
Author | SHA1 | Date |
---|---|---|
|
dcb41dcfc9 | |
|
c1dae7d87d | |
|
984974b40d | |
|
915cb4e9f0 | |
|
ce5c49edb5 | |
|
59f5bbb73b | |
|
93fb3d53e3 | |
|
b3f6c12d57 | |
|
b17d21513d | |
|
d035f1dc97 | |
|
dc5e5ac7a6 | |
|
01a7503762 | |
|
f1ab7116d8 | |
|
b58b1ee767 | |
|
56c00d36c2 | |
|
4521182aa8 |
|
@ -187,7 +187,7 @@ jobs:
|
|||
- name: Prepare Artifacts
|
||||
run: |
|
||||
cp -rf codon-deploy/python/dist .
|
||||
rm -rf codon-deploy/lib/libfmt.a codon-deploy/lib/pkgconfig codon-deploy/lib/cmake codon-deploy/python/codon.egg-info codon-deploy/python/dist codon-deploy/python/build
|
||||
rm -rf codon-deploy/lib/libfmt.a codon-deploy/lib/pkgconfig codon-deploy/lib/cmake codon-deploy/python/codon_jit.egg-info codon-deploy/python/build
|
||||
tar -czf ${CODON_BUILD_ARCHIVE} codon-deploy
|
||||
du -sh codon-deploy
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
cmake_minimum_required(VERSION 3.14)
|
||||
project(
|
||||
Codon
|
||||
VERSION "0.18.0"
|
||||
VERSION "0.18.2"
|
||||
HOMEPAGE_URL "https://github.com/exaloop/codon"
|
||||
DESCRIPTION "high-performance, extensible Python compiler")
|
||||
set(CODON_JIT_PYTHON_VERSION "0.3.0")
|
||||
set(CODON_JIT_PYTHON_VERSION "0.3.2")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
|
||||
"${PROJECT_SOURCE_DIR}/codon/config/config.h")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
|
||||
|
@ -48,10 +48,8 @@ include(${CMAKE_SOURCE_DIR}/cmake/deps.cmake)
|
|||
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
|
||||
if(APPLE)
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path;@loader_path/../lib/codon")
|
||||
set(STATIC_LIBCPP "")
|
||||
else()
|
||||
set(CMAKE_INSTALL_RPATH "$ORIGIN:$ORIGIN/../lib/codon")
|
||||
set(STATIC_LIBCPP "-static-libstdc++")
|
||||
endif()
|
||||
|
||||
add_executable(peg2cpp codon/util/peg2cpp.cpp)
|
||||
|
@ -138,7 +136,7 @@ target_include_directories(codonrt PRIVATE ${backtrace_SOURCE_DIR}
|
|||
${highway_SOURCE_DIR}
|
||||
"${gc_SOURCE_DIR}/include"
|
||||
"${fast_float_SOURCE_DIR}/include" runtime)
|
||||
target_link_libraries(codonrt PRIVATE fmt omp backtrace ${STATIC_LIBCPP} LLVMSupport)
|
||||
target_link_libraries(codonrt PRIVATE fmt omp backtrace LLVMSupport)
|
||||
if(APPLE)
|
||||
target_link_libraries(
|
||||
codonrt
|
||||
|
@ -434,11 +432,7 @@ llvm_map_components_to_libnames(
|
|||
TransformUtils
|
||||
Vectorize
|
||||
Passes)
|
||||
if(APPLE)
|
||||
target_link_libraries(codonc PRIVATE ${LLVM_LIBS} fmt dl codonrt)
|
||||
else()
|
||||
target_link_libraries(codonc PRIVATE ${STATIC_LIBCPP} ${LLVM_LIBS} fmt dl codonrt)
|
||||
endif()
|
||||
target_link_libraries(codonc PRIVATE ${LLVM_LIBS} fmt dl codonrt)
|
||||
|
||||
# Gather headers
|
||||
add_custom_target(
|
||||
|
@ -482,13 +476,13 @@ add_dependencies(libs codonrt codonc)
|
|||
|
||||
# Codon command-line tool
|
||||
add_executable(codon codon/app/main.cpp)
|
||||
target_link_libraries(codon PUBLIC ${STATIC_LIBCPP} fmt codonc codon_jupyter Threads::Threads)
|
||||
target_link_libraries(codon PUBLIC fmt codonc codon_jupyter Threads::Threads)
|
||||
|
||||
# Codon test Download and unpack googletest at configure time
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
googletest
|
||||
URL https://github.com/google/googletest/archive/609281088cfefc76f9d0ce82e1ff6c30cc3591e5.zip
|
||||
URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
|
||||
)
|
||||
# For Windows: Prevent overriding the parent project's compiler/linker settings
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
||||
|
|
|
@ -149,7 +149,7 @@ print(total)
|
|||
```
|
||||
|
||||
Note that Codon automatically turns the `total += 1` statement in the loop body into an atomic
|
||||
reduction to avoid race conditions. Learn more in the [multitheading docs](advanced/parallel.md).
|
||||
reduction to avoid race conditions. Learn more in the [multithreading docs](https://docs.exaloop.io/codon/advanced/parallel).
|
||||
|
||||
Codon also supports writing and executing GPU kernels. Here's an example that computes the
|
||||
[Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
set(CPM_DOWNLOAD_VERSION 0.32.3)
|
||||
set(CPM_DOWNLOAD_VERSION 0.40.8)
|
||||
set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
|
||||
if(NOT (EXISTS ${CPM_DOWNLOAD_LOCATION}))
|
||||
message(STATUS "Downloading CPM.cmake...")
|
||||
file(DOWNLOAD https://github.com/TheLartians/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake ${CPM_DOWNLOAD_LOCATION})
|
||||
file(DOWNLOAD https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake ${CPM_DOWNLOAD_LOCATION})
|
||||
endif()
|
||||
include(${CPM_DOWNLOAD_LOCATION})
|
||||
|
||||
|
@ -77,9 +77,9 @@ endif()
|
|||
|
||||
CPMAddPackage(
|
||||
NAME bdwgc
|
||||
GITHUB_REPOSITORY "ivmai/bdwgc"
|
||||
GITHUB_REPOSITORY "exaloop/bdwgc"
|
||||
VERSION 8.0.5
|
||||
GIT_TAG d0ba209660ea8c663e06d9a68332ba5f42da54ba
|
||||
GIT_TAG e16c67244aff26802203060422545d38305e0160
|
||||
EXCLUDE_FROM_ALL YES
|
||||
OPTIONS "CMAKE_POSITION_INDEPENDENT_CODE ON"
|
||||
"BUILD_SHARED_LIBS OFF"
|
||||
|
@ -169,7 +169,7 @@ if(NOT APPLE)
|
|||
CPMAddPackage(
|
||||
NAME openblas
|
||||
GITHUB_REPOSITORY "OpenMathLib/OpenBLAS"
|
||||
GIT_TAG v0.3.28
|
||||
GIT_TAG v0.3.29
|
||||
EXCLUDE_FROM_ALL YES
|
||||
OPTIONS "DYNAMIC_ARCH ON"
|
||||
"BUILD_TESTING OFF"
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/cir/util/format.h"
|
||||
#include "codon/compiler/compiler.h"
|
||||
#include "codon/compiler/error.h"
|
||||
#include "codon/compiler/jit.h"
|
||||
|
@ -87,7 +88,7 @@ void initLogFlags(const llvm::cl::opt<std::string> &log) {
|
|||
codon::getLogger().parse(std::string(d));
|
||||
}
|
||||
|
||||
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect };
|
||||
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect, CIR };
|
||||
enum OptMode { Debug, Release };
|
||||
enum Numerics { C, Python };
|
||||
} // namespace
|
||||
|
@ -333,6 +334,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||
clEnumValN(Executable, "exe", "Generate executable"),
|
||||
clEnumValN(Library, "lib", "Generate shared library"),
|
||||
clEnumValN(PyExtension, "pyext", "Generate Python extension module"),
|
||||
clEnumValN(CIR, "cir", "Generate Codon Intermediate Representation"),
|
||||
clEnumValN(Detect, "detect",
|
||||
"Detect output type based on output file extension")),
|
||||
llvm::cl::init(Detect));
|
||||
|
@ -372,6 +374,9 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||
case BuildKind::Detect:
|
||||
extension = "";
|
||||
break;
|
||||
case BuildKind::CIR:
|
||||
extension = ".cir";
|
||||
break;
|
||||
default:
|
||||
seqassertn(0, "unknown build kind");
|
||||
}
|
||||
|
@ -401,6 +406,11 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
|
|||
compiler->getLLVMVisitor()->writeToPythonExtension(*compiler->getCache()->pyModule,
|
||||
filename);
|
||||
break;
|
||||
case BuildKind::CIR: {
|
||||
std::ofstream out(filename);
|
||||
codon::ir::util::format(out, compiler->getModule());
|
||||
break;
|
||||
}
|
||||
case BuildKind::Detect:
|
||||
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
|
||||
break;
|
||||
|
|
|
@ -402,7 +402,8 @@ struct ReductionIdentifier : public util::Operator {
|
|||
static void extractAssociativeOpChain(Value *v, const std::string &op,
|
||||
types::Type *type,
|
||||
std::vector<Value *> &result) {
|
||||
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) {
|
||||
if (util::isCallOf(v, op, {type, nullptr}, type, /*method=*/true) ||
|
||||
util::isCallOf(v, op, {nullptr, type}, type, /*method=*/true)) {
|
||||
auto *call = cast<CallInstr>(v);
|
||||
extractAssociativeOpChain(call->front(), op, type, result);
|
||||
extractAssociativeOpChain(call->back(), op, type, result);
|
||||
|
@ -450,7 +451,8 @@ struct ReductionIdentifier : public util::Operator {
|
|||
|
||||
for (auto &rf : reductionFunctions) {
|
||||
if (rf.method) {
|
||||
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true))
|
||||
if (!(util::isCallOf(item, rf.name, {type, nullptr}, type, /*method=*/true) ||
|
||||
util::isCallOf(item, rf.name, {nullptr, type}, type, /*method=*/true)))
|
||||
continue;
|
||||
} else {
|
||||
if (!util::isCallOf(item, rf.name,
|
||||
|
@ -464,8 +466,7 @@ struct ReductionIdentifier : public util::Operator {
|
|||
|
||||
if (rf.method) {
|
||||
std::vector<Value *> opChain;
|
||||
extractAssociativeOpChain(callRHS, rf.name, callRHS->front()->getType(),
|
||||
opChain);
|
||||
extractAssociativeOpChain(callRHS, rf.name, type, opChain);
|
||||
if (opChain.size() < 2)
|
||||
continue;
|
||||
|
||||
|
|
|
@ -38,16 +38,21 @@ bool isCallOf(const Value *value, const std::string &name,
|
|||
|
||||
unsigned i = 0;
|
||||
for (auto *arg : *call) {
|
||||
if (!arg->getType()->is(inputs[i++]))
|
||||
if (inputs[i] && !arg->getType()->is(inputs[i]))
|
||||
return false;
|
||||
++i;
|
||||
}
|
||||
|
||||
if (output && !value->getType()->is(output))
|
||||
return false;
|
||||
|
||||
if (method &&
|
||||
(inputs.empty() || !fn->getParentType() || !fn->getParentType()->is(inputs[0])))
|
||||
return false;
|
||||
if (method) {
|
||||
if (inputs.empty() || !fn->getParentType())
|
||||
return false;
|
||||
|
||||
if (inputs[0] && !fn->getParentType()->is(inputs[0]))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -263,21 +263,21 @@ ir::types::Type *JIT::PythonData::getCObjType(ir::Module *M) {
|
|||
return cobj;
|
||||
}
|
||||
|
||||
JITResult JIT::executeSafe(const std::string &code, const std::string &file, int line,
|
||||
bool debug) {
|
||||
JIT::JITResult JIT::executeSafe(const std::string &code, const std::string &file,
|
||||
int line, bool debug) {
|
||||
auto result = execute(code, file, line, debug);
|
||||
if (auto err = result.takeError()) {
|
||||
auto errorInfo = llvm::toString(std::move(err));
|
||||
return JITResult::error(errorInfo);
|
||||
}
|
||||
return JITResult::success(nullptr);
|
||||
return JITResult::success();
|
||||
}
|
||||
|
||||
JITResult JIT::executePython(const std::string &name,
|
||||
const std::vector<std::string> &types,
|
||||
const std::string &pyModule,
|
||||
const std::vector<std::string> &pyVars, void *arg,
|
||||
bool debug) {
|
||||
JIT::JITResult JIT::executePython(const std::string &name,
|
||||
const std::vector<std::string> &types,
|
||||
const std::string &pyModule,
|
||||
const std::vector<std::string> &pyVars, void *arg,
|
||||
bool debug) {
|
||||
auto key = buildKey(name, types);
|
||||
auto &cache = pydata->cache;
|
||||
auto it = cache.find(key);
|
||||
|
@ -322,26 +322,48 @@ JITResult JIT::executePython(const std::string &name,
|
|||
}
|
||||
}
|
||||
|
||||
JIT *jitInit(const std::string &name) {
|
||||
auto jit = new JIT(name);
|
||||
} // namespace jit
|
||||
} // namespace codon
|
||||
|
||||
void *jit_init(char *name) {
|
||||
auto jit = new codon::jit::JIT(std::string(name));
|
||||
llvm::cantFail(jit->init());
|
||||
return jit;
|
||||
}
|
||||
|
||||
JITResult jitExecutePython(JIT *jit, const std::string &name,
|
||||
const std::vector<std::string> &types,
|
||||
const std::string &pyModule,
|
||||
const std::vector<std::string> &pyVars, void *arg,
|
||||
bool debug) {
|
||||
return jit->executePython(name, types, pyModule, pyVars, arg, debug);
|
||||
void jit_exit(void *jit) { delete ((codon::jit::JIT *)jit); }
|
||||
|
||||
CJITResult jit_execute_python(void *jit, char *name, char **types, size_t types_size,
|
||||
char *pyModule, char **py_vars, size_t py_vars_size,
|
||||
void *arg, uint8_t debug) {
|
||||
std::vector<std::string> cppTypes;
|
||||
cppTypes.reserve(types_size);
|
||||
for (size_t i = 0; i < types_size; i++)
|
||||
cppTypes.emplace_back(types[i]);
|
||||
std::vector<std::string> cppPyVars;
|
||||
cppPyVars.reserve(py_vars_size);
|
||||
for (size_t i = 0; i < py_vars_size; i++)
|
||||
cppPyVars.emplace_back(py_vars[i]);
|
||||
auto t = ((codon::jit::JIT *)jit)
|
||||
->executePython(std::string(name), cppTypes, std::string(pyModule),
|
||||
cppPyVars, arg, bool(debug));
|
||||
void *result = t.result;
|
||||
char *message =
|
||||
t.message.empty() ? nullptr : strndup(t.message.c_str(), t.message.size());
|
||||
return {result, message};
|
||||
}
|
||||
|
||||
JITResult jitExecuteSafe(JIT *jit, const std::string &code, const std::string &file,
|
||||
int line, bool debug) {
|
||||
return jit->executeSafe(code, file, line, debug);
|
||||
CJITResult jit_execute_safe(void *jit, char *code, char *file, int32_t line,
|
||||
uint8_t debug) {
|
||||
auto t = ((codon::jit::JIT *)jit)
|
||||
->executeSafe(std::string(code), std::string(file), line, bool(debug));
|
||||
void *result = t.result;
|
||||
char *message =
|
||||
t.message.empty() ? nullptr : strndup(t.message.c_str(), t.message.size());
|
||||
return {result, message};
|
||||
}
|
||||
|
||||
std::string getJITLibrary() { return ast::library_path(); }
|
||||
|
||||
} // namespace jit
|
||||
} // namespace codon
|
||||
char *get_jit_library() {
|
||||
auto t = codon::ast::library_path();
|
||||
return strndup(t.c_str(), t.size());
|
||||
}
|
||||
|
|
|
@ -31,6 +31,15 @@ public:
|
|||
ir::types::Type *getCObjType(ir::Module *M);
|
||||
};
|
||||
|
||||
struct JITResult {
|
||||
void *result;
|
||||
std::string message;
|
||||
|
||||
operator bool() const { return message.empty(); }
|
||||
static JITResult success(void *result = nullptr) { return {result, ""}; }
|
||||
static JITResult error(const std::string &message) { return {nullptr, message}; }
|
||||
};
|
||||
|
||||
private:
|
||||
std::unique_ptr<Compiler> compiler;
|
||||
std::unique_ptr<Engine> engine;
|
||||
|
|
|
@ -2,35 +2,30 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace codon {
|
||||
namespace jit {
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
class JIT;
|
||||
|
||||
struct JITResult {
|
||||
struct CJITResult {
|
||||
void *result;
|
||||
std::string message;
|
||||
|
||||
operator bool() const { return message.empty(); }
|
||||
static JITResult success(void *result) { return {result, ""}; }
|
||||
static JITResult error(const std::string &message) { return {nullptr, message}; }
|
||||
char *error;
|
||||
};
|
||||
|
||||
JIT *jitInit(const std::string &name);
|
||||
void *jit_init(char *name);
|
||||
void jit_exit(void *jit);
|
||||
|
||||
JITResult jitExecutePython(JIT *jit, const std::string &name,
|
||||
const std::vector<std::string> &types,
|
||||
const std::string &pyModule,
|
||||
const std::vector<std::string> &pyVars, void *arg,
|
||||
bool debug);
|
||||
struct CJITResult jit_execute_python(void *jit, char *name, char **types,
|
||||
size_t types_size, char *pyModule, char **py_vars,
|
||||
size_t py_vars_size, void *arg, uint8_t debug);
|
||||
|
||||
JITResult jitExecuteSafe(JIT *jit, const std::string &code, const std::string &file,
|
||||
int line, bool debug);
|
||||
struct CJITResult jit_execute_safe(void *jit, char *code, char *file, int32_t line,
|
||||
uint8_t debug);
|
||||
|
||||
std::string getJITLibrary();
|
||||
char *get_jit_library();
|
||||
|
||||
} // namespace jit
|
||||
} // namespace codon
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -124,7 +124,7 @@ print(total)
|
|||
```
|
||||
|
||||
Note that Codon automatically turns the `total += 1` statement in the loop body into an atomic
|
||||
reduction to avoid race conditions. Learn more in the [multitheading docs](advanced/parallel.md).
|
||||
reduction to avoid race conditions. Learn more in the [multithreading docs](advanced/parallel.md).
|
||||
|
||||
Codon also supports writing and executing GPU kernels. Here's an example that computes the
|
||||
[Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set):
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
__all__ = ["jit", "convert", "JITError"]
|
||||
__all__ = [
|
||||
"jit", "convert", "JITError", "JITWrapper", "_jit_register_fn", "_jit"
|
||||
]
|
||||
|
||||
from .decorator import jit, convert, execute, JITError
|
||||
from .decorator import jit, convert, execute, JITError, JITWrapper, _jit_register_fn, _jit_callback_fn, _jit
|
||||
|
|
|
@ -23,16 +23,14 @@ if "CODON_PATH" not in os.environ:
|
|||
if codon_lib_path:
|
||||
codon_path.append(Path(codon_lib_path).parent / "stdlib")
|
||||
codon_path.append(
|
||||
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib"
|
||||
)
|
||||
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib")
|
||||
for path in codon_path:
|
||||
if path.exists():
|
||||
os.environ["CODON_PATH"] = str(path.resolve())
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot locate Codon. Please install Codon or set CODON_PATH."
|
||||
)
|
||||
"Cannot locate Codon. Please install Codon or set CODON_PATH.")
|
||||
|
||||
pod_conversions = {
|
||||
type(None): "pyobj",
|
||||
|
@ -61,7 +59,6 @@ pod_conversions = {
|
|||
custom_conversions = {}
|
||||
_error_msgs = set()
|
||||
|
||||
|
||||
def _common_type(t, debug, sample_size):
|
||||
sub, is_optional = None, False
|
||||
for i in itertools.islice(t, sample_size):
|
||||
|
@ -76,7 +73,6 @@ def _common_type(t, debug, sample_size):
|
|||
sub = "Optional[{}]".format(sub)
|
||||
return sub if sub else "pyobj"
|
||||
|
||||
|
||||
def _codon_type(arg, **kwargs):
|
||||
t = type(arg)
|
||||
|
||||
|
@ -88,11 +84,11 @@ def _codon_type(arg, **kwargs):
|
|||
if issubclass(t, set):
|
||||
return "Set[{}]".format(_common_type(arg, **kwargs))
|
||||
if issubclass(t, dict):
|
||||
return "Dict[{},{}]".format(
|
||||
_common_type(arg.keys(), **kwargs), _common_type(arg.values(), **kwargs)
|
||||
)
|
||||
return "Dict[{},{}]".format(_common_type(arg.keys(), **kwargs),
|
||||
_common_type(arg.values(), **kwargs))
|
||||
if issubclass(t, tuple):
|
||||
return "Tuple[{}]".format(",".join(_codon_type(a, **kwargs) for a in arg))
|
||||
return "Tuple[{}]".format(",".join(
|
||||
_codon_type(a, **kwargs) for a in arg))
|
||||
if issubclass(t, np.ndarray):
|
||||
if arg.dtype == np.bool_:
|
||||
dtype = "bool"
|
||||
|
@ -134,7 +130,8 @@ def _codon_type(arg, **kwargs):
|
|||
|
||||
s = custom_conversions.get(t, "")
|
||||
if s:
|
||||
j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
|
||||
j = ",".join(
|
||||
_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
|
||||
return "{}[{}]".format(s, j)
|
||||
|
||||
debug = kwargs.get("debug", None)
|
||||
|
@ -145,28 +142,22 @@ def _codon_type(arg, **kwargs):
|
|||
_error_msgs.add(msg)
|
||||
return "pyobj"
|
||||
|
||||
|
||||
def _codon_types(args, **kwargs):
|
||||
return tuple(_codon_type(arg, **kwargs) for arg in args)
|
||||
|
||||
|
||||
def _reset_jit():
|
||||
global _jit
|
||||
_jit = JITWrapper()
|
||||
init_code = (
|
||||
"from internal.python import "
|
||||
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
|
||||
"setup_decorator()\n"
|
||||
"import numpy as np\n"
|
||||
"import numpy.pybridge\n"
|
||||
)
|
||||
init_code = ("from internal.python import "
|
||||
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
|
||||
"setup_decorator()\n"
|
||||
"import numpy as np\n"
|
||||
"import numpy.pybridge\n")
|
||||
_jit.execute(init_code, "", 0, False)
|
||||
return _jit
|
||||
|
||||
|
||||
_jit = _reset_jit()
|
||||
|
||||
|
||||
class RewriteFunctionArgs(ast.NodeTransformer):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
@ -176,7 +167,6 @@ class RewriteFunctionArgs(ast.NodeTransformer):
|
|||
node.args.args.append(ast.arg(arg=a, annotation=None))
|
||||
return node
|
||||
|
||||
|
||||
def _obj_to_str(obj, **kwargs) -> str:
|
||||
if inspect.isclass(obj):
|
||||
lines = inspect.getsourcelines(obj)[0]
|
||||
|
@ -185,8 +175,10 @@ def _obj_to_str(obj, **kwargs) -> str:
|
|||
obj_name = obj.__name__
|
||||
elif callable(obj) or isinstance(obj, str):
|
||||
is_str = isinstance(obj, str)
|
||||
lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0]
|
||||
if not is_str: lines = lines[1:]
|
||||
lines = [i + '\n' for i in obj.split('\n')
|
||||
] if is_str else inspect.getsourcelines(obj)[0]
|
||||
if not is_str:
|
||||
lines = lines[1:]
|
||||
obj_str = textwrap.dedent(''.join(lines))
|
||||
|
||||
pyvars = kwargs.get("pyvars", None)
|
||||
|
@ -195,8 +187,7 @@ def _obj_to_str(obj, **kwargs) -> str:
|
|||
if not isinstance(i, str):
|
||||
raise ValueError("pyvars only takes string literals")
|
||||
node = ast.fix_missing_locations(
|
||||
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str))
|
||||
)
|
||||
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)))
|
||||
obj_str = astunparse.unparse(node)
|
||||
if is_str:
|
||||
try:
|
||||
|
@ -206,28 +197,23 @@ def _obj_to_str(obj, **kwargs) -> str:
|
|||
else:
|
||||
obj_name = obj.__name__
|
||||
else:
|
||||
raise TypeError("Function or class expected, got " + type(obj).__name__)
|
||||
raise TypeError("Function or class expected, got " +
|
||||
type(obj).__name__)
|
||||
return obj_name, obj_str.replace("_@par", "@par")
|
||||
|
||||
|
||||
def _parse_decorated(obj, **kwargs):
|
||||
return _obj_to_str(obj, **kwargs)
|
||||
|
||||
return _obj_to_str(obj, **kwargs)
|
||||
|
||||
def convert(t):
|
||||
if not hasattr(t, "__slots__"):
|
||||
raise JITError("class '{}' does not have '__slots__' attribute".format(str(t)))
|
||||
raise JITError("class '{}' does not have '__slots__' attribute".format(
|
||||
str(t)))
|
||||
|
||||
name = t.__name__
|
||||
slots = t.__slots__
|
||||
code = (
|
||||
"@tuple\n"
|
||||
"class "
|
||||
+ name
|
||||
+ "["
|
||||
+ ",".join("T{}".format(i) for i in range(len(slots)))
|
||||
+ "]:\n"
|
||||
)
|
||||
code = ("@tuple\n"
|
||||
"class " + name + "[" +
|
||||
",".join("T{}".format(i) for i in range(len(slots))) + "]:\n")
|
||||
for i, slot in enumerate(slots):
|
||||
code += " {}: T{}\n".format(slot, i)
|
||||
|
||||
|
@ -235,17 +221,14 @@ def convert(t):
|
|||
code += " def __from_py__(p: cobj):\n"
|
||||
for i, slot in enumerate(slots):
|
||||
code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n".format(
|
||||
i, i, slot
|
||||
)
|
||||
i, i, slot)
|
||||
code += " return {}({})\n".format(
|
||||
name, ", ".join("a{}".format(i) for i in range(len(slots)))
|
||||
)
|
||||
name, ", ".join("a{}".format(i) for i in range(len(slots))))
|
||||
|
||||
_jit.execute(code, "", 0, False)
|
||||
custom_conversions[t] = name
|
||||
return t
|
||||
|
||||
|
||||
def _jit_register_fn(f, pyvars, debug):
|
||||
try:
|
||||
obj_name, obj_str = _parse_decorated(f, pyvars=pyvars)
|
||||
|
@ -258,29 +241,46 @@ def _jit_register_fn(f, pyvars, debug):
|
|||
_reset_jit()
|
||||
raise
|
||||
|
||||
def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs):
|
||||
try:
|
||||
def _jit_callback_fn(fn,
|
||||
obj_name,
|
||||
module,
|
||||
debug=None,
|
||||
sample_size=5,
|
||||
pyvars=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
if fn is not None:
|
||||
sig = inspect.signature(fn)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
args = tuple(bound_args.arguments[param] for param in sig.parameters)
|
||||
else:
|
||||
args = (*args, *kwargs.values())
|
||||
|
||||
try:
|
||||
types = _codon_types(args, debug=debug, sample_size=sample_size)
|
||||
if debug:
|
||||
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr)
|
||||
return _jit.run_wrapper(
|
||||
obj_name, list(types), module, list(pyvars), args, 1 if debug else 0
|
||||
)
|
||||
print("[python] {}({})".format(obj_name, list(types)),
|
||||
file=sys.stderr)
|
||||
return _jit.run_wrapper(obj_name, list(types), module, list(pyvars),
|
||||
args, 1 if debug else 0)
|
||||
except JITError:
|
||||
_reset_jit()
|
||||
raise
|
||||
|
||||
def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None):
|
||||
obj_name = _jit_register_fn(fstr, pyvars, debug)
|
||||
def wrapped(*args, **kwargs):
|
||||
return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size,
|
||||
pyvars, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
def jit(fn=None, debug=None, sample_size=5, pyvars=None):
|
||||
if not pyvars:
|
||||
pyvars = []
|
||||
|
||||
if not isinstance(pyvars, list):
|
||||
raise ArgumentError("pyvars must be a list")
|
||||
|
||||
|
@ -289,12 +289,15 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
|
|||
|
||||
def _decorate(f):
|
||||
obj_name = _jit_register_fn(f, pyvars, debug)
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs)
|
||||
return wrapped
|
||||
return _decorate(fn) if fn else _decorate
|
||||
return _jit_callback_fn(f, obj_name, f.__module__, debug,
|
||||
sample_size, pyvars, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return _decorate(fn) if fn else _decorate
|
||||
|
||||
def execute(code, debug=False):
|
||||
try:
|
||||
|
|
|
@ -1,16 +1,22 @@
|
|||
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
from libcpp.string cimport string
|
||||
from libcpp.vector cimport vector
|
||||
from libc.stdint cimport int32_t, uint8_t
|
||||
|
||||
cdef extern from "codon/compiler/jit_extern.h" namespace "codon::jit":
|
||||
cdef cppclass JIT
|
||||
cdef cppclass JITResult:
|
||||
cdef extern from "codon/compiler/jit_extern.h":
|
||||
cdef struct CJITResult:
|
||||
void *result
|
||||
string message
|
||||
bint operator bool()
|
||||
char *error
|
||||
|
||||
JIT *jitInit(string)
|
||||
JITResult jitExecuteSafe(JIT*, string, string, int, char)
|
||||
JITResult jitExecutePython(JIT*, string, vector[string], string, vector[string], object, char)
|
||||
string getJITLibrary()
|
||||
void *jit_init(char *name)
|
||||
void jit_exit(void *jit)
|
||||
|
||||
cdef char *get_jit_library()
|
||||
|
||||
cdef CJITResult jit_execute_safe(
|
||||
void *jit, char *code, char *file, int32_t line, uint8_t debug
|
||||
)
|
||||
cdef CJITResult jit_execute_python(
|
||||
void *jit, char *name, char **types, size_t types_size,
|
||||
char *pyModule, char **py_vars, size_t py_vars_size,
|
||||
void *arg, uint8_t debug
|
||||
)
|
||||
|
|
|
@ -1,45 +1,82 @@
|
|||
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
||||
|
||||
# distutils: language=c++
|
||||
# distutils: language=c
|
||||
# cython: language_level=3
|
||||
# cython: c_string_type=unicode
|
||||
# cython: c_string_encoding=utf8
|
||||
|
||||
from libcpp.string cimport string
|
||||
from libcpp.vector cimport vector
|
||||
cimport codon.jit
|
||||
from libc.stdlib cimport malloc, calloc, free
|
||||
from libc.string cimport strcpy
|
||||
from libc.stdint cimport int32_t, uint8_t
|
||||
|
||||
|
||||
class JITError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
cdef str get_free_str(char *s):
|
||||
cdef bytes py_s
|
||||
try:
|
||||
py_s = s
|
||||
return py_s.decode('utf-8')
|
||||
finally:
|
||||
free(s)
|
||||
|
||||
|
||||
cdef class JITWrapper:
|
||||
cdef codon.jit.JIT* jit
|
||||
cdef void* jit
|
||||
|
||||
def __cinit__(self):
|
||||
self.jit = codon.jit.jitInit(b"codon jit")
|
||||
self.jit = codon.jit.jit_init(b"codon jit")
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.jit
|
||||
codon.jit.jit_exit(self.jit)
|
||||
|
||||
def execute(self, code: str, filename: str, fileno: int, debug: char) -> str:
|
||||
result = codon.jit.jitExecuteSafe(self.jit, code, filename, fileno, <char>debug)
|
||||
if <bint>result:
|
||||
def execute(self, code: str, filename: str, fileno: int, debug) -> str:
|
||||
result = codon.jit.jit_execute_safe(
|
||||
self.jit, code.encode('utf-8'), filename.encode('utf-8'), fileno, <uint8_t>debug
|
||||
)
|
||||
if result.error is NULL:
|
||||
return None
|
||||
else:
|
||||
raise JITError(result.message)
|
||||
msg = get_free_str(result.error)
|
||||
raise JITError(msg)
|
||||
|
||||
def run_wrapper(self, name: str, types: list[str], module: str, pyvars: list[str], args, debug) -> object:
|
||||
cdef char** c_types = <char**>calloc(len(types), sizeof(char*))
|
||||
cdef char** c_pyvars = <char**>calloc(len(pyvars), sizeof(char*))
|
||||
if not c_types or not c_pyvars:
|
||||
raise JITError("Cython allocation failed")
|
||||
try:
|
||||
for i, s in enumerate(types):
|
||||
bytes = s.encode('utf-8')
|
||||
c_types[i] = <char*>malloc(len(bytes) + 1)
|
||||
strcpy(c_types[i], bytes)
|
||||
for i, s in enumerate(pyvars):
|
||||
bytes = s.encode('utf-8')
|
||||
c_pyvars[i] = <char*>malloc(len(bytes) + 1)
|
||||
strcpy(c_pyvars[i], bytes)
|
||||
|
||||
result = codon.jit.jit_execute_python(
|
||||
self.jit, name.encode('utf-8'), c_types, len(types),
|
||||
module.encode('utf-8'), c_pyvars, len(pyvars),
|
||||
<void *>args, <uint8_t>debug
|
||||
)
|
||||
if result.error is NULL:
|
||||
return <object>result.result
|
||||
else:
|
||||
msg = get_free_str(result.error)
|
||||
raise JITError(msg)
|
||||
finally:
|
||||
for i in range(len(types)):
|
||||
free(c_types[i])
|
||||
free(c_types)
|
||||
for i in range(len(pyvars)):
|
||||
free(c_pyvars[i])
|
||||
free(c_pyvars)
|
||||
|
||||
def run_wrapper(self, name: str, types: list[str], module: str, pyvars: list[str], args, debug: char) -> object:
|
||||
cdef vector[string] types_vec = types
|
||||
cdef vector[string] pyvars_vec = pyvars
|
||||
result = codon.jit.jitExecutePython(
|
||||
self.jit, name, types_vec, module, pyvars_vec, <object>args, <char>debug
|
||||
)
|
||||
if <bint>result:
|
||||
return <object>result.result
|
||||
else:
|
||||
raise JITError(result.message)
|
||||
|
||||
def codon_library():
|
||||
return codon.jit.getJITLibrary()
|
||||
cdef char* c = codon.jit.get_jit_library()
|
||||
return get_free_str(c)
|
||||
|
|
|
@ -67,7 +67,7 @@ jit_extension = Extension(
|
|||
"codon.codon_jit",
|
||||
sources=["codon/jit.pyx"],
|
||||
libraries=libraries,
|
||||
language="c++",
|
||||
language="c",
|
||||
extra_compile_args=["-w"],
|
||||
extra_link_args=linker_args,
|
||||
include_dirs=[str(codon_path / "include")],
|
||||
|
|
|
@ -24,6 +24,7 @@ PyFloat_AsDouble = Function[[cobj], float](cobj())
|
|||
PyFloat_FromDouble = Function[[float], cobj](cobj())
|
||||
PyBool_FromLong = Function[[int], cobj](cobj())
|
||||
PyBytes_AsString = Function[[cobj], cobj](cobj())
|
||||
PyBytes_Size = Function[[cobj], int](cobj())
|
||||
PyList_New = Function[[int], cobj](cobj())
|
||||
PyList_Size = Function[[cobj], int](cobj())
|
||||
PyList_GetItem = Function[[cobj, int], cobj](cobj())
|
||||
|
@ -130,6 +131,7 @@ PyLong_Type = cobj()
|
|||
PyFloat_Type = cobj()
|
||||
PyBool_Type = cobj()
|
||||
PyUnicode_Type = cobj()
|
||||
PyBytes_Type = cobj()
|
||||
PyComplex_Type = cobj()
|
||||
PyList_Type = cobj()
|
||||
PyDict_Type = cobj()
|
||||
|
@ -213,6 +215,7 @@ def init_handles_dlopen(py_handle: cobj):
|
|||
global PyFloat_FromDouble
|
||||
global PyBool_FromLong
|
||||
global PyBytes_AsString
|
||||
global PyBytes_Size
|
||||
global PyList_New
|
||||
global PyList_Size
|
||||
global PyList_GetItem
|
||||
|
@ -303,6 +306,7 @@ def init_handles_dlopen(py_handle: cobj):
|
|||
global PyFloat_Type
|
||||
global PyBool_Type
|
||||
global PyUnicode_Type
|
||||
global PyBytes_Type
|
||||
global PyComplex_Type
|
||||
global PyList_Type
|
||||
global PyDict_Type
|
||||
|
@ -347,6 +351,7 @@ def init_handles_dlopen(py_handle: cobj):
|
|||
PyFloat_FromDouble = dlsym(py_handle, "PyFloat_FromDouble")
|
||||
PyBool_FromLong = dlsym(py_handle, "PyBool_FromLong")
|
||||
PyBytes_AsString = dlsym(py_handle, "PyBytes_AsString")
|
||||
PyBytes_Size = dlsym(py_handle, "PyBytes_Size")
|
||||
PyList_New = dlsym(py_handle, "PyList_New")
|
||||
PyList_Size = dlsym(py_handle, "PyList_Size")
|
||||
PyList_GetItem = dlsym(py_handle, "PyList_GetItem")
|
||||
|
@ -437,6 +442,7 @@ def init_handles_dlopen(py_handle: cobj):
|
|||
PyFloat_Type = dlsym(py_handle, "PyFloat_Type")
|
||||
PyBool_Type = dlsym(py_handle, "PyBool_Type")
|
||||
PyUnicode_Type = dlsym(py_handle, "PyUnicode_Type")
|
||||
PyBytes_Type = dlsym(py_handle, "PyBytes_Type")
|
||||
PyComplex_Type = dlsym(py_handle, "PyComplex_Type")
|
||||
PyList_Type = dlsym(py_handle, "PyList_Type")
|
||||
PyDict_Type = dlsym(py_handle, "PyDict_Type")
|
||||
|
@ -482,6 +488,7 @@ def init_handles_static():
|
|||
from C import PyFloat_FromDouble(float) -> cobj as _PyFloat_FromDouble
|
||||
from C import PyBool_FromLong(int) -> cobj as _PyBool_FromLong
|
||||
from C import PyBytes_AsString(cobj) -> cobj as _PyBytes_AsString
|
||||
from C import PyBytes_Size(cobj) -> int as _PyBytes_Size
|
||||
from C import PyList_New(int) -> cobj as _PyList_New
|
||||
from C import PyList_Size(cobj) -> int as _PyList_Size
|
||||
from C import PyList_GetItem(cobj, int) -> cobj as _PyList_GetItem
|
||||
|
@ -572,6 +579,7 @@ def init_handles_static():
|
|||
from C import PyFloat_Type: cobj as _PyFloat_Type
|
||||
from C import PyBool_Type: cobj as _PyBool_Type
|
||||
from C import PyUnicode_Type: cobj as _PyUnicode_Type
|
||||
from C import PyBytes_Type: cobj as _PyBytes_Type
|
||||
from C import PyComplex_Type: cobj as _PyComplex_Type
|
||||
from C import PyList_Type: cobj as _PyList_Type
|
||||
from C import PyDict_Type: cobj as _PyDict_Type
|
||||
|
@ -616,6 +624,7 @@ def init_handles_static():
|
|||
global PyFloat_FromDouble
|
||||
global PyBool_FromLong
|
||||
global PyBytes_AsString
|
||||
global PyBytes_Size
|
||||
global PyList_New
|
||||
global PyList_Size
|
||||
global PyList_GetItem
|
||||
|
@ -706,6 +715,7 @@ def init_handles_static():
|
|||
global PyFloat_Type
|
||||
global PyBool_Type
|
||||
global PyUnicode_Type
|
||||
global PyBytes_Type
|
||||
global PyComplex_Type
|
||||
global PyList_Type
|
||||
global PyDict_Type
|
||||
|
@ -750,6 +760,7 @@ def init_handles_static():
|
|||
PyFloat_FromDouble = _PyFloat_FromDouble
|
||||
PyBool_FromLong = _PyBool_FromLong
|
||||
PyBytes_AsString = _PyBytes_AsString
|
||||
PyBytes_Size = _PyBytes_Size
|
||||
PyList_New = _PyList_New
|
||||
PyList_Size = _PyList_Size
|
||||
PyList_GetItem = _PyList_GetItem
|
||||
|
@ -840,6 +851,7 @@ def init_handles_static():
|
|||
PyFloat_Type = __ptr__(_PyFloat_Type).as_byte()
|
||||
PyBool_Type = __ptr__(_PyBool_Type).as_byte()
|
||||
PyUnicode_Type = __ptr__(_PyUnicode_Type).as_byte()
|
||||
PyBytes_Type = __ptr__(_PyBytes_Type).as_byte()
|
||||
PyComplex_Type = __ptr__(_PyComplex_Type).as_byte()
|
||||
PyList_Type = __ptr__(_PyList_Type).as_byte()
|
||||
PyDict_Type = __ptr__(_PyDict_Type).as_byte()
|
||||
|
@ -1174,7 +1186,7 @@ class pyobj:
|
|||
return pyobj.to_str(self.p, errors, empty)
|
||||
|
||||
def to_str(p: cobj, errors: str, empty: str = "") -> str:
|
||||
obj = PyUnicode_AsEncodedString(p, "utf-8".c_str(), errors.c_str())
|
||||
obj = PyUnicode_AsEncodedString(p, "utf-8".ptr, errors.c_str() if errors else "".ptr)
|
||||
if obj == cobj():
|
||||
return empty
|
||||
bts = PyBytes_AsString(obj)
|
||||
|
@ -1292,8 +1304,11 @@ class _PyObject_Struct:
|
|||
def _conversion_error(name: Static[str]):
|
||||
raise PyError("conversion error: Python object did not have type '" + name + "'")
|
||||
|
||||
def _get_type(o: cobj):
|
||||
return Ptr[_PyObject_Struct](o)[0].pytype
|
||||
|
||||
def _ensure_type(o: cobj, t: cobj, name: Static[str]):
|
||||
if Ptr[_PyObject_Struct](o)[0].pytype != t:
|
||||
if _get_type(o) != t:
|
||||
_conversion_error(name)
|
||||
|
||||
|
||||
|
@ -1350,7 +1365,14 @@ class str:
|
|||
return pyobj.exc_wrap(PyUnicode_DecodeFSDefaultAndSize(self.ptr, self.len))
|
||||
|
||||
def __from_py__(s: cobj) -> str:
|
||||
return pyobj.exc_wrap(pyobj.to_str(s, "strict"))
|
||||
if _get_type(s) == PyBytes_Type:
|
||||
n = PyBytes_Size(s)
|
||||
p0 = PyBytes_AsString(s)
|
||||
p1 = cobj(n)
|
||||
str.memcpy(p1, p0, n)
|
||||
return str(p1, n)
|
||||
else:
|
||||
return pyobj.exc_wrap(pyobj.to_str(s, "strict"))
|
||||
|
||||
@extend
|
||||
class complex:
|
||||
|
|
|
@ -4,6 +4,22 @@ from internal.attributes import commutative
|
|||
from internal.gc import alloc_atomic, free
|
||||
from internal.types.complex import complex
|
||||
|
||||
def _float_int_pow(a: F, b: int, F: type) -> F:
|
||||
abs_exp = b.__abs__()
|
||||
result = F(1)
|
||||
factor = a
|
||||
|
||||
while abs_exp:
|
||||
if abs_exp & 1:
|
||||
result *= factor
|
||||
factor *= factor
|
||||
abs_exp >>= 1
|
||||
|
||||
if b < 0:
|
||||
result = F(1) / result
|
||||
|
||||
return result
|
||||
|
||||
@extend
|
||||
class float:
|
||||
def __new__() -> float:
|
||||
|
@ -401,6 +417,50 @@ class float:
|
|||
def imag(self) -> float:
|
||||
return 0.0
|
||||
|
||||
@commutative
|
||||
def __add__(self: float, b: int) -> float:
|
||||
return self + float(b)
|
||||
|
||||
def __sub__(self: float, b: int) -> float:
|
||||
return self - float(b)
|
||||
|
||||
@commutative
|
||||
def __mul__(self: float, b: int) -> float:
|
||||
return self * float(b)
|
||||
|
||||
def __floordiv__(self, b: int) -> float:
|
||||
return self // float(b)
|
||||
|
||||
def __truediv__(self: float, b: int) -> float:
|
||||
return self / float(b)
|
||||
|
||||
def __mod__(self: float, b: int) -> float:
|
||||
return self % float(b)
|
||||
|
||||
def __divmod__(self, b: int):
|
||||
return self.__divmod__(float(b))
|
||||
|
||||
def __eq__(self: float, b: int) -> bool:
|
||||
return self == float(b)
|
||||
|
||||
def __ne__(self: float, b: int) -> bool:
|
||||
return self != float(b)
|
||||
|
||||
def __lt__(self: float, b: int) -> bool:
|
||||
return self < float(b)
|
||||
|
||||
def __gt__(self: float, b: int) -> bool:
|
||||
return self > float(b)
|
||||
|
||||
def __le__(self: float, b: int) -> bool:
|
||||
return self <= float(b)
|
||||
|
||||
def __ge__(self: float, b: int) -> bool:
|
||||
return self >= float(b)
|
||||
|
||||
def __pow__(self: float, b: int) -> float:
|
||||
return _float_int_pow(self, b)
|
||||
|
||||
@extend
|
||||
class float32:
|
||||
@pure
|
||||
|
@ -755,6 +815,50 @@ class float32:
|
|||
def __match__(self, obj: float32) -> bool:
|
||||
return self == obj
|
||||
|
||||
@commutative
|
||||
def __add__(self: float32, b: int) -> float32:
|
||||
return self + float32(b)
|
||||
|
||||
def __sub__(self: float32, b: int) -> float32:
|
||||
return self - float32(b)
|
||||
|
||||
@commutative
|
||||
def __mul__(self: float32, b: int) -> float32:
|
||||
return self * float32(b)
|
||||
|
||||
def __floordiv__(self, b: int) -> float32:
|
||||
return self // float32(b)
|
||||
|
||||
def __truediv__(self: float32, b: int) -> float32:
|
||||
return self / float32(b)
|
||||
|
||||
def __mod__(self: float32, b: int) -> float32:
|
||||
return self % float32(b)
|
||||
|
||||
def __divmod__(self, b: int):
|
||||
return self.__divmod__(float32(b))
|
||||
|
||||
def __eq__(self: float32, b: int) -> bool:
|
||||
return self == float32(b)
|
||||
|
||||
def __ne__(self: float32, b: int) -> bool:
|
||||
return self != float32(b)
|
||||
|
||||
def __lt__(self: float32, b: int) -> bool:
|
||||
return self < float32(b)
|
||||
|
||||
def __gt__(self: float32, b: int) -> bool:
|
||||
return self > float32(b)
|
||||
|
||||
def __le__(self: float32, b: int) -> bool:
|
||||
return self <= float32(b)
|
||||
|
||||
def __ge__(self: float32, b: int) -> bool:
|
||||
return self >= float32(b)
|
||||
|
||||
def __pow__(self: float32, b: int) -> float32:
|
||||
return _float_int_pow(self, b)
|
||||
|
||||
@extend
|
||||
class float16:
|
||||
@pure
|
||||
|
@ -1055,6 +1159,50 @@ class float16:
|
|||
def __match__(self, obj: float16) -> bool:
|
||||
return self == obj
|
||||
|
||||
@commutative
|
||||
def __add__(self: float16, b: int) -> float16:
|
||||
return self + float16(b)
|
||||
|
||||
def __sub__(self: float16, b: int) -> float16:
|
||||
return self - float16(b)
|
||||
|
||||
@commutative
|
||||
def __mul__(self: float16, b: int) -> float16:
|
||||
return self * float16(b)
|
||||
|
||||
def __floordiv__(self, b: int) -> float16:
|
||||
return self // float16(b)
|
||||
|
||||
def __truediv__(self: float16, b: int) -> float16:
|
||||
return self / float16(b)
|
||||
|
||||
def __mod__(self: float16, b: int) -> float16:
|
||||
return self % float16(b)
|
||||
|
||||
def __divmod__(self, b: int):
|
||||
return self.__divmod__(float16(b))
|
||||
|
||||
def __eq__(self: float16, b: int) -> bool:
|
||||
return self == float16(b)
|
||||
|
||||
def __ne__(self: float16, b: int) -> bool:
|
||||
return self != float16(b)
|
||||
|
||||
def __lt__(self: float16, b: int) -> bool:
|
||||
return self < float16(b)
|
||||
|
||||
def __gt__(self: float16, b: int) -> bool:
|
||||
return self > float16(b)
|
||||
|
||||
def __le__(self: float16, b: int) -> bool:
|
||||
return self <= float16(b)
|
||||
|
||||
def __ge__(self: float16, b: int) -> bool:
|
||||
return self >= float16(b)
|
||||
|
||||
def __pow__(self: float16, b: int) -> float16:
|
||||
return _float_int_pow(self, b)
|
||||
|
||||
@extend
|
||||
class bfloat16:
|
||||
@pure
|
||||
|
@ -1355,6 +1503,50 @@ class bfloat16:
|
|||
def __match__(self, obj: bfloat16) -> bool:
|
||||
return self == obj
|
||||
|
||||
@commutative
|
||||
def __add__(self: bfloat16, b: int) -> bfloat16:
|
||||
return self + bfloat16(b)
|
||||
|
||||
def __sub__(self: bfloat16, b: int) -> bfloat16:
|
||||
return self - bfloat16(b)
|
||||
|
||||
@commutative
|
||||
def __mul__(self: bfloat16, b: int) -> bfloat16:
|
||||
return self * bfloat16(b)
|
||||
|
||||
def __floordiv__(self, b: int) -> bfloat16:
|
||||
return self // bfloat16(b)
|
||||
|
||||
def __truediv__(self: bfloat16, b: int) -> bfloat16:
|
||||
return self / bfloat16(b)
|
||||
|
||||
def __mod__(self: bfloat16, b: int) -> bfloat16:
|
||||
return self % bfloat16(b)
|
||||
|
||||
def __divmod__(self, b: int):
|
||||
return self.__divmod__(bfloat16(b))
|
||||
|
||||
def __eq__(self: bfloat16, b: int) -> bool:
|
||||
return self == bfloat16(b)
|
||||
|
||||
def __ne__(self: bfloat16, b: int) -> bool:
|
||||
return self != bfloat16(b)
|
||||
|
||||
def __lt__(self: bfloat16, b: int) -> bool:
|
||||
return self < bfloat16(b)
|
||||
|
||||
def __gt__(self: bfloat16, b: int) -> bool:
|
||||
return self > bfloat16(b)
|
||||
|
||||
def __le__(self: bfloat16, b: int) -> bool:
|
||||
return self <= bfloat16(b)
|
||||
|
||||
def __ge__(self: bfloat16, b: int) -> bool:
|
||||
return self >= bfloat16(b)
|
||||
|
||||
def __pow__(self: bfloat16, b: int) -> bfloat16:
|
||||
return _float_int_pow(self, b)
|
||||
|
||||
@extend
|
||||
class float128:
|
||||
@pure
|
||||
|
@ -1652,6 +1844,50 @@ class float128:
|
|||
def __match__(self, obj: float128) -> bool:
|
||||
return self == obj
|
||||
|
||||
@commutative
|
||||
def __add__(self: float128, b: int) -> float128:
|
||||
return self + float128(b)
|
||||
|
||||
def __sub__(self: float128, b: int) -> float128:
|
||||
return self - float128(b)
|
||||
|
||||
@commutative
|
||||
def __mul__(self: float128, b: int) -> float128:
|
||||
return self * float128(b)
|
||||
|
||||
def __floordiv__(self, b: int) -> float128:
|
||||
return self // float128(b)
|
||||
|
||||
def __truediv__(self: float128, b: int) -> float128:
|
||||
return self / float128(b)
|
||||
|
||||
def __mod__(self: float128, b: int) -> float128:
|
||||
return self % float128(b)
|
||||
|
||||
def __divmod__(self, b: int):
|
||||
return self.__divmod__(float128(b))
|
||||
|
||||
def __eq__(self: float128, b: int) -> bool:
|
||||
return self == float128(b)
|
||||
|
||||
def __ne__(self: float128, b: int) -> bool:
|
||||
return self != float128(b)
|
||||
|
||||
def __lt__(self: float128, b: int) -> bool:
|
||||
return self < float128(b)
|
||||
|
||||
def __gt__(self: float128, b: int) -> bool:
|
||||
return self > float128(b)
|
||||
|
||||
def __le__(self: float128, b: int) -> bool:
|
||||
return self <= float128(b)
|
||||
|
||||
def __ge__(self: float128, b: int) -> bool:
|
||||
return self >= float128(b)
|
||||
|
||||
def __pow__(self: float128, b: int) -> float128:
|
||||
return _float_int_pow(self, b)
|
||||
|
||||
@extend
|
||||
class float:
|
||||
def __suffix_f32__(double) -> float32:
|
||||
|
@ -1666,6 +1902,184 @@ class float:
|
|||
def __suffix_f128__(double) -> float128:
|
||||
return float128.__new__(double)
|
||||
|
||||
@extend
|
||||
class int:
|
||||
@commutative
|
||||
def __add__(self, b: float32) -> float32:
|
||||
return float32(self) + b
|
||||
|
||||
def __sub__(self, b: float32) -> float32:
|
||||
return float32(self) - b
|
||||
|
||||
@commutative
|
||||
def __mul__(self, b: float32) -> float32:
|
||||
return float32(self) * b
|
||||
|
||||
def __floordiv__(self, b: float32) -> float32:
|
||||
return float32(self) // b
|
||||
|
||||
def __truediv__(self, b: float32) -> float32:
|
||||
return float32(self) / b
|
||||
|
||||
def __mod__(self, b: float32) -> float32:
|
||||
return float32(self) % b
|
||||
|
||||
def __divmod__(self, b: float32):
|
||||
return float32(self).__divmod__(b)
|
||||
|
||||
def __pow__(self, b: float32) -> float32:
|
||||
return float32(self) ** b
|
||||
|
||||
def __eq__(self, b: float32) -> bool:
|
||||
return float32(self) == b
|
||||
|
||||
def __ne__(self, b: float32) -> bool:
|
||||
return float32(self) != b
|
||||
|
||||
def __lt__(self, b: float32) -> bool:
|
||||
return float32(self) < b
|
||||
|
||||
def __gt__(self, b: float32) -> bool:
|
||||
return float32(self) > b
|
||||
|
||||
def __le__(self, b: float32) -> bool:
|
||||
return float32(self) <= b
|
||||
|
||||
def __ge__(self, b: float32) -> bool:
|
||||
return float32(self) >= b
|
||||
|
||||
@commutative
|
||||
def __add__(self, b: float16) -> float16:
|
||||
return float16(self) + b
|
||||
|
||||
def __sub__(self, b: float16) -> float16:
|
||||
return float16(self) - b
|
||||
|
||||
@commutative
|
||||
def __mul__(self, b: float16) -> float16:
|
||||
return float16(self) * b
|
||||
|
||||
def __floordiv__(self, b: float16) -> float16:
|
||||
return float16(self) // b
|
||||
|
||||
def __truediv__(self, b: float16) -> float16:
|
||||
return float16(self) / b
|
||||
|
||||
def __mod__(self, b: float16) -> float16:
|
||||
return float16(self) % b
|
||||
|
||||
def __divmod__(self, b: float16):
|
||||
return float16(self).__divmod__(b)
|
||||
|
||||
def __pow__(self, b: float16) -> float16:
|
||||
return float16(self) ** b
|
||||
|
||||
def __eq__(self, b: float16) -> bool:
|
||||
return float16(self) == b
|
||||
|
||||
def __ne__(self, b: float16) -> bool:
|
||||
return float16(self) != b
|
||||
|
||||
def __lt__(self, b: float16) -> bool:
|
||||
return float16(self) < b
|
||||
|
||||
def __gt__(self, b: float16) -> bool:
|
||||
return float16(self) > b
|
||||
|
||||
def __le__(self, b: float16) -> bool:
|
||||
return float16(self) <= b
|
||||
|
||||
def __ge__(self, b: float16) -> bool:
|
||||
return float16(self) >= b
|
||||
|
||||
@commutative
|
||||
def __add__(self, b: bfloat16) -> bfloat16:
|
||||
return bfloat16(self) + b
|
||||
|
||||
def __sub__(self, b: bfloat16) -> bfloat16:
|
||||
return bfloat16(self) - b
|
||||
|
||||
@commutative
|
||||
def __mul__(self, b: bfloat16) -> bfloat16:
|
||||
return bfloat16(self) * b
|
||||
|
||||
def __floordiv__(self, b: bfloat16) -> bfloat16:
|
||||
return bfloat16(self) // b
|
||||
|
||||
def __truediv__(self, b: bfloat16) -> bfloat16:
|
||||
return bfloat16(self) / b
|
||||
|
||||
def __mod__(self, b: bfloat16) -> bfloat16:
|
||||
return bfloat16(self) % b
|
||||
|
||||
def __divmod__(self, b: bfloat16):
|
||||
return bfloat16(self).__divmod__(b)
|
||||
|
||||
def __pow__(self, b: bfloat16) -> bfloat16:
|
||||
return bfloat16(self) ** b
|
||||
|
||||
def __eq__(self, b: bfloat16) -> bool:
|
||||
return bfloat16(self) == b
|
||||
|
||||
def __ne__(self, b: bfloat16) -> bool:
|
||||
return bfloat16(self) != b
|
||||
|
||||
def __lt__(self, b: bfloat16) -> bool:
|
||||
return bfloat16(self) < b
|
||||
|
||||
def __gt__(self, b: bfloat16) -> bool:
|
||||
return bfloat16(self) > b
|
||||
|
||||
def __le__(self, b: bfloat16) -> bool:
|
||||
return bfloat16(self) <= b
|
||||
|
||||
def __ge__(self, b: bfloat16) -> bool:
|
||||
return bfloat16(self) >= b
|
||||
|
||||
@commutative
|
||||
def __add__(self, b: float128) -> float128:
|
||||
return float128(self) + b
|
||||
|
||||
def __sub__(self, b: float128) -> float128:
|
||||
return float128(self) - b
|
||||
|
||||
@commutative
|
||||
def __mul__(self, b: float128) -> float128:
|
||||
return float128(self) * b
|
||||
|
||||
def __floordiv__(self, b: float128) -> float128:
|
||||
return float128(self) // b
|
||||
|
||||
def __truediv__(self, b: float128) -> float128:
|
||||
return float128(self) / b
|
||||
|
||||
def __mod__(self, b: float128) -> float128:
|
||||
return float128(self) % b
|
||||
|
||||
def __divmod__(self, b: float128):
|
||||
return float128(self).__divmod__(b)
|
||||
|
||||
def __pow__(self, b: float128) -> float128:
|
||||
return float128(self) ** b
|
||||
|
||||
def __eq__(self, b: float128) -> bool:
|
||||
return float128(self) == b
|
||||
|
||||
def __ne__(self, b: float128) -> bool:
|
||||
return float128(self) != b
|
||||
|
||||
def __lt__(self, b: float128) -> bool:
|
||||
return float128(self) < b
|
||||
|
||||
def __gt__(self, b: float128) -> bool:
|
||||
return float128(self) > b
|
||||
|
||||
def __le__(self, b: float128) -> bool:
|
||||
return float128(self) <= b
|
||||
|
||||
def __ge__(self, b: float128) -> bool:
|
||||
return float128(self) >= b
|
||||
|
||||
f16 = float16
|
||||
bf16 = bfloat16
|
||||
f32 = float32
|
||||
|
|
|
@ -57,8 +57,8 @@ class FloatingFormat:
|
|||
self.__init__()
|
||||
return
|
||||
|
||||
min_val = None
|
||||
max_val = None
|
||||
min_val: Optional[a.dtype] = None
|
||||
max_val: Optional[a.dtype] = None
|
||||
finite = 0
|
||||
abs_non_zero = 0
|
||||
exp_format = False
|
||||
|
|
|
@ -245,7 +245,7 @@ class ndarray:
|
|||
k = 0
|
||||
for idx in util.multirange(shape):
|
||||
off = 0
|
||||
for i in range(ndim):
|
||||
for i in staticrange(ndim):
|
||||
off += idx[i] * strides[i]
|
||||
e = Ptr[cobj](arr_data + off)[0]
|
||||
if hasattr(dtype, "__from_py__"):
|
||||
|
@ -263,7 +263,7 @@ class ndarray:
|
|||
k = 0
|
||||
for idx in util.multirange(shape):
|
||||
off = 0
|
||||
for i in range(ndim):
|
||||
for i in staticrange(ndim):
|
||||
off += idx[i] * strides[i]
|
||||
e = Ptr[dtype](arr_data + off)[0]
|
||||
data[k] = e
|
||||
|
|
|
@ -2661,7 +2661,7 @@ def pad(array, pad_width, mode = 'constant', **kwargs):
|
|||
if dtype is int or isinstance(dtype, Int) or isinstance(dtype, UInt):
|
||||
return util.cast(util.rint(x), dtype)
|
||||
else:
|
||||
return x
|
||||
return util.cast(x, dtype)
|
||||
|
||||
def pad_from_function(a: ndarray, pw, padding_func, kwargs, extra = None):
|
||||
shape = a.shape
|
||||
|
@ -2915,7 +2915,7 @@ def pad(array, pad_width, mode = 'constant', **kwargs):
|
|||
fill_linear(vector,
|
||||
offset=(n - p2),
|
||||
start=util.cast(end2, float),
|
||||
stop=util.cast(start1, float),
|
||||
stop=util.cast(start2, float),
|
||||
num=p2,
|
||||
rev=True)
|
||||
|
||||
|
|
|
@ -270,84 +270,174 @@ def corrcoef(x, y=None, rowvar=True, dtype: type = NoneType):
|
|||
|
||||
return c
|
||||
|
||||
def correlate(a, b, mode: Static[str] = 'valid'):
|
||||
def _correlate(a, b, mode: str):
|
||||
|
||||
def kernel(d, dstride: int, nd: int, dtype: type,
|
||||
k, kstride: int, nk: Static[int], ktype: type,
|
||||
out, ostride: int):
|
||||
for i in range(nd):
|
||||
acc = util.zero(dtype)
|
||||
for j in staticrange(nk):
|
||||
acc += d[(i + j) * dstride] * k[j * kstride]
|
||||
out[i * ostride] = acc
|
||||
|
||||
def small_correlate(d, dstride: int, nd: int, dtype: type,
|
||||
k, kstride: int, nk: int, ktype: type,
|
||||
out, ostride: int):
|
||||
if dtype is not ktype:
|
||||
return False
|
||||
|
||||
dstride //= util.sizeof(dtype)
|
||||
kstride //= util.sizeof(dtype)
|
||||
ostride //= util.sizeof(dtype)
|
||||
|
||||
if nk == 1:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=1, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 2:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=2, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 3:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=3, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 4:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=4, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 5:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=5, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 6:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=6, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 7:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=7, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 8:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=8, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 9:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=9, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 10:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=10, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
elif nk == 11:
|
||||
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
|
||||
k=k, kstride=kstride, nk=11, ktype=ktype,
|
||||
out=out, ostride=ostride)
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def dot(_ip1: Ptr[T1], is1: int, _ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
|
||||
T1: type, T2: type, T3: type):
|
||||
ip1 = _ip1.as_byte()
|
||||
ip2 = _ip2.as_byte()
|
||||
ans = util.zero(T3)
|
||||
|
||||
for i in range(n):
|
||||
e1 = Ptr[T1](ip1)[0]
|
||||
e2 = Ptr[T2](ip2)[0]
|
||||
ans += util.cast(e1, T3) * util.cast(e2, T3)
|
||||
ip1 += is1
|
||||
ip2 += is2
|
||||
|
||||
op[0] = ans
|
||||
|
||||
def incr(p: Ptr[T], s: int, T: type):
|
||||
return Ptr[T](p.as_byte() + s)
|
||||
|
||||
n1 = a.size
|
||||
n2 = b.size
|
||||
length = n1
|
||||
n = n2
|
||||
|
||||
if mode == 'valid':
|
||||
length = length = length - n + 1
|
||||
n_left = 0
|
||||
n_right = 0
|
||||
elif mode == 'same':
|
||||
n_left = n >> 1
|
||||
n_right = n - n_left - 1
|
||||
elif mode == 'full':
|
||||
n_right = n - 1
|
||||
n_left = n - 1
|
||||
length = length + n - 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})"
|
||||
)
|
||||
|
||||
dt = type(util.coerce(a.dtype, b.dtype))
|
||||
ret = empty(length, dtype=dt)
|
||||
|
||||
is1 = a.strides[0]
|
||||
is2 = b.strides[0]
|
||||
op = ret.data
|
||||
os = ret.itemsize
|
||||
ip1 = a.data
|
||||
ip2 = Ptr[b.dtype](b.data.as_byte() + n_left * is2)
|
||||
n = n - n_left
|
||||
|
||||
for i in range(n_left):
|
||||
dot(ip1, is1, ip2, is2, op, n)
|
||||
n += 1
|
||||
ip2 = incr(ip2, -is2)
|
||||
op = incr(op, os)
|
||||
|
||||
if small_correlate(ip1, is1, n1 - n2 + 1, a.dtype,
|
||||
ip2, is2, n, b.dtype,
|
||||
op, os):
|
||||
ip1 = incr(ip1, is1 * (n1 - n2 + 1))
|
||||
op = incr(op, os * (n1 - n2 + 1))
|
||||
else:
|
||||
for i in range(n1 - n2 + 1):
|
||||
dot(ip1, is1, ip2, is2, op, n)
|
||||
ip1 = incr(ip1, is1)
|
||||
op = incr(op, os)
|
||||
|
||||
for i in range(n_right):
|
||||
n -= 1
|
||||
dot(ip1, is1, ip2, is2, op, n)
|
||||
ip1 = incr(ip1, is1)
|
||||
op = incr(op, os)
|
||||
|
||||
return ret
|
||||
|
||||
def correlate(a, b, mode: str = 'valid'):
|
||||
a = asarray(a)
|
||||
b = asarray(b)
|
||||
|
||||
if a.ndim != 1 or b.ndim != 1:
|
||||
compile_error('object too deep for desired array')
|
||||
|
||||
n1 = len(a)
|
||||
n2 = len(b)
|
||||
n1 = a.size
|
||||
n2 = b.size
|
||||
|
||||
if n1 == 0:
|
||||
raise ValueError("first argument cannot be empty")
|
||||
|
||||
if n2 == 0:
|
||||
raise ValueError("second argument cannot be empty")
|
||||
|
||||
if b.dtype is complex or b.dtype is complex64:
|
||||
b = b.conjugate()
|
||||
|
||||
if n1 < n2:
|
||||
inverted = 1
|
||||
inv = n1
|
||||
n1 = n2
|
||||
n2 = inv
|
||||
correlate(b, a, mode)
|
||||
return _correlate(b, a, mode=mode)[::-1]
|
||||
else:
|
||||
inverted = 0
|
||||
|
||||
length = n1
|
||||
n = n2
|
||||
if mode == 'valid':
|
||||
length = length - n + 1
|
||||
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
|
||||
or b.dtype is complex64):
|
||||
ret = zeros(length, dtype=complex)
|
||||
else:
|
||||
ret = empty(length)
|
||||
for i in range(length):
|
||||
for j in range(n):
|
||||
if inverted == 0:
|
||||
ret.data[i] += a._ptr(
|
||||
(j + i, ))[0] * conjugate(b._ptr((j, ))[0])
|
||||
else:
|
||||
ret.data[i] += a._ptr((j, ))[0] * b._ptr((j + i, ))[0]
|
||||
elif mode == 'same':
|
||||
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
|
||||
or b.dtype is complex64):
|
||||
ret = zeros(length, dtype=complex)
|
||||
else:
|
||||
ret = empty(length)
|
||||
for i in range(length):
|
||||
for j in range(n):
|
||||
signal_index = i - int(n / 2) + j
|
||||
if signal_index >= 0 and signal_index < length:
|
||||
if inverted == 0:
|
||||
ret.data[i] += a._ptr(
|
||||
(signal_index, ))[0] * conjugate(b._ptr((j, ))[0])
|
||||
else:
|
||||
ret.data[i] += a._ptr((j, ))[0] * b._ptr(
|
||||
(signal_index, ))[0]
|
||||
elif mode == 'full':
|
||||
full_length = length + n - 1
|
||||
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
|
||||
or b.dtype is complex64):
|
||||
ret = zeros(full_length, dtype=complex)
|
||||
else:
|
||||
ret = empty(full_length)
|
||||
for i in range(full_length):
|
||||
for j in range(n):
|
||||
signal_index = i + j - 2
|
||||
if signal_index >= 0 and signal_index < length:
|
||||
if inverted == 0:
|
||||
ret.data[i] += a._ptr(
|
||||
(signal_index, ))[0] * conjugate(b._ptr((j, ))[0])
|
||||
else:
|
||||
ret.data[i] += a._ptr((j, ))[0] * b._ptr(
|
||||
(signal_index, ))[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})"
|
||||
)
|
||||
|
||||
if inverted:
|
||||
ret = ret[::-1]
|
||||
if ret.dtype is complex or ret.dtype is complex64:
|
||||
ret.map(conjugate, inplace=True)
|
||||
|
||||
return ret
|
||||
return _correlate(a, b, mode=mode)
|
||||
|
||||
def bincount(x, weights=None, minlength: int = 0):
|
||||
x = asarray(x).astype(int)
|
||||
|
|
|
@ -199,3 +199,55 @@ def test_float_out_of_range_parse():
|
|||
assert 1e10000 == float('inf')
|
||||
|
||||
test_float_out_of_range_parse()
|
||||
|
||||
@test
|
||||
def test_int_float_ops(F: type):
|
||||
def check(got, exp=True):
|
||||
return (exp == got) and (type(exp) is type(got))
|
||||
|
||||
# standard
|
||||
assert check(F(1.5) + 1, F(2.5))
|
||||
assert check(F(1.5) - 1, F(0.5))
|
||||
assert check(F(1.5) * 2, F(3.0))
|
||||
assert check(F(1.5) / 2, F(0.75))
|
||||
assert check(F(3.5) // 2, F(1.0))
|
||||
assert check(F(3.5) % 2, F(1.5))
|
||||
assert check(F(3.5) ** 2, F(12.25))
|
||||
assert check(divmod(F(3.5), 2), (F(1.0), F(1.5)))
|
||||
|
||||
# right-hand ops
|
||||
assert check(1 + F(1.5), F(2.5))
|
||||
assert check(1 - F(1.5), F(-0.5))
|
||||
assert check(2 * F(1.5), F(3.0))
|
||||
assert check(2 / F(2.5), F(0.8))
|
||||
assert check(2 // F(1.5), F(1.0))
|
||||
assert check(2 % F(1.5), F(0.5))
|
||||
assert check(4 ** F(2.5), F(32.0))
|
||||
assert check(divmod(4, F(2.5)), (F(1.0), F(1.5)))
|
||||
|
||||
# comparisons
|
||||
assert check(F(1.0) == 1)
|
||||
assert check(F(2.0) != 1)
|
||||
assert check(F(0.0) < 1)
|
||||
assert check(F(2.0) > 1)
|
||||
assert check(F(0.0) <= 1)
|
||||
assert check(F(2.0) >= 1)
|
||||
assert check(1 == F(1.0))
|
||||
assert check(1 != F(2.0))
|
||||
assert check(1 < F(2.0))
|
||||
assert check(1 > F(0.0))
|
||||
assert check(1 <= F(2.0))
|
||||
assert check(1 >= F(0.0))
|
||||
|
||||
# power
|
||||
assert check(F(3.5) ** 1, F(3.5))
|
||||
assert check(F(3.5) ** 2, F(12.25))
|
||||
assert check(F(3.5) ** 3, F(42.875))
|
||||
assert check(F(4.0) ** -1, F(0.25))
|
||||
assert check(F(4.0) ** -2, F(0.0625))
|
||||
assert check(F(4.0) ** -3, F(0.015625))
|
||||
assert check(F(3.5) ** 0, F(1.0))
|
||||
|
||||
test_int_float_ops(float)
|
||||
test_int_float_ops(float32)
|
||||
test_int_float_ops(float16)
|
||||
|
|
|
@ -1238,7 +1238,7 @@ test_fill_diagonal(np.zeros((3, 5), int),
|
|||
|
||||
@test
|
||||
def test_pad(array, pad_width, expected, mode='constant', **kwargs):
|
||||
assert (np.pad(array, pad_width, mode, **kwargs) == expected).all()
|
||||
assert np.allclose(np.pad(array, pad_width, mode, **kwargs), expected)
|
||||
|
||||
test_pad([1, 2, 3, 4, 5], (2, 3),
|
||||
np.array([4, 4, 1, 2, 3, 4, 5, 6, 6, 6]),
|
||||
|
@ -1270,6 +1270,400 @@ test_pad([1, 2, 3, 4, 5], (2, 3),
|
|||
test_pad([1, 2, 3, 4, 5], (2, 3), np.array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3]),
|
||||
'wrap')
|
||||
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([4, 4, 1, 2, 3, 4, 5, 6, 6, 6], dtype=np.float32),
|
||||
'constant',
|
||||
constant_values=(4, 6))
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([1, 1, 1, 2, 3, 4, 5, 5, 5, 5], dtype=np.float32),
|
||||
'edge')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([5, 3, 1, 2, 3, 4, 5, 2, -1, -4], dtype=np.float32),
|
||||
'linear_ramp',
|
||||
end_values=(5, -4))
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, ),
|
||||
np.array([5, 5, 1, 2, 3, 4, 5, 5, 5], dtype=np.float32),
|
||||
'maximum')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, ),
|
||||
np.array([3, 3, 1, 2, 3, 4, 5, 3, 3], dtype=np.float32),
|
||||
'mean')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, ),
|
||||
np.array([3, 3, 1, 2, 3, 4, 5, 3, 3], dtype=np.float32),
|
||||
'median')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([3, 2, 1, 2, 3, 4, 5, 4, 3, 2], dtype=np.float32),
|
||||
'reflect')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32),
|
||||
'reflect',
|
||||
reflect_type='odd')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([2, 1, 1, 2, 3, 4, 5, 5, 4, 3], dtype=np.float32),
|
||||
'symmetric')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([0, 1, 1, 2, 3, 4, 5, 5, 6, 7], dtype=np.float32),
|
||||
'symmetric',
|
||||
reflect_type='odd')
|
||||
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
|
||||
np.array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3], dtype=np.float32),
|
||||
'wrap')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 2., 0., 0., 0.],
|
||||
[0., 0., 3., 4., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.]], dtype=np.float32),
|
||||
'constant')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[3., 3., 3., 4., 4., 4., 4.],
|
||||
[3., 3., 3., 4., 4., 4., 4.],
|
||||
[3., 3., 3., 4., 4., 4., 4.]], dtype=np.float32),
|
||||
'edge')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[0. , 0. , 0. , 0. , 0. ,
|
||||
0. , 0. ],
|
||||
[0. , 0.16666667, 0.33333334, 0.6666667 , 0.44444448,
|
||||
0.22222224, 0. ],
|
||||
[0. , 0.33333334, 0.6666667 , 1.3333334 , 0.88888896,
|
||||
0.44444448, 0. ],
|
||||
[0. , 0.5 , 1. , 2. , 1.3333334 ,
|
||||
0.6666667 , 0. ],
|
||||
[0. , 1.5 , 3. , 4. , 2.6666667 ,
|
||||
1.3333334 , 0. ],
|
||||
[0. , 0.75 , 1.5 , 2. , 1.3333334 ,
|
||||
0.6666667 , 0. ],
|
||||
[0. , 0. , 0. , 0. , 0. ,
|
||||
0. , 0. ]], dtype=np.float32),
|
||||
'linear_ramp')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[2., 2., 1., 2., 2., 2., 2.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.]], dtype=np.float32),
|
||||
'maximum')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
|
||||
'mean')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
|
||||
'median')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[3., 3., 3., 4., 3., 3., 3.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.]], dtype=np.float32),
|
||||
'minimum')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
|
||||
'reflect')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[4., 3., 3., 4., 4., 3., 3.],
|
||||
[4., 3., 3., 4., 4., 3., 3.],
|
||||
[2., 1., 1., 2., 2., 1., 1.],
|
||||
[2., 1., 1., 2., 2., 1., 1.],
|
||||
[4., 3., 3., 4., 4., 3., 3.],
|
||||
[4., 3., 3., 4., 4., 3., 3.],
|
||||
[2., 1., 1., 2., 2., 1., 1.]], dtype=np.float32),
|
||||
'symmetric')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
|
||||
'wrap')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 1., 2., 0., 0., 0.],
|
||||
[0., 0., 3., 4., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.],
|
||||
[0., 0., 0., 0., 0., 0., 0.]], dtype=np.float32),
|
||||
'constant')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[-1., -1., -1., -1., -1., -1., -1.],
|
||||
[-1., -1., -1., -1., -1., -1., -1.],
|
||||
[-1., -1., -1., -1., -1., -1., -1.],
|
||||
[-1., -1., 1., 2., -1., -1., -1.],
|
||||
[-1., -1., 3., 4., -1., -1., -1.],
|
||||
[-1., -1., -1., -1., -1., -1., -1.],
|
||||
[-1., -1., -1., -1., -1., -1., -1.]], dtype=np.float32),
|
||||
'constant', constant_values=-1)
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[-1., -1., -1., -1., -2., -2., -2.],
|
||||
[-1., -1., -1., -1., -2., -2., -2.],
|
||||
[-1., -1., -1., -1., -2., -2., -2.],
|
||||
[-1., -1., 1., 2., -2., -2., -2.],
|
||||
[-1., -1., 3., 4., -2., -2., -2.],
|
||||
[-1., -1., -2., -2., -2., -2., -2.],
|
||||
[-1., -1., -2., -2., -2., -2., -2.]], dtype=np.float32),
|
||||
'constant', constant_values=((-1, -2)))
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[-3., -3., -1., -1., -4., -4., -4.],
|
||||
[-3., -3., -1., -1., -4., -4., -4.],
|
||||
[-3., -3., -1., -1., -4., -4., -4.],
|
||||
[-3., -3., 1., 2., -4., -4., -4.],
|
||||
[-3., -3., 3., 4., -4., -4., -4.],
|
||||
[-3., -3., -2., -2., -4., -4., -4.],
|
||||
[-3., -3., -2., -2., -4., -4., -4.]], dtype=np.float32),
|
||||
'constant', constant_values=((-1, -2), (-3, -4)))
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[3., 3., 3., 4., 4., 4., 4.],
|
||||
[3., 3., 3., 4., 4., 4., 4.],
|
||||
[3., 3., 3., 4., 4., 4., 4.]], dtype=np.float32),
|
||||
'edge')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[0. , 0. , 0. , 0. , 0. ,
|
||||
0. , 0. ],
|
||||
[0. , 0.16666667, 0.33333334, 0.6666667 , 0.44444448,
|
||||
0.22222224, 0. ],
|
||||
[0. , 0.33333334, 0.6666667 , 1.3333334 , 0.88888896,
|
||||
0.44444448, 0. ],
|
||||
[0. , 0.5 , 1. , 2. , 1.3333334 ,
|
||||
0.6666667 , 0. ],
|
||||
[0. , 1.5 , 3. , 4. , 2.6666667 ,
|
||||
1.3333334 , 0. ],
|
||||
[0. , 0.75 , 1.5 , 2. , 1.3333334 ,
|
||||
0.6666667 , 0. ],
|
||||
[0. , 0. , 0. , 0. , 0. ,
|
||||
0. , 0. ]], dtype=np.float32),
|
||||
'linear_ramp')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1. , 1. , 1. , 1. , 1. , 1. ,
|
||||
1. ],
|
||||
[1. , 1. , 1. , 1.3333334, 1.2222222, 1.1111112,
|
||||
1. ],
|
||||
[1. , 1. , 1. , 1.6666667, 1.4444445, 1.2222222,
|
||||
1. ],
|
||||
[1. , 1. , 1. , 2. , 1.6666667, 1.3333334,
|
||||
1. ],
|
||||
[1. , 2. , 3. , 4. , 3. , 2. ,
|
||||
1. ],
|
||||
[1. , 1.5 , 2. , 2.5 , 2. , 1.5 ,
|
||||
1. ],
|
||||
[1. , 1. , 1. , 1. , 1. , 1. ,
|
||||
1. ]], dtype=np.float32),
|
||||
'linear_ramp', end_values=1)
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1. , 1. , 1. , 1. , 1.3333333, 1.6666666,
|
||||
2. ],
|
||||
[1. , 1. , 1. , 1.3333334, 1.5555556, 1.7777778,
|
||||
2. ],
|
||||
[1. , 1. , 1. , 1.6666667, 1.7777778, 1.888889 ,
|
||||
2. ],
|
||||
[1. , 1. , 1. , 2. , 2. , 2. ,
|
||||
2. ],
|
||||
[1. , 2. , 3. , 4. , 3.3333335, 2.6666667,
|
||||
2. ],
|
||||
[1. , 1.75 , 2.5 , 3. , 2.6666667, 2.3333333,
|
||||
2. ],
|
||||
[1. , 1.5 , 2. , 2. , 2. , 2. ,
|
||||
2. ]], dtype=np.float32),
|
||||
'linear_ramp', end_values=(1, 2))
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[2. , 1.5 , 1. , 1. , 1. , 1. ,
|
||||
1. ],
|
||||
[2. , 1.5 , 1. , 1.3333334, 1.2222222, 1.1111112,
|
||||
1. ],
|
||||
[2. , 1.5 , 1. , 1.6666667, 1.4444445, 1.2222222,
|
||||
1. ],
|
||||
[2. , 1.5 , 1. , 2. , 1.6666667, 1.3333334,
|
||||
1. ],
|
||||
[2. , 2.5 , 3. , 4. , 3. , 2. ,
|
||||
1. ],
|
||||
[2. , 2.25 , 2.5 , 3. , 2.3333335, 1.6666667,
|
||||
1. ],
|
||||
[2. , 2. , 2. , 2. , 1.6666667, 1.3333334,
|
||||
1. ]], dtype=np.float32),
|
||||
'linear_ramp', end_values=((1, 2), (2, 1)))
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[2., 2., 1., 2., 2., 2., 2.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.],
|
||||
[4., 4., 3., 4., 4., 4., 4.]], dtype=np.float32),
|
||||
'maximum')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
|
||||
'mean')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[1., 1., 1., 2., 2., 2., 2.],
|
||||
[3., 3., 3., 4., 4., 4., 4.],
|
||||
[3., 3., 3., 4., 4., 4., 4.],
|
||||
[3., 3., 3., 4., 4., 4., 4.]], dtype=np.float32),
|
||||
'mean', stat_length=1)
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[3. , 3. , 3. , 4. , 3.5, 3.5, 3.5],
|
||||
[2. , 2. , 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2. , 2. , 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
|
||||
'mean', stat_length=(1, 2))
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
|
||||
[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
|
||||
[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
|
||||
[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
|
||||
[3.5, 3.5, 3. , 4. , 4. , 4. , 4. ],
|
||||
[2.5, 2.5, 2. , 3. , 3. , 3. , 3. ],
|
||||
[2.5, 2.5, 2. , 3. , 3. , 3. , 3. ]], dtype=np.float32),
|
||||
'mean', stat_length=((1, 2), (2, 1)))
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
|
||||
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
|
||||
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
|
||||
'median')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[3., 3., 3., 4., 3., 3., 3.],
|
||||
[1., 1., 1., 2., 1., 1., 1.],
|
||||
[1., 1., 1., 2., 1., 1., 1.]], dtype=np.float32),
|
||||
'minimum')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
|
||||
'reflect')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[4., 3., 3., 4., 4., 3., 3.],
|
||||
[4., 3., 3., 4., 4., 3., 3.],
|
||||
[2., 1., 1., 2., 2., 1., 1.],
|
||||
[2., 1., 1., 2., 2., 1., 1.],
|
||||
[4., 3., 3., 4., 4., 3., 3.],
|
||||
[4., 3., 3., 4., 4., 3., 3.],
|
||||
[2., 1., 1., 2., 2., 1., 1.]], dtype=np.float32),
|
||||
'symmetric')
|
||||
|
||||
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
|
||||
np.array([[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.],
|
||||
[1., 2., 1., 2., 1., 2., 1.],
|
||||
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
|
||||
'wrap')
|
||||
|
||||
test_pad([[[[1]]]], 1,
|
||||
np.array([[[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]],
|
||||
[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]],
|
||||
[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]]],
|
||||
[[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]],
|
||||
[[2, 2, 2],
|
||||
[2, 1, 2],
|
||||
[2, 2, 2]],
|
||||
[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]]],
|
||||
[[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]],
|
||||
[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]],
|
||||
[[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]]]]), 'constant', constant_values=2)
|
||||
|
||||
test_pad([[[[1]]]], 0, np.array([[[[1]]]]), 'constant', constant_values=2)
|
||||
|
||||
#############
|
||||
# searching #
|
||||
#############
|
||||
|
|
|
@ -247,6 +247,67 @@ test_correlate(np.array([1 + 0j, 2 + 0j, 3 + 0j, 4 + 1j]),
|
|||
np.array([-1 + 0j, -2j, 3 + 1j]),
|
||||
np.array([8. + 1.j, 11. + 5.j]))
|
||||
|
||||
@test
|
||||
def test_correlate2():
|
||||
# Integer inputs
|
||||
a = np.array([1, 2, 3])
|
||||
b = np.array([4, 5, 6])
|
||||
assert np.allclose(np.correlate(a, b, mode="valid"), [32])
|
||||
assert np.allclose(np.correlate(a, b, mode="same"), [17, 32, 23])
|
||||
assert np.allclose(np.correlate(a, b, mode="full"), [6, 17, 32, 23, 12])
|
||||
|
||||
# Floating-point inputs
|
||||
a = np.array([1.5, 2.5, 3.5])
|
||||
b = np.array([4.0, 5.0, 6.0])
|
||||
assert np.allclose(np.correlate(a, b, mode="valid"), [39.5])
|
||||
assert np.allclose(np.correlate(a, b, mode="same"), [22.5, 39.5, 27.5])
|
||||
assert np.allclose(np.correlate(a, b, mode="full"), [9.0, 22.5, 39.5, 27.5, 14.0])
|
||||
|
||||
# Complex numbers
|
||||
a = np.array([1+2j, 3+4j])
|
||||
b = np.array([5+6j, 7+8j])
|
||||
assert np.allclose(np.correlate(a, b, mode="valid"), [70+8j])
|
||||
assert np.allclose(np.correlate(a, b, mode="same"), [23+6j, 70+8j])
|
||||
assert np.allclose(np.correlate(a, b, mode="full"), [23+6j, 70+8j, 39+2j])
|
||||
|
||||
# Different-length arrays
|
||||
a = np.array([1, 2, 3, 4])
|
||||
b = np.array([0, 1])
|
||||
assert np.allclose(np.correlate(a, b, mode="valid"), [2, 3, 4])
|
||||
assert np.allclose(np.correlate(a, b, mode="same"), [1, 2, 3, 4])
|
||||
assert np.allclose(np.correlate(a, b, mode="full"), [1, 2, 3, 4, 0])
|
||||
a = np.array([0, 1])
|
||||
b = np.array([1, 2, 3, 4])
|
||||
assert np.allclose(np.correlate(a, b, mode="valid")[::-1], [2, 3, 4])
|
||||
assert np.allclose(np.correlate(a, b, mode="same")[::-1], [1, 2, 3, 4])
|
||||
assert np.allclose(np.correlate(a, b, mode="full")[::-1], [1, 2, 3, 4, 0])
|
||||
|
||||
# Large array test
|
||||
a = np.arange(20)
|
||||
b = np.arange(10)
|
||||
expected_valid = np.array([np.sum(a[i : i + len(b)] * b) for i in range(len(a) - len(b) + 1)])
|
||||
expected_full = np.correlate(a, b, mode="full")
|
||||
expected_same = np.correlate(a, b, mode="same")
|
||||
assert np.allclose(np.correlate(a, b, mode="valid"), expected_valid)
|
||||
assert np.allclose(np.correlate(a, b, mode="full"), expected_full)
|
||||
assert np.allclose(np.correlate(a, b, mode="same"), expected_same)
|
||||
|
||||
# Different dtypes (int and float)
|
||||
a = np.array([1, 2, 3], dtype=int)
|
||||
b = np.array([1.5, 2.5, 3.5], dtype=float)
|
||||
assert np.allclose(np.correlate(a, b, mode="valid"), [17.0])
|
||||
assert np.allclose(np.correlate(a, b, mode="same"), [9.5, 17.0, 10.5])
|
||||
assert np.allclose(np.correlate(a, b, mode="full"), [3.5, 9.5, 17.0, 10.5, 4.5])
|
||||
|
||||
# Edge case: Single-element arrays
|
||||
a = np.array([5])
|
||||
b = np.array([10])
|
||||
assert np.allclose(np.correlate(a, b, mode="valid"), [50])
|
||||
assert np.allclose(np.correlate(a, b, mode="same"), [50])
|
||||
assert np.allclose(np.correlate(a, b, mode="full"), [50])
|
||||
|
||||
test_correlate2()
|
||||
|
||||
@test
|
||||
def test_bincount(x, expected, weights=None, minlength=0):
|
||||
assert (np.bincount(x, weights=weights,
|
||||
|
|
|
@ -181,3 +181,16 @@ def test_ndarray():
|
|||
assert np.datetime_data(y.dtype) == ('s', 2)
|
||||
|
||||
test_ndarray()
|
||||
|
||||
@codon.jit
|
||||
def e(x=2, y=99):
|
||||
return 2*x + y
|
||||
|
||||
def test_arg_order():
|
||||
assert e(1, 2) == 4
|
||||
assert e(1) == 101
|
||||
assert e(y=10, x=1) == 12
|
||||
assert e(x=1) == 101
|
||||
assert e() == 103
|
||||
|
||||
test_arg_order()
|
||||
|
|
|
@ -88,6 +88,7 @@ def test_codon_extensions(m):
|
|||
assert m.f4(a=2.2) == (2.2, 2.22)
|
||||
assert m.f4(b=3.3) == (1.11, 3.3)
|
||||
assert m.f4('foo') == ('foo', 'foo')
|
||||
assert m.f4(b'foo') == ('foo', 'foo')
|
||||
assert m.f4({1}) == {1}
|
||||
assert m.f5() is None
|
||||
assert equal(m.f6(1.9, 't'), 1.9, 1.9, 't')
|
||||
|
|
|
@ -450,6 +450,18 @@ def test_omp_reductions():
|
|||
c = min(b, c)
|
||||
assert c == -1.
|
||||
|
||||
c = 0.
|
||||
@par
|
||||
for i in L:
|
||||
c += i # float-int op
|
||||
assert c == expected(N, 0., float.__add__)
|
||||
|
||||
c = 0.
|
||||
@par
|
||||
for i in L:
|
||||
c = i + c # int-float op
|
||||
assert c == expected(N, 0., float.__add__)
|
||||
|
||||
# float32s
|
||||
c = f32(0.)
|
||||
# this one can give different results due to
|
||||
|
@ -479,6 +491,18 @@ def test_omp_reductions():
|
|||
c = min(b, c)
|
||||
assert c == f32(-1.)
|
||||
|
||||
c = f32(0.)
|
||||
@par
|
||||
for i in L[:12]:
|
||||
c += i # float-int op
|
||||
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
|
||||
|
||||
c = f32(0.)
|
||||
@par
|
||||
for i in L[:12]:
|
||||
c = i + c # int-float op
|
||||
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
|
||||
|
||||
x_add = 10.
|
||||
x_min = inf
|
||||
x_max = -inf
|
||||
|
|
Loading…
Reference in New Issue