From 769e985e16d316bf25b0d642658ed82d0f902060 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Wed, 22 Jan 2025 11:15:05 -0300 Subject: [PATCH] FEATURE: Formalize support for matryoshka dimensions. We have a flag to signal we are shortening the embeddings of a model. Only used in Open AI's text-embedding-3-*, but we plan to use it for other services. --- .../admin/ai_embeddings_controller.rb | 1 + app/models/embedding_definition.rb | 39 ++++++++++--------- .../ai_embedding_definition_serializer.rb | 1 + .../discourse/admin/models/ai-embedding.js | 3 +- .../components/ai-embedding-editor.gjs | 10 +++++ .../common/ai-embedding-editor.scss | 5 +++ config/locales/client.en.yml | 1 + ...122131007_matryoshka_dimensions_support.rb | 22 +++++++++++ spec/lib/modules/embeddings/vector_spec.rb | 9 +---- .../admin/ai_embeddings_controller_spec.rb | 2 + 10 files changed, 66 insertions(+), 27 deletions(-) create mode 100644 db/migrate/20250122131007_matryoshka_dimensions_support.rb diff --git a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb index c3f21caf7..66a9c7f3c 100644 --- a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb @@ -113,6 +113,7 @@ def ai_embeddings_params :tokenizer_class, :embed_prompt, :search_prompt, + :matryoshka_dimensions, ) extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym) diff --git a/app/models/embedding_definition.rb b/app/models/embedding_definition.rb index 0c7cbda4c..ba6dd362c 100644 --- a/app/models/embedding_definition.rb +++ b/app/models/embedding_definition.rb @@ -84,6 +84,7 @@ def presets tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", url: "https://api.openai.com/v1/embeddings", provider: OPEN_AI, + matryoshka_dimensions: true, provider_params: { model_name: "text-embedding-3-large", }, @@ -97,6 +98,7 @@ def presets tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", url: "https://api.openai.com/v1/embeddings", provider: OPEN_AI, + matryoshka_dimensions: true, provider_params: { model_name: "text-embedding-3-small", }, @@ -200,9 +202,7 @@ def hugging_face_client end def open_ai_client - model_name = lookup_custom_param("model_name") - can_shorten_dimensions = %w[text-embedding-3-small text-embedding-3-large].include?(model_name) - client_dimensions = can_shorten_dimensions ? dimensions : nil + client_dimensions = matryoshka_dimensions ? dimensions : nil DiscourseAi::Inference::OpenAiEmbeddings.new( endpoint_url, @@ -221,20 +221,21 @@ def gemini_client # # Table name: embedding_definitions # -# id :bigint not null, primary key -# display_name :string not null -# dimensions :integer not null -# max_sequence_length :integer not null -# version :integer default(1), not null -# pg_function :string not null -# provider :string not null -# tokenizer_class :string not null -# url :string not null -# api_key :string -# seeded :boolean default(FALSE), not null -# provider_params :jsonb -# created_at :datetime not null -# updated_at :datetime not null -# embed_prompt :string default(""), not null -# search_prompt :string default(""), not null +# id :bigint not null, primary key +# display_name :string not null +# dimensions :integer not null +# max_sequence_length :integer not null +# version :integer default(1), not null +# pg_function :string not null +# provider :string not null +# tokenizer_class :string not null +# url :string not null +# api_key :string +# seeded :boolean default(FALSE), not null +# provider_params :jsonb +# created_at :datetime not null +# updated_at :datetime not null +# embed_prompt :string default(""), not null +# search_prompt :string default(""), not null +# matryoshka_dimensions :boolean default(FALSE), not null # diff --git a/app/serializers/ai_embedding_definition_serializer.rb b/app/serializers/ai_embedding_definition_serializer.rb index a15adcf2a..8c5b17b3b 100644 --- a/app/serializers/ai_embedding_definition_serializer.rb +++ b/app/serializers/ai_embedding_definition_serializer.rb @@ -15,6 +15,7 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer :tokenizer_class, :embed_prompt, :search_prompt, + :matryoshka_dimensions, :provider_params def api_key diff --git a/assets/javascripts/discourse/admin/models/ai-embedding.js b/assets/javascripts/discourse/admin/models/ai-embedding.js index ea312f250..d3d620fe3 100644 --- a/assets/javascripts/discourse/admin/models/ai-embedding.js +++ b/assets/javascripts/discourse/admin/models/ai-embedding.js @@ -16,7 +16,8 @@ export default class AiEmbedding extends RestModel { "provider_params", "pg_function", "embed_prompt", - "search_prompt" + "search_prompt", + "matryoshka_dimensions" ); } diff --git a/assets/javascripts/discourse/components/ai-embedding-editor.gjs b/assets/javascripts/discourse/components/ai-embedding-editor.gjs index 22a6a4dc7..2ba63c92c 100644 --- a/assets/javascripts/discourse/components/ai-embedding-editor.gjs +++ b/assets/javascripts/discourse/components/ai-embedding-editor.gjs @@ -290,6 +290,16 @@ export default class AiEmbeddingEditor extends Component { {{/if}} +
+ + +
+
>'model_name') = 'text-embedding-3-large' OR + (provider_params->>'model_name') = 'text-embedding-3-small' + ) + SQL + end + + def down + raise ActiveRecord::IrreversibleMigration + end +end diff --git a/spec/lib/modules/embeddings/vector_spec.rb b/spec/lib/modules/embeddings/vector_spec.rb index b69073c38..d2bc4bbc2 100644 --- a/spec/lib/modules/embeddings/vector_spec.rb +++ b/spec/lib/modules/embeddings/vector_spec.rb @@ -99,15 +99,10 @@ def stub_vector_mapping(text, expected_embedding) it_behaves_like "generates and store embeddings using a vector definition" - context "when working with models that support shortening embeddings" do + context "when matryoshka_dimensions is enabled" do it "passes the dimensions param" do shorter_dimensions = 10 - vdef.update!( - dimensions: shorter_dimensions, - provider_params: { - model_name: "text-embedding-3-small", - }, - ) + vdef.update!(dimensions: shorter_dimensions, matryoshka_dimensions: true) text = "This is a piece of text" short_expected_embedding = [0.0038493] * shorter_dimensions diff --git a/spec/requests/admin/ai_embeddings_controller_spec.rb b/spec/requests/admin/ai_embeddings_controller_spec.rb index ac3f0c232..40106b5d9 100644 --- a/spec/requests/admin/ai_embeddings_controller_spec.rb +++ b/spec/requests/admin/ai_embeddings_controller_spec.rb @@ -17,6 +17,7 @@ tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer", embed_prompt: "I come first:", search_prompt: "prefix for search", + matryoshka_dimensions: true, } end @@ -31,6 +32,7 @@ expect(created_def.display_name).to eq(valid_attrs[:display_name]) expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt]) expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt]) + expect(created_def.matryoshka_dimensions).to eq(true) end it "stores provider-specific config params" do