From dedce22af834214892edd2f2c1622b78c791167f Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 12 Dec 2024 10:20:58 -0300 Subject: [PATCH 01/14] add _write_hex util Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 42 ++++++++++++++++ .../{test_format.mojo => test_write.mojo} | 50 +++++++++++++------ 2 files changed, 76 insertions(+), 16 deletions(-) rename stdlib/test/utils/{test_format.mojo => test_write.mojo} (73%) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 6d68951b9b..027632014f 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -382,3 +382,45 @@ fn write_buffered[ var buffer = _WriteBufferStack[buffer_size](writer^) write_args(buffer, args, sep=sep, end=end) buffer.flush() + + +# ===-----------------------------------------------------------------------===# +# Utils +# ===-----------------------------------------------------------------------===# + + +@always_inline +fn _hex_num_to_hex_string(b: Byte) -> Byte: + alias `0` = Byte(ord("0")) + alias `9` = Byte(ord("9")) + alias `a` = Byte(ord("a")) + return `0` + int(b > 9) * (`a` - `9` - 1) + b + + +@always_inline +fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): + """Write a python compliant hexadecimal value into an uninitialized pointer + location, assumed to be large enough for the value to be written.""" + alias `\\` = Byte(ord("\\")) + alias `x` = Byte(ord("x")) + alias `u` = Byte(ord("u")) + alias `U` = Byte(ord("U")) + + constrained[amnt_hex_bytes in (2, 4, 8), "only 2 or 4 or 8 sequences"]() + p.init_pointee_copy(`\\`) + + @parameter + if amnt_hex_bytes == 2: + (p + 1).init_pointee_copy(`x`) + elif amnt_hex_bytes == 4: + (p + 1).init_pointee_copy(`u`) + else: + (p + 1).init_pointee_copy(`U`) + var idx = 2 + + @parameter + for i in reversed(range(amnt_hex_bytes)): + (p + idx).init_pointee_copy( + _hex_num_to_hex_string((decimal // (16**i)) % 16) + ) + idx += 1 diff --git a/stdlib/test/utils/test_format.mojo b/stdlib/test/utils/test_write.mojo similarity index 73% rename from stdlib/test/utils/test_format.mojo rename to stdlib/test/utils/test_write.mojo index 975d26464b..b4eefe7ea9 100644 --- a/stdlib/test/utils/test_format.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -14,20 +14,12 @@ from testing import assert_equal -from utils import Writable, Writer +from memory.memory import memset_zero +from utils import StringSlice +from utils.write import Writable, Writer, _write_hex from utils.inline_string import _FixedString -fn main() raises: - test_writer_of_string() - test_string_format_seq() - test_stringable_based_on_format() - - test_writer_of_fixed_string() - - test_write_int_padded() - - @value struct Point(Writable, Stringable): var x: Int @@ -42,7 +34,7 @@ struct Point(Writable, Stringable): return String.write(self) -fn test_writer_of_string() raises: +def test_writer_of_string(): # # Test write_to(String) # @@ -58,7 +50,7 @@ fn test_writer_of_string() raises: assert_equal(s2, "Point(3, 8)") -fn test_string_format_seq() raises: +def test_string_write_seq(): var s1 = String.write("Hello, ", "World!") assert_equal(s1, "Hello, World!") @@ -69,17 +61,17 @@ fn test_string_format_seq() raises: assert_equal(s3, "") -fn test_stringable_based_on_format() raises: +def test_stringable_based_on_format(): assert_equal(str(Point(10, 11)), "Point(10, 11)") -fn test_writer_of_fixed_string() raises: +def test_writer_of_fixed_string(): var s1 = _FixedString[100]() s1.write("Hello, World!") assert_equal(str(s1), "Hello, World!") -fn test_write_int_padded() raises: +def test_write_int_padded(): var s1 = String() Int(5).write_padded(s1, width=5) @@ -99,3 +91,29 @@ fn test_write_int_padded() raises: Int(12345).write_padded(s2, width=3) assert_equal(s2, "12345") + + +def test_write_hex(): + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _write_hex[8](ptr, ord("🔥")) + assert_equal(r"\U0001f525", S(ptr=ptr, length=10)) + memset_zero(ptr, len(items)) + _write_hex[4](ptr, ord("你")) + assert_equal(r"\u4f60", S(ptr=ptr, length=6)) + memset_zero(ptr, len(items)) + _write_hex[2](ptr, ord("Ö")) + assert_equal(r"\xd6", S(ptr=ptr, length=4)) + + +def main(): + test_writer_of_string() + test_string_write_seq() + test_stringable_based_on_format() + + test_writer_of_fixed_string() + + test_write_int_padded() + + test_write_hex() From 7cfb2ff89dcb1e4267ad3e4a0ac7acaefb6dd7b1 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 12 Dec 2024 10:37:03 -0300 Subject: [PATCH 02/14] vectorize implementation Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 027632014f..20e3baee77 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -390,37 +390,41 @@ fn write_buffered[ @always_inline -fn _hex_num_to_hex_string(b: Byte) -> Byte: +fn _hex_digit_to_hex_char(b: SIMD[DType.uint8, _]) -> __type_of(b): alias `0` = Byte(ord("0")) alias `9` = Byte(ord("9")) alias `a` = Byte(ord("a")) - return `0` + int(b > 9) * (`a` - `9` - 1) + b + return `0` + (b > 9).cast[DType.uint8]() * (`a` - `9` - 1) + b @always_inline fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): """Write a python compliant hexadecimal value into an uninitialized pointer - location, assumed to be large enough for the value to be written.""" + location, assumed to be large enough for the value to be written. + """ + + constrained[amnt_hex_bytes in (2, 4, 8), "only 2 or 4 or 8 sequences"]() + alias `\\` = Byte(ord("\\")) alias `x` = Byte(ord("x")) alias `u` = Byte(ord("u")) alias `U` = Byte(ord("U")) - constrained[amnt_hex_bytes in (2, 4, 8), "only 2 or 4 or 8 sequences"]() - p.init_pointee_copy(`\\`) + p.init_pointee_move(`\\`) @parameter if amnt_hex_bytes == 2: - (p + 1).init_pointee_copy(`x`) + (p + 1).init_pointee_move(`x`) elif amnt_hex_bytes == 4: - (p + 1).init_pointee_copy(`u`) + (p + 1).init_pointee_move(`u`) else: - (p + 1).init_pointee_copy(`U`) - var idx = 2 + (p + 1).init_pointee_move(`U`) + + var idx = 0 + var digits = SIMD[DType.uint8, amnt_hex_bytes](0) @parameter for i in reversed(range(amnt_hex_bytes)): - (p + idx).init_pointee_copy( - _hex_num_to_hex_string((decimal // (16**i)) % 16) - ) + digits[idx] = Byte((decimal // (16**i)) % 16) idx += 1 + (p + 2).store(_hex_digit_to_hex_char(digits)) From 8f016644da856c472cf655e7046b28b0be3d27c6 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Mon, 16 Dec 2024 23:39:00 -0300 Subject: [PATCH 03/14] refactor implementation Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 39 +++++++++++++++++++++++++++---- stdlib/test/utils/test_write.mojo | 32 ++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 20e3baee77..d90b805d86 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -390,11 +390,40 @@ fn write_buffered[ @always_inline -fn _hex_digit_to_hex_char(b: SIMD[DType.uint8, _]) -> __type_of(b): +fn _hex_digit_to_hex_char(b: Byte) -> __type_of(b): + alias values = SIMD[DType.uint8, 16]( + Byte(ord("0")), + Byte(ord("1")), + Byte(ord("2")), + Byte(ord("3")), + Byte(ord("4")), + Byte(ord("5")), + Byte(ord("6")), + Byte(ord("7")), + Byte(ord("8")), + Byte(ord("9")), + Byte(ord("a")), + Byte(ord("b")), + Byte(ord("c")), + Byte(ord("d")), + Byte(ord("e")), + Byte(ord("f")), + ) + return values[int(b)] + + +@always_inline +fn _hex_digits_to_hex_char(b: SIMD[DType.uint8, _]) -> __type_of(b): alias `0` = Byte(ord("0")) alias `9` = Byte(ord("9")) alias `a` = Byte(ord("a")) - return `0` + (b > 9).cast[DType.uint8]() * (`a` - `9` - 1) + b + alias I8 = DType.int8 + alias U8 = DType.uint8 + return ( + `0` + + b + + (((b <= 9).cast[I8]() - 1) & (`a` - `9` - 1).cast[I8]()).cast[U8]() + ) @always_inline @@ -421,10 +450,10 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): (p + 1).init_pointee_move(`U`) var idx = 0 - var digits = SIMD[DType.uint8, amnt_hex_bytes](0) @parameter for i in reversed(range(amnt_hex_bytes)): - digits[idx] = Byte((decimal // (16**i)) % 16) + (p + 2 + idx).init_pointee_move( + _hex_digit_to_hex_char((decimal // (16**i)) % 16) + ) idx += 1 - (p + 2).store(_hex_digit_to_hex_char(digits)) diff --git a/stdlib/test/utils/test_write.mojo b/stdlib/test/utils/test_write.mojo index b4eefe7ea9..bd1f0354bf 100644 --- a/stdlib/test/utils/test_write.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -16,7 +16,13 @@ from testing import assert_equal from memory.memory import memset_zero from utils import StringSlice -from utils.write import Writable, Writer, _write_hex +from utils.write import ( + Writable, + Writer, + _write_hex, + _hex_digit_to_hex_char, + _hex_digits_to_hex_char, +) from utils.inline_string import _FixedString @@ -94,6 +100,30 @@ def test_write_int_padded(): def test_write_hex(): + values = List[Byte]( + ord("0"), + ord("1"), + ord("2"), + ord("3"), + ord("4"), + ord("5"), + ord("6"), + ord("7"), + ord("8"), + ord("9"), + ord("a"), + ord("b"), + ord("c"), + ord("d"), + ord("e"), + ord("f"), + ) + idx = 0 + for value in values: + assert_equal(_hex_digit_to_hex_char(idx), value[]) + assert_equal(_hex_digits_to_hex_char(Byte(idx)), value[]) + idx += 1 + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) alias S = StringSlice[__origin_of(items)] ptr = items.unsafe_ptr() From 690f27b3314d53c9431f7afcd39792d25c9ea765 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Mon, 16 Dec 2024 23:41:43 -0300 Subject: [PATCH 04/14] fix details Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 21 ++++++++++++++++++++- stdlib/test/utils/test_write.mojo | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index d90b805d86..c85bcc9514 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -413,7 +413,7 @@ fn _hex_digit_to_hex_char(b: Byte) -> __type_of(b): @always_inline -fn _hex_digits_to_hex_char(b: SIMD[DType.uint8, _]) -> __type_of(b): +fn _hex_digits_to_hex_chars(b: SIMD[DType.uint8, _]) -> __type_of(b): alias `0` = Byte(ord("0")) alias `9` = Byte(ord("9")) alias `a` = Byte(ord("a")) @@ -430,6 +430,25 @@ fn _hex_digits_to_hex_char(b: SIMD[DType.uint8, _]) -> __type_of(b): fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): """Write a python compliant hexadecimal value into an uninitialized pointer location, assumed to be large enough for the value to be written. + + Examples: + + ```mojo + %# from utils import StringSlice + %# from utils.write import _write_hex + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _write_hex[8](ptr, ord("🔥")) + assert_equal(r"\U0001f525", S(ptr=ptr, length=10)) + memset_zero(ptr, len(items)) + _write_hex[4](ptr, ord("你")) + assert_equal(r"\u4f60", S(ptr=ptr, length=6)) + memset_zero(ptr, len(items)) + _write_hex[2](ptr, ord("Ö")) + assert_equal(r"\xd6", S(ptr=ptr, length=4)) + ``` + . """ constrained[amnt_hex_bytes in (2, 4, 8), "only 2 or 4 or 8 sequences"]() diff --git a/stdlib/test/utils/test_write.mojo b/stdlib/test/utils/test_write.mojo index bd1f0354bf..1293fcc1f6 100644 --- a/stdlib/test/utils/test_write.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -121,7 +121,7 @@ def test_write_hex(): idx = 0 for value in values: assert_equal(_hex_digit_to_hex_char(idx), value[]) - assert_equal(_hex_digits_to_hex_char(Byte(idx)), value[]) + assert_equal(_hex_digits_to_hex_chars(Byte(idx)), value[]) idx += 1 items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) From 57831d61e63fae8080fe503bf5161dc9a6d38f14 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Mon, 16 Dec 2024 23:42:25 -0300 Subject: [PATCH 05/14] fix details Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 1 + 1 file changed, 1 insertion(+) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index c85bcc9514..5a81f03d27 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -434,6 +434,7 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): Examples: ```mojo + %# from memory import memset_zero %# from utils import StringSlice %# from utils.write import _write_hex items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) From c73f74159779fbb3ea999fbd7c09ea20ac270998 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 17 Dec 2024 11:15:19 -0300 Subject: [PATCH 06/14] add fixme for issue #3889 Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 5a81f03d27..60a57ca20c 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -390,7 +390,7 @@ fn write_buffered[ @always_inline -fn _hex_digit_to_hex_char(b: Byte) -> __type_of(b): +fn _hex_digit_to_hex_char(b: Byte) -> Byte: alias values = SIMD[DType.uint8, 16]( Byte(ord("0")), Byte(ord("1")), @@ -435,19 +435,21 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): ```mojo %# from memory import memset_zero + %# from testing import assert_equal %# from utils import StringSlice %# from utils.write import _write_hex items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) alias S = StringSlice[__origin_of(items)] ptr = items.unsafe_ptr() _write_hex[8](ptr, ord("🔥")) - assert_equal(r"\U0001f525", S(ptr=ptr, length=10)) + # FIXME(#3889): this example should not need to be commented, docstrings issue + # assert_equal(r"\U0001f525", S(ptr=ptr, length=10)) memset_zero(ptr, len(items)) _write_hex[4](ptr, ord("你")) - assert_equal(r"\u4f60", S(ptr=ptr, length=6)) + # assert_equal(r"\u4f60", S(ptr=ptr, length=6)) memset_zero(ptr, len(items)) _write_hex[2](ptr, ord("Ö")) - assert_equal(r"\xd6", S(ptr=ptr, length=4)) + # assert_equal(r"\xd6", S(ptr=ptr, length=4)) ``` . """ From 033d7425a1be461e063c6babd1602606f834fade Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 17 Dec 2024 11:24:27 -0300 Subject: [PATCH 07/14] comment examples Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 60a57ca20c..b8e20ed284 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -430,29 +430,28 @@ fn _hex_digits_to_hex_chars(b: SIMD[DType.uint8, _]) -> __type_of(b): fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): """Write a python compliant hexadecimal value into an uninitialized pointer location, assumed to be large enough for the value to be written. - - Examples: - - ```mojo - %# from memory import memset_zero - %# from testing import assert_equal - %# from utils import StringSlice - %# from utils.write import _write_hex - items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) - alias S = StringSlice[__origin_of(items)] - ptr = items.unsafe_ptr() - _write_hex[8](ptr, ord("🔥")) + """ # FIXME(#3889): this example should not need to be commented, docstrings issue + # Examples: + + # ```mojo + # %# from memory import memset_zero + # %# from testing import assert_equal + # %# from utils import StringSlice + # %# from utils.write import _write_hex + # items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + # alias S = StringSlice[__origin_of(items)] + # ptr = items.unsafe_ptr() + # _write_hex[8](ptr, ord("🔥")) # assert_equal(r"\U0001f525", S(ptr=ptr, length=10)) - memset_zero(ptr, len(items)) - _write_hex[4](ptr, ord("你")) + # memset_zero(ptr, len(items)) + # _write_hex[4](ptr, ord("你")) # assert_equal(r"\u4f60", S(ptr=ptr, length=6)) - memset_zero(ptr, len(items)) - _write_hex[2](ptr, ord("Ö")) + # memset_zero(ptr, len(items)) + # _write_hex[2](ptr, ord("Ö")) # assert_equal(r"\xd6", S(ptr=ptr, length=4)) - ``` - . - """ + # ``` + # . constrained[amnt_hex_bytes in (2, 4, 8), "only 2 or 4 or 8 sequences"]() From 2324d7fc4c9dbe9419c25b2ae6b25bdb9b1f654f Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 17 Dec 2024 11:29:05 -0300 Subject: [PATCH 08/14] fix import Signed-off-by: martinvuyk --- stdlib/test/utils/test_write.mojo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/test/utils/test_write.mojo b/stdlib/test/utils/test_write.mojo index 1293fcc1f6..790ad6382b 100644 --- a/stdlib/test/utils/test_write.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -21,7 +21,7 @@ from utils.write import ( Writer, _write_hex, _hex_digit_to_hex_char, - _hex_digits_to_hex_char, + _hex_digits_to_hex_chars, ) from utils.inline_string import _FixedString From 6b7c940a0cde60316d4635493566fc226e6b4bf5 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 17 Dec 2024 15:31:57 -0300 Subject: [PATCH 09/14] reintroduce examples escaping the sequences Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 42 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index b8e20ed284..e6b7814f73 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -430,28 +430,28 @@ fn _hex_digits_to_hex_chars(b: SIMD[DType.uint8, _]) -> __type_of(b): fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): """Write a python compliant hexadecimal value into an uninitialized pointer location, assumed to be large enough for the value to be written. + + Examples: + + ```mojo + %# from memory import memset_zero + %# from testing import assert_equal + %# from utils import StringSlice + %# from utils.write import _write_hex + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _write_hex[8](ptr, ord("🔥")) + assert_equal(r"\\U0001f525", S(ptr=ptr, length=10)) + memset_zero(ptr, len(items)) + _write_hex[4](ptr, ord("你")) + assert_equal(r"\\u4f60", S(ptr=ptr, length=6)) + memset_zero(ptr, len(items)) + _write_hex[2](ptr, ord("Ö")) + assert_equal(r"\\xd6", S(ptr=ptr, length=4)) + ``` + . """ - # FIXME(#3889): this example should not need to be commented, docstrings issue - # Examples: - - # ```mojo - # %# from memory import memset_zero - # %# from testing import assert_equal - # %# from utils import StringSlice - # %# from utils.write import _write_hex - # items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) - # alias S = StringSlice[__origin_of(items)] - # ptr = items.unsafe_ptr() - # _write_hex[8](ptr, ord("🔥")) - # assert_equal(r"\U0001f525", S(ptr=ptr, length=10)) - # memset_zero(ptr, len(items)) - # _write_hex[4](ptr, ord("你")) - # assert_equal(r"\u4f60", S(ptr=ptr, length=6)) - # memset_zero(ptr, len(items)) - # _write_hex[2](ptr, ord("Ö")) - # assert_equal(r"\xd6", S(ptr=ptr, length=4)) - # ``` - # . constrained[amnt_hex_bytes in (2, 4, 8), "only 2 or 4 or 8 sequences"]() From ac5c11b8e919410122876ca8995c833f0c567fdd Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 19 Dec 2024 12:08:42 -0300 Subject: [PATCH 10/14] add the implementation from #3694 by @soraros Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 65 +++++++++++-------------------- stdlib/test/utils/test_write.mojo | 25 ------------ 2 files changed, 23 insertions(+), 67 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index e6b7814f73..b1ab1fda8a 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -12,10 +12,11 @@ # ===----------------------------------------------------------------------=== # """Establishes the contract between `Writer` and `Writable` types.""" +from bit import byte_swap from collections import InlineArray from sys.info import is_gpu -from memory import UnsafePointer, memcpy, Span +from memory import UnsafePointer, memcpy, Span, bitcast from utils import StaticString @@ -389,41 +390,27 @@ fn write_buffered[ # ===-----------------------------------------------------------------------===# -@always_inline -fn _hex_digit_to_hex_char(b: Byte) -> Byte: - alias values = SIMD[DType.uint8, 16]( - Byte(ord("0")), - Byte(ord("1")), - Byte(ord("2")), - Byte(ord("3")), - Byte(ord("4")), - Byte(ord("5")), - Byte(ord("6")), - Byte(ord("7")), - Byte(ord("8")), - Byte(ord("9")), - Byte(ord("a")), - Byte(ord("b")), - Byte(ord("c")), - Byte(ord("d")), - Byte(ord("e")), - Byte(ord("f")), - ) - return values[int(b)] +# fmt: off +alias _hex_table = SIMD[DType.uint8, 16]( + ord("0"), ord("1"), ord("2"), ord("3"), ord("4"), ord("5"), ord("6"), + ord("7"), ord("8"), ord("9"), ord("a"), ord("b"), ord("c"), ord("d"), + ord("e"), ord("f"), +) +# fmt: on @always_inline -fn _hex_digits_to_hex_chars(b: SIMD[DType.uint8, _]) -> __type_of(b): - alias `0` = Byte(ord("0")) - alias `9` = Byte(ord("9")) - alias `a` = Byte(ord("a")) - alias I8 = DType.int8 - alias U8 = DType.uint8 - return ( - `0` - + b - + (((b <= 9).cast[I8]() - 1) & (`a` - `9` - 1).cast[I8]()).cast[U8]() - ) +fn _hex_digits_to_hex_chars(x: Scalar, ptr: UnsafePointer[Byte]): + alias size = x.type.sizeof() + var data: SIMD[DType.uint8, size] + + @parameter + if size == 1: + data = bitcast[DType.uint8, size](x) + else: + data = bitcast[DType.uint8, size](byte_swap(x)) + var nibbles = (data >> 4).interleave(data & 0xF) + ptr.store(_hex_table._dynamic_shuffle(nibbles)) @always_inline @@ -465,16 +452,10 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): @parameter if amnt_hex_bytes == 2: (p + 1).init_pointee_move(`x`) + _hex_digits_to_hex_chars(Scalar[DType.uint8](decimal), p + 2) elif amnt_hex_bytes == 4: (p + 1).init_pointee_move(`u`) + _hex_digits_to_hex_chars(Scalar[DType.uint16](decimal), p + 2) else: (p + 1).init_pointee_move(`U`) - - var idx = 0 - - @parameter - for i in reversed(range(amnt_hex_bytes)): - (p + 2 + idx).init_pointee_move( - _hex_digit_to_hex_char((decimal // (16**i)) % 16) - ) - idx += 1 + _hex_digits_to_hex_chars(Scalar[DType.uint32](decimal), p + 2) diff --git a/stdlib/test/utils/test_write.mojo b/stdlib/test/utils/test_write.mojo index 790ad6382b..3703599ba5 100644 --- a/stdlib/test/utils/test_write.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -20,7 +20,6 @@ from utils.write import ( Writable, Writer, _write_hex, - _hex_digit_to_hex_char, _hex_digits_to_hex_chars, ) from utils.inline_string import _FixedString @@ -100,30 +99,6 @@ def test_write_int_padded(): def test_write_hex(): - values = List[Byte]( - ord("0"), - ord("1"), - ord("2"), - ord("3"), - ord("4"), - ord("5"), - ord("6"), - ord("7"), - ord("8"), - ord("9"), - ord("a"), - ord("b"), - ord("c"), - ord("d"), - ord("e"), - ord("f"), - ) - idx = 0 - for value in values: - assert_equal(_hex_digit_to_hex_char(idx), value[]) - assert_equal(_hex_digits_to_hex_chars(Byte(idx)), value[]) - idx += 1 - items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) alias S = StringSlice[__origin_of(items)] ptr = items.unsafe_ptr() From 4cbdf9bf3c0468cf79cb3a961492f10c95d0db6c Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 19 Dec 2024 12:18:47 -0300 Subject: [PATCH 11/14] add examples and unit test Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 39 +++++++++++++++++++++++++------ stdlib/test/utils/test_write.mojo | 27 +++++++++++++++++++++ 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index b1ab1fda8a..0d41dd7478 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -400,15 +400,40 @@ alias _hex_table = SIMD[DType.uint8, 16]( @always_inline -fn _hex_digits_to_hex_chars(x: Scalar, ptr: UnsafePointer[Byte]): - alias size = x.type.sizeof() +fn _hex_digits_to_hex_chars(ptr: UnsafePointer[Byte], decimal: Scalar): + """Write a fixed width hexadecimal value into an uninitialized pointer + location, assumed to be large enough for the value to be written. + + Examples: + + ```mojo + %# from memory import memset_zero + %# from testing import assert_equal + %# from utils import StringSlice + %# from utils.write import _write_hex + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _hex_digits_to_hex_chars(ptr, UInt32(ord("🔥"))) + assert_equal("0001f525", S(ptr=ptr, length=10)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt16(ord("你"))) + assert_equal("4f60", S(ptr=ptr, length=6)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) + assert_equal("xd6", S(ptr=ptr, length=4)) + ``` + . + """ + + alias size = decimal.type.sizeof() var data: SIMD[DType.uint8, size] @parameter if size == 1: - data = bitcast[DType.uint8, size](x) + data = bitcast[DType.uint8, size](decimal) else: - data = bitcast[DType.uint8, size](byte_swap(x)) + data = bitcast[DType.uint8, size](byte_swap(decimal)) var nibbles = (data >> 4).interleave(data & 0xF) ptr.store(_hex_table._dynamic_shuffle(nibbles)) @@ -452,10 +477,10 @@ fn _write_hex[amnt_hex_bytes: Int](p: UnsafePointer[Byte], decimal: Int): @parameter if amnt_hex_bytes == 2: (p + 1).init_pointee_move(`x`) - _hex_digits_to_hex_chars(Scalar[DType.uint8](decimal), p + 2) + _hex_digits_to_hex_chars(p + 2, UInt8(decimal)) elif amnt_hex_bytes == 4: (p + 1).init_pointee_move(`u`) - _hex_digits_to_hex_chars(Scalar[DType.uint16](decimal), p + 2) + _hex_digits_to_hex_chars(p + 2, UInt16(decimal)) else: (p + 1).init_pointee_move(`U`) - _hex_digits_to_hex_chars(Scalar[DType.uint32](decimal), p + 2) + _hex_digits_to_hex_chars(p + 2, UInt32(decimal)) diff --git a/stdlib/test/utils/test_write.mojo b/stdlib/test/utils/test_write.mojo index 3703599ba5..b4cfab20c0 100644 --- a/stdlib/test/utils/test_write.mojo +++ b/stdlib/test/utils/test_write.mojo @@ -98,6 +98,32 @@ def test_write_int_padded(): assert_equal(s2, "12345") +def test_hex_digits_to_hex_chars(): + items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) + alias S = StringSlice[__origin_of(items)] + ptr = items.unsafe_ptr() + _hex_digits_to_hex_chars(ptr, UInt32(ord("🔥"))) + assert_equal("0001f525", S(ptr=ptr, length=8)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt16(ord("你"))) + assert_equal("4f60", S(ptr=ptr, length=4)) + memset_zero(ptr, len(items)) + _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) + assert_equal("d6", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, UInt8(0)) + assert_equal("00", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, UInt16(0)) + assert_equal("0000", S(ptr=ptr, length=4)) + _hex_digits_to_hex_chars(ptr, UInt32(0)) + assert_equal("00000000", S(ptr=ptr, length=8)) + _hex_digits_to_hex_chars(ptr, ~UInt8(0)) + assert_equal("ff", S(ptr=ptr, length=2)) + _hex_digits_to_hex_chars(ptr, ~UInt16(0)) + assert_equal("ffff", S(ptr=ptr, length=4)) + _hex_digits_to_hex_chars(ptr, ~UInt32(0)) + assert_equal("ffffffff", S(ptr=ptr, length=8)) + + def test_write_hex(): items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) alias S = StringSlice[__origin_of(items)] @@ -121,4 +147,5 @@ def main(): test_write_int_padded() + test_hex_digits_to_hex_chars() test_write_hex() From 8f501c7646bff91cbb3d3fe818aa631c2f378c4a Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 19 Dec 2024 12:20:04 -0300 Subject: [PATCH 12/14] fix examples Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 0d41dd7478..06ed392029 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -415,13 +415,13 @@ fn _hex_digits_to_hex_chars(ptr: UnsafePointer[Byte], decimal: Scalar): alias S = StringSlice[__origin_of(items)] ptr = items.unsafe_ptr() _hex_digits_to_hex_chars(ptr, UInt32(ord("🔥"))) - assert_equal("0001f525", S(ptr=ptr, length=10)) + assert_equal("0001f525", S(ptr=ptr, length=8)) memset_zero(ptr, len(items)) _hex_digits_to_hex_chars(ptr, UInt16(ord("你"))) - assert_equal("4f60", S(ptr=ptr, length=6)) + assert_equal("4f60", S(ptr=ptr, length=4)) memset_zero(ptr, len(items)) _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) - assert_equal("xd6", S(ptr=ptr, length=4)) + assert_equal("xd6", S(ptr=ptr, length=2)) ``` . """ From 0df9c64101c02e3c796ed047fb652592aff44c76 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 19 Dec 2024 12:20:25 -0300 Subject: [PATCH 13/14] fix examples Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 06ed392029..2a186dcd7b 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -421,7 +421,7 @@ fn _hex_digits_to_hex_chars(ptr: UnsafePointer[Byte], decimal: Scalar): assert_equal("4f60", S(ptr=ptr, length=4)) memset_zero(ptr, len(items)) _hex_digits_to_hex_chars(ptr, UInt8(ord("Ö"))) - assert_equal("xd6", S(ptr=ptr, length=2)) + assert_equal("d6", S(ptr=ptr, length=2)) ``` . """ From 493c6f9fa6418c7084f475a813538a7101f38cf0 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Thu, 19 Dec 2024 12:28:06 -0300 Subject: [PATCH 14/14] fix examples Signed-off-by: martinvuyk --- stdlib/src/utils/write.mojo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/src/utils/write.mojo b/stdlib/src/utils/write.mojo index 2a186dcd7b..4aeada7b87 100644 --- a/stdlib/src/utils/write.mojo +++ b/stdlib/src/utils/write.mojo @@ -410,7 +410,7 @@ fn _hex_digits_to_hex_chars(ptr: UnsafePointer[Byte], decimal: Scalar): %# from memory import memset_zero %# from testing import assert_equal %# from utils import StringSlice - %# from utils.write import _write_hex + %# from utils.write import _hex_digits_to_hex_chars items = List[Byte](0, 0, 0, 0, 0, 0, 0, 0, 0) alias S = StringSlice[__origin_of(items)] ptr = items.unsafe_ptr()