Skip to content

Commit

Permalink
[auth_broker]: regress test
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate committed Oct 28, 2024
1 parent 8dd555d commit 33751f2
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 16 deletions.
31 changes: 16 additions & 15 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pytest-repeat = "^0.9.3"
websockets = "^12.0"
clickhouse-connect = "^0.7.16"
kafka-python = "^2.0.2"
jwcrypto = "^1.5.6"
h2 = "^4.1.0"

[tool.poetry.group.dev.dependencies]
mypy = "==1.3.0"
Expand Down
1 change: 1 addition & 0 deletions test_runner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pytest_plugins = (
"fixtures.pg_version",
"fixtures.parametrize",
"fixtures.h2server",
"fixtures.httpserver",
"fixtures.compute_reconfigure",
"fixtures.storage_controller_proxy",
Expand Down
216 changes: 216 additions & 0 deletions test_runner/fixtures/h2server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
https://python-hyper.org/projects/hyper-h2/en/stable/asyncio-example.html
auth-broker -> local-proxy needs a h2 connection, so we need a h2 server :)
"""

from __future__ import annotations

import asyncio
import collections
import io
import json
from collections.abc import AsyncGenerator, Iterator
from typing import List, Tuple

import pytest
import pytest_asyncio
from h2.config import H2Configuration
from h2.connection import H2Connection
from h2.errors import ErrorCodes
from h2.events import (
ConnectionTerminated,
DataReceived,
RemoteSettingsChanged,
RequestReceived,
StreamEnded,
StreamReset,
WindowUpdated,
)
from h2.exceptions import ProtocolError, StreamClosedError
from h2.settings import SettingCodes

from fixtures.port_distributor import PortDistributor

RequestData = collections.namedtuple('RequestData', ['headers', 'data'])

class H2Server:
def __init__(self, host, port) -> None:
self.host = host
self.port = port


class H2Protocol(asyncio.Protocol):
def __init__(self):
config = H2Configuration(client_side=False, header_encoding='utf-8')
self.conn = H2Connection(config=config)
self.transport = None
self.stream_data = {}
self.flow_control_futures = {}

def connection_made(self, transport: asyncio.Transport):
self.transport = transport
self.conn.initiate_connection()
self.transport.write(self.conn.data_to_send())

def connection_lost(self, exc):
for future in self.flow_control_futures.values():
future.cancel()
self.flow_control_futures = {}

def data_received(self, data: bytes):
assert self.transport is not None
try:
events = self.conn.receive_data(data)
except ProtocolError as e:
self.transport.write(self.conn.data_to_send())
self.transport.close()
else:
self.transport.write(self.conn.data_to_send())
for event in events:
if isinstance(event, RequestReceived):
self.request_received(event.headers, event.stream_id)
elif isinstance(event, DataReceived):
self.receive_data(event.data, event.stream_id)
elif isinstance(event, StreamEnded):
self.stream_complete(event.stream_id)
elif isinstance(event, ConnectionTerminated):
self.transport.close()
elif isinstance(event, StreamReset):
self.stream_reset(event.stream_id)
elif isinstance(event, WindowUpdated):
self.window_updated(event.stream_id, event.delta)
elif isinstance(event, RemoteSettingsChanged):
if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
self.window_updated(None, 0)

self.transport.write(self.conn.data_to_send())

def request_received(self, headers: List[Tuple[str, str]], stream_id: int):
headers_map = collections.OrderedDict(headers)
# method = headers_map[':method']

# Store off the request data.
request_data = RequestData(headers_map, io.BytesIO())
self.stream_data[stream_id] = request_data

def stream_complete(self, stream_id: int):
"""
When a stream is complete, we can send our response.
"""
try:
request_data = self.stream_data[stream_id]
except KeyError:
# Just return, we probably 405'd this already
return

headers = request_data.headers
body = request_data.data.getvalue().decode('utf-8')

data = json.dumps(
{"headers": headers, "body": body}, indent=4
).encode("utf8")

response_headers = (
(':status', '200'),
('content-type', 'application/json'),
('content-length', str(len(data))),
)
self.conn.send_headers(stream_id, response_headers)
asyncio.ensure_future(self.send_data(data, stream_id))

def receive_data(self, data: bytes, stream_id: int):
"""
We've received some data on a stream. If that stream is one we're
expecting data on, save it off. Otherwise, reset the stream.
"""
try:
stream_data = self.stream_data[stream_id]
except KeyError:
self.conn.reset_stream(
stream_id, error_code=ErrorCodes.PROTOCOL_ERROR
)
else:
stream_data.data.write(data)

def stream_reset(self, stream_id):
"""
A stream reset was sent. Stop sending data.
"""
if stream_id in self.flow_control_futures:
future = self.flow_control_futures.pop(stream_id)
future.cancel()

async def send_data(self, data, stream_id):
"""
Send data according to the flow control rules.
"""
while data:
while self.conn.local_flow_control_window(stream_id) < 1:
try:
await self.wait_for_flow_control(stream_id)
except asyncio.CancelledError:
return

chunk_size = min(
self.conn.local_flow_control_window(stream_id),
len(data),
self.conn.max_outbound_frame_size,
)

try:
self.conn.send_data(
stream_id,
data[:chunk_size],
end_stream=(chunk_size == len(data))
)
except (StreamClosedError, ProtocolError):
# The stream got closed and we didn't get told. We're done
# here.
break

assert self.transport is not None
self.transport.write(self.conn.data_to_send())
data = data[chunk_size:]

async def wait_for_flow_control(self, stream_id):
"""
Waits for a Future that fires when the flow control window is opened.
"""
f = asyncio.Future()
self.flow_control_futures[stream_id] = f
await f

def window_updated(self, stream_id, delta):
"""
A window update frame was received. Unblock some number of flow control
Futures.
"""
if stream_id and stream_id in self.flow_control_futures:
f = self.flow_control_futures.pop(stream_id)
f.set_result(delta)
elif not stream_id:
for f in self.flow_control_futures.values():
f.set_result(delta)

self.flow_control_futures = {}


@pytest_asyncio.fixture(scope="function")
async def http2_echoserver(http2_echoserver_listen_address: tuple[str, int]) -> AsyncGenerator[H2Server]:
host, port = http2_echoserver_listen_address

loop = asyncio.get_event_loop()
serve = await loop.create_server(H2Protocol, host, port)
asyncio.create_task(serve.wait_closed())

server = H2Server(host, port)
yield server

serve.close()


@pytest.fixture(scope="function")
def http2_echoserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
port = port_distributor.get_port()
return ("localhost", port)
Loading

0 comments on commit 33751f2

Please sign in to comment.