Skip to content

Commit

Permalink
Fix pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrattli committed Jan 13, 2024
1 parent a3c8b83 commit 1e6c329
Show file tree
Hide file tree
Showing 14 changed files with 145 additions and 115 deletions.
110 changes: 42 additions & 68 deletions aioreactive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,11 @@
from __future__ import annotations

from collections.abc import AsyncIterable, Awaitable, Callable, Iterable
from typing import (
Any,
Optional,
TypeVar,
Union,
overload,
)
from typing import Any, TypeVar, TypeVarTuple

from expression import Option, curry_flip, pipe
from expression.system.disposable import AsyncDisposable
from typing_extensions import Unpack

from .observables import AsyncAnonymousObservable, AsyncIterableObservable
from .observers import (
Expand All @@ -42,6 +37,7 @@
_B = TypeVar("_B")
_C = TypeVar("_C")
_D = TypeVar("_D")
_V = TypeVarTuple("_V")
_TSource = TypeVar("_TSource")
_TResult = TypeVar("_TResult")
_TOther = TypeVar("_TOther")
Expand All @@ -64,17 +60,19 @@ def __init__(self, source: AsyncObservable[_TSource]) -> None:

async def subscribe_async(
self,
send: Optional[Union[SendAsync[_TSource], AsyncObserver[_TSource]]] = None,
throw: Optional[ThrowAsync] = None,
close: Optional[CloseAsync] = None,
send: SendAsync[_TSource] | AsyncObserver[_TSource] | None = None,
throw: ThrowAsync | None = None,
close: CloseAsync | None = None,
) -> AsyncDisposable:
"""Subscribe to the async observable.
Uses the given observer to subscribe asynchronously to the async
observable.
Args:
observer: The async observer to subscribe.
send: The async observer or the send function to subscribe.
throw: The throw function to subscribe.
close: The close function to subscribe.
Returns:
An async disposable that can be used to dispose the
Expand Down Expand Up @@ -356,7 +354,9 @@ def skip_last(self, count: int) -> AsyncRx[_TSource]:
"""
return AsyncRx(pipe(self, skip_last(count)))

def starfilter(self: AsyncObservable[Any], predicate: Callable[..., bool]) -> AsyncRx[Any]:
def starfilter(
self: AsyncObservable[tuple[Unpack[_V]]], predicate: Callable[[Unpack[_V]], bool]
) -> AsyncRx[_TSource]:
"""Filter and spread the arguments to the predicate.
Filters the elements of an observable sequence based on a predicate.
Expand All @@ -368,7 +368,7 @@ def starfilter(self: AsyncObservable[Any], predicate: Callable[..., bool]) -> As
xs = pipe(self, starfilter(predicate))
return AsyncRx.create(xs)

def starmap(self: AsyncRx[tuple[Any, ...]], mapper: Callable[..., _TResult]) -> AsyncRx[_TResult]:
def starmap(self: AsyncRx[tuple[Unpack[_V]]], mapper: Callable[[Unpack[_V]], _TResult]) -> AsyncRx[_TResult]:
"""Map and spread the arguments to the mapper.
Returns:
Expand All @@ -384,7 +384,7 @@ def take(self, count: int) -> AsyncObservable[_TSource]:
an observable sequence.
Args:
count Number of elements to take.
count: Number of elements to take.
Returns:
An observable sequence that contains the specified number of
Expand Down Expand Up @@ -530,7 +530,9 @@ def catch(
def concat(
other: AsyncObservable[_TSource],
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TSource]]:
"""Concatenates an observable sequence with another observable
"""Concatenate observables.
Concatenates an observable sequence with another observable
sequence.
"""

Expand All @@ -552,7 +554,9 @@ def concat_seq(


def defer(factory: Callable[[], AsyncObservable[_TSource]]) -> AsyncObservable[_TSource]:
"""Returns an observable sequence that invokes the specified factory
"""Defer observable.
Returns an observable sequence that invokes the specified factory
function whenever a new observer subscribes.
"""
from .create import defer
Expand Down Expand Up @@ -676,10 +680,8 @@ def flat_map_async(
an observable sequence and merges the resulting observable sequences
back into one observable sequence.
Args:
mapperCallable ([type]): [description]
Awaitable ([type]): [description]
mapper: A transform function to apply to each element or an
Returns:
Stream[TSource, TResult]: [description]
Expand Down Expand Up @@ -726,7 +728,9 @@ def from_async_iterable(iter: AsyncIterable[_TSource]) -> AsyncObservable[_TSour


def interval(seconds: float, period: int) -> AsyncObservable[int]:
"""Returns an observable sequence that triggers the increasing
"""Observable interval.
Returns an observable sequence that triggers the increasing
sequence starting with 0 after the given msecs, and the after each
period.
"""
Expand Down Expand Up @@ -758,7 +762,9 @@ def map_async(
def mapi_async(
mapper: Callable[[_TSource, int], Awaitable[_TResult]],
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]:
"""Returns an observable sequence whose elements are the result of
"""Map indexed asynchronously.
Returns an observable sequence whose elements are the result of
invoking the async mapper function by incorporating the element's
index on each element of the source.
"""
Expand All @@ -770,7 +776,9 @@ def mapi_async(
def mapi(
mapper: Callable[[_TSource, int], _TResult],
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]:
"""Returns an observable sequence whose elements are the result of
"""Map indexed.
Returns an observable sequence whose elements are the result of
invoking the mapper function and incorporating the element's index
on each element of the source.
"""
Expand Down Expand Up @@ -842,7 +850,7 @@ def scan(
accumulator: Callable[[_TResult, _TSource], _TResult],
initial: _TResult,
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TResult]]:
"""The scan operator
"""The scan operator.
This operator runs the accumulator for every value from the source with the current state. After every run, the new
computed value is returned.
Expand Down Expand Up @@ -918,7 +926,9 @@ def skip(
def skip_last(
count: int,
) -> Callable[[AsyncObservable[_TSource]], AsyncObservable[_TSource]]:
"""Bypasses a specified number of elements at the end of an
"""Skip the last items of the observable sequence.
Bypasses a specified number of elements at the end of an
observable sequence.
This operator accumulates a queue with a length enough to store
Expand All @@ -939,28 +949,9 @@ def skip_last(
return skip_last(count)


@overload
def starfilter(
predicate: Callable[[tuple[_A, _B]], bool],
) -> Callable[[AsyncObservable[tuple[_A, _B]]], AsyncObservable[tuple[_A, _B]]]:
...


@overload
def starfilter(
predicate: Callable[[tuple[_A, _B, _C]], bool],
) -> Callable[[AsyncObservable[tuple[_A, _B, _C]]], AsyncObservable[tuple[_A, _B, _C]]]:
...


@overload
def starfilter(
predicate: Callable[[tuple[_A, _B, _C, _D]], bool],
) -> Callable[[AsyncObservable[tuple[_A, _B, _C, _D]]], AsyncObservable[tuple[_A, _B, _C, _D]]]:
...


def starfilter(predicate: Callable[..., bool]) -> Callable[[AsyncObservable[Any]], AsyncObservable[Any]]:
predicate: Callable[[Unpack[_V]], bool],
) -> Callable[[AsyncObservable[tuple[Unpack[_V]]]], AsyncObservable[Any]]:
"""Filter and spread the arguments to the predicate.
Filters the elements of an observable sequence based on a predicate.
Expand All @@ -974,28 +965,9 @@ def starfilter(predicate: Callable[..., bool]) -> Callable[[AsyncObservable[Any]
return starfilter(predicate)


@overload
def starmap(
mapper: Callable[[_A, _B], _TResult],
) -> Callable[[AsyncObservable[tuple[_A, _B]]], AsyncObservable[_TResult]]:
...


@overload
def starmap(
mapper: Callable[[_A, _B, _C], _TResult],
) -> Callable[[AsyncObservable[tuple[_A, _B, _C]]], AsyncObservable[_TResult]]:
...


@overload
def starmap(
mapper: Callable[[_A, _B, _C, _D], _TResult],
) -> Callable[[AsyncObservable[tuple[_A, _B, _C, _D]]], AsyncObservable[_TResult]]:
...


def starmap(mapper: Callable[..., _TResult]) -> Callable[[AsyncObservable[Any]], AsyncObservable[_TResult]]:
mapper: Callable[[Unpack[_V]], _TResult],
) -> Callable[[AsyncObservable[tuple[Unpack[_V]]]], AsyncObservable[_TResult]]:
"""Map and spread the arguments to the mapper.
Returns an observable sequence whose elements are the result of
Expand All @@ -1021,7 +993,7 @@ def take(
an observable sequence.
Args:
count Number of elements to take.
count: Number of elements to take.
Returns:
An observable sequence that contains the specified number of
Expand Down Expand Up @@ -1072,7 +1044,9 @@ def take_until(


def timer(due_time: float) -> AsyncObservable[int]:
"""Returns an observable sequence that triggers the value 0
"""Observable timer.
Returns an observable sequence that triggers the value 0
after the given duetime in milliseconds.
"""
from .create import timer
Expand Down
9 changes: 6 additions & 3 deletions aioreactive/create.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from asyncio import Future
from asyncio import Future, Task
from collections.abc import AsyncIterable, Awaitable, Callable, Iterable
from typing import Any, TypeVar

Expand Down Expand Up @@ -77,9 +77,9 @@ async def worker(obv: AsyncObserver[TSource], _: CancellationToken) -> None:


def of_async_iterable(iterable: AsyncIterable[TSource]) -> AsyncObservable[TSource]:
async def subscribe_async(observer: AsyncObserver[TSource]) -> AsyncDisposable:
task: Future[None] | None = None
tasks: set[Task[Any]] = set()

async def subscribe_async(observer: AsyncObserver[TSource]) -> AsyncDisposable:
async def cancel() -> None:
if task:
task.cancel()
Expand All @@ -95,12 +95,15 @@ async def worker() -> None:
return

await observer.aclose()
tasks.remove(task)

try:
task = asyncio.create_task(worker())
except Exception as ex:
log.debug("FromIterable:worker(), Exception: %s" % ex)
await observer.athrow(ex)
else:
tasks.add(task)
return sub

return AsyncAnonymousObservable(subscribe_async)
Expand Down
2 changes: 1 addition & 1 deletion aioreactive/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ async def get_latest() -> Notification[_TSource]:
latest = await get_latest()
return TailCall[Notification[_TSource]](latest)

await message_loop(OnCompleted) # Use as sentinel value as it will not match any OnNext value
await message_loop(OnCompleted())

agent = MailboxProcessor.start(worker)

Expand Down
5 changes: 5 additions & 0 deletions aioreactive/iterable/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Iterable module.
This module contains functions to create async observables from
iterables.
"""
2 changes: 1 addition & 1 deletion aioreactive/iterable/to_async_observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def to_async_observable(source: AsyncIterable[TSource]) -> AsyncObservable[TSour
"""Convert to async observable.
Keyword Arguments:
source -- Async iterable to convert to async observable.
source: Async iterable to convert to async observable.
Returns async observable
"""
Expand Down
23 changes: 14 additions & 9 deletions aioreactive/leave.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections.abc import AsyncIterable, Iterable
from typing import TypeVar
from collections.abc import AsyncIterable
from typing import Any, TypeVar

import reactivex
from expression.system.disposable import AsyncDisposable
Expand Down Expand Up @@ -31,9 +31,10 @@ def to_async_iterable(source: AsyncObservable[_TSource]) -> AsyncIterable[_TSour

def to_observable(source: AsyncObservable[_TSource]) -> Observable[_TSource]:
"""Convert async observable to observable."""
tasks: set[asyncio.Task[Any]] = set()

def subscribe(obv: ObserverBase[_TSource], scheduler: Optional[SchedulerBase] = None) -> DisposableBase:
subscription: Optional[AsyncDisposable] = None
def subscribe(obv: ObserverBase[_TSource], scheduler: SchedulerBase | None = None) -> DisposableBase:
subscription: AsyncDisposable | None = None

async def start() -> None:
nonlocal subscription
Expand All @@ -48,18 +49,22 @@ async def aclose() -> None:
obv.on_completed()

subscription = await source.subscribe_async(AsyncAnonymousObserver(asend, athrow, aclose))
tasks.remove(task)

asyncio.create_task(start())
task = asyncio.create_task(start())
tasks.add(task)
task.add_done_callback(lambda _: tasks.remove(task))

def dispose() -> None:
if subscription:
asyncio.create_task(subscription.dispose_async())
task = asyncio.create_task(subscription.dispose_async())
tasks.add(task)

return Disposable(dispose)

return reactivex.create(subscribe)


def to_iterable(source: AsyncObservable[_TSource]) -> Iterable[_TSource]:
"""Convert async observable to iterable."""
return to_observable(source).to_iterable()
# def to_iterable(source: AsyncObservable[_TSource]) -> Iterable[_TSource]:
# """Convert async observable to iterable."""
# return to_observable(source).to_iterable()
2 changes: 1 addition & 1 deletion aioreactive/testing/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def aclose(self) -> None:
log.debug("AsyncAnonymousObserver:aclose()")

time = self.time()
self._values.append((time, OnCompleted))
self._values.append((time, OnCompleted()))

await self._close()
await super().aclose()
Expand Down
Loading

0 comments on commit 1e6c329

Please sign in to comment.