From 409c376829fe1b86ce1fb778e7d6c0f9806b9044 Mon Sep 17 00:00:00 2001 From: Stephanie Date: Wed, 22 Jan 2025 10:55:30 -0500 Subject: [PATCH] fix test errors Signed-off-by: Stephanie --- ols/src/cache/redis_cache.py | 8 +++-- ols/src/prompts/prompt_generator.py | 32 +++++++++++-------- ols/utils/token_handler.py | 7 ++-- tests/benchmarks/test_prompt_generator.py | 9 +++--- tests/unit/prompts/test_prompt_generator.py | 10 +++--- .../query_helpers/test_docs_summarizer.py | 3 +- tests/unit/utils/test_token_handler.py | 13 ++++---- 7 files changed, 46 insertions(+), 36 deletions(-) diff --git a/ols/src/cache/redis_cache.py b/ols/src/cache/redis_cache.py index df7b1e1a..fe139493 100644 --- a/ols/src/cache/redis_cache.py +++ b/ols/src/cache/redis_cache.py @@ -115,11 +115,13 @@ def insert_or_append( old_value = self.get(user_id, conversation_id, skip_user_id_check) if old_value: old_value.append(cache_entry) - self.redis_client.set( - key, json.dumps(old_value, default=lambda o: o.to_dict(), cls=MessageEncoder) - ) + # self.redis_client.set( + # key, json.dumps(old_value, default=lambda o: o.to_dict(), cls=MessageEncoder) + # ) + self.redis_client.set(key, json.dumps([entry.to_dict() for entry in old_value], cls=MessageEncoder)) else: self.redis_client.set(key, json.dumps([cache_entry.to_dict()], cls=MessageEncoder)) + def delete(self, user_id: str, conversation_id: str, skip_user_id_check: bool=False) -> bool: diff --git a/ols/src/prompts/prompt_generator.py b/ols/src/prompts/prompt_generator.py index df818372..5b4b12e4 100644 --- a/ols/src/prompts/prompt_generator.py +++ b/ols/src/prompts/prompt_generator.py @@ -8,7 +8,7 @@ PromptTemplate, SystemMessagePromptTemplate, ) - +from copy import copy from ols.constants import ModelFamily from ols.customize import prompts from langchain_core.messages import BaseMessage @@ -28,16 +28,19 @@ def restructure_rag_context_post(text: str, model: str) -> str: return "\n" + text.lstrip("\n") + "\n" -# def restructure_history(message: BaseMessage , model: str) -> BaseMessage: -# """Restructure history.""" -# if ModelFamily.GRANITE not in model: -# # No processing required here for gpt. -# return message +def restructure_history(message: BaseMessage , model: str) -> BaseMessage: + """Restructure history.""" + if ModelFamily.GRANITE not in model: + # No processing required here for gpt. + return message -# # Granite specific formatting for history -# if isinstance(message, HumanMessage): -# return "\n<|user|>\n" + message.content -# return "\n<|assistant|>\n" + message.content + newMessage = copy(message) + # Granite specific formatting for history + if isinstance(message, HumanMessage): + newMessage.content = "\n<|user|>\n" + message.content + else: + newMessage.content = "\n<|assistant|>\n" + message.content + return newMessage class GeneratePrompt: @@ -107,7 +110,10 @@ def _generate_prompt_granite(self) -> tuple[PromptTemplate, dict]: prompt_message = ( prompt_message + "\n" + prompts.USE_HISTORY_INSTRUCTION.strip() ) - llm_input_values["chat_history"] = "".join(self._history) + llm_input_values["chat_history"] = "" + for message in self._history: + llm_input_values["chat_history"] += message.content + # llm_input_values["chat_history"] = "".join(self._history) if "context" in llm_input_values: prompt_message = prompt_message + "\n{context}" @@ -122,6 +128,6 @@ def generate_prompt( self, model: str ) -> tuple[ChatPromptTemplate | PromptTemplate, dict]: """Generate prompt.""" - # if ModelFamily.GRANITE in model: - # return self._generate_prompt_granite() + if ModelFamily.GRANITE in model: + return self._generate_prompt_granite() return self._generate_prompt_gpt() diff --git a/ols/utils/token_handler.py b/ols/utils/token_handler.py index e971bc92..2b1b2a17 100644 --- a/ols/utils/token_handler.py +++ b/ols/utils/token_handler.py @@ -19,7 +19,7 @@ restructure_rag_context_pre, ) -from langchain_core.messages import AIMessage, HumanMessage, BaseMessage +from langchain_core.messages import BaseMessage logger = logging.getLogger(__name__) @@ -188,9 +188,8 @@ def limit_conversation_history( for original_message in reversed(history): # Restructure messages as per model - message = original_message #restructure_history(original_message, model) - print("Message is: ", message) - message_length = TokenHandler._get_token_count(self.text_to_tokens(message.content)) + message = restructure_history(original_message, model) + message_length = TokenHandler._get_token_count(self.text_to_tokens(f"{message.type}: {message.content}")) total_length += message_length # if total length of already checked messages is higher than limit # then skip all remaining messages (we need to skip from top) diff --git a/tests/benchmarks/test_prompt_generator.py b/tests/benchmarks/test_prompt_generator.py index 389d7303..63ab65d2 100644 --- a/tests/benchmarks/test_prompt_generator.py +++ b/tests/benchmarks/test_prompt_generator.py @@ -16,6 +16,7 @@ PROVIDER_WATSONX, ) from ols.src.prompts.prompt_generator import GeneratePrompt +from langchain_core.messages import AIMessage, HumanMessage # providers and models used by parametrized benchmarks provider_and_model = ( @@ -38,8 +39,8 @@ def empty_history(): def conversation_history(): """Non-empty conversation history.""" return [ - "First human message", - "First AI response", + HumanMessage("First human message"), + AIMessage("First AI response"), ] * 50 @@ -47,8 +48,8 @@ def conversation_history(): def long_history(): """Long conversation history.""" return [ - "First human message", - "First AI response", + HumanMessage("First human message"), + AIMessage("First AI response"), ] * 10000 diff --git a/tests/unit/prompts/test_prompt_generator.py b/tests/unit/prompts/test_prompt_generator.py index 1bd148ac..63cd3d5d 100644 --- a/tests/unit/prompts/test_prompt_generator.py +++ b/tests/unit/prompts/test_prompt_generator.py @@ -16,7 +16,7 @@ ) from ols.src.prompts.prompt_generator import ( GeneratePrompt, - #restructure_history, + restructure_history, restructure_rag_context_post, restructure_rag_context_pre, ) @@ -39,10 +39,10 @@ def _restructure_prompt_input(rag_context, conversation_history, model): restructure_rag_context_post(restructure_rag_context_pre(text, model), model) for text in rag_context ] - # history_formatted = [ - # restructure_history(history, model) for history in conversation_history - # ] - return rag_formatted, conversation_history + history_formatted = [ + restructure_history(history, model) for history in conversation_history + ] + return rag_formatted, history_formatted @pytest.mark.parametrize("model", model) diff --git a/tests/unit/query_helpers/test_docs_summarizer.py b/tests/unit/query_helpers/test_docs_summarizer.py index 7aeb9b21..2681d042 100644 --- a/tests/unit/query_helpers/test_docs_summarizer.py +++ b/tests/unit/query_helpers/test_docs_summarizer.py @@ -15,6 +15,7 @@ from tests.mock_classes.mock_llama_index import MockLlamaIndex from tests.mock_classes.mock_llm_chain import mock_llm_chain from tests.mock_classes.mock_llm_loader import mock_llm_loader +from langchain_core.messages import HumanMessage conversation_id = suid.get_suid() @@ -114,7 +115,7 @@ def test_summarize_truncation(): rag_index = MockLlamaIndex() # too long history - history = ["human: What is Kubernetes?"] * 10000 + history = [HumanMessage("What is Kubernetes?")] * 10000 summary = summarizer.create_response(question, rag_index, history) # truncation should be done diff --git a/tests/unit/utils/test_token_handler.py b/tests/unit/utils/test_token_handler.py index 7befa2aa..f1d7e404 100644 --- a/tests/unit/utils/test_token_handler.py +++ b/tests/unit/utils/test_token_handler.py @@ -8,6 +8,7 @@ from ols.constants import TOKEN_BUFFER_WEIGHT, ModelFamily from ols.utils.token_handler import PromptTooLongError, TokenHandler from tests.mock_classes.mock_retrieved_node import MockRetrievedNode +from langchain_core.messages import HumanMessage, AIMessage class TestTokenHandler(TestCase): @@ -198,12 +199,12 @@ def test_limit_conversation_history_when_no_history_exists(self): def test_limit_conversation_history(self): """Check the behaviour of limiting long conversation history.""" history = [ - "human: first message from human", - "ai: first answer from AI", - "human: second message from human", - "ai: second answer from AI", - "human: third message from human", - "ai: third answer from AI", + HumanMessage("first message from human"), + AIMessage("first answer from AI"), + HumanMessage("second message from human"), + AIMessage("second answer from AI"), + HumanMessage("third message from human"), + AIMessage("third answer from AI"), ] # for each of the above actual messages the tokens count is 4. # then 2 tokens for the tags. Total tokens are 6.