Skip to content

Commit

Permalink
DEV: Build sentiment clients outside of promises (#1117)
Browse files Browse the repository at this point in the history
  • Loading branch information
romanrizzi authored Feb 6, 2025
1 parent e52045e commit 90bcb8b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 38 deletions.
38 changes: 12 additions & 26 deletions lib/inference/hugging_face_text_embeddings.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ def initialize(endpoint, key, referer = Discourse.base_url)
attr_reader :endpoint, :key, :referer

class << self
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end

def reranker_configured?
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
Expand Down Expand Up @@ -50,32 +45,23 @@ def rerank(content, candidates)

JSON.parse(response.body, symbolize_names: true)
end
end

def classify(content, model_config, base_url = Discourse.base_url)
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = model_config.api_key
headers["Authorization"] = "Bearer #{model_config.api_key}"

body = { inputs: content, truncate: true }.to_json

api_endpoint = model_config.endpoint
if api_endpoint.present? && api_endpoint.start_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
api_endpoint = "https://#{service.target}:#{service.port}"
end
def classify_by_sentiment!(content)
response = do_request!(content)

conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
JSON.parse(response.body, symbolize_names: true)
end

if response.status != 200
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
end
def perform!(content)
response = do_request!(content)

JSON.parse(response.body, symbolize_names: true)
end
JSON.parse(response.body, symbolize_names: true).first
end

def perform!(content)
private

def do_request!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json

Expand All @@ -89,7 +75,7 @@ def perform!(content)

raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status)

JSON.parse(response.body, symbolize_names: true).first
response
end
end
end
Expand Down
39 changes: 27 additions & 12 deletions lib/sentiment/post_classification.rb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def bulk_classify!(relation)

available_classifiers = classifiers
return if available_classifiers.blank?
base_url = Discourse.base_url

promised_classifications =
relation
Expand All @@ -70,12 +69,14 @@ def bulk_classify!(relation)
already_classified = w_text[:target].sentiment_classifications.map(&:model_used)

classifiers_for_target =
available_classifiers.reject { |ac| already_classified.include?(ac.model_name) }
available_classifiers.reject do |ac|
already_classified.include?(ac[:model_name])
end

promised_target_results =
classifiers_for_target.map do |c|
classifiers_for_target.map do |cft|
Concurrent::Promises.future_on(pool) do
results[c.model_name] = request_with(w_text[:text], c, base_url)
results[cft[:model_name]] = request_with(cft[:client], w_text[:text])
end
end

Expand All @@ -98,26 +99,40 @@ def bulk_classify!(relation)

def classify!(target)
return if target.blank?
return if classifiers.blank?
available_classifiers = classifiers
return if available_classifiers.blank?

to_classify = prepare_text(target)
return if to_classify.blank?

already_classified = target.sentiment_classifications.map(&:model_used)
classifiers_for_target =
classifiers.reject { |ac| already_classified.include?(ac.model_name) }
available_classifiers.reject { |ac| already_classified.include?(ac[:model_name]) }

results =
classifiers_for_target.reduce({}) do |memo, model|
memo[model.model_name] = request_with(to_classify, model)
classifiers_for_target.reduce({}) do |memo, cft|
memo[cft[:model_name]] = request_with(cft[:client], to_classify)
memo
end

store_classification(target, results)
end

def classifiers
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.map do |config|
api_endpoint = config.endpoint

if api_endpoint.present? && api_endpoint.start_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
api_endpoint = "https://#{service.target}:#{service.port}"
end

{
model_name: config.model_name,
client:
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(api_endpoint, config.api_key),
}
end
end

def has_classifiers?
Expand All @@ -137,9 +152,9 @@ def prepare_text(target)
Tokenizer::BertTokenizer.truncate(content, 512)
end

def request_with(content, config, base_url = Discourse.base_url)
result =
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
def request_with(client, content)
result = client.classify_by_sentiment!(content)

transform_result(result)
end

Expand Down

0 comments on commit 90bcb8b

Please sign in to comment.