Skip to content

Commit

Permalink
fix test errors
Browse files Browse the repository at this point in the history
Signed-off-by: Stephanie <yangcao@redhat.com>
  • Loading branch information
yangcao77 committed Jan 22, 2025
1 parent 3421986 commit 409c376
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 36 deletions.
8 changes: 5 additions & 3 deletions ols/src/cache/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,13 @@ def insert_or_append(
old_value = self.get(user_id, conversation_id, skip_user_id_check)
if old_value:
old_value.append(cache_entry)
self.redis_client.set(
key, json.dumps(old_value, default=lambda o: o.to_dict(), cls=MessageEncoder)
)
# self.redis_client.set(
# key, json.dumps(old_value, default=lambda o: o.to_dict(), cls=MessageEncoder)
# )
self.redis_client.set(key, json.dumps([entry.to_dict() for entry in old_value], cls=MessageEncoder))
else:
self.redis_client.set(key, json.dumps([cache_entry.to_dict()], cls=MessageEncoder))



def delete(self, user_id: str, conversation_id: str, skip_user_id_check: bool=False) -> bool:
Expand Down
32 changes: 19 additions & 13 deletions ols/src/prompts/prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
PromptTemplate,
SystemMessagePromptTemplate,
)

from copy import copy
from ols.constants import ModelFamily
from ols.customize import prompts
from langchain_core.messages import BaseMessage
Expand All @@ -28,16 +28,19 @@ def restructure_rag_context_post(text: str, model: str) -> str:
return "\n" + text.lstrip("\n") + "\n"


# def restructure_history(message: BaseMessage , model: str) -> BaseMessage:
# """Restructure history."""
# if ModelFamily.GRANITE not in model:
# # No processing required here for gpt.
# return message
def restructure_history(message: BaseMessage , model: str) -> BaseMessage:
"""Restructure history."""
if ModelFamily.GRANITE not in model:
# No processing required here for gpt.
return message

# # Granite specific formatting for history
# if isinstance(message, HumanMessage):
# return "\n<|user|>\n" + message.content
# return "\n<|assistant|>\n" + message.content
newMessage = copy(message)
# Granite specific formatting for history
if isinstance(message, HumanMessage):
newMessage.content = "\n<|user|>\n" + message.content
else:
newMessage.content = "\n<|assistant|>\n" + message.content
return newMessage


class GeneratePrompt:
Expand Down Expand Up @@ -107,7 +110,10 @@ def _generate_prompt_granite(self) -> tuple[PromptTemplate, dict]:
prompt_message = (
prompt_message + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip()
)
llm_input_values["chat_history"] = "".join(self._history)
llm_input_values["chat_history"] = ""
for message in self._history:
llm_input_values["chat_history"] += message.content
# llm_input_values["chat_history"] = "".join(self._history)

if "context" in llm_input_values:
prompt_message = prompt_message + "\n{context}"
Expand All @@ -122,6 +128,6 @@ def generate_prompt(
self, model: str
) -> tuple[ChatPromptTemplate | PromptTemplate, dict]:
"""Generate prompt."""
# if ModelFamily.GRANITE in model:
# return self._generate_prompt_granite()
if ModelFamily.GRANITE in model:
return self._generate_prompt_granite()
return self._generate_prompt_gpt()
7 changes: 3 additions & 4 deletions ols/utils/token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
restructure_rag_context_pre,
)

from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from langchain_core.messages import BaseMessage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -188,9 +188,8 @@ def limit_conversation_history(

for original_message in reversed(history):
# Restructure messages as per model
message = original_message #restructure_history(original_message, model)
print("Message is: ", message)
message_length = TokenHandler._get_token_count(self.text_to_tokens(message.content))
message = restructure_history(original_message, model)
message_length = TokenHandler._get_token_count(self.text_to_tokens(f"{message.type}: {message.content}"))
total_length += message_length
# if total length of already checked messages is higher than limit
# then skip all remaining messages (we need to skip from top)
Expand Down
9 changes: 5 additions & 4 deletions tests/benchmarks/test_prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PROVIDER_WATSONX,
)
from ols.src.prompts.prompt_generator import GeneratePrompt
from langchain_core.messages import AIMessage, HumanMessage

# providers and models used by parametrized benchmarks
provider_and_model = (
Expand All @@ -38,17 +39,17 @@ def empty_history():
def conversation_history():
"""Non-empty conversation history."""
return [
"First human message",
"First AI response",
HumanMessage("First human message"),
AIMessage("First AI response"),
] * 50


@pytest.fixture
def long_history():
"""Long conversation history."""
return [
"First human message",
"First AI response",
HumanMessage("First human message"),
AIMessage("First AI response"),
] * 10000


Expand Down
10 changes: 5 additions & 5 deletions tests/unit/prompts/test_prompt_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from ols.src.prompts.prompt_generator import (
GeneratePrompt,
#restructure_history,
restructure_history,
restructure_rag_context_post,
restructure_rag_context_pre,
)
Expand All @@ -39,10 +39,10 @@ def _restructure_prompt_input(rag_context, conversation_history, model):
restructure_rag_context_post(restructure_rag_context_pre(text, model), model)
for text in rag_context
]
# history_formatted = [
# restructure_history(history, model) for history in conversation_history
# ]
return rag_formatted, conversation_history
history_formatted = [
restructure_history(history, model) for history in conversation_history
]
return rag_formatted, history_formatted


@pytest.mark.parametrize("model", model)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/query_helpers/test_docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tests.mock_classes.mock_llama_index import MockLlamaIndex
from tests.mock_classes.mock_llm_chain import mock_llm_chain
from tests.mock_classes.mock_llm_loader import mock_llm_loader
from langchain_core.messages import HumanMessage

conversation_id = suid.get_suid()

Expand Down Expand Up @@ -114,7 +115,7 @@ def test_summarize_truncation():
rag_index = MockLlamaIndex()

# too long history
history = ["human: What is Kubernetes?"] * 10000
history = [HumanMessage("What is Kubernetes?")] * 10000
summary = summarizer.create_response(question, rag_index, history)

# truncation should be done
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/utils/test_token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ols.constants import TOKEN_BUFFER_WEIGHT, ModelFamily
from ols.utils.token_handler import PromptTooLongError, TokenHandler
from tests.mock_classes.mock_retrieved_node import MockRetrievedNode
from langchain_core.messages import HumanMessage, AIMessage


class TestTokenHandler(TestCase):
Expand Down Expand Up @@ -198,12 +199,12 @@ def test_limit_conversation_history_when_no_history_exists(self):
def test_limit_conversation_history(self):
"""Check the behaviour of limiting long conversation history."""
history = [
"human: first message from human",
"ai: first answer from AI",
"human: second message from human",
"ai: second answer from AI",
"human: third message from human",
"ai: third answer from AI",
HumanMessage("first message from human"),
AIMessage("first answer from AI"),
HumanMessage("second message from human"),
AIMessage("second answer from AI"),
HumanMessage("third message from human"),
AIMessage("third answer from AI"),
]
# for each of the above actual messages the tokens count is 4.
# then 2 tokens for the tags. Total tokens are 6.
Expand Down

0 comments on commit 409c376

Please sign in to comment.