Skip to content

Commit

Permalink
FEATURE: Formalize support for matryoshka dimensions. (#1083)
Browse files Browse the repository at this point in the history
We have a flag to signal we are shortening the embeddings of a model.
Only used in Open AI's text-embedding-3-*, but we plan to use it for other services.
  • Loading branch information
romanrizzi authored Jan 22, 2025
1 parent 654f90f commit e2e753d
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def ai_embeddings_params
:tokenizer_class,
:embed_prompt,
:search_prompt,
:matryoshka_dimensions,
)

extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
Expand Down
39 changes: 20 additions & 19 deletions app/models/embedding_definition.rb
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def presets
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
url: "https://api.openai.com/v1/embeddings",
provider: OPEN_AI,
matryoshka_dimensions: true,
provider_params: {
model_name: "text-embedding-3-large",
},
Expand All @@ -97,6 +98,7 @@ def presets
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
url: "https://api.openai.com/v1/embeddings",
provider: OPEN_AI,
matryoshka_dimensions: true,
provider_params: {
model_name: "text-embedding-3-small",
},
Expand Down Expand Up @@ -200,9 +202,7 @@ def hugging_face_client
end

def open_ai_client
model_name = lookup_custom_param("model_name")
can_shorten_dimensions = %w[text-embedding-3-small text-embedding-3-large].include?(model_name)
client_dimensions = can_shorten_dimensions ? dimensions : nil
client_dimensions = matryoshka_dimensions ? dimensions : nil

DiscourseAi::Inference::OpenAiEmbeddings.new(
endpoint_url,
Expand All @@ -221,20 +221,21 @@ def gemini_client
#
# Table name: embedding_definitions
#
# id :bigint not null, primary key
# display_name :string not null
# dimensions :integer not null
# max_sequence_length :integer not null
# version :integer default(1), not null
# pg_function :string not null
# provider :string not null
# tokenizer_class :string not null
# url :string not null
# api_key :string
# seeded :boolean default(FALSE), not null
# provider_params :jsonb
# created_at :datetime not null
# updated_at :datetime not null
# embed_prompt :string default(""), not null
# search_prompt :string default(""), not null
# id :bigint not null, primary key
# display_name :string not null
# dimensions :integer not null
# max_sequence_length :integer not null
# version :integer default(1), not null
# pg_function :string not null
# provider :string not null
# tokenizer_class :string not null
# url :string not null
# api_key :string
# seeded :boolean default(FALSE), not null
# provider_params :jsonb
# created_at :datetime not null
# updated_at :datetime not null
# embed_prompt :string default(""), not null
# search_prompt :string default(""), not null
# matryoshka_dimensions :boolean default(FALSE), not null
#
1 change: 1 addition & 0 deletions app/serializers/ai_embedding_definition_serializer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer
:tokenizer_class,
:embed_prompt,
:search_prompt,
:matryoshka_dimensions,
:provider_params

def api_key
Expand Down
3 changes: 2 additions & 1 deletion assets/javascripts/discourse/admin/models/ai-embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ export default class AiEmbedding extends RestModel {
"provider_params",
"pg_function",
"embed_prompt",
"search_prompt"
"search_prompt",
"matryoshka_dimensions"
);
}

Expand Down
10 changes: 10 additions & 0 deletions assets/javascripts/discourse/components/ai-embedding-editor.gjs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,16 @@ export default class AiEmbeddingEditor extends Component {
{{/if}}
</div>

<div class="control-group ai-embedding-editor__matryoshka_dimensions">
<Input
@type="checkbox"
@checked={{this.editingModel.matryoshka_dimensions}}
/>
<label>{{i18n
"discourse_ai.embeddings.matryoshka_dimensions"
}}</label>
</div>

<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.embed_prompt"}}</label>
<Input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@
display: flex;
align-items: center;
}

&__matryoshka_dimensions {
display: flex;
align-items: flex-start;
}
}
1 change: 1 addition & 0 deletions config/locales/client.en.yml
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ en:
max_sequence_length: "Sequence length"
embed_prompt: "Embed prompt"
search_prompt: "Search prompt"
matryoshka_dimensions: "Matryoshka dimensions"

distance_function: "Distance function"
distance_functions:
Expand Down
22 changes: 22 additions & 0 deletions db/migrate/20250122131007_matryoshka_dimensions_support.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# frozen_string_literal: true
class MatryoshkaDimensionsSupport < ActiveRecord::Migration[7.2]
def change
add_column :embedding_definitions, :matryoshka_dimensions, :boolean, null: false, default: false

execute <<~SQL
UPDATE embedding_definitions
SET matryoshka_dimensions = TRUE
WHERE
provider = 'open_ai' AND
provider_params IS NOT NULL AND
(
(provider_params->>'model_name') = 'text-embedding-3-large' OR
(provider_params->>'model_name') = 'text-embedding-3-small'
)
SQL
end

def down
raise ActiveRecord::IrreversibleMigration
end
end
9 changes: 2 additions & 7 deletions spec/lib/modules/embeddings/vector_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,10 @@ def stub_vector_mapping(text, expected_embedding)

it_behaves_like "generates and store embeddings using a vector definition"

context "when working with models that support shortening embeddings" do
context "when matryoshka_dimensions is enabled" do
it "passes the dimensions param" do
shorter_dimensions = 10
vdef.update!(
dimensions: shorter_dimensions,
provider_params: {
model_name: "text-embedding-3-small",
},
)
vdef.update!(dimensions: shorter_dimensions, matryoshka_dimensions: true)
text = "This is a piece of text"
short_expected_embedding = [0.0038493] * shorter_dimensions

Expand Down
2 changes: 2 additions & 0 deletions spec/requests/admin/ai_embeddings_controller_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
embed_prompt: "I come first:",
search_prompt: "prefix for search",
matryoshka_dimensions: true,
}
end

Expand All @@ -31,6 +32,7 @@
expect(created_def.display_name).to eq(valid_attrs[:display_name])
expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt])
expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt])
expect(created_def.matryoshka_dimensions).to eq(true)
end

it "stores provider-specific config params" do
Expand Down

0 comments on commit e2e753d

Please sign in to comment.