Skip to content

Commit

Permalink
Add iterator for LinkedList
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Grenier <grenierb96@gmail.com>
  • Loading branch information
bgreni committed Feb 18, 2025
1 parent 8022690 commit e061f8e
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ what we publish.
- Added a `StringSlice.is_codepoint_boundary()` method for querying if a given
byte index is a boundary between encoded UTF-8 codepoints.

- Added an iterator to `LinkedList` ([PR #4005](https://github.com/modular/mojo/pull/4005))
- `LinkedList.__iter__()` to create a forward iterator.
- `LinkedList.__reversed__()` for a backward iterator.

```mojo
var ll = LinkedList[Int](1, 2, 3)
for element in ll:
print(element[])
```

### GPU changes

- `ctx.enqueue_function(compiled_func, ...)` is deprecated:
Expand Down
69 changes: 69 additions & 0 deletions stdlib/src/collections/linked_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,47 @@ struct Node[
writer.write(self.value)


@value
struct _LinkedListIter[
mut: Bool, //,
ElementType: CollectionElement,
origin: Origin[mut],
forward: Bool = True,
]:
var src: Pointer[LinkedList[ElementType], origin]
var curr: UnsafePointer[Node[ElementType]]
var seen: Int

fn __init__(out self, src: Pointer[LinkedList[ElementType], origin]):
self.src = src

@parameter
if forward:
self.curr = self.src[]._head
else:
self.curr = self.src[]._tail
self.seen = 0

fn __iter__(self) -> Self:
return self

fn __next__(mut self, out p: Pointer[ElementType, origin]):
p = Pointer[ElementType, origin].address_of(self.curr[].value)

@parameter
if forward:
self.curr = self.curr[].next
else:
self.curr = self.curr[].prev
self.seen += 1

fn __has_next__(self) -> Bool:
return Bool(self.curr)

fn __len__(self) -> Int:
return len(self.src[]) - self.seen


struct LinkedList[
ElementType: CollectionElement,
]:
Expand Down Expand Up @@ -692,6 +733,34 @@ struct LinkedList[
"""
return self._size

fn __iter__(self) -> _LinkedListIter[ElementType, __origin_of(self)]:
"""Iterate over elements of the list, returning immutable references.
Time Complexity:
O(1) for iterator construction.
O(n) in len(self) for a complete iteration of the list.
Returns:
An iterator of immutable references to the list elements.
"""
return _LinkedListIter(Pointer.address_of(self))

fn __reversed__(
self,
) -> _LinkedListIter[ElementType, __origin_of(self), forward=False]:
"""Iterate backwards over the list, returning immutable references.
Time Complexity:
O(1) for iterator construction.
O(n) in len(self) for a complete iteration of the list.
Returns:
A reversed iterator of immutable references to the list elements.
"""
return _LinkedListIter[ElementType, __origin_of(self), forward=False](
Pointer.address_of(self)
)

fn __bool__(self) -> Bool:
"""Check if the list is non-empty.
Expand Down
34 changes: 34 additions & 0 deletions stdlib/test/collections/test_linked_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,39 @@ def test_list_dtor():
assert_equal(g_dtor_count, 1)


def test_iter():
var l = LinkedList[Int](1, 2, 3)
var iter = l.__iter__()
assert_true(iter.__has_next__(), "Expected iter to have next")
assert_equal(len(iter), 3)
assert_equal(iter.__next__()[], 1)
assert_equal(iter.__next__()[], 2)
assert_equal(len(iter), 1)
assert_equal(iter.__next__()[], 3)
assert_equal(len(iter), 0)
assert_false(iter.__has_next__(), "Expected iter to not have next")

var riter = l.__reversed__()
assert_true(riter.__has_next__(), "Expected iter to have next")
assert_equal(len(riter), 3)
assert_equal(riter.__next__()[], 3)
assert_equal(riter.__next__()[], 2)
assert_equal(len(riter), 1)
assert_equal(riter.__next__()[], 1)
assert_equal(len(riter), 0)
assert_false(riter.__has_next__(), "Expected iter to not have next")

var i = 0
for el in l:
assert_equal(el[], l[i])
i += 1

i = 2
for el in l.__reversed__():
assert_equal(el[], l[i])
i -= 1


def main():
test_construction()
test_append()
Expand Down Expand Up @@ -574,3 +607,4 @@ def main():
test_list_dtor()
test_list_insert()
test_list_eq_ne()
test_iter()

0 comments on commit e061f8e

Please sign in to comment.