Skip to content

Commit

Permalink
feat: LLMClient improve model retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
bearlike committed Dec 19, 2024
1 parent 1ba3e08 commit a577cf9
Showing 1 changed file with 45 additions and 23 deletions.
68 changes: 45 additions & 23 deletions python/openwebui/pipe_mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import AIMessage, HumanMessage, BaseMessage


from pydantic import BaseModel, Field

from open_webui.constants import TASKS
from open_webui.apps.ollama import main as ollama

# * Patch for user-id missing in the request

from types import SimpleNamespace

# Import Langfuse for logging/tracing (optional)
Expand Down Expand Up @@ -63,6 +63,7 @@

# =============================================================================


class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
def __init__(self):
self.queue = asyncio.Queue()
Expand All @@ -86,11 +87,13 @@ async def __aiter__(self):
break
yield token


class LLMClient:
def __init__(self, valves: "Pipe.Valves"):
def __init__(self, valves: "Pipe.Valves", user_mod=None):
self.openai_api_key = valves.OPENAI_API_KEY
self.openai_api_base_url = valves.OPENAI_API_BASE_URL
self.ollama_api_base_url = valves.OLLAMA_API_BASE_URL
self.__user__ = user_mod

async def create_chat_completion(
self, messages: list, model: str, backend: str, stream: bool = False
Expand Down Expand Up @@ -132,7 +135,8 @@ async def create_chat_completion(
return ai_message.content
elif backend == "ollama":
response = await ollama.generate_openai_chat_completion(
{"model": model, "messages": messages, "stream": stream}
{"model": model, "messages": messages, "stream": stream},
user=self.__user__,
)
return response
else:
Expand Down Expand Up @@ -184,10 +188,12 @@ def get_chunk_content(self, chunk):
except json.JSONDecodeError:
logger.error(f'ChunkDecodeError: unable to parse "{chunk_str[:100]}"')


# =============================================================================

# MCTS Classes


class Node:
def __init__(
self,
Expand Down Expand Up @@ -240,6 +246,7 @@ def mermaid(self, offset=0, selected=None):
logger.debug(f"Node Mermaid:\n{msg}")
return msg


class MCTSAgent:
def __init__(
self,
Expand Down Expand Up @@ -483,6 +490,7 @@ async def emit_replace(self, content: str):
if self.event_emitter:
await self.event_emitter({"type": "replace", "data": {"content": content}})


# =============================================================================

# Prompts
Expand Down Expand Up @@ -543,33 +551,32 @@ async def emit_replace(self, content: str):

# =============================================================================

# Pipe Class


class Pipe:
class Valves(BaseModel):
# ! FIX: User Provided Valves not being used. Only defaults used.
# Manually set the default values for the valves
OPENAI_API_KEY: str = Field(
default="sk-CHANGE-ME", description="OpenAI API key"
default="sk-Change-Me", description="OpenAI API key"
)
OPENAI_API_BASE_URL: str = Field(
default="http://litellm:4000/v1", description="OpenAI API base URL"
)
OLLAMA_API_BASE_URL: str = Field(
default="http://avalanche.lan:11434", description="Ollama API base URL"
)
USE_OPENAI: bool = Field(
default=True, description="Whether to use OpenAI endpoints"
default="http://ollama.lan:11434", description="Ollama API base URL"
)
LANGFUSE_SECRET_KEY: str = Field(
default="sk-change-me",
default="sk-Change-Me",
description="Langfuse secret key",
)
LANGFUSE_PUBLIC_KEY: str = Field(
default="pk-change-me",
default="pk-Change-Me",
description="Langfuse public key",
)
LANGFUSE_URL: str = Field(
default="<http://langfuse-server:3000>", description="Langfuse URL"
default="http://langfuse-server:3000", description="Langfuse URL"
)
EXPLORATION_WEIGHT: float = Field(
default=1.414, description="Exploration weight for MCTS"
Expand All @@ -584,7 +591,7 @@ class Valves(BaseModel):
default=2, description="Maximum number of children per node in MCTS"
)
OLLAMA_MODELS: str = Field(
default="Ollama/.tulu3:8b,Ollama/.llama3.2-vision:11b",
default="Ollama/Avalanche/.tulu3:8b,Ollama/Avalanche/.llama3.2-vision:11b",
description="Comma-separated list of Ollama model IDs",
)
OPENAI_MODELS: str = Field(
Expand Down Expand Up @@ -616,29 +623,36 @@ def __init__(self):
)

def pipes(self) -> List[dict]:
# Hardcode models from valves
if self.valves.USE_OPENAI:
openai_models_str = self.valves.OPENAI_MODELS
# Collect models from both OpenAI and Ollama
model_list = []

# Get OpenAI models
openai_models_str = self.valves.OPENAI_MODELS
if openai_models_str:
openai_models = [
model.strip() for model in openai_models_str.split(",") if model.strip()
]
model_list = [
openai_model_list = [
{"id": f"mcts/openai/{model}", "name": f"MCTS/{model}"}
for model in openai_models
]
logger.debug(f"Available OpenAI models: {model_list}")
return model_list
else:
ollama_models_str = self.valves.OLLAMA_MODELS
logger.debug(f"Available OpenAI models: {openai_model_list}")
model_list.extend(openai_model_list)

# Get Ollama models
ollama_models_str = self.valves.OLLAMA_MODELS
if ollama_models_str:
ollama_models = [
model.strip() for model in ollama_models_str.split(",") if model.strip()
]
model_list = [
ollama_model_list = [
{"id": f"mcts/ollama/{model}", "name": f"MCTS/{model}"}
for model in ollama_models
]
logger.debug(f"Available Ollama models: {model_list}")
return model_list
logger.debug(f"Available Ollama models: {ollama_model_list}")
model_list.extend(ollama_model_list)

return model_list

async def pipe(
self,
Expand All @@ -664,6 +678,13 @@ async def pipe(

self.backend = backend
self.model = model_name
# To ensure __user__ is an object with 'id' and 'role' attributes
if __user__ is None or not isinstance(__user__, dict):
self.__user__ = SimpleNamespace(id=None, role="admin")
else:
self.__user__ = SimpleNamespace(**__user__)

self.llm_client.__user__ = self.__user__

messages = body.get("messages")
if not messages:
Expand Down Expand Up @@ -707,4 +728,5 @@ async def pipe(

# Run MCTS search
best_answer = await mcts_agent.search()

return ""

0 comments on commit a577cf9

Please sign in to comment.