diff --git a/src/lib.rs b/src/lib.rs index 9e2d678f..47cdd3f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,20 +108,19 @@ impl RustNotify { }) } - pub fn watch(&self, py: Python, debounce_ms: u64, step_ms: u64, cancel_event: PyObject) -> PyResult { - let event_not_none = !cancel_event.is_none(py); + pub fn watch(&self, py: Python, debounce_ms: u64, step_ms: u64, stop_event: PyObject) -> PyResult { + let event_not_none = !stop_event.is_none(py); let mut max_time: Option = None; let step_time = Duration::from_millis(step_ms); let mut last_size: usize = 0; - let none: Option = None; loop { py.allow_threads(|| sleep(step_time)); match py.check_signals() { Ok(_) => (), Err(_) => { self.clear(); - return Ok(none.to_object(py)); + return Ok("signalled".to_object(py)); } }; @@ -130,9 +129,9 @@ impl RustNotify { return Err(WatchfilesRustInternalError::new_err(error.clone())); } - if event_not_none && cancel_event.getattr(py, "is_set")?.call0(py)?.is_true(py)? { + if event_not_none && stop_event.getattr(py, "is_set")?.call0(py)?.is_true(py)? { self.clear(); - return Ok(none.to_object(py)); + return Ok("stopped".to_object(py)); } let size = self.changes.lock().unwrap().len(); diff --git a/tests/conftest.py b/tests/conftest.py index a9c67d91..0d324645 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,7 +73,7 @@ def watch(self, debounce_ms: int, step_ms: int, cancel_event): try: change = next(self.iter_changes) except StopIteration: - return None + return 'signalled' else: self.watch_count += 1 return change diff --git a/tests/test_watch.py b/tests/test_watch.py index 73a186ef..7b409b04 100644 --- a/tests/test_watch.py +++ b/tests/test_watch.py @@ -1,4 +1,5 @@ import sys +import threading from contextlib import contextmanager from pathlib import Path from time import sleep @@ -23,6 +24,16 @@ def test_watch(tmp_path: Path, write_soon): assert changes == {(Change.added, str((tmp_path / 'foo.txt')))} +def test_wait_stop_event(tmp_path: Path, write_soon): + sleep(0.1) + write_soon(tmp_path / 'foo.txt') + + stop_event = threading.Event() + for changes in watch(tmp_path, watch_filter=None, stop_event=stop_event): + assert changes == {(Change.added, str((tmp_path / 'foo.txt')))} + stop_event.set() + + async def test_awatch(tmp_path: Path, write_soon): sleep(0.1) write_soon(tmp_path / 'foo.txt') @@ -31,7 +42,7 @@ async def test_awatch(tmp_path: Path, write_soon): break -async def test_await_stop(tmp_path: Path, write_soon): +async def test_await_stop_event(tmp_path: Path, write_soon): sleep(0.1) write_soon(tmp_path / 'foo.txt') stop_event = anyio.Event() diff --git a/watchfiles/_rust_notify.pyi b/watchfiles/_rust_notify.pyi index dd3435a5..bcc7591d 100644 --- a/watchfiles/_rust_notify.pyi +++ b/watchfiles/_rust_notify.pyi @@ -1,4 +1,4 @@ -from typing import List, Optional, Protocol, Set, Tuple +from typing import List, Literal, Optional, Protocol, Set, Tuple, Union __all__ = 'RustNotify', 'WatchfilesRustInternalError' @@ -25,8 +25,8 @@ class RustNotify: self, debounce_ms: int, step_ms: int, - cancel_event: Optional[AbstractEvent], - ) -> Optional[Set[Tuple[int, str]]]: + stop_event: Optional[AbstractEvent], + ) -> Union[Literal['signalled', 'stopped'], Set[Tuple[int, str]]]: """ Watch for changes and return a set of `(event_type, path)` tuples. @@ -40,11 +40,12 @@ class RustNotify: debounce_ms: maximum time in milliseconds to group changes over before returning. step_ms: time to wait for new changes in milliseconds, if no changes are detected in this time, and at least one change has been detected, the changes are yielded. - cancel_event: event to check on every iteration to see if this function should return early. + stop_event: event to check on every iteration to see if this function should return early. Returns: - A set of `(event_type, path)` tuples, - the event types are ints which match [`Change`][watchfiles.Change]. + Either a set of `(event_type, path)` tuples + (the event types are ints which match [`Change`][watchfiles.Change]), + `'signalled'` if a signal was received, or `'stopped'` if the `stop_event` was set. """ class WatchfilesRustInternalError(RuntimeError): diff --git a/watchfiles/main.py b/watchfiles/main.py index 6ebb0d6e..d541a07b 100644 --- a/watchfiles/main.py +++ b/watchfiles/main.py @@ -43,17 +43,23 @@ def raw_str(self) -> str: if TYPE_CHECKING: import asyncio + from typing import Protocol import trio AnyEvent = Union[anyio.Event, asyncio.Event, trio.Event] + class AbstractEvent(Protocol): + def is_set(self) -> bool: + ... + def watch( *paths: Union[Path, str], watch_filter: Optional[Callable[['Change', str], bool]] = DefaultFilter(), debounce: int = 1_600, step: int = 50, + stop_event: Optional['AbstractEvent'] = None, debug: bool = False, raise_interrupt: bool = True, ) -> Generator[Set[FileChange], None, None]: @@ -69,6 +75,8 @@ def watch( debounce: maximum time in milliseconds to group changes over before yielding them. step: time to wait for new changes in milliseconds, if no changes are detected in this time, and at least one change has been detected, the changes are yielded. + stop_event: event to stop watching, if this is set, the generator will stop yielding changes, + this can be anything with an `is_set()` method which returns a bool, e.g. `threading.Event()`. debug: whether to print information about all filesystem changes in rust to stdout. raise_interrupt: whether to re-raise `KeyboardInterrupt`s, or suppress the error and just stop iterating. @@ -84,18 +92,20 @@ def watch( """ watcher = RustNotify([str(p) for p in paths], debug) while True: - raw_changes = watcher.watch(debounce, step, None) - if raw_changes is None: + raw_changes = watcher.watch(debounce, step, stop_event) + if raw_changes == 'signalled': if raise_interrupt: raise KeyboardInterrupt else: logger.warning('KeyboardInterrupt caught, stopping watch') return - - changes = _prep_changes(raw_changes, watch_filter) - if changes: - _log_changes(changes) - yield changes + elif raw_changes == 'stopped': + return + else: + changes = _prep_changes(raw_changes, watch_filter) + if changes: + _log_changes(changes) + yield changes async def awatch( @@ -186,7 +196,8 @@ async def signal_handler() -> None: raw_changes = await anyio.to_thread.run_sync(watcher.watch, debounce, step, stop_event_) tg.cancel_scope.cancel() - if raw_changes is None: + # cover both cases here although in theory the watch thread should never get a signal + if raw_changes == 'stopped' or raw_changes == 'signalled': if interrupted: if raise_interrupt: raise KeyboardInterrupt