1
0
mirror of https://github.com/exaloop/codon.git synced 2025-06-03 15:03:52 +08:00
codon/stdlib/numpy/format.codon

426 lines
13 KiB
Python
Raw Normal View History

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
import util
from .ndarray import ndarray
from .npdatetime import datetime64, timedelta64
from .dragon4 import format_float_positional, format_float_scientific
MAX_FLOAT_PREC = 4
class DefaultFormat:
def __call__(self, x):
return str(x)
class TimeDeltaFormat:
def __call__(self, x):
return "NaT" if x._nat else str(x.value)
class FloatingFormat:
_scientific: bool
_sign: bool
_unique: bool
_trim: str
_precision: int
_min_digits: int
_pad_left: int
_pad_right: int
_exp_size: int
def __init__(self,
scientific: bool = False,
sign: bool = False,
unique: bool = True,
trim: str = '.',
precision: int = -1,
min_digits: int = -1,
pad_left: int = -1,
pad_right: int = -1,
exp_size: int = -1):
self._scientific = scientific
self._sign = sign
self._unique = unique
self._trim = trim
self._precision = precision
self._min_digits = min_digits
self._pad_left = pad_left
self._pad_right = pad_right
self._exp_size = exp_size
def __init__(self, a: ndarray, sign: bool = False):
dtype = a.dtype
if not (dtype is float16 or
dtype is float32 or
dtype is float):
compile_error("[internal error] cannot format non-float array")
if a.size == 0:
self.__init__()
return
2025-02-11 15:49:15 -05:00
min_val: Optional[a.dtype] = None
max_val: Optional[a.dtype] = None
finite = 0
abs_non_zero = 0
exp_format = False
for idx in util.multirange(a.shape):
x = a._ptr(idx)[0]
if util.isfinite(x):
finite += 1
if x:
x = abs(x)
if max_val is None or x > max_val:
max_val = x
if min_val is None or x < min_val:
min_val = x
abs_non_zero += 1
exp_format = (abs_non_zero != 0 and (max_val >= dtype(1.e8) or
min_val < dtype(0.0001) or max_val/min_val > dtype(1000.)))
precision = -1
exp_size = -1
pad_left = -1
pad_right = -1
if finite == 0:
self.__init__()
return
elif exp_format:
for x in a.flat:
if util.isfinite(x):
s = format_float_scientific(x, trim='.', sign=sign)
frac_str, _, exp_str = s.partition('e')
int_part, frac_part = frac_str.split('.')
precision = max(precision, len(frac_part))
exp_size = max(exp_size, len(exp_str) - 1)
pad_left = max(pad_left, len(int_part))
pad_right = exp_size + 2 + precision
min_digits = 0
else:
for x in a.flat:
if util.isfinite(x):
s = format_float_positional(x, trim='.', sign=sign, fractional=True)
int_part, frac_part = s.split('.')
pad_left = max(pad_left, len(int_part))
pad_right = max(pad_right, len(frac_part))
precision = min(precision, MAX_FLOAT_PREC)
pad_right = min(pad_right, exp_size + 2 + MAX_FLOAT_PREC)
return self.__init__(scientific=exp_format,
sign=sign,
unique=True,
trim='k',
precision=MAX_FLOAT_PREC,
min_digits=max(precision, 0),
pad_left=pad_left,
pad_right=pad_right,
exp_size=exp_size,)
@property
def scientific(self):
return self._scientific
@property
def sign(self):
return self._sign
@property
def unique(self):
return self._unique
@property
def trim(self):
return self._trim
@property
def precision(self):
return self._precision if self._precision >= 0 else None
@property
def min_digits(self):
return self._min_digits if self._min_digits >= 0 else None
@property
def pad_left(self):
return self._pad_left if self._pad_left >= 0 else None
@property
def pad_right(self):
return self._pad_right if self._pad_right >= 0 else None
@property
def exp_size(self):
return self._exp_size if self._exp_size >= 0 else None
def __call__(self, x):
if self.scientific:
s = format_float_scientific(x,
sign=self.sign,
trim=self.trim,
unique=self.unique,
precision=self.precision,
pad_left=self.pad_left,
exp_digits=self.exp_size,
min_digits=self.min_digits)
else:
s = format_float_positional(x,
sign=self.sign,
trim=self.trim,
unique=self.unique,
precision=self.precision,
fractional=True,
pad_left=self.pad_left,
pad_right=self.pad_right,
min_digits=self.min_digits)
return s
class ComplexFormat:
_re_fmt: FloatingFormat
_im_fmt: FloatingFormat
def __init__(self, a: ndarray):
dtype = a.dtype
if not (dtype is complex or dtype is complex64):
compile_error("[internal error] cannot format non-complex array")
self._re_fmt = FloatingFormat(a.real, sign=False)
self._im_fmt = FloatingFormat(a.imag, sign=True)
def __call__(self, x):
r = self._re_fmt(x.real)
i = self._im_fmt(x.imag)
sp = len(i.rstrip())
return f"{r}{i[:sp]}j{i[sp:]}"
def get_formatter(a: ndarray):
if (a.dtype is float16 or
a.dtype is float32 or
a.dtype is float):
return FloatingFormat(a)
elif (a.dtype is complex or
a.dtype is complex64):
return ComplexFormat(a)
elif isinstance(a.dtype, timedelta64):
return TimeDeltaFormat()
else:
return DefaultFormat()
@extend
class ndarray:
def __repr__(self):
return f'array({repr(self.tolist())})'
def _str_2d(fmt, shape, front: ndarray[X,ndim], back: Optional[ndarray[X,ndim]] = None, ndim: Static[int], X: type):
BOX_VL = ''
BOX_HL = ''
BOX_TR = ''
BOX_BR = ''
BOX_TL = ''
BOX_BL = ''
MID_N = ''
MID_E = ''
MID_S = ''
MID_W = ''
INNER_VL = ''
INNER_HL = ''
INNER_CR = ''
TRUNC = '...'
LIMIT = 10
TRUNC_LEN = 3
if staticlen(front.shape) != 2:
compile_error("[internal error] wrong array shape passed to str2d")
x, y = front.shape
v = _strbuf()
if util.count(shape) == 0:
v.append(str(shape[0]))
for a in shape[1:]:
v.append(' × ')
v.append(str(a))
v.append(' array of ')
v.append(X.__name__)
v.append('\n')
v.append('<empty>')
return v.__str__()
trunc_x = x > LIMIT
trunc_y = y > LIMIT
if not trunc_x and not trunc_y:
items = [fmt(front[i,j]) for i in range(x) for j in range(y)]
else:
x_real = x
y_real = y
if trunc_x:
x_real = 2*TRUNC_LEN + 1
if trunc_y:
y_real = 2*TRUNC_LEN + 1
items = []
i = 0
while i < x:
if trunc_x and (TRUNC_LEN <= i < x - TRUNC_LEN ):
i = x - TRUNC_LEN
for _ in range(y_real):
items.append(TRUNC)
continue
j = 0
while j < y:
if trunc_y and (TRUNC_LEN <= j < y - TRUNC_LEN):
j = y - TRUNC_LEN
items.append(TRUNC)
continue
items.append(fmt(front[i,j]))
j += 1
i += 1
if trunc_x:
x = 2*TRUNC_LEN + 1
if trunc_y:
y = 2*TRUNC_LEN + 1
back_items = []
W = max(len(s) for s in items)
if back is not None:
back_x, back_y = back.shape
i = 0
while i < back_x:
if trunc_x and (TRUNC_LEN <= i < back_x - TRUNC_LEN):
i = back_x - TRUNC_LEN
back_items.append(TRUNC)
continue
back_items.append(fmt(back[i, back_y - 1]))
i += 1
W = max(W, max(len(s) for s in back_items))
# tag
v.append(str(shape[0]))
for a in shape[1:]:
v.append(' × ')
v.append(str(a))
v.append(' array of ')
v.append(X.__name__)
v.append('\n')
# top border
v.append(BOX_TL)
for i in range(y):
for k in range(W + 2):
v.append(BOX_HL)
if i != y - 1:
v.append(MID_N)
v.append(BOX_TR)
v.append('\n')
k1, k2 = 0, 0
for i in range(x):
# row itself
v.append(BOX_VL)
for j in range(y):
e = items[k1].rjust(W)
k1 += 1
v.append(' ')
v.append(e)
v.append(' ')
if j != y - 1:
v.append(INNER_VL)
v.append(BOX_VL)
# behind array
if back is not None:
for _ in range(W + 2):
v.append(INNER_HL)
v.append(BOX_TR if i == 0 else MID_E)
v.append('\n')
if i != x - 1:
# inner border
v.append(MID_W)
for j in range(y):
for k in range(W + 2):
v.append(INNER_HL)
if j != y - 1:
v.append(INNER_CR)
v.append(MID_E)
# behind array
if back is not None:
e = back_items[k2].rjust(W)
k2 += 1
v.append(' ')
v.append(e)
v.append(' ')
v.append(BOX_VL)
v.append('\n')
# bottom border
v.append(BOX_BL)
for j in range(y):
for k in range(W+2):
v.append(BOX_HL)
if j != y - 1:
v.append(MID_S)
v.append(BOX_BR)
# behind array
if back is not None:
e = back_items[k2].rjust(W)
v.append(' ')
v.append(e)
v.append(' ')
v.append(BOX_VL)
v.append('\n')
if W >= 1:
v.append(' ···')
for _ in range(W - 1):
v.append(' ')
else:
v.append(' ··')
for _ in range(W):
v.append(' ')
v.append(BOX_BL)
for i in range(y):
if i > 0:
v.append(MID_S)
for _ in range(W + 2):
v.append(INNER_HL)
v.append(BOX_BR)
return v.__str__()
def __str__(self):
fmt = get_formatter(self)
if staticlen(self.shape) == 0:
return f'array({repr(self.data[0])})'
elif staticlen(self.shape) == 1:
return ndarray._str_2d(fmt, self.shape, self.reshape((1, self.size)))
elif staticlen(self.shape) == 2:
return ndarray._str_2d(fmt, self.shape, self)
else:
if self.size == 0:
return ndarray._str_2d(fmt, self.shape, self.reshape((0, 0)))
rem = self.shape[:-2]
top_idx = tuple(0 for _ in rem)
bot_idx = tuple(i - 1 for i in rem)
top = self[top_idx]
bot = self[bot_idx]
if util.count(rem) == 1:
return ndarray._str_2d(fmt, self.shape, top)
else:
return ndarray._str_2d(fmt, self.shape, top, bot)