-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to make this compatible with running purely on a GPU #56
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
### Instructions | ||
Generate a SQL query to answer the following question: | ||
`{user_question}` | ||
|
||
### Schema | ||
The query will run on a database with the following schema: | ||
{table_metadata_string} | ||
|
||
### SQL | ||
Here is a query that answers the question `{user_question}` | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
if os.getenv("TOKENIZERS_PARALLELISM") is None: | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
||
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu") | ||
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking of having a variable at the top of # at the top
user_device = "cuda" if torch.cuda.is_available() else "cpu"
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=user_device)
...
query_emb = encoder.encode(query, convert_to_tensor=True, device=user_device)
...
column_emb = column_emb.to(user_device) The other benefit is that we can easily override |
||
nlp = spacy.load("en_core_web_sm") | ||
|
||
|
||
|
@@ -23,7 +23,7 @@ def knn( | |
""" | ||
Get top most similar columns' embeddings to query using cosine similarity. | ||
""" | ||
query_emb = encoder.encode(query, convert_to_tensor=True, device="cpu").unsqueeze(0) | ||
query_emb = encoder.encode(query, convert_to_tensor=True).unsqueeze(0) | ||
similarity_scores = F.cosine_similarity(query_emb, all_emb) | ||
top_results = torch.nonzero(similarity_scores > threshold).squeeze() | ||
# if top_results is empty, return empty tensors | ||
|
@@ -95,6 +95,10 @@ def get_md_emb( | |
3. Generate the metadata string using the column info so far. | ||
4. Get joinable columns between tables in topk_table_columns and add to final metadata string. | ||
""" | ||
if torch.cuda.is_available(): | ||
column_emb = column_emb.to("cuda") | ||
else: | ||
column_emb = column_emb.to("cpu") | ||
# 1) get top k columns | ||
top_k_scores, top_k_indices = knn(question, column_emb, k, threshold) | ||
topk_table_columns = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking if this an improved version of prompt.md? If so, shall we just replace the existing prompt.md with this?