Skip to content

Commit

Permalink
FIX: Restore the accidentally deleted query prefix. (#1079)
Browse files Browse the repository at this point in the history
Additionally, we add a prefix for embedding generation.
Both are stored in the definitions table.
  • Loading branch information
romanrizzi authored Jan 21, 2025
1 parent f5cf101 commit 3b66fb3
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def ai_embeddings_params
:url,
:api_key,
:tokenizer_class,
:embed_prompt,
:search_prompt,
)

extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
Expand Down
3 changes: 3 additions & 0 deletions app/models/embedding_definition.rb
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def presets
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
provider: HUGGING_FACE,
search_prompt: "Represent this sentence for searching relevant passages:",
},
{
preset_id: "bge-m3",
Expand Down Expand Up @@ -228,4 +229,6 @@ def gemini_client
# provider_params :jsonb
# created_at :datetime not null
# updated_at :datetime not null
# embed_prompt :string default(""), not null
# search_prompt :string default(""), not null
#
2 changes: 2 additions & 0 deletions app/serializers/ai_embedding_definition_serializer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer
:api_key,
:seeded,
:tokenizer_class,
:embed_prompt,
:search_prompt,
:provider_params

def api_key
Expand Down
4 changes: 3 additions & 1 deletion assets/javascripts/discourse/admin/models/ai-embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ export default class AiEmbedding extends RestModel {
"api_key",
"max_sequence_length",
"provider_params",
"pg_function"
"pg_function",
"embed_prompt",
"search_prompt"
);
}

Expand Down
18 changes: 18 additions & 0 deletions assets/javascripts/discourse/components/ai-embedding-editor.gjs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,24 @@ export default class AiEmbeddingEditor extends Component {
{{/if}}
</div>

<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.embed_prompt"}}</label>
<Input
@type="text"
class="ai-embedding-editor-input ai-embedding-editor__embed_prompt"
@value={{this.editingModel.embed_prompt}}
/>
</div>

<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.search_prompt"}}</label>
<Input
@type="text"
class="ai-embedding-editor-input ai-embedding-editor__search_prompt"
@value={{this.editingModel.search_prompt}}
/>
</div>

<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.max_sequence_length"}}</label>
<Input
Expand Down
4 changes: 3 additions & 1 deletion config/locales/client.en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,9 @@ en:
tokenizer: "Tokenizer"
dimensions: "Embedding dimensions"
max_sequence_length: "Sequence length"

embed_prompt: "Embed prompt"
search_prompt: "Search prompt"

distance_function: "Distance function"
distance_functions:
<#>: "Negative inner product (<#>)"
Expand Down
18 changes: 18 additions & 0 deletions db/migrate/20250121162520_configurable_embeddings_prefixes.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# frozen_string_literal: true
class ConfigurableEmbeddingsPrefixes < ActiveRecord::Migration[7.2]
def up
add_column :embedding_definitions, :embed_prompt, :string, null: false, default: ""
add_column :embedding_definitions, :search_prompt, :string, null: false, default: ""

# 4 is bge-large-en. Default model and the only one using this so far.
execute <<~SQL
UPDATE embedding_definitions
SET search_prompt='Represent this sentence for searching relevant passages:'
WHERE id = 4
SQL
end

def down
raise ActiveRecord::IrreversibleMigration
end
end
29 changes: 17 additions & 12 deletions lib/embeddings/strategies/truncation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,28 @@ def version
def prepare_target_text(target, vdef)
max_length = vdef.max_sequence_length - 2

case target
when Topic
topic_truncation(target, vdef.tokenizer, max_length)
when Post
post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment
vdef.tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end
prepared_text =
case target
when Topic
topic_truncation(target, vdef.tokenizer, max_length)
when Post
post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment
vdef.tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end

return prepared_text if vdef.embed_prompt.blank?

[vdef.embed_prompt, prepared_text].join(" ")
end

def prepare_query_text(text, vdef, asymetric: false)
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
qtext = asymetric ? "#{vdef.search_prompt} #{text}" : text
max_length = vdef.max_sequence_length - 2

vdef.tokenizer.truncate(text, max_length)
vdef.tokenizer.truncate(qtext, max_length)
end

private
Expand Down
62 changes: 42 additions & 20 deletions spec/lib/modules/embeddings/strategies/truncation_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,51 @@
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
subject(:truncation) { described_class.new }

describe "#prepare_query_text" do
context "when using vector def from OpenAI" do
before { SiteSetting.max_post_length = 100_000 }
fab!(:open_ai_embedding_def)
let(:prefix) { "I come first:" }

fab!(:topic)
fab!(:post) do
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
end
fab!(:post) do
Fabricate(
:post,
topic: topic,
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
)
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
fab!(:open_ai_embedding_def)
describe "#prepare_target_text" do
before { SiteSetting.max_post_length = 100_000 }

fab!(:topic)
fab!(:post) do
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
end
fab!(:post) do
Fabricate(
:post,
topic: topic,
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
)
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
fab!(:open_ai_embedding_def)

it "truncates a topic" do
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)

expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
open_ai_embedding_def.max_sequence_length
end

it "includes embed prefix" do
open_ai_embedding_def.update!(embed_prompt: prefix)

prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)

expect(prepared_text.starts_with?(prefix)).to eq(true)
end
end

describe "#prepare_query_text" do
context "when search is asymetric" do
it "includes search prefix" do
open_ai_embedding_def.update!(search_prompt: prefix)

it "truncates a topic" do
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
prepared_query_text =
truncation.prepare_query_text("searching", open_ai_embedding_def, asymetric: true)

expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
open_ai_embedding_def.max_sequence_length
expect(prepared_query_text.starts_with?(prefix)).to eq(true)
end
end
end
Expand Down
4 changes: 4 additions & 0 deletions spec/requests/admin/ai_embeddings_controller_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
url: "https://test.com/api/v1/embeddings",
api_key: "test",
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
embed_prompt: "I come first:",
search_prompt: "prefix for search",
}
end

Expand All @@ -27,6 +29,8 @@

expect(response.status).to eq(201)
expect(created_def.display_name).to eq(valid_attrs[:display_name])
expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt])
expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt])
end

it "stores provider-specific config params" do
Expand Down
7 changes: 7 additions & 0 deletions spec/system/embeddings/ai_embedding_definition_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
select_kit.expand
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")

embed_prefix = "On creation:"
search_prefix = "On search:"
find("input.ai-embedding-editor__embed_prompt").fill_in(with: embed_prefix)
find("input.ai-embedding-editor__search_prompt").fill_in(with: search_prefix)

find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)

Expand All @@ -83,5 +88,7 @@
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
expect(embedding_def.pg_function).to eq(preset[:pg_function])
expect(embedding_def.provider).to eq(preset[:provider])
expect(embedding_def.embed_prompt).to eq(embed_prefix)
expect(embedding_def.search_prompt).to eq(search_prefix)
end
end

0 comments on commit 3b66fb3

Please sign in to comment.