Skip to content

Commit

Permalink
[External] [stdlib] Micro-optimize utf8 helper functions (#56579)
Browse files Browse the repository at this point in the history
[External] [stdlib] Micro-optimize utf8 helper functions

Micro-optimize utf8 helper functions

ORIGINAL_AUTHOR=martinvuyk
<110240700+martinvuyk@users.noreply.github.com>
PUBLIC_PR_LINK=#3896

Co-authored-by: martinvuyk <110240700+martinvuyk@users.noreply.github.com>
Closes #3896
MODULAR_ORIG_COMMIT_REV_ID: d24cb7cdf532b3f6a10f9ee19c95b66f22efc967
  • Loading branch information
modularbot and martinvuyk committed Feb 26, 2025
1 parent 2d53f46 commit bdaca0f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 31 deletions.
61 changes: 37 additions & 24 deletions stdlib/src/collections/string/codepoint.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ from collections.string import StringSlice

from bit import count_leading_zeros
from memory import UnsafePointer
from sys.intrinsics import likely


@always_inline
Expand Down Expand Up @@ -450,13 +451,13 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable):
return self._scalar_value

@always_inline
fn unsafe_write_utf8(self, ptr: UnsafePointer[Byte]) -> UInt:
fn unsafe_write_utf8[
optimize_ascii: Bool = True
](self, ptr: UnsafePointer[Byte]) -> UInt:
"""Shift unicode to utf8 representation.
Safety:
`ptr` MUST point to at least `self.utf8_byte_length()` allocated
bytes or else an out-of-bounds write will occur, which is undefined
behavior.
Parameters:
optimize_ascii: Optimize for languages with mostly ASCII characters.
Args:
ptr: Pointer value to write the encoded UTF-8 bytes. Must validly
Expand All @@ -466,6 +467,11 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable):
Returns:
Returns the number of bytes written.
Safety:
`ptr` MUST point to at least `self.utf8_byte_length()` allocated
bytes or else an out-of-bounds write will occur, which is undefined
behavior.
### Unicode (represented as UInt32 BE) to UTF-8 conversion:
- 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa
- a
Expand All @@ -483,18 +489,30 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable):

var num_bytes = self.utf8_byte_length()

if num_bytes == 1:
ptr[0] = UInt8(c)
return 1

var shift = 6 * (num_bytes - 1)
var mask = UInt8(0xFF) >> UInt8(num_bytes + 1)
var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes)
ptr[0] = ((c >> shift) & mask) | num_bytes_marker
for i in range(1, num_bytes):
shift -= 6
ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000

@parameter
if optimize_ascii:
# FIXME(#933): can't run LLVM intrinsic at compile time
# if likely(num_bytes == 1):
if num_bytes == 1:
ptr[0] = UInt8(c)
return 1
var shift = 6 * (num_bytes - 1)
var mask = UInt8(0xFF) >> (num_bytes + 1)
var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes)
ptr[0] = ((c >> shift) & mask) | num_bytes_marker
for i in range(1, num_bytes):
shift -= 6
ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000
else:
var shift = 6 * (num_bytes - 1)
var mask = UInt8(0xFF) >> (num_bytes + Int(num_bytes > 1))
var num_bytes_marker = UInt8(0xFF) << (8 - num_bytes)
ptr[0] = ((c >> shift) & mask) | (
num_bytes_marker & -Int(num_bytes != 1)
)
for i in range(1, num_bytes):
shift -= 6
ptr[i] = ((c >> shift) & 0b0011_1111) | 0b1000_0000
return num_bytes

@always_inline
Expand All @@ -509,16 +527,11 @@ struct Codepoint(CollectionElement, EqualityComparable, Intable, Stringable):

# Minimum codepoint values (respectively) that can fit in a 1, 2, 3,
# and 4 byte encoded UTF-8 sequence.
alias sizes = SIMD[DType.int32, 4](
0,
2**7,
2**11,
2**16,
)
alias sizes = SIMD[DType.uint32, 4](0, 2**7, 2**11, 2**16)

# Count how many of the minimums this codepoint exceeds, which is equal
# to the number of bytes needed to encode it.
var lt = (sizes <= Int(self)).cast[DType.uint8]()
var lt = (sizes <= self.to_u32()).cast[DType.uint8]()

# TODO(MOCO-1537): Support `reduce_add()` at compile time.
# var count = Int(lt.reduce_add())
Expand Down
3 changes: 1 addition & 2 deletions stdlib/src/collections/string/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ fn chr(c: Int) -> String:
Examples:
```mojo
print(chr(97)) # "a"
print(chr(8364)) # "€"
print(chr(97), chr(8364)) # "a €"
```
.
"""
Expand Down
17 changes: 12 additions & 5 deletions stdlib/src/collections/string/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2196,6 +2196,13 @@ fn _memmem[
return UnsafePointer[Scalar[type]]()


@always_inline
fn _is_utf8_continuation_byte[
w: Int
](vec: SIMD[DType.uint8, w]) -> SIMD[DType.bool, w]:
return vec.cast[DType.int8]() < -(0b1000_0000 >> 1)


fn _count_utf8_continuation_bytes(str_slice: StringSlice) -> Int:
alias sizes = (256, 128, 64, 32, 16, 8)
var ptr = str_slice.unsafe_ptr()
Expand All @@ -2212,12 +2219,12 @@ fn _count_utf8_continuation_bytes(str_slice: StringSlice) -> Int:
var rest = num_bytes - processed
for _ in range(rest // s):
var vec = (ptr + processed).load[width=s]()
var comp = (vec & 0b1100_0000) == 0b1000_0000
var comp = _is_utf8_continuation_byte(vec)
amnt += Int(comp.cast[DType.uint8]().reduce_add())
processed += s

for i in range(num_bytes - processed):
amnt += Int((ptr[processed + i] & 0b1100_0000) == 0b1000_0000)
amnt += Int(_is_utf8_continuation_byte(ptr[processed + i]))

return amnt

Expand All @@ -2228,10 +2235,10 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int:
this does not work correctly if given a continuation byte."""

debug_assert(
(b & 0b1100_0000) != 0b1000_0000,
not _is_utf8_continuation_byte(b),
"Function does not work correctly if given a continuation byte.",
)
return Int(count_leading_zeros(~b)) + Int(b < 0b1000_0000)
return Int(count_leading_zeros(~b) | (b < 0b1000_0000).cast[DType.uint8]())


fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b):
Expand All @@ -2248,7 +2255,7 @@ fn _utf8_byte_type(b: SIMD[DType.uint8, _], /) -> __type_of(b):
- 3 -> start of 3 byte long sequence.
- 4 -> start of 4 byte long sequence.
"""
return count_leading_zeros(~(b & UInt8(0b1111_0000)))
return count_leading_zeros(~b)


@always_inline
Expand Down
1 change: 1 addition & 0 deletions stdlib/test/collections/string/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def test_ord():


def test_chr():
assert_equal("\0", chr(0))
assert_equal("A", chr(65))
assert_equal("a", chr(97))
assert_equal("!", chr(33))
Expand Down

0 comments on commit bdaca0f

Please sign in to comment.