From 1ba3e083724abc591ec2103375586dcc3366f6c5 Mon Sep 17 00:00:00 2001 From: Krishnakanth Alagiri Date: Thu, 19 Dec 2024 03:49:47 -0800 Subject: [PATCH] Migrated to langchain & improved rendering emits --- python/openwebui/pipe_mcts.py | 258 ++++++++++++++++++++++++++++------ 1 file changed, 217 insertions(+), 41 deletions(-) diff --git a/python/openwebui/pipe_mcts.py b/python/openwebui/pipe_mcts.py index 96e1f3b..b0237e8 100644 --- a/python/openwebui/pipe_mcts.py +++ b/python/openwebui/pipe_mcts.py @@ -1,6 +1,6 @@ """ title: MCTS Answer Generation Pipe -author: https://github.com/bearlike/scripts +author: KK description: Monte Carlo Tree Search Pipe Addon for OpenWebUI with support for OpenAI and Ollama endpoints. version: 1.0.0 """ @@ -23,9 +23,19 @@ Generator, Iterator, ) + +import openai +from langchain_openai import ChatOpenAI +from langchain.callbacks.base import AsyncCallbackHandler +from langchain.schema import AIMessage, HumanMessage, BaseMessage + from pydantic import BaseModel, Field + from open_webui.constants import TASKS from open_webui.apps.ollama import main as ollama + +# * Patch for user-id missing in the request + from types import SimpleNamespace # Import Langfuse for logging/tracing (optional) @@ -53,8 +63,28 @@ # ============================================================================= -# LLM Client +class AsyncIteratorCallbackHandler(AsyncCallbackHandler): + def __init__(self): + self.queue = asyncio.Queue() + self.done = False + + async def on_llm_new_token(self, token: str, **kwargs): + await self.queue.put(token) + + async def on_llm_end(self, response: AIMessage, **kwargs): + self.done = True + await self.queue.put(None) # Signal completion + async def on_llm_error(self, error: Exception, **kwargs): + self.done = True + await self.queue.put(None) # Signal completion + + async def __aiter__(self): + while not self.done: + token = await self.queue.get() + if token is None: + break + yield token class LLMClient: def __init__(self, valves: "Pipe.Valves"): @@ -66,15 +96,40 @@ async def create_chat_completion( self, messages: list, model: str, backend: str, stream: bool = False ): if backend == "openai": - import openai - - openai.api_key = self.openai_api_key - openai.api_base = self.openai_api_base_url - # ! FIX: ChatCompletion is no longer supported - response = await openai.ChatCompletion.acreate( - model=model, messages=messages, stream=stream - ) - return response + # Convert messages to LangChain's Message objects + lc_messages = [] + for msg in messages: + if msg["role"] == "user": + lc_messages.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "assistant": + lc_messages.append(AIMessage(content=msg["content"])) + else: + lc_messages.append(HumanMessage(content=msg["content"])) + + if stream: + # Create a callback handler to capture streamed tokens + handler = AsyncIteratorCallbackHandler() + oai_model = ChatOpenAI( + base_url=self.openai_api_base_url, + api_key=self.openai_api_key, + model=model, + streaming=True, + callbacks=[handler], # Pass the handler here + ) + # Call agenerate with messages + asyncio.create_task(oai_model.agenerate([lc_messages])) + return handler # Return the handler to iterate over + else: + oai_model = ChatOpenAI( + base_url=self.openai_api_base_url, + api_key=self.openai_api_key, + model=model, + streaming=False, + ) + response = await oai_model.agenerate([lc_messages]) + # Extract the AIMessage from the response + ai_message = response.generations[0][0].message + return ai_message.content elif backend == "ollama": response = await ollama.generate_openai_chat_completion( {"model": model, "messages": messages, "stream": stream} @@ -90,11 +145,9 @@ async def get_streaming_completion( messages, model, backend=backend, stream=True ) if backend == "openai": - async for chunk in response: - if "choices" in chunk and len(chunk["choices"]) > 0: - delta = chunk["choices"][0]["delta"] - if "content" in delta: - yield delta["content"] + # response is the AsyncIteratorCallbackHandler + async for token in response: + yield token elif backend == "ollama": async for chunk in response.body_iterator: for part in self.get_chunk_content(chunk): @@ -105,9 +158,11 @@ async def get_completion(self, messages: list, model: str, backend: str) -> str: messages, model, backend=backend, stream=False ) if backend == "openai": - return response.choices[0].message.content + # response is a string containing the content + content = response elif backend == "ollama": - return response["choices"][0]["message"]["content"] + content = response["choices"][0]["message"]["content"] + return content def get_chunk_content(self, chunk): # For Ollama only @@ -129,12 +184,10 @@ def get_chunk_content(self, chunk): except json.JSONDecodeError: logger.error(f'ChunkDecodeError: unable to parse "{chunk_str[:100]}"') - # ============================================================================= # MCTS Classes - class Node: def __init__( self, @@ -184,9 +237,9 @@ def mermaid(self, offset=0, selected=None): msg += child.mermaid(offset + 4, selected) msg += f"{padding}{self.id} --> {child.id}\n" + logger.debug(f"Node Mermaid:\n{msg}") return msg - class MCTSAgent: def __init__( self, @@ -206,6 +259,7 @@ def __init__( self.selected = None self.model = model self.backend = backend + self.iteration_responses = [] # List to store responses per iteration async def search(self): max_iterations = self.valves.MAX_ITERATIONS @@ -213,17 +267,84 @@ async def search(self): best_answer = None best_score = -float("inf") - for i in range(max_iterations): - logger.debug(f"MCTS Iteration {i+1}/{max_iterations}") - await self.progress(f"Iteration {i+1}/{max_iterations}") + processed_node_ids = set() # Initialize without root node ID + + # Evaluate the root node's response + root_score = await self.evaluate_answer(self.root.content) + self.root.visits += 1 + self.root.value += root_score + processed_node_ids.add(self.root.id) # Add root node ID to processed + + # Add root node's response to iteration_responses as Iteration 0 + self.iteration_responses.append( + { + "iteration": 0, + "responses": [ + { + "node_id": self.root.id, + "content": self.root.content, + "score": root_score, + } + ], + } + ) + + # Emit the initial state (Iteration 0) + await self.emit_iteration_update(0) + + for i in range(1, max_iterations + 1): + logger.debug(f"MCTS Iteration {i}/{max_iterations}") + await self.progress(f"Iteration {i}/{max_iterations}") + + iteration_responses = [] # Responses for this iteration for _ in range(max_simulations): leaf = await self.select(self.root) if not leaf.fully_expanded(): - leaf = await self.expand(leaf) - score = await self.simulate(leaf) - self.backpropagate(leaf, score) - + # Expand the node and get the new child + child = await self.expand(leaf) + # If we haven't processed this child before, collect its response + if child.id not in processed_node_ids: + score = await self.simulate(child) + self.backpropagate(child, score) + iteration_responses.append( + { + "node_id": child.id, + "content": child.content, + "score": score, + } + ) + processed_node_ids.add(child.id) + else: + # If leaf is fully expanded and not processed, process it + if leaf.id not in processed_node_ids and leaf.id != self.root.id: + score = await self.simulate(leaf) + self.backpropagate(leaf, score) + iteration_responses.append( + { + "node_id": leaf.id, + "content": leaf.content, + "score": score, + } + ) + processed_node_ids.add(leaf.id) + else: + # Do nothing if leaf has been processed or is the root node + continue + + # Add the iteration responses to the overall list if there are any new responses + if iteration_responses: + self.iteration_responses.append( + { + "iteration": i, + "responses": iteration_responses, + } + ) + + # Emit the Mermaid diagram and collapsible section + await self.emit_iteration_update(i) + + # Update best answer if necessary current_best = self.root.best_child() current_score = ( current_best.value / current_best.visits @@ -234,16 +355,52 @@ async def search(self): best_score = current_score best_answer = current_best.content - await self.emit_message(f"Best Answer:\n{best_answer}") + await self.emit_message(f"## Best Answer:\n{best_answer}") await self.done() return best_answer + async def emit_iteration_update(self, iteration_number): + """method to emit the diagram and responses""" + # Generate the Mermaid diagram + mermaid_diagram = "```mermaid\ngraph TD\n" + self.root.mermaid() + "\n```\n" + + # Generate the collapsible section with agent responses + collapsible_content = self.generate_collapsible_content() + + # Combine the Mermaid diagram and collapsible content + full_content = mermaid_diagram + "\n\n" + collapsible_content + + # Emit the content to the client + await self.emit_replace(full_content) + + def generate_collapsible_content(self): + """Method to generate collapsible content""" + content = "" + for iteration_info in self.iteration_responses: + iteration = iteration_info["iteration"] + responses = iteration_info["responses"] + + content += "
\n" + content += f"Expand to View Iteration {iteration}\n\n" + + for resp in responses: + node_id = resp["node_id"] + response_content = resp["content"] + score = resp["score"] + content += f"- Node `{node_id}`: Score `{score}`\n" + content += f" - **Response**: {response_content}\n" + + content += "
\n\n" + + return content + async def select(self, node: Node): while node.fully_expanded() and node.children: node = max(node.children, key=lambda n: n.uct_value()) return node async def expand(self, node: Node): + # Expand the node by adding one child thought = await self.generate_thought(node.content) new_content = await self.update_approach(node.content, thought) child = Node( @@ -289,6 +446,7 @@ async def evaluate_answer(self, answer: str): async def generate_completion(self, prompt: str): messages = [{"role": "user", "content": prompt}] content = "" + logger.debug(f"Attempting to stream completion for prompt: {prompt}") async for chunk in self.llm_client.get_streaming_completion( messages, model=self.model, backend=self.backend ): @@ -321,6 +479,9 @@ async def emit_status(self, level: str, message: str, done: bool): } ) + async def emit_replace(self, content: str): + if self.event_emitter: + await self.event_emitter({"type": "replace", "data": {"content": content}}) # ============================================================================= @@ -382,14 +543,16 @@ async def emit_status(self, level: str, message: str, done: bool): # ============================================================================= -# Pipe Class - class Pipe: class Valves(BaseModel): - OPENAI_API_KEY: str = Field(default="", description="OpenAI API key") + # ! FIX: User Provided Valves not being used. Only defaults used. + # Manually set the default values for the valves + OPENAI_API_KEY: str = Field( + default="sk-CHANGE-ME", description="OpenAI API key" + ) OPENAI_API_BASE_URL: str = Field( - default="https://api.openai.com/v1", description="OpenAI API base URL" + default="http://litellm:4000/v1", description="OpenAI API base URL" ) OLLAMA_API_BASE_URL: str = Field( default="http://avalanche.lan:11434", description="Ollama API base URL" @@ -397,9 +560,17 @@ class Valves(BaseModel): USE_OPENAI: bool = Field( default=True, description="Whether to use OpenAI endpoints" ) - LANGFUSE_SECRET_KEY: str = Field(default="", description="Langfuse secret key") - LANGFUSE_PUBLIC_KEY: str = Field(default="", description="Langfuse public key") - LANGFUSE_URL: str = Field(default="", description="Langfuse URL") + LANGFUSE_SECRET_KEY: str = Field( + default="sk-change-me", + description="Langfuse secret key", + ) + LANGFUSE_PUBLIC_KEY: str = Field( + default="pk-change-me", + description="Langfuse public key", + ) + LANGFUSE_URL: str = Field( + default="", description="Langfuse URL" + ) EXPLORATION_WEIGHT: float = Field( default=1.414, description="Exploration weight for MCTS" ) @@ -413,7 +584,7 @@ class Valves(BaseModel): default=2, description="Maximum number of children per node in MCTS" ) OLLAMA_MODELS: str = Field( - default="Ollama/Avalanche/.tulu3:8b,Ollama/Avalanche/.llama3.2-vision:11b", + default="Ollama/.tulu3:8b,Ollama/.llama3.2-vision:11b", description="Comma-separated list of Ollama model IDs", ) OPENAI_MODELS: str = Field( @@ -424,7 +595,10 @@ class Valves(BaseModel): def __init__(self): self.type = "manifold" self.valves = self.Valves( - **{k: os.getenv(k, v.default) for k, v in self.Valves.model_fields.items()} + **{ + k: v.default if v.default is not None else os.getenv(k) + for k, v in self.Valves.model_fields.items() + } ) logger.debug(f"Valves configuration: {self.valves}") self.llm_client = LLMClient(self.valves) @@ -479,7 +653,7 @@ async def pipe( logger.error("No model specified in the request") return "" - pattern = r'^(?:[a-zA-Z0-9_]+\.)?mcts/([^/]+)/(.+)$' + pattern = r"^(?:[a-zA-Z0-9_]+\.)?mcts/([^/]+)/(.+)$" match = re.match(pattern, model_id) if match: backend, model_name = match.groups() @@ -503,7 +677,10 @@ async def pipe( # Handle title generation task if __task__ == TASKS.TITLE_GENERATION: - logger.debug(f"Generating title for question: {question} using {self.model} and {self.backend}") + # return f"MCTS: {messages[0]['content']}" + logger.debug( + f"Generating title for question: {question} using {self.model} and {self.backend}" + ) content = await self.llm_client.get_completion( messages, self.model, backend=self.backend ) @@ -530,5 +707,4 @@ async def pipe( # Run MCTS search best_answer = await mcts_agent.search() - return ""