Skip to content

Commit

Permalink
Seed embedding definition from old settings
Browse files Browse the repository at this point in the history
  • Loading branch information
romanrizzi committed Jan 14, 2025
1 parent 3a845f6 commit ec2f544
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ export default class AiEmbeddingEditor extends Component {
await this.editingModel.save();

if (isNew) {
this.args.embeddings.addObject(this.args.model);
this.args.embeddings.addObject(this.editingModel);
this.router.transitionTo(
"adminPlugins.show.discourse-ai-embeddings.index"
);
Expand Down
199 changes: 199 additions & 0 deletions db/migrate/20250110114305_embedding_config_data_migration.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# frozen_string_literal: true

class EmbeddingConfigDataMigration < ActiveRecord::Migration[7.0]
def up
current_model = fetch_setting("ai_embeddings_model") || "bge-large-en"
provider = provider_for(current_model)

if provider.present?
attrs = creds_for(provider)

if attrs.present?
attrs = attrs.merge(model_attrs(current_model))
attrs[:display_name] = current_model
attrs[:provider] = provider
persist_config(attrs)
end
end
end

def down
end

# Utils

def fetch_setting(name)
DB.query_single(
"SELECT value FROM site_settings WHERE name = :setting_name",
setting_name: name,
).first || ENV["DISCOURSE_#{name&.upcase}"]
end

def provider_for(model)
cloudflare_api_token = fetch_setting("ai_cloudflare_workers_api_token")

return "cloudflare" if model == "bge-large-en" && cloudflare_api_token.present?

tei_models = %w[bge-large-en bge-m3 multilingual-e5-large]
return "hugging_face" if tei_models.include?(model)

return "google" if model == "gemini"

if %w[text-embedding-3-large text-embedding-3-small text-embedding-ada-002].include?(model)
return "open_ai"
end

nil
end

def creds_for(provider)
# CF
if provider == "cloudflare"
api_key = fetch_setting("ai_cloudflare_workers_api_token")
account_id = fetch_setting("ai_cloudflare_workers_account_id")

return if api_key.blank? || account_id.blank?

{
url:
"https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/baai/bge-large-en-v1.5",
api_key: api_key,
}
# TEI
elsif provider == "hugging_face"
endpoint = fetch_setting("ai_hugging_face_tei_endpoint")

if endpoint.blank?
endpoint = fetch_setting("ai_hugging_face_tei_endpoint_srv")
endpoint = "srv://#{endpoint}" if endpoint.present?
end

api_key = fetch_setting("ai_hugging_face_tei_api_key")

return if endpoint.blank? || api_key.blank?

{ url: endpoint, api_key: api_key }
# Gemini
elsif provider == "google"
api_key = fetch_setting("ai_gemini_api_key")

return if api_key.blank?

{
url: "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent",
api_key: api_key,
}

# Open AI
elsif provider == "open_ai"
endpoint = fetch_setting("ai_openai_embeddings_url")
api_key = fetch_setting("ai_openai_api_key")

return if endpoint.blank? || api_key.blank?

{ url: endpoint, api_key: api_key }
else
nil
end
end

def model_attrs(model_name)
if model_name == "bge-large-en"
{
dimensions: 1024,
max_sequence_length: 512,
id: 4,
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
}
elsif model_name == "bge-m3"
{
dimensions: 1024,
max_sequence_length: 8192,
id: 8,
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
}
elsif model_name == "gemini"
{
dimensions: 768,
max_sequence_length: 1536,
id: 5,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
}
elsif model_name == "multilingual-e5-large"
{
dimensions: 1024,
max_sequence_length: 512,
id: 3,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer",
}
elsif model_name == "text-embedding-3-large"
{
dimensions: 2000,
max_sequence_length: 8191,
id: 7,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
provider_params: {
model_name: "text-embedding-3-large",
},
}
elsif model_name == "text-embedding-3-small"
{
dimensions: 1536,
max_sequence_length: 8191,
id: 6,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
provider_params: {
model_name: "text-embedding-3-small",
},
}
else
{
dimensions: 1536,
max_sequence_length: 8191,
id: 2,
pg_function: "<=>",
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
provider_params: {
model_name: "text-embedding-ada-002",
},
}
end
end

def persist_config(attrs)
DB.exec(
<<~SQL,
INSERT INTO embedding_definitions (id, display_name, dimensions, max_sequence_length, version, pg_function, provider, tokenizer_class, url, api_key, provider_params, created_at, updated_at)
VALUES (:id, :display_name, :dimensions, :max_sequence_length, 1, :pg_function, :provider, :tokenizer_class, :url, :api_key, :provider_params, :now, :now)
SQL
id: attrs[:id],
display_name: attrs[:display_name],
dimensions: attrs[:dimensions],
max_sequence_length: attrs[:max_sequence_length],
pg_function: attrs[:pg_function],
provider: attrs[:provider],
tokenizer_class: attrs[:tokenizer_class],
url: attrs[:url],
api_key: attrs[:api_key],
provider_params: attrs[:provider_params],
now: Time.zone.now,
)

# We hardcoded the ID to match with already generated embeddings. Let's restart the seq to avoid conflicts.
DB.exec(
"ALTER SEQUENCE embedding_definitions_id_seq RESTART WITH :new_seq",
new_seq: attrs[:id].to_i + 1,
)

DB.exec(
"UPDATE site_settings SET value=:id WHERE name = 'ai_embeddings_selected_model'",
id: attrs[:id],
)
end
end

0 comments on commit ec2f544

Please sign in to comment.