# Copyright (C) 2022-2025 Exaloop Inc. from .blas import * from ..ndarray import ndarray from ..ndmath import divide, greater, isnan, multiply, sqrt from ..routines import atleast_2d, asarray, broadcast_shapes, broadcast_to, empty, \ empty_like, eye, expand_dims, moveaxis, reshape, swapaxes, \ zeros, array from ..util import cast, cdiv_int, coerce, eps, exp, free, inf, log, multirange, \ nan, normalize_axis_index, sizeof, sort, sqrt as util_sqrt, \ tuple_delete, tuple_insert, tuple_range, tuple_set, zero ############# # Utilities # ############# class LinAlgError(Static[Exception]): def __init__(self, message: str = ''): super().__init__("numpy.linalg.LinAlgError", message) def _square_rows(a): s = a.shape if staticlen(s) < 2: compile_error("Array must be at least 2-dimensional") n = s[-1] if n != s[-2]: raise LinAlgError("Last 2 dimensions of the array must be square") return n def _rows_cols(a): s = a.shape if staticlen(s) < 2: compile_error("Array must be at least 2-dimensional") m, n = s[-2:] return m, n def _asarray(a, dtype: type = NoneType): if dtype is NoneType: if isinstance(a, ndarray): if (a.dtype is float or a.dtype is float32 or a.dtype is complex or a.dtype is complex64): return a else: return a.astype(float) else: dtype1 = type(asarray(a).data[0]) if (dtype1 is float or dtype1 is float32 or dtype1 is complex or dtype1 is complex64): return asarray(a, dtype=dtype1) else: return asarray(a, dtype=float) elif (dtype is float or dtype is float32 or dtype is complex or dtype is complex64): return asarray(a, dtype=dtype) else: return asarray(a, dtype=float) def _basetype(dtype: type): if dtype is float or dtype is float32: return dtype() elif dtype is complex: return float() elif dtype is complex64: return float32() else: compile_error("[internal error] bad dtype") def _complextype(dtype: type): if dtype is float: return complex() elif dtype is float32: return complex64() elif dtype is complex or dtype is complex64: return dtype() else: compile_error("[internal error] bad dtype") def _copy(n: int, sx: Ptr[T], incx: int, sy: Ptr[T], incy: int, T: type): args = (fint(n), sx, fint(incx), sy, fint(incy)) if T is float: cblas_dcopy(*args) elif T is float32: cblas_scopy(*args) elif T is complex: cblas_zcopy(*args) elif T is complex64: cblas_ccopy(*args) else: compile_error("[internal error] bad type for BLAS copy: " + T.__name__) def _nan(T: type): if T is float or T is float32: return nan(T) elif T is complex: return complex(nan(float), nan(float)) elif T is complex64: return complex64(nan(float32), nan(float32)) else: compile_error("[internal error] bad dtype") @tuple class LinearizeData: rows: int cols: int row_strides: int col_strides: int out_lead_dim: int def __new__(rows: int, cols: int, row_strides: int, col_strides: int, out_lead_dim: int) -> LinearizeData: return (rows, cols, row_strides, col_strides, out_lead_dim) def __new__(rows: int, cols: int, row_strides: int, col_strides: int): return LinearizeData(rows, cols, row_strides, col_strides, cols) def linearize(self, dst: Ptr[T], src: Ptr[T], T: type): if dst: rv = dst cols = self.cols col_strides = cdiv_int(self.col_strides, sizeof(T)) for i in range(self.rows): if col_strides > 0: _copy(cols, src, col_strides, dst, 1) elif col_strides < 0: _copy(cols, src + (cols - 1)*col_strides, col_strides, dst, 1); else: for j in range(cols): dst[j] = src[0] src += cdiv_int(self.row_strides, sizeof(T)) dst += self.out_lead_dim return rv else: return src def delinearize(self, dst: Ptr[T], src: Ptr[T], T: type): if src: rv = src cols = self.cols col_strides = cdiv_int(self.col_strides, sizeof(T)) for i in range(self.rows): if col_strides > 0: _copy(cols, src, 1, dst, col_strides); elif col_strides < 0: _copy(cols, src, 1, dst + (cols - 1)*col_strides, col_strides) else: if cols > 0: dst[0] = src[cols - 1] src += self.out_lead_dim dst += cdiv_int(self.row_strides, sizeof(T)) return rv else: return src def nan_matrix(self, dst: Ptr[T], T: type): for i in range(self.rows): cp = dst cs = cdiv_int(self.col_strides, sizeof(T)) for j in range(self.cols): cp[0] = _nan(T) cp += cs dst += cdiv_int(self.row_strides, sizeof(T)) def zero_matrix(self, dst: Ptr[T], T: type): for i in range(self.rows): cp = dst cs = cdiv_int(self.col_strides, sizeof(T)) for j in range(self.cols): cp[0] = T() cp += cs dst += cdiv_int(self.row_strides, sizeof(T)) def identity_matrix(dst: Ptr[T], n: int, T: type): str.memset(dst.as_byte(), byte(0), n * n * sizeof(T)) for i in range(n): dst[0] = T(1) dst += n + 1 ############### # Determinant # ############### def _slogdet_from_factored_diagonal(src: Ptr[T], m: int, sign: T, T: type): if T is float or T is float32: acc_sign = sign acc_logdet = T(0.0) for i in range(m): abs_element = src[0] if abs_element < T(0.0): acc_sign = -acc_sign abs_element = -abs_element acc_logdet += log(abs_element) src += m + 1 return acc_sign, acc_logdet elif T is complex or T is complex64: B = type(_basetype(T)) acc_sign = sign acc_logdet = B() for i in range(m): abs_element = abs(src[0]) sign_element = src[0] / abs_element acc_sign *= sign_element acc_logdet += log(abs_element) src += m + 1 return acc_sign, acc_logdet else: compile_error("[internal error] invalid type for _slogdet_from_factored_diagonal: " + T.__name__) def _slogdet_single_element(m: int, src: Ptr[T], pivots: Ptr[fint], T: type): m0 = fint(m) lda = fint(max(m, 1)) info = fint(0) args = (__ptr__(m0), __ptr__(m0), src.as_byte(), __ptr__(lda), pivots, __ptr__(info)) if T is float: dgetrf_(*args) elif T is float32: sgetrf_(*args) elif T is complex: zgetrf_(*args) elif T is complex64: cgetrf_(*args) else: compile_error("[internal error] bad input type") sign = T() if info == fint(0): change_sign = False for i in range(m): if pivots[i] != fint(i + 1): change_sign = not change_sign sign = T(-1.0 if change_sign else 1.0) logdet = zero(_basetype(T)) sign, logdet = _slogdet_from_factored_diagonal(src, m, sign) return sign, logdet else: return T(0.0), -inf(type(_basetype(T))) def _det_from_slogdet(sign: T, logdet, T: type): if T is float or T is float32: return sign * exp(logdet) elif T is complex or T is complex64: return sign * T(exp(logdet)) else: compile_error("[internal error] invalid type for _det_from_slogdet: " + T.__name__) @tuple class SlogdetResult[A, B]: sign: A logabsdet: B def __getitem__(self, idx: Static[int]): if idx == 0 or idx == -2: return self.sign elif idx == 1 or idx == -1: return self.logabsdet else: compile_error("tuple ('SlogdetResult') index out of range") def slogdet(a): a = _asarray(a) T = a.dtype s = a.shape m = _square_rows(a) steps = a.strides tmp = Ptr[T](m * m) piv = Ptr[fint](m) lin_data = LinearizeData(m, m, steps[-1], steps[-2]) if a.ndim == 2: lin_data.linearize(tmp, a.data) sign, logabsdet = _slogdet_single_element(m, tmp, piv) else: ans_shape = s[:-2] sign = empty(ans_shape, dtype=a.dtype) logabsdet = empty(ans_shape, dtype=type(_basetype(a.dtype))) for idx in multirange(ans_shape): lin_data.linearize(tmp, a._ptr(idx + (0, 0))) sign0, logabsdet0 = _slogdet_single_element(m, tmp, piv) sign._ptr(idx)[0] = sign0 logabsdet._ptr(idx)[0] = logabsdet0 free(tmp) free(piv) return SlogdetResult(sign, logabsdet) def det(a): a = _asarray(a) T = a.dtype s = a.shape m = _square_rows(a) steps = a.strides tmp = Ptr[T](m * m) piv = Ptr[fint](m) lin_data = LinearizeData(m, m, steps[-1], steps[-2]) if a.ndim == 2: lin_data.linearize(tmp, a.data) sign, logabsdet = _slogdet_single_element(m, tmp, piv) ans = _det_from_slogdet(sign, logabsdet) else: ans_shape = s[:-2] ans = empty(ans_shape, dtype=a.dtype) for idx in multirange(ans_shape): lin_data.linearize(tmp, a._ptr(idx + (0, 0))) sign0, logabsdet0 = _slogdet_single_element(m, tmp, piv) ans._ptr(idx)[0] = _det_from_slogdet(sign0, logabsdet0) free(tmp) free(piv) return ans ######## # Eigh # ######## class EighParams: A: Ptr[T] W: Ptr[B] WORK: Ptr[T] RWORK: Ptr[B] IWORK: Ptr[fint] N: fint LWORK: fint LRWORK: fint LIWORK: fint JOBZ: byte UPLO: byte LDA: fint T: type B: type def _init_real(self, JOBZ: byte, UPLO: byte, N: fint): safe_N = int(N) alloc_size = safe_N * (safe_N + 1) * sizeof(T) lda = N if N else fint(1) mem_buff = cobj(alloc_size) a = mem_buff w = mem_buff + safe_N * safe_N * sizeof(T) self.A = Ptr[T](a) self.W = Ptr[B](w) self.RWORK = Ptr[B]() self.N = N self.LRWORK = fint(0) self.JOBZ = JOBZ self.UPLO = UPLO self.LDA = lda # work size query query_work_size = T() query_iwork_size = fint() self.LWORK = fint(-1) self.LIWORK = fint(-1) self.WORK = __ptr__(query_work_size) self.IWORK = __ptr__(query_iwork_size) if self.call(): free(mem_buff) raise LinAlgError("error in evd work size query") lwork = cast(query_work_size, int) liwork = int(query_iwork_size) mem_buff2 = cobj(lwork * sizeof(T) + liwork * sizeof(fint)) work = mem_buff2 iwork = mem_buff2 + lwork * sizeof(T) self.LWORK = fint(lwork) self.WORK = Ptr[T](work) self.LIWORK = fint(liwork) self.IWORK = Ptr[fint](iwork) def _init_complex(self, JOBZ: byte, UPLO: byte, N: fint): safe_N = int(N) lda = N if N else fint(1) mem_buff = cobj(safe_N * safe_N * sizeof(T) + safe_N * sizeof(B)) a = mem_buff w = mem_buff + safe_N * safe_N * sizeof(T) self.A = Ptr[T](a) self.W = Ptr[B](w) self.N = N self.JOBZ = JOBZ self.UPLO = UPLO self.LDA = lda # work size query query_work_size = T() query_rwork_size = B() query_iwork_size = fint() self.LWORK = fint(-1) self.LRWORK = fint(-1) self.LIWORK = fint(-1) self.WORK = __ptr__(query_work_size) self.RWORK = __ptr__(query_rwork_size) self.IWORK = __ptr__(query_iwork_size) if self.call(): free(mem_buff) raise LinAlgError("error in evd work size query") lwork = cast(Ptr[B](__ptr__(query_work_size).as_byte())[0], int) lrwork = cast(query_rwork_size, int) liwork = int(query_iwork_size) mem_buff2 = cobj(lwork * sizeof(T) + lrwork * sizeof(B) + liwork * sizeof(fint)) work = mem_buff2 rwork = work + lwork * sizeof(T) iwork = rwork + lrwork * sizeof(B) self.WORK = Ptr[T](work) self.RWORK = Ptr[B](rwork) self.IWORK = Ptr[fint](iwork) self.LWORK = fint(lwork) self.LRWORK = fint(lrwork) self.LIWORK = fint(liwork) def __init__(self, JOBZ: byte, UPLO: byte, N: fint): if T is complex or T is complex64: self._init_complex(JOBZ, UPLO, N) else: self._init_real(JOBZ, UPLO, N) def release(self): free(self.A) free(self.WORK) def call(self): JOBZ = self.JOBZ UPLO = self.UPLO N = self.N A = self.A LDA = self.LDA W = self.W WORK = self.WORK LWORK = self.LWORK RWORK = self.RWORK LRWORK = self.LRWORK IWORK = self.IWORK LIWORK = self.LIWORK rv = fint() args_real = (__ptr__(JOBZ), __ptr__(UPLO), __ptr__(N), A.as_byte(), __ptr__(LDA), W.as_byte(), WORK.as_byte(), __ptr__(LWORK), IWORK, __ptr__(LIWORK), __ptr__(rv)) args_cplx = (__ptr__(JOBZ), __ptr__(UPLO), __ptr__(N), A.as_byte(), __ptr__(LDA), W.as_byte(), WORK.as_byte(), __ptr__(LWORK), RWORK.as_byte(), __ptr__(LRWORK), IWORK, __ptr__(LIWORK), __ptr__(rv)) if T is float: dsyevd_(*args_real) elif T is float32: ssyevd_(*args_real) elif T is complex: zheevd_(*args_cplx) elif T is complex64: cheevd_(*args_cplx) else: compile_error("[internal error] bad dtype for eigh") return rv @tuple class EighResult[A, B]: eigenvalues: A eigenvectors: B def __getitem__(self, idx: Static[int]): if idx == 0 or idx == -2: return self.eigenvalues elif idx == 1 or idx == -1: return self.eigenvectors else: compile_error("tuple ('EighResult') index out of range") def _eigh(a, JOBZ: byte, UPLO: byte, compute_eigenvectors: Static[int]): a = _asarray(a) B = type(_basetype(a.dtype)) n = _square_rows(a) params = EighParams[a.dtype, B](JOBZ, UPLO, fint(n)) eigenvalues = empty(a.shape[:-2] + (n,), dtype=B) if compute_eigenvectors: eigenvectors = empty(a.shape, dtype=a.dtype) else: eigenvectors = None matrix_in_ld = LinearizeData(n, n, a.strides[-1], a.strides[-2]) eigenvalues_out_ld = LinearizeData(1, n, 0, eigenvalues.strides[-1]) if compute_eigenvectors: eigenvectors_out_ld = LinearizeData(n, n, eigenvectors.strides[-1], eigenvectors.strides[-2]) else: eigenvectors_out_ld = None error_occured = False for idx in multirange(a.shape[:-2]): matrix_ptr = a._ptr(idx + (0, 0)) eigval_ptr = eigenvalues._ptr(idx + (0,)) if compute_eigenvectors: eigvec_ptr = eigenvectors._ptr(idx + (0, 0)) else: eigvec_ptr = None matrix_in_ld.linearize(params.A, matrix_ptr) not_ok = params.call() if not not_ok: eigenvalues_out_ld.delinearize(eigval_ptr, params.W) if compute_eigenvectors: eigenvectors_out_ld.delinearize(eigvec_ptr, params.A) else: eigenvalues_out_ld.nan_matrix(eigval_ptr) if compute_eigenvectors: eigenvectors_out_ld.nan_matrix(eigvec_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("Eigenvalues did not converge") if compute_eigenvectors: return EighResult(eigenvalues, eigenvectors) else: return eigenvalues def eigh(a, UPLO: str = 'L'): uplo_code = byte() if UPLO == 'L' or UPLO == 'l': uplo_code = byte(76) elif UPLO == 'U' or UPLO == 'u': uplo_code = byte(85) else: raise ValueError("UPLO argument must be 'L' or 'U'") jobz_code = byte(86) return _eigh(a, JOBZ=jobz_code, UPLO=uplo_code, compute_eigenvectors=True) def eigvalsh(a, UPLO: str = 'L'): uplo_code = byte() if UPLO == 'L' or UPLO == 'l': uplo_code = byte(76) elif UPLO == 'U' or UPLO == 'u': uplo_code = byte(85) else: raise ValueError("UPLO argument must be 'L' or 'U'") jobz_code = byte(78) return _eigh(a, JOBZ=jobz_code, UPLO=uplo_code, compute_eigenvectors=False) ######### # Solve # ######### class GesvParams: A: Ptr[T] B: Ptr[T] IPIV: Ptr[fint] N: fint NRHS: fint LDA: fint LDB: fint T: type def __init__(self, N: fint, NRHS: fint): safe_N = int(N) safe_NRHS = int(NRHS) ld = N if N else fint(1) mem_buff = cobj(safe_N * safe_N * sizeof(T) + safe_N * safe_NRHS * sizeof(T) + safe_N * sizeof(fint)) a = mem_buff b = a + safe_N * safe_N * sizeof(T) ipiv = b + safe_N * safe_NRHS * sizeof(T) self.A = Ptr[T](a) self.B = Ptr[T](b) self.IPIV = Ptr[fint](ipiv) self.N = N self.NRHS = NRHS self.LDA = ld self.LDB = ld def release(self): free(self.A) def call(self): A = self.A B = self.B IPIV = self.IPIV N = self.N NRHS = self.NRHS LDA = self.LDA LDB = self.LDB rv = fint() args = (__ptr__(N), __ptr__(NRHS), A.as_byte(), __ptr__(LDA), IPIV, B.as_byte(), __ptr__(LDB), __ptr__(rv)) if T is float: dgesv_(*args) elif T is float32: sgesv_(*args) elif T is complex: zgesv_(*args) elif T is complex64: cgesv_(*args) else: compile_error("[internal error] bad dtype for gesv") return rv def _solve(a, b): if a.ndim == b.ndim: if a.shape[:-2] != b.shape[:-2]: pre_broadcast = broadcast_shapes(a.shape[:-2], b.shape[:-2]) a = broadcast_to(a, pre_broadcast + a.shape[-2:]) b = broadcast_to(b, pre_broadcast + b.shape[-2:]) else: pre_broadcast = broadcast_shapes(a.shape[:-2], b.shape[:-2]) a1 = broadcast_to(a, pre_broadcast + a.shape[-2:]) b1 = broadcast_to(b, pre_broadcast + b.shape[-2:]) return _solve(a1, b1) n = _square_rows(a) m, k = _rows_cols(b) if m != n: raise ValueError("solve: 'a' and 'b' don't have the same number of rows") r = empty(b.shape, dtype=a.dtype) nrhs = b.shape[-1] params = GesvParams[a.dtype](fint(n), fint(nrhs)) a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2]) b_in = LinearizeData(nrhs, n, b.strides[-1], b.strides[-2]) r_out = LinearizeData(nrhs, n, r.strides[-1], r.strides[-2]) error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) b_ptr = b._ptr(idx + (0, 0)) r_ptr = r._ptr(idx + (0, 0)) a_in.linearize(params.A, a_ptr) b_in.linearize(params.B, b_ptr) not_ok = params.call() if not not_ok: r_out.delinearize(r_ptr, params.B) else: r_out.nan_matrix(r_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("Singular matrix") return r def _solve1(a, b): n = _square_rows(a) m = b.shape[0] if m != n: raise ValueError("solve: 'a' and 'b' don't have the same number of rows") r = empty((m,), dtype=a.dtype) params = GesvParams[a.dtype](fint(n), fint(1)) a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2]) b_in = LinearizeData(1, n, 1, b.strides[0]) r_out = LinearizeData(1, n, 1, r.strides[0]) error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) b_ptr = b.data r_ptr = r.data a_in.linearize(params.A, a_ptr) b_in.linearize(params.B, b_ptr) not_ok = params.call() if not not_ok: r_out.delinearize(r_ptr, params.B) else: r_out.nan_matrix(r_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("Singular matrix") return r def solve(a, b): dtype = type( coerce( type(asarray(a).data[0]), type(asarray(b).data[0]))) a = _asarray(a, dtype=dtype) b = _asarray(b, dtype=dtype) if a.ndim < 2: compile_error("'a' must be at least 2-dimensional") if b.ndim == 0: compile_error("'b' must be at least 1-dimensional") if b.ndim == 1: return _solve1(a, b) else: return _solve(a, b) def _inv(a, ignore_errors: Static[int]): a = _asarray(a) n = _square_rows(a) r = empty(a.shape, dtype=a.dtype) params = GesvParams[a.dtype](fint(n), fint(n)) a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2]) r_out = LinearizeData(n, n, r.strides[-1], r.strides[-2]) error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) r_ptr = r._ptr(idx + (0, 0)) a_in.linearize(params.A, a_ptr) LinearizeData.identity_matrix(params.B, n) not_ok = params.call() if not not_ok: r_out.delinearize(r_ptr, params.B) else: r_out.nan_matrix(r_ptr) error_occured = True params.release() if not ignore_errors: if error_occured: raise LinAlgError("Singular matrix") return r def inv(a): return _inv(a, ignore_errors=False) ############ # Cholesky # ############ class PotrfParams: A: Ptr[T] N: fint LDA: fint UPLO: byte T: type def __init__(self, UPLO: byte, N: fint): safe_N = int(N) lda = N if N else fint(1) mem_buff = cobj(safe_N * safe_N * sizeof(T)) a = mem_buff self.A = Ptr[T](a) self.N = N self.LDA = lda self.UPLO = UPLO def zero_lower_triangle(self): n = int(self.N) matrix = self.A for i in range(n - 1): for j in range(i + 1, n): matrix[j] = zero(T) matrix += n def zero_upper_triangle(self): n = int(self.N) matrix = self.A matrix += n for i in range(1, n): for j in range(i): matrix[j] = zero(T) matrix += n def call(self): UPLO = self.UPLO N = self.N A = self.A LDA = self.LDA rv = fint() args = (__ptr__(UPLO), __ptr__(N), A.as_byte(), __ptr__(LDA), __ptr__(rv)) if T is float: dpotrf_(*args) elif T is float32: spotrf_(*args) elif T is complex: zpotrf_(*args) elif T is complex64: cpotrf_(*args) else: compile_error("[internal error] bad dtype for potrf") return rv def release(self): free(self.A) def cholesky(a, upper: bool = False): a = _asarray(a) uplo = byte(85 if upper else 76) # 'U' / 'L' n = _square_rows(a) params = PotrfParams[a.dtype](uplo, fint(n)) r = empty(a.shape, dtype=a.dtype) a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2]) r_out = LinearizeData(n, n, r.strides[-1], r.strides[-2]) error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) r_ptr = r._ptr(idx + (0, 0)) a_in.linearize(params.A, a_ptr) not_ok = params.call() if not not_ok: if upper: params.zero_lower_triangle() else: params.zero_upper_triangle() r_out.delinearize(r_ptr, params.A) else: error_occured = True r_out.nan_matrix(r_ptr) params.release() if error_occured: raise LinAlgError("Matrix is not positive definite") return r ####### # Eig # ####### def _mk_complex_array_from_real(c: Ptr[C], re: Ptr[R], n: int, C: type, R: type): for i in range(n): c[i] = C(re[i], zero(R)) def _mk_complex_array(c: Ptr[C], re: Ptr[R], im: Ptr[R], n: int, C: type, R: type): for i in range(n): c[i] = C(re[i], im[i]) def _mk_complex_array_conjugate_pair(c: Ptr[C], r: Ptr[R], n: int, C: type, R: type): for i in range(n): re = r[i] im = r[i + n] c[i] = C(re, im) c[i + n] = C(re, -im) def _mk_geev_complex_eigenvectors(c: Ptr[C], r: Ptr[R], i: Ptr[R], n: int, C: type, R: type): it = 0 while it < n: if i[it] == zero(R): _mk_complex_array_from_real(c, r, n) c += n r += n it += 1 else: _mk_complex_array_conjugate_pair(c, r, n) c += 2 * n r += 2 * n it += 2 class GeevParams: A: Ptr[T] WR: Ptr[B] WI: Ptr[T] VLR: Ptr[T] VRR: Ptr[T] WORK: Ptr[T] W: Ptr[T] VL: Ptr[T] VR: Ptr[T] N: fint LDA: fint LDVL: fint LDVR: fint LWORK: fint JOBVL: byte JOBVR: byte T: type B: type def _init_real(self, jobvl: byte, jobvr: byte, n: fint): safe_n = int(n) a_size = safe_n * safe_n * sizeof(T) wr_size = safe_n * sizeof(T) wi_size = safe_n * sizeof(T) vlr_size = safe_n * safe_n * sizeof(T) if jobvl == byte(86) else 0 vrr_size = safe_n * safe_n * sizeof(T) if jobvr == byte(86) else 0 w_size = wr_size * 2 vl_size = vlr_size * 2 vr_size = vrr_size * 2 ld = n if n else fint(1) mem_buff = cobj(a_size + wr_size + wi_size + vlr_size + vrr_size + w_size + vl_size + vr_size) a = mem_buff wr = a + a_size wi = wr + wr_size vlr = wi + wi_size vrr = vlr + vlr_size w = vrr + vrr_size vl = w + w_size vr = vl + vl_size self.A = Ptr[T](a) self.WR = Ptr[T](wr) self.WI = Ptr[T](wi) self.VLR = Ptr[T](vlr) self.VRR = Ptr[T](vrr) self.W = Ptr[T](w) self.VL = Ptr[T](vl) self.VR = Ptr[T](vr) self.N = n self.LDA = ld self.LDVL = ld self.LDVR = ld self.JOBVL = jobvl self.JOBVR = jobvr # work size query work_size_query = T() self.LWORK = fint(-1) self.WORK = __ptr__(work_size_query) if self.call(): free(mem_buff) raise LinAlgError("error in evd work size query") work_count = cast(work_size_query, int) mem_buff2 = cobj(work_count * sizeof(T)) work = mem_buff2 self.LWORK = fint(work_count) self.WORK = Ptr[T](work) def _init_complex(self, jobvl: byte, jobvr: byte, n: fint): safe_n = int(n) a_size = safe_n * safe_n * sizeof(T) w_size = safe_n * sizeof(T) vl_size = safe_n * safe_n * sizeof(T) if jobvl == byte(86) else 0 vr_size = safe_n * safe_n * sizeof(T) if jobvr == byte(86) else 0 rwork_size = 2 * safe_n * sizeof(B) total_size = a_size + w_size + vl_size + vr_size + rwork_size ld = n if n else fint(1) mem_buff = cobj(total_size) a = mem_buff w = a + a_size vl = w + w_size vr = vl + vl_size rwork = vr + vr_size self.A = Ptr[T](a) self.WR = Ptr[B](rwork) self.WI = Ptr[T]() self.VLR = Ptr[T]() self.VRR = Ptr[T]() self.VL = Ptr[T](vl) self.VR = Ptr[T](vr) self.W = Ptr[T](w) self.N = n self.LDA = ld self.LDVL = ld self.LDVR = ld self.JOBVL = jobvl self.JOBVR = jobvr # work size query work_size_query = T() self.LWORK = fint(-1) self.WORK = __ptr__(work_size_query) if self.call(): free(mem_buff) raise LinAlgError("error in evd work size query") work_count = cast(work_size_query.real, int) if work_count == 0: work_count = 1 mem_buff2 = cobj(work_count * sizeof(T)) work = mem_buff2 self.LWORK = fint(work_count) self.WORK = Ptr[T](work) def __init__(self, jobvl: byte, jobvr: byte, n: fint): if T is complex or T is complex64: self._init_complex(jobvl, jobvr, n) else: self._init_real(jobvl, jobvr, n) def call(self): A = self.A WR = self.WR WI = self.WI VLR = self.VLR VRR = self.VRR WORK = self.WORK W = self.W VL = self.VL VR = self.VR N = self.N LDA = self.LDA LDVL = self.LDVL LDVR = self.LDVR LWORK = self.LWORK JOBVL = self.JOBVL JOBVR = self.JOBVR rv = fint() args_real = (__ptr__(JOBVL), __ptr__(JOBVR), __ptr__(N), A.as_byte(), __ptr__(LDA), WR.as_byte(), WI.as_byte(), VLR.as_byte(), __ptr__(LDVL), VRR.as_byte(), __ptr__(LDVR), WORK.as_byte(), __ptr__(LWORK), __ptr__(rv)) args_cplx = (__ptr__(JOBVL), __ptr__(JOBVR), __ptr__(N), A.as_byte(), __ptr__(LDA), W.as_byte(), VL.as_byte(), __ptr__(LDVL), VR.as_byte(), __ptr__(LDVR), WORK.as_byte(), __ptr__(LWORK), WR.as_byte(), __ptr__(rv)) if T is float: dgeev_(*args_real) elif T is float32: sgeev_(*args_real) elif T is complex: zgeev_(*args_cplx) elif T is complex64: cgeev_(*args_cplx) else: compile_error("[internal error] bad dtype for geev") return rv def release(self): free(self.WORK) free(self.A) def process_geev_results(self): if T is complex or T is complex64: return C = type(_complextype(T)) _mk_complex_array(Ptr[C](self.W.as_byte()), self.WR, self.WI, int(self.N)) if self.JOBVL == byte(86): _mk_geev_complex_eigenvectors(Ptr[C](self.VL.as_byte()), self.VLR, self.WI, int(self.N)) if self.JOBVR == byte(86): _mk_geev_complex_eigenvectors(Ptr[C](self.VR.as_byte()), self.VRR, self.WI, int(self.N)) @tuple class EigResult[A, B]: eigenvalues: A eigenvectors: B def __getitem__(self, idx: Static[int]): if idx == 0 or idx == -2: return self.eigenvalues elif idx == 1 or idx == -1: return self.eigenvectors else: compile_error("tuple ('EigResult') index out of range") def _eig(a, JOBVL: byte, JOBVR: byte, compute_eigenvectors: Static[int]): a = _asarray(a) B = type(_basetype(a.dtype)) C = type(_complextype(a.dtype)) n = _square_rows(a) params = GeevParams[a.dtype, B](JOBVL, JOBVR, fint(n)) eigenvalues = empty(a.shape[:-2] + (n,), dtype=C) if compute_eigenvectors: eigenvectors = empty(a.shape, dtype=C) else: eigenvectors = None a_in = LinearizeData(n, n, a.strides[-1], a.strides[-2]) w_out = LinearizeData(1, n, 0, eigenvalues.strides[-1]) if compute_eigenvectors: vr_out = LinearizeData(n, n, eigenvectors.strides[-1], eigenvectors.strides[-2]) else: vr_out = None error_occured = False for idx in multirange(a.shape[:-2]): matrix_ptr = a._ptr(idx + (0, 0)) eigval_ptr = eigenvalues._ptr(idx + (0,)) if compute_eigenvectors: eigvec_ptr = eigenvectors._ptr(idx + (0, 0)) else: eigvec_ptr = None a_in.linearize(params.A, matrix_ptr) not_ok = params.call() if not not_ok: params.process_geev_results() w_out.delinearize(eigval_ptr, Ptr[C](params.W.as_byte())) if compute_eigenvectors: vr_out.delinearize(eigvec_ptr, Ptr[C](params.VR.as_byte())) else: w_out.nan_matrix(eigval_ptr) if compute_eigenvectors: vr_out.nan_matrix(eigvec_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("Eigenvalues did not converge") if compute_eigenvectors: return EigResult(eigenvalues, eigenvectors) else: return eigenvalues def eig(a): jobvl_code = byte(78) # 'N' jobvr_code = byte(86) # 'V' return _eig(a, jobvl_code, jobvr_code, compute_eigenvectors=True) def eigvals(a): jobvl_code = byte(78) # 'N' jobvr_code = byte(78) # 'N' return _eig(a, jobvl_code, jobvr_code, compute_eigenvectors=False) ####### # SVD # ####### def _compute_urows_vtcolumns(jobz: byte, m: fint, n: fint): if jobz == byte(78): # 'N' return (fint(0), fint(0)) elif jobz == byte(65): # 'A' return (m, n) else: # 'S' min_m_n = m if m < n else n return (min_m_n, min_m_n) class GesddParams: A: Ptr[T] S: Ptr[B] U: Ptr[T] VT: Ptr[T] WORK: Ptr[T] RWORK: Ptr[B] IWORK: Ptr[fint] M: fint N: fint LDA: fint LDU: fint LDVT: fint LWORK: fint JOBZ: byte T: type B: type def _init_real(self, jobz: byte, m: fint, n: fint): safe_m = int(m) safe_n = int(n) a_size = safe_m * safe_n * sizeof(T) min_m_n = m if m < n else n safe_min_m_n = int(min_m_n) s_size = safe_min_m_n * sizeof(T) iwork_size = 8 * safe_min_m_n * sizeof(fint) ld = m if m else fint(1) u_row_count, vt_column_count = _compute_urows_vtcolumns(jobz, m, n) safe_u_row_count = int(u_row_count) safe_vt_column_count = int(vt_column_count) u_size = safe_u_row_count * safe_m * sizeof(T) vt_size = safe_n * safe_vt_column_count * sizeof(T) mem_buff = cobj(a_size + s_size + u_size + vt_size + iwork_size) a = mem_buff s = a + a_size u = s + s_size vt = u + u_size iwork = vt + vt_size vt_column_count = vt_column_count if vt_column_count else fint(1) self.M = m self.N = n self.A = Ptr[T](a) self.S = Ptr[B](s) self.U = Ptr[T](u) self.VT = Ptr[T](vt) self.RWORK = Ptr[B]() self.IWORK = Ptr[fint](iwork) self.LDA = ld self.LDU = ld self.LDVT = vt_column_count self.JOBZ = jobz # work size query work_size_query = T() self.LWORK = fint(-1) self.WORK = __ptr__(work_size_query) if self.call(): free(mem_buff) raise LinAlgError("error in gesdd work size query") work_count = cast(work_size_query, int) if work_count == 0: work_count = 1 work_size = work_count * sizeof(T) mem_buff2 = cobj(work_size) work = mem_buff2 self.LWORK = fint(work_count) self.WORK = Ptr[T](work) def _init_complex(self, jobz: byte, m: fint, n: fint): safe_m = int(m) safe_n = int(n) min_m_n = m if m < n else n safe_min_m_n = int(min_m_n) ld = m if m else fint(1) u_row_count, vt_column_count = _compute_urows_vtcolumns(jobz, m, n) safe_u_row_count = int(u_row_count) safe_vt_column_count = int(vt_column_count) a_size = safe_m * safe_n * sizeof(T) s_size = safe_min_m_n * sizeof(B) u_size = safe_u_row_count * safe_m * sizeof(T) vt_size = safe_n * safe_vt_column_count * sizeof(T) rwork_size = (7 * safe_min_m_n if jobz == byte(78) else 5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n) rwork_size *= sizeof(T) iwork_size = 8 * safe_min_m_n * sizeof(fint) mem_buff = cobj(a_size + s_size + u_size + vt_size + rwork_size + iwork_size) a = mem_buff s = a + a_size u = s + s_size vt = u + u_size rwork = vt + vt_size iwork = rwork + rwork_size vt_column_count = vt_column_count if vt_column_count else fint(1) self.A = Ptr[T](a) self.S = Ptr[B](s) self.U = Ptr[T](u) self.VT = Ptr[T](vt) self.RWORK = Ptr[B](rwork) self.IWORK = Ptr[fint](iwork) self.M = m self.N = n self.LDA = ld self.LDU = ld self.LDVT = vt_column_count self.JOBZ = jobz # work size query work_size_query = T() self.LWORK = fint(-1) self.WORK = __ptr__(work_size_query) if self.call(): free(mem_buff) raise LinAlgError("error in gesdd work size query") work_count = cast(Ptr[B](__ptr__(work_size_query).as_byte())[0], int) if work_count == 0: work_count = 1 work_size = work_count * sizeof(T) mem_buff2 = cobj(work_size) work = mem_buff2 self.LWORK = fint(work_count) self.WORK = Ptr[T](work) def __init__(self, jobz: byte, m: fint, n: fint): if T is complex or T is complex64: self._init_complex(jobz, m, n) else: self._init_real(jobz, m, n) def call(self): A = self.A S = self.S U = self.U VT = self.VT WORK = self.WORK RWORK = self.RWORK IWORK = self.IWORK M = self.M N = self.N LDA = self.LDA LDU = self.LDU LDVT = self.LDVT LWORK = self.LWORK JOBZ = self.JOBZ rv = fint() args_real = (__ptr__(JOBZ), __ptr__(M), __ptr__(N), A.as_byte(), __ptr__(LDA), S.as_byte(), U.as_byte(), __ptr__(LDU), VT.as_byte(), __ptr__(LDVT), WORK.as_byte(), __ptr__(LWORK), IWORK, __ptr__(rv)) args_cplx = (__ptr__(JOBZ), __ptr__(M), __ptr__(N), A.as_byte(), __ptr__(LDA), S.as_byte(), U.as_byte(), __ptr__(LDU), VT.as_byte(), __ptr__(LDVT), WORK.as_byte(), __ptr__(LWORK), RWORK.as_byte(), IWORK, __ptr__(rv)) if T is float: dgesdd_(*args_real) elif T is float32: sgesdd_(*args_real) elif T is complex: zgesdd_(*args_cplx) elif T is complex64: cgesdd_(*args_cplx) else: compile_error("[internal error] bad dtype for gesdd") return rv def release(self): free(self.A) free(self.WORK) @tuple class SVDResult[A1, A2]: U: A1 S: A2 Vh: A1 def __getitem__(self, idx: Static[int]): if idx == 0 or idx == -3: return self.U elif idx == 1 or idx == -2: return self.S elif idx == 2 or idx == -1: return self.Vh else: compile_error("tuple ('SVDResult') index out of range") def _svd(a, JOBZ: byte, compute_uv: Static[int]): B = type(_basetype(a.dtype)) m, n = _rows_cols(a) min_m_n = min(m, n) params = GesddParams[a.dtype, B](JOBZ, fint(m), fint(n)) S = empty(a.shape[:-2] + (min_m_n,), dtype=B) a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2]) s_out = LinearizeData(1, min_m_n, 0, S.strides[-1]) if compute_uv: if JOBZ == byte(83): # 'S' u_columns = min_m_n v_rows = min_m_n else: u_columns = m v_rows = n U = empty(a.shape[:-2] + (m, u_columns), dtype=a.dtype) V = empty(a.shape[:-2] + (v_rows, n), dtype=a.dtype) u_out = LinearizeData(u_columns, m, U.strides[-1], U.strides[-2]) v_out = LinearizeData(n, v_rows, V.strides[-1], V.strides[-2]) else: U = None V = None u_out = None v_out = None error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) s_ptr = S._ptr(idx + (0,)) if compute_uv: u_ptr = U._ptr(idx + (0, 0)) v_ptr = V._ptr(idx + (0, 0)) else: u_ptr = None v_ptr = None a_in.linearize(params.A, a_ptr) not_ok = params.call() if not not_ok: if compute_uv: if JOBZ == byte(65) and min_m_n == 0: LinearizeData.identity_matrix(params.U, m) LinearizeData.identity_matrix(params.VT, n) u_out.delinearize(u_ptr, params.U) s_out.delinearize(s_ptr, params.S) v_out.delinearize(v_ptr, params.VT) else: s_out.delinearize(s_ptr, params.S) else: if compute_uv: u_out.nan_matrix(u_ptr) s_out.nan_matrix(s_ptr) v_out.nan_matrix(v_ptr) else: s_out.nan_matrix(s_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("SVD did not converge") if compute_uv: return SVDResult(U, S, V) else: return S def svd(a, full_matrices: bool = True, compute_uv: Static[int] = True, hermitian: bool = False): a = _asarray(a) dtype = a.dtype m, n = _rows_cols(a) k = min(m, n) if hermitian: pre_shape = a.shape[:-2] s_shape = pre_shape + (k,) u_shape = pre_shape + ((m, m) if full_matrices else (m, k)) v_shape = pre_shape + ((n, n) if full_matrices else (k, n)) if compute_uv: e_s, e_u = eigh(a) # note that `k == m == n` in this case sidx = Ptr[int](k) if dtype is complex: SH = empty(s_shape, float) elif dtype is complex64: SH = empty(s_shape, float32) else: SH = empty(s_shape, dtype) UH = empty(u_shape, dtype) VH = empty(v_shape, dtype) if a.ndim == 2: for r in range(k): sidx[r] = r sort(sidx, k, key=lambda x: -abs(e_s.data[x])) for j in range(k): mapping = sidx[j] s_elem = e_s.data[mapping] negative = s_elem < type(s_elem)() SH.data[j] = abs(s_elem) for i in range(k): u_elem = e_u._ptr((i, mapping))[0] dst = UH._ptr((i, j)) dst[0] = u_elem dst = VH._ptr((j, i)) # transpose if negative: u_elem = -u_elem if hasattr(u_elem, "conjugate"): u_elem = u_elem.conjugate() dst[0] = u_elem else: for idx in multirange(pre_shape): sub_e_s = e_s._ptr(idx + (0,)) sub_e_u = e_u._ptr(idx + (0, 0)) sub_SH = SH._ptr(idx + (0,)) sub_UH = UH._ptr(idx + (0, 0)) sub_VH = VH._ptr(idx + (0, 0)) for r in range(k): sidx[r] = r sort(sidx, k, key=lambda x: -abs(sub_e_s[x])) for j in range(k): mapping = sidx[j] s_elem = sub_e_s[mapping] negative = s_elem < type(s_elem)() sub_SH[j] = abs(s_elem) for i in range(k): u_elem = (sub_e_u + (i * k + mapping))[0] dst = sub_UH + (i * k + j) dst[0] = u_elem dst = sub_VH + (j * k + i) # transpose if negative: u_elem = -u_elem if hasattr(u_elem, "conjugate"): u_elem = u_elem.conjugate() dst[0] = u_elem return SVDResult(UH, SH, VH) else: evh = eigvalsh(a) if a.ndim == 2: evh = eigvalsh(a) for i in range(k): evh.data[i] = abs(evh.data[i]) sort(evh.data, k, key=lambda x: -x) else: for idx in multirange(pre_shape): sub_evh = evh._ptr(idx + (0,)) for i in range(k): sub_evh[i] = abs(sub_evh[i]) sort(sub_evh, k, key=lambda x: -x) return evh if compute_uv: jobz_code = byte(65 if full_matrices else 83) # 'A' / 'S' return _svd(a, jobz_code, compute_uv=True) else: jobz_code = byte(78) # 'N' return _svd(a, jobz_code, compute_uv=False) def svdvals(x): return svd(x, compute_uv=False, hermitian=False) ###### # QR # ###### class GeqrfParams: M: fint N: fint A: Ptr[T] LDA: fint TAU: Ptr[T] WORK: Ptr[T] LWORK: fint T: type def _init_real(self, m: fint, n: fint): min_m_n = m if m < n else n safe_min_m_n = int(min_m_n) safe_m = int(m) safe_n = int(n) a_size = safe_m * safe_n * sizeof(T) tau_size = safe_min_m_n * sizeof(T) lda = m if m else fint(1) mem_buff = cobj(a_size + tau_size) a = mem_buff tau = a + a_size str.memset(tau.as_byte(), byte(0), tau_size) self.M = m self.N = n self.A = Ptr[T](a) self.TAU = Ptr[T](tau) self.LDA = lda # work size query work_size_query = T() self.WORK = __ptr__(work_size_query) self.LWORK = fint(-1) if self.call(): free(mem_buff) raise LinAlgError("error in geqrf work size query") work_count = cast(work_size_query, fint) lw = n if n else fint(1) lw = lw if lw > work_count else work_count self.LWORK = lw work_size = int(lw) * sizeof(T) mem_buff2 = cobj(work_size) work = mem_buff2 self.WORK = Ptr[T](work) def _init_complex(self, m: fint, n: fint): min_m_n = m if m < n else n safe_min_m_n = int(min_m_n) safe_m = int(m) safe_n = int(n) a_size = safe_m * safe_n * sizeof(T) tau_size = safe_min_m_n * sizeof(T) lda = m if m else fint(1) mem_buff = cobj(a_size + tau_size) a = mem_buff tau = a + a_size str.memset(tau.as_byte(), byte(0), tau_size) self.M = m self.N = n self.A = Ptr[T](a) self.TAU = Ptr[T](tau) self.LDA = lda # work size query work_size_query = T() self.WORK = __ptr__(work_size_query) self.LWORK = fint(-1) if self.call(): free(mem_buff) raise LinAlgError("error in geqrf work size query") work_count = cast(work_size_query.real, fint) lw = n if n else fint(1) lw = lw if lw > work_count else work_count self.LWORK = lw work_size = int(lw) * sizeof(T) mem_buff2 = cobj(work_size) work = mem_buff2 self.WORK = Ptr[T](work) def __init__(self, m: fint, n: fint): if T is complex or T is complex64: self._init_complex(m, n) else: self._init_real(m, n) def call(self): M = self.M N = self.N A = self.A LDA = self.LDA TAU = self.TAU WORK = self.WORK LWORK = self.LWORK rv = fint() args = (__ptr__(M), __ptr__(N), A.as_byte(), __ptr__(LDA), TAU.as_byte(), WORK.as_byte(), __ptr__(LWORK), __ptr__(rv)) if T is float: dgeqrf_(*args) elif T is float32: sgeqrf_(*args) elif T is complex: zgeqrf_(*args) elif T is complex64: cgeqrf_(*args) else: compile_error("[internal error] bad dtype for geqrf") return rv def release(self): free(self.A) free(self.WORK) def _qr_r_raw(a): m, n = _rows_cols(a) params = GeqrfParams[a.dtype](fint(m), fint(n)) k = min(m, n) tau = empty(a.shape[:-2] + (k,), dtype=a.dtype) a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2]) tau_out = LinearizeData(1, k, 1, tau.strides[-1]) error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) tau_ptr = tau._ptr(idx + (0,)) a_in.linearize(params.A, a_ptr) not_ok = params.call() if not not_ok: a_in.delinearize(a_ptr, params.A) tau_out.delinearize(tau_ptr, params.TAU) else: tau_out.nan_matrix(tau_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("Incorrect argument found while performing " "QR factorization") return tau class GqrParams: M: fint MC: fint MN: fint A: Ptr[T] Q: Ptr[T] LDA: fint TAU: Ptr[T] WORK: Ptr[T] LWORK: fint T: type def _init_real(self, m: fint, n: fint, mc: fint): min_m_n = m if m < n else n safe_mc = int(mc) safe_min_m_n = int(min_m_n) safe_m = int(m) safe_n = int(n) a_size = safe_m * safe_n * sizeof(T) q_size = safe_m * safe_mc * sizeof(T) tau_size = safe_min_m_n * sizeof(T) lda = m if m else fint(1) mem_buff = cobj(q_size + tau_size + a_size) q = mem_buff tau = q + q_size a = tau + tau_size self.M = m self.MC = mc self.MN = min_m_n self.A = Ptr[T](a) self.Q = Ptr[T](q) self.TAU = Ptr[T](tau) self.LDA = lda # work size query work_size_query = T() self.WORK = __ptr__(work_size_query) self.LWORK = fint(-1) if self.call(): free(mem_buff) raise LinAlgError("error in gqr work size query") work_count = cast(work_size_query, fint) lw = n if n else fint(1) lw = lw if lw > work_count else work_count self.LWORK = lw work_size = int(lw) * sizeof(T) mem_buff2 = cobj(work_size) work = mem_buff2 self.WORK = Ptr[T](work) def _init_complex(self, m: fint, n: fint, mc: fint): min_m_n = m if m < n else n safe_mc = int(mc) safe_min_m_n = int(min_m_n) safe_m = int(m) safe_n = int(n) a_size = safe_m * safe_n * sizeof(T) q_size = safe_m * safe_mc * sizeof(T) tau_size = safe_min_m_n * sizeof(T) lda = m if m else fint(1) mem_buff = cobj(q_size + tau_size + a_size) q = mem_buff tau = q + q_size a = tau + tau_size self.M = m self.MC = mc self.MN = min_m_n self.A = Ptr[T](a) self.Q = Ptr[T](q) self.TAU = Ptr[T](tau) self.LDA = lda # work size query work_size_query = T() self.WORK = __ptr__(work_size_query) self.LWORK = fint(-1) if self.call(): free(mem_buff) raise LinAlgError("error in gqr work size query") work_count = cast(work_size_query.real, fint) lw = n if n else fint(1) lw = lw if lw > work_count else work_count self.LWORK = lw work_size = int(lw) * sizeof(T) mem_buff2 = cobj(work_size) work = mem_buff2 self.WORK = Ptr[T](work) def __init__(self, m: fint, n: fint, mc: fint): if T is complex or T is complex64: self._init_complex(m, n, mc) else: self._init_real(m, n, mc) def __init__(self, m: fint, n: fint): self.__init__(m, n, m if m < n else n) def call(self): M = self.M MC = self.MC MN = self.MN A = self.A Q = self.Q LDA = self.LDA TAU = self.TAU WORK = self.WORK LWORK = self.LWORK rv = fint() args = (__ptr__(M), __ptr__(MC), __ptr__(MN), Q.as_byte(), __ptr__(LDA), TAU.as_byte(), WORK.as_byte(), __ptr__(LWORK), __ptr__(rv)) if T is float: dorgqr_(*args) elif T is float32: sorgqr_(*args) elif T is complex: zungqr_(*args) elif T is complex64: cungqr_(*args) else: compile_error("[internal error] bad dtype for geqrf") return rv def release(self): free(self.Q) free(self.WORK) def _qr_reduced(a, tau): m, n = _rows_cols(a) params = GqrParams[a.dtype](fint(m), fint(n)) k = min(m, n) q = empty(a.shape[:-2] + (m, k), dtype=a.dtype) a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2]) tau_in = LinearizeData(1, k, 1, tau.strides[-1]) q_out = LinearizeData(k, m, q.strides[-1], q.strides[-2]) error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) tau_ptr = tau._ptr(idx + (0,)) q_ptr = q._ptr(idx + (0, 0)) a_in.linearize(params.A, a_ptr) a_in.linearize(params.Q, a_ptr) tau_in.linearize(params.TAU, tau_ptr) not_ok = params.call() if not not_ok: q_out.delinearize(q_ptr, params.Q) else: q_out.nan_matrix(q_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("Incorrect argument found while performing " "QR factorization") return q def _qr_complete(a, tau): m, n = _rows_cols(a) params = GqrParams[a.dtype](fint(m), fint(n), fint(m)) k = min(m, n) q = empty(a.shape[:-2] + (m, m), dtype=a.dtype) a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2]) tau_in = LinearizeData(1, k, 1, tau.strides[-1]) q_out = LinearizeData(m, m, q.strides[-1], q.strides[-2]) error_occured = False for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) tau_ptr = tau._ptr(idx + (0,)) q_ptr = q._ptr(idx + (0, 0)) a_in.linearize(params.A, a_ptr) a_in.linearize(params.Q, a_ptr) tau_in.linearize(params.TAU, tau_ptr) not_ok = params.call() if not not_ok: q_out.delinearize(q_ptr, params.Q) else: q_out.nan_matrix(q_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("Incorrect argument found while performing " "QR factorization") return q @tuple class QRResult[A]: Q: A R: A def __getitem__(self, idx: Static[int]): if idx == 0 or idx == -2: return self.Q elif idx == 1 or idx == -1: return self.R else: compile_error("tuple ('QRResult') index out of range") def _triu(a): m = a.shape[-2] n = a.shape[-1] sm = a.strides[-2] sn = a.strides[-1] for idx in multirange(a.shape[:-2]): a_ptr = a._ptr(idx + (0, 0)) for i in range(1, m): for j in range(min(i, n)): p = a_ptr.as_byte() + i*sm + j*sn Ptr[a.dtype](p)[0] = zero(a.dtype) def qr(a, mode: Static[str] = 'reduced'): if (mode != 'reduced' and mode != 'complete' and mode != 'r' and mode != 'raw'): compile_error("Unrecognized mode '" + mode + "'") a0 = a a = _asarray(a) # copy if _asarray() didn't do so already if isinstance(a0, ndarray): if a0.dtype is a.dtype: a = a.copy() tau = _qr_r_raw(a) m, n = _rows_cols(a) mn = min(m, n) if mode == 'r': r = a[..., :mn, :] _triu(r) return r if mode == 'raw': q = a.T return q, tau if mode == 'complete' and m > n: mc = m q = _qr_complete(a, tau) else: mc = mn q = _qr_reduced(a, tau) r = a[..., :mc, :] _triu(a) return QRResult(q, r) ################# # Least Squares # ################# class GelsdParams: M: fint N: fint NRHS: fint A: Ptr[T] LDA: fint B_: Ptr[T] LDB: fint S: Ptr[B] RCOND: Ptr[B] RANK: fint WORK: Ptr[T] LWORK: fint RWORK: Ptr[B] IWORK: Ptr[fint] T: type B: type def _init_real(self, m: fint, n: fint, nrhs: fint): min_m_n = m if m < n else n max_m_n = m if m > n else n safe_min_m_n = int(min_m_n) safe_max_m_n = int(max_m_n) safe_m = int(m) safe_n = int(n) safe_nrhs = int(nrhs) a_size = safe_m * safe_n * sizeof(T) b_size = safe_max_m_n * safe_nrhs * sizeof(T) s_size = safe_min_m_n * sizeof(T) lda = m if m else fint(1) ldb = max_m_n if max_m_n else fint(1) msize = a_size + b_size + s_size mem_buff = cobj(msize if msize else 1) a = mem_buff b = a + a_size s = b + b_size self.M = m self.N = n self.NRHS = nrhs self.A = Ptr[T](a) self.B_ = Ptr[T](b) self.S = Ptr[B](s) self.LDA = lda self.LDB = ldb # work size query work_size_query = T() iwork_size_query = fint() self.WORK = __ptr__(work_size_query) self.IWORK = __ptr__(iwork_size_query) self.RWORK = Ptr[B]() self.LWORK = fint(-1) if self.call(): free(mem_buff) raise LinAlgError("error in gelsd work size query") work_count = cast(work_size_query, fint) work_size = int(work_count) * sizeof(T) iwork_size = int(iwork_size_query) * sizeof(fint) mem_buff2 = cobj(work_size + iwork_size) work = mem_buff2 iwork = work + work_size self.WORK = Ptr[T](work) self.RWORK = Ptr[B]() self.IWORK = Ptr[fint](iwork) self.LWORK = work_count def _init_complex(self, m: fint, n: fint, nrhs: fint): min_m_n = m if m < n else n max_m_n = m if m > n else n safe_min_m_n = int(min_m_n) safe_max_m_n = int(max_m_n) safe_m = int(m) safe_n = int(n) safe_nrhs = int(nrhs) a_size = safe_m * safe_n * sizeof(T) b_size = safe_max_m_n * safe_nrhs * sizeof(T) s_size = safe_min_m_n * sizeof(T) lda = m if m else fint(1) ldb = max_m_n if max_m_n else fint(1) msize = a_size + b_size + s_size mem_buff = cobj(msize if msize else 1) a = mem_buff b = a + a_size s = b + b_size self.M = m self.N = n self.NRHS = nrhs self.A = Ptr[T](a) self.B_ = Ptr[T](b) self.S = Ptr[B](s) self.LDA = lda self.LDB = ldb # work size query work_size_query = T() rwork_size_query = B() iwork_size_query = fint() self.WORK = __ptr__(work_size_query) self.IWORK = __ptr__(iwork_size_query) self.RWORK = __ptr__(rwork_size_query) self.LWORK = fint(-1) if self.call(): free(mem_buff) raise LinAlgError("error in gelsd work size query") work_count = cast(work_size_query.real, fint) work_size = int(work_count) * sizeof(T) rwork_size = cast(rwork_size_query, int) * sizeof(B) iwork_size = int(iwork_size_query) * sizeof(fint) mem_buff2 = cobj(work_size + rwork_size + iwork_size) work = mem_buff2 rwork = work + work_size iwork = rwork + rwork_size self.WORK = Ptr[T](work) self.RWORK = Ptr[B](rwork) self.IWORK = Ptr[fint](iwork) self.LWORK = work_count def __init__(self, m: fint, n: fint, nrhs: fint): if T is complex or T is complex64: self._init_complex(m, n, nrhs) else: self._init_real(m, n, nrhs) def call(self): M = self.M N = self.N NRHS = self.NRHS A = self.A LDA = self.LDA B = self.B_ LDB = self.LDB S = self.S RCOND = self.RCOND RANK = self.RANK WORK = self.WORK LWORK = self.LWORK RWORK = self.RWORK IWORK = self.IWORK rv = fint() args_real = (__ptr__(M), __ptr__(N), __ptr__(NRHS), A.as_byte(), __ptr__(LDA), B.as_byte(), __ptr__(LDB), S.as_byte(), RCOND.as_byte(), __ptr__(RANK), WORK.as_byte(), __ptr__(LWORK), IWORK, __ptr__(rv)) args_cplx = (__ptr__(M), __ptr__(N), __ptr__(NRHS), A.as_byte(), __ptr__(LDA), B.as_byte(), __ptr__(LDB), S.as_byte(), RCOND.as_byte(), __ptr__(RANK), WORK.as_byte(), __ptr__(LWORK), RWORK.as_byte(), IWORK, __ptr__(rv)) if T is float: dgelsd_(*args_real) elif T is float32: sgelsd_(*args_real) elif T is complex: zgelsd_(*args_cplx) elif T is complex64: cgelsd_(*args_cplx) else: compile_error("[internal error] bad dtype for gelsd") self.RANK = RANK return rv def release(self): free(self.A) free(self.WORK) def _abs2(p: Ptr[T], n: int, T: type): B = type(_basetype(T)) res = B() for i in range(n): el = p[i] if T is complex or T is complex64: res += el.real*el.real + el.imag*el.imag else: res += el * el return res def lstsq(a, b, rcond = None): dtype = type( coerce( type(asarray(a).data[0]), type(asarray(b).data[0]))) a = _asarray(a, dtype=dtype) b = _asarray(b, dtype=dtype) if a.ndim != 2: compile_error("lstsq argument 'a' must be 2-dimensional") if b.ndim == 1: b2 = b.reshape(b.shape[0], 1) elif b.ndim == 2: b2 = b else: compile_error("lstsq argument 'b' must be 1- or 2-dimensional") B = type(_basetype(a.dtype)) m, n = a.shape m2, nrhs = b2.shape if m != m2: raise LinAlgError('Incompatible dimensions') if rcond is None: rc = eps(B) * cast(max(m, n), B) else: rc = cast(rcond, B) if nrhs == 0: b2 = ndarray((m, 1), Ptr[b2.dtype](m)) str.memset(b2.data.as_byte(), byte(0), b2.nbytes) safe_nrhs = 1 else: safe_nrhs = nrhs excess = m - n params = GelsdParams[a.dtype, B](fint(m), fint(n), fint(safe_nrhs)) x = empty((n, safe_nrhs), dtype=a.dtype) r = empty(safe_nrhs, dtype=B) s = empty(min(m, n), dtype=B) a_in = LinearizeData(n, m, a.strides[-1], a.strides[-2]) b_in = LinearizeData(safe_nrhs, m, b2.strides[-1], b2.strides[-2], max(m, n)) x_out = LinearizeData(safe_nrhs, n, x.strides[-1], x.strides[-2], max(m, n)) r_out = LinearizeData(1, safe_nrhs, 1, r.strides[-1]) s_out = LinearizeData(1, min(m, n), 1, s.strides[-1]) a_ptr = a.data b_ptr = b2.data x_ptr = x.data r_ptr = r.data s_ptr = s.data a_in.linearize(params.A, a_ptr) b_in.linearize(params.B_, b_ptr) params.RCOND = __ptr__(rc) not_ok = params.call() error_occured = False if not not_ok: x_out.delinearize(x_ptr, params.B_) rank = int(params.RANK) s_out.delinearize(s_ptr, params.S) if excess >= 0 and int(params.RANK) == n: components = params.B_ + n for i in range(safe_nrhs): vector = components + i*m r_ptr[i] = _abs2(vector, excess) else: r_out.nan_matrix(r_ptr) else: x_out.nan_matrix(x_ptr) r_out.nan_matrix(r_ptr) rank = -1 s_out.nan_matrix(s_ptr) error_occured = True params.release() if error_occured: raise LinAlgError("SVD did not converge in Linear Least Squares") if m == 0: str.memset(x.data.as_byte(), byte(0), x.nbytes) if nrhs == 0: free(x.data) free(r.data) x = ndarray((x.shape[0], 0), Ptr[x.dtype]()) r = ndarray((0,), Ptr[r.dtype]()) if b.ndim == 1: x1 = x.reshape(x.shape[0]) else: x1 = x if r.size and (rank != n or m <= n): free(r.data) r = ndarray((0,), Ptr[r.dtype]()) return x1, r, rank, s ################### # Matrix Multiply # ################### def _dot_noblas(_ip1: Ptr[T1], is1: int, _ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int, T1: type, T2: type, T3: type): ip1 = _ip1.as_byte() ip2 = _ip2.as_byte() ans = zero(T3) for i in range(n): e1 = Ptr[T1](ip1)[0] e2 = Ptr[T2](ip2)[0] ans += cast(e1, T3) * cast(e2, T3) ip1 += is1 ip2 += is2 op[0] = ans _CBLAS_INT_MAX: Static[int] = 0x7fffffff def _blas_stride(stride: int, itemsize: int): if stride > 0 and stride % itemsize == 0: stride //= itemsize if stride <= _CBLAS_INT_MAX: return stride return 0 def _dot(ip1: Ptr[T1], is1: int, ip2: Ptr[T2], is2: int, op: Ptr[T3], n: int, T1: type, T2: type, T3: type): ib1 = _blas_stride(is1, sizeof(T1)) ib2 = _blas_stride(is2, sizeof(T2)) args = (fint(n), ip1.as_byte(), fint(ib1), ip2.as_byte(), fint(ib2)) if ib1 == 0 or ib2 == 0: _dot_noblas(ip1, is1, ip2, is2, op, n) else: if T1 is float and T2 is float and T3 is float: op[0] = cblas_ddot(*args) elif T1 is float32 and T2 is float32 and T3 is float32: op[0] = cblas_sdot(*args) elif T1 is complex and T2 is complex and T3 is complex: cblas_zdotu_sub(*args, op) elif T1 is complex64 and T2 is complex64 and T3 is complex64: cblas_cdotu_sub(*args, op) else: _dot_noblas(ip1, is1, ip2, is2, op, n) BLAS_MAXSIZE: Static[int] = 2147483646 # 2^31 - 1 def _is_blasable2d(byte_stride1: int, byte_stride2: int, d1: int, d2: int, itemsize: int): unit_stride1 = byte_stride1 // itemsize if byte_stride2 != itemsize: return False if (byte_stride1 % itemsize == 0 and unit_stride1 >= d2 and unit_stride1 <= BLAS_MAXSIZE): return True return False def _gemv(ip1: Ptr[T], is1_m: int, is1_n: int, ip2: Ptr[T], is2_n: int, op: Ptr[T], op_m: int, m: int, n: int, T: type): itemsize = sizeof(T) if _is_blasable2d(is1_m, is1_n, m, n, itemsize): order = CBLAS_COL_MAJOR lda = is1_m // itemsize else: order = CBLAS_ROW_MAJOR lda = is1_n // itemsize one_dtype = cast(1, T) zero_dtype = cast(0, T) if T is float or T is float32: one = one_dtype zero = zero_dtype else: one = __ptr__(one_dtype).as_byte() zero = __ptr__(zero_dtype).as_byte() args = (order, CBLAS_TRANS, fint(n), fint(m), one, ip1.as_byte(), fint(lda), ip2.as_byte(), fint(is2_n // itemsize), zero, op.as_byte(), fint(op_m // itemsize)) if T is float: cblas_dgemv(*args) elif T is float32: cblas_sgemv(*args) elif T is complex: cblas_zgemv(*args) elif T is complex64: cblas_cgemv(*args) else: compile_error("[internal error] bad input type") def _matmul_matrixmatrix(ip1: Ptr[T], is1_m: int, is1_n: int, ip2: Ptr[T], is2_n: int, is2_p: int, op: Ptr[T], os_m: int, os_p: int, m: int, n: int, p: int, T: type): order = CBLAS_ROW_MAJOR itemsize = sizeof(T) ldc = os_m // itemsize if _is_blasable2d(is1_m, is1_n, m, n, itemsize): trans1 = CBLAS_NO_TRANS lda = is1_m // itemsize else: trans1 = CBLAS_TRANS lda = is1_n // itemsize if _is_blasable2d(is2_n, is2_p, n, p, itemsize): trans2 = CBLAS_NO_TRANS ldb = is2_n // itemsize else: trans2 = CBLAS_TRANS ldb = is2_p // itemsize one_dtype = cast(1, T) zero_dtype = cast(0, T) if T is float or T is float32: one = one_dtype zero = zero_dtype else: one = __ptr__(one_dtype).as_byte() zero = __ptr__(zero_dtype).as_byte() if (ip1 == ip2 and m == p and is1_m == is2_p and is1_n == is2_n and trans1 != trans2): ld = lda if trans1 == CBLAS_NO_TRANS else ldb args = (order, CBLAS_UPPER, trans1, fint(p), fint(n), one, ip1, fint(ld), zero, op, fint(ldc)) if T is float: cblas_dsyrk(*args) elif T is float32: cblas_ssyrk(*args) elif T is complex: cblas_zsyrk(*args) elif T is complex64: cblas_csyrk(*args) else: compile_error("[internal error] bad input type") # Copy the triangle for i in range(p): for j in range(i + 1, p): op[j * ldc + i] = op[i * ldc + j] else: args = (order, trans1, trans2, fint(m), fint(p), fint(n), one, ip1, fint(lda), ip2, fint(ldb), zero, op, fint(ldc)) if T is float: cblas_dgemm(*args) elif T is float32: cblas_sgemm(*args) elif T is complex: cblas_zgemm(*args) elif T is complex64: cblas_cgemm(*args) else: compile_error("[internal error] bad input type") def _matmul_inner_noblas(_ip1: Ptr[T1], is1_m: int, is1_n: int, _ip2: Ptr[T2], is2_n: int, is2_p: int, _op: Ptr[T3], os_m: int, os_p: int, dm: int, dn: int, dp: int, T1: type, T2: type, T3: type): ib1_n = is1_n * dn ib2_n = is2_n * dn ib2_p = is2_p * dp ob_p = os_p * dp ip1 = _ip1.as_byte() ip2 = _ip2.as_byte() op = _op.as_byte() if T3 is bool: for m in range(dm): for p in range(dp): ip1tmp = ip1 ip2tmp = ip2 Ptr[T3](op)[0] = False for n in range(dn): val1 = Ptr[T1](ip1tmp)[0] val2 = Ptr[T2](ip2tmp)[0] if val1 and val2: Ptr[T3](op)[0] = True break ip2tmp += is2_n ip1tmp += is1_n op += os_p ip2 += is2_p op -= ob_p ip2 -= ib2_p ip1 += is1_m op += os_m else: for m in range(dm): for p in range(dp): for n in range(dn): val1 = cast(Ptr[T1](ip1)[0], T3) val2 = cast(Ptr[T2](ip2)[0], T3) Ptr[T3](op)[0] += val1 * val2 ip2 += is2_n ip1 += is1_n ip1 -= ib1_n ip2 -= ib2_n op += os_p ip2 += is2_p op -= ob_p ip2 -= ib2_p ip1 += is1_m op += os_m def _matmul(A, B, C): # Caller ensures arrays are broadcasted and of correct shape. # A -> (m, n) # B -> (n, p) # C -> (m, p) dm = A.shape[-2] dn = A.shape[-1] dp = B.shape[-1] is1_m = A.strides[-2] is1_n = A.strides[-1] is2_n = B.strides[-2] is2_p = B.strides[-1] os_m = C.strides[-2] os_p = C.strides[-1] sz = sizeof(A.dtype) special_case = (dm == 1 or dn == 1 or dp == 1) any_zero_dim = (dm == 0 or dn == 0 or dp == 0) scalar_out = (dm == 1 and dp == 1) scalar_vec = (dn == 1 and (dp == 1 or dm == 1)) too_big_for_blas = (dm > BLAS_MAXSIZE or dn > BLAS_MAXSIZE or dp > BLAS_MAXSIZE) i1_c_blasable = _is_blasable2d(is1_m, is1_n, dm, dn, sz) i2_c_blasable = _is_blasable2d(is2_n, is2_p, dn, dp, sz) i1_f_blasable = _is_blasable2d(is1_n, is1_m, dn, dm, sz) i2_f_blasable = _is_blasable2d(is2_p, is2_n, dp, dn, sz) i1blasable = i1_c_blasable or i1_f_blasable i2blasable = i2_c_blasable or i2_f_blasable o_c_blasable = _is_blasable2d(os_m, os_p, dm, dp, sz) o_f_blasable = _is_blasable2d(os_p, os_m, dp, dm, sz) vector_matrix = (dm == 1 and i2blasable and _is_blasable2d(is1_n, sz, dn, 1, sz)) matrix_vector = (dp == 1 and i1blasable and _is_blasable2d(is2_n, sz, dn, 1, sz)) for idx in multirange(C.shape[:-2]): ip1 = A._ptr(idx + (0, 0), broadcast=(A.ndim > 2)) ip2 = B._ptr(idx + (0, 0), broadcast=(B.ndim > 2)) op = C._ptr(idx + (0, 0)) if not ((A.dtype is B.dtype and B.dtype is C.dtype) and (A.dtype is float or A.dtype is float32 or A.dtype is complex or A.dtype is complex64)): _matmul_inner_noblas(ip1, is1_m, is1_n, ip2, is2_n, is2_p, op, os_m, os_p, dm, dn, dp) else: if too_big_for_blas or any_zero_dim: _matmul_inner_noblas(ip1, is1_m, is1_n, ip2, is2_n, is2_p, op, os_m, os_p, dm, dn, dp) elif special_case: if scalar_out: _dot(ip1, is1_n, ip2, is2_n, op, dn) elif scalar_vec: _matmul_inner_noblas(ip1, is1_m, is1_n, ip2, is2_n, is2_p, op, os_m, os_p, dm, dn, dp) elif vector_matrix: _gemv(ip2, is2_p, is2_n, ip1, is1_n, op, os_p, dp, dn) elif matrix_vector: _gemv(ip1, is1_m, is1_n, ip2, is2_n, op, os_m, dm, dn) else: _matmul_inner_noblas(ip1, is1_m, is1_n, ip2, is2_n, is2_p, op, os_m, os_p, dm, dn, dp) else: if i1blasable and i2blasable and o_c_blasable: _matmul_matrixmatrix(ip1, is1_m, is1_n, ip2, is2_n, is2_p, op, os_m, os_p, dm, dn, dp) elif i1blasable and i2blasable and o_f_blasable: _matmul_matrixmatrix(ip2, is2_p, is2_n, ip1, is1_n, is1_m, op, os_p, os_m, dp, dn, dm) else: _matmul_inner_noblas(ip1, is1_m, is1_n, ip2, is2_n, is2_p, op, os_m, os_p, dm, dn, dp) def _check_out(out, ans_shape, dtype: type): if out.dtype is not dtype: compile_error("'out' array has incorrect type '" + out.dtype.__name__ + "'") if staticlen(out.shape) != staticlen(ans_shape): compile_error("'out' array has incorrect number of dimensions") if out.shape != ans_shape: raise ValueError("'out' array has incorrect shape") def _matmul_ufunc(x1, x2, out = None, dtype: type = NoneType): if not isinstance(x1, ndarray) or not isinstance(x2, ndarray): y1 = asarray(x1) y2 = asarray(x2) return _matmul_ufunc(y1, y2, out=out, dtype=dtype) T1 = x1.dtype T2 = x2.dtype x1d: Static[int] = x1.ndim x2d: Static[int] = x2.ndim if dtype is NoneType: return _matmul_ufunc(x1, x2, out=out, dtype=type(coerce(T1, T2))) if x1d == 0: compile_error("first argument is 0-dimensional; must be at least 1-d") if x2d == 0: compile_error("second argument is 0-dimensional; must be at least 1-d") if x1d == 1 and x2d == 1: if x1.shape != x2.shape: raise ValueError("matmul: vectors have different lengths") op = zero(dtype) _dot(x1.data, x1.strides[0], x2.data, x2.strides[0], __ptr__(op), x1.size) return op if x1d == 1: y1 = x1.reshape((1, -1)) else: y1 = x1 if x2d == 1: y2 = x2.reshape((-1, 1)) else: y2 = x2 y1s = y1.shape y2s = y2.shape y1d: Static[int] = y1.ndim y2d: Static[int] = y2.ndim base1s = y1s[:-2] base2s = y2s[:-2] mat1s = y1s[-2:] mat2s = y2s[-2:] m = mat1s[0] k = mat1s[1] n = mat2s[1] if k != mat2s[0]: raise ValueError("matmul: last dimension of first argument does not " "match second-to-last dimension of second argument") ans_base = broadcast_shapes(base1s, base2s) if x1d == 1 and x2d == 1: ans_shape = ans_base elif x1d == 1: ans_shape = ans_base + (mat2s[1],) elif x2d == 1: ans_shape = ans_base + (mat1s[0],) else: ans_shape = ans_base + (mat1s[0], mat2s[1]) if out is None: ans = zeros(ans_shape, dtype=dtype) elif isinstance(out, ndarray): _check_out(out, ans_shape, dtype) ans = out ans.map(lambda _: cast(0, ans.dtype), inplace=True) else: compile_error("'out' must be an ndarray") if x1d == 1 and x2d == 1: _matmul(y1, y2, ans.reshape(1, 1)) elif x1d == 1: _matmul(y1.reshape((1,) * (x2d - 1) + y1.shape), y2, ans.reshape(ans.shape + (1,))) elif x2d == 1: _matmul(y1, y2.reshape((1,) * (x1d - 1) + y2.shape), ans.reshape(ans.shape + (1,))) elif x1d > x2d: _matmul(y1, y2.reshape((1,) * (x1d - x2d) + y2.shape), ans) elif x1d < x2d: _matmul(y1.reshape((1,) * (x2d - x1d) + y1.shape), y2, ans) else: _matmul(y1, y2, ans) if out is not None: return out else: if ans.ndim == 0: return ans.data[0] else: return ans def matmul(x1, x2): return _matmul_ufunc(x1, x2) def dot(a, b, out = None): x1 = asarray(a) x2 = asarray(b) T1 = x1.dtype T2 = x2.dtype x1d: Static[int] = staticlen(x1.shape) x2d: Static[int] = staticlen(x2.shape) if x1d == 0 or x2d == 0: return multiply(a, b, out=out) if x1d <= 2 and x2d <= 2: return _matmul_ufunc(a, b, out=out) # most general case x1s = x1.shape x2s = x2.shape m = x1s[-2] k = x1s[-1] n = x2s[-1] if k != x2s[-2]: raise ValueError("dot: last dimension of first argument does not " "match second-to-last dimension of second argument") dtype = type(coerce(T1, T2)) ans_shape = x1s[:-1] + x2s[:-2] + (x2s[-1],) if out is None: ans = zeros(ans_shape, dtype=dtype) elif isinstance(out, ndarray): _check_out(out, ans_shape, dtype) ans = out else: compile_error("'out' must be an ndarray") for idx in multirange(ans_shape): q = ans._ptr(idx) for i1 in range(k): x1_idx = idx[:x1d-1] + (i1,) x2_idx = idx[x1d-1:-1] + (i1, idx[-1]) p1 = x1._ptr(x1_idx) p2 = x2._ptr(x2_idx) q[0] += cast(p1[0], dtype) * cast(p2[0], dtype) return ans ################## # Other Routines # ################## def _matrix_power(a, n: int): a = asarray(a) dtype = a.dtype s = a.shape m = _square_rows(a) if n == 0: if staticlen(s) == 2: return eye(m, dtype=dtype) else: b = empty_like(a) for idx in multirange(s[:-2]): p = b._ptr(idx + (0, 0)) for i in range(m): for j in range(m): val = cast(1, dtype) if i == j else zero(dtype) (p + (i * m + j))[0] = val return b elif n < 0: if type(inv(a)) is type(a): a = inv(a) n = abs(n) else: raise ValueError( "cannot take integral matrix to non-constant negative " "power; use 'matrix_power(inv(a), abs(n))' instead" ) if n == 1: return a elif n == 2: return matmul(a, a) elif n == 3: return matmul(matmul(a, a), a) result = a z = a first = True have_result = False while n > 0: if first: first = False else: z = matmul(z, z) first = False n, bit = divmod(n, 2) if bit: if have_result: result = matmul(result, z) else: result = z have_result = True return result def matrix_power(a, n: int): return _matrix_power(a, n) @overload def matrix_power(a, n: Static[int]): nonstatic = lambda x: x # hack to get non-static argument if n >= 0: return _matrix_power(a, n) else: return _matrix_power(inv(a), -n) def _multi_dot_three(A, B, C, out=None): a0, a1b0 = A.shape b1c0, c1 = C.shape # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 cost1 = a0 * b1c0 * (a1b0 + c1) # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 cost2 = a1b0 * c1 * (a0 + b1c0) if cost1 < cost2: return dot(dot(A, B), C, out=out) else: return dot(A, dot(B, C), out=out) def _multi_dot(arrays, order, i, j, out=None): if i == j: #assert out is None return arrays[i] else: return dot(_multi_dot(arrays, order, i, order[i, j]), _multi_dot(arrays, order, order[i, j] + 1, j), out=out) def _multi_dot_matrix_chain_order(arrays, return_costs: Static[int] = False): n = len(arrays) p = [a.shape[0] for a in arrays] + [arrays[-1].shape[1]] m = zeros((n, n), dtype=float) s = empty((n, n), dtype=int) for l in range(1, n): for i in range(n - l): j = i + l m[i, j] = inf(float) for k in range(i, j): q = m[i, k] + m[k+1, j] + p[i]*p[k+1]*p[j+1] if q < m[i, j]: m[i, j] = q s[i, j] = k # Note that Cormen uses 1-based index if return_costs: return (s, m) else: return s def multi_dot(arrays, out = None): n = len(arrays) if isinstance(arrays, Tuple): if staticlen(arrays) < 2: compile_error("Expecting at least two arrays.") if staticlen(arrays) == 2: return dot(arrays[0], arrays[1], out=out) else: if n < 2: raise ValueError("Expecting at least two arrays.") if n == 2: return dot(arrays[0], arrays[1], out=out) ndim_first: Static[int] = staticlen(arrays[0].shape) ndim_last: Static[int] = staticlen(arrays[-1].shape) def fix_first(arr): arr = asarray(arr) if staticlen(arr.shape) == 1: return atleast_2d(arr) else: return arr def fix_last(arr): arr = asarray(arr) if staticlen(arr.shape) == 1: return atleast_2d(arr).T else: return arr if isinstance(arrays, Tuple): xarrays = (fix_first(arrays[0]),) + tuple(asarray(arr) for arr in arrays[1:-1]) + (fix_last(arrays[-1]),) else: xarrays = [asarray(arr) for arr in arrays] xarrays[0] = fix_first(xarrays[0]) xarrays[-1] = fix_last(xarrays[-1]) for arr in xarrays: if staticlen(arr.shape) != 2: compile_error("1-dimensional array given. Array must be two-dimensional") if n == 3: result = _multi_dot_three(xarrays[0], xarrays[1], xarrays[2], out=out) else: order = _multi_dot_matrix_chain_order(xarrays) result = _multi_dot(xarrays, order, 0, n - 1, out=out) if ndim_first == 1 and ndim_last == 1: return result[0, 0] # scalar elif ndim_first == 1 or ndim_last == 1: return result.ravel() # 1-D else: return result def _vdot_ptr(x: Ptr[T1], y: Ptr[T2], n: int, incx: int, incy: int, T1: type, T2: type): if T1 is float and T2 is float: return cblas_ddot(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy)) elif T1 is float32 and T2 is float32: return cblas_sdot(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy)) elif T1 is complex and T2 is complex: z = complex() cblas_zdotc_sub(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy), __ptr__(z).as_byte()) return z elif T1 is complex64 and T2 is complex64: z = complex64() cblas_cdotc_sub(fint(n), x.as_byte(), fint(incx), y.as_byte(), fint(incy), __ptr__(z).as_byte()) return z else: TR = type(coerce(T1, T2)) z = TR() for _ in range(n): a = x[0] b = y[0] if hasattr(a, "conjugate"): a = a.conjugate() z += cast(a, TR) * cast(b, TR) x += incx y += incy return z def vdot(a, b): def mismatch(): raise ValueError("vdot: inputs have different sizes") a = asarray(a) b = asarray(b) if staticlen(a.shape) == 0 and staticlen(b.shape) == 0: x = a.data[0] y = b.data[0] TR = coerce(type(x), type(y)) if hasattr(x, "conjugate"): x = x.conjugate() return cast(x, TR) * cast(y, TR) n = a.size if n != b.size: mismatch() if staticlen(a.shape) == 1 and staticlen(b.shape) == 1: inca = a.strides[0] // a.itemsize incb = b.strides[0] // b.itemsize return _vdot_ptr(a.data, b.data, n, inca, incb) if a._contig_match(b): return _vdot_ptr(a.data, b.data, n, 1, 1) T1 = a.dtype T2 = b.dtype TR = type(coerce(T1, T2)) z = TR() for e1, e2 in zip(a.flat, b.flat): if hasattr(e1, "conjugate"): e1 = e1.conjugate() z += cast(e1, TR) * cast(e2, TR) return z def outer(a, b, out = None): a = asarray(a) b = asarray(b) return multiply(a.ravel()[:, None], b.ravel()[None, :], out) def inner(a, b): a = asarray(a) b = asarray(b) shape_a = a.shape shape_b = b.shape if staticlen(shape_a) == 0 or staticlen(shape_b) == 0: return multiply(a, b) if staticlen(shape_a) == 1 and staticlen(shape_b) == 1: return dot(a, b) n = shape_a[-1] if n != shape_b[-1]: raise ValueError("inner: mismatch in last dimension of inputs") shape_cut_a = shape_a[:-1] shape_cut_b = shape_b[:-1] out_shape = shape_cut_a + shape_cut_b dtype = type(coerce(a.dtype, b.dtype)) out = empty(out_shape, dtype=dtype) for idx1 in multirange(shape_cut_a): for idx2 in multirange(shape_cut_b): q = out._ptr(idx1 + idx2) q[0] = zero(dtype) for r in range(n): xa = a._ptr(idx1 + (r,))[0] xb = b._ptr(idx2 + (r,))[0] xc = cast(xa, dtype) * cast(xb, dtype) q[0] += xc return out def _tensordot(a, b, axes, ndim: Static[int] = -1): def get_axes(axes): if isinstance(axes, int): axes_a = list(range(-axes, 0)) axes_b = list(range(0, axes)) na = len(axes_a) nb = len(axes_b) return axes_a, axes_b axes_a, axes_b = axes if isinstance(axes_a, int): xa = [axes_a] else: xa = list(axes_a) if isinstance(axes_b, int): xb = [axes_b] else: xb = list(axes_b) return xa, xb axes_a, axes_b = get_axes(axes) na = len(axes_a) nb = len(axes_b) as_ = List[int](a.shape) nda = a.ndim bs = List[int](b.shape) ndb = b.ndim equal = True if na != nb: equal = False else: for k in range(na): if as_[axes_a[k]] != bs[axes_b[k]]: equal = False break if axes_a[k] < 0: axes_a[k] += nda if axes_b[k] < 0: axes_b[k] += ndb if not equal: raise ValueError("shape-mismatch for sum") notin = [k for k in range(nda) if k not in axes_a] newaxes_a = notin + axes_a N2 = 1 for axis in axes_a: N2 *= as_[axis] M2 = 1 for ax in notin: M2 *= as_[ax] newshape_a = (M2, N2) if ndim >= 0: olda = [as_[axis] for axis in notin] else: olda = None notin = [k for k in range(ndb) if k not in axes_b] newaxes_b = axes_b + notin N2 = 1 for axis in axes_b: N2 *= bs[axis] M2 = 1 for ax in notin: M2 *= bs[ax] newshape_b = (N2, M2) if ndim >= 0: oldb = [bs[axis] for axis in notin] else: oldb = None at = a.transpose(newaxes_a).reshape(newshape_a) bt = b.transpose(newaxes_b).reshape(newshape_b) res = dot(at, bt) # NOTE: 'olda + oldb' length is not known at compile-time, # so cannot reshape unless 'ndim' is given. if ndim >= 0: newshape = (0,) * ndim pnewshape = Ptr[int](__ptr__(newshape).as_byte()) i = 0 for j in olda: pnewshape[i] = j i += 1 for j in oldb: pnewshape[i] = j i += 1 return res.reshape(newshape) else: return res def tensordot(a, b, axes): a = asarray(a) b = asarray(b) return _tensordot(a, b, axes) @overload def tensordot(a, b, axes: Static[int] = 2): if axes < 0: return tensordot(a, b, axes=0) a = asarray(a) b = asarray(b) if axes > a.ndim + 1 or axes > b.ndim + 1: compile_error("'axes' too large for given arrays") return _tensordot(a, b, axes=axes, ndim=(a.ndim + b.ndim - 2*axes)) def kron(a, b): b = asarray(b) ndb: Static[int] = staticlen(b.shape) a = array(a, copy=False, ndmin=ndb) nda: Static[int] = staticlen(a.shape) nd: Static[int] = ndb if ndb >= nda else nda if nda == 0 or ndb == 0: return multiply(a, b) as_ = a.shape bs = b.shape if not a.flags.contiguous: a = reshape(a, as_) if not b.flags.contiguous: b = reshape(b, bs) as_ = (1,)*(ndb-nda if ndb-nda >= 0 else 0) + as_ bs = (1,)*(nda-ndb if nda-ndb >= 0 else 0) + bs a_arr = expand_dims(a, axis=tuple(i for i in staticrange(ndb-nda))) b_arr = expand_dims(b, axis=tuple(i for i in staticrange(nda-ndb))) a_arr = expand_dims(a_arr, axis=tuple(i for i in staticrange(1, nd*2, 2))) b_arr = expand_dims(b_arr, axis=tuple(i for i in staticrange(0, nd*2, 2))) result = multiply(a_arr, b_arr) res_shape = tuple(as_[i] * bs[i] for i in staticrange(staticlen(as_))) result = result.reshape(res_shape) return result def _trace_ufunc(a, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: type = NoneType, out = None): a = asarray(a) if dtype is NoneType: return _trace_ufunc(a, offset=offset, axis1=axis1, axis2=axis2, dtype=a.dtype, out=out) m, n = _rows_cols(a) s = a.shape ndim = a.ndim axis1 = normalize_axis_index(axis1, a.ndim, "axis1") axis2 = normalize_axis_index(axis2, a.ndim, "axis2") if axis1 == axis2: raise ValueError("axis1 and axis2 cannot be the same") if staticlen(s) == 2: i = 0 t = zero(dtype) while ((offset >= 0 and (i < m and i + offset < n)) or (offset < 0 and (i - offset < m and i < n))): e = a._ptr((i, i + offset))[0] if offset >= 0 else a._ptr((i - offset, i))[0] t += cast(e, dtype) i += 1 if out is None: return t else: if staticlen(out.shape) != 0: compile_error("expected 0-dimensional output parameter") if out.dtype is not dtype: compile_error("output parameter has the wrong dtype") out.data[0] = t return out else: if axis1 > axis2: axis1, axis2 = axis2, axis2 ans_shape = tuple_delete(tuple_delete(s, axis2), axis1) m, n = s[axis1], s[axis2] if out is None: ans = empty(ans_shape, dtype=dtype) else: if staticlen(out.shape) != staticlen(ans_shape): compile_error("output parameter has the wrong number of dimensions") if out.dtype is not dtype: compile_error("output parameter has the wrong dtype") if out.shape != ans_shape: raise ValueError("output parameter has the wrong shape") ans = out for idx0 in multirange(ans_shape): i = 0 t = zero(dtype) while ((offset >= 0 and (i < m and i + offset < n)) or (offset < 0 and (i - offset < m and i < n))): idx = (tuple_insert( tuple_insert(idx0, axis1, i), axis2, i + offset) if offset >= 0 else tuple_insert( tuple_insert(idx0, axis1, i - offset), axis2, i)) e = a._ptr(idx)[0] t += cast(e, dtype) i += 1 q = ans._ptr(idx0) q[0] = t return ans def trace(x, offset: int = 0, dtype: type = NoneType): return _trace_ufunc(x, offset=offset, dtype=dtype) def tensorsolve(a, b, axes = None): a = asarray(a) b = asarray(b) an: Static[int] = staticlen(a.shape) bn: Static[int] = staticlen(b.shape) if axes is not None: allaxes = list(range(0, an)) for k in axes: allaxes.remove(k) allaxes.insert(an, k) a = a.transpose(allaxes) oldshape = a.shape[-(an-bn):] prod = 1 for k in oldshape: prod *= k if a.size != prod ** 2: raise LinAlgError( "Input arrays must satisfy the requirement" "prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])" ) a = a.reshape(prod, prod) b = b.ravel() res = solve(a, b) return res.reshape(oldshape) def tensorinv(a, ind: int = 2): a = asarray(a) oldshape = a.shape prod = 1 invshape = oldshape if ind > 0: # invshape = oldshape[ind:] + oldshape[:ind] poldshape = Ptr[int](__ptr__(oldshape).as_byte()) pinvshape = Ptr[int](__ptr__(invshape).as_byte()) i = ind j = 0 while i < len(oldshape): pinvshape[j] = poldshape[i] prod *= poldshape[i] i += 1 j += 1 i = 0 while i < min(ind, len(oldshape)): pinvshape[j] = poldshape[i] i += 1 j += 1 else: raise ValueError("Invalid ind argument.") a = a.reshape(prod, -1) ia = inv(a) return ia.reshape(*invshape) def _square(elem: T, T: type): if T is complex or T is complex64: return (elem.real * elem.real) + (elem.imag * elem.imag) else: return elem * elem def _norm_cast(elem: T, T: type): if T is float or T is float32 or T is complex or T is complex64: return elem else: return cast(elem, float) def _vector_norm_f(x: ndarray, axis: int, R: type): if staticlen(x.shape) == 1: dtype = x.dtype sh = x.shape[0] st = x.strides[0] data = x.data curr = R() for i in range(sh): elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0]) curr += _square(elem) return util_sqrt(curr) else: s = x.shape n = s[axis] ans_shape = tuple_delete(x.shape, axis) ans = empty(ans_shape, R) for idx in multirange(ans_shape): q = ans._ptr(idx) curr = R() for i in range(n): idx1 = tuple_insert(idx, axis, i) elem = _norm_cast(x._ptr(idx1)[0]) curr += _square(elem) q[0] = util_sqrt(curr) return ans def _norm_error_zero_max(): raise ValueError("zero-size array to reduction operation " "maximum which has no identity") def _norm_error_zero_min(): raise ValueError("zero-size array to reduction operation " "minimum which has no identity") def _vector_norm_inf(x: ndarray, axis: int, R: type): sh = x.shape[0] if sh == 0: _norm_error_zero_max() if staticlen(x.shape) == 1: dtype = x.dtype sh = x.shape[0] st = x.strides[0] data = x.data curr = abs(_norm_cast(data[0])) for i in range(1, sh): elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0]) e = abs(elem) if e > curr: curr = e return curr else: s = x.shape n = s[axis] ans_shape = tuple_delete(x.shape, axis) ans = empty(ans_shape, R) for idx in multirange(ans_shape): q = ans._ptr(idx) idx1 = tuple_insert(idx, axis, 0) curr = abs(_norm_cast(x._ptr(idx1)[0])) for i in range(1, n): idx1 = tuple_insert(idx, axis, i) elem = _norm_cast(x._ptr(idx1)[0]) e = abs(elem) if e > curr: curr = e q[0] = curr return ans def _vector_norm_ninf(x: ndarray, axis: int, R: type): sh = x.shape[0] if sh == 0: _norm_error_zero_min() if staticlen(x.shape) == 1: dtype = x.dtype st = x.strides[0] data = x.data curr = abs(_norm_cast(data[0])) for i in range(1, sh): elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0]) e = abs(elem) if e < curr: curr = e return curr else: s = x.shape n = s[axis] ans_shape = tuple_delete(x.shape, axis) ans = empty(ans_shape, R) for idx in multirange(ans_shape): q = ans._ptr(idx) idx1 = tuple_insert(idx, axis, 0) curr = abs(_norm_cast(x._ptr(idx1)[0])) for i in range(1, n): idx1 = tuple_insert(idx, axis, i) elem = _norm_cast(x._ptr(idx1)[0]) e = abs(elem) if e < curr: curr = e q[0] = curr return ans def _vector_norm_0(x: ndarray, axis: int, R: type): if staticlen(x.shape) == 1: dtype = x.dtype sh = x.shape[0] st = x.strides[0] data = x.data curr = R() for i in range(sh): elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0]) if elem: curr += R(1) return curr else: s = x.shape n = s[axis] ans_shape = tuple_delete(x.shape, axis) ans = empty(ans_shape, R) for idx in multirange(ans_shape): q = ans._ptr(idx) idx1 = tuple_insert(idx, axis, 0) curr = R() for i in range(n): idx1 = tuple_insert(idx, axis, i) elem = _norm_cast(x._ptr(idx1)[0]) if elem: curr += R(1) q[0] = curr return ans def _vector_norm_1(x: ndarray, axis: int, R: type): if staticlen(x.shape) == 1: dtype = x.dtype sh = x.shape[0] st = x.strides[0] data = x.data curr = R() for i in range(sh): elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0]) curr += abs(elem) return curr else: s = x.shape n = s[axis] ans_shape = tuple_delete(x.shape, axis) ans = empty(ans_shape, R) for idx in multirange(ans_shape): q = ans._ptr(idx) idx1 = tuple_insert(idx, axis, 0) curr = R() for i in range(n): idx1 = tuple_insert(idx, axis, i) elem = _norm_cast(x._ptr(idx1)[0]) curr += abs(elem) q[0] = curr return ans def _vector_norm_g(x: ndarray, axis: int, g, R: type): if staticlen(x.shape) == 1: dtype = x.dtype sh = x.shape[0] st = x.strides[0] data = x.data curr = R() for i in range(sh): elem = _norm_cast((Ptr[dtype](data.as_byte() + (st * i)))[0]) curr += abs(elem) ** R(g) return curr ** (R(1) / R(g)) else: s = x.shape n = s[axis] ans_shape = tuple_delete(x.shape, axis) ans = empty(ans_shape, R) for idx in multirange(ans_shape): q = ans._ptr(idx) idx1 = tuple_insert(idx, axis, 0) curr = R() for i in range(n): idx1 = tuple_insert(idx, axis, i) elem = _norm_cast(x._ptr(idx1)[0]) curr += abs(elem) ** R(g) q[0] = curr ** (R(1) / R(g)) return ans def _matrix_norm_f(x: ndarray, row_axis: int, col_axis: int, R: type): s = x.shape if staticlen(s) == 2: dtype = x.dtype sh1, sh2 = s data = x.data curr = R() for i in range(sh1): for j in range(sh2): elem = _norm_cast(x._ptr((i, j))[0]) curr += _square(elem) return util_sqrt(curr) else: axis1, axis2 = row_axis, col_axis if axis1 > axis2: axis1, axis2 = axis2, axis1 ans_shape = tuple_delete(tuple_delete(s, axis2), axis1) ans = empty(ans_shape, R) sh1 = s[axis1] sh2 = s[axis2] for idx in multirange(ans_shape): q = ans._ptr(idx) curr = R() for i in range(sh1): for j in range(sh2): idx1 = tuple_insert( tuple_insert(idx, axis1, i), axis2, j) elem = _norm_cast(x._ptr(idx1)[0]) curr += _square(elem) q[0] = util_sqrt(curr) return ans def _matrix_norm_1(x: ndarray, row_axis: int, col_axis: int, R: type): s = x.shape if s[col_axis] == 0: _norm_error_zero_max() if staticlen(s) == 2: dtype = x.dtype row_dim = s[row_axis] col_dim = s[col_axis] row_stride = x.strides[row_axis] col_stride = x.strides[col_axis] data = x.data curr = R() first = True for col_idx in range(col_dim): sub = R() for row_idx in range(row_dim): elem = _norm_cast((Ptr[dtype](data.as_byte() + (row_stride * row_idx) + (col_stride * col_idx)))[0]) sub += abs(elem) if first or sub > curr: curr = sub first = False return curr else: axis1, axis2 = row_axis, col_axis if axis1 > axis2: axis1, axis2 = axis2, axis1 ans_shape = tuple_delete(tuple_delete(s, axis2), axis1) ans = empty(ans_shape, R) row_dim = s[row_axis] col_dim = s[col_axis] for idx in multirange(ans_shape): q = ans._ptr(idx) curr = R() first = True for col_idx in range(col_dim): sub = R() for row_idx in range(row_dim): i, j = row_idx, col_idx if row_axis > col_axis: i, j = j, i idx1 = tuple_insert( tuple_insert(idx, axis1, i), axis2, j) elem = _norm_cast(x._ptr(idx1)[0]) sub += abs(elem) if first or sub > curr: curr = sub first = False q[0] = curr return ans def _matrix_norm_inf(x: ndarray, row_axis: int, col_axis: int, R: type): s = x.shape if s[row_axis] == 0: _norm_error_zero_max() if staticlen(s) == 2: dtype = x.dtype row_dim = s[row_axis] col_dim = s[col_axis] row_stride = x.strides[row_axis] col_stride = x.strides[col_axis] data = x.data curr = R() first = True for row_idx in range(row_dim): sub = R() for col_idx in range(col_dim): elem = _norm_cast((Ptr[dtype](data.as_byte() + (row_stride * row_idx) + (col_stride * col_idx)))[0]) sub += abs(elem) if first or sub > curr: curr = sub first = False return curr else: axis1, axis2 = row_axis, col_axis if axis1 > axis2: axis1, axis2 = axis2, axis1 ans_shape = tuple_delete(tuple_delete(s, axis2), axis1) ans = empty(ans_shape, R) row_dim = s[row_axis] col_dim = s[col_axis] for idx in multirange(ans_shape): q = ans._ptr(idx) curr = R() first = True for row_idx in range(row_dim): sub = R() for col_idx in range(col_dim): i, j = row_idx, col_idx if row_axis > col_axis: i, j = j, i idx1 = tuple_insert( tuple_insert(idx, axis1, i), axis2, j) elem = _norm_cast(x._ptr(idx1)[0]) sub += abs(elem) if first or sub > curr: curr = sub first = False q[0] = curr return ans def _matrix_norm_n1(x: ndarray, row_axis: int, col_axis: int, R: type): s = x.shape if s[col_axis] == 0: _norm_error_zero_min() if staticlen(s) == 2: dtype = x.dtype row_dim = s[row_axis] col_dim = s[col_axis] row_stride = x.strides[row_axis] col_stride = x.strides[col_axis] data = x.data curr = R() first = True for col_idx in range(col_dim): sub = R() for row_idx in range(row_dim): elem = _norm_cast((Ptr[dtype](data.as_byte() + (row_stride * row_idx) + (col_stride * col_idx)))[0]) sub += abs(elem) if first or sub < curr: curr = sub first = False return curr else: axis1, axis2 = row_axis, col_axis if axis1 > axis2: axis1, axis2 = axis2, axis1 ans_shape = tuple_delete(tuple_delete(s, axis2), axis1) ans = empty(ans_shape, R) row_dim = s[row_axis] col_dim = s[col_axis] for idx in multirange(ans_shape): q = ans._ptr(idx) curr = R() first = True for col_idx in range(col_dim): sub = R() for row_idx in range(row_dim): i, j = row_idx, col_idx if row_axis > col_axis: i, j = j, i idx1 = tuple_insert( tuple_insert(idx, axis1, i), axis2, j) elem = _norm_cast(x._ptr(idx1)[0]) sub += abs(elem) if first or sub < curr: curr = sub first = False q[0] = curr return ans def _matrix_norm_ninf(x: ndarray, row_axis: int, col_axis: int, R: type): s = x.shape if s[row_axis] == 0: _norm_error_zero_min() if staticlen(s) == 2: dtype = x.dtype row_dim = s[row_axis] col_dim = s[col_axis] row_stride = x.strides[row_axis] col_stride = x.strides[col_axis] data = x.data curr = R() first = True for row_idx in range(row_dim): sub = R() for col_idx in range(col_dim): elem = _norm_cast((Ptr[dtype](data.as_byte() + (row_stride * row_idx) + (col_stride * col_idx)))[0]) sub += abs(elem) if first or sub < curr: curr = sub first = False return curr else: axis1, axis2 = row_axis, col_axis if axis1 > axis2: axis1, axis2 = axis2, axis1 ans_shape = tuple_delete(tuple_delete(s, axis2), axis1) ans = empty(ans_shape, R) row_dim = s[row_axis] col_dim = s[col_axis] for idx in multirange(ans_shape): q = ans._ptr(idx) curr = R() first = True for row_idx in range(row_dim): sub = R() for col_idx in range(col_dim): i, j = row_idx, col_idx if row_axis > col_axis: i, j = j, i idx1 = tuple_insert( tuple_insert(idx, axis1, i), axis2, j) elem = _norm_cast(x._ptr(idx1)[0]) sub += abs(elem) if first or sub < curr: curr = sub first = False q[0] = curr return ans def _multi_svd_norm(x: ndarray, row_axis: int, col_axis: int, op): y = moveaxis(x, (row_axis, col_axis), (-2, -1)) return op(svd(y, compute_uv=False), axis=-1) def _norm_wrap(x, kd_shape, keepdims: Static[int]): if keepdims: return asarray(x).reshape(kd_shape) else: return x def norm(x, ord = None, axis = None, keepdims: Static[int] = False): x = asarray(x) if not (x.dtype is float or x.dtype is float32 or x.dtype is float16): return norm(x.astype(float), ord=ord, axis=axis, keepdims=keepdims) ndim: Static[int] = staticlen(x.shape) dtype = x.dtype if dtype is complex or dtype is float: r = float() elif dtype is complex64 or dtype is float32: r = float32() else: r = float() R = type(r) # Common cases if axis is None: handle = False if ord is None: handle = True elif isinstance(ord, str): handle = ((ord == 'f' or ord == 'fro') and ndim == 2) elif isinstance(ord, int): handle = (ord == 2 and ndim == 1) if handle: y = x.ravel(order='K') if dtype is complex or dtype is complex64: x_real = y.real x_imag = y.imag sqnorm = x_real.dot(x_real) + x_imag.dot(x_imag) else: sqnorm = y.dot(y) ret = sqrt(sqnorm) if keepdims: return asarray(ret).reshape((1,) * ndim) else: return ret if axis is None: ax = tuple_range(ndim) elif isinstance(axis, int): ax = (axis,) elif isinstance(axis, Tuple): ax = axis else: compile_error("'axis' must be None, an integer or a tuple of integers") if ndim == 0: # this matches NumPy behavior if ord is None: return abs(x.data[0]) else: compile_error("Improper number of dimensions to norm.") ax = tuple(normalize_axis_index(a, ndim) for a in ax) kd_shape = x.shape if keepdims: for a in ax: kd_shape = tuple_set(kd_shape, a, 1) if staticlen(ax) == 1: if ord is None: return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims) elif isinstance(ord, int): if ord == 0: return _norm_wrap(_vector_norm_0(x, ax[0], R), kd_shape, keepdims) elif ord == 1: return _norm_wrap(_vector_norm_1(x, ax[0], R), kd_shape, keepdims) elif ord == 2: return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims) else: return _norm_wrap(_vector_norm_g(x, ax[0], float(ord), R), kd_shape, keepdims) elif isinstance(ord, float): if ord == 0.: return _norm_wrap(_vector_norm_0(x, ax[0], R), kd_shape, keepdims) elif ord == 1.: return _norm_wrap(_vector_norm_1(x, ax[0], R), kd_shape, keepdims) elif ord == 2.: return _norm_wrap(_vector_norm_f(x, ax[0], R), kd_shape, keepdims) elif ord == inf(float): return _norm_wrap(_vector_norm_inf(x, ax[0], R), kd_shape, keepdims) elif ord == -inf(float): return _norm_wrap(_vector_norm_ninf(x, ax[0], R), kd_shape, keepdims) else: return _norm_wrap(_vector_norm_g(x, ax[0], ord, R), kd_shape, keepdims) else: compile_error("Invalid norm order for vectors") elif staticlen(ax) == 2: row_axis, col_axis = ax if row_axis == col_axis: raise ValueError("Duplicate axes given.") if ord is None: return _norm_wrap(_matrix_norm_f(x, ax[0], ax[1], R), kd_shape, keepdims) elif isinstance(ord, int): if ord == 2: return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.max), kd_shape, keepdims) elif ord == -2: return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.min), kd_shape, keepdims) elif ord == 1: return _norm_wrap(_matrix_norm_1(x, ax[0], ax[1], R), kd_shape, keepdims) elif ord == -1: return _norm_wrap(_matrix_norm_n1(x, ax[0], ax[1], R), kd_shape, keepdims) elif isinstance(ord, float): if ord == 2.: return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.max), kd_shape, keepdims) elif ord == -2.: return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.min), kd_shape, keepdims) elif ord == 1.: return _norm_wrap(_matrix_norm_1(x, ax[0], ax[1], R), kd_shape, keepdims) elif ord == -1.: return _norm_wrap(_matrix_norm_n1(x, ax[0], ax[1], R), kd_shape, keepdims) elif ord == inf(float): return _norm_wrap(_matrix_norm_inf(x, ax[0], ax[1], R), kd_shape, keepdims) elif ord == -inf(float): return _norm_wrap(_matrix_norm_ninf(x, ax[0], ax[1], R), kd_shape, keepdims) elif isinstance(ord, str): if ord == 'fro' or ord == 'f': return _norm_wrap(_matrix_norm_f(x, ax[0], ax[1], R), kd_shape, keepdims) elif ord == 'nuc': return _norm_wrap(_multi_svd_norm(x, ax[0], ax[1], ndarray.sum), kd_shape, keepdims) raise ValueError("Invalid norm order for matrices.") else: raise ValueError("Improper number of dimensions to norm.") def pinv(a, rcond = 1e-15, hermitian: bool = False): a = asarray(a) rcond = asarray(rcond) m, n = _rows_cols(a) if m == 0 or n == 0: ans_shape = a.shape[:-2] + (n, m) dtype = a.dtype if (dtype is float or dtype is float32 or dtype is complex or dtype is complex64): return empty(ans_shape, dtype=dtype) else: return empty(ans_shape, dtype=float) a = a.conjugate() u, s, vt = svd(a, full_matrices=False, hermitian=hermitian) cutoff = multiply(rcond[..., None], s.max(axis=-1, keepdims=True)) S = s.dtype for idx in multirange(broadcast_shapes(s.shape, cutoff.shape)): c = cutoff._ptr(idx, broadcast=True)[0] p = s._ptr(idx, broadcast=True) e = p[0] if greater(e, c): e = cast(divide(1, e), S) else: e = S() p[0] = e vt_t = swapaxes(vt, -1, -2) u_t = swapaxes(u, -1, -2) res = matmul(vt_t, multiply(s[..., None], u_t)) return res def matrix_rank(A, tol = None, hermitian: bool = False): A = asarray(A) dtype = A.dtype if staticlen(A.shape) == 0: return int(bool(A.data[0])) elif staticlen(A.shape) == 1: for a in A: if a: return 1 return 0 elif staticlen(A.shape) == 2: m, n = _rows_cols(A) if m == 0 or n == 0: raise ValueError("cannot take rank of empty matrix") S = svd(A, compute_uv=False, hermitian=hermitian) R = S.dtype if staticlen(asarray(tol).shape) != 0: compile_error("invalid tolerance dimension") s_max = -inf(R) if tol is None: for s in S: if s > s_max: s_max = s r = 0 for s in S: if tol is None: if s > s_max * cast(max(m, n), R) * eps(R): r += 1 else: if s > cast(tol, R): r += 1 return r else: m, n = _rows_cols(A) if m == 0 or n == 0: raise ValueError("cannot take rank of empty matrix") k = min(m, n) pre_shape = A.shape[:-2] ans = empty(pre_shape, int) S = svd(A, compute_uv=False, hermitian=hermitian) R = S.dtype if tol is None: t = 0 else: t = asarray(tol) if staticlen(t.shape) > 0: if staticlen(t.shape) != staticlen(A.shape) - 2: compile_error("invalid tolerance dimension") else: if t.shape != A.shape[:-2]: raise ValueError("tolerance does not broadcast against matrix_rank input") for idx in multirange(pre_shape): q = ans._ptr(idx) r = 0 s_max = -inf(R) if tol is None: for i in range(k): e = S._ptr(idx + (i,))[0] if e > s_max: s_max = e for i in range(k): e = S._ptr(idx + (i,))[0] if tol is None: if e > s_max * cast(max(m, n), R) * eps(R): r += 1 else: if staticlen(t.shape) == 0: if e > cast(t.data[0], R): r += 1 else: if e > cast(t._ptr(idx)[0], R): r += 1 q[0] = r return ans def cond(x, p = None): def cond_r_1(x, p: int): s = svd(x, compute_uv=False) if p == -2: return s[..., -1] / s[..., 0] else: return s[..., 0] / s[..., -1] def cond_r_2(x, p): invx = _inv(x, ignore_errors=True) return norm(x, p, axis=(-2, -1)) * norm(invx, p, axis=(-2, -1)) x = asarray(x) m, n = _rows_cols(x) if m == 0 or n == 0: raise LinAlgError("cond is not defined on empty arrays") if p is None: r = cond_r_1(x, 0) elif isinstance(p, int): if p == 2 or p == -2: r = cond_r_1(x, p) else: r = cond_r_2(x, p) else: r = cond_r_2(x, p) r = asarray(r) if staticlen(r.shape) == 0: any_nan = False for i in range(m): for j in range(n): if isnan(x._ptr((i, j))[0]): any_nan = True break rval = r.data[0] if isnan(rval) and not any_nan: return inf(type(rval)) else: return rval else: for idx in multirange(r.shape): any_nan = False for i in range(m): for j in range(n): if isnan(x._ptr(idx + (i, j))[0]): any_nan = True break p = r._ptr(idx) if isnan(p[0]) and not any_nan: p[0] = inf(type(p[0])) return r def matrix_transpose(x): x = asarray(x) if x.ndim < 2: compile_error("Input array must be at least 2-dimensional") return swapaxes(x, -1, -2) @extend class ndarray: def __matmul__(self, other): return matmul(self, other) def dot(self, other): return dot(self, other) def trace(self, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: type = NoneType, out = None): return _trace_ufunc(self, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out)