codon/stdlib/numpy/format.codon

426 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
import util
from .ndarray import ndarray
from .npdatetime import datetime64, timedelta64
from .dragon4 import format_float_positional, format_float_scientific
MAX_FLOAT_PREC = 4
class DefaultFormat:
def __call__(self, x):
return str(x)
class TimeDeltaFormat:
def __call__(self, x):
return "NaT" if x._nat else str(x.value)
class FloatingFormat:
_scientific: bool
_sign: bool
_unique: bool
_trim: str
_precision: int
_min_digits: int
_pad_left: int
_pad_right: int
_exp_size: int
def __init__(self,
scientific: bool = False,
sign: bool = False,
unique: bool = True,
trim: str = '.',
precision: int = -1,
min_digits: int = -1,
pad_left: int = -1,
pad_right: int = -1,
exp_size: int = -1):
self._scientific = scientific
self._sign = sign
self._unique = unique
self._trim = trim
self._precision = precision
self._min_digits = min_digits
self._pad_left = pad_left
self._pad_right = pad_right
self._exp_size = exp_size
def __init__(self, a: ndarray, sign: bool = False):
dtype = a.dtype
if not (dtype is float16 or
dtype is float32 or
dtype is float):
compile_error("[internal error] cannot format non-float array")
if a.size == 0:
self.__init__()
return
min_val = None
max_val = None
finite = 0
abs_non_zero = 0
exp_format = False
for idx in util.multirange(a.shape):
x = a._ptr(idx)[0]
if util.isfinite(x):
finite += 1
if x:
x = abs(x)
if max_val is None or x > max_val:
max_val = x
if min_val is None or x < min_val:
min_val = x
abs_non_zero += 1
exp_format = (abs_non_zero != 0 and (max_val >= dtype(1.e8) or
min_val < dtype(0.0001) or max_val/min_val > dtype(1000.)))
precision = -1
exp_size = -1
pad_left = -1
pad_right = -1
if finite == 0:
self.__init__()
return
elif exp_format:
for x in a.flat:
if util.isfinite(x):
s = format_float_scientific(x, trim='.', sign=sign)
frac_str, _, exp_str = s.partition('e')
int_part, frac_part = frac_str.split('.')
precision = max(precision, len(frac_part))
exp_size = max(exp_size, len(exp_str) - 1)
pad_left = max(pad_left, len(int_part))
pad_right = exp_size + 2 + precision
min_digits = 0
else:
for x in a.flat:
if util.isfinite(x):
s = format_float_positional(x, trim='.', sign=sign, fractional=True)
int_part, frac_part = s.split('.')
pad_left = max(pad_left, len(int_part))
pad_right = max(pad_right, len(frac_part))
precision = min(precision, MAX_FLOAT_PREC)
pad_right = min(pad_right, exp_size + 2 + MAX_FLOAT_PREC)
return self.__init__(scientific=exp_format,
sign=sign,
unique=True,
trim='k',
precision=MAX_FLOAT_PREC,
min_digits=max(precision, 0),
pad_left=pad_left,
pad_right=pad_right,
exp_size=exp_size,)
@property
def scientific(self):
return self._scientific
@property
def sign(self):
return self._sign
@property
def unique(self):
return self._unique
@property
def trim(self):
return self._trim
@property
def precision(self):
return self._precision if self._precision >= 0 else None
@property
def min_digits(self):
return self._min_digits if self._min_digits >= 0 else None
@property
def pad_left(self):
return self._pad_left if self._pad_left >= 0 else None
@property
def pad_right(self):
return self._pad_right if self._pad_right >= 0 else None
@property
def exp_size(self):
return self._exp_size if self._exp_size >= 0 else None
def __call__(self, x):
if self.scientific:
s = format_float_scientific(x,
sign=self.sign,
trim=self.trim,
unique=self.unique,
precision=self.precision,
pad_left=self.pad_left,
exp_digits=self.exp_size,
min_digits=self.min_digits)
else:
s = format_float_positional(x,
sign=self.sign,
trim=self.trim,
unique=self.unique,
precision=self.precision,
fractional=True,
pad_left=self.pad_left,
pad_right=self.pad_right,
min_digits=self.min_digits)
return s
class ComplexFormat:
_re_fmt: FloatingFormat
_im_fmt: FloatingFormat
def __init__(self, a: ndarray):
dtype = a.dtype
if not (dtype is complex or dtype is complex64):
compile_error("[internal error] cannot format non-complex array")
self._re_fmt = FloatingFormat(a.real, sign=False)
self._im_fmt = FloatingFormat(a.imag, sign=True)
def __call__(self, x):
r = self._re_fmt(x.real)
i = self._im_fmt(x.imag)
sp = len(i.rstrip())
return f"{r}{i[:sp]}j{i[sp:]}"
def get_formatter(a: ndarray):
if (a.dtype is float16 or
a.dtype is float32 or
a.dtype is float):
return FloatingFormat(a)
elif (a.dtype is complex or
a.dtype is complex64):
return ComplexFormat(a)
elif isinstance(a.dtype, timedelta64):
return TimeDeltaFormat()
else:
return DefaultFormat()
@extend
class ndarray:
def __repr__(self):
return f'array({repr(self.tolist())})'
def _str_2d(fmt, shape, front: ndarray[X,ndim], back: Optional[ndarray[X,ndim]] = None, ndim: Static[int], X: type):
BOX_VL = ''
BOX_HL = ''
BOX_TR = ''
BOX_BR = ''
BOX_TL = ''
BOX_BL = ''
MID_N = ''
MID_E = ''
MID_S = ''
MID_W = ''
INNER_VL = ''
INNER_HL = ''
INNER_CR = ''
TRUNC = '...'
LIMIT = 10
TRUNC_LEN = 3
if staticlen(front.shape) != 2:
compile_error("[internal error] wrong array shape passed to str2d")
x, y = front.shape
v = _strbuf()
if util.count(shape) == 0:
v.append(str(shape[0]))
for a in shape[1:]:
v.append(' × ')
v.append(str(a))
v.append(' array of ')
v.append(X.__name__)
v.append('\n')
v.append('<empty>')
return v.__str__()
trunc_x = x > LIMIT
trunc_y = y > LIMIT
if not trunc_x and not trunc_y:
items = [fmt(front[i,j]) for i in range(x) for j in range(y)]
else:
x_real = x
y_real = y
if trunc_x:
x_real = 2*TRUNC_LEN + 1
if trunc_y:
y_real = 2*TRUNC_LEN + 1
items = []
i = 0
while i < x:
if trunc_x and (TRUNC_LEN <= i < x - TRUNC_LEN ):
i = x - TRUNC_LEN
for _ in range(y_real):
items.append(TRUNC)
continue
j = 0
while j < y:
if trunc_y and (TRUNC_LEN <= j < y - TRUNC_LEN):
j = y - TRUNC_LEN
items.append(TRUNC)
continue
items.append(fmt(front[i,j]))
j += 1
i += 1
if trunc_x:
x = 2*TRUNC_LEN + 1
if trunc_y:
y = 2*TRUNC_LEN + 1
back_items = []
W = max(len(s) for s in items)
if back is not None:
back_x, back_y = back.shape
i = 0
while i < back_x:
if trunc_x and (TRUNC_LEN <= i < back_x - TRUNC_LEN):
i = back_x - TRUNC_LEN
back_items.append(TRUNC)
continue
back_items.append(fmt(back[i, back_y - 1]))
i += 1
W = max(W, max(len(s) for s in back_items))
# tag
v.append(str(shape[0]))
for a in shape[1:]:
v.append(' × ')
v.append(str(a))
v.append(' array of ')
v.append(X.__name__)
v.append('\n')
# top border
v.append(BOX_TL)
for i in range(y):
for k in range(W + 2):
v.append(BOX_HL)
if i != y - 1:
v.append(MID_N)
v.append(BOX_TR)
v.append('\n')
k1, k2 = 0, 0
for i in range(x):
# row itself
v.append(BOX_VL)
for j in range(y):
e = items[k1].rjust(W)
k1 += 1
v.append(' ')
v.append(e)
v.append(' ')
if j != y - 1:
v.append(INNER_VL)
v.append(BOX_VL)
# behind array
if back is not None:
for _ in range(W + 2):
v.append(INNER_HL)
v.append(BOX_TR if i == 0 else MID_E)
v.append('\n')
if i != x - 1:
# inner border
v.append(MID_W)
for j in range(y):
for k in range(W + 2):
v.append(INNER_HL)
if j != y - 1:
v.append(INNER_CR)
v.append(MID_E)
# behind array
if back is not None:
e = back_items[k2].rjust(W)
k2 += 1
v.append(' ')
v.append(e)
v.append(' ')
v.append(BOX_VL)
v.append('\n')
# bottom border
v.append(BOX_BL)
for j in range(y):
for k in range(W+2):
v.append(BOX_HL)
if j != y - 1:
v.append(MID_S)
v.append(BOX_BR)
# behind array
if back is not None:
e = back_items[k2].rjust(W)
v.append(' ')
v.append(e)
v.append(' ')
v.append(BOX_VL)
v.append('\n')
if W >= 1:
v.append(' ···')
for _ in range(W - 1):
v.append(' ')
else:
v.append(' ··')
for _ in range(W):
v.append(' ')
v.append(BOX_BL)
for i in range(y):
if i > 0:
v.append(MID_S)
for _ in range(W + 2):
v.append(INNER_HL)
v.append(BOX_BR)
return v.__str__()
def __str__(self):
fmt = get_formatter(self)
if staticlen(self.shape) == 0:
return f'array({repr(self.data[0])})'
elif staticlen(self.shape) == 1:
return ndarray._str_2d(fmt, self.shape, self.reshape((1, self.size)))
elif staticlen(self.shape) == 2:
return ndarray._str_2d(fmt, self.shape, self)
else:
if self.size == 0:
return ndarray._str_2d(fmt, self.shape, self.reshape((0, 0)))
rem = self.shape[:-2]
top_idx = tuple(0 for _ in rem)
bot_idx = tuple(i - 1 for i in rem)
top = self[top_idx]
bot = self[bot_idx]
if util.count(rem) == 1:
return ndarray._str_2d(fmt, self.shape, top)
else:
return ndarray._str_2d(fmt, self.shape, top, bot)