mirror of https://github.com/exaloop/codon.git
426 lines
13 KiB
Python
426 lines
13 KiB
Python
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
|
||
|
||
import util
|
||
from .ndarray import ndarray
|
||
from .npdatetime import datetime64, timedelta64
|
||
from .dragon4 import format_float_positional, format_float_scientific
|
||
|
||
MAX_FLOAT_PREC = 4
|
||
|
||
class DefaultFormat:
|
||
def __call__(self, x):
|
||
return str(x)
|
||
|
||
class TimeDeltaFormat:
|
||
def __call__(self, x):
|
||
return "NaT" if x._nat else str(x.value)
|
||
|
||
class FloatingFormat:
|
||
_scientific: bool
|
||
_sign: bool
|
||
_unique: bool
|
||
_trim: str
|
||
_precision: int
|
||
_min_digits: int
|
||
_pad_left: int
|
||
_pad_right: int
|
||
_exp_size: int
|
||
|
||
def __init__(self,
|
||
scientific: bool = False,
|
||
sign: bool = False,
|
||
unique: bool = True,
|
||
trim: str = '.',
|
||
precision: int = -1,
|
||
min_digits: int = -1,
|
||
pad_left: int = -1,
|
||
pad_right: int = -1,
|
||
exp_size: int = -1):
|
||
self._scientific = scientific
|
||
self._sign = sign
|
||
self._unique = unique
|
||
self._trim = trim
|
||
self._precision = precision
|
||
self._min_digits = min_digits
|
||
self._pad_left = pad_left
|
||
self._pad_right = pad_right
|
||
self._exp_size = exp_size
|
||
|
||
def __init__(self, a: ndarray, sign: bool = False):
|
||
dtype = a.dtype
|
||
if not (dtype is float16 or
|
||
dtype is float32 or
|
||
dtype is float):
|
||
compile_error("[internal error] cannot format non-float array")
|
||
|
||
if a.size == 0:
|
||
self.__init__()
|
||
return
|
||
|
||
min_val = None
|
||
max_val = None
|
||
finite = 0
|
||
abs_non_zero = 0
|
||
exp_format = False
|
||
|
||
for idx in util.multirange(a.shape):
|
||
x = a._ptr(idx)[0]
|
||
if util.isfinite(x):
|
||
finite += 1
|
||
if x:
|
||
x = abs(x)
|
||
|
||
if max_val is None or x > max_val:
|
||
max_val = x
|
||
|
||
if min_val is None or x < min_val:
|
||
min_val = x
|
||
|
||
abs_non_zero += 1
|
||
|
||
exp_format = (abs_non_zero != 0 and (max_val >= dtype(1.e8) or
|
||
min_val < dtype(0.0001) or max_val/min_val > dtype(1000.)))
|
||
precision = -1
|
||
exp_size = -1
|
||
pad_left = -1
|
||
pad_right = -1
|
||
|
||
if finite == 0:
|
||
self.__init__()
|
||
return
|
||
elif exp_format:
|
||
for x in a.flat:
|
||
if util.isfinite(x):
|
||
s = format_float_scientific(x, trim='.', sign=sign)
|
||
frac_str, _, exp_str = s.partition('e')
|
||
int_part, frac_part = frac_str.split('.')
|
||
precision = max(precision, len(frac_part))
|
||
exp_size = max(exp_size, len(exp_str) - 1)
|
||
pad_left = max(pad_left, len(int_part))
|
||
pad_right = exp_size + 2 + precision
|
||
min_digits = 0
|
||
else:
|
||
for x in a.flat:
|
||
if util.isfinite(x):
|
||
s = format_float_positional(x, trim='.', sign=sign, fractional=True)
|
||
int_part, frac_part = s.split('.')
|
||
pad_left = max(pad_left, len(int_part))
|
||
pad_right = max(pad_right, len(frac_part))
|
||
|
||
precision = min(precision, MAX_FLOAT_PREC)
|
||
pad_right = min(pad_right, exp_size + 2 + MAX_FLOAT_PREC)
|
||
|
||
return self.__init__(scientific=exp_format,
|
||
sign=sign,
|
||
unique=True,
|
||
trim='k',
|
||
precision=MAX_FLOAT_PREC,
|
||
min_digits=max(precision, 0),
|
||
pad_left=pad_left,
|
||
pad_right=pad_right,
|
||
exp_size=exp_size,)
|
||
|
||
@property
|
||
def scientific(self):
|
||
return self._scientific
|
||
|
||
@property
|
||
def sign(self):
|
||
return self._sign
|
||
|
||
@property
|
||
def unique(self):
|
||
return self._unique
|
||
|
||
@property
|
||
def trim(self):
|
||
return self._trim
|
||
|
||
@property
|
||
def precision(self):
|
||
return self._precision if self._precision >= 0 else None
|
||
|
||
@property
|
||
def min_digits(self):
|
||
return self._min_digits if self._min_digits >= 0 else None
|
||
|
||
@property
|
||
def pad_left(self):
|
||
return self._pad_left if self._pad_left >= 0 else None
|
||
|
||
@property
|
||
def pad_right(self):
|
||
return self._pad_right if self._pad_right >= 0 else None
|
||
|
||
@property
|
||
def exp_size(self):
|
||
return self._exp_size if self._exp_size >= 0 else None
|
||
|
||
def __call__(self, x):
|
||
if self.scientific:
|
||
s = format_float_scientific(x,
|
||
sign=self.sign,
|
||
trim=self.trim,
|
||
unique=self.unique,
|
||
precision=self.precision,
|
||
pad_left=self.pad_left,
|
||
exp_digits=self.exp_size,
|
||
min_digits=self.min_digits)
|
||
else:
|
||
s = format_float_positional(x,
|
||
sign=self.sign,
|
||
trim=self.trim,
|
||
unique=self.unique,
|
||
precision=self.precision,
|
||
fractional=True,
|
||
pad_left=self.pad_left,
|
||
pad_right=self.pad_right,
|
||
min_digits=self.min_digits)
|
||
return s
|
||
|
||
class ComplexFormat:
|
||
_re_fmt: FloatingFormat
|
||
_im_fmt: FloatingFormat
|
||
|
||
def __init__(self, a: ndarray):
|
||
dtype = a.dtype
|
||
if not (dtype is complex or dtype is complex64):
|
||
compile_error("[internal error] cannot format non-complex array")
|
||
self._re_fmt = FloatingFormat(a.real, sign=False)
|
||
self._im_fmt = FloatingFormat(a.imag, sign=True)
|
||
|
||
def __call__(self, x):
|
||
r = self._re_fmt(x.real)
|
||
i = self._im_fmt(x.imag)
|
||
sp = len(i.rstrip())
|
||
return f"{r}{i[:sp]}j{i[sp:]}"
|
||
|
||
def get_formatter(a: ndarray):
|
||
if (a.dtype is float16 or
|
||
a.dtype is float32 or
|
||
a.dtype is float):
|
||
return FloatingFormat(a)
|
||
elif (a.dtype is complex or
|
||
a.dtype is complex64):
|
||
return ComplexFormat(a)
|
||
elif isinstance(a.dtype, timedelta64):
|
||
return TimeDeltaFormat()
|
||
else:
|
||
return DefaultFormat()
|
||
|
||
@extend
|
||
class ndarray:
|
||
|
||
def __repr__(self):
|
||
return f'array({repr(self.tolist())})'
|
||
|
||
def _str_2d(fmt, shape, front: ndarray[X,ndim], back: Optional[ndarray[X,ndim]] = None, ndim: Static[int], X: type):
|
||
BOX_VL = '│'
|
||
BOX_HL = '─'
|
||
BOX_TR = '╮'
|
||
BOX_BR = '╯'
|
||
BOX_TL = '╭'
|
||
BOX_BL = '╰'
|
||
MID_N = '┬'
|
||
MID_E = '┤'
|
||
MID_S = '┴'
|
||
MID_W = '├'
|
||
INNER_VL = '│'
|
||
INNER_HL = '─'
|
||
INNER_CR = '┼'
|
||
TRUNC = '...'
|
||
|
||
LIMIT = 10
|
||
TRUNC_LEN = 3
|
||
|
||
if staticlen(front.shape) != 2:
|
||
compile_error("[internal error] wrong array shape passed to str2d")
|
||
|
||
x, y = front.shape
|
||
v = _strbuf()
|
||
|
||
if util.count(shape) == 0:
|
||
v.append(str(shape[0]))
|
||
for a in shape[1:]:
|
||
v.append(' × ')
|
||
v.append(str(a))
|
||
v.append(' array of ')
|
||
v.append(X.__name__)
|
||
v.append('\n')
|
||
v.append('<empty>')
|
||
return v.__str__()
|
||
|
||
trunc_x = x > LIMIT
|
||
trunc_y = y > LIMIT
|
||
if not trunc_x and not trunc_y:
|
||
items = [fmt(front[i,j]) for i in range(x) for j in range(y)]
|
||
else:
|
||
x_real = x
|
||
y_real = y
|
||
|
||
if trunc_x:
|
||
x_real = 2*TRUNC_LEN + 1
|
||
if trunc_y:
|
||
y_real = 2*TRUNC_LEN + 1
|
||
|
||
items = []
|
||
i = 0
|
||
while i < x:
|
||
if trunc_x and (TRUNC_LEN <= i < x - TRUNC_LEN ):
|
||
i = x - TRUNC_LEN
|
||
for _ in range(y_real):
|
||
items.append(TRUNC)
|
||
continue
|
||
|
||
j = 0
|
||
while j < y:
|
||
if trunc_y and (TRUNC_LEN <= j < y - TRUNC_LEN):
|
||
j = y - TRUNC_LEN
|
||
items.append(TRUNC)
|
||
continue
|
||
items.append(fmt(front[i,j]))
|
||
j += 1
|
||
|
||
i += 1
|
||
|
||
if trunc_x:
|
||
x = 2*TRUNC_LEN + 1
|
||
if trunc_y:
|
||
y = 2*TRUNC_LEN + 1
|
||
|
||
back_items = []
|
||
W = max(len(s) for s in items)
|
||
|
||
if back is not None:
|
||
back_x, back_y = back.shape
|
||
i = 0
|
||
while i < back_x:
|
||
if trunc_x and (TRUNC_LEN <= i < back_x - TRUNC_LEN):
|
||
i = back_x - TRUNC_LEN
|
||
back_items.append(TRUNC)
|
||
continue
|
||
back_items.append(fmt(back[i, back_y - 1]))
|
||
i += 1
|
||
|
||
W = max(W, max(len(s) for s in back_items))
|
||
|
||
# tag
|
||
v.append(str(shape[0]))
|
||
for a in shape[1:]:
|
||
v.append(' × ')
|
||
v.append(str(a))
|
||
v.append(' array of ')
|
||
v.append(X.__name__)
|
||
v.append('\n')
|
||
|
||
# top border
|
||
v.append(BOX_TL)
|
||
for i in range(y):
|
||
for k in range(W + 2):
|
||
v.append(BOX_HL)
|
||
if i != y - 1:
|
||
v.append(MID_N)
|
||
v.append(BOX_TR)
|
||
v.append('\n')
|
||
|
||
k1, k2 = 0, 0
|
||
for i in range(x):
|
||
# row itself
|
||
v.append(BOX_VL)
|
||
for j in range(y):
|
||
e = items[k1].rjust(W)
|
||
k1 += 1
|
||
v.append(' ')
|
||
v.append(e)
|
||
v.append(' ')
|
||
if j != y - 1:
|
||
v.append(INNER_VL)
|
||
v.append(BOX_VL)
|
||
|
||
# behind array
|
||
if back is not None:
|
||
for _ in range(W + 2):
|
||
v.append(INNER_HL)
|
||
v.append(BOX_TR if i == 0 else MID_E)
|
||
|
||
v.append('\n')
|
||
if i != x - 1:
|
||
# inner border
|
||
v.append(MID_W)
|
||
for j in range(y):
|
||
for k in range(W + 2):
|
||
v.append(INNER_HL)
|
||
if j != y - 1:
|
||
v.append(INNER_CR)
|
||
v.append(MID_E)
|
||
|
||
# behind array
|
||
if back is not None:
|
||
e = back_items[k2].rjust(W)
|
||
k2 += 1
|
||
v.append(' ')
|
||
v.append(e)
|
||
v.append(' ')
|
||
v.append(BOX_VL)
|
||
|
||
v.append('\n')
|
||
|
||
# bottom border
|
||
v.append(BOX_BL)
|
||
for j in range(y):
|
||
for k in range(W+2):
|
||
v.append(BOX_HL)
|
||
if j != y - 1:
|
||
v.append(MID_S)
|
||
v.append(BOX_BR)
|
||
|
||
# behind array
|
||
if back is not None:
|
||
e = back_items[k2].rjust(W)
|
||
v.append(' ')
|
||
v.append(e)
|
||
v.append(' ')
|
||
v.append(BOX_VL)
|
||
|
||
v.append('\n')
|
||
if W >= 1:
|
||
v.append(' ···')
|
||
for _ in range(W - 1):
|
||
v.append(' ')
|
||
else:
|
||
v.append(' ··')
|
||
for _ in range(W):
|
||
v.append(' ')
|
||
v.append(BOX_BL)
|
||
for i in range(y):
|
||
if i > 0:
|
||
v.append(MID_S)
|
||
for _ in range(W + 2):
|
||
v.append(INNER_HL)
|
||
v.append(BOX_BR)
|
||
|
||
return v.__str__()
|
||
|
||
def __str__(self):
|
||
fmt = get_formatter(self)
|
||
if staticlen(self.shape) == 0:
|
||
return f'array({repr(self.data[0])})'
|
||
elif staticlen(self.shape) == 1:
|
||
return ndarray._str_2d(fmt, self.shape, self.reshape((1, self.size)))
|
||
elif staticlen(self.shape) == 2:
|
||
return ndarray._str_2d(fmt, self.shape, self)
|
||
else:
|
||
if self.size == 0:
|
||
return ndarray._str_2d(fmt, self.shape, self.reshape((0, 0)))
|
||
|
||
rem = self.shape[:-2]
|
||
top_idx = tuple(0 for _ in rem)
|
||
bot_idx = tuple(i - 1 for i in rem)
|
||
top = self[top_idx]
|
||
bot = self[bot_idx]
|
||
|
||
if util.count(rem) == 1:
|
||
return ndarray._str_2d(fmt, self.shape, top)
|
||
else:
|
||
return ndarray._str_2d(fmt, self.shape, top, bot)
|