Compare commits

...

16 Commits

Author SHA1 Message Date
Jonas Neubert dcb41dcfc9
codon build command: add --cir output type option (#649) 2025-04-22 11:46:03 -04:00
A. R. Shajii c1dae7d87d Update OpenBLAS 2025-04-04 14:59:13 -04:00
A. R. Shajii 984974b40d Support CMake 4.0 2025-04-04 11:27:35 -04:00
A. R. Shajii 915cb4e9f0
Support converting bytes object to Codon str (#646) 2025-04-03 10:41:19 -04:00
A. R. Shajii ce5c49edb5 Fix typo in docs and README 2025-04-03 10:39:45 -04:00
A. R. Shajii 59f5bbb73b Bump versions 2025-03-18 10:46:58 -04:00
A. R. Shajii 93fb3d53e3
JIT argument order fix (#639)
* Fix argument ordering in JIT

* Format

* Update JIT tests

* Fix JIT test
2025-03-18 10:45:34 -04:00
A. R. Shajii b3f6c12d57 Fix 0d array conversions from Python 2025-03-03 11:31:49 -05:00
A. R. Shajii b17d21513d Remove -static-libstdc++ compilation flag 2025-02-18 14:49:45 -05:00
Ibrahim Numanagić d035f1dc97
C-based Cython Backend (#629)
* Move to C-based Cython backend (to avoid all those C++ ABI issues with std::string)

* Fix CI
2025-02-18 10:22:03 -05:00
A. R. Shajii dc5e5ac7a6 Bump version 2025-02-11 22:04:22 -05:00
A. R. Shajii 01a7503762 Bump version 2025-02-11 17:41:16 -05:00
A. R. Shajii f1ab7116d8 Fix np.pad() casting 2025-02-11 15:49:15 -05:00
A. R. Shajii b58b1ee767 Update OpenMP reduction detection for new ops 2025-02-07 12:04:12 -05:00
A. R. Shajii 56c00d36c2 Add additional int-float operators 2025-02-06 14:11:52 -05:00
A. R. Shajii 4521182aa8 Update np.correlate() 2025-02-04 17:32:54 -05:00
28 changed files with 1402 additions and 247 deletions

View File

@ -187,7 +187,7 @@ jobs:
- name: Prepare Artifacts - name: Prepare Artifacts
run: | run: |
cp -rf codon-deploy/python/dist . cp -rf codon-deploy/python/dist .
rm -rf codon-deploy/lib/libfmt.a codon-deploy/lib/pkgconfig codon-deploy/lib/cmake codon-deploy/python/codon.egg-info codon-deploy/python/dist codon-deploy/python/build rm -rf codon-deploy/lib/libfmt.a codon-deploy/lib/pkgconfig codon-deploy/lib/cmake codon-deploy/python/codon_jit.egg-info codon-deploy/python/build
tar -czf ${CODON_BUILD_ARCHIVE} codon-deploy tar -czf ${CODON_BUILD_ARCHIVE} codon-deploy
du -sh codon-deploy du -sh codon-deploy

View File

@ -1,10 +1,10 @@
cmake_minimum_required(VERSION 3.14) cmake_minimum_required(VERSION 3.14)
project( project(
Codon Codon
VERSION "0.18.0" VERSION "0.18.2"
HOMEPAGE_URL "https://github.com/exaloop/codon" HOMEPAGE_URL "https://github.com/exaloop/codon"
DESCRIPTION "high-performance, extensible Python compiler") DESCRIPTION "high-performance, extensible Python compiler")
set(CODON_JIT_PYTHON_VERSION "0.3.0") set(CODON_JIT_PYTHON_VERSION "0.3.2")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in" configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
"${PROJECT_SOURCE_DIR}/codon/config/config.h") "${PROJECT_SOURCE_DIR}/codon/config/config.h")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in" configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
@ -48,10 +48,8 @@ include(${CMAKE_SOURCE_DIR}/cmake/deps.cmake)
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON) set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
if(APPLE) if(APPLE)
set(CMAKE_INSTALL_RPATH "@loader_path;@loader_path/../lib/codon") set(CMAKE_INSTALL_RPATH "@loader_path;@loader_path/../lib/codon")
set(STATIC_LIBCPP "")
else() else()
set(CMAKE_INSTALL_RPATH "$ORIGIN:$ORIGIN/../lib/codon") set(CMAKE_INSTALL_RPATH "$ORIGIN:$ORIGIN/../lib/codon")
set(STATIC_LIBCPP "-static-libstdc++")
endif() endif()
add_executable(peg2cpp codon/util/peg2cpp.cpp) add_executable(peg2cpp codon/util/peg2cpp.cpp)
@ -138,7 +136,7 @@ target_include_directories(codonrt PRIVATE ${backtrace_SOURCE_DIR}
${highway_SOURCE_DIR} ${highway_SOURCE_DIR}
"${gc_SOURCE_DIR}/include" "${gc_SOURCE_DIR}/include"
"${fast_float_SOURCE_DIR}/include" runtime) "${fast_float_SOURCE_DIR}/include" runtime)
target_link_libraries(codonrt PRIVATE fmt omp backtrace ${STATIC_LIBCPP} LLVMSupport) target_link_libraries(codonrt PRIVATE fmt omp backtrace LLVMSupport)
if(APPLE) if(APPLE)
target_link_libraries( target_link_libraries(
codonrt codonrt
@ -434,11 +432,7 @@ llvm_map_components_to_libnames(
TransformUtils TransformUtils
Vectorize Vectorize
Passes) Passes)
if(APPLE) target_link_libraries(codonc PRIVATE ${LLVM_LIBS} fmt dl codonrt)
target_link_libraries(codonc PRIVATE ${LLVM_LIBS} fmt dl codonrt)
else()
target_link_libraries(codonc PRIVATE ${STATIC_LIBCPP} ${LLVM_LIBS} fmt dl codonrt)
endif()
# Gather headers # Gather headers
add_custom_target( add_custom_target(
@ -482,13 +476,13 @@ add_dependencies(libs codonrt codonc)
# Codon command-line tool # Codon command-line tool
add_executable(codon codon/app/main.cpp) add_executable(codon codon/app/main.cpp)
target_link_libraries(codon PUBLIC ${STATIC_LIBCPP} fmt codonc codon_jupyter Threads::Threads) target_link_libraries(codon PUBLIC fmt codonc codon_jupyter Threads::Threads)
# Codon test Download and unpack googletest at configure time # Codon test Download and unpack googletest at configure time
include(FetchContent) include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
googletest googletest
URL https://github.com/google/googletest/archive/609281088cfefc76f9d0ce82e1ff6c30cc3591e5.zip URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
) )
# For Windows: Prevent overriding the parent project's compiler/linker settings # For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)

View File

@ -149,7 +149,7 @@ print(total)
``` ```
Note that Codon automatically turns the `total += 1` statement in the loop body into an atomic Note that Codon automatically turns the `total += 1` statement in the loop body into an atomic
reduction to avoid race conditions. Learn more in the [multitheading docs](advanced/parallel.md). reduction to avoid race conditions. Learn more in the [multithreading docs](https://docs.exaloop.io/codon/advanced/parallel).
Codon also supports writing and executing GPU kernels. Here's an example that computes the Codon also supports writing and executing GPU kernels. Here's an example that computes the
[Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set): [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set):

View File

@ -1,8 +1,8 @@
set(CPM_DOWNLOAD_VERSION 0.32.3) set(CPM_DOWNLOAD_VERSION 0.40.8)
set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
if(NOT (EXISTS ${CPM_DOWNLOAD_LOCATION})) if(NOT (EXISTS ${CPM_DOWNLOAD_LOCATION}))
message(STATUS "Downloading CPM.cmake...") message(STATUS "Downloading CPM.cmake...")
file(DOWNLOAD https://github.com/TheLartians/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake ${CPM_DOWNLOAD_LOCATION}) file(DOWNLOAD https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake ${CPM_DOWNLOAD_LOCATION})
endif() endif()
include(${CPM_DOWNLOAD_LOCATION}) include(${CPM_DOWNLOAD_LOCATION})
@ -77,9 +77,9 @@ endif()
CPMAddPackage( CPMAddPackage(
NAME bdwgc NAME bdwgc
GITHUB_REPOSITORY "ivmai/bdwgc" GITHUB_REPOSITORY "exaloop/bdwgc"
VERSION 8.0.5 VERSION 8.0.5
GIT_TAG d0ba209660ea8c663e06d9a68332ba5f42da54ba GIT_TAG e16c67244aff26802203060422545d38305e0160
EXCLUDE_FROM_ALL YES EXCLUDE_FROM_ALL YES
OPTIONS "CMAKE_POSITION_INDEPENDENT_CODE ON" OPTIONS "CMAKE_POSITION_INDEPENDENT_CODE ON"
"BUILD_SHARED_LIBS OFF" "BUILD_SHARED_LIBS OFF"
@ -169,7 +169,7 @@ if(NOT APPLE)
CPMAddPackage( CPMAddPackage(
NAME openblas NAME openblas
GITHUB_REPOSITORY "OpenMathLib/OpenBLAS" GITHUB_REPOSITORY "OpenMathLib/OpenBLAS"
GIT_TAG v0.3.28 GIT_TAG v0.3.29
EXCLUDE_FROM_ALL YES EXCLUDE_FROM_ALL YES
OPTIONS "DYNAMIC_ARCH ON" OPTIONS "DYNAMIC_ARCH ON"
"BUILD_TESTING OFF" "BUILD_TESTING OFF"

View File

@ -11,6 +11,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "codon/cir/util/format.h"
#include "codon/compiler/compiler.h" #include "codon/compiler/compiler.h"
#include "codon/compiler/error.h" #include "codon/compiler/error.h"
#include "codon/compiler/jit.h" #include "codon/compiler/jit.h"
@ -87,7 +88,7 @@ void initLogFlags(const llvm::cl::opt<std::string> &log) {
codon::getLogger().parse(std::string(d)); codon::getLogger().parse(std::string(d));
} }
enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect }; enum BuildKind { LLVM, Bitcode, Object, Executable, Library, PyExtension, Detect, CIR };
enum OptMode { Debug, Release }; enum OptMode { Debug, Release };
enum Numerics { C, Python }; enum Numerics { C, Python };
} // namespace } // namespace
@ -333,6 +334,7 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
clEnumValN(Executable, "exe", "Generate executable"), clEnumValN(Executable, "exe", "Generate executable"),
clEnumValN(Library, "lib", "Generate shared library"), clEnumValN(Library, "lib", "Generate shared library"),
clEnumValN(PyExtension, "pyext", "Generate Python extension module"), clEnumValN(PyExtension, "pyext", "Generate Python extension module"),
clEnumValN(CIR, "cir", "Generate Codon Intermediate Representation"),
clEnumValN(Detect, "detect", clEnumValN(Detect, "detect",
"Detect output type based on output file extension")), "Detect output type based on output file extension")),
llvm::cl::init(Detect)); llvm::cl::init(Detect));
@ -372,6 +374,9 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
case BuildKind::Detect: case BuildKind::Detect:
extension = ""; extension = "";
break; break;
case BuildKind::CIR:
extension = ".cir";
break;
default: default:
seqassertn(0, "unknown build kind"); seqassertn(0, "unknown build kind");
} }
@ -401,6 +406,11 @@ int buildMode(const std::vector<const char *> &args, const std::string &argv0) {
compiler->getLLVMVisitor()->writeToPythonExtension(*compiler->getCache()->pyModule, compiler->getLLVMVisitor()->writeToPythonExtension(*compiler->getCache()->pyModule,
filename); filename);
break; break;
case BuildKind::CIR: {
std::ofstream out(filename);
codon::ir::util::format(out, compiler->getModule());
break;
}
case BuildKind::Detect: case BuildKind::Detect:
compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags); compiler->getLLVMVisitor()->compile(filename, argv0, libsVec, lflags);
break; break;

View File

@ -402,7 +402,8 @@ struct ReductionIdentifier : public util::Operator {
static void extractAssociativeOpChain(Value *v, const std::string &op, static void extractAssociativeOpChain(Value *v, const std::string &op,
types::Type *type, types::Type *type,
std::vector<Value *> &result) { std::vector<Value *> &result) {
if (util::isCallOf(v, op, {type, type}, type, /*method=*/true)) { if (util::isCallOf(v, op, {type, nullptr}, type, /*method=*/true) ||
util::isCallOf(v, op, {nullptr, type}, type, /*method=*/true)) {
auto *call = cast<CallInstr>(v); auto *call = cast<CallInstr>(v);
extractAssociativeOpChain(call->front(), op, type, result); extractAssociativeOpChain(call->front(), op, type, result);
extractAssociativeOpChain(call->back(), op, type, result); extractAssociativeOpChain(call->back(), op, type, result);
@ -450,7 +451,8 @@ struct ReductionIdentifier : public util::Operator {
for (auto &rf : reductionFunctions) { for (auto &rf : reductionFunctions) {
if (rf.method) { if (rf.method) {
if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true)) if (!(util::isCallOf(item, rf.name, {type, nullptr}, type, /*method=*/true) ||
util::isCallOf(item, rf.name, {nullptr, type}, type, /*method=*/true)))
continue; continue;
} else { } else {
if (!util::isCallOf(item, rf.name, if (!util::isCallOf(item, rf.name,
@ -464,8 +466,7 @@ struct ReductionIdentifier : public util::Operator {
if (rf.method) { if (rf.method) {
std::vector<Value *> opChain; std::vector<Value *> opChain;
extractAssociativeOpChain(callRHS, rf.name, callRHS->front()->getType(), extractAssociativeOpChain(callRHS, rf.name, type, opChain);
opChain);
if (opChain.size() < 2) if (opChain.size() < 2)
continue; continue;

View File

@ -38,16 +38,21 @@ bool isCallOf(const Value *value, const std::string &name,
unsigned i = 0; unsigned i = 0;
for (auto *arg : *call) { for (auto *arg : *call) {
if (!arg->getType()->is(inputs[i++])) if (inputs[i] && !arg->getType()->is(inputs[i]))
return false; return false;
++i;
} }
if (output && !value->getType()->is(output)) if (output && !value->getType()->is(output))
return false; return false;
if (method && if (method) {
(inputs.empty() || !fn->getParentType() || !fn->getParentType()->is(inputs[0]))) if (inputs.empty() || !fn->getParentType())
return false; return false;
if (inputs[0] && !fn->getParentType()->is(inputs[0]))
return false;
}
return true; return true;
} }

View File

@ -263,21 +263,21 @@ ir::types::Type *JIT::PythonData::getCObjType(ir::Module *M) {
return cobj; return cobj;
} }
JITResult JIT::executeSafe(const std::string &code, const std::string &file, int line, JIT::JITResult JIT::executeSafe(const std::string &code, const std::string &file,
bool debug) { int line, bool debug) {
auto result = execute(code, file, line, debug); auto result = execute(code, file, line, debug);
if (auto err = result.takeError()) { if (auto err = result.takeError()) {
auto errorInfo = llvm::toString(std::move(err)); auto errorInfo = llvm::toString(std::move(err));
return JITResult::error(errorInfo); return JITResult::error(errorInfo);
} }
return JITResult::success(nullptr); return JITResult::success();
} }
JITResult JIT::executePython(const std::string &name, JIT::JITResult JIT::executePython(const std::string &name,
const std::vector<std::string> &types, const std::vector<std::string> &types,
const std::string &pyModule, const std::string &pyModule,
const std::vector<std::string> &pyVars, void *arg, const std::vector<std::string> &pyVars, void *arg,
bool debug) { bool debug) {
auto key = buildKey(name, types); auto key = buildKey(name, types);
auto &cache = pydata->cache; auto &cache = pydata->cache;
auto it = cache.find(key); auto it = cache.find(key);
@ -322,26 +322,48 @@ JITResult JIT::executePython(const std::string &name,
} }
} }
JIT *jitInit(const std::string &name) { } // namespace jit
auto jit = new JIT(name); } // namespace codon
void *jit_init(char *name) {
auto jit = new codon::jit::JIT(std::string(name));
llvm::cantFail(jit->init()); llvm::cantFail(jit->init());
return jit; return jit;
} }
JITResult jitExecutePython(JIT *jit, const std::string &name, void jit_exit(void *jit) { delete ((codon::jit::JIT *)jit); }
const std::vector<std::string> &types,
const std::string &pyModule, CJITResult jit_execute_python(void *jit, char *name, char **types, size_t types_size,
const std::vector<std::string> &pyVars, void *arg, char *pyModule, char **py_vars, size_t py_vars_size,
bool debug) { void *arg, uint8_t debug) {
return jit->executePython(name, types, pyModule, pyVars, arg, debug); std::vector<std::string> cppTypes;
cppTypes.reserve(types_size);
for (size_t i = 0; i < types_size; i++)
cppTypes.emplace_back(types[i]);
std::vector<std::string> cppPyVars;
cppPyVars.reserve(py_vars_size);
for (size_t i = 0; i < py_vars_size; i++)
cppPyVars.emplace_back(py_vars[i]);
auto t = ((codon::jit::JIT *)jit)
->executePython(std::string(name), cppTypes, std::string(pyModule),
cppPyVars, arg, bool(debug));
void *result = t.result;
char *message =
t.message.empty() ? nullptr : strndup(t.message.c_str(), t.message.size());
return {result, message};
} }
JITResult jitExecuteSafe(JIT *jit, const std::string &code, const std::string &file, CJITResult jit_execute_safe(void *jit, char *code, char *file, int32_t line,
int line, bool debug) { uint8_t debug) {
return jit->executeSafe(code, file, line, debug); auto t = ((codon::jit::JIT *)jit)
->executeSafe(std::string(code), std::string(file), line, bool(debug));
void *result = t.result;
char *message =
t.message.empty() ? nullptr : strndup(t.message.c_str(), t.message.size());
return {result, message};
} }
std::string getJITLibrary() { return ast::library_path(); } char *get_jit_library() {
auto t = codon::ast::library_path();
} // namespace jit return strndup(t.c_str(), t.size());
} // namespace codon }

View File

@ -31,6 +31,15 @@ public:
ir::types::Type *getCObjType(ir::Module *M); ir::types::Type *getCObjType(ir::Module *M);
}; };
struct JITResult {
void *result;
std::string message;
operator bool() const { return message.empty(); }
static JITResult success(void *result = nullptr) { return {result, ""}; }
static JITResult error(const std::string &message) { return {nullptr, message}; }
};
private: private:
std::unique_ptr<Compiler> compiler; std::unique_ptr<Compiler> compiler;
std::unique_ptr<Engine> engine; std::unique_ptr<Engine> engine;

View File

@ -2,35 +2,30 @@
#pragma once #pragma once
#include <string> #include <stddef.h>
#include <vector> #include <stdint.h>
namespace codon { #ifdef __cplusplus
namespace jit { extern "C" {
#endif
class JIT; struct CJITResult {
struct JITResult {
void *result; void *result;
std::string message; char *error;
operator bool() const { return message.empty(); }
static JITResult success(void *result) { return {result, ""}; }
static JITResult error(const std::string &message) { return {nullptr, message}; }
}; };
JIT *jitInit(const std::string &name); void *jit_init(char *name);
void jit_exit(void *jit);
JITResult jitExecutePython(JIT *jit, const std::string &name, struct CJITResult jit_execute_python(void *jit, char *name, char **types,
const std::vector<std::string> &types, size_t types_size, char *pyModule, char **py_vars,
const std::string &pyModule, size_t py_vars_size, void *arg, uint8_t debug);
const std::vector<std::string> &pyVars, void *arg,
bool debug);
JITResult jitExecuteSafe(JIT *jit, const std::string &code, const std::string &file, struct CJITResult jit_execute_safe(void *jit, char *code, char *file, int32_t line,
int line, bool debug); uint8_t debug);
std::string getJITLibrary(); char *get_jit_library();
} // namespace jit #ifdef __cplusplus
} // namespace codon }
#endif

View File

@ -124,7 +124,7 @@ print(total)
``` ```
Note that Codon automatically turns the `total += 1` statement in the loop body into an atomic Note that Codon automatically turns the `total += 1` statement in the loop body into an atomic
reduction to avoid race conditions. Learn more in the [multitheading docs](advanced/parallel.md). reduction to avoid race conditions. Learn more in the [multithreading docs](advanced/parallel.md).
Codon also supports writing and executing GPU kernels. Here's an example that computes the Codon also supports writing and executing GPU kernels. Here's an example that computes the
[Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set): [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set):

View File

@ -1,5 +1,7 @@
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io> # Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
__all__ = ["jit", "convert", "JITError"] __all__ = [
"jit", "convert", "JITError", "JITWrapper", "_jit_register_fn", "_jit"
]
from .decorator import jit, convert, execute, JITError from .decorator import jit, convert, execute, JITError, JITWrapper, _jit_register_fn, _jit_callback_fn, _jit

View File

@ -23,16 +23,14 @@ if "CODON_PATH" not in os.environ:
if codon_lib_path: if codon_lib_path:
codon_path.append(Path(codon_lib_path).parent / "stdlib") codon_path.append(Path(codon_lib_path).parent / "stdlib")
codon_path.append( codon_path.append(
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib" Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib")
)
for path in codon_path: for path in codon_path:
if path.exists(): if path.exists():
os.environ["CODON_PATH"] = str(path.resolve()) os.environ["CODON_PATH"] = str(path.resolve())
break break
else: else:
raise RuntimeError( raise RuntimeError(
"Cannot locate Codon. Please install Codon or set CODON_PATH." "Cannot locate Codon. Please install Codon or set CODON_PATH.")
)
pod_conversions = { pod_conversions = {
type(None): "pyobj", type(None): "pyobj",
@ -61,7 +59,6 @@ pod_conversions = {
custom_conversions = {} custom_conversions = {}
_error_msgs = set() _error_msgs = set()
def _common_type(t, debug, sample_size): def _common_type(t, debug, sample_size):
sub, is_optional = None, False sub, is_optional = None, False
for i in itertools.islice(t, sample_size): for i in itertools.islice(t, sample_size):
@ -76,7 +73,6 @@ def _common_type(t, debug, sample_size):
sub = "Optional[{}]".format(sub) sub = "Optional[{}]".format(sub)
return sub if sub else "pyobj" return sub if sub else "pyobj"
def _codon_type(arg, **kwargs): def _codon_type(arg, **kwargs):
t = type(arg) t = type(arg)
@ -88,11 +84,11 @@ def _codon_type(arg, **kwargs):
if issubclass(t, set): if issubclass(t, set):
return "Set[{}]".format(_common_type(arg, **kwargs)) return "Set[{}]".format(_common_type(arg, **kwargs))
if issubclass(t, dict): if issubclass(t, dict):
return "Dict[{},{}]".format( return "Dict[{},{}]".format(_common_type(arg.keys(), **kwargs),
_common_type(arg.keys(), **kwargs), _common_type(arg.values(), **kwargs) _common_type(arg.values(), **kwargs))
)
if issubclass(t, tuple): if issubclass(t, tuple):
return "Tuple[{}]".format(",".join(_codon_type(a, **kwargs) for a in arg)) return "Tuple[{}]".format(",".join(
_codon_type(a, **kwargs) for a in arg))
if issubclass(t, np.ndarray): if issubclass(t, np.ndarray):
if arg.dtype == np.bool_: if arg.dtype == np.bool_:
dtype = "bool" dtype = "bool"
@ -134,7 +130,8 @@ def _codon_type(arg, **kwargs):
s = custom_conversions.get(t, "") s = custom_conversions.get(t, "")
if s: if s:
j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__) j = ",".join(
_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
return "{}[{}]".format(s, j) return "{}[{}]".format(s, j)
debug = kwargs.get("debug", None) debug = kwargs.get("debug", None)
@ -145,28 +142,22 @@ def _codon_type(arg, **kwargs):
_error_msgs.add(msg) _error_msgs.add(msg)
return "pyobj" return "pyobj"
def _codon_types(args, **kwargs): def _codon_types(args, **kwargs):
return tuple(_codon_type(arg, **kwargs) for arg in args) return tuple(_codon_type(arg, **kwargs) for arg in args)
def _reset_jit(): def _reset_jit():
global _jit global _jit
_jit = JITWrapper() _jit = JITWrapper()
init_code = ( init_code = ("from internal.python import "
"from internal.python import " "setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n" "setup_decorator()\n"
"setup_decorator()\n" "import numpy as np\n"
"import numpy as np\n" "import numpy.pybridge\n")
"import numpy.pybridge\n"
)
_jit.execute(init_code, "", 0, False) _jit.execute(init_code, "", 0, False)
return _jit return _jit
_jit = _reset_jit() _jit = _reset_jit()
class RewriteFunctionArgs(ast.NodeTransformer): class RewriteFunctionArgs(ast.NodeTransformer):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
@ -176,7 +167,6 @@ class RewriteFunctionArgs(ast.NodeTransformer):
node.args.args.append(ast.arg(arg=a, annotation=None)) node.args.args.append(ast.arg(arg=a, annotation=None))
return node return node
def _obj_to_str(obj, **kwargs) -> str: def _obj_to_str(obj, **kwargs) -> str:
if inspect.isclass(obj): if inspect.isclass(obj):
lines = inspect.getsourcelines(obj)[0] lines = inspect.getsourcelines(obj)[0]
@ -185,8 +175,10 @@ def _obj_to_str(obj, **kwargs) -> str:
obj_name = obj.__name__ obj_name = obj.__name__
elif callable(obj) or isinstance(obj, str): elif callable(obj) or isinstance(obj, str):
is_str = isinstance(obj, str) is_str = isinstance(obj, str)
lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0] lines = [i + '\n' for i in obj.split('\n')
if not is_str: lines = lines[1:] ] if is_str else inspect.getsourcelines(obj)[0]
if not is_str:
lines = lines[1:]
obj_str = textwrap.dedent(''.join(lines)) obj_str = textwrap.dedent(''.join(lines))
pyvars = kwargs.get("pyvars", None) pyvars = kwargs.get("pyvars", None)
@ -195,8 +187,7 @@ def _obj_to_str(obj, **kwargs) -> str:
if not isinstance(i, str): if not isinstance(i, str):
raise ValueError("pyvars only takes string literals") raise ValueError("pyvars only takes string literals")
node = ast.fix_missing_locations( node = ast.fix_missing_locations(
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)) RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)))
)
obj_str = astunparse.unparse(node) obj_str = astunparse.unparse(node)
if is_str: if is_str:
try: try:
@ -206,28 +197,23 @@ def _obj_to_str(obj, **kwargs) -> str:
else: else:
obj_name = obj.__name__ obj_name = obj.__name__
else: else:
raise TypeError("Function or class expected, got " + type(obj).__name__) raise TypeError("Function or class expected, got " +
type(obj).__name__)
return obj_name, obj_str.replace("_@par", "@par") return obj_name, obj_str.replace("_@par", "@par")
def _parse_decorated(obj, **kwargs): def _parse_decorated(obj, **kwargs):
return _obj_to_str(obj, **kwargs) return _obj_to_str(obj, **kwargs)
def convert(t): def convert(t):
if not hasattr(t, "__slots__"): if not hasattr(t, "__slots__"):
raise JITError("class '{}' does not have '__slots__' attribute".format(str(t))) raise JITError("class '{}' does not have '__slots__' attribute".format(
str(t)))
name = t.__name__ name = t.__name__
slots = t.__slots__ slots = t.__slots__
code = ( code = ("@tuple\n"
"@tuple\n" "class " + name + "[" +
"class " ",".join("T{}".format(i) for i in range(len(slots))) + "]:\n")
+ name
+ "["
+ ",".join("T{}".format(i) for i in range(len(slots)))
+ "]:\n"
)
for i, slot in enumerate(slots): for i, slot in enumerate(slots):
code += " {}: T{}\n".format(slot, i) code += " {}: T{}\n".format(slot, i)
@ -235,17 +221,14 @@ def convert(t):
code += " def __from_py__(p: cobj):\n" code += " def __from_py__(p: cobj):\n"
for i, slot in enumerate(slots): for i, slot in enumerate(slots):
code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n".format( code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n".format(
i, i, slot i, i, slot)
)
code += " return {}({})\n".format( code += " return {}({})\n".format(
name, ", ".join("a{}".format(i) for i in range(len(slots))) name, ", ".join("a{}".format(i) for i in range(len(slots))))
)
_jit.execute(code, "", 0, False) _jit.execute(code, "", 0, False)
custom_conversions[t] = name custom_conversions[t] = name
return t return t
def _jit_register_fn(f, pyvars, debug): def _jit_register_fn(f, pyvars, debug):
try: try:
obj_name, obj_str = _parse_decorated(f, pyvars=pyvars) obj_name, obj_str = _parse_decorated(f, pyvars=pyvars)
@ -258,29 +241,46 @@ def _jit_register_fn(f, pyvars, debug):
_reset_jit() _reset_jit()
raise raise
def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs): def _jit_callback_fn(fn,
try: obj_name,
module,
debug=None,
sample_size=5,
pyvars=None,
*args,
**kwargs):
if fn is not None:
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
args = tuple(bound_args.arguments[param] for param in sig.parameters)
else:
args = (*args, *kwargs.values()) args = (*args, *kwargs.values())
try:
types = _codon_types(args, debug=debug, sample_size=sample_size) types = _codon_types(args, debug=debug, sample_size=sample_size)
if debug: if debug:
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr) print("[python] {}({})".format(obj_name, list(types)),
return _jit.run_wrapper( file=sys.stderr)
obj_name, list(types), module, list(pyvars), args, 1 if debug else 0 return _jit.run_wrapper(obj_name, list(types), module, list(pyvars),
) args, 1 if debug else 0)
except JITError: except JITError:
_reset_jit() _reset_jit()
raise raise
def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None): def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None):
obj_name = _jit_register_fn(fstr, pyvars, debug) obj_name = _jit_register_fn(fstr, pyvars, debug)
def wrapped(*args, **kwargs):
return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs)
return wrapped
def wrapped(*args, **kwargs):
return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size,
pyvars, *args, **kwargs)
return wrapped
def jit(fn=None, debug=None, sample_size=5, pyvars=None): def jit(fn=None, debug=None, sample_size=5, pyvars=None):
if not pyvars: if not pyvars:
pyvars = [] pyvars = []
if not isinstance(pyvars, list): if not isinstance(pyvars, list):
raise ArgumentError("pyvars must be a list") raise ArgumentError("pyvars must be a list")
@ -289,12 +289,15 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
def _decorate(f): def _decorate(f):
obj_name = _jit_register_fn(f, pyvars, debug) obj_name = _jit_register_fn(f, pyvars, debug)
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs) return _jit_callback_fn(f, obj_name, f.__module__, debug,
return wrapped sample_size, pyvars, *args, **kwargs)
return _decorate(fn) if fn else _decorate
return wrapped
return _decorate(fn) if fn else _decorate
def execute(code, debug=False): def execute(code, debug=False):
try: try:

View File

@ -1,16 +1,22 @@
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io> # Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from libcpp.string cimport string from libc.stdint cimport int32_t, uint8_t
from libcpp.vector cimport vector
cdef extern from "codon/compiler/jit_extern.h" namespace "codon::jit": cdef extern from "codon/compiler/jit_extern.h":
cdef cppclass JIT cdef struct CJITResult:
cdef cppclass JITResult:
void *result void *result
string message char *error
bint operator bool()
JIT *jitInit(string) void *jit_init(char *name)
JITResult jitExecuteSafe(JIT*, string, string, int, char) void jit_exit(void *jit)
JITResult jitExecutePython(JIT*, string, vector[string], string, vector[string], object, char)
string getJITLibrary() cdef char *get_jit_library()
cdef CJITResult jit_execute_safe(
void *jit, char *code, char *file, int32_t line, uint8_t debug
)
cdef CJITResult jit_execute_python(
void *jit, char *name, char **types, size_t types_size,
char *pyModule, char **py_vars, size_t py_vars_size,
void *arg, uint8_t debug
)

View File

@ -1,45 +1,82 @@
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io> # Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
# distutils: language=c++ # distutils: language=c
# cython: language_level=3 # cython: language_level=3
# cython: c_string_type=unicode # cython: c_string_type=unicode
# cython: c_string_encoding=utf8 # cython: c_string_encoding=utf8
from libcpp.string cimport string
from libcpp.vector cimport vector
cimport codon.jit cimport codon.jit
from libc.stdlib cimport malloc, calloc, free
from libc.string cimport strcpy
from libc.stdint cimport int32_t, uint8_t
class JITError(Exception): class JITError(Exception):
pass pass
cdef str get_free_str(char *s):
cdef bytes py_s
try:
py_s = s
return py_s.decode('utf-8')
finally:
free(s)
cdef class JITWrapper: cdef class JITWrapper:
cdef codon.jit.JIT* jit cdef void* jit
def __cinit__(self): def __cinit__(self):
self.jit = codon.jit.jitInit(b"codon jit") self.jit = codon.jit.jit_init(b"codon jit")
def __dealloc__(self): def __dealloc__(self):
del self.jit codon.jit.jit_exit(self.jit)
def execute(self, code: str, filename: str, fileno: int, debug: char) -> str: def execute(self, code: str, filename: str, fileno: int, debug) -> str:
result = codon.jit.jitExecuteSafe(self.jit, code, filename, fileno, <char>debug) result = codon.jit.jit_execute_safe(
if <bint>result: self.jit, code.encode('utf-8'), filename.encode('utf-8'), fileno, <uint8_t>debug
)
if result.error is NULL:
return None return None
else: else:
raise JITError(result.message) msg = get_free_str(result.error)
raise JITError(msg)
def run_wrapper(self, name: str, types: list[str], module: str, pyvars: list[str], args, debug) -> object:
cdef char** c_types = <char**>calloc(len(types), sizeof(char*))
cdef char** c_pyvars = <char**>calloc(len(pyvars), sizeof(char*))
if not c_types or not c_pyvars:
raise JITError("Cython allocation failed")
try:
for i, s in enumerate(types):
bytes = s.encode('utf-8')
c_types[i] = <char*>malloc(len(bytes) + 1)
strcpy(c_types[i], bytes)
for i, s in enumerate(pyvars):
bytes = s.encode('utf-8')
c_pyvars[i] = <char*>malloc(len(bytes) + 1)
strcpy(c_pyvars[i], bytes)
result = codon.jit.jit_execute_python(
self.jit, name.encode('utf-8'), c_types, len(types),
module.encode('utf-8'), c_pyvars, len(pyvars),
<void *>args, <uint8_t>debug
)
if result.error is NULL:
return <object>result.result
else:
msg = get_free_str(result.error)
raise JITError(msg)
finally:
for i in range(len(types)):
free(c_types[i])
free(c_types)
for i in range(len(pyvars)):
free(c_pyvars[i])
free(c_pyvars)
def run_wrapper(self, name: str, types: list[str], module: str, pyvars: list[str], args, debug: char) -> object:
cdef vector[string] types_vec = types
cdef vector[string] pyvars_vec = pyvars
result = codon.jit.jitExecutePython(
self.jit, name, types_vec, module, pyvars_vec, <object>args, <char>debug
)
if <bint>result:
return <object>result.result
else:
raise JITError(result.message)
def codon_library(): def codon_library():
return codon.jit.getJITLibrary() cdef char* c = codon.jit.get_jit_library()
return get_free_str(c)

View File

@ -67,7 +67,7 @@ jit_extension = Extension(
"codon.codon_jit", "codon.codon_jit",
sources=["codon/jit.pyx"], sources=["codon/jit.pyx"],
libraries=libraries, libraries=libraries,
language="c++", language="c",
extra_compile_args=["-w"], extra_compile_args=["-w"],
extra_link_args=linker_args, extra_link_args=linker_args,
include_dirs=[str(codon_path / "include")], include_dirs=[str(codon_path / "include")],

View File

@ -24,6 +24,7 @@ PyFloat_AsDouble = Function[[cobj], float](cobj())
PyFloat_FromDouble = Function[[float], cobj](cobj()) PyFloat_FromDouble = Function[[float], cobj](cobj())
PyBool_FromLong = Function[[int], cobj](cobj()) PyBool_FromLong = Function[[int], cobj](cobj())
PyBytes_AsString = Function[[cobj], cobj](cobj()) PyBytes_AsString = Function[[cobj], cobj](cobj())
PyBytes_Size = Function[[cobj], int](cobj())
PyList_New = Function[[int], cobj](cobj()) PyList_New = Function[[int], cobj](cobj())
PyList_Size = Function[[cobj], int](cobj()) PyList_Size = Function[[cobj], int](cobj())
PyList_GetItem = Function[[cobj, int], cobj](cobj()) PyList_GetItem = Function[[cobj, int], cobj](cobj())
@ -130,6 +131,7 @@ PyLong_Type = cobj()
PyFloat_Type = cobj() PyFloat_Type = cobj()
PyBool_Type = cobj() PyBool_Type = cobj()
PyUnicode_Type = cobj() PyUnicode_Type = cobj()
PyBytes_Type = cobj()
PyComplex_Type = cobj() PyComplex_Type = cobj()
PyList_Type = cobj() PyList_Type = cobj()
PyDict_Type = cobj() PyDict_Type = cobj()
@ -213,6 +215,7 @@ def init_handles_dlopen(py_handle: cobj):
global PyFloat_FromDouble global PyFloat_FromDouble
global PyBool_FromLong global PyBool_FromLong
global PyBytes_AsString global PyBytes_AsString
global PyBytes_Size
global PyList_New global PyList_New
global PyList_Size global PyList_Size
global PyList_GetItem global PyList_GetItem
@ -303,6 +306,7 @@ def init_handles_dlopen(py_handle: cobj):
global PyFloat_Type global PyFloat_Type
global PyBool_Type global PyBool_Type
global PyUnicode_Type global PyUnicode_Type
global PyBytes_Type
global PyComplex_Type global PyComplex_Type
global PyList_Type global PyList_Type
global PyDict_Type global PyDict_Type
@ -347,6 +351,7 @@ def init_handles_dlopen(py_handle: cobj):
PyFloat_FromDouble = dlsym(py_handle, "PyFloat_FromDouble") PyFloat_FromDouble = dlsym(py_handle, "PyFloat_FromDouble")
PyBool_FromLong = dlsym(py_handle, "PyBool_FromLong") PyBool_FromLong = dlsym(py_handle, "PyBool_FromLong")
PyBytes_AsString = dlsym(py_handle, "PyBytes_AsString") PyBytes_AsString = dlsym(py_handle, "PyBytes_AsString")
PyBytes_Size = dlsym(py_handle, "PyBytes_Size")
PyList_New = dlsym(py_handle, "PyList_New") PyList_New = dlsym(py_handle, "PyList_New")
PyList_Size = dlsym(py_handle, "PyList_Size") PyList_Size = dlsym(py_handle, "PyList_Size")
PyList_GetItem = dlsym(py_handle, "PyList_GetItem") PyList_GetItem = dlsym(py_handle, "PyList_GetItem")
@ -437,6 +442,7 @@ def init_handles_dlopen(py_handle: cobj):
PyFloat_Type = dlsym(py_handle, "PyFloat_Type") PyFloat_Type = dlsym(py_handle, "PyFloat_Type")
PyBool_Type = dlsym(py_handle, "PyBool_Type") PyBool_Type = dlsym(py_handle, "PyBool_Type")
PyUnicode_Type = dlsym(py_handle, "PyUnicode_Type") PyUnicode_Type = dlsym(py_handle, "PyUnicode_Type")
PyBytes_Type = dlsym(py_handle, "PyBytes_Type")
PyComplex_Type = dlsym(py_handle, "PyComplex_Type") PyComplex_Type = dlsym(py_handle, "PyComplex_Type")
PyList_Type = dlsym(py_handle, "PyList_Type") PyList_Type = dlsym(py_handle, "PyList_Type")
PyDict_Type = dlsym(py_handle, "PyDict_Type") PyDict_Type = dlsym(py_handle, "PyDict_Type")
@ -482,6 +488,7 @@ def init_handles_static():
from C import PyFloat_FromDouble(float) -> cobj as _PyFloat_FromDouble from C import PyFloat_FromDouble(float) -> cobj as _PyFloat_FromDouble
from C import PyBool_FromLong(int) -> cobj as _PyBool_FromLong from C import PyBool_FromLong(int) -> cobj as _PyBool_FromLong
from C import PyBytes_AsString(cobj) -> cobj as _PyBytes_AsString from C import PyBytes_AsString(cobj) -> cobj as _PyBytes_AsString
from C import PyBytes_Size(cobj) -> int as _PyBytes_Size
from C import PyList_New(int) -> cobj as _PyList_New from C import PyList_New(int) -> cobj as _PyList_New
from C import PyList_Size(cobj) -> int as _PyList_Size from C import PyList_Size(cobj) -> int as _PyList_Size
from C import PyList_GetItem(cobj, int) -> cobj as _PyList_GetItem from C import PyList_GetItem(cobj, int) -> cobj as _PyList_GetItem
@ -572,6 +579,7 @@ def init_handles_static():
from C import PyFloat_Type: cobj as _PyFloat_Type from C import PyFloat_Type: cobj as _PyFloat_Type
from C import PyBool_Type: cobj as _PyBool_Type from C import PyBool_Type: cobj as _PyBool_Type
from C import PyUnicode_Type: cobj as _PyUnicode_Type from C import PyUnicode_Type: cobj as _PyUnicode_Type
from C import PyBytes_Type: cobj as _PyBytes_Type
from C import PyComplex_Type: cobj as _PyComplex_Type from C import PyComplex_Type: cobj as _PyComplex_Type
from C import PyList_Type: cobj as _PyList_Type from C import PyList_Type: cobj as _PyList_Type
from C import PyDict_Type: cobj as _PyDict_Type from C import PyDict_Type: cobj as _PyDict_Type
@ -616,6 +624,7 @@ def init_handles_static():
global PyFloat_FromDouble global PyFloat_FromDouble
global PyBool_FromLong global PyBool_FromLong
global PyBytes_AsString global PyBytes_AsString
global PyBytes_Size
global PyList_New global PyList_New
global PyList_Size global PyList_Size
global PyList_GetItem global PyList_GetItem
@ -706,6 +715,7 @@ def init_handles_static():
global PyFloat_Type global PyFloat_Type
global PyBool_Type global PyBool_Type
global PyUnicode_Type global PyUnicode_Type
global PyBytes_Type
global PyComplex_Type global PyComplex_Type
global PyList_Type global PyList_Type
global PyDict_Type global PyDict_Type
@ -750,6 +760,7 @@ def init_handles_static():
PyFloat_FromDouble = _PyFloat_FromDouble PyFloat_FromDouble = _PyFloat_FromDouble
PyBool_FromLong = _PyBool_FromLong PyBool_FromLong = _PyBool_FromLong
PyBytes_AsString = _PyBytes_AsString PyBytes_AsString = _PyBytes_AsString
PyBytes_Size = _PyBytes_Size
PyList_New = _PyList_New PyList_New = _PyList_New
PyList_Size = _PyList_Size PyList_Size = _PyList_Size
PyList_GetItem = _PyList_GetItem PyList_GetItem = _PyList_GetItem
@ -840,6 +851,7 @@ def init_handles_static():
PyFloat_Type = __ptr__(_PyFloat_Type).as_byte() PyFloat_Type = __ptr__(_PyFloat_Type).as_byte()
PyBool_Type = __ptr__(_PyBool_Type).as_byte() PyBool_Type = __ptr__(_PyBool_Type).as_byte()
PyUnicode_Type = __ptr__(_PyUnicode_Type).as_byte() PyUnicode_Type = __ptr__(_PyUnicode_Type).as_byte()
PyBytes_Type = __ptr__(_PyBytes_Type).as_byte()
PyComplex_Type = __ptr__(_PyComplex_Type).as_byte() PyComplex_Type = __ptr__(_PyComplex_Type).as_byte()
PyList_Type = __ptr__(_PyList_Type).as_byte() PyList_Type = __ptr__(_PyList_Type).as_byte()
PyDict_Type = __ptr__(_PyDict_Type).as_byte() PyDict_Type = __ptr__(_PyDict_Type).as_byte()
@ -1174,7 +1186,7 @@ class pyobj:
return pyobj.to_str(self.p, errors, empty) return pyobj.to_str(self.p, errors, empty)
def to_str(p: cobj, errors: str, empty: str = "") -> str: def to_str(p: cobj, errors: str, empty: str = "") -> str:
obj = PyUnicode_AsEncodedString(p, "utf-8".c_str(), errors.c_str()) obj = PyUnicode_AsEncodedString(p, "utf-8".ptr, errors.c_str() if errors else "".ptr)
if obj == cobj(): if obj == cobj():
return empty return empty
bts = PyBytes_AsString(obj) bts = PyBytes_AsString(obj)
@ -1292,8 +1304,11 @@ class _PyObject_Struct:
def _conversion_error(name: Static[str]): def _conversion_error(name: Static[str]):
raise PyError("conversion error: Python object did not have type '" + name + "'") raise PyError("conversion error: Python object did not have type '" + name + "'")
def _get_type(o: cobj):
return Ptr[_PyObject_Struct](o)[0].pytype
def _ensure_type(o: cobj, t: cobj, name: Static[str]): def _ensure_type(o: cobj, t: cobj, name: Static[str]):
if Ptr[_PyObject_Struct](o)[0].pytype != t: if _get_type(o) != t:
_conversion_error(name) _conversion_error(name)
@ -1350,7 +1365,14 @@ class str:
return pyobj.exc_wrap(PyUnicode_DecodeFSDefaultAndSize(self.ptr, self.len)) return pyobj.exc_wrap(PyUnicode_DecodeFSDefaultAndSize(self.ptr, self.len))
def __from_py__(s: cobj) -> str: def __from_py__(s: cobj) -> str:
return pyobj.exc_wrap(pyobj.to_str(s, "strict")) if _get_type(s) == PyBytes_Type:
n = PyBytes_Size(s)
p0 = PyBytes_AsString(s)
p1 = cobj(n)
str.memcpy(p1, p0, n)
return str(p1, n)
else:
return pyobj.exc_wrap(pyobj.to_str(s, "strict"))
@extend @extend
class complex: class complex:

View File

@ -4,6 +4,22 @@ from internal.attributes import commutative
from internal.gc import alloc_atomic, free from internal.gc import alloc_atomic, free
from internal.types.complex import complex from internal.types.complex import complex
def _float_int_pow(a: F, b: int, F: type) -> F:
abs_exp = b.__abs__()
result = F(1)
factor = a
while abs_exp:
if abs_exp & 1:
result *= factor
factor *= factor
abs_exp >>= 1
if b < 0:
result = F(1) / result
return result
@extend @extend
class float: class float:
def __new__() -> float: def __new__() -> float:
@ -401,6 +417,50 @@ class float:
def imag(self) -> float: def imag(self) -> float:
return 0.0 return 0.0
@commutative
def __add__(self: float, b: int) -> float:
return self + float(b)
def __sub__(self: float, b: int) -> float:
return self - float(b)
@commutative
def __mul__(self: float, b: int) -> float:
return self * float(b)
def __floordiv__(self, b: int) -> float:
return self // float(b)
def __truediv__(self: float, b: int) -> float:
return self / float(b)
def __mod__(self: float, b: int) -> float:
return self % float(b)
def __divmod__(self, b: int):
return self.__divmod__(float(b))
def __eq__(self: float, b: int) -> bool:
return self == float(b)
def __ne__(self: float, b: int) -> bool:
return self != float(b)
def __lt__(self: float, b: int) -> bool:
return self < float(b)
def __gt__(self: float, b: int) -> bool:
return self > float(b)
def __le__(self: float, b: int) -> bool:
return self <= float(b)
def __ge__(self: float, b: int) -> bool:
return self >= float(b)
def __pow__(self: float, b: int) -> float:
return _float_int_pow(self, b)
@extend @extend
class float32: class float32:
@pure @pure
@ -755,6 +815,50 @@ class float32:
def __match__(self, obj: float32) -> bool: def __match__(self, obj: float32) -> bool:
return self == obj return self == obj
@commutative
def __add__(self: float32, b: int) -> float32:
return self + float32(b)
def __sub__(self: float32, b: int) -> float32:
return self - float32(b)
@commutative
def __mul__(self: float32, b: int) -> float32:
return self * float32(b)
def __floordiv__(self, b: int) -> float32:
return self // float32(b)
def __truediv__(self: float32, b: int) -> float32:
return self / float32(b)
def __mod__(self: float32, b: int) -> float32:
return self % float32(b)
def __divmod__(self, b: int):
return self.__divmod__(float32(b))
def __eq__(self: float32, b: int) -> bool:
return self == float32(b)
def __ne__(self: float32, b: int) -> bool:
return self != float32(b)
def __lt__(self: float32, b: int) -> bool:
return self < float32(b)
def __gt__(self: float32, b: int) -> bool:
return self > float32(b)
def __le__(self: float32, b: int) -> bool:
return self <= float32(b)
def __ge__(self: float32, b: int) -> bool:
return self >= float32(b)
def __pow__(self: float32, b: int) -> float32:
return _float_int_pow(self, b)
@extend @extend
class float16: class float16:
@pure @pure
@ -1055,6 +1159,50 @@ class float16:
def __match__(self, obj: float16) -> bool: def __match__(self, obj: float16) -> bool:
return self == obj return self == obj
@commutative
def __add__(self: float16, b: int) -> float16:
return self + float16(b)
def __sub__(self: float16, b: int) -> float16:
return self - float16(b)
@commutative
def __mul__(self: float16, b: int) -> float16:
return self * float16(b)
def __floordiv__(self, b: int) -> float16:
return self // float16(b)
def __truediv__(self: float16, b: int) -> float16:
return self / float16(b)
def __mod__(self: float16, b: int) -> float16:
return self % float16(b)
def __divmod__(self, b: int):
return self.__divmod__(float16(b))
def __eq__(self: float16, b: int) -> bool:
return self == float16(b)
def __ne__(self: float16, b: int) -> bool:
return self != float16(b)
def __lt__(self: float16, b: int) -> bool:
return self < float16(b)
def __gt__(self: float16, b: int) -> bool:
return self > float16(b)
def __le__(self: float16, b: int) -> bool:
return self <= float16(b)
def __ge__(self: float16, b: int) -> bool:
return self >= float16(b)
def __pow__(self: float16, b: int) -> float16:
return _float_int_pow(self, b)
@extend @extend
class bfloat16: class bfloat16:
@pure @pure
@ -1355,6 +1503,50 @@ class bfloat16:
def __match__(self, obj: bfloat16) -> bool: def __match__(self, obj: bfloat16) -> bool:
return self == obj return self == obj
@commutative
def __add__(self: bfloat16, b: int) -> bfloat16:
return self + bfloat16(b)
def __sub__(self: bfloat16, b: int) -> bfloat16:
return self - bfloat16(b)
@commutative
def __mul__(self: bfloat16, b: int) -> bfloat16:
return self * bfloat16(b)
def __floordiv__(self, b: int) -> bfloat16:
return self // bfloat16(b)
def __truediv__(self: bfloat16, b: int) -> bfloat16:
return self / bfloat16(b)
def __mod__(self: bfloat16, b: int) -> bfloat16:
return self % bfloat16(b)
def __divmod__(self, b: int):
return self.__divmod__(bfloat16(b))
def __eq__(self: bfloat16, b: int) -> bool:
return self == bfloat16(b)
def __ne__(self: bfloat16, b: int) -> bool:
return self != bfloat16(b)
def __lt__(self: bfloat16, b: int) -> bool:
return self < bfloat16(b)
def __gt__(self: bfloat16, b: int) -> bool:
return self > bfloat16(b)
def __le__(self: bfloat16, b: int) -> bool:
return self <= bfloat16(b)
def __ge__(self: bfloat16, b: int) -> bool:
return self >= bfloat16(b)
def __pow__(self: bfloat16, b: int) -> bfloat16:
return _float_int_pow(self, b)
@extend @extend
class float128: class float128:
@pure @pure
@ -1652,6 +1844,50 @@ class float128:
def __match__(self, obj: float128) -> bool: def __match__(self, obj: float128) -> bool:
return self == obj return self == obj
@commutative
def __add__(self: float128, b: int) -> float128:
return self + float128(b)
def __sub__(self: float128, b: int) -> float128:
return self - float128(b)
@commutative
def __mul__(self: float128, b: int) -> float128:
return self * float128(b)
def __floordiv__(self, b: int) -> float128:
return self // float128(b)
def __truediv__(self: float128, b: int) -> float128:
return self / float128(b)
def __mod__(self: float128, b: int) -> float128:
return self % float128(b)
def __divmod__(self, b: int):
return self.__divmod__(float128(b))
def __eq__(self: float128, b: int) -> bool:
return self == float128(b)
def __ne__(self: float128, b: int) -> bool:
return self != float128(b)
def __lt__(self: float128, b: int) -> bool:
return self < float128(b)
def __gt__(self: float128, b: int) -> bool:
return self > float128(b)
def __le__(self: float128, b: int) -> bool:
return self <= float128(b)
def __ge__(self: float128, b: int) -> bool:
return self >= float128(b)
def __pow__(self: float128, b: int) -> float128:
return _float_int_pow(self, b)
@extend @extend
class float: class float:
def __suffix_f32__(double) -> float32: def __suffix_f32__(double) -> float32:
@ -1666,6 +1902,184 @@ class float:
def __suffix_f128__(double) -> float128: def __suffix_f128__(double) -> float128:
return float128.__new__(double) return float128.__new__(double)
@extend
class int:
@commutative
def __add__(self, b: float32) -> float32:
return float32(self) + b
def __sub__(self, b: float32) -> float32:
return float32(self) - b
@commutative
def __mul__(self, b: float32) -> float32:
return float32(self) * b
def __floordiv__(self, b: float32) -> float32:
return float32(self) // b
def __truediv__(self, b: float32) -> float32:
return float32(self) / b
def __mod__(self, b: float32) -> float32:
return float32(self) % b
def __divmod__(self, b: float32):
return float32(self).__divmod__(b)
def __pow__(self, b: float32) -> float32:
return float32(self) ** b
def __eq__(self, b: float32) -> bool:
return float32(self) == b
def __ne__(self, b: float32) -> bool:
return float32(self) != b
def __lt__(self, b: float32) -> bool:
return float32(self) < b
def __gt__(self, b: float32) -> bool:
return float32(self) > b
def __le__(self, b: float32) -> bool:
return float32(self) <= b
def __ge__(self, b: float32) -> bool:
return float32(self) >= b
@commutative
def __add__(self, b: float16) -> float16:
return float16(self) + b
def __sub__(self, b: float16) -> float16:
return float16(self) - b
@commutative
def __mul__(self, b: float16) -> float16:
return float16(self) * b
def __floordiv__(self, b: float16) -> float16:
return float16(self) // b
def __truediv__(self, b: float16) -> float16:
return float16(self) / b
def __mod__(self, b: float16) -> float16:
return float16(self) % b
def __divmod__(self, b: float16):
return float16(self).__divmod__(b)
def __pow__(self, b: float16) -> float16:
return float16(self) ** b
def __eq__(self, b: float16) -> bool:
return float16(self) == b
def __ne__(self, b: float16) -> bool:
return float16(self) != b
def __lt__(self, b: float16) -> bool:
return float16(self) < b
def __gt__(self, b: float16) -> bool:
return float16(self) > b
def __le__(self, b: float16) -> bool:
return float16(self) <= b
def __ge__(self, b: float16) -> bool:
return float16(self) >= b
@commutative
def __add__(self, b: bfloat16) -> bfloat16:
return bfloat16(self) + b
def __sub__(self, b: bfloat16) -> bfloat16:
return bfloat16(self) - b
@commutative
def __mul__(self, b: bfloat16) -> bfloat16:
return bfloat16(self) * b
def __floordiv__(self, b: bfloat16) -> bfloat16:
return bfloat16(self) // b
def __truediv__(self, b: bfloat16) -> bfloat16:
return bfloat16(self) / b
def __mod__(self, b: bfloat16) -> bfloat16:
return bfloat16(self) % b
def __divmod__(self, b: bfloat16):
return bfloat16(self).__divmod__(b)
def __pow__(self, b: bfloat16) -> bfloat16:
return bfloat16(self) ** b
def __eq__(self, b: bfloat16) -> bool:
return bfloat16(self) == b
def __ne__(self, b: bfloat16) -> bool:
return bfloat16(self) != b
def __lt__(self, b: bfloat16) -> bool:
return bfloat16(self) < b
def __gt__(self, b: bfloat16) -> bool:
return bfloat16(self) > b
def __le__(self, b: bfloat16) -> bool:
return bfloat16(self) <= b
def __ge__(self, b: bfloat16) -> bool:
return bfloat16(self) >= b
@commutative
def __add__(self, b: float128) -> float128:
return float128(self) + b
def __sub__(self, b: float128) -> float128:
return float128(self) - b
@commutative
def __mul__(self, b: float128) -> float128:
return float128(self) * b
def __floordiv__(self, b: float128) -> float128:
return float128(self) // b
def __truediv__(self, b: float128) -> float128:
return float128(self) / b
def __mod__(self, b: float128) -> float128:
return float128(self) % b
def __divmod__(self, b: float128):
return float128(self).__divmod__(b)
def __pow__(self, b: float128) -> float128:
return float128(self) ** b
def __eq__(self, b: float128) -> bool:
return float128(self) == b
def __ne__(self, b: float128) -> bool:
return float128(self) != b
def __lt__(self, b: float128) -> bool:
return float128(self) < b
def __gt__(self, b: float128) -> bool:
return float128(self) > b
def __le__(self, b: float128) -> bool:
return float128(self) <= b
def __ge__(self, b: float128) -> bool:
return float128(self) >= b
f16 = float16 f16 = float16
bf16 = bfloat16 bf16 = bfloat16
f32 = float32 f32 = float32

View File

@ -57,8 +57,8 @@ class FloatingFormat:
self.__init__() self.__init__()
return return
min_val = None min_val: Optional[a.dtype] = None
max_val = None max_val: Optional[a.dtype] = None
finite = 0 finite = 0
abs_non_zero = 0 abs_non_zero = 0
exp_format = False exp_format = False

View File

@ -245,7 +245,7 @@ class ndarray:
k = 0 k = 0
for idx in util.multirange(shape): for idx in util.multirange(shape):
off = 0 off = 0
for i in range(ndim): for i in staticrange(ndim):
off += idx[i] * strides[i] off += idx[i] * strides[i]
e = Ptr[cobj](arr_data + off)[0] e = Ptr[cobj](arr_data + off)[0]
if hasattr(dtype, "__from_py__"): if hasattr(dtype, "__from_py__"):
@ -263,7 +263,7 @@ class ndarray:
k = 0 k = 0
for idx in util.multirange(shape): for idx in util.multirange(shape):
off = 0 off = 0
for i in range(ndim): for i in staticrange(ndim):
off += idx[i] * strides[i] off += idx[i] * strides[i]
e = Ptr[dtype](arr_data + off)[0] e = Ptr[dtype](arr_data + off)[0]
data[k] = e data[k] = e

View File

@ -2661,7 +2661,7 @@ def pad(array, pad_width, mode = 'constant', **kwargs):
if dtype is int or isinstance(dtype, Int) or isinstance(dtype, UInt): if dtype is int or isinstance(dtype, Int) or isinstance(dtype, UInt):
return util.cast(util.rint(x), dtype) return util.cast(util.rint(x), dtype)
else: else:
return x return util.cast(x, dtype)
def pad_from_function(a: ndarray, pw, padding_func, kwargs, extra = None): def pad_from_function(a: ndarray, pw, padding_func, kwargs, extra = None):
shape = a.shape shape = a.shape
@ -2915,7 +2915,7 @@ def pad(array, pad_width, mode = 'constant', **kwargs):
fill_linear(vector, fill_linear(vector,
offset=(n - p2), offset=(n - p2),
start=util.cast(end2, float), start=util.cast(end2, float),
stop=util.cast(start1, float), stop=util.cast(start2, float),
num=p2, num=p2,
rev=True) rev=True)

View File

@ -270,84 +270,174 @@ def corrcoef(x, y=None, rowvar=True, dtype: type = NoneType):
return c return c
def correlate(a, b, mode: Static[str] = 'valid'): def _correlate(a, b, mode: str):
def kernel(d, dstride: int, nd: int, dtype: type,
k, kstride: int, nk: Static[int], ktype: type,
out, ostride: int):
for i in range(nd):
acc = util.zero(dtype)
for j in staticrange(nk):
acc += d[(i + j) * dstride] * k[j * kstride]
out[i * ostride] = acc
def small_correlate(d, dstride: int, nd: int, dtype: type,
k, kstride: int, nk: int, ktype: type,
out, ostride: int):
if dtype is not ktype:
return False
dstride //= util.sizeof(dtype)
kstride //= util.sizeof(dtype)
ostride //= util.sizeof(dtype)
if nk == 1:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=1, ktype=ktype,
out=out, ostride=ostride)
elif nk == 2:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=2, ktype=ktype,
out=out, ostride=ostride)
elif nk == 3:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=3, ktype=ktype,
out=out, ostride=ostride)
elif nk == 4:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=4, ktype=ktype,
out=out, ostride=ostride)
elif nk == 5:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=5, ktype=ktype,
out=out, ostride=ostride)
elif nk == 6:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=6, ktype=ktype,
out=out, ostride=ostride)
elif nk == 7:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=7, ktype=ktype,
out=out, ostride=ostride)
elif nk == 8:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=8, ktype=ktype,
out=out, ostride=ostride)
elif nk == 9:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=9, ktype=ktype,
out=out, ostride=ostride)
elif nk == 10:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=10, ktype=ktype,
out=out, ostride=ostride)
elif nk == 11:
kernel(d=d, dstride=dstride, nd=nd, dtype=dtype,
k=k, kstride=kstride, nk=11, ktype=ktype,
out=out, ostride=ostride)
else:
return False
return True
def dot(_ip1: Ptr[T1], is1: int, _ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
T1: type, T2: type, T3: type):
ip1 = _ip1.as_byte()
ip2 = _ip2.as_byte()
ans = util.zero(T3)
for i in range(n):
e1 = Ptr[T1](ip1)[0]
e2 = Ptr[T2](ip2)[0]
ans += util.cast(e1, T3) * util.cast(e2, T3)
ip1 += is1
ip2 += is2
op[0] = ans
def incr(p: Ptr[T], s: int, T: type):
return Ptr[T](p.as_byte() + s)
n1 = a.size
n2 = b.size
length = n1
n = n2
if mode == 'valid':
length = length = length - n + 1
n_left = 0
n_right = 0
elif mode == 'same':
n_left = n >> 1
n_right = n - n_left - 1
elif mode == 'full':
n_right = n - 1
n_left = n - 1
length = length + n - 1
else:
raise ValueError(
f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})"
)
dt = type(util.coerce(a.dtype, b.dtype))
ret = empty(length, dtype=dt)
is1 = a.strides[0]
is2 = b.strides[0]
op = ret.data
os = ret.itemsize
ip1 = a.data
ip2 = Ptr[b.dtype](b.data.as_byte() + n_left * is2)
n = n - n_left
for i in range(n_left):
dot(ip1, is1, ip2, is2, op, n)
n += 1
ip2 = incr(ip2, -is2)
op = incr(op, os)
if small_correlate(ip1, is1, n1 - n2 + 1, a.dtype,
ip2, is2, n, b.dtype,
op, os):
ip1 = incr(ip1, is1 * (n1 - n2 + 1))
op = incr(op, os * (n1 - n2 + 1))
else:
for i in range(n1 - n2 + 1):
dot(ip1, is1, ip2, is2, op, n)
ip1 = incr(ip1, is1)
op = incr(op, os)
for i in range(n_right):
n -= 1
dot(ip1, is1, ip2, is2, op, n)
ip1 = incr(ip1, is1)
op = incr(op, os)
return ret
def correlate(a, b, mode: str = 'valid'):
a = asarray(a) a = asarray(a)
b = asarray(b) b = asarray(b)
if a.ndim != 1 or b.ndim != 1: if a.ndim != 1 or b.ndim != 1:
compile_error('object too deep for desired array') compile_error('object too deep for desired array')
n1 = len(a) n1 = a.size
n2 = len(b) n2 = b.size
if n1 == 0:
raise ValueError("first argument cannot be empty")
if n2 == 0:
raise ValueError("second argument cannot be empty")
if b.dtype is complex or b.dtype is complex64:
b = b.conjugate()
if n1 < n2: if n1 < n2:
inverted = 1 return _correlate(b, a, mode=mode)[::-1]
inv = n1
n1 = n2
n2 = inv
correlate(b, a, mode)
else: else:
inverted = 0 return _correlate(a, b, mode=mode)
length = n1
n = n2
if mode == 'valid':
length = length - n + 1
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
or b.dtype is complex64):
ret = zeros(length, dtype=complex)
else:
ret = empty(length)
for i in range(length):
for j in range(n):
if inverted == 0:
ret.data[i] += a._ptr(
(j + i, ))[0] * conjugate(b._ptr((j, ))[0])
else:
ret.data[i] += a._ptr((j, ))[0] * b._ptr((j + i, ))[0]
elif mode == 'same':
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
or b.dtype is complex64):
ret = zeros(length, dtype=complex)
else:
ret = empty(length)
for i in range(length):
for j in range(n):
signal_index = i - int(n / 2) + j
if signal_index >= 0 and signal_index < length:
if inverted == 0:
ret.data[i] += a._ptr(
(signal_index, ))[0] * conjugate(b._ptr((j, ))[0])
else:
ret.data[i] += a._ptr((j, ))[0] * b._ptr(
(signal_index, ))[0]
elif mode == 'full':
full_length = length + n - 1
if (a.dtype is complex or b.dtype is complex or a.dtype is complex64
or b.dtype is complex64):
ret = zeros(full_length, dtype=complex)
else:
ret = empty(full_length)
for i in range(full_length):
for j in range(n):
signal_index = i + j - 2
if signal_index >= 0 and signal_index < length:
if inverted == 0:
ret.data[i] += a._ptr(
(signal_index, ))[0] * conjugate(b._ptr((j, ))[0])
else:
ret.data[i] += a._ptr((j, ))[0] * b._ptr(
(signal_index, ))[0]
else:
raise ValueError(
f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})"
)
if inverted:
ret = ret[::-1]
if ret.dtype is complex or ret.dtype is complex64:
ret.map(conjugate, inplace=True)
return ret
def bincount(x, weights=None, minlength: int = 0): def bincount(x, weights=None, minlength: int = 0):
x = asarray(x).astype(int) x = asarray(x).astype(int)

View File

@ -199,3 +199,55 @@ def test_float_out_of_range_parse():
assert 1e10000 == float('inf') assert 1e10000 == float('inf')
test_float_out_of_range_parse() test_float_out_of_range_parse()
@test
def test_int_float_ops(F: type):
def check(got, exp=True):
return (exp == got) and (type(exp) is type(got))
# standard
assert check(F(1.5) + 1, F(2.5))
assert check(F(1.5) - 1, F(0.5))
assert check(F(1.5) * 2, F(3.0))
assert check(F(1.5) / 2, F(0.75))
assert check(F(3.5) // 2, F(1.0))
assert check(F(3.5) % 2, F(1.5))
assert check(F(3.5) ** 2, F(12.25))
assert check(divmod(F(3.5), 2), (F(1.0), F(1.5)))
# right-hand ops
assert check(1 + F(1.5), F(2.5))
assert check(1 - F(1.5), F(-0.5))
assert check(2 * F(1.5), F(3.0))
assert check(2 / F(2.5), F(0.8))
assert check(2 // F(1.5), F(1.0))
assert check(2 % F(1.5), F(0.5))
assert check(4 ** F(2.5), F(32.0))
assert check(divmod(4, F(2.5)), (F(1.0), F(1.5)))
# comparisons
assert check(F(1.0) == 1)
assert check(F(2.0) != 1)
assert check(F(0.0) < 1)
assert check(F(2.0) > 1)
assert check(F(0.0) <= 1)
assert check(F(2.0) >= 1)
assert check(1 == F(1.0))
assert check(1 != F(2.0))
assert check(1 < F(2.0))
assert check(1 > F(0.0))
assert check(1 <= F(2.0))
assert check(1 >= F(0.0))
# power
assert check(F(3.5) ** 1, F(3.5))
assert check(F(3.5) ** 2, F(12.25))
assert check(F(3.5) ** 3, F(42.875))
assert check(F(4.0) ** -1, F(0.25))
assert check(F(4.0) ** -2, F(0.0625))
assert check(F(4.0) ** -3, F(0.015625))
assert check(F(3.5) ** 0, F(1.0))
test_int_float_ops(float)
test_int_float_ops(float32)
test_int_float_ops(float16)

View File

@ -1238,7 +1238,7 @@ test_fill_diagonal(np.zeros((3, 5), int),
@test @test
def test_pad(array, pad_width, expected, mode='constant', **kwargs): def test_pad(array, pad_width, expected, mode='constant', **kwargs):
assert (np.pad(array, pad_width, mode, **kwargs) == expected).all() assert np.allclose(np.pad(array, pad_width, mode, **kwargs), expected)
test_pad([1, 2, 3, 4, 5], (2, 3), test_pad([1, 2, 3, 4, 5], (2, 3),
np.array([4, 4, 1, 2, 3, 4, 5, 6, 6, 6]), np.array([4, 4, 1, 2, 3, 4, 5, 6, 6, 6]),
@ -1270,6 +1270,400 @@ test_pad([1, 2, 3, 4, 5], (2, 3),
test_pad([1, 2, 3, 4, 5], (2, 3), np.array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3]), test_pad([1, 2, 3, 4, 5], (2, 3), np.array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3]),
'wrap') 'wrap')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([4, 4, 1, 2, 3, 4, 5, 6, 6, 6], dtype=np.float32),
'constant',
constant_values=(4, 6))
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([1, 1, 1, 2, 3, 4, 5, 5, 5, 5], dtype=np.float32),
'edge')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([5, 3, 1, 2, 3, 4, 5, 2, -1, -4], dtype=np.float32),
'linear_ramp',
end_values=(5, -4))
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, ),
np.array([5, 5, 1, 2, 3, 4, 5, 5, 5], dtype=np.float32),
'maximum')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, ),
np.array([3, 3, 1, 2, 3, 4, 5, 3, 3], dtype=np.float32),
'mean')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, ),
np.array([3, 3, 1, 2, 3, 4, 5, 3, 3], dtype=np.float32),
'median')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([3, 2, 1, 2, 3, 4, 5, 4, 3, 2], dtype=np.float32),
'reflect')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32),
'reflect',
reflect_type='odd')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([2, 1, 1, 2, 3, 4, 5, 5, 4, 3], dtype=np.float32),
'symmetric')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([0, 1, 1, 2, 3, 4, 5, 5, 6, 7], dtype=np.float32),
'symmetric',
reflect_type='odd')
test_pad(np.array([1, 2, 3, 4, 5], dtype=np.float32), (2, 3),
np.array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3], dtype=np.float32),
'wrap')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 2., 0., 0., 0.],
[0., 0., 3., 4., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]], dtype=np.float32),
'constant')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[3., 3., 3., 4., 4., 4., 4.],
[3., 3., 3., 4., 4., 4., 4.],
[3., 3., 3., 4., 4., 4., 4.]], dtype=np.float32),
'edge')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[0. , 0. , 0. , 0. , 0. ,
0. , 0. ],
[0. , 0.16666667, 0.33333334, 0.6666667 , 0.44444448,
0.22222224, 0. ],
[0. , 0.33333334, 0.6666667 , 1.3333334 , 0.88888896,
0.44444448, 0. ],
[0. , 0.5 , 1. , 2. , 1.3333334 ,
0.6666667 , 0. ],
[0. , 1.5 , 3. , 4. , 2.6666667 ,
1.3333334 , 0. ],
[0. , 0.75 , 1.5 , 2. , 1.3333334 ,
0.6666667 , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. ]], dtype=np.float32),
'linear_ramp')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.],
[2., 2., 1., 2., 2., 2., 2.],
[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.]], dtype=np.float32),
'maximum')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
'mean')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
'median')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.],
[3., 3., 3., 4., 3., 3., 3.],
[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.]], dtype=np.float32),
'minimum')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
'reflect')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[4., 3., 3., 4., 4., 3., 3.],
[4., 3., 3., 4., 4., 3., 3.],
[2., 1., 1., 2., 2., 1., 1.],
[2., 1., 1., 2., 2., 1., 1.],
[4., 3., 3., 4., 4., 3., 3.],
[4., 3., 3., 4., 4., 3., 3.],
[2., 1., 1., 2., 2., 1., 1.]], dtype=np.float32),
'symmetric')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
'wrap')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 2., 0., 0., 0.],
[0., 0., 3., 4., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]], dtype=np.float32),
'constant')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[-1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1.],
[-1., -1., 1., 2., -1., -1., -1.],
[-1., -1., 3., 4., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1.]], dtype=np.float32),
'constant', constant_values=-1)
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[-1., -1., -1., -1., -2., -2., -2.],
[-1., -1., -1., -1., -2., -2., -2.],
[-1., -1., -1., -1., -2., -2., -2.],
[-1., -1., 1., 2., -2., -2., -2.],
[-1., -1., 3., 4., -2., -2., -2.],
[-1., -1., -2., -2., -2., -2., -2.],
[-1., -1., -2., -2., -2., -2., -2.]], dtype=np.float32),
'constant', constant_values=((-1, -2)))
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[-3., -3., -1., -1., -4., -4., -4.],
[-3., -3., -1., -1., -4., -4., -4.],
[-3., -3., -1., -1., -4., -4., -4.],
[-3., -3., 1., 2., -4., -4., -4.],
[-3., -3., 3., 4., -4., -4., -4.],
[-3., -3., -2., -2., -4., -4., -4.],
[-3., -3., -2., -2., -4., -4., -4.]], dtype=np.float32),
'constant', constant_values=((-1, -2), (-3, -4)))
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[3., 3., 3., 4., 4., 4., 4.],
[3., 3., 3., 4., 4., 4., 4.],
[3., 3., 3., 4., 4., 4., 4.]], dtype=np.float32),
'edge')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[0. , 0. , 0. , 0. , 0. ,
0. , 0. ],
[0. , 0.16666667, 0.33333334, 0.6666667 , 0.44444448,
0.22222224, 0. ],
[0. , 0.33333334, 0.6666667 , 1.3333334 , 0.88888896,
0.44444448, 0. ],
[0. , 0.5 , 1. , 2. , 1.3333334 ,
0.6666667 , 0. ],
[0. , 1.5 , 3. , 4. , 2.6666667 ,
1.3333334 , 0. ],
[0. , 0.75 , 1.5 , 2. , 1.3333334 ,
0.6666667 , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. ]], dtype=np.float32),
'linear_ramp')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1. , 1. , 1. , 1. , 1. , 1. ,
1. ],
[1. , 1. , 1. , 1.3333334, 1.2222222, 1.1111112,
1. ],
[1. , 1. , 1. , 1.6666667, 1.4444445, 1.2222222,
1. ],
[1. , 1. , 1. , 2. , 1.6666667, 1.3333334,
1. ],
[1. , 2. , 3. , 4. , 3. , 2. ,
1. ],
[1. , 1.5 , 2. , 2.5 , 2. , 1.5 ,
1. ],
[1. , 1. , 1. , 1. , 1. , 1. ,
1. ]], dtype=np.float32),
'linear_ramp', end_values=1)
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1. , 1. , 1. , 1. , 1.3333333, 1.6666666,
2. ],
[1. , 1. , 1. , 1.3333334, 1.5555556, 1.7777778,
2. ],
[1. , 1. , 1. , 1.6666667, 1.7777778, 1.888889 ,
2. ],
[1. , 1. , 1. , 2. , 2. , 2. ,
2. ],
[1. , 2. , 3. , 4. , 3.3333335, 2.6666667,
2. ],
[1. , 1.75 , 2.5 , 3. , 2.6666667, 2.3333333,
2. ],
[1. , 1.5 , 2. , 2. , 2. , 2. ,
2. ]], dtype=np.float32),
'linear_ramp', end_values=(1, 2))
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[2. , 1.5 , 1. , 1. , 1. , 1. ,
1. ],
[2. , 1.5 , 1. , 1.3333334, 1.2222222, 1.1111112,
1. ],
[2. , 1.5 , 1. , 1.6666667, 1.4444445, 1.2222222,
1. ],
[2. , 1.5 , 1. , 2. , 1.6666667, 1.3333334,
1. ],
[2. , 2.5 , 3. , 4. , 3. , 2. ,
1. ],
[2. , 2.25 , 2.5 , 3. , 2.3333335, 1.6666667,
1. ],
[2. , 2. , 2. , 2. , 1.6666667, 1.3333334,
1. ]], dtype=np.float32),
'linear_ramp', end_values=((1, 2), (2, 1)))
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.],
[2., 2., 1., 2., 2., 2., 2.],
[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.],
[4., 4., 3., 4., 4., 4., 4.]], dtype=np.float32),
'maximum')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
'mean')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[1., 1., 1., 2., 2., 2., 2.],
[3., 3., 3., 4., 4., 4., 4.],
[3., 3., 3., 4., 4., 4., 4.],
[3., 3., 3., 4., 4., 4., 4.]], dtype=np.float32),
'mean', stat_length=1)
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
[1. , 1. , 1. , 2. , 1.5, 1.5, 1.5],
[3. , 3. , 3. , 4. , 3.5, 3.5, 3.5],
[2. , 2. , 2. , 3. , 2.5, 2.5, 2.5],
[2. , 2. , 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
'mean', stat_length=(1, 2))
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
[1.5, 1.5, 1. , 2. , 2. , 2. , 2. ],
[3.5, 3.5, 3. , 4. , 4. , 4. , 4. ],
[2.5, 2.5, 2. , 3. , 3. , 3. , 3. ],
[2.5, 2.5, 2. , 3. , 3. , 3. , 3. ]], dtype=np.float32),
'mean', stat_length=((1, 2), (2, 1)))
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[1.5, 1.5, 1. , 2. , 1.5, 1.5, 1.5],
[3.5, 3.5, 3. , 4. , 3.5, 3.5, 3.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5],
[2.5, 2.5, 2. , 3. , 2.5, 2.5, 2.5]], dtype=np.float32),
'median')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.],
[3., 3., 3., 4., 3., 3., 3.],
[1., 1., 1., 2., 1., 1., 1.],
[1., 1., 1., 2., 1., 1., 1.]], dtype=np.float32),
'minimum')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
'reflect')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[4., 3., 3., 4., 4., 3., 3.],
[4., 3., 3., 4., 4., 3., 3.],
[2., 1., 1., 2., 2., 1., 1.],
[2., 1., 1., 2., 2., 1., 1.],
[4., 3., 3., 4., 4., 3., 3.],
[4., 3., 3., 4., 4., 3., 3.],
[2., 1., 1., 2., 2., 1., 1.]], dtype=np.float32),
'symmetric')
test_pad(np.array([[1, 2], [3, 4]], np.float32), ((3, 2), (2, 3)),
np.array([[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.],
[1., 2., 1., 2., 1., 2., 1.],
[3., 4., 3., 4., 3., 4., 3.]], dtype=np.float32),
'wrap')
test_pad([[[[1]]]], 1,
np.array([[[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]],
[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]],
[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]]],
[[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]],
[[2, 2, 2],
[2, 1, 2],
[2, 2, 2]],
[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]]],
[[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]],
[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]],
[[2, 2, 2],
[2, 2, 2],
[2, 2, 2]]]]), 'constant', constant_values=2)
test_pad([[[[1]]]], 0, np.array([[[[1]]]]), 'constant', constant_values=2)
############# #############
# searching # # searching #
############# #############

View File

@ -247,6 +247,67 @@ test_correlate(np.array([1 + 0j, 2 + 0j, 3 + 0j, 4 + 1j]),
np.array([-1 + 0j, -2j, 3 + 1j]), np.array([-1 + 0j, -2j, 3 + 1j]),
np.array([8. + 1.j, 11. + 5.j])) np.array([8. + 1.j, 11. + 5.j]))
@test
def test_correlate2():
# Integer inputs
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
assert np.allclose(np.correlate(a, b, mode="valid"), [32])
assert np.allclose(np.correlate(a, b, mode="same"), [17, 32, 23])
assert np.allclose(np.correlate(a, b, mode="full"), [6, 17, 32, 23, 12])
# Floating-point inputs
a = np.array([1.5, 2.5, 3.5])
b = np.array([4.0, 5.0, 6.0])
assert np.allclose(np.correlate(a, b, mode="valid"), [39.5])
assert np.allclose(np.correlate(a, b, mode="same"), [22.5, 39.5, 27.5])
assert np.allclose(np.correlate(a, b, mode="full"), [9.0, 22.5, 39.5, 27.5, 14.0])
# Complex numbers
a = np.array([1+2j, 3+4j])
b = np.array([5+6j, 7+8j])
assert np.allclose(np.correlate(a, b, mode="valid"), [70+8j])
assert np.allclose(np.correlate(a, b, mode="same"), [23+6j, 70+8j])
assert np.allclose(np.correlate(a, b, mode="full"), [23+6j, 70+8j, 39+2j])
# Different-length arrays
a = np.array([1, 2, 3, 4])
b = np.array([0, 1])
assert np.allclose(np.correlate(a, b, mode="valid"), [2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="same"), [1, 2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="full"), [1, 2, 3, 4, 0])
a = np.array([0, 1])
b = np.array([1, 2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="valid")[::-1], [2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="same")[::-1], [1, 2, 3, 4])
assert np.allclose(np.correlate(a, b, mode="full")[::-1], [1, 2, 3, 4, 0])
# Large array test
a = np.arange(20)
b = np.arange(10)
expected_valid = np.array([np.sum(a[i : i + len(b)] * b) for i in range(len(a) - len(b) + 1)])
expected_full = np.correlate(a, b, mode="full")
expected_same = np.correlate(a, b, mode="same")
assert np.allclose(np.correlate(a, b, mode="valid"), expected_valid)
assert np.allclose(np.correlate(a, b, mode="full"), expected_full)
assert np.allclose(np.correlate(a, b, mode="same"), expected_same)
# Different dtypes (int and float)
a = np.array([1, 2, 3], dtype=int)
b = np.array([1.5, 2.5, 3.5], dtype=float)
assert np.allclose(np.correlate(a, b, mode="valid"), [17.0])
assert np.allclose(np.correlate(a, b, mode="same"), [9.5, 17.0, 10.5])
assert np.allclose(np.correlate(a, b, mode="full"), [3.5, 9.5, 17.0, 10.5, 4.5])
# Edge case: Single-element arrays
a = np.array([5])
b = np.array([10])
assert np.allclose(np.correlate(a, b, mode="valid"), [50])
assert np.allclose(np.correlate(a, b, mode="same"), [50])
assert np.allclose(np.correlate(a, b, mode="full"), [50])
test_correlate2()
@test @test
def test_bincount(x, expected, weights=None, minlength=0): def test_bincount(x, expected, weights=None, minlength=0):
assert (np.bincount(x, weights=weights, assert (np.bincount(x, weights=weights,

View File

@ -181,3 +181,16 @@ def test_ndarray():
assert np.datetime_data(y.dtype) == ('s', 2) assert np.datetime_data(y.dtype) == ('s', 2)
test_ndarray() test_ndarray()
@codon.jit
def e(x=2, y=99):
return 2*x + y
def test_arg_order():
assert e(1, 2) == 4
assert e(1) == 101
assert e(y=10, x=1) == 12
assert e(x=1) == 101
assert e() == 103
test_arg_order()

View File

@ -88,6 +88,7 @@ def test_codon_extensions(m):
assert m.f4(a=2.2) == (2.2, 2.22) assert m.f4(a=2.2) == (2.2, 2.22)
assert m.f4(b=3.3) == (1.11, 3.3) assert m.f4(b=3.3) == (1.11, 3.3)
assert m.f4('foo') == ('foo', 'foo') assert m.f4('foo') == ('foo', 'foo')
assert m.f4(b'foo') == ('foo', 'foo')
assert m.f4({1}) == {1} assert m.f4({1}) == {1}
assert m.f5() is None assert m.f5() is None
assert equal(m.f6(1.9, 't'), 1.9, 1.9, 't') assert equal(m.f6(1.9, 't'), 1.9, 1.9, 't')

View File

@ -450,6 +450,18 @@ def test_omp_reductions():
c = min(b, c) c = min(b, c)
assert c == -1. assert c == -1.
c = 0.
@par
for i in L:
c += i # float-int op
assert c == expected(N, 0., float.__add__)
c = 0.
@par
for i in L:
c = i + c # int-float op
assert c == expected(N, 0., float.__add__)
# float32s # float32s
c = f32(0.) c = f32(0.)
# this one can give different results due to # this one can give different results due to
@ -479,6 +491,18 @@ def test_omp_reductions():
c = min(b, c) c = min(b, c)
assert c == f32(-1.) assert c == f32(-1.)
c = f32(0.)
@par
for i in L[:12]:
c += i # float-int op
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
c = f32(0.)
@par
for i in L[:12]:
c = i + c # int-float op
assert c == f32(1+2+3+4+5+6+7+8+9+10+11)
x_add = 10. x_add = 10.
x_min = inf x_min = inf
x_max = -inf x_max = -inf