Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Make String and StringLiteral .splitlines() return List[StringSlice] #3894

Open
wants to merge 11 commits into
base: nightly
Choose a base branch
from
10 changes: 5 additions & 5 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,10 @@ struct StringLiteral(
"""
return str(self).split(sep, maxsplit)

fn splitlines(self, keepends: Bool = False) -> List[String]:
"""Split the string literal at line boundaries. This corresponds to Python's
[universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
fn splitlines(self, keepends: Bool = False) -> List[StaticString]:
"""Split the string literal at line boundaries. This corresponds to
Python's [universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
`"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`.

Args:
Expand All @@ -759,7 +759,7 @@ struct StringLiteral(
Returns:
A List of Strings containing the input split by line boundaries.
"""
return _to_string_list(self.as_string_slice().splitlines(keepends))
return self.as_string_slice().splitlines(keepends)

fn count(self, substr: String) -> Int:
"""Return the number of non-overlapping occurrences of substring
Expand Down
8 changes: 5 additions & 3 deletions stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1694,10 +1694,12 @@ struct String(

return output

fn splitlines(self, keepends: Bool = False) -> List[String]:
fn splitlines(
ref self, keepends: Bool = False
) -> List[StringSlice[__origin_of(self)]]:
"""Split the string at line boundaries. This corresponds to Python's
[universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
`"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`.

Args:
Expand All @@ -1706,7 +1708,7 @@ struct String(
Returns:
A List of Strings containing the input split by line boundaries.
"""
return _to_string_list(self.as_string_slice().splitlines(keepends))
return self.as_string_slice().splitlines(keepends)

fn replace(self, old: String, new: String) -> String:
"""Return a copy of the string with all occurrences of substring `old`
Expand Down
38 changes: 17 additions & 21 deletions stdlib/src/collections/string/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1111,29 +1111,23 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
offset += b_len
return length != 0

fn splitlines[
O: ImmutableOrigin, //
](self: StringSlice[O], keepends: Bool = False) -> List[StringSlice[O]]:
fn splitlines(self, keepends: Bool = False) -> List[Self]:
"""Split the string at line boundaries. This corresponds to Python's
[universal newlines:](
https://docs.python.org/3/library/stdtypes.html#str.splitlines)
`"\\r\\n"` and `"\\t\\n\\v\\f\\r\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`.

Parameters:
O: The immutable origin.

Args:
keepends: If True, line breaks are kept in the resulting strings.

Returns:
A List of Strings containing the input split by line boundaries.
"""

# highly performance sensitive code, benchmark before touching
alias `\r` = UInt8(ord("\r"))
alias `\n` = UInt8(ord("\n"))

output = List[StringSlice[O]](capacity=128) # guessing
output = List[Self](capacity=128) # guessing
var ptr = self.unsafe_ptr()
var length = self.byte_length()
var offset = 0
Expand Down Expand Up @@ -1163,7 +1157,7 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]](
eol_start += char_len

var str_len = eol_start - offset + int(keepends) * eol_length
var s = StringSlice[O](ptr=ptr + offset, length=str_len)
var s = Self(ptr=ptr + offset, length=str_len)
output.append(s)
offset = eol_start + eol_length

Expand Down Expand Up @@ -1209,29 +1203,30 @@ fn _to_string_list[
len_fn: fn (T) -> Int,
unsafe_ptr_fn: fn (T) -> UnsafePointer[Byte],
](items: List[T]) -> List[String]:
i_len = len(items)
i_ptr = items.unsafe_ptr()
out_ptr = UnsafePointer[String].alloc(i_len)
var i_len = len(items)
var i_ptr = items.unsafe_ptr()
var out_ptr = UnsafePointer[String].alloc(i_len)

for i in range(i_len):
og_len = len_fn(i_ptr[i])
f_len = og_len + 1 # null terminator
p = UnsafePointer[Byte].alloc(f_len)
og_ptr = unsafe_ptr_fn(i_ptr[i])
var og_len = len_fn(i_ptr[i])
var f_len = og_len + 1 # null terminator
var p = UnsafePointer[Byte].alloc(f_len)
var og_ptr = unsafe_ptr_fn(i_ptr[i])
memcpy(p, og_ptr, og_len)
p[og_len] = 0 # null terminator
buf = String._buffer_type(ptr=p, length=f_len, capacity=f_len)
var buf = String._buffer_type(ptr=p, length=f_len, capacity=f_len)
(out_ptr + i).init_pointee_move(String(buf^))
return List[String](ptr=out_ptr, length=i_len, capacity=i_len)


@always_inline
fn _to_string_list[
O: ImmutableOrigin, //
fn to_string_list[
mut: Bool, O: Origin[mut], //
](items: List[StringSlice[O]]) -> List[String]:
"""Create a list of Strings **copying** the existing data.

Parameters:
mut: The mutability of the origin.
O: The origin of the data.

Args:
Expand All @@ -1251,12 +1246,13 @@ fn _to_string_list[


@always_inline
fn _to_string_list[
O: ImmutableOrigin, //
fn to_string_list[
mut: Bool, O: Origin[mut], //
](items: List[Span[Byte, O]]) -> List[String]:
"""Create a list of Strings **copying** the existing data.

Parameters:
mut: The mutability of the origin.
O: The origin of the data.

Args:
Expand Down
20 changes: 19 additions & 1 deletion stdlib/test/builtin/test_string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ from testing import (
assert_raises,
assert_true,
)
from collections.string.string_slice import StringSlice, to_string_list


def test_add():
Expand Down Expand Up @@ -442,7 +443,8 @@ def test_split():


def test_splitlines():
alias L = List[String]
alias L = List[StringSlice[StaticConstantOrigin]]

# Test with no line breaks
assert_equal("hello world".splitlines(), L("hello world"))

Expand Down Expand Up @@ -478,6 +480,22 @@ def test_splitlines():
L("hello\x1c", "world\x1d", "mojo\x1e", "language\x1e"),
)

# test \x85 \u2028 \u2029
var next_line = String(List[UInt8](0xC2, 0x85, 0))
"""TODO: \\x85"""
var unicode_line_sep = String(List[UInt8](0xE2, 0x80, 0xA8, 0))
"""TODO: \\u2028"""
var unicode_paragraph_sep = String(List[UInt8](0xE2, 0x80, 0xA9, 0))
"""TODO: \\u2029"""

for i in List(next_line, unicode_line_sep, unicode_paragraph_sep):
u = i[]
item = String("").join("hello", u, "world", u, "mojo", u, "language", u)
s = StringSlice(item)
assert_equal(s.splitlines(), hello_mojo)
items = List("hello" + u, "world" + u, "mojo" + u, "language" + u)
assert_equal(to_string_list(s.splitlines(keepends=True)), items)


def test_float_conversion():
assert_equal(("4.5").__float__(), 4.5)
Expand Down
38 changes: 19 additions & 19 deletions stdlib/test/collections/string/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ from testing import (
assert_true,
)

from collections.string import StringSlice
from collections.string.string import (
_calc_initial_buffer_size_int32,
_calc_initial_buffer_size_int64,
_isspace,
)
from collections.string.string_slice import StringSlice, to_string_list
from memory import UnsafePointer
from python import Python
from utils import StringRef
Expand Down Expand Up @@ -755,17 +755,18 @@ def test_split():


def test_splitlines():
alias L = List[String]
alias L = List[StringSlice[StaticConstantOrigin]]

# Test with no line breaks
assert_equal(String("hello world").splitlines(), L("hello world"))
assert_equal("hello world".splitlines(), L("hello world"))

# Test with line breaks
assert_equal(String("hello\nworld").splitlines(), L("hello", "world"))
assert_equal(String("hello\rworld").splitlines(), L("hello", "world"))
assert_equal(String("hello\r\nworld").splitlines(), L("hello", "world"))
assert_equal("hello\nworld".splitlines(), L("hello", "world"))
assert_equal("hello\rworld".splitlines(), L("hello", "world"))
assert_equal("hello\r\nworld".splitlines(), L("hello", "world"))

# Test with multiple different line breaks
s1 = String("hello\nworld\r\nmojo\rlanguage\r\n")
s1 = "hello\nworld\r\nmojo\rlanguage\r\n"
hello_mojo = L("hello", "world", "mojo", "language")
assert_equal(s1.splitlines(), hello_mojo)
assert_equal(
Expand All @@ -774,39 +775,38 @@ def test_splitlines():
)

# Test with an empty string
assert_equal(String("").splitlines(), L())
assert_equal("".splitlines(), L())
# test \v \f \x1c \x1d
s2 = String("hello\vworld\fmojo\x1clanguage\x1d")
s2 = "hello\vworld\fmojo\x1clanguage\x1d"
assert_equal(s2.splitlines(), hello_mojo)
assert_equal(
s2.splitlines(keepends=True),
L("hello\v", "world\f", "mojo\x1c", "language\x1d"),
)

# test \x1c \x1d \x1e
s3 = String("hello\x1cworld\x1dmojo\x1elanguage\x1e")
s3 = "hello\x1cworld\x1dmojo\x1elanguage\x1e"
assert_equal(s3.splitlines(), hello_mojo)
assert_equal(
s3.splitlines(keepends=True),
L("hello\x1c", "world\x1d", "mojo\x1e", "language\x1e"),
)

# test \x85 \u2028 \u2029
var next_line = List[UInt8](0xC2, 0x85, 0)
var next_line = String(List[UInt8](0xC2, 0x85, 0))
"""TODO: \\x85"""
var unicode_line_sep = List[UInt8](0xE2, 0x80, 0xA8, 0)
var unicode_line_sep = String(List[UInt8](0xE2, 0x80, 0xA8, 0))
"""TODO: \\u2028"""
var unicode_paragraph_sep = List[UInt8](0xE2, 0x80, 0xA9, 0)
var unicode_paragraph_sep = String(List[UInt8](0xE2, 0x80, 0xA9, 0))
"""TODO: \\u2029"""

for i in List(next_line, unicode_line_sep, unicode_paragraph_sep):
u = String(i[])
u = i[]
item = String("").join("hello", u, "world", u, "mojo", u, "language", u)
assert_equal(item.splitlines(), hello_mojo)
assert_equal(
item.splitlines(keepends=True),
L("hello" + u, "world" + u, "mojo" + u, "language" + u),
)
s = StringSlice(item)
assert_equal(s.splitlines(), hello_mojo)
items = List("hello" + u, "world" + u, "mojo" + u, "language" + u)
assert_equal(to_string_list(s.splitlines(keepends=True)), items)


def test_isupper():
Expand Down
28 changes: 10 additions & 18 deletions stdlib/test/collections/string/test_string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from testing import assert_equal, assert_false, assert_true, assert_raises
from collections.string.string_slice import (
StringSlice,
_count_utf8_continuation_bytes,
to_string_list,
)
from collections.string._utf8_validation import _is_valid_utf8
from memory import Span, UnsafePointer
Expand Down Expand Up @@ -480,27 +481,18 @@ def test_count_utf8_continuation_bytes():


def test_splitlines():
alias S = StringSlice[StaticConstantOrigin]
alias L = List[StringSlice[StaticConstantOrigin]]

# FIXME: remove once StringSlice conforms to TestableCollectionElement
fn _assert_equal[
O1: ImmutableOrigin
](l1: List[StringSlice[O1]], l2: List[String]) raises:
assert_equal(len(l1), len(l2))
for i in range(len(l1)):
assert_equal(str(l1[i]), l2[i])

# Test with no line breaks
assert_equal(S("hello world").splitlines(), L("hello world"))
assert_equal("hello world".splitlines(), L("hello world"))

# Test with line breaks
assert_equal(S("hello\nworld").splitlines(), L("hello", "world"))
assert_equal(S("hello\rworld").splitlines(), L("hello", "world"))
assert_equal(S("hello\r\nworld").splitlines(), L("hello", "world"))
assert_equal("hello\nworld".splitlines(), L("hello", "world"))
assert_equal("hello\rworld".splitlines(), L("hello", "world"))
assert_equal("hello\r\nworld".splitlines(), L("hello", "world"))

# Test with multiple different line breaks
s1 = S("hello\nworld\r\nmojo\rlanguage\r\n")
s1 = "hello\nworld\r\nmojo\rlanguage\r\n"
hello_mojo = L("hello", "world", "mojo", "language")
assert_equal(s1.splitlines(), hello_mojo)
assert_equal(
Expand All @@ -509,17 +501,17 @@ def test_splitlines():
)

# Test with an empty string
assert_equal(S("").splitlines(), L())
assert_equal("".splitlines(), L())
# test \v \f \x1c \x1d
s2 = S("hello\vworld\fmojo\x1clanguage\x1d")
s2 = "hello\vworld\fmojo\x1clanguage\x1d"
assert_equal(s2.splitlines(), hello_mojo)
assert_equal(
s2.splitlines(keepends=True),
L("hello\v", "world\f", "mojo\x1c", "language\x1d"),
)

# test \x1c \x1d \x1e
s3 = S("hello\x1cworld\x1dmojo\x1elanguage\x1e")
s3 = "hello\x1cworld\x1dmojo\x1elanguage\x1e"
assert_equal(s3.splitlines(), hello_mojo)
assert_equal(
s3.splitlines(keepends=True),
Expand All @@ -540,7 +532,7 @@ def test_splitlines():
s = StringSlice(item)
assert_equal(s.splitlines(), hello_mojo)
items = List("hello" + u, "world" + u, "mojo" + u, "language" + u)
_assert_equal(s.splitlines(keepends=True), items)
assert_equal(to_string_list(s.splitlines(keepends=True)), items)


def test_rstrip():
Expand Down
Loading