From d3a34f9778081b19d713fdc820ae06cf518addb8 Mon Sep 17 00:00:00 2001 From: Justin Hayes <52832301+justinh-rahb@users.noreply.github.com> Date: Wed, 9 Aug 2023 09:15:50 -0400 Subject: [PATCH] API_URL variable added (#41) --- main.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 17794e3..54c60c9 100644 --- a/main.py +++ b/main.py @@ -22,7 +22,7 @@ # Try to get the system prompt from an environment variable try: - system_prompt = os.getenv('SYSTEM_PROMPT', 'You are a helpful assistant.') + SYSTEM_PROMPT = os.getenv('SYSTEM_PROMPT', 'You are a helpful assistant.') except Exception as e: print(f"Error getting system prompt: {str(e)}") @@ -50,6 +50,10 @@ except Exception as e: print(f"Error getting MAX_TOKENS_OUTPUT: {str(e)}") +# Try to get the chat completions API endpoint from an environment variable +# Example: https://example.com:8000/v1/chat/completions +API_URL = os.getenv('API_URL') # Defaults to OpenAI API if not set + # Define globals user_sessions = {} # A dictionary to track the AIChat instances for each user turn_counts = {} # A dictionary to track the turn count for each user @@ -121,12 +125,12 @@ def handle_message(user_id, user_message): last_received_time = last_received_times.get(user_id) # Count the tokens in the user message - num_tokens = num_tokens_from_string(user_message + system_prompt) + num_tokens = num_tokens_from_string(user_message + SYSTEM_PROMPT) # If the user types '/reset', reset the session if user_message.strip().lower() == '/reset': - ai_chat = AIChat(api_key=openai_api_key, system=system_prompt, model=MODEL_NAME, params=params) - user_sessions[user_id] = ai_chat + if user_id in user_sessions: + del user_sessions[user_id] # Delete the user's chat session if it exists turn_count = 0 bot_message = "Your session has been reset. How can I assist you now?" @@ -173,7 +177,10 @@ def handle_message(user_id, user_message): # If it's not a slash command, handle it normally else: if ai_chat is None or turn_count >= MAX_TURNS or (last_received_time is not None and (current_time - last_received_time).total_seconds() > TTL): - ai_chat = AIChat(api_key=openai_api_key, system=system_prompt, model=MODEL_NAME, params=params) + if API_URL: + ai_chat = AIChat(api_key=None, api_url=API_URL, system=SYSTEM_PROMPT, params=params) + else: + ai_chat = AIChat(api_key=openai_api_key, system=SYSTEM_PROMPT, model=MODEL_NAME, params=params) user_sessions[user_id] = ai_chat turn_count = 0