mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
4341 lines
122 KiB
Python
4341 lines
122 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
|
|
|
from .blas import *
|
|
from ..ndarray import ndarray
|
|
from ..ndmath import divide, greater, isnan, multiply, sqrt
|
|
from ..routines import atleast_2d, asarray, broadcast_shapes, broadcast_to, empty, \
|
|
empty_like, eye, expand_dims, moveaxis, reshape, swapaxes, \
|
|
zeros, array
|
|
from ..util import cast, cdiv_int, coerce, eps, exp, free, inf, log, multirange, \
|
|
nan, normalize_axis_index, sizeof, sort, sqrt as util_sqrt, \
|
|
tuple_delete, tuple_insert, tuple_range, tuple_set, zero
|
|
|
|
|
|
#############
|
|
# Utilities #
|
|
#############
|
|
|
|
class LinAlgError(Static[Exception]):
|
|
def __init__(self, message: str = ''):
|
|
super().__init__("numpy.linalg.LinAlgError", message)
|
|
|
|
def _square_rows(a):
|
|
s = a.shape
|
|
|
|
if staticlen(s) < 2:
|
|
compile_error("Array must be at least 2-dimensional")
|
|
|
|
n = s[-1]
|
|
if n != s[-2]:
|
|
raise LinAlgError("Last 2 dimensions of the array must be square")
|
|
return n
|
|
|
|
def _rows_cols(a):
|
|
s = a.shape
|
|
|
|
if staticlen(s) < 2:
|
|
compile_error("Array must be at least 2-dimensional")
|
|
|
|
m, n = s[-2:]
|
|
return m, n
|
|
|
|
def _asarray(a, dtype: type = NoneType):
|
|
if dtype is NoneType:
|
|
if isinstance(a, ndarray):
|
|
if (a.dtype is float or a.dtype is float32 or
|
|
a.dtype is complex or a.dtype is complex64):
|
|
return a
|
|
else:
|
|
return a.astype(float)
|
|
else:
|
|
dtype1 = type(asarray(a).data[0])
|
|
if (dtype1 is float or dtype1 is float32 or
|
|
dtype1 is complex or dtype1 is complex64):
|
|
return asarray(a, dtype=dtype1)
|
|
else:
|
|
return asarray(a, dtype=float)
|
|
elif (dtype is float or dtype is float32 or
|
|
dtype is complex or dtype is complex64):
|
|
return asarray(a, dtype=dtype)
|
|
else:
|
|
return asarray(a, dtype=float)
|
|
|
|
def _basetype(dtype: type):
|
|
if dtype is float or dtype is float32:
|
|
return dtype()
|
|
elif dtype is complex:
|
|
return float()
|
|
elif dtype is complex64:
|
|
return float32()
|
|
else:
|
|
compile_error("[internal error] bad dtype")
|
|
|
|
def _complextype(dtype: type):
|
|
if dtype is float:
|
|
return complex()
|
|
elif dtype is float32:
|
|
return complex64()
|
|
elif dtype is complex or dtype is complex64:
|
|
return dtype()
|
|
else:
|
|
compile_error("[internal error] bad dtype")
|
|
|
|
def _copy(n: int, sx: Ptr[T], incx: int, sy: Ptr[T], incy: int, T: type):
|
|
args = (fint(n), sx, fint(incx), sy, fint(incy))
|
|
if T is float:
|
|
cblas_dcopy(*args)
|
|
elif T is float32:
|
|
cblas_scopy(*args)
|
|
elif T is complex:
|
|
cblas_zcopy(*args)
|
|
elif T is complex64:
|
|
cblas_ccopy(*args)
|
|
else:
|
|
compile_error("[internal error] bad type for BLAS copy: " + T.__name__)
|
|
|
|
def _nan(T: type):
|
|
if T is float or T is float32:
|
|
return nan(T)
|
|
elif T is complex:
|
|
return complex(nan(float), nan(float))
|
|
elif T is complex64:
|
|
return complex64(nan(float32), nan(float32))
|
|
else:
|
|
compile_error("[internal error] bad dtype")
|
|
|
|
@tuple
|
|
class LinearizeData:
|
|
rows: int
|
|
cols: int
|
|
row_strides: int
|
|
col_strides: int
|
|
out_lead_dim: int
|
|
|
|
def __new__(rows: int, cols: int, row_strides: int, col_strides: int,
|
|
out_lead_dim: int) -> LinearizeData:
|
|
return (rows, cols, row_strides, col_strides, out_lead_dim)
|
|
|
|
def __new__(rows: int, cols: int, row_strides: int, col_strides: int):
|
|
return LinearizeData(rows, cols, row_strides, col_strides, cols)
|
|
|
|
def linearize(self, dst: Ptr[T], src: Ptr[T], T: type):
|
|
if dst:
|
|
rv = dst
|
|
cols = self.cols
|
|
col_strides = cdiv_int(self.col_strides, sizeof(T))
|
|
for i in range(self.rows):
|
|
if col_strides > 0:
|
|
_copy(cols, src, col_strides, dst, 1)
|
|
elif col_strides < 0:
|
|
_copy(cols, src + (cols - 1)*col_strides, col_strides, dst, 1);
|
|
else:
|
|
for j in range(cols):
|
|
dst[j] = src[0]
|
|
src += cdiv_int(self.row_strides, sizeof(T))
|
|
dst += self.out_lead_dim
|
|
return rv
|
|
else:
|
|
return src
|
|
|
|
def delinearize(self, dst: Ptr[T], src: Ptr[T], T: type):
|
|
if src:
|
|
rv = src
|
|
cols = self.cols
|
|
col_strides = cdiv_int(self.col_strides, sizeof(T))
|
|
for i in range(self.rows):
|
|
if col_strides > 0:
|
|
_copy(cols, src, 1, dst, col_strides);
|
|
elif col_strides < 0:
|
|
_copy(cols, src, 1, dst + (cols - 1)*col_strides, col_strides)
|
|
else:
|
|
if cols > 0:
|
|
dst[0] = src[cols - 1]
|
|
src += self.out_lead_dim
|
|
dst += cdiv_int(self.row_strides, sizeof(T))
|
|
return rv
|
|
else:
|
|
return src
|
|
|
|
def nan_matrix(self, dst: Ptr[T], T: type):
|
|
for i in range(self.rows):
|
|
cp = dst
|
|
cs = cdiv_int(self.col_strides, sizeof(T))
|
|
for j in range(self.cols):
|
|
cp[0] = _nan(T)
|
|
cp += cs
|
|
dst += cdiv_int(self.row_strides, sizeof(T))
|
|
|
|
def zero_matrix(self, dst: Ptr[T], T: type):
|
|
for i in range(self.rows):
|
|
cp = dst
|
|
cs = cdiv_int(self.col_strides, sizeof(T))
|
|
for j in range(self.cols):
|
|
cp[0] = T()
|
|
cp += cs
|
|
dst += cdiv_int(self.row_strides, sizeof(T))
|
|
|
|
def identity_matrix(dst: Ptr[T], n: int, T: type):
|
|
str.memset(dst.as_byte(), byte(0), n * n * sizeof(T))
|
|
for i in range(n):
|
|
dst[0] = T(1)
|
|
dst += n + 1
|
|
|
|
|
|
###############
|
|
# Determinant #
|
|
###############
|
|
|
|
def _slogdet_from_factored_diagonal(src: Ptr[T],
|
|
m: int,
|
|
sign: T,
|
|
T: type):
|
|
if T is float or T is float32:
|
|
acc_sign = sign
|
|
acc_logdet = T(0.0)
|
|
for i in range(m):
|
|
abs_element = src[0]
|
|
if abs_element < T(0.0):
|
|
acc_sign = -acc_sign
|
|
abs_element = -abs_element
|
|
acc_logdet += log(abs_element)
|
|
src += m + 1
|
|
return acc_sign, acc_logdet
|
|
elif T is complex or T is complex64:
|
|
B = type(_basetype(T))
|
|
acc_sign = sign
|
|
acc_logdet = B()
|
|
for i in range(m):
|
|
abs_element = abs(src[0])
|
|
sign_element = src[0] / abs_element
|
|
acc_sign *= sign_element
|
|
acc_logdet += log(abs_element)
|
|
src += m + 1
|
|
return acc_sign, acc_logdet
|
|
else:
|
|
compile_error("[internal error] invalid type for _slogdet_from_factored_diagonal: " + T.__name__)
|
|
|
|
def _slogdet_single_element(m: int,
|
|
src: Ptr[T],
|
|
pivots: Ptr[fint],
|
|
T: type):
|
|
m0 = fint(m)
|
|
lda = fint(max(m, 1))
|
|
info = fint(0)
|
|
args = (__ptr__(m0), __ptr__(m0), src.as_byte(), __ptr__(lda), pivots, __ptr__(info))
|
|
if T is float:
|
|
dgetrf_(*args)
|
|
elif T is float32:
|
|
sgetrf_(*args)
|
|
elif T is complex:
|
|
zgetrf_(*args)
|
|
elif T is complex64:
|
|
cgetrf_(*args)
|
|
else:
|
|
compile_error("[internal error] bad input type")
|
|
|
|
sign = T()
|
|
if info == fint(0):
|
|
change_sign = False
|
|
for i in range(m):
|
|
if pivots[i] != fint(i + 1):
|
|
change_sign = not change_sign
|
|
sign = T(-1.0 if change_sign else 1.0)
|
|
logdet = zero(_basetype(T))
|
|
sign, logdet = _slogdet_from_factored_diagonal(src, m, sign)
|
|
return sign, logdet
|
|
else:
|
|
return T(0.0), -inf(type(_basetype(T)))
|
|
|
|
def _det_from_slogdet(sign: T, logdet, T: type):
|
|
if T is float or T is float32:
|
|
return sign * exp(logdet)
|
|
elif T is complex or T is complex64:
|
|
return sign * T(exp(logdet))
|
|
else:
|
|
compile_error("[internal error] invalid type for _det_from_slogdet: " + T.__name__)
|
|
|
|
@tuple
|
|
class SlogdetResult[A, B]:
|
|
sign: A
|
|
logabsdet: B
|
|
|
|
def __getitem__(self, idx: Static[int]):
|
|
if idx == 0 or idx == -2:
|
|
return self.sign
|
|
elif idx == 1 or idx == -1:
|
|
return self.logabsdet
|
|
else:
|
|
compile_error("tuple ('SlogdetResult') index out of range")
|
|
|
|
def slogdet(a):
|
|
a = _asarray(a)
|
|
T = a.dtype
|
|
s = a.shape
|
|
m = _square_rows(a)
|
|
steps = a.strides
|
|
|
|
tmp = Ptr[T](m * m)
|
|
piv = Ptr[fint](m)
|
|
lin_data = LinearizeData(m, m, steps[-1], steps[-2])
|
|
|
|
if a.ndim == 2:
|
|
lin_data.linearize(tmp, a.data)
|
|
sign, logabsdet = _slogdet_single_element(m, tmp, piv)
|
|
else:
|
|
ans_shape = s[:-2]
|
|
sign = empty(ans_shape, dtype=a.dtype)
|
|
logabsdet = empty(ans_shape, dtype=type(_basetype(a.dtype)))
|
|
|
|
for idx in multirange(ans_shape):
|
|
lin_data.linearize(tmp, a._ptr(idx + (0, 0)))
|
|
sign0, logabsdet0 = _slogdet_single_element(m, tmp, piv)
|
|
sign._ptr(idx)[0] = sign0
|
|
logabsdet._ptr(idx)[0] = logabsdet0
|
|
|
|
free(tmp)
|
|
free(piv)
|
|
return SlogdetResult(sign, logabsdet)
|
|
|
|
def det(a):
|
|
a = _asarray(a)
|
|
T = a.dtype
|
|
s = a.shape
|
|
m = _square_rows(a)
|
|
steps = a.strides
|
|
|
|
tmp = Ptr[T](m * m)
|
|
piv = Ptr[fint](m)
|
|
lin_data = LinearizeData(m, m, steps[-1], steps[-2])
|
|
|
|
if a.ndim == 2:
|
|
lin_data.linearize(tmp, a.data)
|
|
sign, logabsdet = _slogdet_single_element(m, tmp, piv)
|
|
ans = _det_from_slogdet(sign, logabsdet)
|
|
else:
|
|
ans_shape = s[:-2]
|
|
ans = empty(ans_shape, dtype=a.dtype)
|
|
|
|
for idx in multirange(ans_shape):
|
|
lin_data.linearize(tmp, a._ptr(idx + (0, 0)))
|
|
sign0, logabsdet0 = _slogdet_single_element(m, tmp, piv)
|
|
ans._ptr(idx)[0] = _det_from_slogdet(sign0, logabsdet0)
|
|
|
|
free(tmp)
|
|
free(piv)
|
|
return ans
|
|
|
|
|
|
########
|
|
# Eigh #
|
|
########
|
|
|
|
class EighParams:
|
|
A: Ptr[T]
|
|
W: Ptr[B]
|
|
WORK: Ptr[T]
|
|
RWORK: Ptr[B]
|
|
IWORK: Ptr[fint]
|
|
N: fint
|
|
LWORK: fint
|
|
LRWORK: fint
|
|
LIWORK: fint
|
|
JOBZ: byte
|
|
UPLO: byte
|
|
LDA: fint
|
|
T: type
|
|
B: type
|
|
|
|
def _init_real(self, JOBZ: byte, UPLO: byte, N: fint):
|
|
safe_N = int(N)
|
|
alloc_size = safe_N * (safe_N + 1) * sizeof(T)
|
|
lda = N if N else fint(1)
|
|
mem_buff = cobj(alloc_size)
|
|
|
|
a = mem_buff
|
|
w = mem_buff + safe_N * safe_N * sizeof(T)
|
|
|
|
self.A = Ptr[T](a)
|
|
self.W = Ptr[B](w)
|
|
self.RWORK = Ptr[B]()
|
|
self.N = N
|
|
self.LRWORK = fint(0)
|
|
self.JOBZ = JOBZ
|
|
self.UPLO = UPLO
|
|
self.LDA = lda
|
|
|
|
# work size query
|
|
query_work_size = T()
|
|
query_iwork_size = fint()
|
|
|
|
self.LWORK = fint(-1)
|
|
self.LIWORK = fint(-1)
|
|
self.WORK = __ptr__(query_work_size)
|
|
self.IWORK = __ptr__(query_iwork_size)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in evd work size query")
|
|
|
|
lwork = cast(query_work_size, int)
|
|
liwork = int(query_iwork_size)
|
|
|
|
mem_buff2 = cobj(lwork * sizeof(T) + liwork * sizeof(fint))
|
|
work = mem_buff2
|
|
iwork = mem_buff2 + lwork * sizeof(T)
|
|
|
|
self.LWORK = fint(lwork)
|
|
self.WORK = Ptr[T](work)
|
|
self.LIWORK = fint(liwork)
|
|
self.IWORK = Ptr[fint](iwork)
|
|
|
|
def _init_complex(self, JOBZ: byte, UPLO: byte, N: fint):
|
|
safe_N = int(N)
|
|
lda = N if N else fint(1)
|
|
mem_buff = cobj(safe_N * safe_N * sizeof(T) + safe_N * sizeof(B))
|
|
|
|
a = mem_buff
|
|
w = mem_buff + safe_N * safe_N * sizeof(T)
|
|
|
|
self.A = Ptr[T](a)
|
|
self.W = Ptr[B](w)
|
|
self.N = N
|
|
self.JOBZ = JOBZ
|
|
self.UPLO = UPLO
|
|
self.LDA = lda
|
|
|
|
# work size query
|
|
query_work_size = T()
|
|
query_rwork_size = B()
|
|
query_iwork_size = fint()
|
|
|
|
self.LWORK = fint(-1)
|
|
self.LRWORK = fint(-1)
|
|
self.LIWORK = fint(-1)
|
|
self.WORK = __ptr__(query_work_size)
|
|
self.RWORK = __ptr__(query_rwork_size)
|
|
self.IWORK = __ptr__(query_iwork_size)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in evd work size query")
|
|
|
|
lwork = cast(Ptr[B](__ptr__(query_work_size).as_byte())[0], int)
|
|
lrwork = cast(query_rwork_size, int)
|
|
liwork = int(query_iwork_size)
|
|
|
|
mem_buff2 = cobj(lwork * sizeof(T) + lrwork * sizeof(B) + liwork * sizeof(fint))
|
|
work = mem_buff2
|
|
rwork = work + lwork * sizeof(T)
|
|
iwork = rwork + lrwork * sizeof(B)
|
|
|
|
self.WORK = Ptr[T](work)
|
|
self.RWORK = Ptr[B](rwork)
|
|
self.IWORK = Ptr[fint](iwork)
|
|
self.LWORK = fint(lwork)
|
|
self.LRWORK = fint(lrwork)
|
|
self.LIWORK = fint(liwork)
|
|
|
|
def __init__(self, JOBZ: byte, UPLO: byte, N: fint):
|
|
if T is complex or T is complex64:
|
|
self._init_complex(JOBZ, UPLO, N)
|
|
else:
|
|
self._init_real(JOBZ, UPLO, N)
|
|
|
|
def release(self):
|
|
free(self.A)
|
|
free(self.WORK)
|
|
|
|
def call(self):
|
|
JOBZ = self.JOBZ
|
|
UPLO = self.UPLO
|
|
N = self.N
|
|
A = self.A
|
|
LDA = self.LDA
|
|
W = self.W
|
|
WORK = self.WORK
|
|
LWORK = self.LWORK
|
|
RWORK = self.RWORK
|
|
LRWORK = self.LRWORK
|
|
IWORK = self.IWORK
|
|
LIWORK = self.LIWORK
|
|
rv = fint()
|
|
|
|
args_real = (__ptr__(JOBZ),
|
|
__ptr__(UPLO),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
W.as_byte(),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
IWORK,
|
|
__ptr__(LIWORK),
|
|
__ptr__(rv))
|
|
|
|
args_cplx = (__ptr__(JOBZ),
|
|
__ptr__(UPLO),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
W.as_byte(),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
RWORK.as_byte(),
|
|
__ptr__(LRWORK),
|
|
IWORK,
|
|
__ptr__(LIWORK),
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dsyevd_(*args_real)
|
|
elif T is float32:
|
|
ssyevd_(*args_real)
|
|
elif T is complex:
|
|
zheevd_(*args_cplx)
|
|
elif T is complex64:
|
|
cheevd_(*args_cplx)
|
|
else:
|
|
compile_error("[internal error] bad dtype for eigh")
|
|
|
|
return rv
|
|
|
|
@tuple
|
|
class EighResult[A, B]:
|
|
eigenvalues: A
|
|
eigenvectors: B
|
|
|
|
def __getitem__(self, idx: Static[int]):
|
|
if idx == 0 or idx == -2:
|
|
return self.eigenvalues
|
|
elif idx == 1 or idx == -1:
|
|
return self.eigenvectors
|
|
else:
|
|
compile_error("tuple ('EighResult') index out of range")
|
|
|
|
def _eigh(a, JOBZ: byte, UPLO: byte, compute_eigenvectors: Static[int]):
|
|
a = _asarray(a)
|
|
B = type(_basetype(a.dtype))
|
|
|
|
n = _square_rows(a)
|
|
params = EighParams[a.dtype, B](JOBZ, UPLO, fint(n))
|
|
|
|
eigenvalues = empty(a.shape[:-2] + (n,), dtype=B)
|
|
if compute_eigenvectors:
|
|
eigenvectors = empty(a.shape, dtype=a.dtype)
|
|
else:
|
|
eigenvectors = None
|
|
|
|
matrix_in_ld = LinearizeData(n, n, a.strides[-1], a.strides[-2])
|
|
eigenvalues_out_ld = LinearizeData(1, n, 0, eigenvalues.strides[-1])
|
|
if compute_eigenvectors:
|
|
eigenvectors_out_ld = LinearizeData(n, n, eigenvectors.strides[-1], eigenvectors.strides[-2])
|
|
else:
|
|
eigenvectors_out_ld = None
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
matrix_ptr = a._ptr(idx + (0, 0))
|
|
eigval_ptr = eigenvalues._ptr(idx + (0,))
|
|
if compute_eigenvectors:
|
|
eigvec_ptr = eigenvectors._ptr(idx + (0, 0))
|
|
else:
|
|
eigvec_ptr = None
|
|
|
|
matrix_in_ld.linearize(params.A, matrix_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
eigenvalues_out_ld.delinearize(eigval_ptr, params.W)
|
|
if compute_eigenvectors:
|
|
eigenvectors_out_ld.delinearize(eigvec_ptr, params.A)
|
|
else:
|
|
eigenvalues_out_ld.nan_matrix(eigval_ptr)
|
|
if compute_eigenvectors:
|
|
eigenvectors_out_ld.nan_matrix(eigvec_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Eigenvalues did not converge")
|
|
|
|
if compute_eigenvectors:
|
|
return EighResult(eigenvalues, eigenvectors)
|
|
else:
|
|
return eigenvalues
|
|
|
|
def eigh(a, UPLO: str = 'L'):
|
|
uplo_code = byte()
|
|
if UPLO == 'L' or UPLO == 'l':
|
|
uplo_code = byte(76)
|
|
elif UPLO == 'U' or UPLO == 'u':
|
|
uplo_code = byte(85)
|
|
else:
|
|
raise ValueError("UPLO argument must be 'L' or 'U'")
|
|
|
|
jobz_code = byte(86)
|
|
return _eigh(a, JOBZ=jobz_code, UPLO=uplo_code, compute_eigenvectors=True)
|
|
|
|
def eigvalsh(a, UPLO: str = 'L'):
|
|
uplo_code = byte()
|
|
if UPLO == 'L' or UPLO == 'l':
|
|
uplo_code = byte(76)
|
|
elif UPLO == 'U' or UPLO == 'u':
|
|
uplo_code = byte(85)
|
|
else:
|
|
raise ValueError("UPLO argument must be 'L' or 'U'")
|
|
|
|
jobz_code = byte(78)
|
|
return _eigh(a, JOBZ=jobz_code, UPLO=uplo_code, compute_eigenvectors=False)
|
|
|
|
|
|
#########
|
|
# Solve #
|
|
#########
|
|
|
|
class GesvParams:
|
|
A: Ptr[T]
|
|
B: Ptr[T]
|
|
IPIV: Ptr[fint]
|
|
N: fint
|
|
NRHS: fint
|
|
LDA: fint
|
|
LDB: fint
|
|
T: type
|
|
|
|
def __init__(self, N: fint, NRHS: fint):
|
|
safe_N = int(N)
|
|
safe_NRHS = int(NRHS)
|
|
ld = N if N else fint(1)
|
|
mem_buff = cobj(safe_N * safe_N * sizeof(T) +
|
|
safe_N * safe_NRHS * sizeof(T) +
|
|
safe_N * sizeof(fint))
|
|
|
|
a = mem_buff
|
|
b = a + safe_N * safe_N * sizeof(T)
|
|
ipiv = b + safe_N * safe_NRHS * sizeof(T)
|
|
|
|
self.A = Ptr[T](a)
|
|
self.B = Ptr[T](b)
|
|
self.IPIV = Ptr[fint](ipiv)
|
|
self.N = N
|
|
self.NRHS = NRHS
|
|
self.LDA = ld
|
|
self.LDB = ld
|
|
|
|
def release(self):
|
|
free(self.A)
|
|
|
|
def call(self):
|
|
A = self.A
|
|
B = self.B
|
|
IPIV = self.IPIV
|
|
N = self.N
|
|
NRHS = self.NRHS
|
|
LDA = self.LDA
|
|
LDB = self.LDB
|
|
rv = fint()
|
|
|
|
args = (__ptr__(N),
|
|
__ptr__(NRHS),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
IPIV,
|
|
B.as_byte(),
|
|
__ptr__(LDB),
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dgesv_(*args)
|
|
elif T is float32:
|
|
sgesv_(*args)
|
|
elif T is complex:
|
|
zgesv_(*args)
|
|
elif T is complex64:
|
|
cgesv_(*args)
|
|
else:
|
|
compile_error("[internal error] bad dtype for gesv")
|
|
|
|
return rv
|
|
|
|
def _solve(a, b):
|
|
if a.ndim == b.ndim:
|
|
if a.shape[:-2] != b.shape[:-2]:
|
|
pre_broadcast = broadcast_shapes(a.shape[:-2], b.shape[:-2])
|
|
a = broadcast_to(a, pre_broadcast + a.shape[-2:])
|
|
b = broadcast_to(b, pre_broadcast + b.shape[-2:])
|
|
else:
|
|
pre_broadcast = broadcast_shapes(a.shape[:-2], b.shape[:-2])
|
|
a1 = broadcast_to(a, pre_broadcast + a.shape[-2:])
|
|
b1 = broadcast_to(b, pre_broadcast + b.shape[-2:])
|
|
return _solve(a1, b1)
|
|
|
|
n = _square_rows(a)
|
|
m, k = _rows_cols(b)
|
|
|
|
if m != n:
|
|
raise ValueError("solve: 'a' and 'b' don't have the same number of rows")
|
|
|
|
r = empty(b.shape, dtype=a.dtype)
|
|
nrhs = b.shape[-1]
|
|
params = GesvParams[a.dtype](fint(n), fint(nrhs))
|
|
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
|
|
b_in = LinearizeData(nrhs, n, b.strides[-1], b.strides[-2])
|
|
r_out = LinearizeData(nrhs, n, r.strides[-1], r.strides[-2])
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
b_ptr = b._ptr(idx + (0, 0))
|
|
r_ptr = r._ptr(idx + (0, 0))
|
|
|
|
a_in.linearize(params.A, a_ptr)
|
|
b_in.linearize(params.B, b_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
r_out.delinearize(r_ptr, params.B)
|
|
else:
|
|
r_out.nan_matrix(r_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Singular matrix")
|
|
|
|
return r
|
|
|
|
def _solve1(a, b):
|
|
n = _square_rows(a)
|
|
m = b.shape[0]
|
|
|
|
if m != n:
|
|
raise ValueError("solve: 'a' and 'b' don't have the same number of rows")
|
|
|
|
r = empty((m,), dtype=a.dtype)
|
|
params = GesvParams[a.dtype](fint(n), fint(1))
|
|
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
|
|
b_in = LinearizeData(1, n, 1, b.strides[0])
|
|
r_out = LinearizeData(1, n, 1, r.strides[0])
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
b_ptr = b.data
|
|
r_ptr = r.data
|
|
|
|
a_in.linearize(params.A, a_ptr)
|
|
b_in.linearize(params.B, b_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
r_out.delinearize(r_ptr, params.B)
|
|
else:
|
|
r_out.nan_matrix(r_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Singular matrix")
|
|
|
|
return r
|
|
|
|
def solve(a, b):
|
|
dtype = type(
|
|
coerce(
|
|
type(asarray(a).data[0]),
|
|
type(asarray(b).data[0])))
|
|
a = _asarray(a, dtype=dtype)
|
|
b = _asarray(b, dtype=dtype)
|
|
|
|
if a.ndim < 2:
|
|
compile_error("'a' must be at least 2-dimensional")
|
|
|
|
if b.ndim == 0:
|
|
compile_error("'b' must be at least 1-dimensional")
|
|
|
|
if b.ndim == 1:
|
|
return _solve1(a, b)
|
|
else:
|
|
return _solve(a, b)
|
|
|
|
def _inv(a, ignore_errors: Static[int]):
|
|
a = _asarray(a)
|
|
n = _square_rows(a)
|
|
|
|
r = empty(a.shape, dtype=a.dtype)
|
|
params = GesvParams[a.dtype](fint(n), fint(n))
|
|
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
|
|
r_out = LinearizeData(n, n, r.strides[-1], r.strides[-2])
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
r_ptr = r._ptr(idx + (0, 0))
|
|
|
|
a_in.linearize(params.A, a_ptr)
|
|
LinearizeData.identity_matrix(params.B, n)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
r_out.delinearize(r_ptr, params.B)
|
|
else:
|
|
r_out.nan_matrix(r_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if not ignore_errors:
|
|
if error_occured:
|
|
raise LinAlgError("Singular matrix")
|
|
|
|
return r
|
|
|
|
def inv(a):
|
|
return _inv(a, ignore_errors=False)
|
|
|
|
|
|
############
|
|
# Cholesky #
|
|
############
|
|
|
|
class PotrfParams:
|
|
A: Ptr[T]
|
|
N: fint
|
|
LDA: fint
|
|
UPLO: byte
|
|
T: type
|
|
|
|
def __init__(self, UPLO: byte, N: fint):
|
|
safe_N = int(N)
|
|
lda = N if N else fint(1)
|
|
mem_buff = cobj(safe_N * safe_N * sizeof(T))
|
|
|
|
a = mem_buff
|
|
self.A = Ptr[T](a)
|
|
self.N = N
|
|
self.LDA = lda
|
|
self.UPLO = UPLO
|
|
|
|
def zero_lower_triangle(self):
|
|
n = int(self.N)
|
|
matrix = self.A
|
|
for i in range(n - 1):
|
|
for j in range(i + 1, n):
|
|
matrix[j] = zero(T)
|
|
matrix += n
|
|
|
|
def zero_upper_triangle(self):
|
|
n = int(self.N)
|
|
matrix = self.A
|
|
matrix += n
|
|
for i in range(1, n):
|
|
for j in range(i):
|
|
matrix[j] = zero(T)
|
|
matrix += n
|
|
|
|
def call(self):
|
|
UPLO = self.UPLO
|
|
N = self.N
|
|
A = self.A
|
|
LDA = self.LDA
|
|
rv = fint()
|
|
|
|
args = (__ptr__(UPLO),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dpotrf_(*args)
|
|
elif T is float32:
|
|
spotrf_(*args)
|
|
elif T is complex:
|
|
zpotrf_(*args)
|
|
elif T is complex64:
|
|
cpotrf_(*args)
|
|
else:
|
|
compile_error("[internal error] bad dtype for potrf")
|
|
|
|
return rv
|
|
|
|
def release(self):
|
|
free(self.A)
|
|
|
|
def cholesky(a, upper: bool = False):
|
|
a = _asarray(a)
|
|
uplo = byte(85 if upper else 76) # 'U' / 'L'
|
|
n = _square_rows(a)
|
|
|
|
params = PotrfParams[a.dtype](uplo, fint(n))
|
|
r = empty(a.shape, dtype=a.dtype)
|
|
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
|
|
r_out = LinearizeData(n, n, r.strides[-1], r.strides[-2])
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
r_ptr = r._ptr(idx + (0, 0))
|
|
|
|
a_in.linearize(params.A, a_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
if upper:
|
|
params.zero_lower_triangle()
|
|
else:
|
|
params.zero_upper_triangle()
|
|
r_out.delinearize(r_ptr, params.A)
|
|
else:
|
|
error_occured = True
|
|
r_out.nan_matrix(r_ptr)
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Matrix is not positive definite")
|
|
|
|
return r
|
|
|
|
|
|
#######
|
|
# Eig #
|
|
#######
|
|
|
|
def _mk_complex_array_from_real(c: Ptr[C], re: Ptr[R], n: int, C: type, R: type):
|
|
for i in range(n):
|
|
c[i] = C(re[i], zero(R))
|
|
|
|
def _mk_complex_array(c: Ptr[C], re: Ptr[R], im: Ptr[R], n: int, C: type, R: type):
|
|
for i in range(n):
|
|
c[i] = C(re[i], im[i])
|
|
|
|
def _mk_complex_array_conjugate_pair(c: Ptr[C], r: Ptr[R], n: int, C: type, R: type):
|
|
for i in range(n):
|
|
re = r[i]
|
|
im = r[i + n]
|
|
c[i] = C(re, im)
|
|
c[i + n] = C(re, -im)
|
|
|
|
def _mk_geev_complex_eigenvectors(c: Ptr[C], r: Ptr[R], i: Ptr[R], n: int, C: type, R: type):
|
|
it = 0
|
|
while it < n:
|
|
if i[it] == zero(R):
|
|
_mk_complex_array_from_real(c, r, n)
|
|
c += n
|
|
r += n
|
|
it += 1
|
|
else:
|
|
_mk_complex_array_conjugate_pair(c, r, n)
|
|
c += 2 * n
|
|
r += 2 * n
|
|
it += 2
|
|
|
|
class GeevParams:
|
|
A: Ptr[T]
|
|
WR: Ptr[B]
|
|
WI: Ptr[T]
|
|
VLR: Ptr[T]
|
|
VRR: Ptr[T]
|
|
WORK: Ptr[T]
|
|
W: Ptr[T]
|
|
VL: Ptr[T]
|
|
VR: Ptr[T]
|
|
N: fint
|
|
LDA: fint
|
|
LDVL: fint
|
|
LDVR: fint
|
|
LWORK: fint
|
|
JOBVL: byte
|
|
JOBVR: byte
|
|
T: type
|
|
B: type
|
|
|
|
def _init_real(self, jobvl: byte, jobvr: byte, n: fint):
|
|
safe_n = int(n)
|
|
a_size = safe_n * safe_n * sizeof(T)
|
|
wr_size = safe_n * sizeof(T)
|
|
wi_size = safe_n * sizeof(T)
|
|
vlr_size = safe_n * safe_n * sizeof(T) if jobvl == byte(86) else 0
|
|
vrr_size = safe_n * safe_n * sizeof(T) if jobvr == byte(86) else 0
|
|
w_size = wr_size * 2
|
|
vl_size = vlr_size * 2
|
|
vr_size = vrr_size * 2
|
|
ld = n if n else fint(1)
|
|
|
|
mem_buff = cobj(a_size + wr_size + wi_size +
|
|
vlr_size + vrr_size +
|
|
w_size + vl_size + vr_size)
|
|
a = mem_buff
|
|
wr = a + a_size
|
|
wi = wr + wr_size
|
|
vlr = wi + wi_size
|
|
vrr = vlr + vlr_size
|
|
w = vrr + vrr_size
|
|
vl = w + w_size
|
|
vr = vl + vl_size
|
|
|
|
self.A = Ptr[T](a)
|
|
self.WR = Ptr[T](wr)
|
|
self.WI = Ptr[T](wi)
|
|
self.VLR = Ptr[T](vlr)
|
|
self.VRR = Ptr[T](vrr)
|
|
self.W = Ptr[T](w)
|
|
self.VL = Ptr[T](vl)
|
|
self.VR = Ptr[T](vr)
|
|
self.N = n
|
|
self.LDA = ld
|
|
self.LDVL = ld
|
|
self.LDVR = ld
|
|
self.JOBVL = jobvl
|
|
self.JOBVR = jobvr
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
self.LWORK = fint(-1)
|
|
self.WORK = __ptr__(work_size_query)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in evd work size query")
|
|
|
|
work_count = cast(work_size_query, int)
|
|
|
|
mem_buff2 = cobj(work_count * sizeof(T))
|
|
work = mem_buff2
|
|
|
|
self.LWORK = fint(work_count)
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def _init_complex(self, jobvl: byte, jobvr: byte, n: fint):
|
|
safe_n = int(n)
|
|
a_size = safe_n * safe_n * sizeof(T)
|
|
w_size = safe_n * sizeof(T)
|
|
vl_size = safe_n * safe_n * sizeof(T) if jobvl == byte(86) else 0
|
|
vr_size = safe_n * safe_n * sizeof(T) if jobvr == byte(86) else 0
|
|
rwork_size = 2 * safe_n * sizeof(B)
|
|
total_size = a_size + w_size + vl_size + vr_size + rwork_size
|
|
ld = n if n else fint(1)
|
|
|
|
mem_buff = cobj(total_size)
|
|
a = mem_buff
|
|
w = a + a_size
|
|
vl = w + w_size
|
|
vr = vl + vl_size
|
|
rwork = vr + vr_size
|
|
|
|
self.A = Ptr[T](a)
|
|
self.WR = Ptr[B](rwork)
|
|
self.WI = Ptr[T]()
|
|
self.VLR = Ptr[T]()
|
|
self.VRR = Ptr[T]()
|
|
self.VL = Ptr[T](vl)
|
|
self.VR = Ptr[T](vr)
|
|
self.W = Ptr[T](w)
|
|
self.N = n
|
|
self.LDA = ld
|
|
self.LDVL = ld
|
|
self.LDVR = ld
|
|
self.JOBVL = jobvl
|
|
self.JOBVR = jobvr
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
self.LWORK = fint(-1)
|
|
self.WORK = __ptr__(work_size_query)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in evd work size query")
|
|
|
|
work_count = cast(work_size_query.real, int)
|
|
if work_count == 0:
|
|
work_count = 1
|
|
|
|
mem_buff2 = cobj(work_count * sizeof(T))
|
|
work = mem_buff2
|
|
|
|
self.LWORK = fint(work_count)
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def __init__(self, jobvl: byte, jobvr: byte, n: fint):
|
|
if T is complex or T is complex64:
|
|
self._init_complex(jobvl, jobvr, n)
|
|
else:
|
|
self._init_real(jobvl, jobvr, n)
|
|
|
|
def call(self):
|
|
A = self.A
|
|
WR = self.WR
|
|
WI = self.WI
|
|
VLR = self.VLR
|
|
VRR = self.VRR
|
|
WORK = self.WORK
|
|
W = self.W
|
|
VL = self.VL
|
|
VR = self.VR
|
|
N = self.N
|
|
LDA = self.LDA
|
|
LDVL = self.LDVL
|
|
LDVR = self.LDVR
|
|
LWORK = self.LWORK
|
|
JOBVL = self.JOBVL
|
|
JOBVR = self.JOBVR
|
|
rv = fint()
|
|
|
|
args_real = (__ptr__(JOBVL),
|
|
__ptr__(JOBVR),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
WR.as_byte(),
|
|
WI.as_byte(),
|
|
VLR.as_byte(),
|
|
__ptr__(LDVL),
|
|
VRR.as_byte(),
|
|
__ptr__(LDVR),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
__ptr__(rv))
|
|
|
|
args_cplx = (__ptr__(JOBVL),
|
|
__ptr__(JOBVR),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
W.as_byte(),
|
|
VL.as_byte(),
|
|
__ptr__(LDVL),
|
|
VR.as_byte(),
|
|
__ptr__(LDVR),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
WR.as_byte(),
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dgeev_(*args_real)
|
|
elif T is float32:
|
|
sgeev_(*args_real)
|
|
elif T is complex:
|
|
zgeev_(*args_cplx)
|
|
elif T is complex64:
|
|
cgeev_(*args_cplx)
|
|
else:
|
|
compile_error("[internal error] bad dtype for geev")
|
|
|
|
return rv
|
|
|
|
def release(self):
|
|
free(self.WORK)
|
|
free(self.A)
|
|
|
|
def process_geev_results(self):
|
|
if T is complex or T is complex64:
|
|
return
|
|
|
|
C = type(_complextype(T))
|
|
|
|
_mk_complex_array(Ptr[C](self.W.as_byte()), self.WR, self.WI, int(self.N))
|
|
|
|
if self.JOBVL == byte(86):
|
|
_mk_geev_complex_eigenvectors(Ptr[C](self.VL.as_byte()), self.VLR, self.WI, int(self.N))
|
|
|
|
if self.JOBVR == byte(86):
|
|
_mk_geev_complex_eigenvectors(Ptr[C](self.VR.as_byte()), self.VRR, self.WI, int(self.N))
|
|
|
|
@tuple
|
|
class EigResult[A, B]:
|
|
eigenvalues: A
|
|
eigenvectors: B
|
|
|
|
def __getitem__(self, idx: Static[int]):
|
|
if idx == 0 or idx == -2:
|
|
return self.eigenvalues
|
|
elif idx == 1 or idx == -1:
|
|
return self.eigenvectors
|
|
else:
|
|
compile_error("tuple ('EigResult') index out of range")
|
|
|
|
def _eig(a, JOBVL: byte, JOBVR: byte, compute_eigenvectors: Static[int]):
|
|
a = _asarray(a)
|
|
B = type(_basetype(a.dtype))
|
|
C = type(_complextype(a.dtype))
|
|
|
|
n = _square_rows(a)
|
|
params = GeevParams[a.dtype, B](JOBVL, JOBVR, fint(n))
|
|
|
|
eigenvalues = empty(a.shape[:-2] + (n,), dtype=C)
|
|
if compute_eigenvectors:
|
|
eigenvectors = empty(a.shape, dtype=C)
|
|
else:
|
|
eigenvectors = None
|
|
|
|
a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2])
|
|
w_out = LinearizeData(1, n, 0, eigenvalues.strides[-1])
|
|
if compute_eigenvectors:
|
|
vr_out = LinearizeData(n, n, eigenvectors.strides[-1], eigenvectors.strides[-2])
|
|
else:
|
|
vr_out = None
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
matrix_ptr = a._ptr(idx + (0, 0))
|
|
eigval_ptr = eigenvalues._ptr(idx + (0,))
|
|
if compute_eigenvectors:
|
|
eigvec_ptr = eigenvectors._ptr(idx + (0, 0))
|
|
else:
|
|
eigvec_ptr = None
|
|
|
|
a_in.linearize(params.A, matrix_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
params.process_geev_results()
|
|
w_out.delinearize(eigval_ptr, Ptr[C](params.W.as_byte()))
|
|
if compute_eigenvectors:
|
|
vr_out.delinearize(eigvec_ptr, Ptr[C](params.VR.as_byte()))
|
|
else:
|
|
w_out.nan_matrix(eigval_ptr)
|
|
if compute_eigenvectors:
|
|
vr_out.nan_matrix(eigvec_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Eigenvalues did not converge")
|
|
|
|
if compute_eigenvectors:
|
|
return EigResult(eigenvalues, eigenvectors)
|
|
else:
|
|
return eigenvalues
|
|
|
|
def eig(a):
|
|
jobvl_code = byte(78) # 'N'
|
|
jobvr_code = byte(86) # 'V'
|
|
return _eig(a, jobvl_code, jobvr_code, compute_eigenvectors=True)
|
|
|
|
def eigvals(a):
|
|
jobvl_code = byte(78) # 'N'
|
|
jobvr_code = byte(78) # 'N'
|
|
return _eig(a, jobvl_code, jobvr_code, compute_eigenvectors=False)
|
|
|
|
|
|
#######
|
|
# SVD #
|
|
#######
|
|
|
|
def _compute_urows_vtcolumns(jobz: byte, m: fint, n: fint):
|
|
if jobz == byte(78): # 'N'
|
|
return (fint(0), fint(0))
|
|
elif jobz == byte(65): # 'A'
|
|
return (m, n)
|
|
else: # 'S'
|
|
min_m_n = m if m < n else n
|
|
return (min_m_n, min_m_n)
|
|
|
|
class GesddParams:
|
|
A: Ptr[T]
|
|
S: Ptr[B]
|
|
U: Ptr[T]
|
|
VT: Ptr[T]
|
|
WORK: Ptr[T]
|
|
RWORK: Ptr[B]
|
|
IWORK: Ptr[fint]
|
|
M: fint
|
|
N: fint
|
|
LDA: fint
|
|
LDU: fint
|
|
LDVT: fint
|
|
LWORK: fint
|
|
JOBZ: byte
|
|
T: type
|
|
B: type
|
|
|
|
def _init_real(self, jobz: byte, m: fint, n: fint):
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
min_m_n = m if m < n else n
|
|
safe_min_m_n = int(min_m_n)
|
|
s_size = safe_min_m_n * sizeof(T)
|
|
iwork_size = 8 * safe_min_m_n * sizeof(fint)
|
|
ld = m if m else fint(1)
|
|
|
|
u_row_count, vt_column_count = _compute_urows_vtcolumns(jobz, m, n)
|
|
safe_u_row_count = int(u_row_count)
|
|
safe_vt_column_count = int(vt_column_count)
|
|
|
|
u_size = safe_u_row_count * safe_m * sizeof(T)
|
|
vt_size = safe_n * safe_vt_column_count * sizeof(T)
|
|
mem_buff = cobj(a_size + s_size + u_size + vt_size + iwork_size)
|
|
|
|
a = mem_buff
|
|
s = a + a_size
|
|
u = s + s_size
|
|
vt = u + u_size
|
|
iwork = vt + vt_size
|
|
vt_column_count = vt_column_count if vt_column_count else fint(1)
|
|
|
|
self.M = m
|
|
self.N = n
|
|
self.A = Ptr[T](a)
|
|
self.S = Ptr[B](s)
|
|
self.U = Ptr[T](u)
|
|
self.VT = Ptr[T](vt)
|
|
self.RWORK = Ptr[B]()
|
|
self.IWORK = Ptr[fint](iwork)
|
|
self.LDA = ld
|
|
self.LDU = ld
|
|
self.LDVT = vt_column_count
|
|
self.JOBZ = jobz
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
self.LWORK = fint(-1)
|
|
self.WORK = __ptr__(work_size_query)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in gesdd work size query")
|
|
|
|
work_count = cast(work_size_query, int)
|
|
if work_count == 0:
|
|
work_count = 1
|
|
|
|
work_size = work_count * sizeof(T)
|
|
mem_buff2 = cobj(work_size)
|
|
work = mem_buff2
|
|
|
|
self.LWORK = fint(work_count)
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def _init_complex(self, jobz: byte, m: fint, n: fint):
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
min_m_n = m if m < n else n
|
|
safe_min_m_n = int(min_m_n)
|
|
ld = m if m else fint(1)
|
|
|
|
u_row_count, vt_column_count = _compute_urows_vtcolumns(jobz, m, n)
|
|
safe_u_row_count = int(u_row_count)
|
|
safe_vt_column_count = int(vt_column_count)
|
|
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
s_size = safe_min_m_n * sizeof(B)
|
|
u_size = safe_u_row_count * safe_m * sizeof(T)
|
|
vt_size = safe_n * safe_vt_column_count * sizeof(T)
|
|
rwork_size = (7 * safe_min_m_n if jobz == byte(78) else
|
|
5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n)
|
|
rwork_size *= sizeof(T)
|
|
iwork_size = 8 * safe_min_m_n * sizeof(fint)
|
|
mem_buff = cobj(a_size +
|
|
s_size +
|
|
u_size +
|
|
vt_size +
|
|
rwork_size +
|
|
iwork_size)
|
|
|
|
a = mem_buff
|
|
s = a + a_size
|
|
u = s + s_size
|
|
vt = u + u_size
|
|
rwork = vt + vt_size
|
|
iwork = rwork + rwork_size
|
|
|
|
vt_column_count = vt_column_count if vt_column_count else fint(1)
|
|
|
|
self.A = Ptr[T](a)
|
|
self.S = Ptr[B](s)
|
|
self.U = Ptr[T](u)
|
|
self.VT = Ptr[T](vt)
|
|
self.RWORK = Ptr[B](rwork)
|
|
self.IWORK = Ptr[fint](iwork)
|
|
self.M = m
|
|
self.N = n
|
|
self.LDA = ld
|
|
self.LDU = ld
|
|
self.LDVT = vt_column_count
|
|
self.JOBZ = jobz
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
self.LWORK = fint(-1)
|
|
self.WORK = __ptr__(work_size_query)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in gesdd work size query")
|
|
|
|
work_count = cast(Ptr[B](__ptr__(work_size_query).as_byte())[0], int)
|
|
if work_count == 0:
|
|
work_count = 1
|
|
|
|
work_size = work_count * sizeof(T)
|
|
mem_buff2 = cobj(work_size)
|
|
work = mem_buff2
|
|
|
|
self.LWORK = fint(work_count)
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def __init__(self, jobz: byte, m: fint, n: fint):
|
|
if T is complex or T is complex64:
|
|
self._init_complex(jobz, m, n)
|
|
else:
|
|
self._init_real(jobz, m, n)
|
|
|
|
def call(self):
|
|
A = self.A
|
|
S = self.S
|
|
U = self.U
|
|
VT = self.VT
|
|
WORK = self.WORK
|
|
RWORK = self.RWORK
|
|
IWORK = self.IWORK
|
|
M = self.M
|
|
N = self.N
|
|
LDA = self.LDA
|
|
LDU = self.LDU
|
|
LDVT = self.LDVT
|
|
LWORK = self.LWORK
|
|
JOBZ = self.JOBZ
|
|
rv = fint()
|
|
|
|
args_real = (__ptr__(JOBZ),
|
|
__ptr__(M),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
S.as_byte(),
|
|
U.as_byte(),
|
|
__ptr__(LDU),
|
|
VT.as_byte(),
|
|
__ptr__(LDVT),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
IWORK,
|
|
__ptr__(rv))
|
|
|
|
args_cplx = (__ptr__(JOBZ),
|
|
__ptr__(M),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
S.as_byte(),
|
|
U.as_byte(),
|
|
__ptr__(LDU),
|
|
VT.as_byte(),
|
|
__ptr__(LDVT),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
RWORK.as_byte(),
|
|
IWORK,
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dgesdd_(*args_real)
|
|
elif T is float32:
|
|
sgesdd_(*args_real)
|
|
elif T is complex:
|
|
zgesdd_(*args_cplx)
|
|
elif T is complex64:
|
|
cgesdd_(*args_cplx)
|
|
else:
|
|
compile_error("[internal error] bad dtype for gesdd")
|
|
|
|
return rv
|
|
|
|
def release(self):
|
|
free(self.A)
|
|
free(self.WORK)
|
|
|
|
@tuple
|
|
class SVDResult[A1, A2]:
|
|
U: A1
|
|
S: A2
|
|
Vh: A1
|
|
|
|
def __getitem__(self, idx: Static[int]):
|
|
if idx == 0 or idx == -3:
|
|
return self.U
|
|
elif idx == 1 or idx == -2:
|
|
return self.S
|
|
elif idx == 2 or idx == -1:
|
|
return self.Vh
|
|
else:
|
|
compile_error("tuple ('SVDResult') index out of range")
|
|
|
|
def _svd(a, JOBZ: byte, compute_uv: Static[int]):
|
|
B = type(_basetype(a.dtype))
|
|
m, n = _rows_cols(a)
|
|
min_m_n = min(m, n)
|
|
|
|
params = GesddParams[a.dtype, B](JOBZ, fint(m), fint(n))
|
|
S = empty(a.shape[:-2] + (min_m_n,), dtype=B)
|
|
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
|
|
s_out = LinearizeData(1, min_m_n, 0, S.strides[-1])
|
|
|
|
if compute_uv:
|
|
if JOBZ == byte(83): # 'S'
|
|
u_columns = min_m_n
|
|
v_rows = min_m_n
|
|
else:
|
|
u_columns = m
|
|
v_rows = n
|
|
|
|
U = empty(a.shape[:-2] + (m, u_columns), dtype=a.dtype)
|
|
V = empty(a.shape[:-2] + (v_rows, n), dtype=a.dtype)
|
|
u_out = LinearizeData(u_columns, m, U.strides[-1], U.strides[-2])
|
|
v_out = LinearizeData(n, v_rows, V.strides[-1], V.strides[-2])
|
|
else:
|
|
U = None
|
|
V = None
|
|
u_out = None
|
|
v_out = None
|
|
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
s_ptr = S._ptr(idx + (0,))
|
|
if compute_uv:
|
|
u_ptr = U._ptr(idx + (0, 0))
|
|
v_ptr = V._ptr(idx + (0, 0))
|
|
else:
|
|
u_ptr = None
|
|
v_ptr = None
|
|
|
|
a_in.linearize(params.A, a_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
if compute_uv:
|
|
if JOBZ == byte(65) and min_m_n == 0:
|
|
LinearizeData.identity_matrix(params.U, m)
|
|
LinearizeData.identity_matrix(params.VT, n)
|
|
|
|
u_out.delinearize(u_ptr, params.U)
|
|
s_out.delinearize(s_ptr, params.S)
|
|
v_out.delinearize(v_ptr, params.VT)
|
|
else:
|
|
s_out.delinearize(s_ptr, params.S)
|
|
else:
|
|
if compute_uv:
|
|
u_out.nan_matrix(u_ptr)
|
|
s_out.nan_matrix(s_ptr)
|
|
v_out.nan_matrix(v_ptr)
|
|
else:
|
|
s_out.nan_matrix(s_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("SVD did not converge")
|
|
|
|
if compute_uv:
|
|
return SVDResult(U, S, V)
|
|
else:
|
|
return S
|
|
|
|
def svd(a,
|
|
full_matrices: bool = True,
|
|
compute_uv: Static[int] = True,
|
|
hermitian: bool = False):
|
|
|
|
a = _asarray(a)
|
|
dtype = a.dtype
|
|
m, n = _rows_cols(a)
|
|
k = min(m, n)
|
|
|
|
if hermitian:
|
|
pre_shape = a.shape[:-2]
|
|
s_shape = pre_shape + (k,)
|
|
u_shape = pre_shape + ((m, m) if full_matrices else (m, k))
|
|
v_shape = pre_shape + ((n, n) if full_matrices else (k, n))
|
|
|
|
if compute_uv:
|
|
e_s, e_u = eigh(a) # note that `k == m == n` in this case
|
|
sidx = Ptr[int](k)
|
|
if dtype is complex:
|
|
SH = empty(s_shape, float)
|
|
elif dtype is complex64:
|
|
SH = empty(s_shape, float32)
|
|
else:
|
|
SH = empty(s_shape, dtype)
|
|
UH = empty(u_shape, dtype)
|
|
VH = empty(v_shape, dtype)
|
|
|
|
if a.ndim == 2:
|
|
for r in range(k):
|
|
sidx[r] = r
|
|
sort(sidx, k, key=lambda x: -abs(e_s.data[x]))
|
|
|
|
for j in range(k):
|
|
mapping = sidx[j]
|
|
s_elem = e_s.data[mapping]
|
|
negative = s_elem < type(s_elem)()
|
|
SH.data[j] = abs(s_elem)
|
|
|
|
for i in range(k):
|
|
u_elem = e_u._ptr((i, mapping))[0]
|
|
dst = UH._ptr((i, j))
|
|
dst[0] = u_elem
|
|
|
|
dst = VH._ptr((j, i)) # transpose
|
|
if negative:
|
|
u_elem = -u_elem
|
|
if hasattr(u_elem, "conjugate"):
|
|
u_elem = u_elem.conjugate()
|
|
dst[0] = u_elem
|
|
else:
|
|
for idx in multirange(pre_shape):
|
|
sub_e_s = e_s._ptr(idx + (0,))
|
|
sub_e_u = e_u._ptr(idx + (0, 0))
|
|
sub_SH = SH._ptr(idx + (0,))
|
|
sub_UH = UH._ptr(idx + (0, 0))
|
|
sub_VH = VH._ptr(idx + (0, 0))
|
|
|
|
for r in range(k):
|
|
sidx[r] = r
|
|
sort(sidx, k, key=lambda x: -abs(sub_e_s[x]))
|
|
|
|
for j in range(k):
|
|
mapping = sidx[j]
|
|
s_elem = sub_e_s[mapping]
|
|
negative = s_elem < type(s_elem)()
|
|
sub_SH[j] = abs(s_elem)
|
|
|
|
for i in range(k):
|
|
u_elem = (sub_e_u + (i * k + mapping))[0]
|
|
dst = sub_UH + (i * k + j)
|
|
dst[0] = u_elem
|
|
|
|
dst = sub_VH + (j * k + i) # transpose
|
|
if negative:
|
|
u_elem = -u_elem
|
|
if hasattr(u_elem, "conjugate"):
|
|
u_elem = u_elem.conjugate()
|
|
dst[0] = u_elem
|
|
|
|
return SVDResult(UH, SH, VH)
|
|
else:
|
|
evh = eigvalsh(a)
|
|
|
|
if a.ndim == 2:
|
|
evh = eigvalsh(a)
|
|
for i in range(k):
|
|
evh.data[i] = abs(evh.data[i])
|
|
sort(evh.data, k, key=lambda x: -x)
|
|
else:
|
|
for idx in multirange(pre_shape):
|
|
sub_evh = evh._ptr(idx + (0,))
|
|
for i in range(k):
|
|
sub_evh[i] = abs(sub_evh[i])
|
|
sort(sub_evh, k, key=lambda x: -x)
|
|
|
|
return evh
|
|
|
|
if compute_uv:
|
|
jobz_code = byte(65 if full_matrices else 83) # 'A' / 'S'
|
|
return _svd(a, jobz_code, compute_uv=True)
|
|
else:
|
|
jobz_code = byte(78) # 'N'
|
|
return _svd(a, jobz_code, compute_uv=False)
|
|
|
|
def svdvals(x):
|
|
return svd(x, compute_uv=False, hermitian=False)
|
|
|
|
|
|
######
|
|
# QR #
|
|
######
|
|
|
|
class GeqrfParams:
|
|
M: fint
|
|
N: fint
|
|
A: Ptr[T]
|
|
LDA: fint
|
|
TAU: Ptr[T]
|
|
WORK: Ptr[T]
|
|
LWORK: fint
|
|
T: type
|
|
|
|
def _init_real(self, m: fint, n: fint):
|
|
min_m_n = m if m < n else n
|
|
safe_min_m_n = int(min_m_n)
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
tau_size = safe_min_m_n * sizeof(T)
|
|
lda = m if m else fint(1)
|
|
|
|
mem_buff = cobj(a_size + tau_size)
|
|
a = mem_buff
|
|
tau = a + a_size
|
|
str.memset(tau.as_byte(), byte(0), tau_size)
|
|
|
|
self.M = m
|
|
self.N = n
|
|
self.A = Ptr[T](a)
|
|
self.TAU = Ptr[T](tau)
|
|
self.LDA = lda
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
|
|
self.WORK = __ptr__(work_size_query)
|
|
self.LWORK = fint(-1)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in geqrf work size query")
|
|
|
|
work_count = cast(work_size_query, fint)
|
|
|
|
lw = n if n else fint(1)
|
|
lw = lw if lw > work_count else work_count
|
|
self.LWORK = lw
|
|
|
|
work_size = int(lw) * sizeof(T)
|
|
mem_buff2 = cobj(work_size)
|
|
|
|
work = mem_buff2
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def _init_complex(self, m: fint, n: fint):
|
|
min_m_n = m if m < n else n
|
|
safe_min_m_n = int(min_m_n)
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
tau_size = safe_min_m_n * sizeof(T)
|
|
lda = m if m else fint(1)
|
|
|
|
mem_buff = cobj(a_size + tau_size)
|
|
a = mem_buff
|
|
tau = a + a_size
|
|
str.memset(tau.as_byte(), byte(0), tau_size)
|
|
|
|
self.M = m
|
|
self.N = n
|
|
self.A = Ptr[T](a)
|
|
self.TAU = Ptr[T](tau)
|
|
self.LDA = lda
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
|
|
self.WORK = __ptr__(work_size_query)
|
|
self.LWORK = fint(-1)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in geqrf work size query")
|
|
|
|
work_count = cast(work_size_query.real, fint)
|
|
|
|
lw = n if n else fint(1)
|
|
lw = lw if lw > work_count else work_count
|
|
self.LWORK = lw
|
|
|
|
work_size = int(lw) * sizeof(T)
|
|
mem_buff2 = cobj(work_size)
|
|
|
|
work = mem_buff2
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def __init__(self, m: fint, n: fint):
|
|
if T is complex or T is complex64:
|
|
self._init_complex(m, n)
|
|
else:
|
|
self._init_real(m, n)
|
|
|
|
def call(self):
|
|
M = self.M
|
|
N = self.N
|
|
A = self.A
|
|
LDA = self.LDA
|
|
TAU = self.TAU
|
|
WORK = self.WORK
|
|
LWORK = self.LWORK
|
|
rv = fint()
|
|
|
|
args = (__ptr__(M),
|
|
__ptr__(N),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
TAU.as_byte(),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dgeqrf_(*args)
|
|
elif T is float32:
|
|
sgeqrf_(*args)
|
|
elif T is complex:
|
|
zgeqrf_(*args)
|
|
elif T is complex64:
|
|
cgeqrf_(*args)
|
|
else:
|
|
compile_error("[internal error] bad dtype for geqrf")
|
|
|
|
return rv
|
|
|
|
def release(self):
|
|
free(self.A)
|
|
free(self.WORK)
|
|
|
|
def _qr_r_raw(a):
|
|
m, n = _rows_cols(a)
|
|
params = GeqrfParams[a.dtype](fint(m), fint(n))
|
|
k = min(m, n)
|
|
|
|
tau = empty(a.shape[:-2] + (k,), dtype=a.dtype)
|
|
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
|
|
tau_out = LinearizeData(1, k, 1, tau.strides[-1])
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
tau_ptr = tau._ptr(idx + (0,))
|
|
a_in.linearize(params.A, a_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
a_in.delinearize(a_ptr, params.A)
|
|
tau_out.delinearize(tau_ptr, params.TAU)
|
|
else:
|
|
tau_out.nan_matrix(tau_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Incorrect argument found while performing "
|
|
"QR factorization")
|
|
|
|
return tau
|
|
|
|
class GqrParams:
|
|
M: fint
|
|
MC: fint
|
|
MN: fint
|
|
A: Ptr[T]
|
|
Q: Ptr[T]
|
|
LDA: fint
|
|
TAU: Ptr[T]
|
|
WORK: Ptr[T]
|
|
LWORK: fint
|
|
T: type
|
|
|
|
def _init_real(self, m: fint, n: fint, mc: fint):
|
|
min_m_n = m if m < n else n
|
|
safe_mc = int(mc)
|
|
safe_min_m_n = int(min_m_n)
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
q_size = safe_m * safe_mc * sizeof(T)
|
|
tau_size = safe_min_m_n * sizeof(T)
|
|
|
|
lda = m if m else fint(1)
|
|
mem_buff = cobj(q_size + tau_size + a_size)
|
|
|
|
q = mem_buff
|
|
tau = q + q_size
|
|
a = tau + tau_size
|
|
|
|
self.M = m
|
|
self.MC = mc
|
|
self.MN = min_m_n
|
|
self.A = Ptr[T](a)
|
|
self.Q = Ptr[T](q)
|
|
self.TAU = Ptr[T](tau)
|
|
self.LDA = lda
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
|
|
self.WORK = __ptr__(work_size_query)
|
|
self.LWORK = fint(-1)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in gqr work size query")
|
|
|
|
work_count = cast(work_size_query, fint)
|
|
|
|
lw = n if n else fint(1)
|
|
lw = lw if lw > work_count else work_count
|
|
self.LWORK = lw
|
|
|
|
work_size = int(lw) * sizeof(T)
|
|
mem_buff2 = cobj(work_size)
|
|
|
|
work = mem_buff2
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def _init_complex(self, m: fint, n: fint, mc: fint):
|
|
min_m_n = m if m < n else n
|
|
safe_mc = int(mc)
|
|
safe_min_m_n = int(min_m_n)
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
q_size = safe_m * safe_mc * sizeof(T)
|
|
tau_size = safe_min_m_n * sizeof(T)
|
|
|
|
lda = m if m else fint(1)
|
|
mem_buff = cobj(q_size + tau_size + a_size)
|
|
|
|
q = mem_buff
|
|
tau = q + q_size
|
|
a = tau + tau_size
|
|
|
|
self.M = m
|
|
self.MC = mc
|
|
self.MN = min_m_n
|
|
self.A = Ptr[T](a)
|
|
self.Q = Ptr[T](q)
|
|
self.TAU = Ptr[T](tau)
|
|
self.LDA = lda
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
|
|
self.WORK = __ptr__(work_size_query)
|
|
self.LWORK = fint(-1)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in gqr work size query")
|
|
|
|
work_count = cast(work_size_query.real, fint)
|
|
|
|
lw = n if n else fint(1)
|
|
lw = lw if lw > work_count else work_count
|
|
self.LWORK = lw
|
|
|
|
work_size = int(lw) * sizeof(T)
|
|
mem_buff2 = cobj(work_size)
|
|
|
|
work = mem_buff2
|
|
self.WORK = Ptr[T](work)
|
|
|
|
def __init__(self, m: fint, n: fint, mc: fint):
|
|
if T is complex or T is complex64:
|
|
self._init_complex(m, n, mc)
|
|
else:
|
|
self._init_real(m, n, mc)
|
|
|
|
def __init__(self, m: fint, n: fint):
|
|
self.__init__(m, n, m if m < n else n)
|
|
|
|
def call(self):
|
|
M = self.M
|
|
MC = self.MC
|
|
MN = self.MN
|
|
A = self.A
|
|
Q = self.Q
|
|
LDA = self.LDA
|
|
TAU = self.TAU
|
|
WORK = self.WORK
|
|
LWORK = self.LWORK
|
|
rv = fint()
|
|
|
|
args = (__ptr__(M),
|
|
__ptr__(MC),
|
|
__ptr__(MN),
|
|
Q.as_byte(),
|
|
__ptr__(LDA),
|
|
TAU.as_byte(),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dorgqr_(*args)
|
|
elif T is float32:
|
|
sorgqr_(*args)
|
|
elif T is complex:
|
|
zungqr_(*args)
|
|
elif T is complex64:
|
|
cungqr_(*args)
|
|
else:
|
|
compile_error("[internal error] bad dtype for geqrf")
|
|
|
|
return rv
|
|
|
|
def release(self):
|
|
free(self.Q)
|
|
free(self.WORK)
|
|
|
|
def _qr_reduced(a, tau):
|
|
m, n = _rows_cols(a)
|
|
params = GqrParams[a.dtype](fint(m), fint(n))
|
|
k = min(m, n)
|
|
|
|
q = empty(a.shape[:-2] + (m, k), dtype=a.dtype)
|
|
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
|
|
tau_in = LinearizeData(1, k, 1, tau.strides[-1])
|
|
q_out = LinearizeData(k, m, q.strides[-1], q.strides[-2])
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
tau_ptr = tau._ptr(idx + (0,))
|
|
q_ptr = q._ptr(idx + (0, 0))
|
|
a_in.linearize(params.A, a_ptr)
|
|
a_in.linearize(params.Q, a_ptr)
|
|
tau_in.linearize(params.TAU, tau_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
q_out.delinearize(q_ptr, params.Q)
|
|
else:
|
|
q_out.nan_matrix(q_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Incorrect argument found while performing "
|
|
"QR factorization")
|
|
|
|
return q
|
|
|
|
def _qr_complete(a, tau):
|
|
m, n = _rows_cols(a)
|
|
params = GqrParams[a.dtype](fint(m), fint(n), fint(m))
|
|
k = min(m, n)
|
|
|
|
q = empty(a.shape[:-2] + (m, m), dtype=a.dtype)
|
|
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
|
|
tau_in = LinearizeData(1, k, 1, tau.strides[-1])
|
|
q_out = LinearizeData(m, m, q.strides[-1], q.strides[-2])
|
|
error_occured = False
|
|
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
tau_ptr = tau._ptr(idx + (0,))
|
|
q_ptr = q._ptr(idx + (0, 0))
|
|
a_in.linearize(params.A, a_ptr)
|
|
a_in.linearize(params.Q, a_ptr)
|
|
tau_in.linearize(params.TAU, tau_ptr)
|
|
not_ok = params.call()
|
|
|
|
if not not_ok:
|
|
q_out.delinearize(q_ptr, params.Q)
|
|
else:
|
|
q_out.nan_matrix(q_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("Incorrect argument found while performing "
|
|
"QR factorization")
|
|
|
|
return q
|
|
|
|
@tuple
|
|
class QRResult[A]:
|
|
Q: A
|
|
R: A
|
|
|
|
def __getitem__(self, idx: Static[int]):
|
|
if idx == 0 or idx == -2:
|
|
return self.Q
|
|
elif idx == 1 or idx == -1:
|
|
return self.R
|
|
else:
|
|
compile_error("tuple ('QRResult') index out of range")
|
|
|
|
def _triu(a):
|
|
m = a.shape[-2]
|
|
n = a.shape[-1]
|
|
sm = a.strides[-2]
|
|
sn = a.strides[-1]
|
|
for idx in multirange(a.shape[:-2]):
|
|
a_ptr = a._ptr(idx + (0, 0))
|
|
for i in range(1, m):
|
|
for j in range(min(i, n)):
|
|
p = a_ptr.as_byte() + i*sm + j*sn
|
|
Ptr[a.dtype](p)[0] = zero(a.dtype)
|
|
|
|
def qr(a, mode: Static[str] = 'reduced'):
|
|
if (mode != 'reduced' and mode != 'complete' and
|
|
mode != 'r' and mode != 'raw'):
|
|
compile_error("Unrecognized mode '" + mode + "'")
|
|
|
|
a0 = a
|
|
a = _asarray(a)
|
|
# copy if _asarray() didn't do so already
|
|
if isinstance(a0, ndarray):
|
|
if a0.dtype is a.dtype:
|
|
a = a.copy()
|
|
|
|
tau = _qr_r_raw(a)
|
|
m, n = _rows_cols(a)
|
|
mn = min(m, n)
|
|
|
|
if mode == 'r':
|
|
r = a[..., :mn, :]
|
|
_triu(r)
|
|
return r
|
|
|
|
if mode == 'raw':
|
|
q = a.T
|
|
return q, tau
|
|
|
|
if mode == 'complete' and m > n:
|
|
mc = m
|
|
q = _qr_complete(a, tau)
|
|
else:
|
|
mc = mn
|
|
q = _qr_reduced(a, tau)
|
|
|
|
r = a[..., :mc, :]
|
|
_triu(a)
|
|
return QRResult(q, r)
|
|
|
|
|
|
#################
|
|
# Least Squares #
|
|
#################
|
|
|
|
class GelsdParams:
|
|
M: fint
|
|
N: fint
|
|
NRHS: fint
|
|
A: Ptr[T]
|
|
LDA: fint
|
|
B_: Ptr[T]
|
|
LDB: fint
|
|
S: Ptr[B]
|
|
RCOND: Ptr[B]
|
|
RANK: fint
|
|
WORK: Ptr[T]
|
|
LWORK: fint
|
|
RWORK: Ptr[B]
|
|
IWORK: Ptr[fint]
|
|
T: type
|
|
B: type
|
|
|
|
def _init_real(self, m: fint, n: fint, nrhs: fint):
|
|
min_m_n = m if m < n else n
|
|
max_m_n = m if m > n else n
|
|
safe_min_m_n = int(min_m_n)
|
|
safe_max_m_n = int(max_m_n)
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
safe_nrhs = int(nrhs)
|
|
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
b_size = safe_max_m_n * safe_nrhs * sizeof(T)
|
|
s_size = safe_min_m_n * sizeof(T)
|
|
|
|
lda = m if m else fint(1)
|
|
ldb = max_m_n if max_m_n else fint(1)
|
|
|
|
msize = a_size + b_size + s_size
|
|
mem_buff = cobj(msize if msize else 1)
|
|
|
|
a = mem_buff
|
|
b = a + a_size
|
|
s = b + b_size
|
|
|
|
self.M = m
|
|
self.N = n
|
|
self.NRHS = nrhs
|
|
self.A = Ptr[T](a)
|
|
self.B_ = Ptr[T](b)
|
|
self.S = Ptr[B](s)
|
|
self.LDA = lda
|
|
self.LDB = ldb
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
iwork_size_query = fint()
|
|
|
|
self.WORK = __ptr__(work_size_query)
|
|
self.IWORK = __ptr__(iwork_size_query)
|
|
self.RWORK = Ptr[B]()
|
|
self.LWORK = fint(-1)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in gelsd work size query")
|
|
|
|
work_count = cast(work_size_query, fint)
|
|
work_size = int(work_count) * sizeof(T)
|
|
iwork_size = int(iwork_size_query) * sizeof(fint)
|
|
|
|
mem_buff2 = cobj(work_size + iwork_size)
|
|
work = mem_buff2
|
|
iwork = work + work_size
|
|
|
|
self.WORK = Ptr[T](work)
|
|
self.RWORK = Ptr[B]()
|
|
self.IWORK = Ptr[fint](iwork)
|
|
self.LWORK = work_count
|
|
|
|
def _init_complex(self, m: fint, n: fint, nrhs: fint):
|
|
min_m_n = m if m < n else n
|
|
max_m_n = m if m > n else n
|
|
safe_min_m_n = int(min_m_n)
|
|
safe_max_m_n = int(max_m_n)
|
|
safe_m = int(m)
|
|
safe_n = int(n)
|
|
safe_nrhs = int(nrhs)
|
|
|
|
a_size = safe_m * safe_n * sizeof(T)
|
|
b_size = safe_max_m_n * safe_nrhs * sizeof(T)
|
|
s_size = safe_min_m_n * sizeof(T)
|
|
|
|
lda = m if m else fint(1)
|
|
ldb = max_m_n if max_m_n else fint(1)
|
|
|
|
msize = a_size + b_size + s_size
|
|
mem_buff = cobj(msize if msize else 1)
|
|
|
|
a = mem_buff
|
|
b = a + a_size
|
|
s = b + b_size
|
|
|
|
self.M = m
|
|
self.N = n
|
|
self.NRHS = nrhs
|
|
self.A = Ptr[T](a)
|
|
self.B_ = Ptr[T](b)
|
|
self.S = Ptr[B](s)
|
|
self.LDA = lda
|
|
self.LDB = ldb
|
|
|
|
# work size query
|
|
work_size_query = T()
|
|
rwork_size_query = B()
|
|
iwork_size_query = fint()
|
|
|
|
self.WORK = __ptr__(work_size_query)
|
|
self.IWORK = __ptr__(iwork_size_query)
|
|
self.RWORK = __ptr__(rwork_size_query)
|
|
self.LWORK = fint(-1)
|
|
|
|
if self.call():
|
|
free(mem_buff)
|
|
raise LinAlgError("error in gelsd work size query")
|
|
|
|
work_count = cast(work_size_query.real, fint)
|
|
work_size = int(work_count) * sizeof(T)
|
|
rwork_size = cast(rwork_size_query, int) * sizeof(B)
|
|
iwork_size = int(iwork_size_query) * sizeof(fint)
|
|
|
|
mem_buff2 = cobj(work_size + rwork_size + iwork_size)
|
|
work = mem_buff2
|
|
rwork = work + work_size
|
|
iwork = rwork + rwork_size
|
|
|
|
self.WORK = Ptr[T](work)
|
|
self.RWORK = Ptr[B](rwork)
|
|
self.IWORK = Ptr[fint](iwork)
|
|
self.LWORK = work_count
|
|
|
|
def __init__(self, m: fint, n: fint, nrhs: fint):
|
|
if T is complex or T is complex64:
|
|
self._init_complex(m, n, nrhs)
|
|
else:
|
|
self._init_real(m, n, nrhs)
|
|
|
|
def call(self):
|
|
M = self.M
|
|
N = self.N
|
|
NRHS = self.NRHS
|
|
A = self.A
|
|
LDA = self.LDA
|
|
B = self.B_
|
|
LDB = self.LDB
|
|
S = self.S
|
|
RCOND = self.RCOND
|
|
RANK = self.RANK
|
|
WORK = self.WORK
|
|
LWORK = self.LWORK
|
|
RWORK = self.RWORK
|
|
IWORK = self.IWORK
|
|
rv = fint()
|
|
|
|
args_real = (__ptr__(M),
|
|
__ptr__(N),
|
|
__ptr__(NRHS),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
B.as_byte(),
|
|
__ptr__(LDB),
|
|
S.as_byte(),
|
|
RCOND.as_byte(),
|
|
__ptr__(RANK),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
IWORK,
|
|
__ptr__(rv))
|
|
|
|
args_cplx = (__ptr__(M),
|
|
__ptr__(N),
|
|
__ptr__(NRHS),
|
|
A.as_byte(),
|
|
__ptr__(LDA),
|
|
B.as_byte(),
|
|
__ptr__(LDB),
|
|
S.as_byte(),
|
|
RCOND.as_byte(),
|
|
__ptr__(RANK),
|
|
WORK.as_byte(),
|
|
__ptr__(LWORK),
|
|
RWORK.as_byte(),
|
|
IWORK,
|
|
__ptr__(rv))
|
|
|
|
if T is float:
|
|
dgelsd_(*args_real)
|
|
elif T is float32:
|
|
sgelsd_(*args_real)
|
|
elif T is complex:
|
|
zgelsd_(*args_cplx)
|
|
elif T is complex64:
|
|
cgelsd_(*args_cplx)
|
|
else:
|
|
compile_error("[internal error] bad dtype for gelsd")
|
|
|
|
self.RANK = RANK
|
|
|
|
return rv
|
|
|
|
def release(self):
|
|
free(self.A)
|
|
free(self.WORK)
|
|
|
|
def _abs2(p: Ptr[T], n: int, T: type):
|
|
B = type(_basetype(T))
|
|
res = B()
|
|
for i in range(n):
|
|
el = p[i]
|
|
if T is complex or T is complex64:
|
|
res += el.real*el.real + el.imag*el.imag
|
|
else:
|
|
res += el * el
|
|
return res
|
|
|
|
def lstsq(a, b, rcond = None):
|
|
dtype = type(
|
|
coerce(
|
|
type(asarray(a).data[0]),
|
|
type(asarray(b).data[0])))
|
|
a = _asarray(a, dtype=dtype)
|
|
b = _asarray(b, dtype=dtype)
|
|
|
|
if a.ndim != 2:
|
|
compile_error("lstsq argument 'a' must be 2-dimensional")
|
|
|
|
if b.ndim == 1:
|
|
b2 = b.reshape(b.shape[0], 1)
|
|
elif b.ndim == 2:
|
|
b2 = b
|
|
else:
|
|
compile_error("lstsq argument 'b' must be 1- or 2-dimensional")
|
|
|
|
B = type(_basetype(a.dtype))
|
|
m, n = a.shape
|
|
m2, nrhs = b2.shape
|
|
|
|
if m != m2:
|
|
raise LinAlgError('Incompatible dimensions')
|
|
|
|
if rcond is None:
|
|
rc = eps(B) * cast(max(m, n), B)
|
|
else:
|
|
rc = cast(rcond, B)
|
|
|
|
if nrhs == 0:
|
|
b2 = ndarray((m, 1), Ptr[b2.dtype](m))
|
|
str.memset(b2.data.as_byte(), byte(0), b2.nbytes)
|
|
safe_nrhs = 1
|
|
else:
|
|
safe_nrhs = nrhs
|
|
|
|
excess = m - n
|
|
params = GelsdParams[a.dtype, B](fint(m), fint(n), fint(safe_nrhs))
|
|
|
|
x = empty((n, safe_nrhs), dtype=a.dtype)
|
|
r = empty(safe_nrhs, dtype=B)
|
|
s = empty(min(m, n), dtype=B)
|
|
|
|
a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2])
|
|
b_in = LinearizeData(safe_nrhs, m, b2.strides[-1], b2.strides[-2], max(m, n))
|
|
x_out = LinearizeData(safe_nrhs, n, x.strides[-1], x.strides[-2], max(m, n))
|
|
r_out = LinearizeData(1, safe_nrhs, 1, r.strides[-1])
|
|
s_out = LinearizeData(1, min(m, n), 1, s.strides[-1])
|
|
|
|
a_ptr = a.data
|
|
b_ptr = b2.data
|
|
x_ptr = x.data
|
|
r_ptr = r.data
|
|
s_ptr = s.data
|
|
|
|
a_in.linearize(params.A, a_ptr)
|
|
b_in.linearize(params.B_, b_ptr)
|
|
params.RCOND = __ptr__(rc)
|
|
not_ok = params.call()
|
|
error_occured = False
|
|
|
|
if not not_ok:
|
|
x_out.delinearize(x_ptr, params.B_)
|
|
rank = int(params.RANK)
|
|
s_out.delinearize(s_ptr, params.S)
|
|
|
|
if excess >= 0 and int(params.RANK) == n:
|
|
components = params.B_ + n
|
|
for i in range(safe_nrhs):
|
|
vector = components + i*m
|
|
r_ptr[i] = _abs2(vector, excess)
|
|
else:
|
|
r_out.nan_matrix(r_ptr)
|
|
else:
|
|
x_out.nan_matrix(x_ptr)
|
|
r_out.nan_matrix(r_ptr)
|
|
rank = -1
|
|
s_out.nan_matrix(s_ptr)
|
|
error_occured = True
|
|
|
|
params.release()
|
|
|
|
if error_occured:
|
|
raise LinAlgError("SVD did not converge in Linear Least Squares")
|
|
|
|
if m == 0:
|
|
str.memset(x.data.as_byte(), byte(0), x.nbytes)
|
|
|
|
if nrhs == 0:
|
|
free(x.data)
|
|
free(r.data)
|
|
x = ndarray((x.shape[0], 0), Ptr[x.dtype]())
|
|
r = ndarray((0,), Ptr[r.dtype]())
|
|
|
|
if b.ndim == 1:
|
|
x1 = x.reshape(x.shape[0])
|
|
else:
|
|
x1 = x
|
|
|
|
if r.size and (rank != n or m <= n):
|
|
free(r.data)
|
|
r = ndarray((0,), Ptr[r.dtype]())
|
|
|
|
return x1, r, rank, s
|
|
|
|
|
|
###################
|
|
# Matrix Multiply #
|
|
###################
|
|
|
|
def _dot_noblas(_ip1: Ptr[T1], is1: int, _ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
|
|
T1: type, T2: type, T3: type):
|
|
ip1 = _ip1.as_byte()
|
|
ip2 = _ip2.as_byte()
|
|
ans = zero(T3)
|
|
|
|
for i in range(n):
|
|
e1 = Ptr[T1](ip1)[0]
|
|
e2 = Ptr[T2](ip2)[0]
|
|
ans += cast(e1, T3) * cast(e2, T3)
|
|
ip1 += is1
|
|
ip2 += is2
|
|
|
|
op[0] = ans
|
|
|
|
_CBLAS_INT_MAX: Static[int] = 0x7fffffff
|
|
|
|
def _blas_stride(stride: int, itemsize: int):
|
|
if stride > 0 and stride % itemsize == 0:
|
|
stride //= itemsize
|
|
if stride <= _CBLAS_INT_MAX:
|
|
return stride
|
|
return 0
|
|
|
|
def _dot(ip1: Ptr[T1], is1: int, ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int,
|
|
T1: type, T2: type, T3: type):
|
|
ib1 = _blas_stride(is1, sizeof(T1))
|
|
ib2 = _blas_stride(is2, sizeof(T2))
|
|
args = (fint(n), ip1.as_byte(), fint(ib1), ip2.as_byte(), fint(ib2))
|
|
|
|
if ib1 == 0 or ib2 == 0:
|
|
_dot_noblas(ip1, is1, ip2, is2, op, n)
|
|
else:
|
|
if T1 is float and T2 is float and T3 is float:
|
|
op[0] = cblas_ddot(*args)
|
|
elif T1 is float32 and T2 is float32 and T3 is float32:
|
|
op[0] = cblas_sdot(*args)
|
|
elif T1 is complex and T2 is complex and T3 is complex:
|
|
cblas_zdotu_sub(*args, op)
|
|
elif T1 is complex64 and T2 is complex64 and T3 is complex64:
|
|
cblas_cdotu_sub(*args, op)
|
|
else:
|
|
_dot_noblas(ip1, is1, ip2, is2, op, n)
|
|
|
|
BLAS_MAXSIZE: Static[int] = 2147483646 # 2^31 - 1
|
|
|
|
def _is_blasable2d(byte_stride1: int,
|
|
byte_stride2: int,
|
|
d1: int,
|
|
d2: int,
|
|
itemsize: int):
|
|
unit_stride1 = byte_stride1 // itemsize
|
|
if byte_stride2 != itemsize:
|
|
return False
|
|
if (byte_stride1 % itemsize == 0 and
|
|
unit_stride1 >= d2 and
|
|
unit_stride1 <= BLAS_MAXSIZE):
|
|
return True
|
|
return False
|
|
|
|
def _gemv(ip1: Ptr[T], is1_m: int, is1_n: int,
|
|
ip2: Ptr[T], is2_n: int,
|
|
op: Ptr[T], op_m: int,
|
|
m: int, n: int, T: type):
|
|
itemsize = sizeof(T)
|
|
if _is_blasable2d(is1_m, is1_n, m, n, itemsize):
|
|
order = CBLAS_COL_MAJOR
|
|
lda = is1_m // itemsize
|
|
else:
|
|
order = CBLAS_ROW_MAJOR
|
|
lda = is1_n // itemsize
|
|
|
|
one_dtype = cast(1, T)
|
|
zero_dtype = cast(0, T)
|
|
|
|
if T is float or T is float32:
|
|
one = one_dtype
|
|
zero = zero_dtype
|
|
else:
|
|
one = __ptr__(one_dtype).as_byte()
|
|
zero = __ptr__(zero_dtype).as_byte()
|
|
|
|
args = (order, CBLAS_TRANS, fint(n), fint(m),
|
|
one, ip1.as_byte(), fint(lda), ip2.as_byte(),
|
|
fint(is2_n // itemsize), zero, op.as_byte(),
|
|
fint(op_m // itemsize))
|
|
|
|
if T is float:
|
|
cblas_dgemv(*args)
|
|
elif T is float32:
|
|
cblas_sgemv(*args)
|
|
elif T is complex:
|
|
cblas_zgemv(*args)
|
|
elif T is complex64:
|
|
cblas_cgemv(*args)
|
|
else:
|
|
compile_error("[internal error] bad input type")
|
|
|
|
def _matmul_matrixmatrix(ip1: Ptr[T], is1_m: int, is1_n: int,
|
|
ip2: Ptr[T], is2_n: int, is2_p: int,
|
|
op: Ptr[T], os_m: int, os_p: int,
|
|
m: int, n: int, p: int, T: type):
|
|
order = CBLAS_ROW_MAJOR
|
|
itemsize = sizeof(T)
|
|
ldc = os_m // itemsize
|
|
|
|
if _is_blasable2d(is1_m, is1_n, m, n, itemsize):
|
|
trans1 = CBLAS_NO_TRANS
|
|
lda = is1_m // itemsize
|
|
else:
|
|
trans1 = CBLAS_TRANS
|
|
lda = is1_n // itemsize
|
|
|
|
if _is_blasable2d(is2_n, is2_p, n, p, itemsize):
|
|
trans2 = CBLAS_NO_TRANS
|
|
ldb = is2_n // itemsize
|
|
else:
|
|
trans2 = CBLAS_TRANS
|
|
ldb = is2_p // itemsize
|
|
|
|
one_dtype = cast(1, T)
|
|
zero_dtype = cast(0, T)
|
|
|
|
if T is float or T is float32:
|
|
one = one_dtype
|
|
zero = zero_dtype
|
|
else:
|
|
one = __ptr__(one_dtype).as_byte()
|
|
zero = __ptr__(zero_dtype).as_byte()
|
|
|
|
if (ip1 == ip2 and
|
|
m == p and
|
|
is1_m == is2_p and
|
|
is1_n == is2_n and
|
|
trans1 != trans2):
|
|
|
|
ld = lda if trans1 == CBLAS_NO_TRANS else ldb
|
|
args = (order, CBLAS_UPPER, trans1, fint(p),
|
|
fint(n), one, ip1, fint(ld),
|
|
zero, op, fint(ldc))
|
|
|
|
if T is float:
|
|
cblas_dsyrk(*args)
|
|
elif T is float32:
|
|
cblas_ssyrk(*args)
|
|
elif T is complex:
|
|
cblas_zsyrk(*args)
|
|
elif T is complex64:
|
|
cblas_csyrk(*args)
|
|
else:
|
|
compile_error("[internal error] bad input type")
|
|
|
|
# Copy the triangle
|
|
for i in range(p):
|
|
for j in range(i + 1, p):
|
|
op[j * ldc + i] = op[i * ldc + j]
|
|
else:
|
|
args = (order, trans1, trans2, fint(m), fint(p), fint(n),
|
|
one, ip1, fint(lda), ip2,
|
|
fint(ldb), zero, op, fint(ldc))
|
|
|
|
if T is float:
|
|
cblas_dgemm(*args)
|
|
elif T is float32:
|
|
cblas_sgemm(*args)
|
|
elif T is complex:
|
|
cblas_zgemm(*args)
|
|
elif T is complex64:
|
|
cblas_cgemm(*args)
|
|
else:
|
|
compile_error("[internal error] bad input type")
|
|
|
|
def _matmul_inner_noblas(_ip1: Ptr[T1], is1_m: int, is1_n: int,
|
|
_ip2: Ptr[T2], is2_n: int, is2_p: int,
|
|
_op: Ptr[T3], os_m: int, os_p: int,
|
|
dm: int, dn: int, dp: int, T1: type,
|
|
T2: type, T3: type):
|
|
ib1_n = is1_n * dn
|
|
ib2_n = is2_n * dn
|
|
ib2_p = is2_p * dp
|
|
ob_p = os_p * dp
|
|
ip1 = _ip1.as_byte()
|
|
ip2 = _ip2.as_byte()
|
|
op = _op.as_byte()
|
|
|
|
if T3 is bool:
|
|
for m in range(dm):
|
|
for p in range(dp):
|
|
ip1tmp = ip1
|
|
ip2tmp = ip2
|
|
Ptr[T3](op)[0] = False
|
|
for n in range(dn):
|
|
val1 = Ptr[T1](ip1tmp)[0]
|
|
val2 = Ptr[T2](ip2tmp)[0]
|
|
if val1 and val2:
|
|
Ptr[T3](op)[0] = True
|
|
break
|
|
ip2tmp += is2_n
|
|
ip1tmp += is1_n
|
|
op += os_p
|
|
ip2 += is2_p
|
|
op -= ob_p
|
|
ip2 -= ib2_p
|
|
ip1 += is1_m
|
|
op += os_m
|
|
else:
|
|
for m in range(dm):
|
|
for p in range(dp):
|
|
for n in range(dn):
|
|
val1 = cast(Ptr[T1](ip1)[0], T3)
|
|
val2 = cast(Ptr[T2](ip2)[0], T3)
|
|
Ptr[T3](op)[0] += val1 * val2
|
|
ip2 += is2_n
|
|
ip1 += is1_n
|
|
ip1 -= ib1_n
|
|
ip2 -= ib2_n
|
|
op += os_p
|
|
ip2 += is2_p
|
|
op -= ob_p
|
|
ip2 -= ib2_p
|
|
ip1 += is1_m
|
|
op += os_m
|
|
|
|
def _matmul(A, B, C):
|
|
# Caller ensures arrays are broadcasted and of correct shape.
|
|
# A -> (m, n)
|
|
# B -> (n, p)
|
|
# C -> (m, p)
|
|
dm = A.shape[-2]
|
|
dn = A.shape[-1]
|
|
dp = B.shape[-1]
|
|
|
|
is1_m = A.strides[-2]
|
|
is1_n = A.strides[-1]
|
|
is2_n = B.strides[-2]
|
|
is2_p = B.strides[-1]
|
|
os_m = C.strides[-2]
|
|
os_p = C.strides[-1]
|
|
|
|
sz = sizeof(A.dtype)
|
|
special_case = (dm == 1 or dn == 1 or dp == 1)
|
|
any_zero_dim = (dm == 0 or dn == 0 or dp == 0)
|
|
scalar_out = (dm == 1 and dp == 1)
|
|
scalar_vec = (dn == 1 and (dp == 1 or dm == 1))
|
|
too_big_for_blas = (dm > BLAS_MAXSIZE or dn > BLAS_MAXSIZE or dp > BLAS_MAXSIZE)
|
|
i1_c_blasable = _is_blasable2d(is1_m, is1_n, dm, dn, sz)
|
|
i2_c_blasable = _is_blasable2d(is2_n, is2_p, dn, dp, sz)
|
|
i1_f_blasable = _is_blasable2d(is1_n, is1_m, dn, dm, sz)
|
|
i2_f_blasable = _is_blasable2d(is2_p, is2_n, dp, dn, sz)
|
|
i1blasable = i1_c_blasable or i1_f_blasable
|
|
i2blasable = i2_c_blasable or i2_f_blasable
|
|
o_c_blasable = _is_blasable2d(os_m, os_p, dm, dp, sz)
|
|
o_f_blasable = _is_blasable2d(os_p, os_m, dp, dm, sz)
|
|
vector_matrix = (dm == 1 and i2blasable and
|
|
_is_blasable2d(is1_n, sz, dn, 1, sz))
|
|
matrix_vector = (dp == 1 and i1blasable and
|
|
_is_blasable2d(is2_n, sz, dn, 1, sz))
|
|
|
|
for idx in multirange(C.shape[:-2]):
|
|
ip1 = A._ptr(idx + (0, 0), broadcast=(A.ndim > 2))
|
|
ip2 = B._ptr(idx + (0, 0), broadcast=(B.ndim > 2))
|
|
op = C._ptr(idx + (0, 0))
|
|
|
|
if not ((A.dtype is B.dtype and B.dtype is C.dtype) and
|
|
(A.dtype is float or A.dtype is float32 or
|
|
A.dtype is complex or A.dtype is complex64)):
|
|
_matmul_inner_noblas(ip1, is1_m, is1_n,
|
|
ip2, is2_n, is2_p,
|
|
op, os_m, os_p,
|
|
dm, dn, dp)
|
|
else:
|
|
if too_big_for_blas or any_zero_dim:
|
|
_matmul_inner_noblas(ip1, is1_m, is1_n,
|
|
ip2, is2_n, is2_p,
|
|
op, os_m, os_p,
|
|
dm, dn, dp)
|
|
elif special_case:
|
|
if scalar_out:
|
|
_dot(ip1, is1_n, ip2, is2_n, op, dn)
|
|
elif scalar_vec:
|
|
_matmul_inner_noblas(ip1, is1_m, is1_n,
|
|
ip2, is2_n, is2_p,
|
|
op, os_m, os_p,
|
|
dm, dn, dp)
|
|
elif vector_matrix:
|
|
_gemv(ip2, is2_p, is2_n, ip1, is1_n, op, os_p, dp, dn)
|
|
elif matrix_vector:
|
|
_gemv(ip1, is1_m, is1_n, ip2, is2_n, op, os_m, dm, dn)
|
|
else:
|
|
_matmul_inner_noblas(ip1, is1_m, is1_n,
|
|
ip2, is2_n, is2_p,
|
|
op, os_m, os_p,
|
|
dm, dn, dp)
|
|
else:
|
|
if i1blasable and i2blasable and o_c_blasable:
|
|
_matmul_matrixmatrix(ip1, is1_m, is1_n,
|
|
ip2, is2_n, is2_p,
|
|
op, os_m, os_p,
|
|
dm, dn, dp)
|
|
elif i1blasable and i2blasable and o_f_blasable:
|
|
_matmul_matrixmatrix(ip2, is2_p, is2_n,
|
|
ip1, is1_n, is1_m,
|
|
op, os_p, os_m,
|
|
dp, dn, dm)
|
|
else:
|
|
_matmul_inner_noblas(ip1, is1_m, is1_n,
|
|
ip2, is2_n, is2_p,
|
|
op, os_m, os_p,
|
|
dm, dn, dp)
|
|
|
|
def _check_out(out, ans_shape, dtype: type):
|
|
if out.dtype is not dtype:
|
|
compile_error("'out' array has incorrect type '" + out.dtype.__name__ + "'")
|
|
|
|
if staticlen(out.shape) != staticlen(ans_shape):
|
|
compile_error("'out' array has incorrect number of dimensions")
|
|
|
|
if out.shape != ans_shape:
|
|
raise ValueError("'out' array has incorrect shape")
|
|
|
|
def _matmul_ufunc(x1, x2, out = None, dtype: type = NoneType):
|
|
if not isinstance(x1, ndarray) or not isinstance(x2, ndarray):
|
|
y1 = asarray(x1)
|
|
y2 = asarray(x2)
|
|
return _matmul_ufunc(y1, y2, out=out, dtype=dtype)
|
|
|
|
T1 = x1.dtype
|
|
T2 = x2.dtype
|
|
x1d: Static[int] = x1.ndim
|
|
x2d: Static[int] = x2.ndim
|
|
|
|
if dtype is NoneType:
|
|
return _matmul_ufunc(x1, x2, out=out, dtype=type(coerce(T1, T2)))
|
|
|
|
if x1d == 0:
|
|
compile_error("first argument is 0-dimensional; must be at least 1-d")
|
|
|
|
if x2d == 0:
|
|
compile_error("second argument is 0-dimensional; must be at least 1-d")
|
|
|
|
if x1d == 1 and x2d == 1:
|
|
if x1.shape != x2.shape:
|
|
raise ValueError("matmul: vectors have different lengths")
|
|
op = zero(dtype)
|
|
_dot(x1.data, x1.strides[0], x2.data, x2.strides[0], __ptr__(op), x1.size)
|
|
return op
|
|
|
|
if x1d == 1:
|
|
y1 = x1.reshape((1, -1))
|
|
else:
|
|
y1 = x1
|
|
|
|
if x2d == 1:
|
|
y2 = x2.reshape((-1, 1))
|
|
else:
|
|
y2 = x2
|
|
|
|
y1s = y1.shape
|
|
y2s = y2.shape
|
|
y1d: Static[int] = y1.ndim
|
|
y2d: Static[int] = y2.ndim
|
|
|
|
base1s = y1s[:-2]
|
|
base2s = y2s[:-2]
|
|
mat1s = y1s[-2:]
|
|
mat2s = y2s[-2:]
|
|
|
|
m = mat1s[0]
|
|
k = mat1s[1]
|
|
n = mat2s[1]
|
|
|
|
if k != mat2s[0]:
|
|
raise ValueError("matmul: last dimension of first argument does not "
|
|
"match second-to-last dimension of second argument")
|
|
|
|
ans_base = broadcast_shapes(base1s, base2s)
|
|
if x1d == 1 and x2d == 1:
|
|
ans_shape = ans_base
|
|
elif x1d == 1:
|
|
ans_shape = ans_base + (mat2s[1],)
|
|
elif x2d == 1:
|
|
ans_shape = ans_base + (mat1s[0],)
|
|
else:
|
|
ans_shape = ans_base + (mat1s[0], mat2s[1])
|
|
|
|
if out is None:
|
|
ans = zeros(ans_shape, dtype=dtype)
|
|
elif isinstance(out, ndarray):
|
|
_check_out(out, ans_shape, dtype)
|
|
ans = out
|
|
ans.map(lambda _: cast(0, ans.dtype), inplace=True)
|
|
else:
|
|
compile_error("'out' must be an ndarray")
|
|
|
|
if x1d == 1 and x2d == 1:
|
|
_matmul(y1, y2, ans.reshape(1, 1))
|
|
elif x1d == 1:
|
|
_matmul(y1.reshape((1,) * (x2d - 1) + y1.shape), y2, ans.reshape(ans.shape + (1,)))
|
|
elif x2d == 1:
|
|
_matmul(y1, y2.reshape((1,) * (x1d - 1) + y2.shape), ans.reshape(ans.shape + (1,)))
|
|
elif x1d > x2d:
|
|
_matmul(y1, y2.reshape((1,) * (x1d - x2d) + y2.shape), ans)
|
|
elif x1d < x2d:
|
|
_matmul(y1.reshape((1,) * (x2d - x1d) + y1.shape), y2, ans)
|
|
else:
|
|
_matmul(y1, y2, ans)
|
|
|
|
if out is not None:
|
|
return out
|
|
else:
|
|
if ans.ndim == 0:
|
|
return ans.data[0]
|
|
else:
|
|
return ans
|
|
|
|
def matmul(x1, x2):
|
|
return _matmul_ufunc(x1, x2)
|
|
|
|
def dot(a, b, out = None):
|
|
x1 = asarray(a)
|
|
x2 = asarray(b)
|
|
T1 = x1.dtype
|
|
T2 = x2.dtype
|
|
x1d: Static[int] = staticlen(x1.shape)
|
|
x2d: Static[int] = staticlen(x2.shape)
|
|
|
|
if x1d == 0 or x2d == 0:
|
|
return multiply(a, b, out=out)
|
|
|
|
if x1d <= 2 and x2d <= 2:
|
|
return _matmul_ufunc(a, b, out=out)
|
|
|
|
# most general case
|
|
x1s = x1.shape
|
|
x2s = x2.shape
|
|
m = x1s[-2]
|
|
k = x1s[-1]
|
|
n = x2s[-1]
|
|
|
|
if k != x2s[-2]:
|
|
raise ValueError("dot: last dimension of first argument does not "
|
|
"match second-to-last dimension of second argument")
|
|
|
|
dtype = type(coerce(T1, T2))
|
|
ans_shape = x1s[:-1] + x2s[:-2] + (x2s[-1],)
|
|
if out is None:
|
|
ans = zeros(ans_shape, dtype=dtype)
|
|
elif isinstance(out, ndarray):
|
|
_check_out(out, ans_shape, dtype)
|
|
ans = out
|
|
else:
|
|
compile_error("'out' must be an ndarray")
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
for i1 in range(k):
|
|
x1_idx = idx[:x1d-1] + (i1,)
|
|
x2_idx = idx[x1d-1:-1] + (i1, idx[-1])
|
|
p1 = x1._ptr(x1_idx)
|
|
p2 = x2._ptr(x2_idx)
|
|
q[0] += cast(p1[0], dtype) * cast(p2[0], dtype)
|
|
|
|
return ans
|
|
|
|
|
|
##################
|
|
# Other Routines #
|
|
##################
|
|
|
|
def _matrix_power(a, n: int):
|
|
a = asarray(a)
|
|
dtype = a.dtype
|
|
s = a.shape
|
|
m = _square_rows(a)
|
|
|
|
if n == 0:
|
|
if staticlen(s) == 2:
|
|
return eye(m, dtype=dtype)
|
|
else:
|
|
b = empty_like(a)
|
|
for idx in multirange(s[:-2]):
|
|
p = b._ptr(idx + (0, 0))
|
|
for i in range(m):
|
|
for j in range(m):
|
|
val = cast(1, dtype) if i == j else zero(dtype)
|
|
(p + (i * m + j))[0] = val
|
|
return b
|
|
elif n < 0:
|
|
if type(inv(a)) is type(a):
|
|
a = inv(a)
|
|
n = abs(n)
|
|
else:
|
|
raise ValueError(
|
|
"cannot take integral matrix to non-constant negative "
|
|
"power; use 'matrix_power(inv(a), abs(n))' instead"
|
|
)
|
|
|
|
if n == 1:
|
|
return a
|
|
elif n == 2:
|
|
return matmul(a, a)
|
|
elif n == 3:
|
|
return matmul(matmul(a, a), a)
|
|
|
|
result = a
|
|
z = a
|
|
first = True
|
|
have_result = False
|
|
|
|
while n > 0:
|
|
if first:
|
|
first = False
|
|
else:
|
|
z = matmul(z, z)
|
|
first = False
|
|
|
|
n, bit = divmod(n, 2)
|
|
if bit:
|
|
if have_result:
|
|
result = matmul(result, z)
|
|
else:
|
|
result = z
|
|
have_result = True
|
|
|
|
return result
|
|
|
|
def matrix_power(a, n: int):
|
|
return _matrix_power(a, n)
|
|
|
|
@overload
|
|
def matrix_power(a, n: Static[int]):
|
|
nonstatic = lambda x: x # hack to get non-static argument
|
|
if n >= 0:
|
|
return _matrix_power(a, n)
|
|
else:
|
|
return _matrix_power(inv(a), -n)
|
|
|
|
def _multi_dot_three(A, B, C, out=None):
|
|
a0, a1b0 = A.shape
|
|
b1c0, c1 = C.shape
|
|
# cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
|
|
cost1 = a0 * b1c0 * (a1b0 + c1)
|
|
# cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
|
|
cost2 = a1b0 * c1 * (a0 + b1c0)
|
|
|
|
if cost1 < cost2:
|
|
return dot(dot(A, B), C, out=out)
|
|
else:
|
|
return dot(A, dot(B, C), out=out)
|
|
|
|
def _multi_dot(arrays, order, i, j, out=None):
|
|
if i == j:
|
|
#assert out is None
|
|
return arrays[i]
|
|
else:
|
|
return dot(_multi_dot(arrays, order, i, order[i, j]),
|
|
_multi_dot(arrays, order, order[i, j] + 1, j),
|
|
out=out)
|
|
|
|
def _multi_dot_matrix_chain_order(arrays, return_costs: Static[int] = False):
|
|
n = len(arrays)
|
|
p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]]
|
|
m = zeros((n, n), dtype=float)
|
|
s = empty((n, n), dtype=int)
|
|
|
|
for l in range(1, n):
|
|
for i in range(n - l):
|
|
j = i + l
|
|
m[i, j] = inf(float)
|
|
for k in range(i, j):
|
|
q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1]
|
|
if q < m[i, j]:
|
|
m[i, j] = q
|
|
s[i, j] = k # Note that Cormen uses 1-based index
|
|
|
|
if return_costs:
|
|
return (s, m)
|
|
else:
|
|
return s
|
|
|
|
def multi_dot(arrays, out = None):
|
|
n = len(arrays)
|
|
|
|
if isinstance(arrays, Tuple):
|
|
if staticlen(arrays) < 2:
|
|
compile_error("Expecting at least two arrays.")
|
|
|
|
if staticlen(arrays) == 2:
|
|
return dot(arrays[0], arrays[1], out=out)
|
|
else:
|
|
if n < 2:
|
|
raise ValueError("Expecting at least two arrays.")
|
|
|
|
if n == 2:
|
|
return dot(arrays[0], arrays[1], out=out)
|
|
|
|
ndim_first: Static[int] = staticlen(arrays[0].shape)
|
|
ndim_last: Static[int] = staticlen(arrays[-1].shape)
|
|
|
|
def fix_first(arr):
|
|
arr = asarray(arr)
|
|
if staticlen(arr.shape) == 1:
|
|
return atleast_2d(arr)
|
|
else:
|
|
return arr
|
|
|
|
def fix_last(arr):
|
|
arr = asarray(arr)
|
|
if staticlen(arr.shape) == 1:
|
|
return atleast_2d(arr).T
|
|
else:
|
|
return arr
|
|
|
|
if isinstance(arrays, Tuple):
|
|
xarrays = (fix_first(arrays[0]),) + tuple(asarray(arr) for arr in arrays[1:-1]) + (fix_last(arrays[-1]),)
|
|
else:
|
|
xarrays = [asarray(arr) for arr in arrays]
|
|
xarrays[0] = fix_first(xarrays[0])
|
|
xarrays[-1] = fix_last(xarrays[-1])
|
|
|
|
for arr in xarrays:
|
|
if staticlen(arr.shape) != 2:
|
|
compile_error("1-dimensional array given. Array must be two-dimensional")
|
|
|
|
if n == 3:
|
|
result = _multi_dot_three(xarrays[0], xarrays[1], xarrays[2], out=out)
|
|
else:
|
|
order = _multi_dot_matrix_chain_order(xarrays)
|
|
result = _multi_dot(xarrays, order, 0, n - 1, out=out)
|
|
|
|
if ndim_first == 1 and ndim_last == 1:
|
|
return result[0, 0] # scalar
|
|
elif ndim_first == 1 or ndim_last == 1:
|
|
return result.ravel() # 1-D
|
|
else:
|
|
return result
|
|
|
|
def _vdot_ptr(x: Ptr[T1], y: Ptr[T2], n: int, incx: int, incy: int, T1: type, T2: type):
|
|
if T1 is float and T2 is float:
|
|
return cblas_ddot(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy))
|
|
elif T1 is float32 and T2 is float32:
|
|
return cblas_sdot(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy))
|
|
elif T1 is complex and T2 is complex:
|
|
z = complex()
|
|
cblas_zdotc_sub(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy), __ptr__(z).as_byte())
|
|
return z
|
|
elif T1 is complex64 and T2 is complex64:
|
|
z = complex64()
|
|
cblas_cdotc_sub(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy), __ptr__(z).as_byte())
|
|
return z
|
|
else:
|
|
TR = type(coerce(T1, T2))
|
|
z = TR()
|
|
for _ in range(n):
|
|
a = x[0]
|
|
b = y[0]
|
|
|
|
if hasattr(a, "conjugate"):
|
|
a = a.conjugate()
|
|
|
|
z += cast(a, TR) * cast(b, TR)
|
|
x += incx
|
|
y += incy
|
|
|
|
return z
|
|
|
|
def vdot(a, b):
|
|
def mismatch():
|
|
raise ValueError("vdot: inputs have different sizes")
|
|
|
|
a = asarray(a)
|
|
b = asarray(b)
|
|
|
|
if staticlen(a.shape) == 0 and staticlen(b.shape) == 0:
|
|
x = a.data[0]
|
|
y = b.data[0]
|
|
TR = coerce(type(x), type(y))
|
|
|
|
if hasattr(x, "conjugate"):
|
|
x = x.conjugate()
|
|
|
|
return cast(x, TR) * cast(y, TR)
|
|
|
|
n = a.size
|
|
if n != b.size:
|
|
mismatch()
|
|
|
|
if staticlen(a.shape) == 1 and staticlen(b.shape) == 1:
|
|
inca = a.strides[0] // a.itemsize
|
|
incb = b.strides[0] // b.itemsize
|
|
return _vdot_ptr(a.data, b.data, n, inca, incb)
|
|
|
|
if a._contig_match(b):
|
|
return _vdot_ptr(a.data, b.data, n, 1, 1)
|
|
|
|
T1 = a.dtype
|
|
T2 = b.dtype
|
|
TR = type(coerce(T1, T2))
|
|
z = TR()
|
|
|
|
for e1, e2 in zip(a.flat, b.flat):
|
|
if hasattr(e1, "conjugate"):
|
|
e1 = e1.conjugate()
|
|
z += cast(e1, TR) * cast(e2, TR)
|
|
|
|
return z
|
|
|
|
def outer(a, b, out = None):
|
|
a = asarray(a)
|
|
b = asarray(b)
|
|
return multiply(a.ravel()[:, None], b.ravel()[None, :], out)
|
|
|
|
def inner(a, b):
|
|
a = asarray(a)
|
|
b = asarray(b)
|
|
shape_a = a.shape
|
|
shape_b = b.shape
|
|
|
|
if staticlen(shape_a) == 0 or staticlen(shape_b) == 0:
|
|
return multiply(a, b)
|
|
|
|
if staticlen(shape_a) == 1 and staticlen(shape_b) == 1:
|
|
return dot(a, b)
|
|
|
|
n = shape_a[-1]
|
|
if n != shape_b[-1]:
|
|
raise ValueError("inner: mismatch in last dimension of inputs")
|
|
|
|
shape_cut_a = shape_a[:-1]
|
|
shape_cut_b = shape_b[:-1]
|
|
out_shape = shape_cut_a + shape_cut_b
|
|
dtype = type(coerce(a.dtype, b.dtype))
|
|
out = empty(out_shape, dtype=dtype)
|
|
|
|
for idx1 in multirange(shape_cut_a):
|
|
for idx2 in multirange(shape_cut_b):
|
|
q = out._ptr(idx1 + idx2)
|
|
q[0] = zero(dtype)
|
|
for r in range(n):
|
|
xa = a._ptr(idx1 + (r,))[0]
|
|
xb = b._ptr(idx2 + (r,))[0]
|
|
xc = cast(xa, dtype) * cast(xb, dtype)
|
|
q[0] += xc
|
|
|
|
return out
|
|
|
|
def _tensordot(a, b, axes, ndim: Static[int] = -1):
|
|
def get_axes(axes):
|
|
if isinstance(axes, int):
|
|
axes_a = list(range(-axes, 0))
|
|
axes_b = list(range(0, axes))
|
|
na = len(axes_a)
|
|
nb = len(axes_b)
|
|
return axes_a, axes_b
|
|
|
|
axes_a, axes_b = axes
|
|
|
|
if isinstance(axes_a, int):
|
|
xa = [axes_a]
|
|
else:
|
|
xa = list(axes_a)
|
|
|
|
if isinstance(axes_b, int):
|
|
xb = [axes_b]
|
|
else:
|
|
xb = list(axes_b)
|
|
|
|
return xa, xb
|
|
|
|
axes_a, axes_b = get_axes(axes)
|
|
na = len(axes_a)
|
|
nb = len(axes_b)
|
|
|
|
as_ = List[int](a.shape)
|
|
nda = a.ndim
|
|
bs = List[int](b.shape)
|
|
ndb = b.ndim
|
|
equal = True
|
|
if na != nb:
|
|
equal = False
|
|
else:
|
|
for k in range(na):
|
|
if as_[axes_a[k]] != bs[axes_b[k]]:
|
|
equal = False
|
|
break
|
|
if axes_a[k] < 0:
|
|
axes_a[k] += nda
|
|
if axes_b[k] < 0:
|
|
axes_b[k] += ndb
|
|
if not equal:
|
|
raise ValueError("shape-mismatch for sum")
|
|
|
|
notin = [k for k in range(nda) if k not in axes_a]
|
|
newaxes_a = notin + axes_a
|
|
|
|
N2 = 1
|
|
for axis in axes_a:
|
|
N2 *= as_[axis]
|
|
|
|
M2 = 1
|
|
for ax in notin:
|
|
M2 *= as_[ax]
|
|
|
|
newshape_a = (M2, N2)
|
|
if ndim >= 0:
|
|
olda = [as_[axis] for axis in notin]
|
|
else:
|
|
olda = None
|
|
|
|
notin = [k for k in range(ndb) if k not in axes_b]
|
|
newaxes_b = axes_b + notin
|
|
|
|
N2 = 1
|
|
for axis in axes_b:
|
|
N2 *= bs[axis]
|
|
|
|
M2 = 1
|
|
for ax in notin:
|
|
M2 *= bs[ax]
|
|
|
|
newshape_b = (N2, M2)
|
|
if ndim >= 0:
|
|
oldb = [bs[axis] for axis in notin]
|
|
else:
|
|
oldb = None
|
|
|
|
at = a.transpose(newaxes_a).reshape(newshape_a)
|
|
bt = b.transpose(newaxes_b).reshape(newshape_b)
|
|
res = dot(at, bt)
|
|
# NOTE: 'olda + oldb' length is not known at compile-time,
|
|
# so cannot reshape unless 'ndim' is given.
|
|
if ndim >= 0:
|
|
newshape = (0,) * ndim
|
|
pnewshape = Ptr[int](__ptr__(newshape).as_byte())
|
|
i = 0
|
|
for j in olda:
|
|
pnewshape[i] = j
|
|
i += 1
|
|
for j in oldb:
|
|
pnewshape[i] = j
|
|
i += 1
|
|
return res.reshape(newshape)
|
|
else:
|
|
return res
|
|
|
|
def tensordot(a, b, axes):
|
|
a = asarray(a)
|
|
b = asarray(b)
|
|
return _tensordot(a, b, axes)
|
|
|
|
@overload
|
|
def tensordot(a, b, axes: Static[int] = 2):
|
|
if axes < 0:
|
|
return tensordot(a, b, axes=0)
|
|
|
|
a = asarray(a)
|
|
b = asarray(b)
|
|
|
|
if axes > a.ndim + 1 or axes > b.ndim + 1:
|
|
compile_error("'axes' too large for given arrays")
|
|
|
|
return _tensordot(a, b, axes=axes, ndim=(a.ndim + b.ndim - 2*axes))
|
|
|
|
def kron(a, b):
|
|
b = asarray(b)
|
|
ndb: Static[int] = staticlen(b.shape)
|
|
a = array(a, copy=False, ndmin=ndb)
|
|
nda: Static[int] = staticlen(a.shape)
|
|
nd: Static[int] = ndb if ndb >= nda else nda
|
|
|
|
if nda == 0 or ndb == 0:
|
|
return multiply(a, b)
|
|
|
|
as_ = a.shape
|
|
bs = b.shape
|
|
if not a.flags.contiguous:
|
|
a = reshape(a, as_)
|
|
if not b.flags.contiguous:
|
|
b = reshape(b, bs)
|
|
|
|
as_ = (1,)*(ndb-nda if ndb-nda >= 0 else 0) + as_
|
|
bs = (1,)*(nda-ndb if nda-ndb >= 0 else 0) + bs
|
|
|
|
a_arr = expand_dims(a, axis=tuple(i for i in staticrange(ndb-nda)))
|
|
b_arr = expand_dims(b, axis=tuple(i for i in staticrange(nda-ndb)))
|
|
|
|
a_arr = expand_dims(a_arr, axis=tuple(i for i in staticrange(1, nd*2, 2)))
|
|
b_arr = expand_dims(b_arr, axis=tuple(i for i in staticrange(0, nd*2, 2)))
|
|
result = multiply(a_arr, b_arr)
|
|
|
|
res_shape = tuple(as_[i] * bs[i] for i in staticrange(staticlen(as_)))
|
|
result = result.reshape(res_shape)
|
|
|
|
return result
|
|
|
|
def _trace_ufunc(a, offset: int = 0, axis1: int = 0, axis2: int = 1,
|
|
dtype: type = NoneType, out = None):
|
|
a = asarray(a)
|
|
|
|
if dtype is NoneType:
|
|
return _trace_ufunc(a, offset=offset, axis1=axis1, axis2=axis2,
|
|
dtype=a.dtype, out=out)
|
|
|
|
m, n = _rows_cols(a)
|
|
s = a.shape
|
|
ndim = a.ndim
|
|
|
|
axis1 = normalize_axis_index(axis1, a.ndim, "axis1")
|
|
axis2 = normalize_axis_index(axis2, a.ndim, "axis2")
|
|
if axis1 == axis2:
|
|
raise ValueError("axis1 and axis2 cannot be the same")
|
|
|
|
if staticlen(s) == 2:
|
|
i = 0
|
|
t = zero(dtype)
|
|
while ((offset >= 0 and (i < m and i + offset < n)) or
|
|
(offset < 0 and (i - offset < m and i < n))):
|
|
e = a._ptr((i, i + offset))[0] if offset >= 0 else a._ptr((i - offset, i))[0]
|
|
t += cast(e, dtype)
|
|
i += 1
|
|
|
|
if out is None:
|
|
return t
|
|
else:
|
|
if staticlen(out.shape) != 0:
|
|
compile_error("expected 0-dimensional output parameter")
|
|
|
|
if out.dtype is not dtype:
|
|
compile_error("output parameter has the wrong dtype")
|
|
|
|
out.data[0] = t
|
|
return out
|
|
else:
|
|
if axis1 > axis2:
|
|
axis1, axis2 = axis2, axis2
|
|
|
|
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
|
|
m, n = s[axis1], s[axis2]
|
|
|
|
if out is None:
|
|
ans = empty(ans_shape, dtype=dtype)
|
|
else:
|
|
if staticlen(out.shape) != staticlen(ans_shape):
|
|
compile_error("output parameter has the wrong number of dimensions")
|
|
|
|
if out.dtype is not dtype:
|
|
compile_error("output parameter has the wrong dtype")
|
|
|
|
if out.shape != ans_shape:
|
|
raise ValueError("output parameter has the wrong shape")
|
|
|
|
ans = out
|
|
|
|
for idx0 in multirange(ans_shape):
|
|
i = 0
|
|
t = zero(dtype)
|
|
|
|
while ((offset >= 0 and (i < m and i + offset < n)) or
|
|
(offset < 0 and (i - offset < m and i < n))):
|
|
idx = (tuple_insert(
|
|
tuple_insert(idx0, axis1, i),
|
|
axis2, i + offset) if offset >= 0 else
|
|
tuple_insert(
|
|
tuple_insert(idx0, axis1, i - offset),
|
|
axis2, i))
|
|
e = a._ptr(idx)[0]
|
|
t += cast(e, dtype)
|
|
i += 1
|
|
|
|
q = ans._ptr(idx0)
|
|
q[0] = t
|
|
|
|
return ans
|
|
|
|
def trace(x, offset: int = 0, dtype: type = NoneType):
|
|
return _trace_ufunc(x, offset=offset, dtype=dtype)
|
|
|
|
def tensorsolve(a, b, axes = None):
|
|
a = asarray(a)
|
|
b = asarray(b)
|
|
an: Static[int] = staticlen(a.shape)
|
|
bn: Static[int] = staticlen(b.shape)
|
|
|
|
if axes is not None:
|
|
allaxes = list(range(0, an))
|
|
for k in axes:
|
|
allaxes.remove(k)
|
|
allaxes.insert(an, k)
|
|
a = a.transpose(allaxes)
|
|
|
|
oldshape = a.shape[-(an-bn):]
|
|
prod = 1
|
|
for k in oldshape:
|
|
prod *= k
|
|
|
|
if a.size != prod ** 2:
|
|
raise LinAlgError(
|
|
"Input arrays must satisfy the requirement"
|
|
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
|
|
)
|
|
|
|
a = a.reshape(prod, prod)
|
|
b = b.ravel()
|
|
res = solve(a, b)
|
|
return res.reshape(oldshape)
|
|
|
|
def tensorinv(a, ind: int = 2):
|
|
a = asarray(a)
|
|
oldshape = a.shape
|
|
prod = 1
|
|
invshape = oldshape
|
|
|
|
if ind > 0:
|
|
# invshape = oldshape[ind:] + oldshape[:ind]
|
|
poldshape = Ptr[int](__ptr__(oldshape).as_byte())
|
|
pinvshape = Ptr[int](__ptr__(invshape).as_byte())
|
|
|
|
i = ind
|
|
j = 0
|
|
while i < len(oldshape):
|
|
pinvshape[j] = poldshape[i]
|
|
prod *= poldshape[i]
|
|
i += 1
|
|
j += 1
|
|
|
|
i = 0
|
|
while i < min(ind, len(oldshape)):
|
|
pinvshape[j] = poldshape[i]
|
|
i += 1
|
|
j += 1
|
|
else:
|
|
raise ValueError("Invalid ind argument.")
|
|
|
|
a = a.reshape(prod, -1)
|
|
ia = inv(a)
|
|
return ia.reshape(*invshape)
|
|
|
|
def _square(elem: T, T: type):
|
|
if T is complex or T is complex64:
|
|
return (elem.real * elem.real) + (elem.imag * elem.imag)
|
|
else:
|
|
return elem * elem
|
|
|
|
def _norm_cast(elem: T, T: type):
|
|
if T is float or T is float32 or T is complex or T is complex64:
|
|
return elem
|
|
else:
|
|
return cast(elem, float)
|
|
|
|
def _vector_norm_f(x: ndarray, axis: int, R: type):
|
|
if staticlen(x.shape) == 1:
|
|
dtype = x.dtype
|
|
sh = x.shape[0]
|
|
st = x.strides[0]
|
|
data = x.data
|
|
curr = R()
|
|
|
|
for i in range(sh):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
|
|
curr += _square(elem)
|
|
|
|
return util_sqrt(curr)
|
|
else:
|
|
s = x.shape
|
|
n = s[axis]
|
|
ans_shape = tuple_delete(x.shape, axis)
|
|
ans = empty(ans_shape, R)
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
curr = R()
|
|
|
|
for i in range(n):
|
|
idx1 = tuple_insert(idx, axis, i)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
curr += _square(elem)
|
|
|
|
q[0] = util_sqrt(curr)
|
|
|
|
return ans
|
|
|
|
def _norm_error_zero_max():
|
|
raise ValueError("zero-size array to reduction operation "
|
|
"maximum which has no identity")
|
|
|
|
def _norm_error_zero_min():
|
|
raise ValueError("zero-size array to reduction operation "
|
|
"minimum which has no identity")
|
|
|
|
def _vector_norm_inf(x: ndarray, axis: int, R: type):
|
|
sh = x.shape[0]
|
|
if sh == 0:
|
|
_norm_error_zero_max()
|
|
|
|
if staticlen(x.shape) == 1:
|
|
dtype = x.dtype
|
|
sh = x.shape[0]
|
|
st = x.strides[0]
|
|
data = x.data
|
|
curr = abs(_norm_cast(data[0]))
|
|
|
|
for i in range(1, sh):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
|
|
e = abs(elem)
|
|
if e > curr:
|
|
curr = e
|
|
|
|
return curr
|
|
else:
|
|
s = x.shape
|
|
n = s[axis]
|
|
ans_shape = tuple_delete(x.shape, axis)
|
|
ans = empty(ans_shape, R)
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
curr = abs(_norm_cast(x._ptr(idx1)[0]))
|
|
|
|
for i in range(1, n):
|
|
idx1 = tuple_insert(idx, axis, i)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
e = abs(elem)
|
|
if e > curr:
|
|
curr = e
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _vector_norm_ninf(x: ndarray, axis: int, R: type):
|
|
sh = x.shape[0]
|
|
if sh == 0:
|
|
_norm_error_zero_min()
|
|
|
|
if staticlen(x.shape) == 1:
|
|
dtype = x.dtype
|
|
st = x.strides[0]
|
|
data = x.data
|
|
curr = abs(_norm_cast(data[0]))
|
|
|
|
for i in range(1, sh):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
|
|
e = abs(elem)
|
|
if e < curr:
|
|
curr = e
|
|
|
|
return curr
|
|
else:
|
|
s = x.shape
|
|
n = s[axis]
|
|
ans_shape = tuple_delete(x.shape, axis)
|
|
ans = empty(ans_shape, R)
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
curr = abs(_norm_cast(x._ptr(idx1)[0]))
|
|
|
|
for i in range(1, n):
|
|
idx1 = tuple_insert(idx, axis, i)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
e = abs(elem)
|
|
if e < curr:
|
|
curr = e
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _vector_norm_0(x: ndarray, axis: int, R: type):
|
|
if staticlen(x.shape) == 1:
|
|
dtype = x.dtype
|
|
sh = x.shape[0]
|
|
st = x.strides[0]
|
|
data = x.data
|
|
curr = R()
|
|
|
|
for i in range(sh):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
|
|
if elem:
|
|
curr += R(1)
|
|
|
|
return curr
|
|
else:
|
|
s = x.shape
|
|
n = s[axis]
|
|
ans_shape = tuple_delete(x.shape, axis)
|
|
ans = empty(ans_shape, R)
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
curr = R()
|
|
|
|
for i in range(n):
|
|
idx1 = tuple_insert(idx, axis, i)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
if elem:
|
|
curr += R(1)
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _vector_norm_1(x: ndarray, axis: int, R: type):
|
|
if staticlen(x.shape) == 1:
|
|
dtype = x.dtype
|
|
sh = x.shape[0]
|
|
st = x.strides[0]
|
|
data = x.data
|
|
curr = R()
|
|
|
|
for i in range(sh):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
|
|
curr += abs(elem)
|
|
|
|
return curr
|
|
else:
|
|
s = x.shape
|
|
n = s[axis]
|
|
ans_shape = tuple_delete(x.shape, axis)
|
|
ans = empty(ans_shape, R)
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
curr = R()
|
|
|
|
for i in range(n):
|
|
idx1 = tuple_insert(idx, axis, i)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
curr += abs(elem)
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _vector_norm_g(x: ndarray, axis: int, g, R: type):
|
|
if staticlen(x.shape) == 1:
|
|
dtype = x.dtype
|
|
sh = x.shape[0]
|
|
st = x.strides[0]
|
|
data = x.data
|
|
curr = R()
|
|
|
|
for i in range(sh):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0])
|
|
curr += abs(elem) ** R(g)
|
|
|
|
return curr ** (R(1) / R(g))
|
|
else:
|
|
s = x.shape
|
|
n = s[axis]
|
|
ans_shape = tuple_delete(x.shape, axis)
|
|
ans = empty(ans_shape, R)
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
idx1 = tuple_insert(idx, axis, 0)
|
|
curr = R()
|
|
|
|
for i in range(n):
|
|
idx1 = tuple_insert(idx, axis, i)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
curr += abs(elem) ** R(g)
|
|
|
|
q[0] = curr ** (R(1) / R(g))
|
|
|
|
return ans
|
|
|
|
def _matrix_norm_f(x: ndarray, row_axis: int, col_axis: int, R: type):
|
|
s = x.shape
|
|
if staticlen(s) == 2:
|
|
dtype = x.dtype
|
|
sh1, sh2 = s
|
|
data = x.data
|
|
curr = R()
|
|
|
|
for i in range(sh1):
|
|
for j in range(sh2):
|
|
elem = _norm_cast(x._ptr((i, j))[0])
|
|
curr += _square(elem)
|
|
|
|
return util_sqrt(curr)
|
|
else:
|
|
axis1, axis2 = row_axis, col_axis
|
|
if axis1 > axis2:
|
|
axis1, axis2 = axis2, axis1
|
|
|
|
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
|
|
ans = empty(ans_shape, R)
|
|
sh1 = s[axis1]
|
|
sh2 = s[axis2]
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
curr = R()
|
|
|
|
for i in range(sh1):
|
|
for j in range(sh2):
|
|
idx1 = tuple_insert(
|
|
tuple_insert(idx, axis1, i), axis2, j)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
curr += _square(elem)
|
|
|
|
q[0] = util_sqrt(curr)
|
|
|
|
return ans
|
|
|
|
def _matrix_norm_1(x: ndarray, row_axis: int, col_axis: int, R: type):
|
|
s = x.shape
|
|
if s[col_axis] == 0:
|
|
_norm_error_zero_max()
|
|
|
|
if staticlen(s) == 2:
|
|
dtype = x.dtype
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
row_stride = x.strides[row_axis]
|
|
col_stride = x.strides[col_axis]
|
|
data = x.data
|
|
curr = R()
|
|
first = True
|
|
|
|
for col_idx in range(col_dim):
|
|
sub = R()
|
|
for row_idx in range(row_dim):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() +
|
|
(row_stride * row_idx) +
|
|
(col_stride * col_idx)))[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub > curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
return curr
|
|
else:
|
|
axis1, axis2 = row_axis, col_axis
|
|
if axis1 > axis2:
|
|
axis1, axis2 = axis2, axis1
|
|
|
|
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
|
|
ans = empty(ans_shape, R)
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
curr = R()
|
|
first = True
|
|
|
|
for col_idx in range(col_dim):
|
|
sub = R()
|
|
for row_idx in range(row_dim):
|
|
i, j = row_idx, col_idx
|
|
if row_axis > col_axis:
|
|
i, j = j, i
|
|
idx1 = tuple_insert(
|
|
tuple_insert(idx, axis1, i), axis2, j)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub > curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _matrix_norm_inf(x: ndarray, row_axis: int, col_axis: int, R: type):
|
|
s = x.shape
|
|
if s[row_axis] == 0:
|
|
_norm_error_zero_max()
|
|
|
|
if staticlen(s) == 2:
|
|
dtype = x.dtype
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
row_stride = x.strides[row_axis]
|
|
col_stride = x.strides[col_axis]
|
|
data = x.data
|
|
curr = R()
|
|
first = True
|
|
|
|
for row_idx in range(row_dim):
|
|
sub = R()
|
|
for col_idx in range(col_dim):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() +
|
|
(row_stride * row_idx) +
|
|
(col_stride * col_idx)))[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub > curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
return curr
|
|
else:
|
|
axis1, axis2 = row_axis, col_axis
|
|
if axis1 > axis2:
|
|
axis1, axis2 = axis2, axis1
|
|
|
|
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
|
|
ans = empty(ans_shape, R)
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
curr = R()
|
|
first = True
|
|
|
|
for row_idx in range(row_dim):
|
|
sub = R()
|
|
for col_idx in range(col_dim):
|
|
i, j = row_idx, col_idx
|
|
if row_axis > col_axis:
|
|
i, j = j, i
|
|
idx1 = tuple_insert(
|
|
tuple_insert(idx, axis1, i), axis2, j)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub > curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _matrix_norm_n1(x: ndarray, row_axis: int, col_axis: int, R: type):
|
|
s = x.shape
|
|
if s[col_axis] == 0:
|
|
_norm_error_zero_min()
|
|
|
|
if staticlen(s) == 2:
|
|
dtype = x.dtype
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
row_stride = x.strides[row_axis]
|
|
col_stride = x.strides[col_axis]
|
|
data = x.data
|
|
curr = R()
|
|
first = True
|
|
|
|
for col_idx in range(col_dim):
|
|
sub = R()
|
|
for row_idx in range(row_dim):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() +
|
|
(row_stride * row_idx) +
|
|
(col_stride * col_idx)))[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub < curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
return curr
|
|
else:
|
|
axis1, axis2 = row_axis, col_axis
|
|
if axis1 > axis2:
|
|
axis1, axis2 = axis2, axis1
|
|
|
|
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
|
|
ans = empty(ans_shape, R)
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
curr = R()
|
|
first = True
|
|
|
|
for col_idx in range(col_dim):
|
|
sub = R()
|
|
for row_idx in range(row_dim):
|
|
i, j = row_idx, col_idx
|
|
if row_axis > col_axis:
|
|
i, j = j, i
|
|
idx1 = tuple_insert(
|
|
tuple_insert(idx, axis1, i), axis2, j)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub < curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _matrix_norm_ninf(x: ndarray, row_axis: int, col_axis: int, R: type):
|
|
s = x.shape
|
|
if s[row_axis] == 0:
|
|
_norm_error_zero_min()
|
|
|
|
if staticlen(s) == 2:
|
|
dtype = x.dtype
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
row_stride = x.strides[row_axis]
|
|
col_stride = x.strides[col_axis]
|
|
data = x.data
|
|
curr = R()
|
|
first = True
|
|
|
|
for row_idx in range(row_dim):
|
|
sub = R()
|
|
for col_idx in range(col_dim):
|
|
elem = _norm_cast((Ptr[dtype](data.as_byte() +
|
|
(row_stride * row_idx) +
|
|
(col_stride * col_idx)))[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub < curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
return curr
|
|
else:
|
|
axis1, axis2 = row_axis, col_axis
|
|
if axis1 > axis2:
|
|
axis1, axis2 = axis2, axis1
|
|
|
|
ans_shape = tuple_delete(tuple_delete(s, axis2), axis1)
|
|
ans = empty(ans_shape, R)
|
|
row_dim = s[row_axis]
|
|
col_dim = s[col_axis]
|
|
|
|
for idx in multirange(ans_shape):
|
|
q = ans._ptr(idx)
|
|
curr = R()
|
|
first = True
|
|
|
|
for row_idx in range(row_dim):
|
|
sub = R()
|
|
for col_idx in range(col_dim):
|
|
i, j = row_idx, col_idx
|
|
if row_axis > col_axis:
|
|
i, j = j, i
|
|
idx1 = tuple_insert(
|
|
tuple_insert(idx, axis1, i), axis2, j)
|
|
elem = _norm_cast(x._ptr(idx1)[0])
|
|
sub += abs(elem)
|
|
|
|
if first or sub < curr:
|
|
curr = sub
|
|
first = False
|
|
|
|
q[0] = curr
|
|
|
|
return ans
|
|
|
|
def _multi_svd_norm(x: ndarray, row_axis: int, col_axis: int, op):
|
|
y = moveaxis(x, (row_axis, col_axis), (-2, -1))
|
|
return op(svd(y, compute_uv=False), axis=-1)
|
|
|
|
def _norm_wrap(x, kd_shape, keepdims: Static[int]):
|
|
if keepdims:
|
|
return asarray(x).reshape(kd_shape)
|
|
else:
|
|
return x
|
|
|
|
def norm(x, ord = None, axis = None, keepdims: Static[int] = False):
|
|
x = asarray(x)
|
|
|
|
if not (x.dtype is float or x.dtype is float32 or x.dtype is float16):
|
|
return norm(x.astype(float), ord=ord, axis=axis, keepdims=keepdims)
|
|
|
|
ndim: Static[int] = staticlen(x.shape)
|
|
dtype = x.dtype
|
|
|
|
if dtype is complex or dtype is float:
|
|
r = float()
|
|
elif dtype is complex64 or dtype is float32:
|
|
r = float32()
|
|
else:
|
|
r = float()
|
|
|
|
R = type(r)
|
|
|
|
# Common cases
|
|
if axis is None:
|
|
handle = False
|
|
if ord is None:
|
|
handle = True
|
|
elif isinstance(ord, str):
|
|
handle = ((ord == 'f' or ord == 'fro') and ndim == 2)
|
|
elif isinstance(ord, int):
|
|
handle = (ord == 2 and ndim == 1)
|
|
|
|
if handle:
|
|
y = x.ravel(order='K')
|
|
if dtype is complex or dtype is complex64:
|
|
x_real = y.real
|
|
x_imag = y.imag
|
|
sqnorm = x_real.dot(x_real) + x_imag.dot(x_imag)
|
|
else:
|
|
sqnorm = y.dot(y)
|
|
|
|
ret = sqrt(sqnorm)
|
|
if keepdims:
|
|
return asarray(ret).reshape((1,) * ndim)
|
|
else:
|
|
return ret
|
|
|
|
if axis is None:
|
|
ax = tuple_range(ndim)
|
|
elif isinstance(axis, int):
|
|
ax = (axis,)
|
|
elif isinstance(axis, Tuple):
|
|
ax = axis
|
|
else:
|
|
compile_error("'axis' must be None, an integer or a tuple of integers")
|
|
|
|
if ndim == 0:
|
|
# this matches NumPy behavior
|
|
if ord is None:
|
|
return abs(x.data[0])
|
|
else:
|
|
compile_error("Improper number of dimensions to norm.")
|
|
|
|
ax = tuple(normalize_axis_index(a, ndim) for a in ax)
|
|
kd_shape = x.shape
|
|
if keepdims:
|
|
for a in ax:
|
|
kd_shape = tuple_set(kd_shape, a, 1)
|
|
|
|
if staticlen(ax) == 1:
|
|
if ord is None:
|
|
return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims)
|
|
elif isinstance(ord, int):
|
|
if ord == 0:
|
|
return _norm_wrap(_vector_norm_0(x, ax[0], R), kd_shape, keepdims)
|
|
elif ord == 1:
|
|
return _norm_wrap(_vector_norm_1(x, ax[0], R), kd_shape, keepdims)
|
|
elif ord == 2:
|
|
return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims)
|
|
else:
|
|
return _norm_wrap(_vector_norm_g(x, ax[0], float(ord), R), kd_shape, keepdims)
|
|
elif isinstance(ord, float):
|
|
if ord == 0.:
|
|
return _norm_wrap(_vector_norm_0(x, ax[0], R), kd_shape, keepdims)
|
|
elif ord == 1.:
|
|
return _norm_wrap(_vector_norm_1(x, ax[0], R), kd_shape, keepdims)
|
|
elif ord == 2.:
|
|
return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims)
|
|
elif ord == inf(float):
|
|
return _norm_wrap(_vector_norm_inf(x, ax[0], R), kd_shape, keepdims)
|
|
elif ord == -inf(float):
|
|
return _norm_wrap(_vector_norm_ninf(x, ax[0], R), kd_shape, keepdims)
|
|
else:
|
|
return _norm_wrap(_vector_norm_g(x, ax[0], ord, R), kd_shape, keepdims)
|
|
else:
|
|
compile_error("Invalid norm order for vectors")
|
|
elif staticlen(ax) == 2:
|
|
row_axis, col_axis = ax
|
|
if row_axis == col_axis:
|
|
raise ValueError("Duplicate axes given.")
|
|
|
|
if ord is None:
|
|
return _norm_wrap(_matrix_norm_f(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif isinstance(ord, int):
|
|
if ord == 2:
|
|
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.max), kd_shape, keepdims)
|
|
elif ord == -2:
|
|
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.min), kd_shape, keepdims)
|
|
elif ord == 1:
|
|
return _norm_wrap(_matrix_norm_1(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif ord == -1:
|
|
return _norm_wrap(_matrix_norm_n1(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif isinstance(ord, float):
|
|
if ord == 2.:
|
|
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.max), kd_shape, keepdims)
|
|
elif ord == -2.:
|
|
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.min), kd_shape, keepdims)
|
|
elif ord == 1.:
|
|
return _norm_wrap(_matrix_norm_1(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif ord == -1.:
|
|
return _norm_wrap(_matrix_norm_n1(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif ord == inf(float):
|
|
return _norm_wrap(_matrix_norm_inf(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif ord == -inf(float):
|
|
return _norm_wrap(_matrix_norm_ninf(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif isinstance(ord, str):
|
|
if ord == 'fro' or ord == 'f':
|
|
return _norm_wrap(_matrix_norm_f(x, ax[0], ax[1], R), kd_shape, keepdims)
|
|
elif ord == 'nuc':
|
|
return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.sum), kd_shape, keepdims)
|
|
|
|
raise ValueError("Invalid norm order for matrices.")
|
|
else:
|
|
raise ValueError("Improper number of dimensions to norm.")
|
|
|
|
def pinv(a, rcond = 1e-15, hermitian: bool = False):
|
|
a = asarray(a)
|
|
rcond = asarray(rcond)
|
|
m, n = _rows_cols(a)
|
|
|
|
if m == 0 or n == 0:
|
|
ans_shape = a.shape[:-2] + (n, m)
|
|
dtype = a.dtype
|
|
|
|
if (dtype is float or
|
|
dtype is float32 or
|
|
dtype is complex or
|
|
dtype is complex64):
|
|
return empty(ans_shape, dtype=dtype)
|
|
else:
|
|
return empty(ans_shape, dtype=float)
|
|
|
|
a = a.conjugate()
|
|
u, s, vt = svd(a, full_matrices=False, hermitian=hermitian)
|
|
cutoff = multiply(rcond[..., None], s.max(axis=-1, keepdims=True))
|
|
S = s.dtype
|
|
|
|
for idx in multirange(broadcast_shapes(s.shape, cutoff.shape)):
|
|
c = cutoff._ptr(idx, broadcast=True)[0]
|
|
p = s._ptr(idx, broadcast=True)
|
|
e = p[0]
|
|
if greater(e, c):
|
|
e = cast(divide(1, e), S)
|
|
else:
|
|
e = S()
|
|
p[0] = e
|
|
|
|
vt_t = swapaxes(vt, -1, -2)
|
|
u_t = swapaxes(u, -1, -2)
|
|
res = matmul(vt_t, multiply(s[..., None], u_t))
|
|
return res
|
|
|
|
def matrix_rank(A, tol = None, hermitian: bool = False):
|
|
A = asarray(A)
|
|
dtype = A.dtype
|
|
|
|
if staticlen(A.shape) == 0:
|
|
return int(bool(A.data[0]))
|
|
elif staticlen(A.shape) == 1:
|
|
for a in A:
|
|
if a:
|
|
return 1
|
|
return 0
|
|
elif staticlen(A.shape) == 2:
|
|
m, n = _rows_cols(A)
|
|
if m == 0 or n == 0:
|
|
raise ValueError("cannot take rank of empty matrix")
|
|
S = svd(A, compute_uv=False, hermitian=hermitian)
|
|
R = S.dtype
|
|
|
|
if staticlen(asarray(tol).shape) != 0:
|
|
compile_error("invalid tolerance dimension")
|
|
|
|
s_max = -inf(R)
|
|
if tol is None:
|
|
for s in S:
|
|
if s > s_max:
|
|
s_max = s
|
|
|
|
r = 0
|
|
for s in S:
|
|
if tol is None:
|
|
if s > s_max * cast(max(m, n), R) * eps(R):
|
|
r += 1
|
|
else:
|
|
if s > cast(tol, R):
|
|
r += 1
|
|
|
|
return r
|
|
else:
|
|
m, n = _rows_cols(A)
|
|
if m == 0 or n == 0:
|
|
raise ValueError("cannot take rank of empty matrix")
|
|
k = min(m, n)
|
|
pre_shape = A.shape[:-2]
|
|
ans = empty(pre_shape, int)
|
|
S = svd(A, compute_uv=False, hermitian=hermitian)
|
|
R = S.dtype
|
|
|
|
if tol is None:
|
|
t = 0
|
|
else:
|
|
t = asarray(tol)
|
|
if staticlen(t.shape) > 0:
|
|
if staticlen(t.shape) != staticlen(A.shape) - 2:
|
|
compile_error("invalid tolerance dimension")
|
|
else:
|
|
if t.shape != A.shape[:-2]:
|
|
raise ValueError("tolerance does not broadcast against matrix_rank input")
|
|
|
|
for idx in multirange(pre_shape):
|
|
q = ans._ptr(idx)
|
|
r = 0
|
|
s_max = -inf(R)
|
|
|
|
if tol is None:
|
|
for i in range(k):
|
|
e = S._ptr(idx + (i,))[0]
|
|
if e > s_max:
|
|
s_max = e
|
|
|
|
for i in range(k):
|
|
e = S._ptr(idx + (i,))[0]
|
|
if tol is None:
|
|
if e > s_max * cast(max(m, n), R) * eps(R):
|
|
r += 1
|
|
else:
|
|
if staticlen(t.shape) == 0:
|
|
if e > cast(t.data[0], R):
|
|
r += 1
|
|
else:
|
|
if e > cast(t._ptr(idx)[0], R):
|
|
r += 1
|
|
q[0] = r
|
|
|
|
return ans
|
|
|
|
def cond(x, p = None):
|
|
def cond_r_1(x, p: int):
|
|
s = svd(x, compute_uv=False)
|
|
if p == -2:
|
|
return s[..., -1] / s[..., 0]
|
|
else:
|
|
return s[..., 0] / s[..., -1]
|
|
|
|
def cond_r_2(x, p):
|
|
invx = _inv(x, ignore_errors=True)
|
|
return norm(x, p, axis=(-2, -1)) * norm(invx, p, axis=(-2, -1))
|
|
|
|
x = asarray(x)
|
|
m, n = _rows_cols(x)
|
|
|
|
if m == 0 or n == 0:
|
|
raise LinAlgError("cond is not defined on empty arrays")
|
|
|
|
if p is None:
|
|
r = cond_r_1(x, 0)
|
|
elif isinstance(p, int):
|
|
if p == 2 or p == -2:
|
|
r = cond_r_1(x, p)
|
|
else:
|
|
r = cond_r_2(x, p)
|
|
else:
|
|
r = cond_r_2(x, p)
|
|
|
|
r = asarray(r)
|
|
if staticlen(r.shape) == 0:
|
|
any_nan = False
|
|
for i in range(m):
|
|
for j in range(n):
|
|
if isnan(x._ptr((i, j))[0]):
|
|
any_nan = True
|
|
break
|
|
|
|
rval = r.data[0]
|
|
if isnan(rval) and not any_nan:
|
|
return inf(type(rval))
|
|
else:
|
|
return rval
|
|
else:
|
|
for idx in multirange(r.shape):
|
|
any_nan = False
|
|
for i in range(m):
|
|
for j in range(n):
|
|
if isnan(x._ptr(idx + (i, j))[0]):
|
|
any_nan = True
|
|
break
|
|
|
|
p = r._ptr(idx)
|
|
if isnan(p[0]) and not any_nan:
|
|
p[0] = inf(type(p[0]))
|
|
|
|
return r
|
|
|
|
def matrix_transpose(x):
|
|
x = asarray(x)
|
|
if x.ndim < 2:
|
|
compile_error("Input array must be at least 2-dimensional")
|
|
return swapaxes(x, -1, -2)
|
|
|
|
@extend
|
|
class ndarray:
|
|
def __matmul__(self, other):
|
|
return matmul(self, other)
|
|
|
|
def dot(self, other):
|
|
return dot(self, other)
|
|
|
|
def trace(self, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: type = NoneType, out = None):
|
|
return _trace_ufunc(self, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out)
|