diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index dd8247353..2d07b7bbe 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -27,9 +27,9 @@ def index } end llms = - DiscourseAi::Configuration::LlmEnumerator.values.map do |hash| - { id: hash[:value], name: hash[:name] } - end + DiscourseAi::Configuration::LlmEnumerator + .values(allowed_seeded_llms: SiteSetting.ai_bot_allowed_seeded_models) + .map { |hash| { id: hash[:value], name: hash[:name] } } render json: { ai_personas: ai_personas, meta: { tools: tools, llms: llms } } end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index e0cea9f34..681ea7a8b 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -12,6 +12,7 @@ class AiPersona < ActiveRecord::Base validates :system_prompt, presence: true, length: { maximum: 10_000_000 } validate :system_persona_unchangeable, on: :update, if: :system validate :chat_preconditions + validate :allowed_seeded_model, if: :default_llm validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true # leaves some room for growth but sets a maximum to avoid memory issues # we may want to revisit this in the future @@ -275,6 +276,18 @@ def ensure_not_system throw :abort end end + + def allowed_seeded_model + return if default_llm.blank? + + llm = LlmModel.find_by(id: default_llm.split(":").last.to_i) + return if llm.nil? + return if !llm.seeded? + + return if SiteSetting.ai_bot_allowed_seeded_models.include?(llm.id.to_s) + + errors.add(:default_llm, I18n.t("discourse_ai.llm.configuration.invalid_seeded_model")) + end end # == Schema Information diff --git a/config/settings.yml b/config/settings.yml index 9a88429af..1d95e7a43 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -26,9 +26,8 @@ discourse_ai: default: 60 hidden: true - ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations" - ai_openai_embeddings_url: + ai_openai_embeddings_url: hidden: true default: "https://api.openai.com/v1/embeddings" ai_openai_organization: @@ -57,7 +56,7 @@ discourse_ai: ai_hugging_face_tei_endpoint_srv: default: "" hidden: true - ai_hugging_face_tei_api_key: + ai_hugging_face_tei_api_key: default: "" hidden: true ai_hugging_face_tei_reranker_endpoint: @@ -203,7 +202,7 @@ discourse_ai: client: true hidden: true - ai_embeddings_discourse_service_api_endpoint: + ai_embeddings_discourse_service_api_endpoint: default: "" hidden: true ai_embeddings_discourse_service_api_endpoint_srv: @@ -307,6 +306,11 @@ discourse_ai: ai_bot_github_access_token: default: "" secret: true + ai_bot_allowed_seeded_models: + default: "" + hidden: true + type: list + list_type: compact ai_automation_max_triage_per_minute: default: 60 hidden: true diff --git a/spec/models/ai_persona_spec.rb b/spec/models/ai_persona_spec.rb index 4783dbfb7..da7caf8ce 100644 --- a/spec/models/ai_persona_spec.rb +++ b/spec/models/ai_persona_spec.rb @@ -172,6 +172,29 @@ ) end + it "validates allowed seeded model" do + persona = + AiPersona.new( + name: "test", + description: "test", + system_prompt: "test", + tools: [], + allowed_group_ids: [], + default_llm: "seeded_model:-1", + ) + + llm_model = Fabricate(:llm_model, id: -1) + SiteSetting.ai_bot_allowed_seeded_models = "" + + expect(persona.valid?).to eq(false) + expect(persona.errors[:default_llm]).to include( + I18n.t("discourse_ai.llm.configuration.invalid_seeded_model"), + ) + + SiteSetting.ai_bot_allowed_seeded_models = "-1" + expect(persona.valid?).to eq(true) + end + it "does not leak caches between sites" do AiPersona.create!( name: "pun_bot",