Skip to content

Commit

Permalink
Remove OpenAI semantic search
Browse files Browse the repository at this point in the history
  • Loading branch information
rce committed May 18, 2023
1 parent b2c53e6 commit 7e4a858
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 77 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ version: "3"

services:
postgres:
image: ankane/pgvector
image: postgres:15
ports:
- 5432:5432
environment:
Expand Down
6 changes: 0 additions & 6 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
import logger
import migration
from pgvector.asyncpg import register_vector

log = logger.get("DATABASE")

Expand All @@ -31,16 +30,11 @@ async def get_pool():
_connect_string,
min_size=MIN_CONNECTION_POOL_SIZE,
max_size=MAX_CONNECTION_POOL_SIZE,
init=init_connection,
)
setattr(_pool_holder, "pool", pool)
return pool


async def init_connection(conn):
await register_vector(conn)


async def close_pool():
global _pool_holder
pool = getattr(_pool_holder, "pool", None)
Expand Down
5 changes: 5 additions & 0 deletions src/migration/00055-vector-for-openai-embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
async def exec(log, tx):
# These fail if run without the extension installed which
# means these migrations can't be run on a fresh database
# without the now unnecessary pgvector extension
return

await tx.execute("CREATE EXTENSION vector")
await tx.execute("""
CREATE TABLE openaiembedding (
Expand Down
3 changes: 3 additions & 0 deletions src/migration/00056-remove-pgvector-used-for-embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
async def exec(log, tx):
await tx.execute("DROP TABLE IF EXISTS openaiembedding")
await tx.execute("DROP EXTENSION IF EXISTS vector")
71 changes: 2 additions & 69 deletions src/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import retry
import database as db
import util
import perf

log = logger.get("OPENAI")

Expand All @@ -29,66 +28,9 @@ def register(client):
return {
'setprompt': cmd_setprompt,
'genimage': cmd_genimage,
'search': cmd_search,
}


async def main():
while True:
try:
await generate_embeddings()
except Exception:
await util.log_exception(log)
await asyncio.sleep(120)


async def cmd_search(client, message, arg):
messages = await search_embedding(message.guild.id, arg)
contents = list(map(lambda r: f"Result with distance {r['distance']}:\n{r['content']}", messages))
response = "\n".join(contents)
reply_target = message
for msg in util.split_message_for_sending(response.split("\n")):
reply_target = await reply_target.reply(msg)


async def search_embedding(guild_id, query):
response = await embeddings([query])
embedding = response["data"][0]["embedding"]
return await db.fetch("""
SELECT message_id, content, embedding <#> $2 AS distance
FROM message JOIN openaiembedding USING (message_id)
WHERE guild_id = $1 -- AND (embedding <#> $2) < -0.85
-- OpenAI embeddings are normalized to length 1 so this is best performance for exact search
-- https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use
ORDER BY embedding <#> $2
LIMIT 5
""", str(guild_id), embedding)


@perf.time_async("generate_embeddings")
async def generate_embeddings():
rows = await db.fetch("""
SELECT message_id, content
FROM message
WHERE NOT bot and content != '' AND content NOT LIKE '!%'
AND NOT EXISTS (SELECT 1 FROM openaiembedding WHERE openaiembedding.message_id = message.message_id)
LIMIT 1000
""")
if len(rows) == 0:
return

contents = list(map(lambda r: r["content"], rows))
response = await embeddings(contents)

async with db.transaction() as tx:
def mk_row(output): return rows[output["index"]]["message_id"], output["embedding"]
queryparams = list(mk_row(output) for output in response["data"])
await tx.executemany(
"INSERT INTO openaiembedding(message_id, embedding) VALUES ($1, $2)",
queryparams
)


async def handle_message(client, message):
bot_mentioned = any(user for user in message.mentions if user.id == client.user.id)
if not bot_mentioned:
Expand Down Expand Up @@ -194,22 +136,13 @@ async def get_response_for_messages(messages):
})


async def embeddings(input, model="text-embedding-ada-002"):
status, response = await _call_api("/v1/embeddings", json_body={
"model": model,
"input": input,
}, skip_log=True)
return response


@retry.on_any_exception(max_attempts = 1, init_delay = 1, max_delay = 30)
async def _call_api(path, json_body=None, query=None, skip_log=False):
async def _call_api(path, json_body=None, query=None):
url = "https://api.openai.com{0}{1}".format(path, http_util.make_query_string(query))
async with aiohttp.ClientSession() as session:
for ratelimit_delay in retry.jitter(retry.exponential(1, 128)):
response = await session.post(url, headers=AUTH_HEADER, json=json_body)
log_fn = log.debug if skip_log and response.status == 200 else log.info
log_fn({
log.info({
"requestMethod": response.method,
"requestUrl": str(response.url),
"responseStatus": response.status,
Expand Down
1 change: 0 additions & 1 deletion src/run_lemon_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ async def main():
# Database schema has to be initialized before running the bot
await db.initialize_schema()
asyncio.create_task(archiver.main())
asyncio.create_task(openai.main())
await trophies.main()

for module in [casino, sqlcommands, osu, feed, reminder, youtube, lan, steam, anssicommands, trophies, laiva,
Expand Down

0 comments on commit 7e4a858

Please sign in to comment.