Skip to content

Commit

Permalink
Adding json_deserialize parameter to aiohttp and httpx transports (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
leszekhanusz authored Feb 8, 2024
1 parent 3a641b1 commit a3f0bd9
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 11 deletions.
18 changes: 9 additions & 9 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums. Default: False.
:param parse_results: Whether gql will try to parse the serialized output
sent by the backend. Can be used to unserialize custom scalars or enums.
sent by the backend. Can be used to deserialize custom scalars or enums.
:param batch_interval: Time to wait in seconds for batching requests together.
Batching is disabled (by default) if 0.
:param batch_max: Maximum number of requests in a single batch.
Expand Down Expand Up @@ -892,7 +892,7 @@ def _execute(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
The extra arguments are passed to the transport execute method."""
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def execute(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: return the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.
Expand Down Expand Up @@ -1057,7 +1057,7 @@ def _execute_batch(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param validate_document: Whether we still need to validate the document.
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def execute_batch(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: return the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.
Expand Down Expand Up @@ -1333,7 +1333,7 @@ async def _subscribe(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
The extra arguments are passed to the transport subscribe method."""
Expand Down Expand Up @@ -1454,7 +1454,7 @@ async def subscribe(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: yield the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.
Expand Down Expand Up @@ -1511,7 +1511,7 @@ async def _execute(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
The extra arguments are passed to the transport execute method."""
Expand Down Expand Up @@ -1617,7 +1617,7 @@ async def execute(
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will unserialize the result.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: return the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.
Expand Down
6 changes: 5 additions & 1 deletion gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
timeout: Optional[int] = None,
ssl_close_timeout: Optional[Union[int, float]] = 10,
json_serialize: Callable = json.dumps,
json_deserialize: Callable = json.loads,
client_session_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the transport with the given aiohttp parameters.
Expand All @@ -64,6 +65,8 @@ def __init__(
to close properly
:param json_serialize: Json serializer callable.
By default json.dumps() function
:param json_deserialize: Json deserializer callable.
By default json.loads() function
:param client_session_args: Dict of extra args passed to
`aiohttp.ClientSession`_
Expand All @@ -81,6 +84,7 @@ def __init__(
self.session: Optional[aiohttp.ClientSession] = None
self.response_headers: Optional[CIMultiDictProxy[str]]
self.json_serialize: Callable = json_serialize
self.json_deserialize: Callable = json_deserialize

async def connect(self) -> None:
"""Coroutine which will create an aiohttp ClientSession() as self.session.
Expand Down Expand Up @@ -328,7 +332,7 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str):
)

try:
result = await resp.json(content_type=None)
result = await resp.json(loads=self.json_deserialize, content_type=None)

if log.isEnabledFor(logging.INFO):
result_text = await resp.text()
Expand Down
6 changes: 5 additions & 1 deletion gql/transport/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,21 @@ def __init__(
self,
url: Union[str, httpx.URL],
json_serialize: Callable = json.dumps,
json_deserialize: Callable = json.loads,
**kwargs,
):
"""Initialize the transport with the given httpx parameters.
:param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'.
:param json_serialize: Json serializer callable.
By default json.dumps() function.
:param json_deserialize: Json deserializer callable.
By default json.loads() function.
:param kwargs: Extra args passed to the `httpx` client.
"""
self.url = url
self.json_serialize = json_serialize
self.json_deserialize = json_deserialize
self.kwargs = kwargs

def _prepare_request(
Expand Down Expand Up @@ -145,7 +149,7 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult:
log.debug("<<< %s", response.text)

try:
result: Dict[str, Any] = response.json()
result: Dict[str, Any] = self.json_deserialize(response.content)

except Exception:
self._raise_response_error(response, "Not a JSON answer")
Expand Down
50 changes: 50 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,56 @@ async def handler(request):
assert expected_log in caplog.text


query_float_str = """
query getPi {
pi
}
"""

query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}'

query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}'


@pytest.mark.asyncio
async def test_aiohttp_json_deserializer(event_loop, aiohttp_server):
from aiohttp import web
from decimal import Decimal
from functools import partial
from gql.transport.aiohttp import AIOHTTPTransport

async def handler(request):
return web.Response(
text=query_float_server_answer,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

json_loads = partial(json.loads, parse_float=Decimal)

transport = AIOHTTPTransport(
url=url,
timeout=10,
json_deserialize=json_loads,
)

async with Client(transport=transport) as session:

query = gql(query_float_str)

# Execute query asynchronously
result = await session.execute(query)

pi = result["pi"]

assert pi == Decimal("3.141592653589793238462643383279502884197")


@pytest.mark.asyncio
async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server):
from aiohttp import web, TCPConnector
Expand Down
51 changes: 51 additions & 0 deletions tests/test_httpx_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,3 +1389,54 @@ async def handler(request):
# Checking that there is no space after the colon in the log
expected_log = '"query":"query getContinents'
assert expected_log in caplog.text


query_float_str = """
query getPi {
pi
}
"""

query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}'

query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}'


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_httpx_json_deserializer(event_loop, aiohttp_server):
from aiohttp import web
from decimal import Decimal
from functools import partial
from gql.transport.httpx import HTTPXAsyncTransport

async def handler(request):
return web.Response(
text=query_float_server_answer,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = str(server.make_url("/"))

json_loads = partial(json.loads, parse_float=Decimal)

transport = HTTPXAsyncTransport(
url=url,
timeout=10,
json_deserialize=json_loads,
)

async with Client(transport=transport) as session:

query = gql(query_float_str)

# Execute query asynchronously
result = await session.execute(query)

pi = result["pi"]

assert pi == Decimal("3.141592653589793238462643383279502884197")

0 comments on commit a3f0bd9

Please sign in to comment.