diff --git a/stdlib/src/builtin/string_literal.mojo b/stdlib/src/builtin/string_literal.mojo index c5af9cef8d..5917daf5cb 100644 --- a/stdlib/src/builtin/string_literal.mojo +++ b/stdlib/src/builtin/string_literal.mojo @@ -747,10 +747,10 @@ struct StringLiteral( """ return str(self).split(sep, maxsplit) - fn splitlines(self, keepends: Bool = False) -> List[String]: - """Split the string literal at line boundaries. This corresponds to Python's - [universal newlines:]( - https://docs.python.org/3/library/stdtypes.html#str.splitlines) + fn splitlines(self, keepends: Bool = False) -> List[StaticString]: + """Split the string literal at line boundaries. This corresponds to + Python's [universal newlines:]( + https://docs.python.org/3/library/stdtypes.html#str.splitlines) `"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`. Args: @@ -759,7 +759,7 @@ struct StringLiteral( Returns: A List of Strings containing the input split by line boundaries. """ - return _to_string_list(self.as_string_slice().splitlines(keepends)) + return self.as_string_slice().splitlines(keepends) fn count(self, substr: String) -> Int: """Return the number of non-overlapping occurrences of substring diff --git a/stdlib/src/collections/string/string.mojo b/stdlib/src/collections/string/string.mojo index 7cbbd2a884..56e34c9b9d 100644 --- a/stdlib/src/collections/string/string.mojo +++ b/stdlib/src/collections/string/string.mojo @@ -1694,10 +1694,12 @@ struct String( return output - fn splitlines(self, keepends: Bool = False) -> List[String]: + fn splitlines( + ref self, keepends: Bool = False + ) -> List[StringSlice[__origin_of(self)]]: """Split the string at line boundaries. This corresponds to Python's [universal newlines:]( - https://docs.python.org/3/library/stdtypes.html#str.splitlines) + https://docs.python.org/3/library/stdtypes.html#str.splitlines) `"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`. Args: @@ -1706,7 +1708,7 @@ struct String( Returns: A List of Strings containing the input split by line boundaries. """ - return _to_string_list(self.as_string_slice().splitlines(keepends)) + return self.as_string_slice().splitlines(keepends) fn replace(self, old: String, new: String) -> String: """Return a copy of the string with all occurrences of substring `old` diff --git a/stdlib/src/collections/string/string_slice.mojo b/stdlib/src/collections/string/string_slice.mojo index 977f5b7f76..9e2d738215 100644 --- a/stdlib/src/collections/string/string_slice.mojo +++ b/stdlib/src/collections/string/string_slice.mojo @@ -1111,29 +1111,23 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]]( offset += b_len return length != 0 - fn splitlines[ - O: ImmutableOrigin, // - ](self: StringSlice[O], keepends: Bool = False) -> List[StringSlice[O]]: + fn splitlines(self, keepends: Bool = False) -> List[Self]: """Split the string at line boundaries. This corresponds to Python's [universal newlines:]( https://docs.python.org/3/library/stdtypes.html#str.splitlines) `"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`. - Parameters: - O: The immutable origin. - Args: keepends: If True, line breaks are kept in the resulting strings. Returns: A List of Strings containing the input split by line boundaries. """ - # highly performance sensitive code, benchmark before touching alias `\r` = UInt8(ord("\r")) alias `\n` = UInt8(ord("\n")) - output = List[StringSlice[O]](capacity=128) # guessing + output = List[Self](capacity=128) # guessing var ptr = self.unsafe_ptr() var length = self.byte_length() var offset = 0 @@ -1163,7 +1157,7 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]]( eol_start += char_len var str_len = eol_start - offset + int(keepends) * eol_length - var s = StringSlice[O](ptr=ptr + offset, length=str_len) + var s = Self(ptr=ptr + offset, length=str_len) output.append(s) offset = eol_start + eol_length @@ -1209,29 +1203,30 @@ fn _to_string_list[ len_fn: fn (T) -> Int, unsafe_ptr_fn: fn (T) -> UnsafePointer[Byte], ](items: List[T]) -> List[String]: - i_len = len(items) - i_ptr = items.unsafe_ptr() - out_ptr = UnsafePointer[String].alloc(i_len) + var i_len = len(items) + var i_ptr = items.unsafe_ptr() + var out_ptr = UnsafePointer[String].alloc(i_len) for i in range(i_len): - og_len = len_fn(i_ptr[i]) - f_len = og_len + 1 # null terminator - p = UnsafePointer[Byte].alloc(f_len) - og_ptr = unsafe_ptr_fn(i_ptr[i]) + var og_len = len_fn(i_ptr[i]) + var f_len = og_len + 1 # null terminator + var p = UnsafePointer[Byte].alloc(f_len) + var og_ptr = unsafe_ptr_fn(i_ptr[i]) memcpy(p, og_ptr, og_len) p[og_len] = 0 # null terminator - buf = String._buffer_type(ptr=p, length=f_len, capacity=f_len) + var buf = String._buffer_type(ptr=p, length=f_len, capacity=f_len) (out_ptr + i).init_pointee_move(String(buf^)) return List[String](ptr=out_ptr, length=i_len, capacity=i_len) @always_inline -fn _to_string_list[ - O: ImmutableOrigin, // +fn to_string_list[ + mut: Bool, O: Origin[mut], // ](items: List[StringSlice[O]]) -> List[String]: """Create a list of Strings **copying** the existing data. Parameters: + mut: The mutability of the origin. O: The origin of the data. Args: @@ -1251,12 +1246,13 @@ fn _to_string_list[ @always_inline -fn _to_string_list[ - O: ImmutableOrigin, // +fn to_string_list[ + mut: Bool, O: Origin[mut], // ](items: List[Span[Byte, O]]) -> List[String]: """Create a list of Strings **copying** the existing data. Parameters: + mut: The mutability of the origin. O: The origin of the data. Args: diff --git a/stdlib/test/builtin/test_string_literal.mojo b/stdlib/test/builtin/test_string_literal.mojo index c549a9423c..5d6d08ad05 100644 --- a/stdlib/test/builtin/test_string_literal.mojo +++ b/stdlib/test/builtin/test_string_literal.mojo @@ -22,6 +22,7 @@ from testing import ( assert_raises, assert_true, ) +from collections.string.string_slice import StringSlice, to_string_list def test_add(): @@ -442,7 +443,8 @@ def test_split(): def test_splitlines(): - alias L = List[String] + alias L = List[StringSlice[StaticConstantOrigin]] + # Test with no line breaks assert_equal("hello world".splitlines(), L("hello world")) @@ -478,6 +480,22 @@ def test_splitlines(): L("hello\x1c", "world\x1d", "mojo\x1e", "language\x1e"), ) + # test \x85 \u2028 \u2029 + var next_line = String(List[UInt8](0xC2, 0x85, 0)) + """TODO: \\x85""" + var unicode_line_sep = String(List[UInt8](0xE2, 0x80, 0xA8, 0)) + """TODO: \\u2028""" + var unicode_paragraph_sep = String(List[UInt8](0xE2, 0x80, 0xA9, 0)) + """TODO: \\u2029""" + + for i in List(next_line, unicode_line_sep, unicode_paragraph_sep): + u = i[] + item = String("").join("hello", u, "world", u, "mojo", u, "language", u) + s = StringSlice(item) + assert_equal(s.splitlines(), hello_mojo) + items = List("hello" + u, "world" + u, "mojo" + u, "language" + u) + assert_equal(to_string_list(s.splitlines(keepends=True)), items) + def test_float_conversion(): assert_equal(("4.5").__float__(), 4.5) diff --git a/stdlib/test/collections/string/test_string.mojo b/stdlib/test/collections/string/test_string.mojo index 53f9f6ace8..9f83fc3ef6 100644 --- a/stdlib/test/collections/string/test_string.mojo +++ b/stdlib/test/collections/string/test_string.mojo @@ -20,12 +20,12 @@ from testing import ( assert_true, ) -from collections.string import StringSlice from collections.string.string import ( _calc_initial_buffer_size_int32, _calc_initial_buffer_size_int64, _isspace, ) +from collections.string.string_slice import StringSlice, to_string_list from memory import UnsafePointer from python import Python from utils import StringRef @@ -755,17 +755,18 @@ def test_split(): def test_splitlines(): - alias L = List[String] + alias L = List[StringSlice[StaticConstantOrigin]] + # Test with no line breaks - assert_equal(String("hello world").splitlines(), L("hello world")) + assert_equal("hello world".splitlines(), L("hello world")) # Test with line breaks - assert_equal(String("hello\nworld").splitlines(), L("hello", "world")) - assert_equal(String("hello\rworld").splitlines(), L("hello", "world")) - assert_equal(String("hello\r\nworld").splitlines(), L("hello", "world")) + assert_equal("hello\nworld".splitlines(), L("hello", "world")) + assert_equal("hello\rworld".splitlines(), L("hello", "world")) + assert_equal("hello\r\nworld".splitlines(), L("hello", "world")) # Test with multiple different line breaks - s1 = String("hello\nworld\r\nmojo\rlanguage\r\n") + s1 = "hello\nworld\r\nmojo\rlanguage\r\n" hello_mojo = L("hello", "world", "mojo", "language") assert_equal(s1.splitlines(), hello_mojo) assert_equal( @@ -774,9 +775,9 @@ def test_splitlines(): ) # Test with an empty string - assert_equal(String("").splitlines(), L()) + assert_equal("".splitlines(), L()) # test \v \f \x1c \x1d - s2 = String("hello\vworld\fmojo\x1clanguage\x1d") + s2 = "hello\vworld\fmojo\x1clanguage\x1d" assert_equal(s2.splitlines(), hello_mojo) assert_equal( s2.splitlines(keepends=True), @@ -784,7 +785,7 @@ def test_splitlines(): ) # test \x1c \x1d \x1e - s3 = String("hello\x1cworld\x1dmojo\x1elanguage\x1e") + s3 = "hello\x1cworld\x1dmojo\x1elanguage\x1e" assert_equal(s3.splitlines(), hello_mojo) assert_equal( s3.splitlines(keepends=True), @@ -792,21 +793,20 @@ def test_splitlines(): ) # test \x85 \u2028 \u2029 - var next_line = List[UInt8](0xC2, 0x85, 0) + var next_line = String(List[UInt8](0xC2, 0x85, 0)) """TODO: \\x85""" - var unicode_line_sep = List[UInt8](0xE2, 0x80, 0xA8, 0) + var unicode_line_sep = String(List[UInt8](0xE2, 0x80, 0xA8, 0)) """TODO: \\u2028""" - var unicode_paragraph_sep = List[UInt8](0xE2, 0x80, 0xA9, 0) + var unicode_paragraph_sep = String(List[UInt8](0xE2, 0x80, 0xA9, 0)) """TODO: \\u2029""" for i in List(next_line, unicode_line_sep, unicode_paragraph_sep): - u = String(i[]) + u = i[] item = String("").join("hello", u, "world", u, "mojo", u, "language", u) - assert_equal(item.splitlines(), hello_mojo) - assert_equal( - item.splitlines(keepends=True), - L("hello" + u, "world" + u, "mojo" + u, "language" + u), - ) + s = StringSlice(item) + assert_equal(s.splitlines(), hello_mojo) + items = List("hello" + u, "world" + u, "mojo" + u, "language" + u) + assert_equal(to_string_list(s.splitlines(keepends=True)), items) def test_isupper(): diff --git a/stdlib/test/collections/string/test_string_slice.mojo b/stdlib/test/collections/string/test_string_slice.mojo index 32b6da2bf9..d183d51c10 100644 --- a/stdlib/test/collections/string/test_string_slice.mojo +++ b/stdlib/test/collections/string/test_string_slice.mojo @@ -17,6 +17,7 @@ from testing import assert_equal, assert_false, assert_true, assert_raises from collections.string.string_slice import ( StringSlice, _count_utf8_continuation_bytes, + to_string_list, ) from collections.string._utf8_validation import _is_valid_utf8 from memory import Span, UnsafePointer @@ -480,27 +481,18 @@ def test_count_utf8_continuation_bytes(): def test_splitlines(): - alias S = StringSlice[StaticConstantOrigin] alias L = List[StringSlice[StaticConstantOrigin]] - # FIXME: remove once StringSlice conforms to TestableCollectionElement - fn _assert_equal[ - O1: ImmutableOrigin - ](l1: List[StringSlice[O1]], l2: List[String]) raises: - assert_equal(len(l1), len(l2)) - for i in range(len(l1)): - assert_equal(str(l1[i]), l2[i]) - # Test with no line breaks - assert_equal(S("hello world").splitlines(), L("hello world")) + assert_equal("hello world".splitlines(), L("hello world")) # Test with line breaks - assert_equal(S("hello\nworld").splitlines(), L("hello", "world")) - assert_equal(S("hello\rworld").splitlines(), L("hello", "world")) - assert_equal(S("hello\r\nworld").splitlines(), L("hello", "world")) + assert_equal("hello\nworld".splitlines(), L("hello", "world")) + assert_equal("hello\rworld".splitlines(), L("hello", "world")) + assert_equal("hello\r\nworld".splitlines(), L("hello", "world")) # Test with multiple different line breaks - s1 = S("hello\nworld\r\nmojo\rlanguage\r\n") + s1 = "hello\nworld\r\nmojo\rlanguage\r\n" hello_mojo = L("hello", "world", "mojo", "language") assert_equal(s1.splitlines(), hello_mojo) assert_equal( @@ -509,9 +501,9 @@ def test_splitlines(): ) # Test with an empty string - assert_equal(S("").splitlines(), L()) + assert_equal("".splitlines(), L()) # test \v \f \x1c \x1d - s2 = S("hello\vworld\fmojo\x1clanguage\x1d") + s2 = "hello\vworld\fmojo\x1clanguage\x1d" assert_equal(s2.splitlines(), hello_mojo) assert_equal( s2.splitlines(keepends=True), @@ -519,7 +511,7 @@ def test_splitlines(): ) # test \x1c \x1d \x1e - s3 = S("hello\x1cworld\x1dmojo\x1elanguage\x1e") + s3 = "hello\x1cworld\x1dmojo\x1elanguage\x1e" assert_equal(s3.splitlines(), hello_mojo) assert_equal( s3.splitlines(keepends=True), @@ -540,7 +532,7 @@ def test_splitlines(): s = StringSlice(item) assert_equal(s.splitlines(), hello_mojo) items = List("hello" + u, "world" + u, "mojo" + u, "language" + u) - _assert_equal(s.splitlines(keepends=True), items) + assert_equal(to_string_list(s.splitlines(keepends=True)), items) def test_rstrip():