From 33751f28055c6273805311e945904e9fb75c240c Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 28 Oct 2024 13:34:25 +0000 Subject: [PATCH] [auth_broker]: regress test --- poetry.lock | 31 ++-- pyproject.toml | 2 + test_runner/conftest.py | 1 + test_runner/fixtures/h2server.py | 216 ++++++++++++++++++++++++ test_runner/fixtures/neon_fixtures.py | 111 +++++++++++- test_runner/regress/test_auth_broker.py | 39 +++++ 6 files changed, 384 insertions(+), 16 deletions(-) create mode 100644 test_runner/fixtures/h2server.py create mode 100644 test_runner/regress/test_auth_broker.py diff --git a/poetry.lock b/poetry.lock index 7abd79423593..57193ad3707a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1521,6 +1521,21 @@ files = [ [package.dependencies] six = "*" +[[package]] +name = "jwcrypto" +version = "1.5.6" +description = "Implementation of JOSE Web standards" +optional = false +python-versions = ">= 3.8" +files = [ + {file = "jwcrypto-1.5.6-py3-none-any.whl", hash = "sha256:150d2b0ebbdb8f40b77f543fb44ffd2baeff48788be71f67f03566692fd55789"}, + {file = "jwcrypto-1.5.6.tar.gz", hash = "sha256:771a87762a0c081ae6166958a954f80848820b2ab066937dc8b8379d65b1b039"}, +] + +[package.dependencies] +cryptography = ">=3.4" +typing-extensions = ">=4.5.0" + [[package]] name = "kafka-python" version = "2.0.2" @@ -2111,7 +2126,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -2120,8 +2134,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -2603,7 +2615,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3159,16 +3170,6 @@ files = [ {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"}, {file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"}, {file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"}, - {file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"}, - {file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"}, - {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"}, - {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"}, - {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"}, - {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"}, - {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"}, - {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"}, - {file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"}, - {file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"}, @@ -3406,4 +3407,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0f4804119f417edf8e1fbd6d715d2e8d70ad731334fa9570304a2203f83339cf" +content-hash = "f767eaa9cb906a47372540aef37446ae55d37011be844b652eec8fb27a49d866" diff --git a/pyproject.toml b/pyproject.toml index d4926cfb9a97..5c1df1d0c07e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ pytest-repeat = "^0.9.3" websockets = "^12.0" clickhouse-connect = "^0.7.16" kafka-python = "^2.0.2" +jwcrypto = "^1.5.6" +h2 = "^4.1.0" [tool.poetry.group.dev.dependencies] mypy = "==1.3.0" diff --git a/test_runner/conftest.py b/test_runner/conftest.py index 4a3194c69102..84eda52d33ee 100644 --- a/test_runner/conftest.py +++ b/test_runner/conftest.py @@ -3,6 +3,7 @@ pytest_plugins = ( "fixtures.pg_version", "fixtures.parametrize", + "fixtures.h2server", "fixtures.httpserver", "fixtures.compute_reconfigure", "fixtures.storage_controller_proxy", diff --git a/test_runner/fixtures/h2server.py b/test_runner/fixtures/h2server.py new file mode 100644 index 000000000000..81295cb68e2a --- /dev/null +++ b/test_runner/fixtures/h2server.py @@ -0,0 +1,216 @@ +""" +https://python-hyper.org/projects/hyper-h2/en/stable/asyncio-example.html + +auth-broker -> local-proxy needs a h2 connection, so we need a h2 server :) +""" + +from __future__ import annotations + +import asyncio +import collections +import io +import json +from collections.abc import AsyncGenerator, Iterator +from typing import List, Tuple + +import pytest +import pytest_asyncio +from h2.config import H2Configuration +from h2.connection import H2Connection +from h2.errors import ErrorCodes +from h2.events import ( + ConnectionTerminated, + DataReceived, + RemoteSettingsChanged, + RequestReceived, + StreamEnded, + StreamReset, + WindowUpdated, +) +from h2.exceptions import ProtocolError, StreamClosedError +from h2.settings import SettingCodes + +from fixtures.port_distributor import PortDistributor + +RequestData = collections.namedtuple('RequestData', ['headers', 'data']) + +class H2Server: + def __init__(self, host, port) -> None: + self.host = host + self.port = port + + +class H2Protocol(asyncio.Protocol): + def __init__(self): + config = H2Configuration(client_side=False, header_encoding='utf-8') + self.conn = H2Connection(config=config) + self.transport = None + self.stream_data = {} + self.flow_control_futures = {} + + def connection_made(self, transport: asyncio.Transport): + self.transport = transport + self.conn.initiate_connection() + self.transport.write(self.conn.data_to_send()) + + def connection_lost(self, exc): + for future in self.flow_control_futures.values(): + future.cancel() + self.flow_control_futures = {} + + def data_received(self, data: bytes): + assert self.transport is not None + try: + events = self.conn.receive_data(data) + except ProtocolError as e: + self.transport.write(self.conn.data_to_send()) + self.transport.close() + else: + self.transport.write(self.conn.data_to_send()) + for event in events: + if isinstance(event, RequestReceived): + self.request_received(event.headers, event.stream_id) + elif isinstance(event, DataReceived): + self.receive_data(event.data, event.stream_id) + elif isinstance(event, StreamEnded): + self.stream_complete(event.stream_id) + elif isinstance(event, ConnectionTerminated): + self.transport.close() + elif isinstance(event, StreamReset): + self.stream_reset(event.stream_id) + elif isinstance(event, WindowUpdated): + self.window_updated(event.stream_id, event.delta) + elif isinstance(event, RemoteSettingsChanged): + if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings: + self.window_updated(None, 0) + + self.transport.write(self.conn.data_to_send()) + + def request_received(self, headers: List[Tuple[str, str]], stream_id: int): + headers_map = collections.OrderedDict(headers) + # method = headers_map[':method'] + + # Store off the request data. + request_data = RequestData(headers_map, io.BytesIO()) + self.stream_data[stream_id] = request_data + + def stream_complete(self, stream_id: int): + """ + When a stream is complete, we can send our response. + """ + try: + request_data = self.stream_data[stream_id] + except KeyError: + # Just return, we probably 405'd this already + return + + headers = request_data.headers + body = request_data.data.getvalue().decode('utf-8') + + data = json.dumps( + {"headers": headers, "body": body}, indent=4 + ).encode("utf8") + + response_headers = ( + (':status', '200'), + ('content-type', 'application/json'), + ('content-length', str(len(data))), + ) + self.conn.send_headers(stream_id, response_headers) + asyncio.ensure_future(self.send_data(data, stream_id)) + + def receive_data(self, data: bytes, stream_id: int): + """ + We've received some data on a stream. If that stream is one we're + expecting data on, save it off. Otherwise, reset the stream. + """ + try: + stream_data = self.stream_data[stream_id] + except KeyError: + self.conn.reset_stream( + stream_id, error_code=ErrorCodes.PROTOCOL_ERROR + ) + else: + stream_data.data.write(data) + + def stream_reset(self, stream_id): + """ + A stream reset was sent. Stop sending data. + """ + if stream_id in self.flow_control_futures: + future = self.flow_control_futures.pop(stream_id) + future.cancel() + + async def send_data(self, data, stream_id): + """ + Send data according to the flow control rules. + """ + while data: + while self.conn.local_flow_control_window(stream_id) < 1: + try: + await self.wait_for_flow_control(stream_id) + except asyncio.CancelledError: + return + + chunk_size = min( + self.conn.local_flow_control_window(stream_id), + len(data), + self.conn.max_outbound_frame_size, + ) + + try: + self.conn.send_data( + stream_id, + data[:chunk_size], + end_stream=(chunk_size == len(data)) + ) + except (StreamClosedError, ProtocolError): + # The stream got closed and we didn't get told. We're done + # here. + break + + assert self.transport is not None + self.transport.write(self.conn.data_to_send()) + data = data[chunk_size:] + + async def wait_for_flow_control(self, stream_id): + """ + Waits for a Future that fires when the flow control window is opened. + """ + f = asyncio.Future() + self.flow_control_futures[stream_id] = f + await f + + def window_updated(self, stream_id, delta): + """ + A window update frame was received. Unblock some number of flow control + Futures. + """ + if stream_id and stream_id in self.flow_control_futures: + f = self.flow_control_futures.pop(stream_id) + f.set_result(delta) + elif not stream_id: + for f in self.flow_control_futures.values(): + f.set_result(delta) + + self.flow_control_futures = {} + + +@pytest_asyncio.fixture(scope="function") +async def http2_echoserver(http2_echoserver_listen_address: tuple[str, int]) -> AsyncGenerator[H2Server]: + host, port = http2_echoserver_listen_address + + loop = asyncio.get_event_loop() + serve = await loop.create_server(H2Protocol, host, port) + asyncio.create_task(serve.wait_closed()) + + server = H2Server(host, port) + yield server + + serve.close() + + +@pytest.fixture(scope="function") +def http2_echoserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]: + port = port_distributor.get_port() + return ("localhost", port) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 6491069f2084..f1334d968b13 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -13,7 +13,7 @@ import time import uuid from collections import defaultdict -from collections.abc import Iterable, Iterator +from collections.abc import AsyncGenerator, Iterable, Iterator from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import datetime @@ -35,11 +35,13 @@ from _pytest.config import Config from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureRequest +from jwcrypto import jwk # Type-related stuff from psycopg2.extensions import connection as PgConnection from psycopg2.extensions import cursor as PgCursor from psycopg2.extensions import make_dsn, parse_dsn +from pytest_httpserver import HTTPServer from urllib3.util.retry import Retry from fixtures import overlayfs @@ -53,6 +55,7 @@ TimelineId, ) from fixtures.endpoint.http import EndpointHttpClient +from fixtures.h2server import H2Server from fixtures.log_helper import log from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.neon_cli import NeonLocalCli, Pagectl @@ -3139,6 +3142,20 @@ def extra_args(self) -> list[str]: ] return args + class AuthBroker(AuthBackend): + def __init__(self, endpoint: str): + self.endpoint = endpoint + + def extra_args(self) -> list[str]: + args = [ + # Console auth backend params + *["--auth-backend", "console"], + *["--auth-endpoint", self.endpoint], + *["--sql-over-http-pool-opt-in", "false"], + *["--is-auth-broker", "true"], + ] + return args + @dataclass(frozen=True) class Postgres(AuthBackend): pg_conn_url: str @@ -3311,6 +3328,29 @@ async def http2_query(self, query, args, **kwargs): assert response.status_code == expected_code, f"response: {response.json()}" return response.json() + async def auth_broker_query(self, query, args, **kwargs): + # TODO maybe use default values if not provided + user = kwargs["user"] + token = kwargs["token"] + expected_code = kwargs.get("expected_code") + + log.info(f"Executing http query: {query}") + + connstr = f"postgresql://{user}@{self.domain}:{self.proxy_port}/postgres" + async with httpx.AsyncClient(verify=str(self.test_output_dir / "proxy.crt")) as client: + response = await client.post( + f"https://{self.domain}:{self.external_http_port}/sql", + json={"query": query, "params": args}, + headers={ + "Neon-Connection-String": connstr, + "Authorization": f"Bearer {token}", + }, + ) + + if expected_code is not None: + assert response.status_code == expected_code, f"response: {response.json()}" + return response.json() + def get_metrics(self) -> str: request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics") return request_result.text @@ -3456,6 +3496,75 @@ def static_proxy( yield proxy +@pytest.fixture(scope="function") +def neon_authorize_jwk() -> Iterator[jwk.JWK]: + kid = str(uuid.uuid4()) + key = jwk.JWK.generate(kty="RSA", size=2048, alg="RS256", use="sig", kid=kid) + yield key + + +@pytest.fixture(scope="function") +def static_auth_broker( + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, + httpserver: HTTPServer, + neon_authorize_jwk: jwk.JWK, + http2_echoserver: H2Server, +) -> Iterable[NeonProxy]: + """Neon proxy that routes directly to vanilla postgres.""" + + # local_proxy_endpoint = httpserver.url_for("/sql") + # local_proxy_addr = local_proxy_endpoint.removeprefix("http://").removesuffix("/sql") + local_proxy_addr = f"{http2_echoserver.host}:{http2_echoserver.port}" + log.info(f"local_proxy {local_proxy_addr}") + + httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json( + { + "address": local_proxy_addr, + "aux": { + "endpoint_id": "ep-foo-bar-1234", + "branch_id": "br-foo-bar", + "project_id": "foo-bar", + }, + } + ) + httpserver.expect_request(re.compile("^/cplane/endpoints/.+/jwks$")).respond_with_json( + { + "jwks": [ + { + "id": "foo", + "jwks_url": httpserver.url_for("/authorize/jwks.json"), + "provider_name": "test", + "jwt_audience": None, + "role_names": ["anonymous", "authenticated"], + } + ] + } + ) + auth_endpoint = httpserver.url_for("/cplane") + + jwk = neon_authorize_jwk.export_public(as_dict=True) + httpserver.expect_request("/authorize/jwks.json").respond_with_json({"keys": [jwk]}) + + proxy_port = port_distributor.get_port() + mgmt_port = port_distributor.get_port() + http_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + + with NeonProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + proxy_port=proxy_port, + http_port=http_port, + mgmt_port=mgmt_port, + external_http_port=external_http_port, + auth_backend=NeonProxy.AuthBroker(auth_endpoint), + ) as proxy: + proxy.start() + yield proxy + + class Endpoint(PgProtocol, LogUtils): """An object representing a Postgres compute endpoint managed by the control plane.""" diff --git a/test_runner/regress/test_auth_broker.py b/test_runner/regress/test_auth_broker.py new file mode 100644 index 000000000000..de7964965cab --- /dev/null +++ b/test_runner/regress/test_auth_broker.py @@ -0,0 +1,39 @@ +import json + +import pytest +from fixtures.neon_fixtures import NeonProxy +from jwcrypto import jwk, jwt + + +@pytest.mark.asyncio +async def test_auth_broker_happy( + static_auth_broker: NeonProxy, + neon_authorize_jwk: jwk.JWK, +): + """ + Signs a JWT and uses it to authorize a query to local_proxy. + """ + + token = jwt.JWT( + header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, claims={"sub": "user1"} + ) + token.make_signed_token(neon_authorize_jwk) + res = await static_auth_broker.auth_broker_query( + "foo", ["arg1"], user="anonymous", token=token.serialize() + ) + + # local proxy mock just echos back the request + # check that we forward the correct data + + assert ( + res["headers"]["authorization"] == f"Bearer {token.serialize()}" + ), "JWT should be forwarded" + + assert ( + "anonymous" in res["headers"]["neon-connection-string"] + ), "conn string should be forwarded" + + assert json.loads(res["body"]) == { + "query": "foo", + "params": ["arg1"], + }, "Query body should be forwarded"