Skip to content

Commit

Permalink
Use fast_float library for str-to-int/float (#561)
Browse files Browse the repository at this point in the history
* Use fast_float library for str-to-int/float

* Fix str-to-int conversions
  • Loading branch information
arshajii authored May 22, 2024
1 parent acff5e3 commit ffeeca2
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 88 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,14 @@ set(CODONRT_FILES codon/runtime/lib.h codon/runtime/lib.cpp
codon/runtime/re.cpp codon/runtime/exc.cpp
codon/runtime/gpu.cpp)
add_library(codonrt SHARED ${CODONRT_FILES})
add_dependencies(codonrt zlibstatic gc backtrace bz2 liblzma re2)
add_dependencies(codonrt zlibstatic gc backtrace bz2 liblzma re2 fast_float)
if(APPLE AND APPLE_ARM)
add_dependencies(codonrt unwind_shared)
endif()
target_include_directories(codonrt PRIVATE ${backtrace_SOURCE_DIR}
${re2_SOURCE_DIR}
"${gc_SOURCE_DIR}/include" runtime)
"${gc_SOURCE_DIR}/include"
"${fast_float_SOURCE_DIR}/include" runtime)
target_link_libraries(codonrt PRIVATE fmt omp backtrace ${STATIC_LIBCPP}
LLVMSupport)
if(APPLE)
Expand Down
6 changes: 6 additions & 0 deletions cmake/deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ CPMAddPackage(
"BUILD_SHARED_LIBS OFF"
"RE2_BUILD_TESTING OFF")

CPMAddPackage(
NAME fast_float
GITHUB_REPOSITORY "fastfloat/fast_float"
GIT_TAG v6.1.1
EXCLUDE_FROM_ALL YES)

if(APPLE AND APPLE_ARM)
enable_language(ASM)
CPMAddPackage(
Expand Down
18 changes: 18 additions & 0 deletions codon/runtime/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include "codon/runtime/lib.h"
#include <gc.h>

#define FASTFLOAT_ALLOWS_LEADING_PLUS
#define FASTFLOAT_SKIP_WHITE_SPACE
#include "fast_float/fast_float.h"

/*
* General
*/
Expand Down Expand Up @@ -282,6 +286,20 @@ SEQ_FUNC seq_str_t seq_str_str(seq_str_t s, seq_str_t format, bool *error) {
return fmt_conv(t, format, error);
}

SEQ_FUNC seq_int_t seq_int_from_str(seq_str_t s, const char **e, int base) {
seq_int_t result;
auto r = fast_float::from_chars(s.str, s.str + s.len, result, base);
*e = (r.ec == std::errc()) ? r.ptr : s.str;
return result;
}

SEQ_FUNC double seq_float_from_str(seq_str_t s, const char **e) {
double result;
auto r = fast_float::from_chars(s.str, s.str + s.len, result);
*e = (r.ec == std::errc() || r.ec == std::errc::result_out_of_range) ? r.ptr : s.str;
return result;
}

/*
* General I/O
*/
Expand Down
183 changes: 107 additions & 76 deletions stdlib/internal/builtin.codon
Original file line number Diff line number Diff line change
Expand Up @@ -381,54 +381,101 @@ def pow(base: int, exp: int, mod: Optional[int] = None):
@extend
class int:
def _from_str(s: str, base: int):
from internal.gc import alloc_atomic, free
def parse_error(s: str, base: int):
raise ValueError(
f"invalid literal for int() with base {base}: {s.__repr__()}"
)

if base < 0 or base > 36 or base == 1:
raise ValueError("int() base must be >= 2 and <= 36, or 0")

s0 = s
base0 = base
s = s.strip()
buf = __array__[byte](32)
n = len(s)
need_dyn_alloc = n >= len(buf)

p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(p, s.ptr, n)
p[n] = byte(0)
negate = False

if base == 0:
# skip leading sign
o = 0
if n >= 1 and (s.ptr[0] == byte(43) or s.ptr[0] == byte(45)):
o = 1

# detect base from prefix
if n >= o + 1 and s.ptr[o] == byte(48): # '0'
if n < o + 2:
parse_error(s0, base)

if s.ptr[o + 1] == byte(98) or s.ptr[o + 1] == byte(66): # 'b'/'B'
base = 2
elif s.ptr[o + 1] == byte(111) or s.ptr[o + 1] == byte(79): # 'o'/'O'
base = 8
elif s.ptr[o + 1] == byte(120) or s.ptr[o + 1] == byte(88): # 'x'/'X'
base = 16
else:
parse_error(s0, base)
else:
base = 10

if base == 2 or base == 8 or base == 16:
if base == 2:
C_LOWER = byte(98) # 'b'
C_UPPER = byte(66) # 'B'
elif base == 8:
C_LOWER = byte(111) # 'o'
C_UPPER = byte(79) # 'O'
else:
C_LOWER = byte(120) # 'x'
C_UPPER = byte(88) # 'X'

def check_digit(d: byte, base: int):
if base == 2:
return d == byte(48) or d == byte(49)
elif base == 8:
return byte(48) <= d <= byte(55)
elif base == 16:
return ((byte(48) <= d <= byte(57)) or
(byte(97) <= d <= byte(102)) or
(byte(65) <= d <= byte(70)))
return False

if (n >= 4 and
(s.ptr[0] == byte(43) or s.ptr[0] == byte(45)) and
s.ptr[1] == byte(48) and
(s.ptr[2] == C_LOWER or s.ptr[2] == C_UPPER)): # '+0b' etc.
if not check_digit(s.ptr[3], base):
parse_error(s0, base0)
negate = (s.ptr[0] == byte(45))
s = str(s.ptr + 3, n - 3)
elif (n >= 3 and
s.ptr[0] == byte(48) and
(s.ptr[1] == C_LOWER or s.ptr[1] == C_UPPER)): # '0b' etc.
if not check_digit(s.ptr[3], base):
parse_error(s0, base0)
s = str(s.ptr + 2, n - 2)

end = cobj()
result = _C.strtoll(p, __ptr__(end), i32(base))
result = _C.seq_int_from_str(s, __ptr__(end), i32(base))
n = len(s)

if need_dyn_alloc:
free(p)
if n == 0 or end != s.ptr + n:
parse_error(s0, base0)

if n == 0 or end != p + n:
raise ValueError(
f"invalid literal for int() with base {base}: {s0.__repr__()}"
)
if negate:
result = -result

return result

@extend
class float:
def _from_str(s: str) -> float:
s0 = s
s = s.strip()
buf = __array__[byte](32)
s = s.rstrip()
n = len(s)
need_dyn_alloc = n >= len(buf)

p = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(p, s.ptr, n)
p[n] = byte(0)

end = cobj()
result = _C.strtod(p, __ptr__(end))

if need_dyn_alloc:
free(p)
result = _C.seq_float_from_str(s, __ptr__(end))

if n == 0 or end != p + n:
if n == 0 or end != s.ptr + n:
raise ValueError(f"could not convert string to float: {s0.__repr__()}")

return result
Expand All @@ -439,89 +486,73 @@ class complex:
def parse_error():
raise ValueError("complex() arg is a malformed string")

buf = __array__[byte](32)
n = len(v)
need_dyn_alloc = n >= len(buf)

s = alloc_atomic(n + 1) if need_dyn_alloc else buf.ptr
str.memcpy(s, v.ptr, n)
s[n] = byte(0)

s = v.ptr
x = 0.0
y = 0.0
z = 0.0
got_bracket = False
start = s
end = cobj()
i = 0

while str._isspace(s[0]):
s += 1
while i < n and str._isspace(s[i]):
i += 1

if s[0] == byte(40): # '('
if i < n and s[i] == byte(40): # '('
got_bracket = True
s += 1
while str._isspace(s[0]):
s += 1
i += 1
while i < n and str._isspace(s[i]):
i += 1

z = _C.strtod(s, __ptr__(end))
z = _C.seq_float_from_str(str(s + i, n - i), __ptr__(end))

if end != s:
s = end
if end != s + i:
i = end - s

if s[0] == byte(43) or s[0] == byte(45): # '+' '-'
if i < n and (s[i] == byte(43) or s[i] == byte(45)): # '+' '-'
x = z
y = _C.strtod(s, __ptr__(end))
y = _C.seq_float_from_str(str(s + i, n - i), __ptr__(end))

if end != s:
s = end
if end != s + i:
i = end - s
else:
y = 1.0 if s[0] == byte(43) else -1.0
s += 1
y = 1.0 if s[i] == byte(43) else -1.0
i += 1

if not (s[0] == byte(106) or s[0] == byte(74)): # 'j' 'J'
if need_dyn_alloc:
free(s)
if not (i < n and (s[i] == byte(106) or s[i] == byte(74))): # 'j' 'J'
parse_error()

s += 1
elif s[0] == byte(106) or s[0] == byte(74): # 'j' 'J'
s += 1
i += 1
elif i < n and (s[i] == byte(106) or s[i] == byte(74)): # 'j' 'J'
i += 1
y = z
else:
x = z
else:
if s[0] == byte(43) or s[0] == byte(45): # '+' '-'
y = 1.0 if s[0] == byte(43) else -1.0
s += 1
if i < n and (s[i] == byte(43) or s[i] == byte(45)): # '+' '-'
y = 1.0 if s[i] == byte(43) else -1.0
i += 1
else:
y = 1.0

if not (s[0] == byte(106) or s[0] == byte(74)): # 'j' 'J'
if need_dyn_alloc:
free(s)
if not (i < n and (s[i] == byte(106) or s[i] == byte(74))): # 'j' 'J'
parse_error()

s += 1
i += 1

while str._isspace(s[0]):
s += 1
while i < n and str._isspace(s[i]):
i += 1

if got_bracket:
if s[0] != byte(41): # ')'
if need_dyn_alloc:
free(s)
if i < n and s[i] != byte(41): # ')'
parse_error()
s += 1
while str._isspace(s[0]):
s += 1
i += 1
while i < n and str._isspace(s[i]):
i += 1

if s - start != n:
if need_dyn_alloc:
free(s)
if i != n:
parse_error()

if need_dyn_alloc:
free(s)
return complex(x, y)

@extend
Expand Down
20 changes: 10 additions & 10 deletions stdlib/internal/c_stubs.codon
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def seq_str_str(a: str, fmt: str, error: Ptr[bool]) -> str:
def seq_str_ptr(a: cobj, fmt: str, error: Ptr[bool]) -> str:
pass

@nocapture
@C
def seq_int_from_str(a: str, b: Ptr[cobj], c: i32) -> int:
pass

@nocapture
@C
def seq_float_from_str(a: str, b: Ptr[cobj]) -> float:
pass

@pure
@C
def seq_strdup(a: cobj) -> str:
Expand Down Expand Up @@ -647,16 +657,6 @@ def free(a: cobj) -> None:
def atoi(a: cobj) -> int:
pass

@nocapture
@C
def strtoll(a: cobj, b: Ptr[cobj], c: i32) -> int:
pass

@nocapture
@C
def strtod(a: cobj, b: Ptr[cobj]) -> float:
pass

# <zlib.h>
@nocapture
@C
Expand Down
Loading

0 comments on commit ffeeca2

Please sign in to comment.