1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/linalg/linalg.codon
A. R. Shajii b8c1eeed36
2025 updates (#619)
* 2025 updates

* Update ci.yml
2025-01-29 15:41:43 -05:00

4341 lines
122 KiB
Python

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
from .blas import *
from ..ndarray import ndarray
from ..ndmath import divide, greater, isnan, multiply, sqrt
from ..routines import atleast_2d, asarray, broadcast_shapes, broadcast_to, empty, \
empty_like, eye, expand_dims, moveaxis, reshape, swapaxes, \
zeros, array
from ..util import cast, cdiv_int, coerce, eps, exp, free, inf, log, multirange, \
nan, normalize_axis_index, sizeof, sort, sqrt as util_sqrt, \
tuple_delete, tuple_insert, tuple_range, tuple_set, zero
#############
# Utilities #
#############
class LinAlgError(Static[Exception]):
def __init__(self, message: str = ''):
super().__init__("numpy.linalg.LinAlgError", message)
def _square_rows(a):
s = a.shape
if staticlen(s) < 2:
compile_error("Array must be at least 2-dimensional")
n = s[-1]
if n != s[-2]:
raise LinAlgError("Last 2 dimensions of the array must be square")
return n
def _rows_cols(a):
s = a.shape
if staticlen(s) < 2:
compile_error("Array must be at least 2-dimensional")
m, n = s[-2:]
return m, n
def _asarray(a, dtype: type = NoneType):
if dtype is NoneType:
if isinstance(a, ndarray):
if (a.dtype is float or a.dtype is float32 or
a.dtype is complex or a.dtype is complex64):
return a
else:
return a.astype(float)
else:
dtype1 = type(asarray(a).data[0])
if (dtype1 is float or dtype1 is float32 or
dtype1 is complex or dtype1 is complex64):
return asarray(a, dtype=dtype1)
else:
return asarray(a, dtype=float)
elif (dtype is float or dtype is float32 or
dtype is complex or dtype is complex64):
return asarray(a, dtype=dtype)
else:
return asarray(a, dtype=float)
def _basetype(dtype: type):
if dtype is float or dtype is float32:
return dtype()
elif dtype is complex:
return float()
elif dtype is complex64:
return float32()
else:
compile_error("[internal error] bad dtype")
def _complextype(dtype: type):
if dtype is float:
return complex()
elif dtype is float32:
return complex64()
elif dtype is complex or dtype is complex64:
return dtype()
else:
compile_error("[internal error] bad dtype")
def _copy(n: int, sx: Ptr[T], incx: int, sy: Ptr[T], incy: int, T: type):
args = (fint(n), sx, fint(incx), sy, fint(incy))
if T is float:
cblas_dcopy(*args)
elif T is float32:
cblas_scopy(*args)
elif T is complex:
cblas_zcopy(*args)
elif T is complex64:
cblas_ccopy(*args)
else:
compile_error("[internal error] bad type for BLAS copy: " + T.__name__)
def _nan(T: type):
if T is float or T is float32:
return nan(T)
elif T is complex:
return complex(nan(float), nan(float))
elif T is complex64:
return complex64(nan(float32), nan(float32))
else:
compile_error("[internal error] bad dtype")
@tuple
class LinearizeData:
rows: int
cols: int
row_strides: int
col_strides: int
out_lead_dim: int
def __new__(rows: int, cols: int, row_strides: int, col_strides: int,
out_lead_dim: int) -> LinearizeData:
return (rows, cols, row_strides, col_strides, out_lead_dim)
def __new__(rows: int, cols: int, row_strides: int, col_strides: int):
return LinearizeData(rows, cols, row_strides, col_strides, cols)
def linearize(self, dst: Ptr[T], src: Ptr[T], T: type):
if dst:
rv = dst
cols = self.cols
col_strides = cdiv_int(self.col_strides, sizeof(T))
for i in range(self.rows):
if col_strides > 0:
_copy(cols, src, col_strides, dst, 1)
elif col_strides < 0:
_copy(cols, src + (cols - 1)*col_strides, col_strides, dst, 1);
else:
for j in range(cols):
dst[j] = src[0]
src += cdiv_int(self.row_strides, sizeof(T))
dst += self.out_lead_dim
return rv
else:
return src
def delinearize(self, dst: Ptr[T], src: Ptr[T], T: type):
if src:
rv = src
cols = self.cols
col_strides = cdiv_int(self.col_strides, sizeof(T))
for i in range(self.rows):
if col_strides > 0:
_copy(cols, src, 1, dst, col_strides);
elif col_strides < 0:
_copy(cols, src, 1, dst + (cols - 1)*col_strides, col_strides)
else:
if cols > 0:
dst[0] = src[cols - 1]
src += self.out_lead_dim
dst += cdiv_int(self.row_strides, sizeof(T))
return rv
else:
return src
def nan_matrix(self, dst: Ptr[T], T: type):
for i in range(self.rows):
cp = dst
cs = cdiv_int(self.col_strides, sizeof(T))
for j in range(self.cols):
cp[0] = _nan(T)
cp += cs
dst += cdiv_int(self.row_strides, sizeof(T))
def zero_matrix(self, dst: Ptr[T], T: type):
for i in range(self.rows):
cp = dst
cs = cdiv_int(self.col_strides, sizeof(T))
for j in range(self.cols):
cp[0] = T()
cp += cs
dst += cdiv_int(self.row_strides, sizeof(T))
def identity_matrix(dst: Ptr[T], n: int, T: type):
str.memset(dst.as_byte(), byte(0), n * n * sizeof(T))
for i in range(n):
dst[0] = T(1)
dst += n + 1
###############
# Determinant #
###############
def _slogdet_from_factored_diagonal(src: Ptr[T],
m: int,
sign: T,
T: type):
if T is float or T is float32:
acc_sign = sign
acc_logdet = T(0.0)
for i in range(m):
abs_element = src[0]
if abs_element < T(0.0):
acc_sign = -acc_sign
abs_element = -abs_element
acc_logdet += log(abs_element)
src += m + 1
return acc_sign, acc_logdet
elif T is complex or T is complex64:
B = type(_basetype(T))
acc_sign = sign
acc_logdet = B()
for i in range(m):
abs_element = abs(src[0])
sign_element = src[0] / abs_element
acc_sign *= sign_element
acc_logdet += log(abs_element)
src += m + 1
return acc_sign, acc_logdet
else:
compile_error("[internal error] invalid type for _slogdet_from_factored_diagonal: " + T.__name__)
def _slogdet_single_element(m: int,
src: Ptr[T],
pivots: Ptr[fint],
T: type):
m0 = fint(m)
lda = fint(max(m, 1))
info = fint(0)
args = (__ptr__(m0), __ptr__(m0), src.as_byte(), __ptr__(lda), pivots, __ptr__(info))
if T is float:
dgetrf_(*args)
elif T is float32:
sgetrf_(*args)
elif T is complex:
zgetrf_(*args)
elif T is complex64:
cgetrf_(*args)
else:
compile_error("[internal error] bad input type")
sign = T()
if info == fint(0):
change_sign = False
for i in range(m):
if pivots[i] != fint(i + 1):
change_sign = not change_sign
sign = T(-1.0 if change_sign else 1.0)
logdet = zero(_basetype(T))
sign, logdet = _slogdet_from_factored_diagonal(src, m, sign)
return sign, logdet
else:
return T(0.0), -inf(type(_basetype(T)))
def _det_from_slogdet(sign: T, logdet, T: type):
if T is float or T is float32:
return sign * exp(logdet)
elif T is complex or T is complex64:
return sign * T(exp(logdet))
else:
compile_error("[internal error] invalid type for _det_from_slogdet: " + T.__name__)
@tuple
class SlogdetResult[A, B]:
sign: A
logabsdet: B
def __getitem__(self, idx: Static[int]):
if idx == 0 or idx == -2:
return self.sign
elif idx == 1 or idx == -1:
return self.logabsdet
else:
compile_error("tuple ('SlogdetResult') index out of range")
def slogdet(a):
a = _asarray(a)
T = a.dtype
s = a.shape
m = _square_rows(a)
steps = a.strides
tmp = Ptr[T](m * m)
piv = Ptr[fint](m)
lin_data = LinearizeData(m, m, steps[-1], steps[-2])
if a.ndim == 2:
lin_data.linearize(tmp, a.data)
sign, logabsdet = _slogdet_single_element(m, tmp, piv)
else:
ans_shape = s[:-2]
sign = empty(ans_shape, dtype=a.dtype)
logabsdet = empty(ans_shape, dtype=type(_basetype(a.dtype)))
for idx in multirange(ans_shape):
lin_data.linearize(tmp, a._ptr(idx + (0, 0)))
sign0, logabsdet0 = _slogdet_single_element(m, tmp, piv)
sign._ptr(idx)[0] = sign0
logabsdet._ptr(idx)[0] = logabsdet0
free(tmp)
free(piv)
return SlogdetResult(sign, logabsdet)
def det(a):
a = _asarray(a)
T = a.dtype
s = a.shape
m = _square_rows(a)
steps = a.strides
tmp = Ptr[T](m * m)
piv = Ptr[fint](m)
lin_data = LinearizeData(m, m, steps[-1], steps[-2])
if a.ndim == 2:
lin_data.linearize(tmp, a.data)
sign, logabsdet = _slogdet_single_element(m, tmp, piv)
ans = _det_from_slogdet(sign, logabsdet)
else:
ans_shape = s[:-2]
ans = empty(ans_shape, dtype=a.dtype)
for idx in multirange(ans_shape):
lin_data.linearize(tmp, a._ptr(idx + (0, 0)))
sign0, logabsdet0 = _slogdet_single_element(m, tmp, piv)
ans._ptr(idx)[0] = _det_from_slogdet(sign0, logabsdet0)
free(tmp)
free(piv)
return ans
########
# Eigh #
########
class EighParams:
A: Ptr[T]
W: Ptr[B]
WORK: Ptr[T]
RWORK: Ptr[B]
IWORK: Ptr[fint]
N: fint
LWORK: fint
LRWORK: fint
LIWORK: fint
JOBZ: byte
UPLO: byte
LDA: fint
T: type
B: type
def _init_real(self, JOBZ: byte, UPLO: byte, N: fint):
safe_N = int(N)
alloc_size = safe_N * (safe_N + 1) * sizeof(T)
lda = N if N else fint(1)
mem_buff = cobj(alloc_size)
a = mem_buff
w = mem_buff + safe_N * safe_N * sizeof(T)
self.A = Ptr[T](a)
self.W = Ptr[B](w)
self.RWORK = Ptr[B]()
self.N = N
self.LRWORK = fint(0)
self.JOBZ = JOBZ
self.UPLO = UPLO
self.LDA = lda
# work size query
query_work_size = T()
query_iwork_size = fint()
self.LWORK = fint(-1)
self.LIWORK = fint(-1)
self.WORK = __ptr__(query_work_size)
self.IWORK = __ptr__(query_iwork_size)
if self.call():
free(mem_buff)
raise LinAlgError("error in evd work size query")
lwork = cast(query_work_size, int)
liwork = int(query_iwork_size)
mem_buff2 = cobj(lwork * sizeof(T) + liwork * sizeof(fint))
work = mem_buff2
iwork = mem_buff2 + lwork * sizeof(T)
self.LWORK = fint(lwork)
self.WORK = Ptr[T](work)
self.LIWORK = fint(liwork)
self.IWORK = Ptr[fint](iwork)
def _init_complex(self, JOBZ: byte, UPLO: byte, N: fint):
safe_N = int(N)
lda = N if N else fint(1)
mem_buff = cobj(safe_N * safe_N * sizeof(T) + safe_N * sizeof(B))
a = mem_buff
w = mem_buff + safe_N * safe_N * sizeof(T)
self.A = Ptr[T](a)
self.W = Ptr[B](w)
self.N = N
self.JOBZ = JOBZ
self.UPLO = UPLO
self.LDA = lda
# work size query
query_work_size = T()
query_rwork_size = B()
query_iwork_size = fint()
self.LWORK = fint(-1)
self.LRWORK = fint(-1)
self.LIWORK = fint(-1)
self.WORK = __ptr__(query_work_size)
self.RWORK = __ptr__(query_rwork_size)
self.IWORK = __ptr__(query_iwork_size)
if self.call():
free(mem_buff)
raise LinAlgError("error in evd work size query")
lwork = cast(Ptr[B](__ptr__(query_work_size).as_byte())[0], int)
lrwork = cast(query_rwork_size, int)
liwork = int(query_iwork_size)
mem_buff2 = cobj(lwork * sizeof(T) + lrwork * sizeof(B) + liwork * sizeof(fint))
work = mem_buff2
rwork = work + lwork * sizeof(T)
iwork = rwork + lrwork * sizeof(B)
self.WORK = Ptr[T](work)
self.RWORK = Ptr[B](rwork)
self.IWORK = Ptr[fint](iwork)
self.LWORK = fint(lwork)
self.LRWORK = fint(lrwork)
self.LIWORK = fint(liwork)
def __init__(self, JOBZ: byte, UPLO: byte, N: fint):
if T is complex or T is complex64:
self._init_complex(JOBZ, UPLO, N)
else:
self._init_real(JOBZ, UPLO, N)
def release(self):
free(self.A)
free(self.WORK)
def call(self):
JOBZ = self.JOBZ
UPLO = self.UPLO
N = self.N
A = self.A
LDA = self.LDA
W = self.W
WORK = self.WORK
LWORK = self.LWORK
RWORK = self.RWORK
LRWORK = self.LRWORK
IWORK = self.IWORK
LIWORK = self.LIWORK
rv = fint()
args_real = (__ptr__(JOBZ),
__ptr__(UPLO),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
W.as_byte(),
WORK.as_byte(),
__ptr__(LWORK),
IWORK,
__ptr__(LIWORK),
__ptr__(rv))
args_cplx = (__ptr__(JOBZ),
__ptr__(UPLO),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
W.as_byte(),
WORK.as_byte(),
__ptr__(LWORK),
RWORK.as_byte(),
__ptr__(LRWORK),
IWORK,
__ptr__(LIWORK),
__ptr__(rv))
if T is float:
dsyevd_(*args_real)
elif T is float32:
ssyevd_(*args_real)
elif T is complex:
zheevd_(*args_cplx)
elif T is complex64:
cheevd_(*args_cplx)
else:
compile_error("[internal error] bad dtype for eigh")
return rv
@tuple
class EighResult[A, B]:
eigenvalues: A
eigenvectors: B
def __getitem__(self, idx: Static[int]):
if idx == 0 or idx == -2:
return self.eigenvalues
elif idx == 1 or idx == -1:
return self.eigenvectors
else:
compile_error("tuple ('EighResult') index out of range")
def _eigh(a, JOBZ: byte, UPLO: byte, compute_eigenvectors: Static[int]):
a = _asarray(a)
B = type(_basetype(a.dtype))
n = _square_rows(a)
params = EighParams[a.dtype, B](JOBZ, UPLO, fint(n))
eigenvalues = empty(a.shape[:-2] + (n,), dtype=B)
if compute_eigenvectors:
eigenvectors = empty(a.shape, dtype=a.dtype)
else:
eigenvectors = None
matrix_in_ld = LinearizeData(n, n, a.strides[-1], a.strides[-2])
eigenvalues_out_ld = LinearizeData(1, n, 0, eigenvalues.strides[-1])
if compute_eigenvectors:
eigenvectors_out_ld = LinearizeData(n, n, eigenvectors.strides[-1], eigenvectors.strides[-2])
else:
eigenvectors_out_ld = None
error_occured = False
for idx in multirange(a.shape[:-2]):
matrix_ptr = a._ptr(idx + (0, 0))
eigval_ptr = eigenvalues._ptr(idx + (0,))
if compute_eigenvectors:
eigvec_ptr = eigenvectors._ptr(idx + (0, 0))
else:
eigvec_ptr = None
matrix_in_ld.linearize(params.A, matrix_ptr)
not_ok = params.call()
if not not_ok:
eigenvalues_out_ld.delinearize(eigval_ptr, params.W)
if compute_eigenvectors:
eigenvectors_out_ld.delinearize(eigvec_ptr, params.A)
else:
eigenvalues_out_ld.nan_matrix(eigval_ptr)
if compute_eigenvectors:
eigenvectors_out_ld.nan_matrix(eigvec_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("Eigenvalues did not converge")
if compute_eigenvectors:
return EighResult(eigenvalues, eigenvectors)
else:
return eigenvalues
def eigh(a, UPLO: str = 'L'):
uplo_code = byte()
if UPLO == 'L' or UPLO == 'l':
uplo_code = byte(76)
elif UPLO == 'U' or UPLO == 'u':
uplo_code = byte(85)
else:
raise ValueError("UPLO argument must be 'L' or 'U'")
jobz_code = byte(86)
return _eigh(a, JOBZ=jobz_code, UPLO=uplo_code, compute_eigenvectors=True)
def eigvalsh(a, UPLO: str = 'L'):
uplo_code = byte()
if UPLO == 'L' or UPLO == 'l':
uplo_code = byte(76)
elif UPLO == 'U' or UPLO == 'u':
uplo_code = byte(85)
else:
raise ValueError("UPLO argument must be 'L' or 'U'")
jobz_code = byte(78)
return _eigh(a, JOBZ=jobz_code, UPLO=uplo_code, compute_eigenvectors=False)
#########
# Solve #
#########
class GesvParams:
A: Ptr[T]
B: Ptr[T]
IPIV: Ptr[fint]
N: fint
NRHS: fint
LDA: fint
LDB: fint
T: type
def __init__(self, N: fint, NRHS: fint):
safe_N = int(N)
safe_NRHS = int(NRHS)
ld = N if N else fint(1)
mem_buff = cobj(safe_N * safe_N * sizeof(T) +
safe_N * safe_NRHS * sizeof(T) +
safe_N * sizeof(fint))
a = mem_buff
b = a + safe_N * safe_N * sizeof(T)
ipiv = b + safe_N * safe_NRHS * sizeof(T)
self.A = Ptr[T](a)
self.B = Ptr[T](b)
self.IPIV = Ptr[fint](ipiv)
self.N = N
self.NRHS = NRHS
self.LDA = ld
self.LDB = ld
def release(self):
free(self.A)
def call(self):
A = self.A
B = self.B
IPIV = self.IPIV
N = self.N
NRHS = self.NRHS
LDA = self.LDA
LDB = self.LDB
rv = fint()
args = (__ptr__(N),
__ptr__(NRHS),
A.as_byte(),
__ptr__(LDA),
IPIV,
B.as_byte(),
__ptr__(LDB),
__ptr__(rv))
if T is float:
dgesv_(*args)
elif T is float32:
sgesv_(*args)
elif T is complex:
zgesv_(*args)
elif T is complex64:
cgesv_(*args)
else:
compile_error("[internal error] bad dtype for gesv")
return rv
def _solve(a, b):
if a.ndim == b.ndim:
if a.shape[:-2] != b.shape[:-2]:
pre_broadcast = broadcast_shapes(a.shape[:-2], b.shape[:-2])
a = broadcast_to(a, pre_broadcast + a.shape[-2:])
b = broadcast_to(b, pre_broadcast + b.shape[-2:])
else:
pre_broadcast = broadcast_shapes(a.shape[:-2], b.shape[:-2])
a1 = broadcast_to(a, pre_broadcast + a.shape[-2:])
b1 = broadcast_to(b, pre_broadcast + b.shape[-2:])
return _solve(a1, b1)
n = _square_rows(a)
m, k = _rows_cols(b)
if m != n:
raise ValueError("solve: 'a' and 'b' don't have the same number of rows")
r = empty(b.shape, dtype=a.dtype)
nrhs = b.shape[-1]
params = GesvParams[a.dtype](fint(n), fint(nrhs))
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
b_in = LinearizeData(nrhs, n, b.strides[-1], b.strides[-2])
r_out = LinearizeData(nrhs, n, r.strides[-1], r.strides[-2])
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
b_ptr = b._ptr(idx + (0, 0))
r_ptr = r._ptr(idx + (0, 0))
a_in.linearize(params.A, a_ptr)
b_in.linearize(params.B, b_ptr)
not_ok = params.call()
if not not_ok:
r_out.delinearize(r_ptr, params.B)
else:
r_out.nan_matrix(r_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("Singular matrix")
return r
def _solve1(a, b):
n = _square_rows(a)
m = b.shape[0]
if m != n:
raise ValueError("solve: 'a' and 'b' don't have the same number of rows")
r = empty((m,), dtype=a.dtype)
params = GesvParams[a.dtype](fint(n), fint(1))
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
b_in = LinearizeData(1, n, 1, b.strides[0])
r_out = LinearizeData(1, n, 1, r.strides[0])
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
b_ptr = b.data
r_ptr = r.data
a_in.linearize(params.A, a_ptr)
b_in.linearize(params.B, b_ptr)
not_ok = params.call()
if not not_ok:
r_out.delinearize(r_ptr, params.B)
else:
r_out.nan_matrix(r_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("Singular matrix")
return r
def solve(a, b):
dtype = type(
coerce(
type(asarray(a).data[0]),
type(asarray(b).data[0])))
a = _asarray(a, dtype=dtype)
b = _asarray(b, dtype=dtype)
if a.ndim < 2:
compile_error("'a' must be at least 2-dimensional")
if b.ndim == 0:
compile_error("'b' must be at least 1-dimensional")
if b.ndim == 1:
return _solve1(a, b)
else:
return _solve(a, b)
def _inv(a, ignore_errors: Static[int]):
a = _asarray(a)
n = _square_rows(a)
r = empty(a.shape, dtype=a.dtype)
params = GesvParams[a.dtype](fint(n), fint(n))
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
r_out = LinearizeData(n, n, r.strides[-1], r.strides[-2])
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
r_ptr = r._ptr(idx + (0, 0))
a_in.linearize(params.A, a_ptr)
LinearizeData.identity_matrix(params.B, n)
not_ok = params.call()
if not not_ok:
r_out.delinearize(r_ptr, params.B)
else:
r_out.nan_matrix(r_ptr)
error_occured = True
params.release()
if not ignore_errors:
if error_occured:
raise LinAlgError("Singular matrix")
return r
def inv(a):
return _inv(a, ignore_errors=False)
############
# Cholesky #
############
class PotrfParams:
A: Ptr[T]
N: fint
LDA: fint
UPLO: byte
T: type
def __init__(self, UPLO: byte, N: fint):
safe_N = int(N)
lda = N if N else fint(1)
mem_buff = cobj(safe_N * safe_N * sizeof(T))
a = mem_buff
self.A = Ptr[T](a)
self.N = N
self.LDA = lda
self.UPLO = UPLO
def zero_lower_triangle(self):
n = int(self.N)
matrix = self.A
for i in range(n - 1):
for j in range(i + 1, n):
matrix[j] = zero(T)
matrix += n
def zero_upper_triangle(self):
n = int(self.N)
matrix = self.A
matrix += n
for i in range(1, n):
for j in range(i):
matrix[j] = zero(T)
matrix += n
def call(self):
UPLO = self.UPLO
N = self.N
A = self.A
LDA = self.LDA
rv = fint()
args = (__ptr__(UPLO),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
__ptr__(rv))
if T is float:
dpotrf_(*args)
elif T is float32:
spotrf_(*args)
elif T is complex:
zpotrf_(*args)
elif T is complex64:
cpotrf_(*args)
else:
compile_error("[internal error] bad dtype for potrf")
return rv
def release(self):
free(self.A)
def cholesky(a, upper: bool = False):
a = _asarray(a)
uplo = byte(85 if upper else 76) # 'U' / 'L'
n = _square_rows(a)
params = PotrfParams[a.dtype](uplo, fint(n))
r = empty(a.shape, dtype=a.dtype)
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
r_out = LinearizeData(n, n, r.strides[-1], r.strides[-2])
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
r_ptr = r._ptr(idx + (0, 0))
a_in.linearize(params.A, a_ptr)
not_ok = params.call()
if not not_ok:
if upper:
params.zero_lower_triangle()
else:
params.zero_upper_triangle()
r_out.delinearize(r_ptr, params.A)
else:
error_occured = True
r_out.nan_matrix(r_ptr)
params.release()
if error_occured:
raise LinAlgError("Matrix is not positive definite")
return r
#######
# Eig #
#######
def _mk_complex_array_from_real(c: Ptr[C], re: Ptr[R], n: int, C: type, R: type):
for i in range(n):
c[i] = C(re[i], zero(R))
def _mk_complex_array(c: Ptr[C], re: Ptr[R], im: Ptr[R], n: int, C: type, R: type):
for i in range(n):
c[i] = C(re[i], im[i])
def _mk_complex_array_conjugate_pair(c: Ptr[C], r: Ptr[R], n: int, C: type, R: type):
for i in range(n):
re = r[i]
im = r[i + n]
c[i] = C(re, im)
c[i + n] = C(re, -im)
def _mk_geev_complex_eigenvectors(c: Ptr[C], r: Ptr[R], i: Ptr[R], n: int, C: type, R: type):
it = 0
while it < n:
if i[it] == zero(R):
_mk_complex_array_from_real(c, r, n)
c += n
r += n
it += 1
else:
_mk_complex_array_conjugate_pair(c, r, n)
c += 2 * n
r += 2 * n
it += 2
class GeevParams:
A: Ptr[T]
WR: Ptr[B]
WI: Ptr[T]
VLR: Ptr[T]
VRR: Ptr[T]
WORK: Ptr[T]
W: Ptr[T]
VL: Ptr[T]
VR: Ptr[T]
N: fint
LDA: fint
LDVL: fint
LDVR: fint
LWORK: fint
JOBVL: byte
JOBVR: byte
T: type
B: type
def _init_real(self, jobvl: byte, jobvr: byte, n: fint):
safe_n = int(n)
a_size = safe_n * safe_n * sizeof(T)
wr_size = safe_n * sizeof(T)
wi_size = safe_n * sizeof(T)
vlr_size = safe_n * safe_n * sizeof(T) if jobvl == byte(86) else 0
vrr_size = safe_n * safe_n * sizeof(T) if jobvr == byte(86) else 0
w_size = wr_size * 2
vl_size = vlr_size * 2
vr_size = vrr_size * 2
ld = n if n else fint(1)
mem_buff = cobj(a_size + wr_size + wi_size +
vlr_size + vrr_size +
w_size + vl_size + vr_size)
a = mem_buff
wr = a + a_size
wi = wr + wr_size
vlr = wi + wi_size
vrr = vlr + vlr_size
w = vrr + vrr_size
vl = w + w_size
vr = vl + vl_size
self.A = Ptr[T](a)
self.WR = Ptr[T](wr)
self.WI = Ptr[T](wi)
self.VLR = Ptr[T](vlr)
self.VRR = Ptr[T](vrr)
self.W = Ptr[T](w)
self.VL = Ptr[T](vl)
self.VR = Ptr[T](vr)
self.N = n
self.LDA = ld
self.LDVL = ld
self.LDVR = ld
self.JOBVL = jobvl
self.JOBVR = jobvr
# work size query
work_size_query = T()
self.LWORK = fint(-1)
self.WORK = __ptr__(work_size_query)
if self.call():
free(mem_buff)
raise LinAlgError("error in evd work size query")
work_count = cast(work_size_query, int)
mem_buff2 = cobj(work_count * sizeof(T))
work = mem_buff2
self.LWORK = fint(work_count)
self.WORK = Ptr[T](work)
def _init_complex(self, jobvl: byte, jobvr: byte, n: fint):
safe_n = int(n)
a_size = safe_n * safe_n * sizeof(T)
w_size = safe_n * sizeof(T)
vl_size = safe_n * safe_n * sizeof(T) if jobvl == byte(86) else 0
vr_size = safe_n * safe_n * sizeof(T) if jobvr == byte(86) else 0
rwork_size = 2 * safe_n * sizeof(B)
total_size = a_size + w_size + vl_size + vr_size + rwork_size
ld = n if n else fint(1)
mem_buff = cobj(total_size)
a = mem_buff
w = a + a_size
vl = w + w_size
vr = vl + vl_size
rwork = vr + vr_size
self.A = Ptr[T](a)
self.WR = Ptr[B](rwork)
self.WI = Ptr[T]()
self.VLR = Ptr[T]()
self.VRR = Ptr[T]()
self.VL = Ptr[T](vl)
self.VR = Ptr[T](vr)
self.W = Ptr[T](w)
self.N = n
self.LDA = ld
self.LDVL = ld
self.LDVR = ld
self.JOBVL = jobvl
self.JOBVR = jobvr
# work size query
work_size_query = T()
self.LWORK = fint(-1)
self.WORK = __ptr__(work_size_query)
if self.call():
free(mem_buff)
raise LinAlgError("error in evd work size query")
work_count = cast(work_size_query.real, int)
if work_count == 0:
work_count = 1
mem_buff2 = cobj(work_count * sizeof(T))
work = mem_buff2
self.LWORK = fint(work_count)
self.WORK = Ptr[T](work)
def __init__(self, jobvl: byte, jobvr: byte, n: fint):
if T is complex or T is complex64:
self._init_complex(jobvl, jobvr, n)
else:
self._init_real(jobvl, jobvr, n)
def call(self):
A = self.A
WR = self.WR
WI = self.WI
VLR = self.VLR
VRR = self.VRR
WORK = self.WORK
W = self.W
VL = self.VL
VR = self.VR
N = self.N
LDA = self.LDA
LDVL = self.LDVL
LDVR = self.LDVR
LWORK = self.LWORK
JOBVL = self.JOBVL
JOBVR = self.JOBVR
rv = fint()
args_real = (__ptr__(JOBVL),
__ptr__(JOBVR),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
WR.as_byte(),
WI.as_byte(),
VLR.as_byte(),
__ptr__(LDVL),
VRR.as_byte(),
__ptr__(LDVR),
WORK.as_byte(),
__ptr__(LWORK),
__ptr__(rv))
args_cplx = (__ptr__(JOBVL),
__ptr__(JOBVR),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
W.as_byte(),
VL.as_byte(),
__ptr__(LDVL),
VR.as_byte(),
__ptr__(LDVR),
WORK.as_byte(),
__ptr__(LWORK),
WR.as_byte(),
__ptr__(rv))
if T is float:
dgeev_(*args_real)
elif T is float32:
sgeev_(*args_real)
elif T is complex:
zgeev_(*args_cplx)
elif T is complex64:
cgeev_(*args_cplx)
else:
compile_error("[internal error] bad dtype for geev")
return rv
def release(self):
free(self.WORK)
free(self.A)
def process_geev_results(self):
if T is complex or T is complex64:
return
C = type(_complextype(T))
_mk_complex_array(Ptr[C](self.W.as_byte()), self.WR, self.WI, int(self.N))
if self.JOBVL == byte(86):
_mk_geev_complex_eigenvectors(Ptr[C](self.VL.as_byte()), self.VLR, self.WI, int(self.N))
if self.JOBVR == byte(86):
_mk_geev_complex_eigenvectors(Ptr[C](self.VR.as_byte()), self.VRR, self.WI, int(self.N))
@tuple
class EigResult[A, B]:
eigenvalues: A
eigenvectors: B
def __getitem__(self, idx: Static[int]):
if idx == 0 or idx == -2:
return self.eigenvalues
elif idx == 1 or idx == -1:
return self.eigenvectors
else:
compile_error("tuple ('EigResult') index out of range")
def _eig(a, JOBVL: byte, JOBVR: byte, compute_eigenvectors: Static[int]):
a = _asarray(a)
B = type(_basetype(a.dtype))
C = type(_complextype(a.dtype))
n = _square_rows(a)
params = GeevParams[a.dtype, B](JOBVL, JOBVR, fint(n))
eigenvalues = empty(a.shape[:-2] + (n,), dtype=C)
if compute_eigenvectors:
eigenvectors = empty(a.shape, dtype=C)
else:
eigenvectors = None
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
w_out = LinearizeData(1, n, 0, eigenvalues.strides[-1])
if compute_eigenvectors:
vr_out = LinearizeData(n, n, eigenvectors.strides[-1], eigenvectors.strides[-2])
else:
vr_out = None
error_occured = False
for idx in multirange(a.shape[:-2]):
matrix_ptr = a._ptr(idx + (0, 0))
eigval_ptr = eigenvalues._ptr(idx + (0,))
if compute_eigenvectors:
eigvec_ptr = eigenvectors._ptr(idx + (0, 0))
else:
eigvec_ptr = None
a_in.linearize(params.A, matrix_ptr)
not_ok = params.call()
if not not_ok:
params.process_geev_results()
w_out.delinearize(eigval_ptr, Ptr[C](params.W.as_byte()))
if compute_eigenvectors:
vr_out.delinearize(eigvec_ptr, Ptr[C](params.VR.as_byte()))
else:
w_out.nan_matrix(eigval_ptr)
if compute_eigenvectors:
vr_out.nan_matrix(eigvec_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("Eigenvalues did not converge")
if compute_eigenvectors:
return EigResult(eigenvalues, eigenvectors)
else:
return eigenvalues
def eig(a):
jobvl_code = byte(78) # 'N'
jobvr_code = byte(86) # 'V'
return _eig(a, jobvl_code, jobvr_code, compute_eigenvectors=True)
def eigvals(a):
jobvl_code = byte(78) # 'N'
jobvr_code = byte(78) # 'N'
return _eig(a, jobvl_code, jobvr_code, compute_eigenvectors=False)
#######
# SVD #
#######
def _compute_urows_vtcolumns(jobz: byte, m: fint, n: fint):
if jobz == byte(78): # 'N'
return (fint(0), fint(0))
elif jobz == byte(65): # 'A'
return (m, n)
else: # 'S'
min_m_n = m if m < n else n
return (min_m_n, min_m_n)
class GesddParams:
A: Ptr[T]
S: Ptr[B]
U: Ptr[T]
VT: Ptr[T]
WORK: Ptr[T]
RWORK: Ptr[B]
IWORK: Ptr[fint]
M: fint
N: fint
LDA: fint
LDU: fint
LDVT: fint
LWORK: fint
JOBZ: byte
T: type
B: type
def _init_real(self, jobz: byte, m: fint, n: fint):
safe_m = int(m)
safe_n = int(n)
a_size = safe_m * safe_n * sizeof(T)
min_m_n = m if m < n else n
safe_min_m_n = int(min_m_n)
s_size = safe_min_m_n * sizeof(T)
iwork_size = 8 * safe_min_m_n * sizeof(fint)
ld = m if m else fint(1)
u_row_count, vt_column_count = _compute_urows_vtcolumns(jobz, m, n)
safe_u_row_count = int(u_row_count)
safe_vt_column_count = int(vt_column_count)
u_size = safe_u_row_count * safe_m * sizeof(T)
vt_size = safe_n * safe_vt_column_count * sizeof(T)
mem_buff = cobj(a_size + s_size + u_size + vt_size + iwork_size)
a = mem_buff
s = a + a_size
u = s + s_size
vt = u + u_size
iwork = vt + vt_size
vt_column_count = vt_column_count if vt_column_count else fint(1)
self.M = m
self.N = n
self.A = Ptr[T](a)
self.S = Ptr[B](s)
self.U = Ptr[T](u)
self.VT = Ptr[T](vt)
self.RWORK = Ptr[B]()
self.IWORK = Ptr[fint](iwork)
self.LDA = ld
self.LDU = ld
self.LDVT = vt_column_count
self.JOBZ = jobz
# work size query
work_size_query = T()
self.LWORK = fint(-1)
self.WORK = __ptr__(work_size_query)
if self.call():
free(mem_buff)
raise LinAlgError("error in gesdd work size query")
work_count = cast(work_size_query, int)
if work_count == 0:
work_count = 1
work_size = work_count * sizeof(T)
mem_buff2 = cobj(work_size)
work = mem_buff2
self.LWORK = fint(work_count)
self.WORK = Ptr[T](work)
def _init_complex(self, jobz: byte, m: fint, n: fint):
safe_m = int(m)
safe_n = int(n)
min_m_n = m if m < n else n
safe_min_m_n = int(min_m_n)
ld = m if m else fint(1)
u_row_count, vt_column_count = _compute_urows_vtcolumns(jobz, m, n)
safe_u_row_count = int(u_row_count)
safe_vt_column_count = int(vt_column_count)
a_size = safe_m * safe_n * sizeof(T)
s_size = safe_min_m_n * sizeof(B)
u_size = safe_u_row_count * safe_m * sizeof(T)
vt_size = safe_n * safe_vt_column_count * sizeof(T)
rwork_size = (7 * safe_min_m_n if jobz == byte(78) else
5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n)
rwork_size *= sizeof(T)
iwork_size = 8 * safe_min_m_n * sizeof(fint)
mem_buff = cobj(a_size +
s_size +
u_size +
vt_size +
rwork_size +
iwork_size)
a = mem_buff
s = a + a_size
u = s + s_size
vt = u + u_size
rwork = vt + vt_size
iwork = rwork + rwork_size
vt_column_count = vt_column_count if vt_column_count else fint(1)
self.A = Ptr[T](a)
self.S = Ptr[B](s)
self.U = Ptr[T](u)
self.VT = Ptr[T](vt)
self.RWORK = Ptr[B](rwork)
self.IWORK = Ptr[fint](iwork)
self.M = m
self.N = n
self.LDA = ld
self.LDU = ld
self.LDVT = vt_column_count
self.JOBZ = jobz
# work size query
work_size_query = T()
self.LWORK = fint(-1)
self.WORK = __ptr__(work_size_query)
if self.call():
free(mem_buff)
raise LinAlgError("error in gesdd work size query")
work_count = cast(Ptr[B](__ptr__(work_size_query).as_byte())[0], int)
if work_count == 0:
work_count = 1
work_size = work_count * sizeof(T)
mem_buff2 = cobj(work_size)
work = mem_buff2
self.LWORK = fint(work_count)
self.WORK = Ptr[T](work)
def __init__(self, jobz: byte, m: fint, n: fint):
if T is complex or T is complex64:
self._init_complex(jobz, m, n)
else:
self._init_real(jobz, m, n)
def call(self):
A = self.A
S = self.S
U = self.U
VT = self.VT
WORK = self.WORK
RWORK = self.RWORK
IWORK = self.IWORK
M = self.M
N = self.N
LDA = self.LDA
LDU = self.LDU
LDVT = self.LDVT
LWORK = self.LWORK
JOBZ = self.JOBZ
rv = fint()
args_real = (__ptr__(JOBZ),
__ptr__(M),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
S.as_byte(),
U.as_byte(),
__ptr__(LDU),
VT.as_byte(),
__ptr__(LDVT),
WORK.as_byte(),
__ptr__(LWORK),
IWORK,
__ptr__(rv))
args_cplx = (__ptr__(JOBZ),
__ptr__(M),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
S.as_byte(),
U.as_byte(),
__ptr__(LDU),
VT.as_byte(),
__ptr__(LDVT),
WORK.as_byte(),
__ptr__(LWORK),
RWORK.as_byte(),
IWORK,
__ptr__(rv))
if T is float:
dgesdd_(*args_real)
elif T is float32:
sgesdd_(*args_real)
elif T is complex:
zgesdd_(*args_cplx)
elif T is complex64:
cgesdd_(*args_cplx)
else:
compile_error("[internal error] bad dtype for gesdd")
return rv
def release(self):
free(self.A)
free(self.WORK)
@tuple
class SVDResult[A1, A2]:
U: A1
S: A2
Vh: A1
def __getitem__(self, idx: Static[int]):
if idx == 0 or idx == -3:
return self.U
elif idx == 1 or idx == -2:
return self.S
elif idx == 2 or idx == -1:
return self.Vh
else:
compile_error("tuple ('SVDResult') index out of range")
def _svd(a, JOBZ: byte, compute_uv: Static[int]):
B = type(_basetype(a.dtype))
m, n = _rows_cols(a)
min_m_n = min(m, n)
params = GesddParams[a.dtype, B](JOBZ, fint(m), fint(n))
S = empty(a.shape[:-2] + (min_m_n,), dtype=B)
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
s_out = LinearizeData(1, min_m_n, 0, S.strides[-1])
if compute_uv:
if JOBZ == byte(83): # 'S'
u_columns = min_m_n
v_rows = min_m_n
else:
u_columns = m
v_rows = n
U = empty(a.shape[:-2] + (m, u_columns), dtype=a.dtype)
V = empty(a.shape[:-2] + (v_rows, n), dtype=a.dtype)
u_out = LinearizeData(u_columns, m, U.strides[-1], U.strides[-2])
v_out = LinearizeData(n, v_rows, V.strides[-1], V.strides[-2])
else:
U = None
V = None
u_out = None
v_out = None
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
s_ptr = S._ptr(idx + (0,))
if compute_uv:
u_ptr = U._ptr(idx + (0, 0))
v_ptr = V._ptr(idx + (0, 0))
else:
u_ptr = None
v_ptr = None
a_in.linearize(params.A, a_ptr)
not_ok = params.call()
if not not_ok:
if compute_uv:
if JOBZ == byte(65) and min_m_n == 0:
LinearizeData.identity_matrix(params.U, m)
LinearizeData.identity_matrix(params.VT, n)
u_out.delinearize(u_ptr, params.U)
s_out.delinearize(s_ptr, params.S)
v_out.delinearize(v_ptr, params.VT)
else:
s_out.delinearize(s_ptr, params.S)
else:
if compute_uv:
u_out.nan_matrix(u_ptr)
s_out.nan_matrix(s_ptr)
v_out.nan_matrix(v_ptr)
else:
s_out.nan_matrix(s_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("SVD did not converge")
if compute_uv:
return SVDResult(U, S, V)
else:
return S
def svd(a,
full_matrices: bool = True,
compute_uv: Static[int] = True,
hermitian: bool = False):
a = _asarray(a)
dtype = a.dtype
m, n = _rows_cols(a)
k = min(m, n)
if hermitian:
pre_shape = a.shape[:-2]
s_shape = pre_shape + (k,)
u_shape = pre_shape + ((m, m) if full_matrices else (m, k))
v_shape = pre_shape + ((n, n) if full_matrices else (k, n))
if compute_uv:
e_s, e_u = eigh(a) # note that `k == m == n` in this case
sidx = Ptr[int](k)
if dtype is complex:
SH = empty(s_shape, float)
elif dtype is complex64:
SH = empty(s_shape, float32)
else:
SH = empty(s_shape, dtype)
UH = empty(u_shape, dtype)
VH = empty(v_shape, dtype)
if a.ndim == 2:
for r in range(k):
sidx[r] = r
sort(sidx, k, key=lambda x: -abs(e_s.data[x]))
for j in range(k):
mapping = sidx[j]
s_elem = e_s.data[mapping]
negative = s_elem < type(s_elem)()
SH.data[j] = abs(s_elem)
for i in range(k):
u_elem = e_u._ptr((i, mapping))[0]
dst = UH._ptr((i, j))
dst[0] = u_elem
dst = VH._ptr((j, i)) # transpose
if negative:
u_elem = -u_elem
if hasattr(u_elem, "conjugate"):
u_elem = u_elem.conjugate()
dst[0] = u_elem
else:
for idx in multirange(pre_shape):
sub_e_s = e_s._ptr(idx + (0,))
sub_e_u = e_u._ptr(idx + (0, 0))
sub_SH = SH._ptr(idx + (0,))
sub_UH = UH._ptr(idx + (0, 0))
sub_VH = VH._ptr(idx + (0, 0))
for r in range(k):
sidx[r] = r
sort(sidx, k, key=lambda x: -abs(sub_e_s[x]))
for j in range(k):
mapping = sidx[j]
s_elem = sub_e_s[mapping]
negative = s_elem < type(s_elem)()
sub_SH[j] = abs(s_elem)
for i in range(k):
u_elem = (sub_e_u + (i * k + mapping))[0]
dst = sub_UH + (i * k + j)
dst[0] = u_elem
dst = sub_VH + (j * k + i) # transpose
if negative:
u_elem = -u_elem
if hasattr(u_elem, "conjugate"):
u_elem = u_elem.conjugate()
dst[0] = u_elem
return SVDResult(UH, SH, VH)
else:
evh = eigvalsh(a)
if a.ndim == 2:
evh = eigvalsh(a)
for i in range(k):
evh.data[i] = abs(evh.data[i])
sort(evh.data, k, key=lambda x: -x)
else:
for idx in multirange(pre_shape):
sub_evh = evh._ptr(idx + (0,))
for i in range(k):
sub_evh[i] = abs(sub_evh[i])
sort(sub_evh, k, key=lambda x: -x)
return evh
if compute_uv:
jobz_code = byte(65 if full_matrices else 83) # 'A' / 'S'
return _svd(a, jobz_code, compute_uv=True)
else:
jobz_code = byte(78) # 'N'
return _svd(a, jobz_code, compute_uv=False)
def svdvals(x):
return svd(x, compute_uv=False, hermitian=False)
######
# QR #
######
class GeqrfParams:
M: fint
N: fint
A: Ptr[T]
LDA: fint
TAU: Ptr[T]
WORK: Ptr[T]
LWORK: fint
T: type
def _init_real(self, m: fint, n: fint):
min_m_n = m if m < n else n
safe_min_m_n = int(min_m_n)
safe_m = int(m)
safe_n = int(n)
a_size = safe_m * safe_n * sizeof(T)
tau_size = safe_min_m_n * sizeof(T)
lda = m if m else fint(1)
mem_buff = cobj(a_size + tau_size)
a = mem_buff
tau = a + a_size
str.memset(tau.as_byte(), byte(0), tau_size)
self.M = m
self.N = n
self.A = Ptr[T](a)
self.TAU = Ptr[T](tau)
self.LDA = lda
# work size query
work_size_query = T()
self.WORK = __ptr__(work_size_query)
self.LWORK = fint(-1)
if self.call():
free(mem_buff)
raise LinAlgError("error in geqrf work size query")
work_count = cast(work_size_query, fint)
lw = n if n else fint(1)
lw = lw if lw > work_count else work_count
self.LWORK = lw
work_size = int(lw) * sizeof(T)
mem_buff2 = cobj(work_size)
work = mem_buff2
self.WORK = Ptr[T](work)
def _init_complex(self, m: fint, n: fint):
min_m_n = m if m < n else n
safe_min_m_n = int(min_m_n)
safe_m = int(m)
safe_n = int(n)
a_size = safe_m * safe_n * sizeof(T)
tau_size = safe_min_m_n * sizeof(T)
lda = m if m else fint(1)
mem_buff = cobj(a_size + tau_size)
a = mem_buff
tau = a + a_size
str.memset(tau.as_byte(), byte(0), tau_size)
self.M = m
self.N = n
self.A = Ptr[T](a)
self.TAU = Ptr[T](tau)
self.LDA = lda
# work size query
work_size_query = T()
self.WORK = __ptr__(work_size_query)
self.LWORK = fint(-1)
if self.call():
free(mem_buff)
raise LinAlgError("error in geqrf work size query")
work_count = cast(work_size_query.real, fint)
lw = n if n else fint(1)
lw = lw if lw > work_count else work_count
self.LWORK = lw
work_size = int(lw) * sizeof(T)
mem_buff2 = cobj(work_size)
work = mem_buff2
self.WORK = Ptr[T](work)
def __init__(self, m: fint, n: fint):
if T is complex or T is complex64:
self._init_complex(m, n)
else:
self._init_real(m, n)
def call(self):
M = self.M
N = self.N
A = self.A
LDA = self.LDA
TAU = self.TAU
WORK = self.WORK
LWORK = self.LWORK
rv = fint()
args = (__ptr__(M),
__ptr__(N),
A.as_byte(),
__ptr__(LDA),
TAU.as_byte(),
WORK.as_byte(),
__ptr__(LWORK),
__ptr__(rv))
if T is float:
dgeqrf_(*args)
elif T is float32:
sgeqrf_(*args)
elif T is complex:
zgeqrf_(*args)
elif T is complex64:
cgeqrf_(*args)
else:
compile_error("[internal error] bad dtype for geqrf")
return rv
def release(self):
free(self.A)
free(self.WORK)
def _qr_r_raw(a):
m, n = _rows_cols(a)
params = GeqrfParams[a.dtype](fint(m), fint(n))
k = min(m, n)
tau = empty(a.shape[:-2] + (k,), dtype=a.dtype)
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
tau_out = LinearizeData(1, k, 1, tau.strides[-1])
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
tau_ptr = tau._ptr(idx + (0,))
a_in.linearize(params.A, a_ptr)
not_ok = params.call()
if not not_ok:
a_in.delinearize(a_ptr, params.A)
tau_out.delinearize(tau_ptr, params.TAU)
else:
tau_out.nan_matrix(tau_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("Incorrect argument found while performing "
"QR factorization")
return tau
class GqrParams:
M: fint
MC: fint
MN: fint
A: Ptr[T]
Q: Ptr[T]
LDA: fint
TAU: Ptr[T]
WORK: Ptr[T]
LWORK: fint
T: type
def _init_real(self, m: fint, n: fint, mc: fint):
min_m_n = m if m < n else n
safe_mc = int(mc)
safe_min_m_n = int(min_m_n)
safe_m = int(m)
safe_n = int(n)
a_size = safe_m * safe_n * sizeof(T)
q_size = safe_m * safe_mc * sizeof(T)
tau_size = safe_min_m_n * sizeof(T)
lda = m if m else fint(1)
mem_buff = cobj(q_size + tau_size + a_size)
q = mem_buff
tau = q + q_size
a = tau + tau_size
self.M = m
self.MC = mc
self.MN = min_m_n
self.A = Ptr[T](a)
self.Q = Ptr[T](q)
self.TAU = Ptr[T](tau)
self.LDA = lda
# work size query
work_size_query = T()
self.WORK = __ptr__(work_size_query)
self.LWORK = fint(-1)
if self.call():
free(mem_buff)
raise LinAlgError("error in gqr work size query")
work_count = cast(work_size_query, fint)
lw = n if n else fint(1)
lw = lw if lw > work_count else work_count
self.LWORK = lw
work_size = int(lw) * sizeof(T)
mem_buff2 = cobj(work_size)
work = mem_buff2
self.WORK = Ptr[T](work)
def _init_complex(self, m: fint, n: fint, mc: fint):
min_m_n = m if m < n else n
safe_mc = int(mc)
safe_min_m_n = int(min_m_n)
safe_m = int(m)
safe_n = int(n)
a_size = safe_m * safe_n * sizeof(T)
q_size = safe_m * safe_mc * sizeof(T)
tau_size = safe_min_m_n * sizeof(T)
lda = m if m else fint(1)
mem_buff = cobj(q_size + tau_size + a_size)
q = mem_buff
tau = q + q_size
a = tau + tau_size
self.M = m
self.MC = mc
self.MN = min_m_n
self.A = Ptr[T](a)
self.Q = Ptr[T](q)
self.TAU = Ptr[T](tau)
self.LDA = lda
# work size query
work_size_query = T()
self.WORK = __ptr__(work_size_query)
self.LWORK = fint(-1)
if self.call():
free(mem_buff)
raise LinAlgError("error in gqr work size query")
work_count = cast(work_size_query.real, fint)
lw = n if n else fint(1)
lw = lw if lw > work_count else work_count
self.LWORK = lw
work_size = int(lw) * sizeof(T)
mem_buff2 = cobj(work_size)
work = mem_buff2
self.WORK = Ptr[T](work)
def __init__(self, m: fint, n: fint, mc: fint):
if T is complex or T is complex64:
self._init_complex(m, n, mc)
else:
self._init_real(m, n, mc)
def __init__(self, m: fint, n: fint):
self.__init__(m, n, m if m < n else n)
def call(self):
M = self.M
MC = self.MC
MN = self.MN
A = self.A
Q = self.Q
LDA = self.LDA
TAU = self.TAU
WORK = self.WORK
LWORK = self.LWORK
rv = fint()
args = (__ptr__(M),
__ptr__(MC),
__ptr__(MN),
Q.as_byte(),
__ptr__(LDA),
TAU.as_byte(),
WORK.as_byte(),
__ptr__(LWORK),
__ptr__(rv))
if T is float:
dorgqr_(*args)
elif T is float32:
sorgqr_(*args)
elif T is complex:
zungqr_(*args)
elif T is complex64:
cungqr_(*args)
else:
compile_error("[internal error] bad dtype for geqrf")
return rv
def release(self):
free(self.Q)
free(self.WORK)
def _qr_reduced(a, tau):
m, n = _rows_cols(a)
params = GqrParams[a.dtype](fint(m), fint(n))
k = min(m, n)
q = empty(a.shape[:-2] + (m, k), dtype=a.dtype)
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
tau_in = LinearizeData(1, k, 1, tau.strides[-1])
q_out = LinearizeData(k, m, q.strides[-1], q.strides[-2])
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
tau_ptr = tau._ptr(idx + (0,))
q_ptr = q._ptr(idx + (0, 0))
a_in.linearize(params.A, a_ptr)
a_in.linearize(params.Q, a_ptr)
tau_in.linearize(params.TAU, tau_ptr)
not_ok = params.call()
if not not_ok:
q_out.delinearize(q_ptr, params.Q)
else:
q_out.nan_matrix(q_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("Incorrect argument found while performing "
"QR factorization")
return q
def _qr_complete(a, tau):
m, n = _rows_cols(a)
params = GqrParams[a.dtype](fint(m), fint(n), fint(m))
k = min(m, n)
q = empty(a.shape[:-2] + (m, m), dtype=a.dtype)
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
tau_in = LinearizeData(1, k, 1, tau.strides[-1])
q_out = LinearizeData(m, m, q.strides[-1], q.strides[-2])
error_occured = False
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
tau_ptr = tau._ptr(idx + (0,))
q_ptr = q._ptr(idx + (0, 0))
a_in.linearize(params.A, a_ptr)
a_in.linearize(params.Q, a_ptr)
tau_in.linearize(params.TAU, tau_ptr)
not_ok = params.call()
if not not_ok:
q_out.delinearize(q_ptr, params.Q)
else:
q_out.nan_matrix(q_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("Incorrect argument found while performing "
"QR factorization")
return q
@tuple
class QRResult[A]:
Q: A
R: A
def __getitem__(self, idx: Static[int]):
if idx == 0 or idx == -2:
return self.Q
elif idx == 1 or idx == -1:
return self.R
else:
compile_error("tuple ('QRResult') index out of range")
def _triu(a):
m = a.shape[-2]
n = a.shape[-1]
sm = a.strides[-2]
sn = a.strides[-1]
for idx in multirange(a.shape[:-2]):
a_ptr = a._ptr(idx + (0, 0))
for i in range(1, m):
for j in range(min(i, n)):
p = a_ptr.as_byte() + i*sm + j*sn
Ptr[a.dtype](p)[0] = zero(a.dtype)
def qr(a, mode: Static[str] = 'reduced'):
if (mode != 'reduced' and mode != 'complete' and
mode != 'r' and mode != 'raw'):
compile_error("Unrecognized mode '" + mode + "'")
a0 = a
a = _asarray(a)
# copy if _asarray() didn't do so already
if isinstance(a0, ndarray):
if a0.dtype is a.dtype:
a = a.copy()
tau = _qr_r_raw(a)
m, n = _rows_cols(a)
mn = min(m, n)
if mode == 'r':
r = a[..., :mn, :]
_triu(r)
return r
if mode == 'raw':
q = a.T
return q, tau
if mode == 'complete' and m > n:
mc = m
q = _qr_complete(a, tau)
else:
mc = mn
q = _qr_reduced(a, tau)
r = a[..., :mc, :]
_triu(a)
return QRResult(q, r)
#################
# Least Squares #
#################
class GelsdParams:
M: fint
N: fint
NRHS: fint
A: Ptr[T]
LDA: fint
B_: Ptr[T]
LDB: fint
S: Ptr[B]
RCOND: Ptr[B]
RANK: fint
WORK: Ptr[T]
LWORK: fint
RWORK: Ptr[B]
IWORK: Ptr[fint]
T: type
B: type
def _init_real(self, m: fint, n: fint, nrhs: fint):
min_m_n = m if m < n else n
max_m_n = m if m > n else n
safe_min_m_n = int(min_m_n)
safe_max_m_n = int(max_m_n)
safe_m = int(m)
safe_n = int(n)
safe_nrhs = int(nrhs)
a_size = safe_m * safe_n * sizeof(T)
b_size = safe_max_m_n * safe_nrhs * sizeof(T)
s_size = safe_min_m_n * sizeof(T)
lda = m if m else fint(1)
ldb = max_m_n if max_m_n else fint(1)
msize = a_size + b_size + s_size
mem_buff = cobj(msize if msize else 1)
a = mem_buff
b = a + a_size
s = b + b_size
self.M = m
self.N = n
self.NRHS = nrhs
self.A = Ptr[T](a)
self.B_ = Ptr[T](b)
self.S = Ptr[B](s)
self.LDA = lda
self.LDB = ldb
# work size query
work_size_query = T()
iwork_size_query = fint()
self.WORK = __ptr__(work_size_query)
self.IWORK = __ptr__(iwork_size_query)
self.RWORK = Ptr[B]()
self.LWORK = fint(-1)
if self.call():
free(mem_buff)
raise LinAlgError("error in gelsd work size query")
work_count = cast(work_size_query, fint)
work_size = int(work_count) * sizeof(T)
iwork_size = int(iwork_size_query) * sizeof(fint)
mem_buff2 = cobj(work_size + iwork_size)
work = mem_buff2
iwork = work + work_size
self.WORK = Ptr[T](work)
self.RWORK = Ptr[B]()
self.IWORK = Ptr[fint](iwork)
self.LWORK = work_count
def _init_complex(self, m: fint, n: fint, nrhs: fint):
min_m_n = m if m < n else n
max_m_n = m if m > n else n
safe_min_m_n = int(min_m_n)
safe_max_m_n = int(max_m_n)
safe_m = int(m)
safe_n = int(n)
safe_nrhs = int(nrhs)
a_size = safe_m * safe_n * sizeof(T)
b_size = safe_max_m_n * safe_nrhs * sizeof(T)
s_size = safe_min_m_n * sizeof(T)
lda = m if m else fint(1)
ldb = max_m_n if max_m_n else fint(1)
msize = a_size + b_size + s_size
mem_buff = cobj(msize if msize else 1)
a = mem_buff
b = a + a_size
s = b + b_size
self.M = m
self.N = n
self.NRHS = nrhs
self.A = Ptr[T](a)
self.B_ = Ptr[T](b)
self.S = Ptr[B](s)
self.LDA = lda
self.LDB = ldb
# work size query
work_size_query = T()
rwork_size_query = B()
iwork_size_query = fint()
self.WORK = __ptr__(work_size_query)
self.IWORK = __ptr__(iwork_size_query)
self.RWORK = __ptr__(rwork_size_query)
self.LWORK = fint(-1)
if self.call():
free(mem_buff)
raise LinAlgError("error in gelsd work size query")
work_count = cast(work_size_query.real, fint)
work_size = int(work_count) * sizeof(T)
rwork_size = cast(rwork_size_query, int) * sizeof(B)
iwork_size = int(iwork_size_query) * sizeof(fint)
mem_buff2 = cobj(work_size + rwork_size + iwork_size)
work = mem_buff2
rwork = work + work_size
iwork = rwork + rwork_size
self.WORK = Ptr[T](work)
self.RWORK = Ptr[B](rwork)
self.IWORK = Ptr[fint](iwork)
self.LWORK = work_count
def __init__(self, m: fint, n: fint, nrhs: fint):
if T is complex or T is complex64:
self._init_complex(m, n, nrhs)
else:
self._init_real(m, n, nrhs)
def call(self):
M = self.M
N = self.N
NRHS = self.NRHS
A = self.A
LDA = self.LDA
B = self.B_
LDB = self.LDB
S = self.S
RCOND = self.RCOND
RANK = self.RANK
WORK = self.WORK
LWORK = self.LWORK
RWORK = self.RWORK
IWORK = self.IWORK
rv = fint()
args_real = (__ptr__(M),
__ptr__(N),
__ptr__(NRHS),
A.as_byte(),
__ptr__(LDA),
B.as_byte(),
__ptr__(LDB),
S.as_byte(),
RCOND.as_byte(),
__ptr__(RANK),
WORK.as_byte(),
__ptr__(LWORK),
IWORK,
__ptr__(rv))
args_cplx = (__ptr__(M),
__ptr__(N),
__ptr__(NRHS),
A.as_byte(),
__ptr__(LDA),
B.as_byte(),
__ptr__(LDB),
S.as_byte(),
RCOND.as_byte(),
__ptr__(RANK),
WORK.as_byte(),
__ptr__(LWORK),
RWORK.as_byte(),
IWORK,
__ptr__(rv))
if T is float:
dgelsd_(*args_real)
elif T is float32:
sgelsd_(*args_real)
elif T is complex:
zgelsd_(*args_cplx)
elif T is complex64:
cgelsd_(*args_cplx)
else:
compile_error("[internal error] bad dtype for gelsd")
self.RANK = RANK
return rv
def release(self):
free(self.A)
free(self.WORK)
def _abs2(p: Ptr[T], n: int, T: type):
B = type(_basetype(T))
res = B()
for i in range(n):
el = p[i]
if T is complex or T is complex64:
res += el.real*el.real + el.imag*el.imag
else:
res += el * el
return res
def lstsq(a, b, rcond = None):
dtype = type(
coerce(
type(asarray(a).data[0]),
type(asarray(b).data[0])))
a = _asarray(a, dtype=dtype)
b = _asarray(b, dtype=dtype)
if a.ndim != 2:
compile_error("lstsq argument 'a' must be 2-dimensional")
if b.ndim == 1:
b2 = b.reshape(b.shape[0], 1)
elif b.ndim == 2:
b2 = b
else:
compile_error("lstsq argument 'b' must be 1- or 2-dimensional")
B = type(_basetype(a.dtype))
m, n = a.shape
m2, nrhs = b2.shape
if m != m2:
raise LinAlgError('Incompatible dimensions')
if rcond is None:
rc = eps(B) * cast(max(m, n), B)
else:
rc = cast(rcond, B)
if nrhs == 0:
b2 = ndarray((m, 1), Ptr[b2.dtype](m))
str.memset(b2.data.as_byte(), byte(0), b2.nbytes)
safe_nrhs = 1
else:
safe_nrhs = nrhs
excess = m - n
params = GelsdParams[a.dtype, B](fint(m), fint(n), fint(safe_nrhs))
x = empty((n, safe_nrhs), dtype=a.dtype)
r = empty(safe_nrhs, dtype=B)
s = empty(min(m, n), dtype=B)
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
b_in = LinearizeData(safe_nrhs, m, b2.strides[-1], b2.strides[-2], max(m, n))
x_out = LinearizeData(safe_nrhs, n, x.strides[-1], x.strides[-2], max(m, n))
r_out = LinearizeData(1, safe_nrhs, 1, r.strides[-1])
s_out = LinearizeData(1, min(m, n), 1, s.strides[-1])
a_ptr = a.data
b_ptr = b2.data
x_ptr = x.data
r_ptr = r.data
s_ptr = s.data
a_in.linearize(params.A, a_ptr)
b_in.linearize(params.B_, b_ptr)
params.RCOND = __ptr__(rc)
not_ok = params.call()
error_occured = False
if not not_ok:
x_out.delinearize(x_ptr, params.B_)
rank = int(params.RANK)
s_out.delinearize(s_ptr, params.S)
if excess >= 0 and int(params.RANK) == n:
components = params.B_ + n
for i in range(safe_nrhs):
vector = components + i*m
r_ptr[i] = _abs2(vector, excess)
else:
r_out.nan_matrix(r_ptr)
else:
x_out.nan_matrix(x_ptr)
r_out.nan_matrix(r_ptr)
rank = -1
s_out.nan_matrix(s_ptr)
error_occured = True
params.release()
if error_occured:
raise LinAlgError("SVD did not converge in Linear Least Squares")
if m == 0:
str.memset(x.data.as_byte(), byte(0), x.nbytes)
if nrhs == 0:
free(x.data)
free(r.data)
x = ndarray((x.shape[0], 0), Ptr[x.dtype]())
r = ndarray((0,), Ptr[r.dtype]())
if b.ndim == 1:
x1 = x.reshape(x.shape[0])
else:
x1 = x
if r.size and (rank != n or m <= n):
free(r.data)
r = ndarray((0,), Ptr[r.dtype]())
return x1, r, rank, s
###################
# Matrix Multiply #
###################
def _dot_noblas(_ip1: Ptr[T1], is1: int, _ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
T1: type, T2: type, T3: type):
ip1 = _ip1.as_byte()
ip2 = _ip2.as_byte()
ans = zero(T3)
for i in range(n):
e1 = Ptr[T1](ip1)[0]
e2 = Ptr[T2](ip2)[0]
ans += cast(e1, T3) * cast(e2, T3)
ip1 += is1
ip2 += is2
op[0] = ans
_CBLAS_INT_MAX: Static[int] = 0x7fffffff
def _blas_stride(stride: int, itemsize: int):
if stride > 0 and stride % itemsize == 0:
stride //= itemsize
if stride <= _CBLAS_INT_MAX:
return stride
return 0
def _dot(ip1: Ptr[T1], is1: int, ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
T1: type, T2: type, T3: type):
ib1 = _blas_stride(is1, sizeof(T1))
ib2 = _blas_stride(is2, sizeof(T2))
args = (fint(n), ip1.as_byte(), fint(ib1), ip2.as_byte(), fint(ib2))
if ib1 == 0 or ib2 == 0:
_dot_noblas(ip1, is1, ip2, is2, op, n)
else:
if T1 is float and T2 is float and T3 is float:
op[0] = cblas_ddot(*args)
elif T1 is float32 and T2 is float32 and T3 is float32:
op[0] = cblas_sdot(*args)
elif T1 is complex and T2 is complex and T3 is complex:
cblas_zdotu_sub(*args, op)
elif T1 is complex64 and T2 is complex64 and T3 is complex64:
cblas_cdotu_sub(*args, op)
else:
_dot_noblas(ip1, is1, ip2, is2, op, n)
BLAS_MAXSIZE: Static[int] = 2147483646 # 2^31 - 1
def _is_blasable2d(byte_stride1: int,
byte_stride2: int,
d1: int,
d2: int,
itemsize: int):
unit_stride1 = byte_stride1 // itemsize
if byte_stride2 != itemsize:
return False
if (byte_stride1 % itemsize == 0 and
unit_stride1 >= d2 and
unit_stride1 <= BLAS_MAXSIZE):
return True
return False
def _gemv(ip1: Ptr[T], is1_m: int, is1_n: int,
ip2: Ptr[T], is2_n: int,
op: Ptr[T], op_m: int,
m: int, n: int, T: type):
itemsize = sizeof(T)
if _is_blasable2d(is1_m, is1_n, m, n, itemsize):
order = CBLAS_COL_MAJOR
lda = is1_m // itemsize
else:
order = CBLAS_ROW_MAJOR
lda = is1_n // itemsize
one_dtype = cast(1, T)
zero_dtype = cast(0, T)
if T is float or T is float32:
one = one_dtype
zero = zero_dtype
else:
one = __ptr__(one_dtype).as_byte()
zero = __ptr__(zero_dtype).as_byte()
args = (order, CBLAS_TRANS, fint(n), fint(m),
one, ip1.as_byte(), fint(lda), ip2.as_byte(),
fint(is2_n // itemsize), zero, op.as_byte(),
fint(op_m // itemsize))
if T is float:
cblas_dgemv(*args)
elif T is float32:
cblas_sgemv(*args)
elif T is complex:
cblas_zgemv(*args)
elif T is complex64:
cblas_cgemv(*args)
else:
compile_error("[internal error] bad input type")
def _matmul_matrixmatrix(ip1: Ptr[T], is1_m: int, is1_n: int,
ip2: Ptr[T], is2_n: int, is2_p: int,
op: Ptr[T], os_m: int, os_p: int,
m: int, n: int, p: int, T: type):
order = CBLAS_ROW_MAJOR
itemsize = sizeof(T)
ldc = os_m // itemsize
if _is_blasable2d(is1_m, is1_n, m, n, itemsize):
trans1 = CBLAS_NO_TRANS
lda = is1_m // itemsize
else:
trans1 = CBLAS_TRANS
lda = is1_n // itemsize
if _is_blasable2d(is2_n, is2_p, n, p, itemsize):
trans2 = CBLAS_NO_TRANS
ldb = is2_n // itemsize
else:
trans2 = CBLAS_TRANS
ldb = is2_p // itemsize
one_dtype = cast(1, T)
zero_dtype = cast(0, T)
if T is float or T is float32:
one = one_dtype
zero = zero_dtype
else:
one = __ptr__(one_dtype).as_byte()
zero = __ptr__(zero_dtype).as_byte()
if (ip1 == ip2 and
m == p and
is1_m == is2_p and
is1_n == is2_n and
trans1 != trans2):
ld = lda if trans1 == CBLAS_NO_TRANS else ldb
args = (order, CBLAS_UPPER, trans1, fint(p),
fint(n), one, ip1, fint(ld),
zero, op, fint(ldc))
if T is float:
cblas_dsyrk(*args)
elif T is float32:
cblas_ssyrk(*args)
elif T is complex:
cblas_zsyrk(*args)
elif T is complex64:
cblas_csyrk(*args)
else:
compile_error("[internal error] bad input type")
# Copy the triangle
for i in range(p):
for j in range(i + 1, p):
op[j * ldc + i] = op[i * ldc + j]
else:
args = (order, trans1, trans2, fint(m), fint(p), fint(n),
one, ip1, fint(lda), ip2,
fint(ldb), zero, op, fint(ldc))
if T is float:
cblas_dgemm(*args)
elif T is float32:
cblas_sgemm(*args)
elif T is complex:
cblas_zgemm(*args)
elif T is complex64:
cblas_cgemm(*args)
else:
compile_error("[internal error] bad input type")
def _matmul_inner_noblas(_ip1: Ptr[T1], is1_m: int, is1_n: int,
_ip2: Ptr[T2], is2_n: int, is2_p: int,
_op: Ptr[T3], os_m: int, os_p: int,
dm: int, dn: int, dp: int, T1: type,
T2: type, T3: type):
ib1_n = is1_n * dn
ib2_n = is2_n * dn
ib2_p = is2_p * dp
ob_p = os_p * dp
ip1 = _ip1.as_byte()
ip2 = _ip2.as_byte()
op = _op.as_byte()
if T3 is bool:
for m in range(dm):
for p in range(dp):
ip1tmp = ip1
ip2tmp = ip2
Ptr[T3](op)[0] = False
for n in range(dn):
val1 = Ptr[T1](ip1tmp)[0]
val2 = Ptr[T2](ip2tmp)[0]
if val1 and val2:
Ptr[T3](op)[0] = True
break
ip2tmp += is2_n
ip1tmp += is1_n
op += os_p
ip2 += is2_p
op -= ob_p
ip2 -= ib2_p
ip1 += is1_m
op += os_m
else:
for m in range(dm):
for p in range(dp):
for n in range(dn):
val1 = cast(Ptr[T1](ip1)[0], T3)
val2 = cast(Ptr[T2](ip2)[0], T3)
Ptr[T3](op)[0] += val1 * val2
ip2 += is2_n
ip1 += is1_n
ip1 -= ib1_n
ip2 -= ib2_n
op += os_p
ip2 += is2_p
op -= ob_p
ip2 -= ib2_p
ip1 += is1_m
op += os_m
def _matmul(A, B, C):
# Caller ensures arrays are broadcasted and of correct shape.
# A -> (m, n)
# B -> (n, p)
# C -> (m, p)
dm = A.shape[-2]
dn = A.shape[-1]
dp = B.shape[-1]
is1_m = A.strides[-2]
is1_n = A.strides[-1]
is2_n = B.strides[-2]
is2_p = B.strides[-1]
os_m = C.strides[-2]
os_p = C.strides[-1]
sz = sizeof(A.dtype)
special_case = (dm == 1 or dn == 1 or dp == 1)
any_zero_dim = (dm == 0 or dn == 0 or dp == 0)
scalar_out = (dm == 1 and dp == 1)
scalar_vec = (dn == 1 and (dp == 1 or dm == 1))
too_big_for_blas = (dm > BLAS_MAXSIZE or dn > BLAS_MAXSIZE or dp > BLAS_MAXSIZE)
i1_c_blasable = _is_blasable2d(is1_m, is1_n, dm, dn, sz)
i2_c_blasable = _is_blasable2d(is2_n, is2_p, dn, dp, sz)
i1_f_blasable = _is_blasable2d(is1_n, is1_m, dn, dm, sz)
i2_f_blasable = _is_blasable2d(is2_p, is2_n, dp, dn, sz)
i1blasable = i1_c_blasable or i1_f_blasable
i2blasable = i2_c_blasable or i2_f_blasable
o_c_blasable = _is_blasable2d(os_m, os_p, dm, dp, sz)
o_f_blasable = _is_blasable2d(os_p, os_m, dp, dm, sz)
vector_matrix = (dm == 1 and i2blasable and
_is_blasable2d(is1_n, sz, dn, 1, sz))
matrix_vector = (dp == 1 and i1blasable and
_is_blasable2d(is2_n, sz, dn, 1, sz))
for idx in multirange(C.shape[:-2]):
ip1 = A._ptr(idx + (0, 0), broadcast=(A.ndim > 2))
ip2 = B._ptr(idx + (0, 0), broadcast=(B.ndim > 2))
op = C._ptr(idx + (0, 0))
if not ((A.dtype is B.dtype and B.dtype is C.dtype) and
(A.dtype is float or A.dtype is float32 or
A.dtype is complex or A.dtype is complex64)):
_matmul_inner_noblas(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p,
dm, dn, dp)
else:
if too_big_for_blas or any_zero_dim:
_matmul_inner_noblas(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p,
dm, dn, dp)
elif special_case:
if scalar_out:
_dot(ip1, is1_n, ip2, is2_n, op, dn)
elif scalar_vec:
_matmul_inner_noblas(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p,
dm, dn, dp)
elif vector_matrix:
_gemv(ip2, is2_p, is2_n, ip1, is1_n, op, os_p, dp, dn)
elif matrix_vector:
_gemv(ip1, is1_m, is1_n, ip2, is2_n, op, os_m, dm, dn)
else:
_matmul_inner_noblas(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p,
dm, dn, dp)
else:
if i1blasable and i2blasable and o_c_blasable:
_matmul_matrixmatrix(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p,
dm, dn, dp)
elif i1blasable and i2blasable and o_f_blasable:
_matmul_matrixmatrix(ip2, is2_p, is2_n,
ip1, is1_n, is1_m,
op, os_p, os_m,
dp, dn, dm)
else:
_matmul_inner_noblas(ip1, is1_m, is1_n,
ip2, is2_n, is2_p,
op, os_m, os_p,
dm, dn, dp)
def _check_out(out, ans_shape, dtype: type):
if out.dtype is not dtype:
compile_error("'out' array has incorrect type '" + out.dtype.__name__ + "'")
if staticlen(out.shape) != staticlen(ans_shape):
compile_error("'out' array has incorrect number of dimensions")
if out.shape != ans_shape:
raise ValueError("'out' array has incorrect shape")
def _matmul_ufunc(x1, x2, out = None, dtype: type = NoneType):
if not isinstance(x1, ndarray) or not isinstance(x2, ndarray):
y1 = asarray(x1)
y2 = asarray(x2)
return _matmul_ufunc(y1, y2, out=out, dtype=dtype)
T1 = x1.dtype
T2 = x2.dtype
x1d: Static[int] = x1.ndim
x2d: Static[int] = x2.ndim
if dtype is NoneType:
return _matmul_ufunc(x1, x2, out=out, dtype=type(coerce(T1, T2)))
if x1d == 0:
compile_error("first argument is 0-dimensional; must be at least 1-d")
if x2d == 0:
compile_error("second argument is 0-dimensional; must be at least 1-d")
if x1d == 1 and x2d == 1:
if x1.shape != x2.shape:
raise ValueError("matmul: vectors have different lengths")
op = zero(dtype)
_dot(x1.data, x1.strides[0], x2.data, x2.strides[0], __ptr__(op), x1.size)
return op
if x1d == 1:
y1 = x1.reshape((1, -1))
else:
y1 = x1
if x2d == 1:
y2 = x2.reshape((-1, 1))
else:
y2 = x2
y1s = y1.shape
y2s = y2.shape
y1d: Static[int] = y1.ndim
y2d: Static[int] = y2.ndim
base1s = y1s[:-2]
base2s = y2s[:-2]
mat1s = y1s[-2:]
mat2s = y2s[-2:]
m = mat1s[0]
k = mat1s[1]
n = mat2s[1]
if k != mat2s[0]:
raise ValueError("matmul: last dimension of first argument does not "
"match second-to-last dimension of second argument")
ans_base = broadcast_shapes(base1s, base2s)
if x1d == 1 and x2d == 1:
ans_shape = ans_base
elif x1d == 1:
ans_shape = ans_base + (mat2s[1],)
elif x2d == 1:
ans_shape = ans_base + (mat1s[0],)
else:
ans_shape = ans_base + (mat1s[0], mat2s[1])
if out is None:
ans = zeros(ans_shape, dtype=dtype)
elif isinstance(out, ndarray):
_check_out(out, ans_shape, dtype)
ans = out
ans.map(lambda _: cast(0, ans.dtype), inplace=True)
else:
compile_error("'out' must be an ndarray")
if x1d == 1 and x2d == 1:
_matmul(y1, y2, ans.reshape(1, 1))
elif x1d == 1:
_matmul(y1.reshape((1,) * (x2d - 1) + y1.shape), y2, ans.reshape(ans.shape + (1,)))
elif x2d == 1:
_matmul(y1, y2.reshape((1,) * (x1d - 1) + y2.shape), ans.reshape(ans.shape + (1,)))
elif x1d > x2d:
_matmul(y1, y2.reshape((1,) * (x1d - x2d) + y2.shape), ans)
elif x1d < x2d:
_matmul(y1.reshape((1,) * (x2d - x1d) + y1.shape), y2, ans)
else:
_matmul(y1, y2, ans)
if out is not None:
return out
else:
if ans.ndim == 0:
return ans.data[0]
else:
return ans
def matmul(x1, x2):
return _matmul_ufunc(x1, x2)
def dot(a, b, out = None):
x1 = asarray(a)
x2 = asarray(b)
T1 = x1.dtype
T2 = x2.dtype
x1d: Static[int] = staticlen(x1.shape)
x2d: Static[int] = staticlen(x2.shape)
if x1d == 0 or x2d == 0:
return multiply(a, b, out=out)
if x1d <= 2 and x2d <= 2:
return _matmul_ufunc(a, b, out=out)
# most general case
x1s = x1.shape
x2s = x2.shape
m = x1s[-2]
k = x1s[-1]
n = x2s[-1]
if k != x2s[-2]:
raise ValueError("dot: last dimension of first argument does not "
"match second-to-last dimension of second argument")
dtype = type(coerce(T1, T2))
ans_shape = x1s[:-1] + x2s[:-2] + (x2s[-1],)
if out is None:
ans = zeros(ans_shape, dtype=dtype)
elif isinstance(out, ndarray):
_check_out(out, ans_shape, dtype)
ans = out
else:
compile_error("'out' must be an ndarray")
for idx in multirange(ans_shape):
q = ans._ptr(idx)
for i1 in range(k):
x1_idx = idx[:x1d-1] + (i1,)
x2_idx = idx[x1d-1:-1] + (i1, idx[-1])
p1 = x1._ptr(x1_idx)
p2 = x2._ptr(x2_idx)
q[0] += cast(p1[0], dtype) * cast(p2[0], dtype)
return ans
##################
# Other Routines #
##################
def _matrix_power(a, n: int):
a = asarray(a)
dtype = a.dtype
s = a.shape
m = _square_rows(a)
if n == 0:
if staticlen(s) == 2:
return eye(m, dtype=dtype)
else:
b = empty_like(a)
for idx in multirange(s[:-2]):
p = b._ptr(idx + (0, 0))
for i in range(m):
for j in range(m):
val = cast(1, dtype) if i == j else zero(dtype)
(p + (i * m + j))[0] = val
return b
elif n < 0:
if type(inv(a)) is type(a):
a = inv(a)
n = abs(n)
else:
raise ValueError(
"cannot take integral matrix to non-constant negative "
"power; use 'matrix_power(inv(a), abs(n))' instead"
)
if n == 1:
return a
elif n == 2:
return matmul(a, a)
elif n == 3:
return matmul(matmul(a, a), a)
result = a
z = a
first = True
have_result = False
while n > 0:
if first:
first = False
else:
z = matmul(z, z)
first = False
n, bit = divmod(n, 2)
if bit:
if have_result:
result = matmul(result, z)
else:
result = z
have_result = True
return result
def matrix_power(a, n: int):
return _matrix_power(a, n)
@overload
def matrix_power(a, n: Static[int]):
nonstatic = lambda x: x # hack to get non-static argument
if n >= 0:
return _matrix_power(a, n)
else:
return _matrix_power(inv(a), -n)
def _multi_dot_three(A, B, C, out=None):
a0, a1b0 = A.shape
b1c0, c1 = C.shape
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
cost1 = a0 * b1c0 * (a1b0 + c1)
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
cost2 = a1b0 * c1 * (a0 + b1c0)
if cost1 < cost2:
return dot(dot(A, B), C, out=out)
else:
return dot(A, dot(B, C), out=out)
def _multi_dot(arrays, order, i, j, out=None):
if i == j:
#assert out is None
return arrays[i]
else:
return dot(_multi_dot(arrays, order, i, order[i, j]),
_multi_dot(arrays, order, order[i, j] + 1, j),
out=out)
def _multi_dot_matrix_chain_order(arrays, return_costs: Static[int] = False):
n = len(arrays)
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
m = zeros((n, n), dtype=float)
s = empty((n, n), dtype=int)
for l in range(1, n):
for i in range(n - l):
j = i + l
m[i, j] = inf(float)
for k in range(i, j):
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
if q < m[i, j]:
m[i, j] = q
s[i, j] = k # Note that Cormen uses 1-based index
if return_costs:
return (s, m)
else:
return s
def multi_dot(arrays, out = None):
n = len(arrays)
if isinstance(arrays, Tuple):
if staticlen(arrays) < 2:
compile_error("Expecting at least two arrays.")
if staticlen(arrays) == 2:
return dot(arrays[0], arrays[1], out=out)
else:
if n < 2:
raise ValueError("Expecting at least two arrays.")
if n == 2:
return dot(arrays[0], arrays[1], out=out)
ndim_first: Static[int] = staticlen(arrays[0].shape)
ndim_last: Static[int] = staticlen(arrays[-1].shape)
def fix_first(arr):
arr = asarray(arr)
if staticlen(arr.shape) == 1:
return atleast_2d(arr)
else:
return arr
def fix_last(arr):
arr = asarray(arr)
if staticlen(arr.shape) == 1:
return atleast_2d(arr).T
else:
return arr
if isinstance(arrays, Tuple):
xarrays = (fix_first(arrays[0]),) + tuple(asarray(arr) for arr in arrays[1:-1]) + (fix_last(arrays[-1]),)
else:
xarrays = [asarray(arr) for arr in arrays]
xarrays[0] = fix_first(xarrays[0])
xarrays[-1] = fix_last(xarrays[-1])
for arr in xarrays:
if staticlen(arr.shape) != 2:
compile_error("1-dimensional array given. Array must be two-dimensional")
if n == 3:
result = _multi_dot_three(xarrays[0], xarrays[1], xarrays[2], out=out)
else:
order = _multi_dot_matrix_chain_order(xarrays)
result = _multi_dot(xarrays, order, 0, n - 1, out=out)
if ndim_first == 1 and ndim_last == 1:
return result[0, 0] # scalar
elif ndim_first == 1 or ndim_last == 1:
return result.ravel() # 1-D
else:
return result
def _vdot_ptr(x: Ptr[T1], y: Ptr[T2], n: int, incx: int, incy: int, T1: type, T2: type):
if T1 is float and T2 is float:
return cblas_ddot(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy))
elif T1 is float32 and T2 is float32:
return cblas_sdot(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy))
elif T1 is complex and T2 is complex:
z = complex()
cblas_zdotc_sub(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy), __ptr__(z).as_byte())
return z
elif T1 is complex64 and T2 is complex64:
z = complex64()
cblas_cdotc_sub(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy), __ptr__(z).as_byte())
return z
else:
TR = type(coerce(T1, T2))
z = TR()
for _ in range(n):
a = x[0]
b = y[0]
if hasattr(a, "conjugate"):
a = a.conjugate()
z += cast(a, TR) * cast(b, TR)
x += incx
y += incy
return z
def vdot(a, b):
def mismatch():
raise ValueError("vdot: inputs have different sizes")
a = asarray(a)
b = asarray(b)
if staticlen(a.shape) == 0 and staticlen(b.shape) == 0:
x = a.data[0]
y = b.data[0]
TR = coerce(type(x), type(y))
if hasattr(x, "conjugate"):
x = x.conjugate()
return cast(x, TR) * cast(y, TR)
n = a.size
if n != b.size:
mismatch()
if staticlen(a.shape) == 1 and staticlen(b.shape) == 1:
inca = a.strides[0] // a.itemsize
incb = b.strides[0] // b.itemsize
return _vdot_ptr(a.data, b.data, n, inca, incb)
if a._contig_match(b):
return _vdot_ptr(a.data, b.data, n, 1, 1)
T1 = a.dtype
T2 = b.dtype
TR = type(coerce(T1, T2))
z = TR()
for e1, e2 in zip(a.flat, b.flat):
if hasattr(e1, "conjugate"):
e1 = e1.conjugate()
z += cast(e1, TR) * cast(e2, TR)
return z
def outer(a, b, out = None):
a = asarray(a)
b = asarray(b)
return multiply(a.ravel()[:, None], b.ravel()[None, :], out)
def inner(a, b):
a = asarray(a)
b = asarray(b)
shape_a = a.shape
shape_b = b.shape
if staticlen(shape_a) == 0 or staticlen(shape_b) == 0:
return multiply(a, b)
if staticlen(shape_a) == 1 and staticlen(shape_b) == 1:
return dot(a, b)
n = shape_a[-1]
if n != shape_b[-1]:
raise ValueError("inner: mismatch in last dimension of inputs")
shape_cut_a = shape_a[:-1]
shape_cut_b = shape_b[:-1]
out_shape = shape_cut_a + shape_cut_b
dtype = type(coerce(a.dtype, b.dtype))
out = empty(out_shape, dtype=dtype)
for idx1 in multirange(shape_cut_a):
for idx2 in multirange(shape_cut_b):
q = out._ptr(idx1 + idx2)
q[0] = zero(dtype)
for r in range(n):
xa = a._ptr(idx1 + (r,))[0]
xb = b._ptr(idx2 + (r,))[0]
xc = cast(xa, dtype) * cast(xb, dtype)
q[0] += xc
return out
def _tensordot(a, b, axes, ndim: Static[int] = -1):
def get_axes(axes):
if isinstance(axes, int):
axes_a = list(range(-axes, 0))
axes_b = list(range(0, axes))
na = len(axes_a)
nb = len(axes_b)
return axes_a, axes_b
axes_a, axes_b = axes
if isinstance(axes_a, int):
xa = [axes_a]
else:
xa = list(axes_a)
if isinstance(axes_b, int):
xb = [axes_b]
else:
xb = list(axes_b)
return xa, xb
axes_a, axes_b = get_axes(axes)
na = len(axes_a)
nb = len(axes_b)
as_ = List[int](a.shape)
nda = a.ndim
bs = List[int](b.shape)
ndb = b.ndim
equal = True
if na != nb:
equal = False
else:
for k in range(na):
if as_[axes_a[k]] != bs[axes_b[k]]:
equal = False
break
if axes_a[k] < 0:
axes_a[k] += nda
if axes_b[k] < 0:
axes_b[k] += ndb
if not equal:
raise ValueError("shape-mismatch for sum")
notin = [k for k in range(nda) if k not in axes_a]
newaxes_a = notin + axes_a
N2 = 1
for axis in axes_a:
N2 *= as_[axis]
M2 = 1
for ax in notin:
M2 *= as_[ax]
newshape_a = (M2, N2)
if ndim >= 0:
olda = [as_[axis] for axis in notin]
else:
olda = None
notin = [k for k in range(ndb) if k not in axes_b]
newaxes_b = axes_b + notin
N2 = 1
for axis in axes_b:
N2 *= bs[axis]
M2 = 1
for ax in notin:
M2 *= bs[ax]
newshape_b = (N2, M2)
if ndim >= 0:
oldb = [bs[axis] for axis in notin]
else:
oldb = None
at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
# NOTE: 'olda + oldb' length is not known at compile-time,
# so cannot reshape unless 'ndim' is given.
if ndim >= 0:
newshape = (0,) * ndim
pnewshape = Ptr[int](__ptr__(newshape).as_byte())
i = 0
for j in olda:
pnewshape[i] = j
i += 1
for j in oldb:
pnewshape[i] = j
i += 1
return res.reshape(newshape)
else:
return res
def tensordot(a, b, axes):
a = asarray(a)
b = asarray(b)
return _tensordot(a, b, axes)
@overload
def tensordot(a, b, axes: Static[int] = 2):
if axes < 0:
return tensordot(a, b, axes=0)
a = asarray(a)
b = asarray(b)
if axes > a.ndim + 1 or axes > b.ndim + 1:
compile_error("'axes' too large for given arrays")
return _tensordot(a, b, axes=axes, ndim=(a.ndim + b.ndim - 2*axes))
def kron(a, b):
b = asarray(b)
ndb: Static[int] = staticlen(b.shape)
a = array(a, copy=False, ndmin=ndb)
nda: Static[int] = staticlen(a.shape)
nd: Static[int] = ndb if ndb >= nda else nda
if nda == 0 or ndb == 0:
return multiply(a, b)
as_ = a.shape
bs = b.shape
if not a.flags.contiguous:
a = reshape(a, as_)
if not b.flags.contiguous:
b = reshape(b, bs)
as_ = (1,)*(ndb-nda if ndb-nda >= 0 else 0) + as_
bs = (1,)*(nda-ndb if nda-ndb >= 0 else 0) + bs
a_arr = expand_dims(a, axis=tuple(i for i in staticrange(ndb-nda)))
b_arr = expand_dims(b, axis=tuple(i for i in staticrange(nda-ndb)))
a_arr = expand_dims(a_arr, axis=tuple(i for i in staticrange(1, nd*2, 2)))
b_arr = expand_dims(b_arr, axis=tuple(i for i in staticrange(0, nd*2, 2)))
result = multiply(a_arr, b_arr)
res_shape = tuple(as_[i] * bs[i] for i in staticrange(staticlen(as_)))
result = result.reshape(res_shape)
return result
def _trace_ufunc(a, offset: int = 0, axis1: int = 0, axis2: int = 1,
dtype: type = NoneType, out = None):
a = asarray(a)
if dtype is NoneType:
return _trace_ufunc(a, offset=offset, axis1=axis1, axis2=axis2,
dtype=a.dtype, out=out)
m, n = _rows_cols(a)
s = a.shape
ndim = a.ndim
axis1 = normalize_axis_index(axis1, a.ndim, "axis1")
axis2 = normalize_axis_index(axis2, a.ndim, "axis2")
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
if staticlen(s) == 2:
i = 0
t = zero(dtype)
while ((offset >= 0 and (i < m and i + offset < n)) or
(offset < 0 and (i - offset < m and i < n))):
e = a._ptr((i, i + offset))[0] if offset >= 0 else a._ptr((i - offset, i))[0]
t += cast(e, dtype)
i += 1
if out is None:
return t
else:
if staticlen(out.shape) != 0:
compile_error("expected 0-dimensional output parameter")
if out.dtype is not dtype:
compile_error("output parameter has the wrong dtype")
out.data[0] = t
return out
else:
if axis1 > axis2:
axis1, axis2 = axis2, axis2
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
m, n = s[axis1], s[axis2]
if out is None:
ans = empty(ans_shape, dtype=dtype)
else:
if staticlen(out.shape) != staticlen(ans_shape):
compile_error("output parameter has the wrong number of dimensions")
if out.dtype is not dtype:
compile_error("output parameter has the wrong dtype")
if out.shape != ans_shape:
raise ValueError("output parameter has the wrong shape")
ans = out
for idx0 in multirange(ans_shape):
i = 0
t = zero(dtype)
while ((offset >= 0 and (i < m and i + offset < n)) or
(offset < 0 and (i - offset < m and i < n))):
idx = (tuple_insert(
tuple_insert(idx0, axis1, i),
axis2, i + offset) if offset >= 0 else
tuple_insert(
tuple_insert(idx0, axis1, i - offset),
axis2, i))
e = a._ptr(idx)[0]
t += cast(e, dtype)
i += 1
q = ans._ptr(idx0)
q[0] = t
return ans
def trace(x, offset: int = 0, dtype: type = NoneType):
return _trace_ufunc(x, offset=offset, dtype=dtype)
def tensorsolve(a, b, axes = None):
a = asarray(a)
b = asarray(b)
an: Static[int] = staticlen(a.shape)
bn: Static[int] = staticlen(b.shape)
if axes is not None:
allaxes = list(range(0, an))
for k in axes:
allaxes.remove(k)
allaxes.insert(an, k)
a = a.transpose(allaxes)
oldshape = a.shape[-(an-bn):]
prod = 1
for k in oldshape:
prod *= k
if a.size != prod ** 2:
raise LinAlgError(
"Input arrays must satisfy the requirement"
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
)
a = a.reshape(prod, prod)
b = b.ravel()
res = solve(a, b)
return res.reshape(oldshape)
def tensorinv(a, ind: int = 2):
a = asarray(a)
oldshape = a.shape
prod = 1
invshape = oldshape
if ind > 0:
# invshape = oldshape[ind:] + oldshape[:ind]
poldshape = Ptr[int](__ptr__(oldshape).as_byte())
pinvshape = Ptr[int](__ptr__(invshape).as_byte())
i = ind
j = 0
while i < len(oldshape):
pinvshape[j] = poldshape[i]
prod *= poldshape[i]
i += 1
j += 1
i = 0
while i < min(ind, len(oldshape)):
pinvshape[j] = poldshape[i]
i += 1
j += 1
else:
raise ValueError("Invalid ind argument.")
a = a.reshape(prod, -1)
ia = inv(a)
return ia.reshape(*invshape)
def _square(elem: T, T: type):
if T is complex or T is complex64:
return (elem.real * elem.real) + (elem.imag * elem.imag)
else:
return elem * elem
def _norm_cast(elem: T, T: type):
if T is float or T is float32 or T is complex or T is complex64:
return elem
else:
return cast(elem, float)
def _vector_norm_f(x: ndarray, axis: int, R: type):
if staticlen(x.shape) == 1:
dtype = x.dtype
sh = x.shape[0]
st = x.strides[0]
data = x.data
curr = R()
for i in range(sh):
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
curr += _square(elem)
return util_sqrt(curr)
else:
s = x.shape
n = s[axis]
ans_shape = tuple_delete(x.shape, axis)
ans = empty(ans_shape, R)
for idx in multirange(ans_shape):
q = ans._ptr(idx)
curr = R()
for i in range(n):
idx1 = tuple_insert(idx, axis, i)
elem = _norm_cast(x._ptr(idx1)[0])
curr += _square(elem)
q[0] = util_sqrt(curr)
return ans
def _norm_error_zero_max():
raise ValueError("zero-size array to reduction operation "
"maximum which has no identity")
def _norm_error_zero_min():
raise ValueError("zero-size array to reduction operation "
"minimum which has no identity")
def _vector_norm_inf(x: ndarray, axis: int, R: type):
sh = x.shape[0]
if sh == 0:
_norm_error_zero_max()
if staticlen(x.shape) == 1:
dtype = x.dtype
sh = x.shape[0]
st = x.strides[0]
data = x.data
curr = abs(_norm_cast(data[0]))
for i in range(1, sh):
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
e = abs(elem)
if e > curr:
curr = e
return curr
else:
s = x.shape
n = s[axis]
ans_shape = tuple_delete(x.shape, axis)
ans = empty(ans_shape, R)
for idx in multirange(ans_shape):
q = ans._ptr(idx)
idx1 = tuple_insert(idx, axis, 0)
curr = abs(_norm_cast(x._ptr(idx1)[0]))
for i in range(1, n):
idx1 = tuple_insert(idx, axis, i)
elem = _norm_cast(x._ptr(idx1)[0])
e = abs(elem)
if e > curr:
curr = e
q[0] = curr
return ans
def _vector_norm_ninf(x: ndarray, axis: int, R: type):
sh = x.shape[0]
if sh == 0:
_norm_error_zero_min()
if staticlen(x.shape) == 1:
dtype = x.dtype
st = x.strides[0]
data = x.data
curr = abs(_norm_cast(data[0]))
for i in range(1, sh):
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
e = abs(elem)
if e < curr:
curr = e
return curr
else:
s = x.shape
n = s[axis]
ans_shape = tuple_delete(x.shape, axis)
ans = empty(ans_shape, R)
for idx in multirange(ans_shape):
q = ans._ptr(idx)
idx1 = tuple_insert(idx, axis, 0)
curr = abs(_norm_cast(x._ptr(idx1)[0]))
for i in range(1, n):
idx1 = tuple_insert(idx, axis, i)
elem = _norm_cast(x._ptr(idx1)[0])
e = abs(elem)
if e < curr:
curr = e
q[0] = curr
return ans
def _vector_norm_0(x: ndarray, axis: int, R: type):
if staticlen(x.shape) == 1:
dtype = x.dtype
sh = x.shape[0]
st = x.strides[0]
data = x.data
curr = R()
for i in range(sh):
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
if elem:
curr += R(1)
return curr
else:
s = x.shape
n = s[axis]
ans_shape = tuple_delete(x.shape, axis)
ans = empty(ans_shape, R)
for idx in multirange(ans_shape):
q = ans._ptr(idx)
idx1 = tuple_insert(idx, axis, 0)
curr = R()
for i in range(n):
idx1 = tuple_insert(idx, axis, i)
elem = _norm_cast(x._ptr(idx1)[0])
if elem:
curr += R(1)
q[0] = curr
return ans
def _vector_norm_1(x: ndarray, axis: int, R: type):
if staticlen(x.shape) == 1:
dtype = x.dtype
sh = x.shape[0]
st = x.strides[0]
data = x.data
curr = R()
for i in range(sh):
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
curr += abs(elem)
return curr
else:
s = x.shape
n = s[axis]
ans_shape = tuple_delete(x.shape, axis)
ans = empty(ans_shape, R)
for idx in multirange(ans_shape):
q = ans._ptr(idx)
idx1 = tuple_insert(idx, axis, 0)
curr = R()
for i in range(n):
idx1 = tuple_insert(idx, axis, i)
elem = _norm_cast(x._ptr(idx1)[0])
curr += abs(elem)
q[0] = curr
return ans
def _vector_norm_g(x: ndarray, axis: int, g, R: type):
if staticlen(x.shape) == 1:
dtype = x.dtype
sh = x.shape[0]
st = x.strides[0]
data = x.data
curr = R()
for i in range(sh):
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
curr += abs(elem) ** R(g)
return curr ** (R(1) / R(g))
else:
s = x.shape
n = s[axis]
ans_shape = tuple_delete(x.shape, axis)
ans = empty(ans_shape, R)
for idx in multirange(ans_shape):
q = ans._ptr(idx)
idx1 = tuple_insert(idx, axis, 0)
curr = R()
for i in range(n):
idx1 = tuple_insert(idx, axis, i)
elem = _norm_cast(x._ptr(idx1)[0])
curr += abs(elem) ** R(g)
q[0] = curr ** (R(1) / R(g))
return ans
def _matrix_norm_f(x: ndarray, row_axis: int, col_axis: int, R: type):
s = x.shape
if staticlen(s) == 2:
dtype = x.dtype
sh1, sh2 = s
data = x.data
curr = R()
for i in range(sh1):
for j in range(sh2):
elem = _norm_cast(x._ptr((i, j))[0])
curr += _square(elem)
return util_sqrt(curr)
else:
axis1, axis2 = row_axis, col_axis
if axis1 > axis2:
axis1, axis2 = axis2, axis1
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
ans = empty(ans_shape, R)
sh1 = s[axis1]
sh2 = s[axis2]
for idx in multirange(ans_shape):
q = ans._ptr(idx)
curr = R()
for i in range(sh1):
for j in range(sh2):
idx1 = tuple_insert(
tuple_insert(idx, axis1, i), axis2, j)
elem = _norm_cast(x._ptr(idx1)[0])
curr += _square(elem)
q[0] = util_sqrt(curr)
return ans
def _matrix_norm_1(x: ndarray, row_axis: int, col_axis: int, R: type):
s = x.shape
if s[col_axis] == 0:
_norm_error_zero_max()
if staticlen(s) == 2:
dtype = x.dtype
row_dim = s[row_axis]
col_dim = s[col_axis]
row_stride = x.strides[row_axis]
col_stride = x.strides[col_axis]
data = x.data
curr = R()
first = True
for col_idx in range(col_dim):
sub = R()
for row_idx in range(row_dim):
elem = _norm_cast((Ptr[dtype](data.as_byte() +
(row_stride * row_idx) +
(col_stride * col_idx)))[0])
sub += abs(elem)
if first or sub > curr:
curr = sub
first = False
return curr
else:
axis1, axis2 = row_axis, col_axis
if axis1 > axis2:
axis1, axis2 = axis2, axis1
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
ans = empty(ans_shape, R)
row_dim = s[row_axis]
col_dim = s[col_axis]
for idx in multirange(ans_shape):
q = ans._ptr(idx)
curr = R()
first = True
for col_idx in range(col_dim):
sub = R()
for row_idx in range(row_dim):
i, j = row_idx, col_idx
if row_axis > col_axis:
i, j = j, i
idx1 = tuple_insert(
tuple_insert(idx, axis1, i), axis2, j)
elem = _norm_cast(x._ptr(idx1)[0])
sub += abs(elem)
if first or sub > curr:
curr = sub
first = False
q[0] = curr
return ans
def _matrix_norm_inf(x: ndarray, row_axis: int, col_axis: int, R: type):
s = x.shape
if s[row_axis] == 0:
_norm_error_zero_max()
if staticlen(s) == 2:
dtype = x.dtype
row_dim = s[row_axis]
col_dim = s[col_axis]
row_stride = x.strides[row_axis]
col_stride = x.strides[col_axis]
data = x.data
curr = R()
first = True
for row_idx in range(row_dim):
sub = R()
for col_idx in range(col_dim):
elem = _norm_cast((Ptr[dtype](data.as_byte() +
(row_stride * row_idx) +
(col_stride * col_idx)))[0])
sub += abs(elem)
if first or sub > curr:
curr = sub
first = False
return curr
else:
axis1, axis2 = row_axis, col_axis
if axis1 > axis2:
axis1, axis2 = axis2, axis1
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
ans = empty(ans_shape, R)
row_dim = s[row_axis]
col_dim = s[col_axis]
for idx in multirange(ans_shape):
q = ans._ptr(idx)
curr = R()
first = True
for row_idx in range(row_dim):
sub = R()
for col_idx in range(col_dim):
i, j = row_idx, col_idx
if row_axis > col_axis:
i, j = j, i
idx1 = tuple_insert(
tuple_insert(idx, axis1, i), axis2, j)
elem = _norm_cast(x._ptr(idx1)[0])
sub += abs(elem)
if first or sub > curr:
curr = sub
first = False
q[0] = curr
return ans
def _matrix_norm_n1(x: ndarray, row_axis: int, col_axis: int, R: type):
s = x.shape
if s[col_axis] == 0:
_norm_error_zero_min()
if staticlen(s) == 2:
dtype = x.dtype
row_dim = s[row_axis]
col_dim = s[col_axis]
row_stride = x.strides[row_axis]
col_stride = x.strides[col_axis]
data = x.data
curr = R()
first = True
for col_idx in range(col_dim):
sub = R()
for row_idx in range(row_dim):
elem = _norm_cast((Ptr[dtype](data.as_byte() +
(row_stride * row_idx) +
(col_stride * col_idx)))[0])
sub += abs(elem)
if first or sub < curr:
curr = sub
first = False
return curr
else:
axis1, axis2 = row_axis, col_axis
if axis1 > axis2:
axis1, axis2 = axis2, axis1
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
ans = empty(ans_shape, R)
row_dim = s[row_axis]
col_dim = s[col_axis]
for idx in multirange(ans_shape):
q = ans._ptr(idx)
curr = R()
first = True
for col_idx in range(col_dim):
sub = R()
for row_idx in range(row_dim):
i, j = row_idx, col_idx
if row_axis > col_axis:
i, j = j, i
idx1 = tuple_insert(
tuple_insert(idx, axis1, i), axis2, j)
elem = _norm_cast(x._ptr(idx1)[0])
sub += abs(elem)
if first or sub < curr:
curr = sub
first = False
q[0] = curr
return ans
def _matrix_norm_ninf(x: ndarray, row_axis: int, col_axis: int, R: type):
s = x.shape
if s[row_axis] == 0:
_norm_error_zero_min()
if staticlen(s) == 2:
dtype = x.dtype
row_dim = s[row_axis]
col_dim = s[col_axis]
row_stride = x.strides[row_axis]
col_stride = x.strides[col_axis]
data = x.data
curr = R()
first = True
for row_idx in range(row_dim):
sub = R()
for col_idx in range(col_dim):
elem = _norm_cast((Ptr[dtype](data.as_byte() +
(row_stride * row_idx) +
(col_stride * col_idx)))[0])
sub += abs(elem)
if first or sub < curr:
curr = sub
first = False
return curr
else:
axis1, axis2 = row_axis, col_axis
if axis1 > axis2:
axis1, axis2 = axis2, axis1
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
ans = empty(ans_shape, R)
row_dim = s[row_axis]
col_dim = s[col_axis]
for idx in multirange(ans_shape):
q = ans._ptr(idx)
curr = R()
first = True
for row_idx in range(row_dim):
sub = R()
for col_idx in range(col_dim):
i, j = row_idx, col_idx
if row_axis > col_axis:
i, j = j, i
idx1 = tuple_insert(
tuple_insert(idx, axis1, i), axis2, j)
elem = _norm_cast(x._ptr(idx1)[0])
sub += abs(elem)
if first or sub < curr:
curr = sub
first = False
q[0] = curr
return ans
def _multi_svd_norm(x: ndarray, row_axis: int, col_axis: int, op):
y = moveaxis(x, (row_axis, col_axis), (-2, -1))
return op(svd(y, compute_uv=False), axis=-1)
def _norm_wrap(x, kd_shape, keepdims: Static[int]):
if keepdims:
return asarray(x).reshape(kd_shape)
else:
return x
def norm(x, ord = None, axis = None, keepdims: Static[int] = False):
x = asarray(x)
if not (x.dtype is float or x.dtype is float32 or x.dtype is float16):
return norm(x.astype(float), ord=ord, axis=axis, keepdims=keepdims)
ndim: Static[int] = staticlen(x.shape)
dtype = x.dtype
if dtype is complex or dtype is float:
r = float()
elif dtype is complex64 or dtype is float32:
r = float32()
else:
r = float()
R = type(r)
# Common cases
if axis is None:
handle = False
if ord is None:
handle = True
elif isinstance(ord, str):
handle = ((ord == 'f' or ord == 'fro') and ndim == 2)
elif isinstance(ord, int):
handle = (ord == 2 and ndim == 1)
if handle:
y = x.ravel(order='K')
if dtype is complex or dtype is complex64:
x_real = y.real
x_imag = y.imag
sqnorm = x_real.dot(x_real) + x_imag.dot(x_imag)
else:
sqnorm = y.dot(y)
ret = sqrt(sqnorm)
if keepdims:
return asarray(ret).reshape((1,) * ndim)
else:
return ret
if axis is None:
ax = tuple_range(ndim)
elif isinstance(axis, int):
ax = (axis,)
elif isinstance(axis, Tuple):
ax = axis
else:
compile_error("'axis' must be None, an integer or a tuple of integers")
if ndim == 0:
# this matches NumPy behavior
if ord is None:
return abs(x.data[0])
else:
compile_error("Improper number of dimensions to norm.")
ax = tuple(normalize_axis_index(a, ndim) for a in ax)
kd_shape = x.shape
if keepdims:
for a in ax:
kd_shape = tuple_set(kd_shape, a, 1)
if staticlen(ax) == 1:
if ord is None:
return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims)
elif isinstance(ord, int):
if ord == 0:
return _norm_wrap(_vector_norm_0(x, ax[0], R), kd_shape, keepdims)
elif ord == 1:
return _norm_wrap(_vector_norm_1(x, ax[0], R), kd_shape, keepdims)
elif ord == 2:
return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims)
else:
return _norm_wrap(_vector_norm_g(x, ax[0], float(ord), R), kd_shape, keepdims)
elif isinstance(ord, float):
if ord == 0.:
return _norm_wrap(_vector_norm_0(x, ax[0], R), kd_shape, keepdims)
elif ord == 1.:
return _norm_wrap(_vector_norm_1(x, ax[0], R), kd_shape, keepdims)
elif ord == 2.:
return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims)
elif ord == inf(float):
return _norm_wrap(_vector_norm_inf(x, ax[0], R), kd_shape, keepdims)
elif ord == -inf(float):
return _norm_wrap(_vector_norm_ninf(x, ax[0], R), kd_shape, keepdims)
else:
return _norm_wrap(_vector_norm_g(x, ax[0], ord, R), kd_shape, keepdims)
else:
compile_error("Invalid norm order for vectors")
elif staticlen(ax) == 2:
row_axis, col_axis = ax
if row_axis == col_axis:
raise ValueError("Duplicate axes given.")
if ord is None:
return _norm_wrap(_matrix_norm_f(x, ax[0], ax[1], R), kd_shape, keepdims)
elif isinstance(ord, int):
if ord == 2:
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.max), kd_shape, keepdims)
elif ord == -2:
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.min), kd_shape, keepdims)
elif ord == 1:
return _norm_wrap(_matrix_norm_1(x, ax[0], ax[1], R), kd_shape, keepdims)
elif ord == -1:
return _norm_wrap(_matrix_norm_n1(x, ax[0], ax[1], R), kd_shape, keepdims)
elif isinstance(ord, float):
if ord == 2.:
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.max), kd_shape, keepdims)
elif ord == -2.:
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.min), kd_shape, keepdims)
elif ord == 1.:
return _norm_wrap(_matrix_norm_1(x, ax[0], ax[1], R), kd_shape, keepdims)
elif ord == -1.:
return _norm_wrap(_matrix_norm_n1(x, ax[0], ax[1], R), kd_shape, keepdims)
elif ord == inf(float):
return _norm_wrap(_matrix_norm_inf(x, ax[0], ax[1], R), kd_shape, keepdims)
elif ord == -inf(float):
return _norm_wrap(_matrix_norm_ninf(x, ax[0], ax[1], R), kd_shape, keepdims)
elif isinstance(ord, str):
if ord == 'fro' or ord == 'f':
return _norm_wrap(_matrix_norm_f(x, ax[0], ax[1], R), kd_shape, keepdims)
elif ord == 'nuc':
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.sum), kd_shape, keepdims)
raise ValueError("Invalid norm order for matrices.")
else:
raise ValueError("Improper number of dimensions to norm.")
def pinv(a, rcond = 1e-15, hermitian: bool = False):
a = asarray(a)
rcond = asarray(rcond)
m, n = _rows_cols(a)
if m == 0 or n == 0:
ans_shape = a.shape[:-2] + (n, m)
dtype = a.dtype
if (dtype is float or
dtype is float32 or
dtype is complex or
dtype is complex64):
return empty(ans_shape, dtype=dtype)
else:
return empty(ans_shape, dtype=float)
a = a.conjugate()
u, s, vt = svd(a, full_matrices=False, hermitian=hermitian)
cutoff = multiply(rcond[..., None], s.max(axis=-1, keepdims=True))
S = s.dtype
for idx in multirange(broadcast_shapes(s.shape, cutoff.shape)):
c = cutoff._ptr(idx, broadcast=True)[0]
p = s._ptr(idx, broadcast=True)
e = p[0]
if greater(e, c):
e = cast(divide(1, e), S)
else:
e = S()
p[0] = e
vt_t = swapaxes(vt, -1, -2)
u_t = swapaxes(u, -1, -2)
res = matmul(vt_t, multiply(s[..., None], u_t))
return res
def matrix_rank(A, tol = None, hermitian: bool = False):
A = asarray(A)
dtype = A.dtype
if staticlen(A.shape) == 0:
return int(bool(A.data[0]))
elif staticlen(A.shape) == 1:
for a in A:
if a:
return 1
return 0
elif staticlen(A.shape) == 2:
m, n = _rows_cols(A)
if m == 0 or n == 0:
raise ValueError("cannot take rank of empty matrix")
S = svd(A, compute_uv=False, hermitian=hermitian)
R = S.dtype
if staticlen(asarray(tol).shape) != 0:
compile_error("invalid tolerance dimension")
s_max = -inf(R)
if tol is None:
for s in S:
if s > s_max:
s_max = s
r = 0
for s in S:
if tol is None:
if s > s_max * cast(max(m, n), R) * eps(R):
r += 1
else:
if s > cast(tol, R):
r += 1
return r
else:
m, n = _rows_cols(A)
if m == 0 or n == 0:
raise ValueError("cannot take rank of empty matrix")
k = min(m, n)
pre_shape = A.shape[:-2]
ans = empty(pre_shape, int)
S = svd(A, compute_uv=False, hermitian=hermitian)
R = S.dtype
if tol is None:
t = 0
else:
t = asarray(tol)
if staticlen(t.shape) > 0:
if staticlen(t.shape) != staticlen(A.shape) - 2:
compile_error("invalid tolerance dimension")
else:
if t.shape != A.shape[:-2]:
raise ValueError("tolerance does not broadcast against matrix_rank input")
for idx in multirange(pre_shape):
q = ans._ptr(idx)
r = 0
s_max = -inf(R)
if tol is None:
for i in range(k):
e = S._ptr(idx + (i,))[0]
if e > s_max:
s_max = e
for i in range(k):
e = S._ptr(idx + (i,))[0]
if tol is None:
if e > s_max * cast(max(m, n), R) * eps(R):
r += 1
else:
if staticlen(t.shape) == 0:
if e > cast(t.data[0], R):
r += 1
else:
if e > cast(t._ptr(idx)[0], R):
r += 1
q[0] = r
return ans
def cond(x, p = None):
def cond_r_1(x, p: int):
s = svd(x, compute_uv=False)
if p == -2:
return s[..., -1] / s[..., 0]
else:
return s[..., 0] / s[..., -1]
def cond_r_2(x, p):
invx = _inv(x, ignore_errors=True)
return norm(x, p, axis=(-2, -1)) * norm(invx, p, axis=(-2, -1))
x = asarray(x)
m, n = _rows_cols(x)
if m == 0 or n == 0:
raise LinAlgError("cond is not defined on empty arrays")
if p is None:
r = cond_r_1(x, 0)
elif isinstance(p, int):
if p == 2 or p == -2:
r = cond_r_1(x, p)
else:
r = cond_r_2(x, p)
else:
r = cond_r_2(x, p)
r = asarray(r)
if staticlen(r.shape) == 0:
any_nan = False
for i in range(m):
for j in range(n):
if isnan(x._ptr((i, j))[0]):
any_nan = True
break
rval = r.data[0]
if isnan(rval) and not any_nan:
return inf(type(rval))
else:
return rval
else:
for idx in multirange(r.shape):
any_nan = False
for i in range(m):
for j in range(n):
if isnan(x._ptr(idx + (i, j))[0]):
any_nan = True
break
p = r._ptr(idx)
if isnan(p[0]) and not any_nan:
p[0] = inf(type(p[0]))
return r
def matrix_transpose(x):
x = asarray(x)
if x.ndim < 2:
compile_error("Input array must be at least 2-dimensional")
return swapaxes(x, -1, -2)
@extend
class ndarray:
def __matmul__(self, other):
return matmul(self, other)
def dot(self, other):
return dot(self, other)
def trace(self, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: type = NoneType, out = None):
return _trace_ufunc(self, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out)