Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
Signed-off-by: Stephanie <yangcao@redhat.com>
  • Loading branch information
yangcao77 committed Jan 22, 2025
1 parent 77b2a64 commit 5448fd3
Show file tree
Hide file tree
Showing 25 changed files with 376 additions and 207 deletions.
59 changes: 41 additions & 18 deletions ols/app/endpoints/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ChatHistoryResponse,
ConversationDeletionResponse,
ListConversationsResponse,
CacheEntry
CacheEntry,
)
from ols.src.auth.auth import get_auth_dependency

Expand Down Expand Up @@ -56,8 +56,7 @@

@router.get("/conversations/{conversation_id}", responses=chat_history_response)
def get_conversation(
conversation_id: str,
auth: Any = Depends(auth_dependency)
conversation_id: str, auth: Any = Depends(auth_dependency)
) -> ChatHistoryResponse:
"""Get conversation history for a given conversation ID.
Expand All @@ -77,12 +76,22 @@ def get_conversation(
skip_user_id_check = retrieve_skip_user_id_check(auth)

# Log incoming request (after redaction)
logger.info("Getting chat history for user: %s with conversation_id: %s", user_id, conversation_id)
logger.info(
"Getting chat history for user: %s with conversation_id: %s",
user_id,
conversation_id,
)
try:
chat_history=CacheEntry.cache_entries_to_history(retrieve_previous_input(user_id, conversation_id, skip_user_id_check))
chat_history = CacheEntry.cache_entries_to_history(
retrieve_previous_input(user_id, conversation_id, skip_user_id_check)
)
if chat_history.__len__() == 0:
logger.info("No chat history found for user: %s with conversation_id: %s", user_id, conversation_id)
raise Exception( f"Conversation {conversation_id} not found")
logger.info(
"No chat history found for user: %s with conversation_id: %s",
user_id,
conversation_id,
)
raise Exception(f"Conversation {conversation_id} not found")
return ChatHistoryResponse(chat_history=chat_history)
except Exception as e:
logger.error("Error retrieving previous chat history: %s", e)
Expand Down Expand Up @@ -114,10 +123,12 @@ def get_conversation(
},
}

@router.delete("/conversations/{conversation_id}", responses=delete_conversation_response)

@router.delete(
"/conversations/{conversation_id}", responses=delete_conversation_response
)
def delete_conversation(
conversation_id: str,
auth: Any = Depends(auth_dependency)
conversation_id: str, auth: Any = Depends(auth_dependency)
) -> ConversationDeletionResponse:
"""Delete conversation history for a given conversation ID.
Expand All @@ -133,20 +144,29 @@ def delete_conversation(
skip_user_id_check = retrieve_skip_user_id_check(auth)

# Log incoming request (after redaction)
logger.info("Deleting chat history for user: %s with conversation_id: %s", user_id, conversation_id)
logger.info(
"Deleting chat history for user: %s with conversation_id: %s",
user_id,
conversation_id,
)

if config.conversation_cache.delete(user_id, conversation_id, skip_user_id_check):
return ConversationDeletionResponse(response=f"Conversation {conversation_id} successfully deleted")
return ConversationDeletionResponse(
response=f"Conversation {conversation_id} successfully deleted"
)
else:
logger.info("No chat history found for user: %s with conversation_id: %s", user_id, conversation_id)
logger.info(
"No chat history found for user: %s with conversation_id: %s",
user_id,
conversation_id,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Error deleting conversation",
"cause": f"Conversation {conversation_id} not found"
"cause": f"Conversation {conversation_id} not found",
},
)

)


list_conversations_response: dict[int | str, dict[str, Any]] = {
Expand All @@ -168,9 +188,10 @@ def delete_conversation(
},
}


@router.get("/conversations", responses=list_conversations_response)
def list_conversations(
auth: Any = Depends(auth_dependency)
auth: Any = Depends(auth_dependency),
) -> ListConversationsResponse:
"""List all conversations for a given user.
Expand All @@ -187,4 +208,6 @@ def list_conversations(
# Log incoming request (after redaction)
logger.info("Listing all conversations for user: %s ", user_id)

return ListConversationsResponse(conversations=config.conversation_cache.list(user_id, skip_user_id_check))
return ListConversationsResponse(
conversations=config.conversation_cache.list(user_id, skip_user_id_check)
)
26 changes: 19 additions & 7 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def conversation_request(
attachments,
valid,
timestamps,
skip_user_id_check
skip_user_id_check,
) = process_request(auth, llm_request)

summarizer_response: SummarizerResponse | Generator
Expand All @@ -111,7 +111,12 @@ def conversation_request(
timestamps["generate response"] = time.time()

store_conversation_history(
user_id, conversation_id, llm_request, summarizer_response.response, attachments, skip_user_id_check
user_id,
conversation_id,
llm_request,
summarizer_response.response,
attachments,
skip_user_id_check,
)

if config.ols_config.user_data_collection.transcripts_disabled:
Expand Down Expand Up @@ -158,7 +163,9 @@ def conversation_request(

def process_request(
auth: Any, llm_request: LLMRequest
) -> tuple[str, str, str, list[CacheEntry], list[Attachment], bool, dict[str, float], str]:
) -> tuple[
str, str, str, list[CacheEntry], list[Attachment], bool, dict[str, float], str
]:
"""Process incoming request.
Args:
Expand Down Expand Up @@ -191,7 +198,9 @@ def process_request(
# Log incoming request (after redaction)
logger.info("%s Incoming request: %s", conversation_id, llm_request.query)

previous_input = retrieve_previous_input(user_id, llm_request.conversation_id, skip_user_id_check)
previous_input = retrieve_previous_input(
user_id, llm_request.conversation_id, skip_user_id_check
)
timestamps["retrieve previous input"] = time.time()

# Retrieve attachments from the request
Expand Down Expand Up @@ -225,7 +234,7 @@ def process_request(
attachments,
valid,
timestamps,
skip_user_id_check
skip_user_id_check,
)


Expand Down Expand Up @@ -268,6 +277,7 @@ def retrieve_user_id(auth: Any) -> str:
# auth contains tuple with user ID (in UUID format) and user name
return auth[0]


def retrieve_skip_user_id_check(auth: Any) -> bool:
"""Retrieve skip user_id check from the token processed by auth. mechanism."""
return auth[2]
Expand All @@ -285,7 +295,9 @@ def retrieve_conversation_id(llm_request: LLMRequest) -> str:
return conversation_id


def retrieve_previous_input(user_id: str, conversation_id: str, skip_user_id_check: bool=False) -> list[CacheEntry]:
def retrieve_previous_input(
user_id: str, conversation_id: str, skip_user_id_check: bool = False
) -> list[CacheEntry]:
"""Retrieve previous user input, if exists."""
try:
previous_input = []
Expand Down Expand Up @@ -424,7 +436,7 @@ def store_conversation_history(
llm_request: LLMRequest,
response: Optional[str],
attachments: list[Attachment],
skip_user_id_check: bool=False,
skip_user_id_check: bool = False,
) -> None:
"""Store conversation history into selected cache.
Expand Down
4 changes: 2 additions & 2 deletions ols/app/endpoints/streaming_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def conversation_request(
attachments,
valid,
timestamps,
skip_user_id_check
skip_user_id_check,
) = process_request(auth, llm_request)

summarizer_response = (
Expand Down Expand Up @@ -268,7 +268,7 @@ def store_data(
rag_chunks: list[RagChunk],
history_truncated: bool,
timestamps: dict[str, float],
skip_user_id_check: bool
skip_user_id_check: bool,
) -> None:
"""Store conversation history and transcript if enabled.
Expand Down
41 changes: 23 additions & 18 deletions ols/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class LLMResponse(BaseModel):
}
}


class ChatHistoryResponse(BaseModel):
"""Model representing a response to a list conversation request.
Expand All @@ -227,20 +228,21 @@ class ChatHistoryResponse(BaseModel):
"examples": [
{
"chat_history": [
{
"content": "what is openshift",
"type": "human",
},
{
"content": " OpenShift is a container orchestration platform built by Red Hat...",
"type": "ai",
}
]
{
"content": "what is openshift",
"type": "human",
},
{
"content": " OpenShift is a container orchestration platform built by Red Hat...",
"type": "ai",
},
]
}
]
}
}


class ListConversationsResponse(BaseModel):
"""Model representing a response to a request to retrieve a conversation history.
Expand All @@ -256,15 +258,16 @@ class ListConversationsResponse(BaseModel):
"examples": [
{
"conversations": [
"15a78660-a18e-447b-9fea-9deb27b63b5f",
"c0a3bc27-77cc-46da-822f-93a9c0e0de4b",
"51984bb1-f3a3-4ab2-9df6-cf92c30bbb7f",
]
"15a78660-a18e-447b-9fea-9deb27b63b5f",
"c0a3bc27-77cc-46da-822f-93a9c0e0de4b",
"51984bb1-f3a3-4ab2-9df6-cf92c30bbb7f",
]
}
]
}
}


class ConversationDeletionResponse(BaseModel):
"""Model representing a response to a conversation deletion request.
Expand Down Expand Up @@ -662,7 +665,7 @@ class CacheEntry(BaseModel):
"""

query: HumanMessage
response: Optional[AIMessage]= AIMessage("")
response: Optional[AIMessage] = AIMessage("")
attachments: list[Attachment] = []

@field_validator("response")
Expand Down Expand Up @@ -693,7 +696,9 @@ def from_dict(cls, data: dict) -> Self:
)

@staticmethod
def cache_entries_to_history(cache_entries: list["CacheEntry"]) -> list[BaseMessage]:
def cache_entries_to_history(
cache_entries: list["CacheEntry"],
) -> list[BaseMessage]:
"""Convert cache entries to a history."""
history: list[BaseMessage] = []
for entry in cache_entries:
Expand All @@ -706,18 +711,18 @@ def cache_entries_to_history(cache_entries: list["CacheEntry"]) -> list[BaseMess
return history



class MessageEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (HumanMessage, AIMessage)):
return {
"type": obj.type,
"content": obj.content,
"response_metadata": obj.response_metadata,
"additional_kwargs": obj.additional_kwargs
"additional_kwargs": obj.additional_kwargs,
}
return super().default(obj)


class MessageDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
Expand All @@ -731,4 +736,4 @@ def object_hook(self, dct):
message.additional_kwargs = dct["additional_kwargs"]
message.response_metadata = dct["response_metadata"]
return message
return dct
return dct
9 changes: 8 additions & 1 deletion ols/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from fastapi import FastAPI

from ols.app.endpoints import authorized, feedback, health, ols, streaming_ols, conversations
from ols.app.endpoints import (
authorized,
feedback,
health,
ols,
streaming_ols,
conversations,
)
from ols.app.metrics import metrics


Expand Down
1 change: 1 addition & 0 deletions ols/src/auth/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _extract_bearer_token(header: str) -> str:

class AuthDependency(AuthDependencyInterface):
"""Create an AuthDependency Class that allows customizing the acces Scope path to check."""

skip_userid_check = False

def __init__(self, virtual_path: str = "/ols-access") -> None:
Expand Down
Loading

0 comments on commit 5448fd3

Please sign in to comment.