From 1f05a6b6b853a2fa400982f6d5dbf7f4fdc52fd4 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 10 Dec 2024 19:53:00 -0300 Subject: [PATCH 1/5] Add List[Scalar[D]] append SIMD and Span[Scalar[D]] Signed-off-by: martinvuyk --- stdlib/src/base64/_b64encode.mojo | 28 +------------ stdlib/src/collections/list.mojo | 61 ++++++++++++++++++++++++++--- stdlib/src/utils/inline_string.mojo | 25 +++--------- stdlib/test/python/my_module.py | 3 +- 4 files changed, 66 insertions(+), 51 deletions(-) diff --git a/stdlib/src/base64/_b64encode.mojo b/stdlib/src/base64/_b64encode.mojo index 74b8c31501..02329029fe 100644 --- a/stdlib/src/base64/_b64encode.mojo +++ b/stdlib/src/base64/_b64encode.mojo @@ -195,21 +195,6 @@ fn load_incomplete_simd[ return result -fn store_incomplete_simd[ - simd_width: Int -]( - pointer: UnsafePointer[UInt8], - owned simd_vector: SIMD[DType.uint8, simd_width], - nb_of_elements_to_store: Int, -): - var tmp_buffer_pointer = UnsafePointer.address_of(simd_vector).bitcast[ - UInt8 - ]() - - memcpy(dest=pointer, src=tmp_buffer_pointer, count=nb_of_elements_to_store) - _ = simd_vector # We make it live long enough - - # TODO: Use Span instead of List as input when Span is easier to use @no_inline fn b64encode_with_buffers( @@ -229,11 +214,7 @@ fn b64encode_with_buffers( var input_vector = start_of_input_chunk.load[width=simd_width]() - result_vector = _to_b64_ascii(input_vector) - - (result.unsafe_ptr() + len(result)).store(result_vector) - - result.size += simd_width + result.append(_to_b64_ascii(input_vector)) input_index += input_simd_width # We handle the last 0, 1 or 2 chunks @@ -268,12 +249,7 @@ fn b64encode_with_buffers( ]( nb_of_elements_to_load ) - store_incomplete_simd( - result.unsafe_ptr() + len(result), - result_vector_with_equals, - nb_of_elements_to_store, - ) - result.size += nb_of_elements_to_store + result.append(result_vector_with_equals, nb_of_elements_to_store) input_index += input_simd_width diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index bcfca0c2fa..2c51c81cd0 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -493,15 +493,66 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( self.capacity = new_capacity fn append(mut self, owned value: T): - """Appends a value to this list. + """Appends a value to this list. If there is no capacity left, resizes + to twice the current capacity. Except for 0 capacity where it sets 1. Args: value: The value to append. """ - if self.size >= self.capacity: - self._realloc(max(1, self.capacity * 2)) - (self.data + self.size).init_pointee_move(value^) - self.size += 1 + if len(self) >= self.capacity: + self._realloc(self.capacity * 2 + int(self.capacity == 0)) + (self.data + len(self)).init_pointee_move(value^) + self._len += 1 + + fn append[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _]): + """Appends a vector to this list. If there is no capacity left, resizes + to `len(self) + value.size`. + + Parameters: + D: The DType. + + Args: + value: The value to append. + """ + self.reserve(len(self) + value.size) + (self.data + len(self)).store(value) + self._len += value.size + + fn append[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _], count: Int): + """Appends a vector to this list. If there is no capacity left, resizes + to `len(self) + count`. + + Parameters: + D: The DType. + + Args: + value: The value to append. + count: The ammount of items to append. + """ + self.reserve(len(self) + count) + var v_ptr = UnsafePointer.address_of(value).bitcast[Scalar[D]]() + memcpy(self.data + len(self), v_ptr, count) + self._len += count + + fn append[ + D: DType, // + ](mut self: List[Scalar[D], *_, **_], value: Span[Scalar[D]]): + """Appends a Span to this list. If there is no capacity left, resizes + to `len(self) + len(value)`. + + Parameters: + D: The DType. + + Args: + value: The value to append. + """ + self.reserve(len(self) + len(value)) + memcpy(self.data + len(self), value.unsafe_ptr(), len(value)) + self._len += len(value) fn insert(mut self, i: Int, owned value: T): """Inserts a value to the list at the given index. diff --git a/stdlib/src/utils/inline_string.mojo b/stdlib/src/utils/inline_string.mojo index 8c6cfb3166..2bbda77571 100644 --- a/stdlib/src/utils/inline_string.mojo +++ b/stdlib/src/utils/inline_string.mojo @@ -147,28 +147,15 @@ struct InlineString(Sized, Stringable, CollectionElement, CollectionElementNew): # Begin by heap allocating enough space to store the combined # string. var buffer = List[UInt8](capacity=total_len) - # Copy the bytes from the current small string layout - memcpy( - dest=buffer.unsafe_ptr(), - src=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(), - count=len(self), + var span_self = Span[Byte, __origin_of(self)]( + ptr=self._storage[_FixedString[Self.SMALL_CAP]].unsafe_ptr(), + length=len(self), ) - + buffer.append(span_self) # Copy the bytes from the additional string. - memcpy( - dest=buffer.unsafe_ptr() + len(self), - src=str_slice.unsafe_ptr(), - count=str_slice.byte_length(), - ) - - # Record that we've initialized `total_len` count of elements - # in `buffer` - buffer.size = total_len - - # Add the NUL byte - buffer.append(0) - + buffer.append(str_slice.as_bytes()) + buffer.append(0) # Add the NUL byte self._storage = Self.Layout(String(buffer^)) fn __add__(self, other: StringLiteral) -> Self: diff --git a/stdlib/test/python/my_module.py b/stdlib/test/python/my_module.py index 8147b0a382..c78c39556e 100644 --- a/stdlib/test/python/my_module.py +++ b/stdlib/test/python/my_module.py @@ -25,7 +25,8 @@ def __init__(self, bar): class AbstractPerson(ABC): @abstractmethod - def method(self): ... + def method(self): + ... def my_function(name): From ac0d2528e54c93ebd373243317e91d05818b0a3b Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 10 Dec 2024 19:55:53 -0300 Subject: [PATCH 2/5] fix use size instead of _len Signed-off-by: martinvuyk --- stdlib/src/collections/list.mojo | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index 2c51c81cd0..a6104796c4 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -502,7 +502,7 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( if len(self) >= self.capacity: self._realloc(self.capacity * 2 + int(self.capacity == 0)) (self.data + len(self)).init_pointee_move(value^) - self._len += 1 + self.size += 1 fn append[ D: DType, // @@ -518,7 +518,7 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( """ self.reserve(len(self) + value.size) (self.data + len(self)).store(value) - self._len += value.size + self.size += value.size fn append[ D: DType, // @@ -536,7 +536,7 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( self.reserve(len(self) + count) var v_ptr = UnsafePointer.address_of(value).bitcast[Scalar[D]]() memcpy(self.data + len(self), v_ptr, count) - self._len += count + self.size += count fn append[ D: DType, // @@ -552,7 +552,7 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( """ self.reserve(len(self) + len(value)) memcpy(self.data + len(self), value.unsafe_ptr(), len(value)) - self._len += len(value) + self.size += len(value) fn insert(mut self, i: Int, owned value: T): """Inserts a value to the list at the given index. From 922a18e691701d5c4e2bb53af1e3c251c464d934 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 10 Dec 2024 21:18:57 -0300 Subject: [PATCH 3/5] apply suggestions by @ConnorGray Signed-off-by: martinvuyk --- stdlib/src/collections/list.mojo | 41 +++++++++++++++++++------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index a6104796c4..8de2b2d2fd 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -493,38 +493,42 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( self.capacity = new_capacity fn append(mut self, owned value: T): - """Appends a value to this list. If there is no capacity left, resizes - to twice the current capacity. Except for 0 capacity where it sets 1. + """Appends a value to this list. Args: value: The value to append. + + Notes: + If there is no capacity left, resizes to twice the current capacity. + Except for 0 capacity where it sets 1. """ - if len(self) >= self.capacity: + if self.size >= self.capacity: self._realloc(self.capacity * 2 + int(self.capacity == 0)) - (self.data + len(self)).init_pointee_move(value^) + (self.data + self.size).init_pointee_move(value^) self.size += 1 fn append[ D: DType, // ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _]): - """Appends a vector to this list. If there is no capacity left, resizes - to `len(self) + value.size`. + """Appends a vector to this list. Parameters: D: The DType. Args: value: The value to append. + + Notes: + If there is no capacity left, resizes to `len(self) + value.size`. """ - self.reserve(len(self) + value.size) - (self.data + len(self)).store(value) + self.reserve(self.size + value.size) + (self.data + self.size).store(value) self.size += value.size fn append[ D: DType, // ](mut self: List[Scalar[D], *_, **_], value: SIMD[D, _], count: Int): - """Appends a vector to this list. If there is no capacity left, resizes - to `len(self) + count`. + """Appends a vector to this list. Parameters: D: The DType. @@ -532,26 +536,31 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( Args: value: The value to append. count: The ammount of items to append. + + Notes: + If there is no capacity left, resizes to `len(self) + count`. """ - self.reserve(len(self) + count) + self.reserve(self.size + count) var v_ptr = UnsafePointer.address_of(value).bitcast[Scalar[D]]() - memcpy(self.data + len(self), v_ptr, count) + memcpy(self.data + self.size, v_ptr, count) self.size += count fn append[ D: DType, // ](mut self: List[Scalar[D], *_, **_], value: Span[Scalar[D]]): - """Appends a Span to this list. If there is no capacity left, resizes - to `len(self) + len(value)`. + """Appends a Span to this list. Parameters: D: The DType. Args: value: The value to append. + + Notes: + If there is no capacity left, resizes to `len(self) + len(value)`. """ - self.reserve(len(self) + len(value)) - memcpy(self.data + len(self), value.unsafe_ptr(), len(value)) + self.reserve(self.size + len(value)) + memcpy(self.data + self.size, value.unsafe_ptr(), len(value)) self.size += len(value) fn insert(mut self, i: Int, owned value: T): From b0b2da9326de7937caad605cccca89a0b43f0592 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 10 Dec 2024 22:35:14 -0300 Subject: [PATCH 4/5] add tests Signed-off-by: martinvuyk --- stdlib/src/collections/list.mojo | 1 + stdlib/src/testing/testing.mojo | 39 +++++++++++++++++++- stdlib/test/collections/test_list.mojo | 50 +++++++++++++------------- 3 files changed, 65 insertions(+), 25 deletions(-) diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index 8de2b2d2fd..6a3850da91 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -540,6 +540,7 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( Notes: If there is no capacity left, resizes to `len(self) + count`. """ + debug_assert(count <= value.size, "count must be <= value.size") self.reserve(self.size + count) var v_ptr = UnsafePointer.address_of(value).bitcast[Scalar[D]]() memcpy(self.data + self.size, v_ptr, count) diff --git a/stdlib/src/testing/testing.mojo b/stdlib/src/testing/testing.mojo index 20173be736..ef39769ff8 100644 --- a/stdlib/src/testing/testing.mojo +++ b/stdlib/src/testing/testing.mojo @@ -32,7 +32,7 @@ def main(): """ from collections import Optional from math import isclose - +from memory import memcmp from builtin._location import __call_location, _SourceLocation # ===----------------------------------------------------------------------=== # @@ -236,6 +236,43 @@ fn assert_equal[ ) +@always_inline +fn assert_equal[ + D: DType +]( + lhs: List[Scalar[D]], + rhs: List[Scalar[D]], + msg: String = "", + *, + location: Optional[_SourceLocation] = None, +) raises: + """Asserts that two lists are equal. + + Parameters: + D: A DType. + + Args: + lhs: The left-hand side list. + rhs: The right-hand side list. + msg: The message to be printed if the assertion fails. + location: The location of the error (default to the `__call_location`). + + Raises: + An Error with the provided message if assert fails and `None` otherwise. + """ + var length = len(lhs) + if ( + length != len(rhs) + or memcmp(lhs.unsafe_ptr(), rhs.unsafe_ptr(), length) != 0 + ): + raise _assert_cmp_error["`left == right` comparison"]( + lhs.__str__(), + rhs.__str__(), + msg=msg, + loc=location.or_else(__call_location()), + ) + + @always_inline fn assert_not_equal[ T: Testable diff --git a/stdlib/test/collections/test_list.mojo b/stdlib/test/collections/test_list.mojo index 56dab6510b..805a38ef39 100644 --- a/stdlib/test/collections/test_list.mojo +++ b/stdlib/test/collections/test_list.mojo @@ -437,32 +437,33 @@ def test_list_index(): _ = test_list_b.index(20, start=4, stop=5) -def test_list_extend(): - # - # Test extending the list [1, 2, 3] with itself - # +def test_list_append(): + items = List[UInt32]() + items.append(1) + items.append(2) + items.append(3) + assert_equal(items, List[UInt32](1, 2, 3)) + + # append span + copy = items + items.append(Span(copy)) + assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3)) + + # whole SIMD + items = List[UInt32](1, 2, 3) + items.append(SIMD[DType.uint32, 4](1, 2, 3, 4)) + assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3, 4)) + # part of SIMD + items = List[UInt32](1, 2, 3) + items.append(SIMD[DType.uint32, 4](1, 2, 3, 4), 3) + assert_equal(items, List[UInt32](1, 2, 3, 1, 2, 3)) - vec = List[Int]() - vec.append(1) - vec.append(2) - vec.append(3) - assert_equal(len(vec), 3) - assert_equal(vec[0], 1) - assert_equal(vec[1], 2) - assert_equal(vec[2], 3) - - var copy = vec - vec.extend(copy) - - # vec == [1, 2, 3, 1, 2, 3] - assert_equal(len(vec), 6) - assert_equal(vec[0], 1) - assert_equal(vec[1], 2) - assert_equal(vec[2], 3) - assert_equal(vec[3], 1) - assert_equal(vec[4], 2) - assert_equal(vec[5], 3) +def test_list_extend(): + items = List[Int](1, 2, 3) + copy = items + items.extend(copy) + assert_equal(items, List[Int](1, 2, 3, 1, 2, 3)) def test_list_extend_non_trivial(): @@ -952,6 +953,7 @@ def main(): test_list_reverse_move_count() test_list_insert() test_list_index() + test_list_append() test_list_extend() test_list_extend_non_trivial() test_list_explicit_copy() From 7cfe7e0b07fd412310f131c7fea238f312a637e2 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Sun, 15 Dec 2024 11:09:20 -0300 Subject: [PATCH 5/5] add a maybe better optimization, since current might not be as good (pointed out by @soraros) Signed-off-by: martinvuyk --- stdlib/src/collections/list.mojo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/src/collections/list.mojo b/stdlib/src/collections/list.mojo index 7c26272986..efcf159dd8 100644 --- a/stdlib/src/collections/list.mojo +++ b/stdlib/src/collections/list.mojo @@ -502,7 +502,7 @@ struct List[T: CollectionElement, hint_trivial_type: Bool = False]( Except for 0 capacity where it sets 1. """ if self.size >= self.capacity: - self._realloc(self.capacity * 2 + int(self.capacity == 0)) + self._realloc(self.capacity * 2 | int(self.capacity == 0)) (self.data + self.size).init_pointee_move(value^) self.size += 1