From f9d2c2d828ba9db9837c6d9b1ac8951f4bf9bee4 Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Tue, 6 Aug 2024 11:34:29 +0800 Subject: [PATCH] Fix cherry-pick conflicts and compilation issue Signed-off-by: Songkan Tang --- .../org/opensearch/agent/tools/RCATool.java | 116 ++++++++---------- 1 file changed, 51 insertions(+), 65 deletions(-) diff --git a/src/main/java/org/opensearch/agent/tools/RCATool.java b/src/main/java/org/opensearch/agent/tools/RCATool.java index 57068790..3dafb134 100644 --- a/src/main/java/org/opensearch/agent/tools/RCATool.java +++ b/src/main/java/org/opensearch/agent/tools/RCATool.java @@ -8,7 +8,6 @@ import static org.apache.commons.text.StringEscapeUtils.unescapeJson; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -79,21 +78,6 @@ public RCATool(Client client, String modelId, String embeddingModelId, Boolean i this.isLLMOption = isLLMOption; } - @Override - public String getType() { - return TYPE; - } - - @Override - public String getVersion() { - return null; - } - - @Override - public boolean validate(Map parameters) { - return parameters != null; - } - public static final String TOOL_PROMPT = "You are going to help find the root cause of the phenomenon from the several potential causes listed below. In this RCA process, for each cause, it usually needs to call an API to get some necessary information verify whether it's the right root cause. I've filled the related response for each cause, you should decide which cause are most possible to be the root cause based on these responses. \n\n" + "Human: PHENOMENON\n" @@ -154,44 +138,57 @@ public void runOption1(Map parameters, ActionListener lis } @SuppressWarnings("unchecked") - public void runOption2(Map knowledgeBase, ActionListener listener) { - String phenomenon = (String) knowledgeBase.get("phenomenon"); - - // API response embedded vectors + public void runOption2(Map parameters, ActionListener listener) { + String knowledge = parameters.get(KNOWLEDGE_BASE_TOOL_OUTPUT_FIELD); + knowledge = unescapeJson(knowledge); + Map knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class); List> causes = (List>) knowledgeBase.get("causes"); - List responses = causes.stream() - .map(cause -> cause.get("response")) - .collect(Collectors.toList()); - List responseVectors = getEmbeddedVector(responses); - - // expected API response embedded vectors - List expectedResponses = causes.stream() - .map(cause -> cause.get("expected_response")) - .collect(Collectors.toList()); - List expectedResponseVectors = getEmbeddedVector(expectedResponses); - - Map dotProductMap = IntStream.range(0, causes.size()) - .boxed() - .collect(Collectors.toMap( - i -> causes.get(i).get("reason"), - i -> responseVectors.get(i).dotProduct(expectedResponseVectors.get(i)) - )); - - Optional> mapEntry = - dotProductMap.entrySet().stream() - .max(Map.Entry.comparingByValue()); - - String rootCauseReason = "No root cause found"; - if (mapEntry.isPresent()) { - Entry entry = mapEntry.get(); - log.info("kNN RCA reason: {} with score: {} for the phenomenon: {}", - entry.getKey(), entry.getValue(), phenomenon); - rootCauseReason = entry.getKey(); - } else { - log.warn("No root cause found for the phenomenon: {}", phenomenon); - } + List apiList = causes.stream().map(cause -> cause.get(API_URL_FIELD)).distinct().collect(Collectors.toList()); + final GroupedActionListener> groupedListener = new GroupedActionListener<>(ActionListener.wrap(responses -> { + Map apiToResponse = responses.stream().collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + String phenomenon = (String) knowledgeBase.get("phenomenon"); + + // API response embedded vectors + Map responseVectorMap = apiToResponse.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> { + List responseVector = getEmbeddedVector(List.of(entry.getValue())); + return responseVector.get(0); + } + )); + + // expected API response embedded vectors + List expectedResponses = causes.stream() + .map(cause -> cause.get("expected_response")) + .collect(Collectors.toList()); + List expectedResponseVectors = getEmbeddedVector(expectedResponses); + + Map dotProductMap = IntStream.range(0, causes.size()) + .boxed() + .collect(Collectors.toMap( + i -> causes.get(i).get("reason"), + i -> responseVectorMap.get(causes.get(i).get(API_URL_FIELD)).dotProduct(expectedResponseVectors.get(i)) + )); + + Optional> mapEntry = + dotProductMap.entrySet().stream() + .max(Map.Entry.comparingByValue()); + + String rootCauseReason = "No root cause found"; + if (mapEntry.isPresent()) { + Entry entry = mapEntry.get(); + log.info("kNN RCA reason: {} with score: {} for the phenomenon: {}", + entry.getKey(), entry.getValue(), phenomenon); + rootCauseReason = entry.getKey(); + } else { + log.warn("No root cause found for the phenomenon: {}", phenomenon); + } - listener.onResponse((T) rootCauseReason); + listener.onResponse((T) rootCauseReason); + }, listener::onFailure), apiList.size()); + // TODO: support different parameters for different apis + apiList.forEach(api -> invokeAPI(api, parameters, groupedListener)); } /** @@ -205,21 +202,10 @@ public void runOption2(Map knowledgeBase, ActionListener liste @Override public void run(Map parameters, ActionListener listener) { try { - String knowledge = parameters.get(KNOWLEDGE_BASE_TOOL_OUTPUT_FIELD); - knowledge = unescapeJson(knowledge); - Map knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class); - List> causes = (List>) knowledgeBase.get("causes"); - Map apiToResponse = causes - .stream() - .map(c -> c.get(API_URL_FIELD)) - .distinct() - .collect(Collectors.toMap(url -> url, url -> invokeAPI(url, parameters))); - causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get(API_URL_FIELD)))); - if (isLLMOption) { - runOption1(knowledgeBase, listener); + runOption1(parameters, listener); } else { - runOption2(knowledgeBase, listener); + runOption2(parameters, listener); } } catch (Exception e) { log.error("Failed to run RCA tool", e);