mirror of
https://github.com/exaloop/codon.git
synced 2025-06-03 15:03:52 +08:00
218 lines
5.5 KiB
Python
218 lines
5.5 KiB
Python
# Copyright (C) 2022-2024 Exaloop Inc. <https://exaloop.io>
|
|
|
|
def _format_error(ret: str):
|
|
raise ValueError(f"invalid format specifier: {ret}")
|
|
|
|
def python_to_fmt_format(s):
|
|
i = 0
|
|
|
|
fill, align = None, None
|
|
if i + 1 < len(s) and (s[i + 1] == '<' or s[i + 1] == '>' or s[i + 1] == '=' or s[i + 1] == '^'):
|
|
fill = s[i]
|
|
align = s[i + 1]
|
|
i += 2
|
|
elif i < len(s) and (s[i] == '<' or s[i] == '>' or s[i] == '=' or s[i] == '^'):
|
|
align = s[i]
|
|
i += 1
|
|
if align and align == '=':
|
|
raise NotImplementedError("'=' alignment not yet supported")
|
|
|
|
sign = None
|
|
if i < len(s) and (s[i] == '+' or s[i] == '-' or s[i] == ' '):
|
|
sign = s[i]
|
|
i += 1
|
|
|
|
coerce_negative_float = False
|
|
if i < len(s) and s[i] == 'z':
|
|
coerce_negative_float = True
|
|
i += 1
|
|
if coerce_negative_float:
|
|
raise NotImplementedError("'z' not yet supported")
|
|
|
|
alternate_form = False
|
|
if i < len(s) and s[i] == '#':
|
|
alternate_form = True
|
|
i += 1
|
|
|
|
width_pre_zero = False
|
|
if i < len(s) and s[i] == '#':
|
|
width_pre_zero = True
|
|
i += 1
|
|
width = 0
|
|
while i < len(s) and str._isdigit(s.ptr[i]):
|
|
width = 10 * width + ord(s[i]) - ord('0')
|
|
i += 1
|
|
|
|
grouping = None
|
|
if i < len(s) and (s[i] == '_' or s[i] == ','):
|
|
grouping = s[i]
|
|
i += 1
|
|
if grouping == '_':
|
|
raise NotImplementedError("'_' grouping not yet supported")
|
|
|
|
precision = None
|
|
if i < len(s) and s[i] == '.':
|
|
i += 1
|
|
precision = 0
|
|
while i < len(s) and str._isdigit(s.ptr[i]):
|
|
precision = 10 * precision + ord(s[i]) - ord('0')
|
|
i += 1
|
|
|
|
type = None
|
|
if i < len(s):
|
|
type = s[i]
|
|
i += 1
|
|
|
|
if i != len(s):
|
|
raise ValueError("bad format string")
|
|
|
|
# Construct fmt::format-compatible string
|
|
ns = ""
|
|
if align:
|
|
if fill: ns += fill
|
|
ns += align
|
|
if sign: ns += sign
|
|
if alternate_form: ns += "#"
|
|
if width_pre_zero: ns += "0"
|
|
if width: ns += str(width)
|
|
if precision is not None: ns += f".{precision}"
|
|
if grouping: ns += "L"
|
|
if type: ns += type
|
|
return ns
|
|
|
|
@extend
|
|
class int:
|
|
def __format__(self, format_spec: str) -> str:
|
|
err = False
|
|
ret = _C.seq_str_int(self, python_to_fmt_format(format_spec), __ptr__(err))
|
|
if format_spec and err:
|
|
_format_error(ret)
|
|
return ret
|
|
|
|
@extend
|
|
class Int:
|
|
def __format__(self, format_spec: str) -> str:
|
|
err = False
|
|
ret = _C.seq_str_int(self.__int__(), python_to_fmt_format(format_spec), __ptr__(err))
|
|
if format_spec and err:
|
|
_format_error(ret)
|
|
return ret
|
|
|
|
@extend
|
|
class UInt:
|
|
def __format__(self, format_spec: str) -> str:
|
|
err = False
|
|
ret = _C.seq_str_uint(self.__int__(), python_to_fmt_format(format_spec), __ptr__(err))
|
|
if format_spec and err:
|
|
_format_error(ret)
|
|
return ret
|
|
|
|
@extend
|
|
class float:
|
|
def __format__(self, format_spec: str) -> str:
|
|
err = False
|
|
ret = _C.seq_str_float(self, python_to_fmt_format(format_spec), __ptr__(err))
|
|
if format_spec and err:
|
|
_format_error(ret)
|
|
return ret if ret != "-nan" else "nan"
|
|
|
|
@extend
|
|
class str:
|
|
def __format__(self, format_spec: str) -> str:
|
|
err = False
|
|
ret = _C.seq_str_str(self, python_to_fmt_format(format_spec), __ptr__(err))
|
|
if format_spec and err:
|
|
_format_error(ret)
|
|
return ret
|
|
|
|
@extend
|
|
class Ptr:
|
|
def __format__(self, format_spec: str) -> str:
|
|
err = False
|
|
ret = _C.seq_str_ptr(self.as_byte(), python_to_fmt_format(format_spec), __ptr__(err))
|
|
if format_spec and err:
|
|
_format_error(ret)
|
|
return ret
|
|
|
|
def _divmod_10(dividend, N: Static[int]):
|
|
T = type(dividend)
|
|
zero, one = T(0), T(1)
|
|
neg = dividend < zero
|
|
dvd = dividend.__abs__()
|
|
|
|
remainder = 0
|
|
quotient = zero
|
|
|
|
# Euclidean division
|
|
for bit_idx in range(N - 1, -1, -1):
|
|
mask = int((dvd & (one << T(bit_idx))) != zero)
|
|
remainder = (remainder << 1) + mask
|
|
if remainder >= 10:
|
|
quotient = (quotient << one) + one
|
|
remainder -= 10
|
|
else:
|
|
quotient = quotient << one
|
|
|
|
if neg:
|
|
quotient = -quotient
|
|
remainder = -remainder
|
|
|
|
return quotient, remainder
|
|
|
|
@extend
|
|
class Int:
|
|
def __str__(self) -> str:
|
|
if N <= 64:
|
|
return str(int(self))
|
|
|
|
if not self:
|
|
return '0'
|
|
|
|
s = _strbuf()
|
|
d = self
|
|
|
|
if d >= Int[N](0):
|
|
while True:
|
|
d, m = _divmod_10(d, N)
|
|
b = byte(48 + m) # 48 == ord('0')
|
|
s.append(str(__ptr__(b), 1))
|
|
if not d:
|
|
break
|
|
else:
|
|
while True:
|
|
d, m = _divmod_10(d, N)
|
|
b = byte(48 - m) # 48 == ord('0')
|
|
s.append(str(__ptr__(b), 1))
|
|
|
|
if not d:
|
|
break
|
|
s.append('-')
|
|
|
|
s.reverse()
|
|
return s.__str__()
|
|
|
|
@extend
|
|
class UInt:
|
|
def __str__(self) -> str:
|
|
if N <= 64:
|
|
return self.__format__("")
|
|
|
|
s = _strbuf()
|
|
d = self
|
|
|
|
while True:
|
|
d, m = _divmod_10(d, N)
|
|
b = byte(48 + int(m)) # 48 == ord('0')
|
|
s.append(str(__ptr__(b), 1))
|
|
if not d:
|
|
break
|
|
|
|
s.reverse()
|
|
return s.__str__()
|
|
|
|
@extend
|
|
class __magic__:
|
|
def repr_partial(slf) -> str:
|
|
h = slf.__class__.__name__
|
|
return f"partial({h})"
|