Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Result filter methods #224

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions expression/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,38 @@ def is_ok(self) -> bool:
"""Return `True` if the result is an `Ok` value."""
return self.tag == "ok"

def filter(self, predicate: Callable[[_TSource], bool], default: _TError) -> Result[_TSource, _TError]:
"""Filter result.

Returns the input if the predicate evaluates to true, otherwise
returns the `default`
"""
match self:
case Result(tag="ok", ok=value) if predicate(value):
return self
case Result(tag="error"):
return self
case _:
return Error(default)

def filter_with(
self,
predicate: Callable[[_TSource], bool],
default: Callable[[_TSource], _TError],
) -> Result[_TSource, _TError]:
"""Filter result.

Returns the input if the predicate evaluates to true, otherwise
returns the `default` using the value as input
"""
match self:
case Result(tag="ok", ok=value) if predicate(value):
return self
case Result(tag="ok", ok=value):
return Error(default(value))
case Result():
return self

def dict(self) -> builtins.dict[str, _TSource | _TError | Literal["ok", "error"]]:
"""Return a json serializable representation of the result."""
match self:
Expand Down Expand Up @@ -352,6 +384,11 @@ def map2(
return x.map2(y, mapper)


@curry_flip(1)
def map_error(result: Result[_TSource, _TError], mapper: Callable[[_TError], _TResult]) -> Result[_TSource, _TResult]:
return result.map_error(mapper)


@curry_flip(1)
def bind(
result: Result[_TSource, _TError],
Expand All @@ -374,11 +411,46 @@ def is_error(result: Result[_TSource, _TError]) -> TypeGuard[Result[_TSource, _T
return result.is_error()


@curry_flip(1)
def filter(
result: Result[_TSource, _TError],
predicate: Callable[[_TSource], bool],
default: _TError,
) -> Result[_TSource, _TError]:
return result.filter(predicate, default)


@curry_flip(1)
def filter_with(
result: Result[_TSource, _TError],
predicate: Callable[[_TSource], bool],
default: Callable[[_TSource], _TError],
) -> Result[_TSource, _TError]:
return result.filter_with(predicate, default)


def swap(result: Result[_TSource, _TError]) -> Result[_TError, _TSource]:
"""Swaps the value in the result so an Ok becomes an Error and an Error becomes an Ok."""
return result.swap()


@curry_flip(1)
def or_else(result: Result[_TSource, _TError], other: Result[_TSource, _TError]) -> Result[_TSource, _TError]:
return result.or_else(other)


@curry_flip(1)
def or_else_with(
result: Result[_TSource, _TError],
other: Callable[[_TError], Result[_TSource, _TError]],
) -> Result[_TSource, _TError]:
return result.or_else_with(other)


def merge(result: Result[_TSource, _TSource]) -> _TSource:
return result.merge()


def to_option(result: Result[_TSource, Any]) -> Option[_TSource]:
from expression.core.option import Nothing, Some

Expand Down Expand Up @@ -406,9 +478,17 @@ def of_option_with(value: Option[_TSource], error: Callable[[], _TError]) -> Res
"map",
"bind",
"dict",
"filter",
"filter_with",
"is_ok",
"is_error",
"map2",
"map_error",
"merge",
"to_option",
"of_option",
"of_option_with",
"or_else",
"or_else_with",
"swap",
]
68 changes: 68 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,14 @@ def test_result_error_chained_map(msg: str, y: int):
case _:
assert False

@given(st.text())
def test_map_error(msg: str):
assert Error(msg).map_error(lambda x: f"more {x}") == Error("more " + msg)

@given(st.text())
def test_map_error_piped(msg: str):
assert Error(msg).pipe(result.map_error(lambda x: f"more {x}")) == Error(f"more {msg}")


@given(st.integers(), st.integers()) # type: ignore
def test_result_bind_piped(x: int, y: int):
Expand Down Expand Up @@ -362,6 +370,54 @@ def test_pipeline_error():
assert hn(42) == error


def test_filter_ok_passing_predicate():
xs: Result[int, str] = Ok(42)
ys = xs.filter(lambda x: x > 10, "error")

assert ys == xs


def test_filter_ok_failing_predicate():
xs: Result[int, str] = Ok(5)
ys = xs.filter(lambda x: x > 10, "error")

assert ys == Error("error")


def test_filter_error():
error = Error("original error")
ys = error.filter(lambda x: x > 10, "error")

assert ys == error

def test_filter_piped():
assert Ok(42).pipe(result.filter(lambda x: x > 10, "error")) == Ok(42)


def test_filter_with_ok_passing_predicate():
xs: Result[int, str] = Ok(42)
ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}")

assert ys == xs


def test_filter_with_ok_failing_predicate():
xs: Result[int, str] = Ok(5)
ys = xs.filter_with(lambda x: x > 10, lambda value: f"error {value}")

assert ys == Error("error 5")


def test_filter_with_error():
error = Error("original error")
ys = error.filter_with(lambda x: x > 10, lambda value: f"error {value}")

assert ys == error

def test_filter_with_piped():
assert Ok(42).pipe(result.filter_with(lambda x: x > 10, lambda value: f"error {value}")) == Ok(42)


class MyError(BaseModel):
message: str

Expand Down Expand Up @@ -525,6 +581,8 @@ def test_result_swap_with_error():
xs = result.swap(error)
assert xs == Ok(1)

def test_swap_piped():
assert Ok(42).pipe(result.swap) == Error(42)

def test_ok_or_else_ok():
xs: Result[int, str] = Ok(42)
Expand All @@ -549,6 +607,8 @@ def test_error_or_else_error():
ys = xs.or_else(Error("new error"))
assert ys == Error("new error")

def test_or_else_piped():
assert Ok(42).pipe(result.or_else(Ok(0))) == Ok(42)

def test_ok_or_else_with_ok():
xs: Result[str, str] = Ok("good")
Expand All @@ -574,6 +634,10 @@ def test_error_or_else_with_error():
assert ys == Error("new error from original error")


def test_or_else_with_piped():
assert Ok(42).pipe(result.or_else_with(lambda _: Ok(0))) == Ok(42)


def test_merge_ok():
assert Result.Ok(42).merge() == 42

Expand Down Expand Up @@ -601,3 +665,7 @@ class Child2(Parent):
def test_merge_subclasses():
xs: Result[Parent, Parent] = Result.Ok(Child1(x=42))
assert xs.merge() == Child1(x=42)


def test_merge_piped():
assert Ok(42).pipe(result.merge) == 42
Loading