Skip to content

Commit

Permalink
add product
Browse files Browse the repository at this point in the history
  • Loading branch information
thejaminator committed May 21, 2024
1 parent 7204e84 commit 8918709
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool]
[tool.poetry]
name = "slist"
version = "0.3.6"
version = "0.3.7"
homepage = "https://github.com/thejaminator/slist"
description = "A typesafe list with more method chaining!"
authors = ["James Chua <chuajamessh@gmail.com>"]
Expand Down
7 changes: 5 additions & 2 deletions slist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def one_option(element: Optional[A]) -> Slist[A]:
"""Returns a list with one element, or an empty slist if the element is None
Equal to Slist.one(element).flatten_option()"""
return Slist([element]) if element is not None else Slist()

def any(self, predicate: Callable[[A], bool]) -> bool:
for x in self:
if predicate(x):
Expand All @@ -109,6 +109,9 @@ def filter(self, predicate: Callable[[A], bool]) -> Slist[A]:
def map(self, func: Callable[[A], B]) -> Slist[B]:
return Slist(func(item) for item in self)

def product(self: Sequence[A], other: Sequence[B]) -> Slist[Tuple[A, B]]:
return Slist((a, b) for a in self for b in other)

def map_2(self: Sequence[Tuple[B, C]], func: Callable[[B, C], D]) -> Slist[D]:
return Slist(func(b, c) for b, c in self)

Expand Down Expand Up @@ -451,7 +454,7 @@ def find_last_idx_or_raise(

def take(self, n: int) -> Slist[A]:
return Slist(self[:n])

def take_or_raise(self, n: int) -> Slist[A]:
# raises if we end up having less elements than n
if len(self) < n:
Expand Down
37 changes: 36 additions & 1 deletion tests/test_slist.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,46 @@ def __init__(self, name: str, age: int):
]
)


def test_take_or_raise():
numbers = Slist([1, 2, 3, 4, 5])
assert numbers.take_or_raise(0) == Slist([])
assert numbers.take_or_raise(1) == Slist([1])
assert numbers.take_or_raise(2) == Slist([1, 2])
assert numbers.take_or_raise(5) == Slist([1, 2, 3, 4, 5])
with pytest.raises(ValueError):
numbers.take_or_raise(6)
numbers.take_or_raise(6)


def test_product():
numbers = Slist([1, 2, 3, 4, 5])
# cartesian product
assert numbers.product(numbers) == Slist(
[
(1, 1),
(1, 2),
(1, 3),
(1, 4),
(1, 5),
(2, 1),
(2, 2),
(2, 3),
(2, 4),
(2, 5),
(3, 1),
(3, 2),
(3, 3),
(3, 4),
(3, 5),
(4, 1),
(4, 2),
(4, 3),
(4, 4),
(4, 5),
(5, 1),
(5, 2),
(5, 3),
(5, 4),
(5, 5),
]
)

0 comments on commit 8918709

Please sign in to comment.