Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle network being temporarily unavailable #57

Merged
merged 18 commits into from
Feb 2, 2025
Merged
59 changes: 53 additions & 6 deletions src/aiodhcpwatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
import os
import socket
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Iterable
Expand All @@ -17,6 +18,7 @@

FILTER = "udp and (port 67 or 68)"
DHCP_REQUEST = 3
AUTO_RECOVER_TIME = 30


@dataclass(slots=True)
Expand Down Expand Up @@ -85,12 +87,46 @@
def __init__(self, callback: Callable[[DHCPRequest], None]) -> None:
"""Initialize watcher."""
self._loop = asyncio.get_running_loop()
self._sock: Any | None = None
self._sock: socket.socket | None = None
self._fileno: int | None = None
self._callback = callback
self._shutdown: bool = False
self._restart_timer: asyncio.TimerHandle | None = None
self._restart_task: asyncio.Task[None] | None = None

def restart_soon(self) -> None:
"""Restart the watcher soon."""
if not self._restart_timer:
_LOGGER.debug("Restarting watcher in %s seconds", AUTO_RECOVER_TIME)
self._restart_timer = self._loop.call_later(
AUTO_RECOVER_TIME, self._execute_restart
)

def _clear_restart_task(self, task: asyncio.Task[None]) -> None:
"""Clear the restart task."""
self._restart_task = None

def _execute_restart(self) -> None:
"""Execute the restart."""
self._restart_timer = None
if not self._shutdown:
_LOGGER.debug("Restarting watcher")
self._restart_task = self._loop.create_task(self.async_start())
self._restart_task.add_done_callback(self._clear_restart_task)

def shutdown(self) -> None:
"""Shutdown the watcher."""
self._shutdown = True
self.stop()

def stop(self) -> None:
"""Stop watching for DHCP packets."""
if self._restart_timer:
self._restart_timer.cancel()
self._restart_timer = None
if self._restart_task:
self._restart_task.cancel()
self._restart_task = None

Check warning on line 129 in src/aiodhcpwatcher/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/aiodhcpwatcher/__init__.py#L128-L129

Added lines #L128 - L129 were not covered by tests
if self._sock and self._fileno:
self._loop.remove_reader(self._fileno)
self._sock.close()
Expand Down Expand Up @@ -129,12 +165,16 @@

async def async_start(self) -> None:
"""Start watching for dhcp packets."""
if self._shutdown:
_LOGGER.debug("Not starting watcher because it is shutdown")
return

Check warning on line 170 in src/aiodhcpwatcher/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/aiodhcpwatcher/__init__.py#L169-L170

Added lines #L169 - L170 were not covered by tests
if not (
_handle_dhcp_packet := await asyncio.get_running_loop().run_in_executor(
None, self._start
)
_handle_dhcp_packet := await self._loop.run_in_executor(None, self._start)
):
return
if self._shutdown: # may change during the executor call
_LOGGER.debug("Not starting watcher because it is shutdown after init") # type: ignore[unreachable]
return

Check warning on line 177 in src/aiodhcpwatcher/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/aiodhcpwatcher/__init__.py#L176-L177

Added lines #L176 - L177 were not covered by tests
sock = self._sock
fileno = self._fileno
if TYPE_CHECKING:
Expand All @@ -149,6 +189,7 @@
sock.close()
self._sock = None
self._fileno = None
_LOGGER.debug("Started watching for dhcp packets")

def _on_data(
self, handle_dhcp_packet: Callable[["Packet"], None], sock: Any
Expand All @@ -158,9 +199,15 @@
data = sock.recv()
except (BlockingIOError, InterruptedError):
return
except OSError as ex:
_LOGGER.error("Error while processing dhcp packet: %s", ex)
self.stop()
self.restart_soon()
return
except BaseException as ex: # pylint: disable=broad-except
_LOGGER.exception("Fatal error while processing dhcp packet: %s", ex)
self.stop()
self.shutdown()
return

if data:
handle_dhcp_packet(data)
Expand Down Expand Up @@ -209,7 +256,7 @@
"""Listen for DHCP requests."""
watcher = AIODHCPWatcher(callback)
await watcher.async_start()
return watcher.stop
return watcher.shutdown


async def async_init() -> None:
Expand Down
195 changes: 179 additions & 16 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import logging
import os
import time
from datetime import datetime, timedelta, timezone
from functools import partial
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -13,11 +16,31 @@
from scapy.layers.l2 import Ether
from scapy.packet import Packet

from aiodhcpwatcher import DHCPRequest, async_init, async_start
from aiodhcpwatcher import AUTO_RECOVER_TIME, DHCPRequest, async_init, async_start

utcnow = partial(datetime.now, timezone.utc)
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution

logging.basicConfig(level=logging.DEBUG)


def async_fire_time_changed(utc_datetime: datetime) -> None:
timestamp = utc_datetime.timestamp()
loop = asyncio.get_running_loop()
for task in list(loop._scheduled): # type: ignore[attr-defined]
if not isinstance(task, asyncio.TimerHandle):
continue
if task.cancelled():
continue

mock_seconds_into_future = timestamp - time.time()
future_seconds = task.when() - (loop.time() + _MONOTONIC_RESOLUTION)

if mock_seconds_into_future >= future_seconds:
task._run()
task.cancel()


# connect b8:b7:f1:6d:b5:33 192.168.210.56
RAW_DHCP_REQUEST = (
b"\xff\xff\xff\xff\xff\xff\xb8\xb7\xf1m\xb53\x08\x00E\x00\x01P\x06E"
Expand Down Expand Up @@ -161,14 +184,34 @@
)


async def _write_test_packets_to_pipe(w: int) -> None:
for test_packet in (
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
RAW_DHCP_REQUEST,
RAW_DHCP_RENEWAL,
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
DHCP_REQUEST_BAD_UTF8,
DHCP_REQUEST_IDNA,
):
os.write(w, test_packet)
for _ in range(3):
await asyncio.sleep(0)
os.write(w, b"garbage")
for _ in range(3):
await asyncio.sleep(0)


class MockSocket:

def __init__(self, reader: int) -> None:
def __init__(self, reader: int, exc: type[Exception] | None = None) -> None:
self._fileno = reader
self.close = MagicMock()
self.buffer = b""
self.exc = exc

def recv(self) -> Packet:
if self.exc:
raise self.exc
raw = os.read(self._fileno, 1000000)
try:
packet = Ether(raw)
Expand Down Expand Up @@ -206,20 +249,7 @@ def _handle_dhcp_packet(data: DHCPRequest) -> None:
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
stop = await async_start(_handle_dhcp_packet)
for test_packet in (
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
RAW_DHCP_REQUEST,
RAW_DHCP_RENEWAL,
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
DHCP_REQUEST_BAD_UTF8,
DHCP_REQUEST_IDNA,
):
os.write(w, test_packet)
for _ in range(3):
await asyncio.sleep(0)
os.write(w, b"garbage")
for _ in range(3):
await asyncio.sleep(0)
await _write_test_packets_to_pipe(w)

stop()

Expand Down Expand Up @@ -255,6 +285,139 @@ def _handle_dhcp_packet(data: DHCPRequest) -> None:
]


@pytest.mark.asyncio
async def test_watcher_fatal_exception(caplog: pytest.LogCaptureFixture) -> None:
"""Test mocking a dhcp packet to the watcher."""
requests: list[DHCPRequest] = []

def _handle_dhcp_packet(data: DHCPRequest) -> None:
requests.append(data)

r, w = os.pipe()

mock_socket = MockSocket(r, ValueError)
with patch(
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
stop = await async_start(_handle_dhcp_packet)
await _write_test_packets_to_pipe(w)

stop()

os.close(r)
os.close(w)
assert requests == []
assert "Fatal error while processing dhcp packet" in caplog.text


@pytest.mark.asyncio
async def test_watcher_temp_exception(caplog: pytest.LogCaptureFixture) -> None:
"""Test mocking a dhcp packet to the watcher."""
requests: list[DHCPRequest] = []

def _handle_dhcp_packet(data: DHCPRequest) -> None:
requests.append(data)

r, w = os.pipe()

mock_socket = MockSocket(r, OSError)
with patch(
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
stop = await async_start(_handle_dhcp_packet)
await _write_test_packets_to_pipe(w)
os.close(r)
os.close(w)
assert requests == []
assert "Error while processing dhcp packet" in caplog.text

r, w = os.pipe()
mock_socket = MockSocket(r)
with patch(
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):

async_fire_time_changed(utcnow() + timedelta(seconds=AUTO_RECOVER_TIME))
await asyncio.sleep(0.1)

await _write_test_packets_to_pipe(w)

stop()

os.close(r)
os.close(w)
assert requests == [
DHCPRequest(
ip_address="192.168.107.151", hostname="", mac_address="60:6b:bd:59:e4:b4"
),
DHCPRequest(
ip_address="192.168.210.56",
hostname="connect",
mac_address="b8:b7:f1:6d:b5:33",
),
DHCPRequest(
ip_address="192.168.1.120",
hostname="iRobot-AE9EC12DD3B04885BCBFA36AFB01E1CC",
mac_address="50:14:79:03:85:2c",
),
DHCPRequest(
ip_address="192.168.107.151", hostname="", mac_address="60:6b:bd:59:e4:b4"
),
DHCPRequest(
ip_address="192.168.210.56",
hostname="connec�",
mac_address="b8:b7:f1:6d:b5:33",
),
DHCPRequest(
ip_address="192.168.210.56",
hostname="ó",
mac_address="b8:b7:f1:6d:b5:33",
),
]


@pytest.mark.asyncio
async def test_watcher_stop_after_temp_exception(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test mocking a dhcp packet to the watcher."""
requests: list[DHCPRequest] = []

def _handle_dhcp_packet(data: DHCPRequest) -> None:
requests.append(data)

r, w = os.pipe()

mock_socket = MockSocket(r, OSError)
with patch(
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
stop = await async_start(_handle_dhcp_packet)
await _write_test_packets_to_pipe(w)

os.close(r)
os.close(w)
assert requests == []
assert "Error while processing dhcp packet" in caplog.text
stop()

r, w = os.pipe()
mock_socket = MockSocket(r)
with patch(
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):

async_fire_time_changed(utcnow() + timedelta(seconds=30))
await asyncio.sleep(0)
await _write_test_packets_to_pipe(w)

stop()

os.close(r)
os.close(w)
assert requests == []


@pytest.mark.asyncio
async def test_setup_fails_broken_filtering(caplog: pytest.LogCaptureFixture) -> None:
"""Test that the setup fails when filtering is broken."""
Expand Down