Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
add StreamingResponse support to AutoFastAPI (#71)
Browse files Browse the repository at this point in the history
* Implement support for StreamingResponse in AutoFastAPI

* minor fix
  • Loading branch information
SeeknnDestroy authored Nov 1, 2023
1 parent 230b770 commit e2e7608
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
13 changes: 12 additions & 1 deletion autollm/auto/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from typing import Optional, Sequence

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from llama_index import Document
from llama_index.indices.query.base import BaseQueryEngine
from pydantic import BaseModel, Field

from autollm.serve.docs import description, openapi_url, tags_metadata, terms_of_service, title, version
from autollm.serve.utils import load_config_and_initialize_engines
from autollm.serve.utils import load_config_and_initialize_engines, stream_text_data


class FromConfigQueryPayload(BaseModel):
task: str = Field(..., description="Task to execute")
user_query: str = Field(..., description="User's query")
streaming: Optional[bool] = Field(False, description="Flag to enable streaming of response")


class FromEngineQueryPayload(BaseModel):
user_query: str = Field(..., description="User's query")
streaming: Optional[bool] = Field(False, description="Flag to enable streaming of response")


class AutoFastAPI:
Expand Down Expand Up @@ -101,6 +104,10 @@ async def query(payload: FromConfigQueryPayload):
query_engine: BaseQueryEngine = task_name_to_query_engine[task]
response = query_engine.query(user_query)

# Check if the response should be streamed
if payload.streaming:
return StreamingResponse(stream_text_data(response.response))

return response.response

return app
Expand Down Expand Up @@ -162,6 +169,10 @@ async def query(payload: FromEngineQueryPayload):

response = query_engine.query(user_query)

# Check if the response should be streamed
if payload.streaming:
return StreamingResponse(stream_text_data(response.response))

return response.response

return app
11 changes: 11 additions & 0 deletions autollm/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

logging.basicConfig(level=logging.INFO)

STREAMING_CHUNK_SIZE = 16


def load_config_and_initialize_engines(
config_file_path: str,
Expand Down Expand Up @@ -36,3 +38,12 @@ def load_config_and_initialize_engines(
query_engines[task_name] = AutoQueryEngine.from_parameters(documents=documents, **task_params)

return query_engines


def stream_text_data(text_data: str, chunk_size: int = STREAMING_CHUNK_SIZE):
start = 0
end = chunk_size
while start < len(text_data):
yield text_data[start:end]
start = end
end += chunk_size

0 comments on commit e2e7608

Please sign in to comment.