Skip to content

Commit

Permalink
Merge pull request #2183 from opentensor/fix/streaming-synapse
Browse files Browse the repository at this point in the history
Merge streaming fix to staging
  • Loading branch information
ibraheem-opentensor authored Aug 7, 2024
2 parents f9677da + 45bc4ea commit 40d453e
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 24 deletions.
82 changes: 63 additions & 19 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
import traceback
import typing
import uuid
from inspect import Parameter, Signature, signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import warnings
from inspect import signature, Signature, Parameter
from typing import List, Optional, Tuple, Callable, Any, Dict, Awaitable

import uvicorn
from fastapi import APIRouter, Depends, FastAPI
Expand Down Expand Up @@ -485,17 +486,50 @@ def verify_custom(synapse: MyCustomSynapse):

async def endpoint(*args, **kwargs):
start_time = time.time()
response_synapse = forward_fn(*args, **kwargs)
if isinstance(response_synapse, Awaitable):
response_synapse = await response_synapse
return await self.middleware_cls.synapse_to_response(
synapse=response_synapse, start_time=start_time
)
response = forward_fn(*args, **kwargs)
if isinstance(response, Awaitable):
response = await response
if isinstance(response, bittensor.Synapse):
return await self.middleware_cls.synapse_to_response(
synapse=response, start_time=start_time
)
else:
response_synapse = getattr(response, "synapse", None)
if response_synapse is None:
warnings.warn(
"The response synapse is None. The input synapse will be used as the response synapse. "
"Reliance on forward_fn modifying input synapse as a side-effects is deprecated. "
"Explicitly set `synapse` on response object instead.",
DeprecationWarning,
)
# Replace with `return response` in next major version
response_synapse = args[0]

return await self.middleware_cls.synapse_to_response(
synapse=response_synapse,
start_time=start_time,
response_override=response,
)

return_annotation = forward_sig.return_annotation

if isinstance(return_annotation, type) and issubclass(
return_annotation, bittensor.Synapse
):
if issubclass(
return_annotation,
bittensor.StreamingSynapse,
):
warnings.warn(
"The forward_fn return annotation is a subclass of bittensor.StreamingSynapse. "
"Most likely the correct return annotation would be BTStreamingResponse."
)
else:
return_annotation = JSONResponse

# replace the endpoint signature, but set return annotation to JSONResponse
endpoint.__signature__ = Signature( # type: ignore
parameters=list(forward_sig.parameters.values()),
return_annotation=JSONResponse,
return_annotation=return_annotation,
)

# Add the endpoint to the router, making it available on both GET and POST methods
Expand Down Expand Up @@ -1433,14 +1467,21 @@ async def run(

@classmethod
async def synapse_to_response(
cls, synapse: bittensor.Synapse, start_time: float
) -> JSONResponse:
cls,
synapse: bittensor.Synapse,
start_time: float,
*,
response_override: Optional[Response] = None,
) -> Response:
"""
Converts the Synapse object into a JSON response with HTTP headers.
Args:
synapse (bittensor.Synapse): The Synapse object representing the request.
start_time (float): The timestamp when the request processing started.
synapse: The Synapse object representing the request.
start_time: The timestamp when the request processing started.
response_override:
Instead of serializing the synapse, mutate the provided response object.
This is only really useful for StreamingSynapse responses.
Returns:
Response: The final HTTP response, with updated headers, ready to be sent back to the client.
Expand All @@ -1459,11 +1500,14 @@ async def synapse_to_response(

synapse.axon.process_time = time.time() - start_time

serialized_synapse = await serialize_response(response_content=synapse)
response = JSONResponse(
status_code=synapse.axon.status_code,
content=serialized_synapse,
)
if response_override:
response = response_override
else:
serialized_synapse = await serialize_response(response_content=synapse)
response = JSONResponse(
status_code=synapse.axon.status_code,
content=serialized_synapse,
)

try:
updated_headers = synapse.to_headers()
Expand Down
14 changes: 12 additions & 2 deletions bittensor/stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

from aiohttp import ClientResponse
import bittensor

Expand Down Expand Up @@ -49,16 +51,24 @@ class BTStreamingResponse(_StreamingResponse):
provided by the subclass.
"""

def __init__(self, model: BTStreamingResponseModel, **kwargs):
def __init__(
self,
model: BTStreamingResponseModel,
*,
synapse: typing.Optional["StreamingSynapse"] = None,
**kwargs,
):
"""
Initializes the BTStreamingResponse with the given token streamer model.
Args:
model: A BTStreamingResponseModel instance containing the token streamer callable, which is responsible for generating the content of the response.
synapse: The response Synapse to be used to update the response headers etc.
**kwargs: Additional keyword arguments passed to the parent StreamingResponse class.
"""
super().__init__(content=iter(()), **kwargs)
self.token_streamer = model.token_streamer
self.synapse = synapse

async def stream_response(self, send: Send):
"""
Expand Down Expand Up @@ -139,4 +149,4 @@ def create_streaming_response(
"""
model_instance = BTStreamingResponseModel(token_streamer=token_streamer)

return self.BTStreamingResponse(model_instance)
return self.BTStreamingResponse(model_instance, synapse=self)
107 changes: 104 additions & 3 deletions tests/unit_tests/test_axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,21 @@
import time
from dataclasses import dataclass

from typing import Any
from typing import Any, Optional
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

# Third Party
import fastapi
import netaddr

import pydantic
import pytest
from starlette.requests import Request
from fastapi.testclient import TestClient

# Bittensor
import bittensor
from bittensor import Synapse, RunException
from bittensor import Synapse, RunException, StreamingSynapse
from bittensor.axon import AxonMiddleware
from bittensor.axon import axon as Axon
from bittensor.utils.axon_utils import allowed_nonce_window_ns, calculate_diff_seconds
Expand Down Expand Up @@ -538,6 +539,39 @@ def http_client(self, axon):
async def no_verify_fn(self, synapse):
return

class NonDeterministicHeaders(pydantic.BaseModel):
"""
Helper class to verify headers.
Size headers are non-determistic as for example, header_size depends on non-deterministic
processing-time value.
"""

bt_header_axon_process_time: float = pydantic.Field(gt=0, lt=30)
timeout: float = pydantic.Field(gt=0, lt=30)
header_size: int = pydantic.Field(None, gt=10, lt=400)
total_size: int = pydantic.Field(gt=100, lt=10000)
content_length: Optional[int] = pydantic.Field(
None, alias="content-length", gt=100, lt=10000
)

def assert_headers(self, response, expected_headers):
expected_headers = {
"bt_header_axon_status_code": "200",
"bt_header_axon_status_message": "Success",
**expected_headers,
}
headers = dict(response.headers)
non_deterministic_headers_names = {
field.alias or field_name
for field_name, field in self.NonDeterministicHeaders.model_fields.items()
}
non_deterministic_headers = {
field: headers.pop(field, None) for field in non_deterministic_headers_names
}
assert headers == expected_headers
self.NonDeterministicHeaders.model_validate(non_deterministic_headers)

async def test_unknown_path(self, http_client):
response = http_client.get("/no_such_path")
assert (response.status_code, response.json()) == (
Expand All @@ -563,6 +597,14 @@ async def test_ping__without_verification(self, http_client, axon):
assert response.status_code == 200
response_synapse = Synapse(**response.json())
assert response_synapse.axon.status_code == 200
self.assert_headers(
response,
{
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
"content-type": "application/json",
"name": "Synapse",
},
)

@pytest.fixture
def custom_synapse_cls(self):
Expand All @@ -571,6 +613,17 @@ class CustomSynapse(Synapse):

return CustomSynapse

@pytest.fixture
def streaming_synapse_cls(self):
class CustomStreamingSynapse(StreamingSynapse):
async def process_streaming_response(self, response):
pass

def extract_response_json(self, response) -> dict:
return {}

return CustomStreamingSynapse

async def test_synapse__explicitly_set_status_code(
self, http_client, axon, custom_synapse_cls, no_verify_axon
):
Expand Down Expand Up @@ -678,3 +731,51 @@ def test_nonce_within_allowed_window(nonce_offset_seconds, expected_result):
result = is_nonce_within_allowed_window(synapse_nonce, allowed_window_ns)

assert result == expected_result, f"Expected {expected_result} but got {result}"

@pytest.mark.parametrize(
"forward_fn_return_annotation",
[
None,
fastapi.Response,
bittensor.StreamingSynapse,
],
)
async def test_streaming_synapse(
self,
http_client,
axon,
streaming_synapse_cls,
no_verify_axon,
forward_fn_return_annotation,
):
tokens = [f"data{i}\n" for i in range(10)]

async def streamer(send):
for token in tokens:
await send(
{
"type": "http.response.body",
"body": token.encode(),
"more_body": True,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def forward_fn(synapse: streaming_synapse_cls):
return synapse.create_streaming_response(token_streamer=streamer)

if forward_fn_return_annotation is not None:
forward_fn.__annotations__["return"] = forward_fn_return_annotation

axon.attach(forward_fn)

response = http_client.post_synapse(streaming_synapse_cls())
assert (response.status_code, response.text) == (200, "".join(tokens))
self.assert_headers(
response,
{
"content-type": "text/event-stream",
"name": "CustomStreamingSynapse",
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
},
)

0 comments on commit 40d453e

Please sign in to comment.