diff --git a/stdlib/src/collections/string/codepoint.mojo b/stdlib/src/collections/string/codepoint.mojo index 0452f15621..934d74ce49 100644 --- a/stdlib/src/collections/string/codepoint.mojo +++ b/stdlib/src/collections/string/codepoint.mojo @@ -18,6 +18,7 @@ from collections.string import StringSlice from bit import count_leading_zeros from memory import UnsafePointer +from sys.intrinsics import likely @always_inline @@ -450,13 +451,13 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable): return self._scalar_value @always_inline - fn unsafe_write_utf8(self, ptr: UnsafePointer[Byte]) -> UInt: + fn unsafe_write_utf8[ + optimize_ascii: Bool = True + ](self, ptr: UnsafePointer[Byte]) -> UInt: """Shift unicode to utf8 representation. - Safety: - `ptr` MUST point to at least `self.utf8_byte_length()` allocated - bytes or else an out-of-bounds write will occur, which is undefined - behavior. + Parameters: + optimize_ascii: Optimize for languages with mostly ASCII characters. Args: ptr: Pointer value to write the encoded UTF-8 bytes. Must validly @@ -466,6 +467,11 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable): Returns: Returns the number of bytes written. + Safety: + `ptr` MUST point to at least `self.utf8_byte_length()` allocated + bytes or else an out-of-bounds write will occur, which is undefined + behavior. + ### Unicode (represented as UInt32 BE) to UTF-8 conversion: - 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa - a @@ -483,18 +489,30 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable): var num_bytes = self.utf8_byte_length() - if num_bytes == 1: - ptr[0] = UInt8(c) - return 1 - - var shift = 6 * (num_bytes - 1) - var mask = UInt8(0xFF) >> UInt8(num_bytes + 1) - var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) - ptr[0] = ((c >> shift) & mask) | num_bytes_marker - for i in range(1, num_bytes): - shift -= 6 - ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 - + @parameter + if optimize_ascii: + # FIXME(#933): can't run LLVM intrinsic at compile time + # if likely(num_bytes == 1): + if num_bytes == 1: + ptr[0] = UInt8(c) + return 1 + var shift = 6 * (num_bytes - 1) + var mask = UInt8(0xFF) >> (num_bytes + 1) + var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) + ptr[0] = ((c >> shift) & mask) | num_bytes_marker + for i in range(1, num_bytes): + shift -= 6 + ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 + else: + var shift = 6 * (num_bytes - 1) + var mask = UInt8(0xFF) >> (num_bytes + Int(num_bytes > 1)) + var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes) + ptr[0] = ((c >> shift) & mask) | ( + num_bytes_marker & -Int(num_bytes != 1) + ) + for i in range(1, num_bytes): + shift -= 6 + ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000 return num_bytes @always_inline @@ -509,16 +527,11 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable): # Minimum codepoint values (respectively) that can fit in a 1, 2, 3, # and 4 byte encoded UTF-8 sequence. - alias sizes = SIMD[DType.int32, 4]( - 0, - 2**7, - 2**11, - 2**16, - ) + alias sizes = SIMD[DType.uint32, 4](0, 2**7, 2**11, 2**16) # Count how many of the minimums this codepoint exceeds, which is equal # to the number of bytes needed to encode it. - var lt = (sizes <= Int(self)).cast[DType.uint8]() + var lt = (sizes <= self.to_u32()).cast[DType.uint8]() # TODO(MOCO-1537): Support `reduce_add()` at compile time. # var count = Int(lt.reduce_add()) diff --git a/stdlib/src/collections/string/string.mojo b/stdlib/src/collections/string/string.mojo index 395dcdce10..304d33b6cc 100644 --- a/stdlib/src/collections/string/string.mojo +++ b/stdlib/src/collections/string/string.mojo @@ -81,8 +81,7 @@ fn chr(c: Int) -> String: Examples: ```mojo - print(chr(97)) # "a" - print(chr(8364)) # "€" + print(chr(97), chr(8364)) # "a €" ``` . """ diff --git a/stdlib/src/collections/string/string_slice.mojo b/stdlib/src/collections/string/string_slice.mojo index cd97b8ed95..d0c5019a3f 100644 --- a/stdlib/src/collections/string/string_slice.mojo +++ b/stdlib/src/collections/string/string_slice.mojo @@ -2178,6 +2178,13 @@ fn _memmem[ return UnsafePointer[Scalar[type]]() +@always_inline +fn _is_utf8_continuation_byte[ + w: Int +](vec: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]: + return vec.cast[DType.int8]() < -(0b1000_0000 >> 1) + + fn _count_utf8_continuation_bytes(str_slice: StringSlice) -> Int: alias sizes = (256, 128, 64, 32, 16, 8) var ptr = str_slice.unsafe_ptr() @@ -2194,12 +2201,12 @@ fn _count_utf8_continuation_bytes(str_slice: StringSlice) -> Int: var rest = num_bytes - processed for _ in range(rest // s): var vec = (ptr + processed).load[width=s]() - var comp = (vec & 0b1100_0000) == 0b1000_0000 + var comp = _is_utf8_continuation_byte(vec) amnt += Int(comp.cast[DType.uint8]().reduce_add()) processed += s for i in range(num_bytes - processed): - amnt += Int((ptr[processed + i] & 0b1100_0000) == 0b1000_0000) + amnt += Int(_is_utf8_continuation_byte(ptr[processed + i])) return amnt @@ -2210,10 +2217,10 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int: this does not work correctly if given a continuation byte.""" debug_assert( - (b & 0b1100_0000) != 0b1000_0000, + not _is_utf8_continuation_byte(b), "Function does not work correctly if given a continuation byte.", ) - return Int(count_leading_zeros(~b)) + Int(b < 0b1000_0000) + return Int(count_leading_zeros(~b) | (b < 0b1000_0000).cast[DType.uint8]()) fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b): @@ -2230,7 +2237,7 @@ fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b): - 3 -> start of 3 byte long sequence. - 4 -> start of 4 byte long sequence. """ - return count_leading_zeros(~(b & UInt8(0b1111_0000))) + return count_leading_zeros(~b) @always_inline diff --git a/stdlib/test/collections/string/test_string.mojo b/stdlib/test/collections/string/test_string.mojo index 26d14bdeab..42e6adf1cc 100644 --- a/stdlib/test/collections/string/test_string.mojo +++ b/stdlib/test/collections/string/test_string.mojo @@ -263,6 +263,7 @@ def test_ord(): def test_chr(): + assert_equal("\0", chr(0)) assert_equal("A", chr(65)) assert_equal("a", chr(97)) assert_equal("!", chr(33))