diff --git a/CHANGELOG.md b/CHANGELOG.md index 897ffb7..05b05a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## Next version + +### ✨ Improved + +* Replaced `TCPSource` internals with the use of `lvmopstools.AsyncSocketHandler` which includes retrying and better error handling. + + ## 1.2.0 - November 24, 2023 ### 🚀 Added diff --git a/cerebro/sources/lvm.py b/cerebro/sources/lvm.py index edbef32..1a440a5 100644 --- a/cerebro/sources/lvm.py +++ b/cerebro/sources/lvm.py @@ -53,14 +53,17 @@ def __init__( self.bucket = self.bucket or "sensors" - async def _read_internal(self) -> list[dict] | None: - if not self.writer or not self.reader: - return + async def _read_internal( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> list[dict] | None: + """Queries the TCP server and returns a list of points.""" - self.writer.write((f"status {self.address}\n").encode()) - await self.writer.drain() + writer.write((f"status {self.address}\n").encode()) + await writer.drain() - data = await asyncio.wait_for(self.reader.readline(), timeout=5) + data = await asyncio.wait_for(reader.readline(), timeout=5) # Not found if data == b"?\n": @@ -131,14 +134,17 @@ def __init__( self.bucket = self.bucket or "sensors" - async def _read_internal(self) -> list[dict] | None: - if not self.writer or not self.reader: - return + async def _read_internal( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> list[dict] | None: + """Queries the TCP server and returns a list of points.""" - self.writer.write((f"@{self.device_id:d}Q?\\").encode()) - await self.writer.drain() + writer.write((f"@{self.device_id:d}Q?\\").encode()) + await writer.drain() - data = await asyncio.wait_for(self.reader.readuntil(b"\\"), timeout=5) + data = await asyncio.wait_for(reader.readuntil(b"\\"), timeout=5) data = data.decode() m = re.match( @@ -201,14 +207,17 @@ def __init__( self.delay = delay self.bucket = self.bucket or "sensors" - async def _read_internal(self) -> list[dict] | None: - if not self.writer or not self.reader: - return + async def _read_internal( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> list[dict] | None: + """Queries the TCP server and returns a list of points.""" - self.writer.write(("~*P*~\n").encode()) - await self.writer.drain() + writer.write(("~*P*~\n").encode()) + await writer.drain() - data = await asyncio.wait_for(self.reader.readline(), timeout=5) + data = await asyncio.wait_for(reader.readline(), timeout=5) data = data.decode() m = re.search(r"\s([\-0-9.]+)\slb", data) diff --git a/cerebro/sources/source.py b/cerebro/sources/source.py index 9fbc7f2..13b74a9 100644 --- a/cerebro/sources/source.py +++ b/cerebro/sources/source.py @@ -10,14 +10,16 @@ import abc import asyncio -from contextlib import suppress from typing import Any, Dict, List, NamedTuple, Optional, Type import rx +from lvmopstools.socket import AsyncSocketHandler from rx.disposable.disposable import Disposable from rx.subject.subject import Subject +from sdsstools.utils import cancel_task + from cerebro import log @@ -156,6 +158,8 @@ def __init__( host: str, port: int, delay: Optional[float] = None, + retry: bool = True, + retrier_params: dict = {}, **kwargs, ): super().__init__(name, **kwargs) @@ -168,6 +172,12 @@ def __init__( self.reader: asyncio.StreamReader | None = None self.writer: asyncio.StreamWriter | None = None + self.handler = AsyncSocketHandler( + self.host, + self.port, + retry=retry, + retrier_params=retrier_params, + ) self._runner: asyncio.Task | None = None async def start(self): @@ -185,16 +195,17 @@ async def stop(self): if not self.running: raise RuntimeError(f"{self.name}: source is not running.") - with suppress(asyncio.CancelledError): - if self._runner: - self._runner.cancel() - await self._runner - self._runner = None + await cancel_task(self._runner) + self._runner = None super().stop() @abc.abstractmethod - async def _read_internal(self) -> list[dict] | None: + async def _read_internal( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> list[dict] | None: """Queries the TCP server and returns a list of points.""" pass @@ -207,41 +218,12 @@ async def read(self, delay=None): while True: # Connect to server try: - if self.writer and self.writer.is_closing(): - self.writer.close() - await self.writer.wait_closed() - - self.reader, self.writer = await asyncio.open_connection( - self.host, - self.port, - ) - - except ( - OSError, - ConnectionError, - ConnectionResetError, - ConnectionRefusedError, - ) as err: - log.warning(f"{self.name}: {err}. Reconnecting in {delay} seconds.") - await asyncio.sleep(delay) - continue - - except BaseException as err: - log.warning(f"{self.name}: Stopping after unknown error {err}.") - await self.stop() - return - - # Communicate with server - try: - points = await self._read_internal() - if not self.reader.at_eof() and points is not None: + points = await self.handler(self._read_internal) + if points is not None: self.on_next(DataPoints(data=points, bucket=self.bucket)) - except asyncio.TimeoutError: - log.warning(f"{self.name}: timed out waiting for the server to reply.") - except Exception as err: - log.warning(f"{self.name}: {str(err)}") + log.warning(f"{self.name}: Error while reading device: {str(err)}") finally: await asyncio.sleep(delay) diff --git a/poetry.lock b/poetry.lock index 0379659..9aaf3de 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1014,6 +1014,28 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] +[[package]] +name = "lvmopstools" +version = "0.1.0a0" +description = "LVM tools and utilities for operations" +optional = false +python-versions = "^3.10,<3.13" +files = [] +develop = false + +[package.dependencies] +astropy = "^6.0.0" +click = "^8.1.7" +pyds9 = "^1.8.1" +sdss-clu = "^2.2.2" +sdsstools = "^1.3.1" + +[package.source] +type = "git" +url = "https://github.com/sdss/lvmopstools" +reference = "main" +resolved_reference = "d4e504cf27d3d4b710b9c2ed7d7d7431aefe5849" + [[package]] name = "makefun" version = "1.15.2" @@ -1736,6 +1758,19 @@ files = [ {file = "pycodestyle-2.11.1.tar.gz", hash = "sha256:41ba0e7afc9752dfb53ced5489e89f8186be00e599e712660695b7a75ff2663f"}, ] +[[package]] +name = "pyds9" +version = "1.8.1" +description = "Python/DS9 connection via XPA (with numpy and pyfits support)" +optional = false +python-versions = "*" +files = [ + {file = "pyds9-1.8.1.tar.gz", hash = "sha256:b4f198f5d29b749f721c491f8384f6293e43ec417bd0492be36bffb5c3904b2a"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "pyerfa" version = "2.0.1.1" @@ -2692,4 +2727,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10,<3.13" -content-hash = "083c7919b01bc39d8e798420f10570b67abcec389a596bf02f9b25a5c5fe5753" +content-hash = "b054c4984293aa4770b12272eb01ae96c62b318b65ecc2533fe7674191008dfc" diff --git a/pyproject.toml b/pyproject.toml index 5be38ee..c1dca42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ rx = "^3.2.0" pymysql = "^1.0.2" peewee = "^3.15.4" asyncudp = "^0.11.0" +lvmopstools = {git="https://github.com/sdss/lvmopstools", branch="main"} [tool.poetry.dev-dependencies] ipython = ">=8.0.0"