diff --git a/mojo/docs/changelog.md b/mojo/docs/changelog.md index 750d44f331..e51eba30c2 100644 --- a/mojo/docs/changelog.md +++ b/mojo/docs/changelog.md @@ -126,6 +126,11 @@ what we publish. an issue with the any origin parameter extending the lifetime of unrelated local variables for this common method. +- `Span` now has `find()` and `rfind()` methods which work for any + `Span[Scalar[D]]` e.g. `Span[Byte]`. The `rfind()` implementation is + now vectorized. PR [#3548](https://github.com/modularml/mojo/pull/3548) by + [@martinvuyk](https://github.com/martinvuyk). + ### GPU changes - You can now skip compiling a GPU kernel first and then enqueueing it: diff --git a/mojo/stdlib/src/builtin/string_literal.mojo b/mojo/stdlib/src/builtin/string_literal.mojo index f72938df4b..afbdc09c9a 100644 --- a/mojo/stdlib/src/builtin/string_literal.mojo +++ b/mojo/stdlib/src/builtin/string_literal.mojo @@ -624,7 +624,7 @@ struct StringLiteral( Returns: The offset of `substr` relative to the beginning of the string. """ - return self.as_string_slice().find(substr, start=start) + return self.as_string_slice().find(substr.as_string_slice(), start) fn rfind(self, substr: StringLiteral, start: Int = 0) -> Int: """Finds the offset of the last occurrence of `substr` starting at @@ -637,7 +637,7 @@ struct StringLiteral( Returns: The offset of `substr` relative to the beginning of the string. """ - return self.as_string_slice().rfind(substr, start=start) + return self.as_string_slice().rfind(substr.as_string_slice(), start) fn replace(self, old: StringLiteral, new: StringLiteral) -> StringLiteral: """Return a copy of the string with all occurrences of substring `old` diff --git a/mojo/stdlib/src/collections/string/string.mojo b/mojo/stdlib/src/collections/string/string.mojo index 85e121f004..ae87033817 100644 --- a/mojo/stdlib/src/collections/string/string.mojo +++ b/mojo/stdlib/src/collections/string/string.mojo @@ -1513,7 +1513,7 @@ struct String( return self._interleave(new) var occurrences = self.count(old) - if occurrences == -1: + if occurrences == len(self) + 1: return self var self_start = self.unsafe_ptr() diff --git a/mojo/stdlib/src/collections/string/string_slice.mojo b/mojo/stdlib/src/collections/string/string_slice.mojo index 7ec78dc993..1d123e646e 100644 --- a/mojo/stdlib/src/collections/string/string_slice.mojo +++ b/mojo/stdlib/src/collections/string/string_slice.mojo @@ -32,14 +32,15 @@ from collections.string._unicode import ( ) from hashlib._hasher import _HashableWithHasher, _Hasher from os import PathLike, abort -from sys import bitwidthof, simdwidthof +from sys import simdwidthof from sys.ffi import c_char from sys.intrinsics import likely, unlikely -from bit import count_leading_zeros, count_trailing_zeros -from memory import Span, UnsafePointer, memcmp, memcpy, pack_bits +from bit import count_leading_zeros +from memory import Span, UnsafePointer, memcmp, memcpy from memory.memory import _memcmp_impl_unconstrained + alias StaticString = StringSlice[StaticConstantOrigin] """An immutable static string slice.""" @@ -1635,28 +1636,13 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]]( The offset in bytes of `substr` relative to the beginning of the string. """ - if not substr: - return 0 - - if self.byte_length() < substr.byte_length() + start: - return -1 - - # The substring to search within, offset from the beginning if `start` - # is positive, and offset from the end if `start` is negative. - var haystack_str = self._from_start(start) - - var loc = _memmem( - haystack_str.unsafe_ptr(), - haystack_str.byte_length(), - substr.unsafe_ptr(), - substr.byte_length(), + # FIXME(#3526): this should return unicode codepoint offsets + return ( + self.as_bytes() + .get_immutable() + .find(substr.as_bytes().get_immutable(), start) ) - if not loc: - return -1 - - return Int(loc) - Int(self.unsafe_ptr()) - fn rfind(self, substr: StringSlice, start: Int = 0) -> Int: """Finds the offset in bytes of the last occurrence of `substr` starting at `start`. If not found, returns `-1`. @@ -1670,28 +1656,13 @@ struct StringSlice[mut: Bool, //, origin: Origin[mut]]( The offset in bytes of `substr` relative to the beginning of the string. """ - if not substr: - return len(self) - - if len(self) < len(substr) + start: - return -1 - - # The substring to search within, offset from the beginning if `start` - # is positive, and offset from the end if `start` is negative. - var haystack_str = self._from_start(start) - - var loc = _memrmem( - haystack_str.unsafe_ptr(), - len(haystack_str), - substr.unsafe_ptr(), - len(substr), + # FIXME(#3526): this should return unicode codepoint offsets + return ( + self.as_bytes() + .get_immutable() + .rfind(substr.as_bytes().get_immutable(), start) ) - if not loc: - return -1 - - return Int(loc) - Int(self.unsafe_ptr()) - fn isspace(self) -> Bool: """Determines whether every character in the given StringSlice is a python whitespace String. This corresponds to Python's @@ -2109,89 +2080,6 @@ fn _unsafe_strlen(owned ptr: UnsafePointer[Byte]) -> Int: return len -@always_inline -fn _align_down(value: Int, alignment: Int) -> Int: - return value._positive_div(alignment) * alignment - - -@always_inline -fn _memchr[ - type: DType -]( - source: UnsafePointer[Scalar[type]], char: Scalar[type], len: Int -) -> UnsafePointer[Scalar[type]]: - if not len: - return UnsafePointer[Scalar[type]]() - alias bool_mask_width = simdwidthof[DType.bool]() - var first_needle = SIMD[type, bool_mask_width](char) - var vectorized_end = _align_down(len, bool_mask_width) - - for i in range(0, vectorized_end, bool_mask_width): - var bool_mask = source.load[width=bool_mask_width](i) == first_needle - var mask = pack_bits(bool_mask) - if mask: - return source + Int(i + count_trailing_zeros(mask)) - - for i in range(vectorized_end, len): - if source[i] == char: - return source + i - return UnsafePointer[Scalar[type]]() - - -@always_inline -fn _memmem[ - type: DType -]( - haystack: UnsafePointer[Scalar[type]], - haystack_len: Int, - needle: UnsafePointer[Scalar[type]], - needle_len: Int, -) -> UnsafePointer[Scalar[type]]: - if not needle_len: - return haystack - if needle_len > haystack_len: - return UnsafePointer[Scalar[type]]() - if needle_len == 1: - return _memchr[type](haystack, needle[0], haystack_len) - - alias bool_mask_width = simdwidthof[DType.bool]() - var vectorized_end = _align_down( - haystack_len - needle_len + 1, bool_mask_width - ) - - var first_needle = SIMD[type, bool_mask_width](needle[0]) - var last_needle = SIMD[type, bool_mask_width](needle[needle_len - 1]) - - for i in range(0, vectorized_end, bool_mask_width): - var first_block = haystack.load[width=bool_mask_width](i) - var last_block = haystack.load[width=bool_mask_width]( - i + needle_len - 1 - ) - - var eq_first = first_needle == first_block - var eq_last = last_needle == last_block - - var bool_mask = eq_first & eq_last - var mask = pack_bits(bool_mask) - - while mask: - var offset = Int(i + count_trailing_zeros(mask)) - if memcmp(haystack + offset + 1, needle + 1, needle_len - 1) == 0: - return haystack + offset - mask = mask & (mask - 1) - - # remaining partial block compare using byte-by-byte - # - for i in range(vectorized_end, haystack_len - needle_len + 1): - if haystack[i] != needle[0]: - continue - - if memcmp(haystack + i + 1, needle + 1, needle_len - 1) == 0: - return haystack + i - - return UnsafePointer[Scalar[type]]() - - @always_inline fn _is_utf8_continuation_byte[ w: Int diff --git a/mojo/stdlib/src/memory/span.mojo b/mojo/stdlib/src/memory/span.mojo index 70a0482a52..fc6d1ffe91 100644 --- a/mojo/stdlib/src/memory/span.mojo +++ b/mojo/stdlib/src/memory/span.mojo @@ -20,8 +20,11 @@ from memory import Span ``` """ +from bit import count_trailing_zeros, count_leading_zeros +from builtin.dtype import _uint_type_of_width from collections import InlineArray -from sys.info import simdwidthof +from memory import Pointer, UnsafePointer, memcmp, pack_bits +from sys import simdwidthof from memory import Pointer, UnsafePointer from memory.unsafe_pointer import _default_alignment @@ -441,8 +444,7 @@ struct Span[ return not self == rhs fn fill[origin: MutableOrigin, //](self: Span[T, origin], value: T): - """ - Fill the memory that a span references with a given value. + """Fill the memory that a span references with a given value. Parameters: origin: The inferred mutable origin of the data within the Span. @@ -461,8 +463,7 @@ struct Span[ address_space=address_space, alignment=alignment, ]: - """ - Return an immutable version of this span. + """Return an immutable version of this span. Returns: A span covering the same elements, but without mutability. @@ -473,3 +474,346 @@ struct Span[ address_space=address_space, alignment=alignment, ](ptr=self._data, length=self._len) + + fn find[ + O1: ImmutableOrigin, + O2: ImmutableOrigin, + D: DType, //, + from_left: Bool = True, + single_value: Bool = False, + unsafe_dont_normalize: Bool = False, + ]( + self: Span[Scalar[D], O1], subseq: Span[Scalar[D], O2], start: Int + ) -> Int: + """Finds the offset of the first occurrence of `subseq` starting at + `start`. If not found, returns `-1`. + + Parameters: + O1: The immutable origin of `self`. + O2: The immutable origin of `subseq`. + D: The `DType` of the Scalar. + from_left: Whether to search the first occurrence from the left. + single_value: Whether to search with the `subseq`s first value. + unsafe_dont_normalize: Whether to not normalize the index (no + negative indexing, no bounds checks at runtime. There is still + a `debug_assert(0 <= start < len(self))`). + + Args: + subseq: The sub sequence to find. + start: The offset from which to find. + + Returns: + The offset of `subseq` relative to the beginning of the `Span`. + + Notes: + The function works on an empty span, always returning `-1`. + """ + var _len = len(self) + + if not subseq: + + @parameter + if from_left: + return 0 + else: + return _len + + var n_s: Int + + # _memXXX implementations already handle when haystack_len == 0 + @parameter + if unsafe_dont_normalize: + debug_assert(0 <= start < _len + Int(_len == 0), "out of bounds") + n_s = start + else: + var v = start + _len * Int(start < 0) + n_s = v * Int(v < _len and v > 0) + _len * Int(v >= _len) + var s_ptr = self.unsafe_ptr() + var haystack = __type_of(self)(ptr=s_ptr + n_s, length=_len - n_s) + var loc: UnsafePointer[Scalar[D]] + + @parameter + if from_left and not single_value: + loc = _memmem(haystack, subseq) + elif from_left: + loc = _memchr(haystack, subseq.unsafe_ptr()[0]) + elif not single_value: + loc = _memrmem(haystack, subseq) + else: + loc = _memrchr(haystack, subseq.unsafe_ptr()[0]) + + return (Int(loc) - Int(s_ptr) + 1) * Int(Bool(loc)) - 1 + + fn find[ + O1: ImmutableOrigin, + O2: ImmutableOrigin, + D: DType, //, + single_value: Bool = False, + ](self: Span[Scalar[D], O1], subseq: Span[Scalar[D], O2]) -> Int: + """Finds the offset of the first occurrence of `subseq`. If not found, + returns `-1`. + + Parameters: + O1: The immutable origin of `self`. + O2: The immutable origin of `subseq`. + D: The `DType` of the Scalar. + single_value: Whether to search with the `subseq`s first value. + + Args: + subseq: The sub sequence to find. + + Returns: + The offset of `subseq` relative to the beginning of the `Span`. + + Notes: + The function works on an empty span, always returning `-1`. + """ + return self.find[single_value=single_value, unsafe_dont_normalize=True]( + subseq, 0 + ) + + @always_inline + fn rfind[ + O1: ImmutableOrigin, + O2: ImmutableOrigin, + D: DType, //, + single_value: Bool = False, + ]( + self: Span[Scalar[D], O1], subseq: Span[Scalar[D], O2], start: Int + ) -> Int: + """Finds the offset of the last occurrence of `subseq` starting at + `start`. If not found, returns `-1`. + + Parameters: + O1: The immutable origin of `self`. + O2: The immutable origin of `subseq`. + D: The `DType` of the Scalar. + single_value: Whether to search with the `subseq`s first value. + + Args: + subseq: The sub sequence to find. + start: The offset from which to find. + + Returns: + The offset of `subseq` relative to the beginning of the `Span`. + + Notes: + The function works on an empty span, always returning `-1`. + """ + return self.find[from_left=False, single_value=single_value]( + subseq, start + ) + + @always_inline + fn rfind[ + O1: ImmutableOrigin, + O2: ImmutableOrigin, + D: DType, //, + single_value: Bool = False, + ](self: Span[Scalar[D], O1], subseq: Span[Scalar[D], O2]) -> Int: + """Finds the offset of the last occurrence of `subseq`. If not found, + returns `-1`. + + Parameters: + O1: The immutable origin of `self`. + O2: The immutable origin of `subseq`. + D: The `DType` of the Scalar. + single_value: Whether to search with the `subseq`s first value. + + Args: + subseq: The sub sequence to find. + + Returns: + The offset of `subseq` relative to the beginning of the `Span`. + + Notes: + The function works on an empty span, always returning `-1`. + """ + return self.find[ + from_left=False, + single_value=single_value, + unsafe_dont_normalize=True, + ](subseq, 0) + + +# ===----------------------------------------------------------------------===# +# Utilities +# ===----------------------------------------------------------------------===# + + +@always_inline +fn _align_down(value: Int, alignment: Int) -> Int: + return value._positive_div(alignment) * alignment + + +@always_inline +fn _memchr[ + O: ImmutableOrigin, D: DType, // +]( + span: Span[Scalar[D], O], + char: Scalar[D], + out output: UnsafePointer[Scalar[D]], +): + var haystack = span.unsafe_ptr() + var length = len(span) + alias bool_mask_width = simdwidthof[DType.bool]() + var first_needle = SIMD[D, bool_mask_width](char) + var vectorized_end = _align_down(length, bool_mask_width) + + for i in range(0, vectorized_end, bool_mask_width): + var bool_mask = haystack.load[width=bool_mask_width](i) == first_needle + var mask = pack_bits(bool_mask) + if mask: + output = haystack + Int(i + count_trailing_zeros(mask)) + return + + for i in range(vectorized_end, length): + if haystack[i] == char: + output = haystack + i + return + + output = UnsafePointer[Scalar[D]]() + + +@always_inline +fn _memmem[ + O1: ImmutableOrigin, O2: ImmutableOrigin, D: DType, // +]( + haystack_span: Span[Scalar[D], O1], + needle_span: Span[Scalar[D], O2], + out output: UnsafePointer[Scalar[D]], +): + var haystack = haystack_span.unsafe_ptr() + var haystack_len = len(haystack_span) + var needle = needle_span.unsafe_ptr() + var needle_len = len(needle_span) + debug_assert(needle_len > 0, "needle_len must be > 0") + if needle_len == 1: + output = _memchr(haystack_span, needle[0]) + return + elif needle_len > haystack_len: + output = UnsafePointer[Scalar[D]]() + return + + alias bool_mask_width = simdwidthof[DType.bool]() + var vectorized_end = _align_down( + haystack_len - needle_len + 1, bool_mask_width + ) + + var first_needle = SIMD[D, bool_mask_width](needle[0]) + var last_needle = SIMD[D, bool_mask_width](needle[needle_len - 1]) + + for i in range(0, vectorized_end, bool_mask_width): + var first_block = haystack.load[width=bool_mask_width](i) + var last_block = haystack.load[width=bool_mask_width]( + i + needle_len - 1 + ) + + var bool_mask = (first_needle == first_block) & ( + last_needle == last_block + ) + var mask = pack_bits(bool_mask) + + while mask: + var offset = Int(i + count_trailing_zeros(mask)) + if memcmp(haystack + offset + 1, needle + 1, needle_len - 1) == 0: + output = haystack + offset + return + mask = mask & (mask - 1) + + for i in range(vectorized_end, haystack_len - needle_len + 1): + if haystack[i] != needle[0]: + continue + + if memcmp(haystack + i + 1, needle + 1, needle_len - 1) == 0: + output = haystack + i + return + output = UnsafePointer[Scalar[D]]() + + +@always_inline +fn _memrchr[ + O: ImmutableOrigin, D: DType, // +]( + span: Span[Scalar[D], O], + char: Scalar[D], + out output: UnsafePointer[Scalar[D]], +): + var haystack = span.unsafe_ptr() + var length = len(span) + alias bool_mask_width = simdwidthof[DType.bool]() + var first_needle = SIMD[D, bool_mask_width](char) + var vectorized_end = _align_down(length, bool_mask_width) + + for i in reversed(range(vectorized_end, length)): + if haystack[i] == char: + output = haystack + i + return + + for i in reversed(range(0, vectorized_end, bool_mask_width)): + var bool_mask = haystack.load[width=bool_mask_width](i) == first_needle + var mask = pack_bits(bool_mask) + if mask: + var zeros = Int(count_leading_zeros(mask)) + 1 + output = haystack + (i + bool_mask_width - zeros) + return + + output = UnsafePointer[Scalar[D]]() + + +@always_inline +fn _memrmem[ + O1: ImmutableOrigin, O2: ImmutableOrigin, D: DType, // +]( + haystack_span: Span[Scalar[D], O1], + needle_span: Span[Scalar[D], O2], + out output: UnsafePointer[Scalar[D]], +): + var haystack = haystack_span.unsafe_ptr() + var haystack_len = len(haystack_span) + var needle = needle_span.unsafe_ptr() + var needle_len = len(needle_span) + debug_assert(needle_len > 0, "needle_len must be > 0") + + if needle_len == 1: + output = _memrchr(haystack_span, needle[0]) + return + elif needle_len > haystack_len: + output = UnsafePointer[Scalar[D]]() + return + + alias bool_mask_width = simdwidthof[DType.bool]() + var vectorized_end = _align_down( + haystack_len - needle_len + 1, bool_mask_width + ) + + for i in reversed(range(vectorized_end, haystack_len - needle_len + 1)): + if haystack[i] != needle[0]: + continue + + if memcmp(haystack + i + 1, needle + 1, needle_len - 1) == 0: + output = haystack + i + return + + var first_needle = SIMD[D, bool_mask_width](needle[0]) + var last_needle = SIMD[D, bool_mask_width](needle[needle_len - 1]) + + for i in reversed(range(0, vectorized_end, bool_mask_width)): + var first_block = haystack.load[width=bool_mask_width](i) + var last_block = haystack.load[width=bool_mask_width]( + i + needle_len - 1 + ) + + var bool_mask = (first_needle == first_block) & ( + last_needle == last_block + ) + var mask = pack_bits(bool_mask) + + while mask: + var offset = i + bool_mask_width - Int(count_leading_zeros(mask)) + if memcmp(haystack + offset, needle + 1, needle_len - 1) == 0: + output = haystack + offset - 1 + return + mask = mask & (mask - 1) + + output = UnsafePointer[Scalar[D]]() diff --git a/mojo/stdlib/src/os/path/path.mojo b/mojo/stdlib/src/os/path/path.mojo index bb9a14a9ca..ba77cb2f60 100644 --- a/mojo/stdlib/src/os/path/path.mojo +++ b/mojo/stdlib/src/os/path/path.mojo @@ -347,7 +347,7 @@ fn join(owned path: String, *paths: String) -> String: # ===----------------------------------------------------------------------=== # -def split[PathLike: os.PathLike, //](path: PathLike) -> (String, String): +fn split[PathLike: os.PathLike, //](path: PathLike) -> (String, String): """ Split a given pathname into two components: head and tail. This is useful for separating the directory path from the filename. If the input path ends @@ -365,8 +365,8 @@ def split[PathLike: os.PathLike, //](path: PathLike) -> (String, String): Returns: A tuple containing two strings: (head, tail). """ - fspath = path.__fspath__() - i = fspath.rfind(os.sep) + 1 + var fspath = path.__fspath__() + var i = fspath.rfind(os.sep) + 1 head, tail = fspath[:i], fspath[i:] if head and head != String(os.sep) * len(head): head = String(head.rstrip(sep)) diff --git a/mojo/stdlib/test/collections/string/test_string.mojo b/mojo/stdlib/test/collections/string/test_string.mojo index 42e6adf1cc..da0ce37691 100644 --- a/mojo/stdlib/test/collections/string/test_string.mojo +++ b/mojo/stdlib/test/collections/string/test_string.mojo @@ -564,6 +564,10 @@ def test_find(): assert_equal(6, str.find("world")) assert_equal(-1, str.find("universe")) + # Empty string and substring. + assert_equal(String("").find("ab"), -1) + assert_equal(String("foo").find(""), 0) + # Test find() offset is absolute, not relative (issue mojo/#1355) var str2 = String("...a") assert_equal(3, str2.find("a", 0)) @@ -576,6 +580,24 @@ def test_find(): assert_equal(7, str.find("o", -5)) assert_equal(-1, String("abc").find("abcd")) + assert_equal(0, (String("0") * 100).find("0")) + assert_equal(0, (String("0") * 100).find("00")) + assert_equal(9, String.DIGITS.find("9")) + assert_equal(25, String.ASCII_LETTERS.find("z")) + assert_equal(25, String.ASCII_LETTERS.find("zA")) + assert_equal(25, String.ASCII_LETTERS.find("zABCD")) + for i in range(2, 50): + assert_equal(51, (String.ASCII_LETTERS * i).find("Za")) + + # Special characters. + # FIXME: once find works by unicode codepoints + # assert_equal(String("こんにちは").find("にち"), 2) + # assert_equal(String("🔥🔥").find("🔥"), 0) + + # pathlike operations + p = "/subdir/subdir/subdir/subdir/mojo/stdlib/test/os/path/test_split.mojo" + file_name = "test_split.mojo" + assert_equal(len(p) - len(file_name), p.find("test_split.mojo")) def test_replace(): @@ -618,12 +640,25 @@ def test_rfind(): assert_equal(String("hello world").rfind("w", -5), 6) assert_equal(-1, String("abc").rfind("abcd")) + assert_equal(99, (String("0") * 100).rfind("0")) + assert_equal(98, (String("0") * 100).rfind("00")) + assert_equal(9, String.DIGITS.rfind("9")) + assert_equal(25, String.ASCII_LETTERS.rfind("z")) + assert_equal(25, String.ASCII_LETTERS.rfind("zA")) + assert_equal(25, String.ASCII_LETTERS.rfind("zABCD")) + for i in range(2, 50): + assert_equal(0, ("123" + String.ASCII_LETTERS * i).rfind("123")) # Special characters. - # TODO(#26444): Support unicode strings. + # FIXME: once find works by unicode codepoints # assert_equal(String("こんにちは").rfind("にち"), 2) # assert_equal(String("🔥🔥").rfind("🔥"), 1) + # pathlike operations + p = "/subdir/subdir/subdir/subdir/mojo/stdlib/test/os/path/test_split.mojo" + file_name = "test_split.mojo" + assert_equal(len(p) - len(file_name) - 1, p.rfind("/")) + def test_split(): # empty separators default to whitespace @@ -928,6 +963,7 @@ def test_rstrip(): var str1 = String("string") assert_true(str1.rstrip() == "string") + assert_true(str1.rstrip("g") == "strin") var str2 = String("something \t\n\t\v\f") assert_true(str2.rstrip() == "something") @@ -954,6 +990,7 @@ def test_lstrip(): var str1 = String("string") assert_true(str1.lstrip() == "string") + assert_true(str1.lstrip("s") == "tring") var str2 = String(" \t\n\t\v\fsomething") assert_true(str2.lstrip() == "something") diff --git a/mojo/stdlib/test/os/path/test_split.mojo b/mojo/stdlib/test/os/path/test_split.mojo index 261238f472..a28e636f53 100644 --- a/mojo/stdlib/test/os/path/test_split.mojo +++ b/mojo/stdlib/test/os/path/test_split.mojo @@ -63,5 +63,11 @@ def main(): # Test with __source_location() source_location = __source_location().file_name + s_len = source_location.byte_length() + file_name = "test_split.mojo" + assert_equal( + source_location.__fspath__().rfind(os.sep), s_len - len(file_name) - 1 + ) head, tail = split(source_location) + assert_equal(tail, file_name) assert_equal(head + os.sep + tail, source_location)