From 8918709cebb73e10fda070f0a9ce9659314de276 Mon Sep 17 00:00:00 2001 From: James Chua Date: Wed, 22 May 2024 00:43:11 +0800 Subject: [PATCH] add product --- pyproject.toml | 2 +- slist/__init__.py | 7 +++++-- tests/test_slist.py | 37 ++++++++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d1094a1..d8b8006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool] [tool.poetry] name = "slist" -version = "0.3.6" +version = "0.3.7" homepage = "https://github.com/thejaminator/slist" description = "A typesafe list with more method chaining!" authors = ["James Chua "] diff --git a/slist/__init__.py b/slist/__init__.py index 6872ed5..75a801a 100644 --- a/slist/__init__.py +++ b/slist/__init__.py @@ -90,7 +90,7 @@ def one_option(element: Optional[A]) -> Slist[A]: """Returns a list with one element, or an empty slist if the element is None Equal to Slist.one(element).flatten_option()""" return Slist([element]) if element is not None else Slist() - + def any(self, predicate: Callable[[A], bool]) -> bool: for x in self: if predicate(x): @@ -109,6 +109,9 @@ def filter(self, predicate: Callable[[A], bool]) -> Slist[A]: def map(self, func: Callable[[A], B]) -> Slist[B]: return Slist(func(item) for item in self) + def product(self: Sequence[A], other: Sequence[B]) -> Slist[Tuple[A, B]]: + return Slist((a, b) for a in self for b in other) + def map_2(self: Sequence[Tuple[B, C]], func: Callable[[B, C], D]) -> Slist[D]: return Slist(func(b, c) for b, c in self) @@ -451,7 +454,7 @@ def find_last_idx_or_raise( def take(self, n: int) -> Slist[A]: return Slist(self[:n]) - + def take_or_raise(self, n: int) -> Slist[A]: # raises if we end up having less elements than n if len(self) < n: diff --git a/tests/test_slist.py b/tests/test_slist.py index 4685c0d..6d422f6 100644 --- a/tests/test_slist.py +++ b/tests/test_slist.py @@ -242,6 +242,7 @@ def __init__(self, name: str, age: int): ] ) + def test_take_or_raise(): numbers = Slist([1, 2, 3, 4, 5]) assert numbers.take_or_raise(0) == Slist([]) @@ -249,4 +250,38 @@ def test_take_or_raise(): assert numbers.take_or_raise(2) == Slist([1, 2]) assert numbers.take_or_raise(5) == Slist([1, 2, 3, 4, 5]) with pytest.raises(ValueError): - numbers.take_or_raise(6) \ No newline at end of file + numbers.take_or_raise(6) + + +def test_product(): + numbers = Slist([1, 2, 3, 4, 5]) + # cartesian product + assert numbers.product(numbers) == Slist( + [ + (1, 1), + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 1), + (2, 2), + (2, 3), + (2, 4), + (2, 5), + (3, 1), + (3, 2), + (3, 3), + (3, 4), + (3, 5), + (4, 1), + (4, 2), + (4, 3), + (4, 4), + (4, 5), + (5, 1), + (5, 2), + (5, 3), + (5, 4), + (5, 5), + ] + )