Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Stephanie <yangcao@redhat.com>
  • Loading branch information
yangcao77 committed Jan 22, 2025
1 parent 5448fd3 commit ea52586
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 59 deletions.
30 changes: 14 additions & 16 deletions ols/app/endpoints/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
)
from ols.src.auth.auth import get_auth_dependency

from langchain_core.messages import BaseMessage

logger = logging.getLogger(__name__)

router = APIRouter(tags=["conversations"])
Expand Down Expand Up @@ -85,7 +83,7 @@ def get_conversation(
chat_history = CacheEntry.cache_entries_to_history(
retrieve_previous_input(user_id, conversation_id, skip_user_id_check)
)
if chat_history.__len__() == 0:
if len(chat_history) == 0:
logger.info(
"No chat history found for user: %s with conversation_id: %s",
user_id,
Expand Down Expand Up @@ -154,19 +152,19 @@ def delete_conversation(
return ConversationDeletionResponse(
response=f"Conversation {conversation_id} successfully deleted"
)
else:
logger.info(
"No chat history found for user: %s with conversation_id: %s",
user_id,
conversation_id,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Error deleting conversation",
"cause": f"Conversation {conversation_id} not found",
},
)

logger.info(
"No chat history found for user: %s with conversation_id: %s",
user_id,
conversation_id,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Error deleting conversation",
"cause": f"Conversation {conversation_id} not found",
},
)


list_conversations_response: dict[int | str, dict[str, Any]] = {
Expand Down
5 changes: 3 additions & 2 deletions ols/app/endpoints/ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytz
from fastapi import APIRouter, Depends, HTTPException, status
from langchain_core.messages import AIMessage, HumanMessage

from ols import config, constants
from ols.app import metrics
Expand All @@ -28,14 +29,14 @@
UnauthorizedResponse,
)
from ols.customize import keywords, prompts
from ols.src.auth.auth import get_auth_dependency, noop
from ols.src.auth.auth import get_auth_dependency
from ols.src.llms.llm_loader import LLMConfigurationError, resolve_provider_config
from ols.src.query_helpers.attachment_appender import append_attachments_to_query
from ols.src.query_helpers.docs_summarizer import DocsSummarizer
from ols.src.query_helpers.question_validator import QuestionValidator
from ols.utils import errors_parsing, suid
from ols.utils.token_handler import PromptTooLongError
from langchain_core.messages import AIMessage, HumanMessage


logger = logging.getLogger(__name__)

Expand Down
58 changes: 47 additions & 11 deletions ols/app/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from pydantic import BaseModel, field_validator, model_validator
from pydantic.dataclasses import dataclass

from langchain_core.messages import AIMessage, HumanMessage, BaseMessage

from ols.constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT
from ols.customize import prompts
from ols.utils import suid


from langchain_core.messages import AIMessage, HumanMessage, BaseMessage


class Attachment(BaseModel):
"""Model representing an attachment that can be send from UI as part of query.
Expand Down Expand Up @@ -712,27 +712,63 @@ def cache_entries_to_history(


class MessageEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (HumanMessage, AIMessage)):
"""Convert Message objects to serializable dictionaries.
Args:
o: The object to serialize. Expected to be either a HumanMessage
or AIMessage instance.
Returns:
dict: A dictionary containing the message attributes if the input is
a Message object.
"""

def default(self, o):
if isinstance(o, (HumanMessage, AIMessage)):
return {
"type": obj.type,
"content": obj.content,
"response_metadata": obj.response_metadata,
"additional_kwargs": obj.additional_kwargs,
"type": o.type,
"content": o.content,
"response_metadata": o.response_metadata,
"additional_kwargs": o.additional_kwargs,
}
return super().default(obj)
return super().default(o)


class MessageDecoder(json.JSONDecoder):
"""Custom JSON decoder for deserializing Message objects.
This decoder extends the default JSONDecoder to handle JSON representations of
HumanMessage and AIMessage objects, converting them back into their respective
Python objects. It processes JSON objects containing 'type', 'content',
'response_metadata', and 'additional_kwargs' fields.
Example:
>>> decoder = MessageDecoder()
>>> json.loads('{"type": "human", "content": "Hello", ...}', cls=MessageDecoder)
HumanMessage(content="Hello", ...)
"""

def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
"""Initialize the MessageDecoder with custom object hook."""
super().__init__(object_hook=self._decode_message, *args, **kwargs)

def _decode_message(self, dct):
"""Decode JSON dictionary into Message objects if applicable.
def object_hook(self, dct):
Args:
dct (dict): Dictionary to decode, potentially representing a Message.
Returns:
Union[HumanMessage, AIMessage, dict]: A Message object if the input
dictionary represents a message, otherwise returns the original dictionary.
"""
if "type" in dct:
if dct["type"] == "human":
message = HumanMessage(content=dct["content"])
elif dct["type"] == "ai":
message = AIMessage(content=dct["content"])
else:
return dct
message.additional_kwargs = dct["additional_kwargs"]
message.response_metadata = dct["response_metadata"]
return message
Expand Down
2 changes: 1 addition & 1 deletion ols/src/cache/in_memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def list(self, user_id: str, skip_user_id_check: bool = False) -> list[str]:
prefix = f"{user_id}{Cache.COMPOUND_KEY_SEPARATOR}"

with self._lock:
for key in self.cache.keys():
for key in self.cache:
if key.startswith(prefix):
# Extract conversation_id from the key
conversation_id = key[len(prefix) :]
Expand Down
23 changes: 8 additions & 15 deletions ols/src/prompts/prompt_generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""Prompt generator based on model / context."""

from langchain_core.messages import AIMessage, HumanMessage
from copy import copy
from langchain_core.messages import HumanMessage, BaseMessage
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from copy import copy

from ols.constants import ModelFamily
from ols.customize import prompts
from langchain_core.messages import BaseMessage



def restructure_rag_context_pre(text: str, model: str) -> str:
Expand All @@ -34,13 +34,13 @@ def restructure_history(message: BaseMessage, model: str) -> BaseMessage:
# No processing required here for gpt.
return message

newMessage = copy(message)
new_message = copy(message)
# Granite specific formatting for history
if isinstance(message, HumanMessage):
newMessage.content = "\n<|user|>\n" + message.content
new_message.content = "\n<|user|>\n" + message.content
else:
newMessage.content = "\n<|assistant|>\n" + message.content
return newMessage
new_message.content = "\n<|assistant|>\n" + message.content
return new_message


class GeneratePrompt:
Expand Down Expand Up @@ -72,12 +72,6 @@ def _generate_prompt_gpt(self) -> tuple[ChatPromptTemplate, dict]:
)

if len(self._history) > 0:
# chat_history = []
# for h in self._history:
# if h.type == "human":
# chat_history.append(HumanMessage(content=h.removeprefix("human: ")))
# else:
# chat_history.append(AIMessage(content=h.removeprefix("ai: ")))
llm_input_values["chat_history"] = self._history

sys_intruction = (
Expand Down Expand Up @@ -113,7 +107,6 @@ def _generate_prompt_granite(self) -> tuple[PromptTemplate, dict]:
llm_input_values["chat_history"] = ""
for message in self._history:
llm_input_values["chat_history"] += message.content
# llm_input_values["chat_history"] = "".join(self._history)

if "context" in llm_input_values:
prompt_message = prompt_message + "\n{context}"
Expand Down
3 changes: 2 additions & 1 deletion ols/src/query_helpers/docs_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from langchain.chains import LLMChain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage, BaseMessage
from llama_index.core import VectorStoreIndex

from ols import config
Expand All @@ -15,7 +16,7 @@
from ols.src.prompts.prompt_generator import GeneratePrompt
from ols.src.query_helpers.query_helper import QueryHelper
from ols.utils.token_handler import TokenHandler
from langchain_core.messages import AIMessage, BaseMessage


logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions ols/utils/token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from llama_index.core.schema import NodeWithScore
from tiktoken import get_encoding
from langchain_core.messages import BaseMessage

from ols.app.models.models import RagChunk
from ols.constants import (
Expand All @@ -19,8 +20,6 @@
restructure_rag_context_pre,
)

from langchain_core.messages import BaseMessage

logger = logging.getLogger(__name__)


Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_authorized_noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def test_authorized():
assert response is not None
assert response.user_id == user_id_in_request
assert response.username == constants.DEFAULT_USER_NAME
assert response.skip_user_id_check == True
assert response.skip_user_id_check is True
11 changes: 4 additions & 7 deletions tests/integration/test_conversations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch
"""Integration tests for /conversations REST API endpoints."""

from unittest.mock import patch
import pytest
import requests
from fastapi.testclient import TestClient

from ols import config
from ols.utils import suid
Expand All @@ -25,10 +26,6 @@ def _setup():
@pytest.mark.parametrize("endpoint", ("/conversations/{conversation_id}",))
def test_get_conversation_with_history(_setup, endpoint):
"""Test getting conversation history after creating some chat history."""
# we need to import it here because these modules triggers config
# load too -> causes exception in auth module because of missing config
# values
from ols.app.models.models import CacheEntry

ml = mock_langchain_interface("test response")
with (
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import requests
from fastapi.testclient import TestClient
from langchain_core.messages import AIMessage, HumanMessage

from ols import config, constants
from ols.app.models.config import (
Expand All @@ -20,7 +21,7 @@
from tests.mock_classes.mock_langchain_interface import mock_langchain_interface
from tests.mock_classes.mock_llm_chain import mock_llm_chain
from tests.mock_classes.mock_llm_loader import mock_llm_loader
from langchain_core.messages import AIMessage, HumanMessage



@pytest.fixture(scope="function")
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_redis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Integration tests for real Redis behaviour."""

import pytest
from langchain_core.messages import AIMessage, HumanMessage

from ols.app.models.config import RedisConfig
from ols.app.models.models import CacheEntry
from ols.src.cache.redis_cache import RedisCache
from langchain_core.messages import AIMessage, HumanMessage


USER_ID = "00000000-0000-0000-0000-000000000001"
CONVERSATION_ID = "00000000-0000-0000-0000-000000000002"
Expand Down
2 changes: 1 addition & 1 deletion tests/mock_classes/mock_redis_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ def keys(self, pattern):

# Use fnmatch to match keys against the pattern
matching_keys = [
key for key in self.cache.keys() if fnmatch.fnmatch(key, pattern)
key for key in self.cache if fnmatch.fnmatch(key, pattern)
]
return matching_keys

0 comments on commit ea52586

Please sign in to comment.