Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Stephanie <yangcao@redhat.com>
  • Loading branch information
yangcao77 committed Jan 22, 2025
1 parent 409c376 commit 77b2a64
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 23 deletions.
31 changes: 10 additions & 21 deletions ols/app/endpoints/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ols.app.endpoints.ols import (
retrieve_user_id,
retrieve_previous_input,
retrieve_skip_user_id_check
retrieve_skip_user_id_check,
)
from ols.app.models.models import (
ErrorResponse,
Expand All @@ -22,6 +22,7 @@
ChatHistoryResponse,
ConversationDeletionResponse,
ListConversationsResponse,
CacheEntry
)
from ols.src.auth.auth import get_auth_dependency

Expand Down Expand Up @@ -69,8 +70,7 @@ def get_conversation(
List of conversation messages.
"""
# Initialize variables
previous_input = []
chat_history: list[BaseMessage] = []
chat_history = []

user_id = retrieve_user_id(auth)
logger.info("User ID %s", user_id)
Expand All @@ -79,7 +79,11 @@ def get_conversation(
# Log incoming request (after redaction)
logger.info("Getting chat history for user: %s with conversation_id: %s", user_id, conversation_id)
try:
previous_input = retrieve_previous_input(user_id, conversation_id, skip_user_id_check)
chat_history=CacheEntry.cache_entries_to_history(retrieve_previous_input(user_id, conversation_id, skip_user_id_check))
if chat_history.__len__() == 0:
logger.info("No chat history found for user: %s with conversation_id: %s", user_id, conversation_id)
raise Exception( f"Conversation {conversation_id} not found")
return ChatHistoryResponse(chat_history=chat_history)
except Exception as e:
logger.error("Error retrieving previous chat history: %s", e)
raise HTTPException(
Expand All @@ -89,20 +93,6 @@ def get_conversation(
"cause": str(e),
},
)
if previous_input.__len__() == 0:
logger.info("No chat history found for user: %s with conversation_id: %s", user_id, conversation_id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Error retrieving previous chat history",
"cause": f"Conversation {conversation_id} not found"
},
)
for entry in previous_input:
chat_history.append(entry.query)
chat_history.append(entry.response)

return ChatHistoryResponse(chat_history=chat_history)


delete_conversation_response: dict[int | str, dict[str, Any]] = {
Expand All @@ -125,7 +115,7 @@ def get_conversation(
}

@router.delete("/conversations/{conversation_id}", responses=delete_conversation_response)
def get_conversation(
def delete_conversation(
conversation_id: str,
auth: Any = Depends(auth_dependency)
) -> ConversationDeletionResponse:
Expand Down Expand Up @@ -179,7 +169,7 @@ def get_conversation(
}

@router.get("/conversations", responses=list_conversations_response)
def get_conversation(
def list_conversations(
auth: Any = Depends(auth_dependency)
) -> ListConversationsResponse:
"""List all conversations for a given user.
Expand All @@ -198,4 +188,3 @@ def get_conversation(
logger.info("Listing all conversations for user: %s ", user_id)

return ListConversationsResponse(conversations=config.conversation_cache.list(user_id, skip_user_id_check))

2 changes: 2 additions & 0 deletions tests/integration/test_authorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_post_authorized_disabled(caplog):
assert response.json() == {
"user_id": constants.DEFAULT_USER_UID,
"username": constants.DEFAULT_USER_NAME,
"skip_user_id_check": False,
}

# check if the auth checks warning message is found in the log
Expand All @@ -83,6 +84,7 @@ def test_post_authorized_disabled_with_logging_suppressed(caplog):
assert response.json() == {
"user_id": constants.DEFAULT_USER_UID,
"username": constants.DEFAULT_USER_NAME,
"skip_user_id_check": False
}

# check if the auth checks warning message is NOT found in the log
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_authorized_noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ def test_authorized():
assert response is not None
assert response.user_id == user_id_in_request
assert response.username == constants.DEFAULT_USER_NAME
assert response.skip_user_id_check == True
198 changes: 198 additions & 0 deletions tests/integration/test_conversations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch

import requests

from ols import config
from ols.utils import suid
from tests.mock_classes.mock_langchain_interface import mock_langchain_interface
from tests.mock_classes.mock_llm_chain import mock_llm_chain
from tests.mock_classes.mock_llm_loader import mock_llm_loader

@pytest.fixture(scope="function")
def _setup():
"""Setups the test client."""
config.reload_from_yaml_file("tests/config/config_for_integration_tests.yaml")

# app.main need to be imported after the configuration is read
from ols.app.main import app # pylint: disable=C0415

pytest.client = TestClient(app)


@pytest.mark.parametrize("endpoint", ("/conversations/{conversation_id}",))
def test_get_conversation_with_history(_setup, endpoint):
"""Test getting conversation history after creating some chat history."""
# we need to import it here because these modules triggers config
# load too -> causes exception in auth module because of missing config
# values
from ols.app.models.models import CacheEntry

ml = mock_langchain_interface("test response")
with (
patch(
"ols.src.query_helpers.docs_summarizer.LLMChain",
new=mock_llm_chain(None),
),
patch(
"ols.src.query_helpers.query_helper.load_llm",
new=mock_llm_loader(ml()),
),
):
# First create some conversation history
conversation_id = suid.get_suid()

# Make first query to create conversation
response = pytest.client.post(
"/v1/query",
json={
"conversation_id": conversation_id,
"query": "First question",
},
)
assert response.status_code == requests.codes.ok

# Make second query to add to conversation
response = pytest.client.post(
"/v1/query",
json={
"conversation_id": conversation_id,
"query": "Second question",
},
)
assert response.status_code == requests.codes.ok

# Now test getting the conversation history
response = pytest.client.get(endpoint.format(conversation_id=conversation_id))
assert response.status_code == requests.codes.ok

history = response.json()["chat_history"]
assert len(history) == 4 # 2 query + 2 response

# Verify first message
assert history[0]["content"] == "First question"
assert history[0]["type"] == "human"
# Verify first response
assert history[1]["type"] == "ai"

# Verify second message
assert history[2]["content"] == "Second question"
assert history[2]["type"] == "human"
# Verify second response
assert history[3]["type"] == "ai"

@pytest.mark.parametrize("endpoint", ("/conversations",))
def test_list_conversations_with_history(_setup, endpoint):
"""Test listing conversations after creating multiple conversations."""
ml = mock_langchain_interface("test response")
with (
patch(
"ols.src.query_helpers.docs_summarizer.LLMChain",
new=mock_llm_chain(None),
),
patch(
"ols.src.query_helpers.query_helper.load_llm",
new=mock_llm_loader(ml()),
),
):
# Create first conversation
conv_id_1 = suid.get_suid()
response = pytest.client.post(
"/v1/query",
json={
"conversation_id": conv_id_1,
"query": "Question for conversation 1",
},
)
assert response.status_code == requests.codes.ok

# Create second conversation
conv_id_2 = suid.get_suid()
response = pytest.client.post(
"/v1/query",
json={
"conversation_id": conv_id_2,
"query": "Question for conversation 2",
},
)
assert response.status_code == requests.codes.ok

# Test listing conversations
response = pytest.client.get(endpoint)
assert response.status_code == requests.codes.ok

conversations = response.json()["conversations"]
assert len(conversations) >= 2 # May have more from other tests
assert conv_id_1 in conversations
assert conv_id_2 in conversations

@pytest.mark.parametrize("endpoint", ("/conversations/{conversation_id}",))
def test_delete_conversation_with_history(_setup, endpoint):
"""Test deleting a conversation after creating chat history."""
ml = mock_langchain_interface("test response")
with (
patch(
"ols.src.query_helpers.docs_summarizer.LLMChain",
new=mock_llm_chain(None),
),
patch(
"ols.src.query_helpers.query_helper.load_llm",
new=mock_llm_loader(ml()),
),
):
# First create a conversation
conversation_id = suid.get_suid()
response = pytest.client.post(
"/v1/query",
json={
"conversation_id": conversation_id,
"query": "Question to create conversation",
},
)
assert response.status_code == requests.codes.ok

# Verify conversation exists
response = pytest.client.get(endpoint.format(conversation_id=conversation_id))
assert response.status_code == requests.codes.ok
assert len(response.json()["chat_history"]) == 2

# Delete the conversation
response = pytest.client.delete(endpoint.format(conversation_id=conversation_id))
assert response.status_code == requests.codes.ok
assert f"Conversation {conversation_id} successfully deleted" in response.json()["response"]

# Verify conversation is gone
response = pytest.client.get(endpoint.format(conversation_id=conversation_id))
assert response.status_code == requests.codes.internal_server_error
assert "Error retrieving previous chat history" in response.json()["detail"]["response"]

def test_get_conversation_not_found(_setup):
"""Test conversation not found scenario"""
conversation_id = suid.get_suid()

with patch('ols.app.endpoints.ols.retrieve_previous_input', return_value=[]):
response = pytest.client.get(f"/conversations/{conversation_id}")

assert response.status_code == 500
assert response.json()["detail"]["cause"] == f"Conversation {conversation_id} not found"


def test_delete_conversation_not_found(_setup):
"""Test deletion of non-existent conversation"""
conversation_id = suid.get_suid()

with patch('ols.config.conversation_cache.delete', return_value=False):
response = pytest.client.delete(f"/conversations/{conversation_id}")

assert response.status_code == 500
assert response.json()["detail"]["cause"] == f"Conversation {conversation_id} not found"


def test_invalid_conversation_id(_setup):
"""Test handling of invalid conversation ID format"""
invalid_id = "not-a-valid-uuid"
response = pytest.client.get(f"/conversations/{invalid_id}")

assert response.status_code == 500
assert "Invalid conversation ID" in response.json()["detail"]["cause"]
21 changes: 20 additions & 1 deletion tests/mock_classes/mock_redis_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Mock for StrictRedis client."""


import fnmatch
class MockRedisClient:
"""Mock for StrictRedis client.
Expand Down Expand Up @@ -41,3 +41,22 @@ def set(self, key, value, *args, **kwargs):
assert isinstance(value, (str, bytes, int, float))

self.cache[key] = value

def delete(self, key):
"""Return item from cache (implementation of DELETE command)."""
# real Redis accepts keys as strings only
assert isinstance(key, str)

if key in self.cache:
del self.cache[key]
return True # successfuly deleted, return True
return False # Key did not exist, return False

def keys(self, pattern):
"""List keys matching a given pattern (implementation of KEYS command)."""
# real Redis accepts patterns as strings only
assert isinstance(pattern, str)

# Use fnmatch to match keys against the pattern
matching_keys = [key for key in self.cache.keys() if fnmatch.fnmatch(key, pattern)]
return matching_keys
Loading

0 comments on commit 77b2a64

Please sign in to comment.