Skip to content

Commit 55d24dc

Browse files
authored
fix: handle network being temporarily unavailable (#57)
1 parent 0934a7d commit 55d24dc

File tree

2 files changed

+232
-22
lines changed

2 files changed

+232
-22
lines changed

src/aiodhcpwatcher/__init__.py

+53-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import logging
55
import os
6+
import socket
67
from dataclasses import dataclass
78
from functools import partial
89
from typing import TYPE_CHECKING, Any, Callable, Iterable
@@ -17,6 +18,7 @@
1718

1819
FILTER = "udp and (port 67 or 68)"
1920
DHCP_REQUEST = 3
21+
AUTO_RECOVER_TIME = 30
2022

2123

2224
@dataclass(slots=True)
@@ -85,12 +87,46 @@ class AIODHCPWatcher:
8587
def __init__(self, callback: Callable[[DHCPRequest], None]) -> None:
8688
"""Initialize watcher."""
8789
self._loop = asyncio.get_running_loop()
88-
self._sock: Any | None = None
90+
self._sock: socket.socket | None = None
8991
self._fileno: int | None = None
9092
self._callback = callback
93+
self._shutdown: bool = False
94+
self._restart_timer: asyncio.TimerHandle | None = None
95+
self._restart_task: asyncio.Task[None] | None = None
96+
97+
def restart_soon(self) -> None:
98+
"""Restart the watcher soon."""
99+
if not self._restart_timer:
100+
_LOGGER.debug("Restarting watcher in %s seconds", AUTO_RECOVER_TIME)
101+
self._restart_timer = self._loop.call_later(
102+
AUTO_RECOVER_TIME, self._execute_restart
103+
)
104+
105+
def _clear_restart_task(self, task: asyncio.Task[None]) -> None:
106+
"""Clear the restart task."""
107+
self._restart_task = None
108+
109+
def _execute_restart(self) -> None:
110+
"""Execute the restart."""
111+
self._restart_timer = None
112+
if not self._shutdown:
113+
_LOGGER.debug("Restarting watcher")
114+
self._restart_task = self._loop.create_task(self.async_start())
115+
self._restart_task.add_done_callback(self._clear_restart_task)
116+
117+
def shutdown(self) -> None:
118+
"""Shutdown the watcher."""
119+
self._shutdown = True
120+
self.stop()
91121

92122
def stop(self) -> None:
93123
"""Stop watching for DHCP packets."""
124+
if self._restart_timer:
125+
self._restart_timer.cancel()
126+
self._restart_timer = None
127+
if self._restart_task:
128+
self._restart_task.cancel()
129+
self._restart_task = None
94130
if self._sock and self._fileno:
95131
self._loop.remove_reader(self._fileno)
96132
self._sock.close()
@@ -129,12 +165,16 @@ def _start(self) -> Callable[["Packet"], None] | None:
129165

130166
async def async_start(self) -> None:
131167
"""Start watching for dhcp packets."""
168+
if self._shutdown:
169+
_LOGGER.debug("Not starting watcher because it is shutdown")
170+
return
132171
if not (
133-
_handle_dhcp_packet := await asyncio.get_running_loop().run_in_executor(
134-
None, self._start
135-
)
172+
_handle_dhcp_packet := await self._loop.run_in_executor(None, self._start)
136173
):
137174
return
175+
if self._shutdown: # may change during the executor call
176+
_LOGGER.debug("Not starting watcher because it is shutdown after init") # type: ignore[unreachable]
177+
return
138178
sock = self._sock
139179
fileno = self._fileno
140180
if TYPE_CHECKING:
@@ -149,6 +189,7 @@ async def async_start(self) -> None:
149189
sock.close()
150190
self._sock = None
151191
self._fileno = None
192+
_LOGGER.debug("Started watching for dhcp packets")
152193

153194
def _on_data(
154195
self, handle_dhcp_packet: Callable[["Packet"], None], sock: Any
@@ -158,9 +199,15 @@ def _on_data(
158199
data = sock.recv()
159200
except (BlockingIOError, InterruptedError):
160201
return
202+
except OSError as ex:
203+
_LOGGER.error("Error while processing dhcp packet: %s", ex)
204+
self.stop()
205+
self.restart_soon()
206+
return
161207
except BaseException as ex: # pylint: disable=broad-except
162208
_LOGGER.exception("Fatal error while processing dhcp packet: %s", ex)
163-
self.stop()
209+
self.shutdown()
210+
return
164211

165212
if data:
166213
handle_dhcp_packet(data)
@@ -209,7 +256,7 @@ async def async_start(callback: Callable[[DHCPRequest], None]) -> Callable[[], N
209256
"""Listen for DHCP requests."""
210257
watcher = AIODHCPWatcher(callback)
211258
await watcher.async_start()
212-
return watcher.stop
259+
return watcher.shutdown
213260

214261

215262
async def async_init() -> None:

tests/test_init.py

+179-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
22
import logging
33
import os
4+
import time
5+
from datetime import datetime, timedelta, timezone
6+
from functools import partial
47
from unittest.mock import MagicMock, patch
58

69
import pytest
@@ -13,11 +16,31 @@
1316
from scapy.layers.l2 import Ether
1417
from scapy.packet import Packet
1518

16-
from aiodhcpwatcher import DHCPRequest, async_init, async_start
19+
from aiodhcpwatcher import AUTO_RECOVER_TIME, DHCPRequest, async_init, async_start
20+
21+
utcnow = partial(datetime.now, timezone.utc)
22+
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
1723

1824
logging.basicConfig(level=logging.DEBUG)
1925

2026

27+
def async_fire_time_changed(utc_datetime: datetime) -> None:
28+
timestamp = utc_datetime.timestamp()
29+
loop = asyncio.get_running_loop()
30+
for task in list(loop._scheduled): # type: ignore[attr-defined]
31+
if not isinstance(task, asyncio.TimerHandle):
32+
continue
33+
if task.cancelled():
34+
continue
35+
36+
mock_seconds_into_future = timestamp - time.time()
37+
future_seconds = task.when() - (loop.time() + _MONOTONIC_RESOLUTION)
38+
39+
if mock_seconds_into_future >= future_seconds:
40+
task._run()
41+
task.cancel()
42+
43+
2144
# connect b8:b7:f1:6d:b5:33 192.168.210.56
2245
RAW_DHCP_REQUEST = (
2346
b"\xff\xff\xff\xff\xff\xff\xb8\xb7\xf1m\xb53\x08\x00E\x00\x01P\x06E"
@@ -161,14 +184,34 @@
161184
)
162185

163186

187+
async def _write_test_packets_to_pipe(w: int) -> None:
188+
for test_packet in (
189+
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
190+
RAW_DHCP_REQUEST,
191+
RAW_DHCP_RENEWAL,
192+
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
193+
DHCP_REQUEST_BAD_UTF8,
194+
DHCP_REQUEST_IDNA,
195+
):
196+
os.write(w, test_packet)
197+
for _ in range(3):
198+
await asyncio.sleep(0)
199+
os.write(w, b"garbage")
200+
for _ in range(3):
201+
await asyncio.sleep(0)
202+
203+
164204
class MockSocket:
165205

166-
def __init__(self, reader: int) -> None:
206+
def __init__(self, reader: int, exc: type[Exception] | None = None) -> None:
167207
self._fileno = reader
168208
self.close = MagicMock()
169209
self.buffer = b""
210+
self.exc = exc
170211

171212
def recv(self) -> Packet:
213+
if self.exc:
214+
raise self.exc
172215
raw = os.read(self._fileno, 1000000)
173216
try:
174217
packet = Ether(raw)
@@ -206,20 +249,7 @@ def _handle_dhcp_packet(data: DHCPRequest) -> None:
206249
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
207250
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
208251
stop = await async_start(_handle_dhcp_packet)
209-
for test_packet in (
210-
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
211-
RAW_DHCP_REQUEST,
212-
RAW_DHCP_RENEWAL,
213-
RAW_DHCP_REQUEST_WITHOUT_HOSTNAME,
214-
DHCP_REQUEST_BAD_UTF8,
215-
DHCP_REQUEST_IDNA,
216-
):
217-
os.write(w, test_packet)
218-
for _ in range(3):
219-
await asyncio.sleep(0)
220-
os.write(w, b"garbage")
221-
for _ in range(3):
222-
await asyncio.sleep(0)
252+
await _write_test_packets_to_pipe(w)
223253

224254
stop()
225255

@@ -255,6 +285,139 @@ def _handle_dhcp_packet(data: DHCPRequest) -> None:
255285
]
256286

257287

288+
@pytest.mark.asyncio
289+
async def test_watcher_fatal_exception(caplog: pytest.LogCaptureFixture) -> None:
290+
"""Test mocking a dhcp packet to the watcher."""
291+
requests: list[DHCPRequest] = []
292+
293+
def _handle_dhcp_packet(data: DHCPRequest) -> None:
294+
requests.append(data)
295+
296+
r, w = os.pipe()
297+
298+
mock_socket = MockSocket(r, ValueError)
299+
with patch(
300+
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
301+
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
302+
stop = await async_start(_handle_dhcp_packet)
303+
await _write_test_packets_to_pipe(w)
304+
305+
stop()
306+
307+
os.close(r)
308+
os.close(w)
309+
assert requests == []
310+
assert "Fatal error while processing dhcp packet" in caplog.text
311+
312+
313+
@pytest.mark.asyncio
314+
async def test_watcher_temp_exception(caplog: pytest.LogCaptureFixture) -> None:
315+
"""Test mocking a dhcp packet to the watcher."""
316+
requests: list[DHCPRequest] = []
317+
318+
def _handle_dhcp_packet(data: DHCPRequest) -> None:
319+
requests.append(data)
320+
321+
r, w = os.pipe()
322+
323+
mock_socket = MockSocket(r, OSError)
324+
with patch(
325+
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
326+
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
327+
stop = await async_start(_handle_dhcp_packet)
328+
await _write_test_packets_to_pipe(w)
329+
os.close(r)
330+
os.close(w)
331+
assert requests == []
332+
assert "Error while processing dhcp packet" in caplog.text
333+
334+
r, w = os.pipe()
335+
mock_socket = MockSocket(r)
336+
with patch(
337+
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
338+
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
339+
340+
async_fire_time_changed(utcnow() + timedelta(seconds=AUTO_RECOVER_TIME))
341+
await asyncio.sleep(0.1)
342+
343+
await _write_test_packets_to_pipe(w)
344+
345+
stop()
346+
347+
os.close(r)
348+
os.close(w)
349+
assert requests == [
350+
DHCPRequest(
351+
ip_address="192.168.107.151", hostname="", mac_address="60:6b:bd:59:e4:b4"
352+
),
353+
DHCPRequest(
354+
ip_address="192.168.210.56",
355+
hostname="connect",
356+
mac_address="b8:b7:f1:6d:b5:33",
357+
),
358+
DHCPRequest(
359+
ip_address="192.168.1.120",
360+
hostname="iRobot-AE9EC12DD3B04885BCBFA36AFB01E1CC",
361+
mac_address="50:14:79:03:85:2c",
362+
),
363+
DHCPRequest(
364+
ip_address="192.168.107.151", hostname="", mac_address="60:6b:bd:59:e4:b4"
365+
),
366+
DHCPRequest(
367+
ip_address="192.168.210.56",
368+
hostname="connec�",
369+
mac_address="b8:b7:f1:6d:b5:33",
370+
),
371+
DHCPRequest(
372+
ip_address="192.168.210.56",
373+
hostname="ó",
374+
mac_address="b8:b7:f1:6d:b5:33",
375+
),
376+
]
377+
378+
379+
@pytest.mark.asyncio
380+
async def test_watcher_stop_after_temp_exception(
381+
caplog: pytest.LogCaptureFixture,
382+
) -> None:
383+
"""Test mocking a dhcp packet to the watcher."""
384+
requests: list[DHCPRequest] = []
385+
386+
def _handle_dhcp_packet(data: DHCPRequest) -> None:
387+
requests.append(data)
388+
389+
r, w = os.pipe()
390+
391+
mock_socket = MockSocket(r, OSError)
392+
with patch(
393+
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
394+
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
395+
stop = await async_start(_handle_dhcp_packet)
396+
await _write_test_packets_to_pipe(w)
397+
398+
os.close(r)
399+
os.close(w)
400+
assert requests == []
401+
assert "Error while processing dhcp packet" in caplog.text
402+
stop()
403+
404+
r, w = os.pipe()
405+
mock_socket = MockSocket(r)
406+
with patch(
407+
"aiodhcpwatcher.AIODHCPWatcher._make_listen_socket", return_value=mock_socket
408+
), patch("aiodhcpwatcher.AIODHCPWatcher._verify_working_pcap"):
409+
410+
async_fire_time_changed(utcnow() + timedelta(seconds=30))
411+
await asyncio.sleep(0)
412+
await _write_test_packets_to_pipe(w)
413+
414+
stop()
415+
416+
os.close(r)
417+
os.close(w)
418+
assert requests == []
419+
420+
258421
@pytest.mark.asyncio
259422
async def test_setup_fails_broken_filtering(caplog: pytest.LogCaptureFixture) -> None:
260423
"""Test that the setup fails when filtering is broken."""

0 commit comments

Comments
 (0)