Skip to content

Commit

Permalink
use models
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Dec 15, 2023
1 parent 3865d62 commit 01b5fcd
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
from grobid_client_generic import GrobidClientGeneric

OPENAI_MODELS = ['chatgpt-3.5-turbo',
"gpt-4",
"gpt-4-1106-preview"]

if 'rqa' not in st.session_state:
st.session_state['rqa'] = {}

Expand Down Expand Up @@ -117,17 +121,17 @@ def clear_memory():
# @st.cache_resource
def init_qa(model, api_key=None):
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
if model == 'chatgpt-3.5-turbo':
if model in OPENAI_MODELS:
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
if api_key:
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
chat = ChatOpenAI(model_name=model,
temperature=0,
openai_api_key=api_key,
frequency_penalty=0.1)
embeddings = OpenAIEmbeddings(openai_api_key=api_key)

else:
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
chat = ChatOpenAI(model_name=model,
temperature=0,
frequency_penalty=0.1)
embeddings = OpenAIEmbeddings()
Expand Down Expand Up @@ -241,7 +245,7 @@ def play_old_messages():
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
st.session_state['rqa'][model] = init_qa(model)

elif model == 'chatgpt-3.5-turbo' and model not in st.session_state['api_keys']:
elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
if 'OPENAI_API_KEY' not in os.environ:
api_key = st.text_input('OpenAI API Key', type="password")
st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")
Expand Down

0 comments on commit 01b5fcd

Please sign in to comment.