Skip to content

Commit

Permalink
Fix cherry-pick conflicts and compilation issue
Browse files Browse the repository at this point in the history
Signed-off-by: Songkan Tang <songkant@amazon.com>
  • Loading branch information
songkant-aws committed Aug 6, 2024
1 parent 1361127 commit f9d2c2d
Showing 1 changed file with 51 additions and 65 deletions.
116 changes: 51 additions & 65 deletions src/main/java/org/opensearch/agent/tools/RCATool.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.apache.commons.text.StringEscapeUtils.unescapeJson;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -79,21 +78,6 @@ public RCATool(Client client, String modelId, String embeddingModelId, Boolean i
this.isLLMOption = isLLMOption;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getVersion() {
return null;
}

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null;
}

public static final String TOOL_PROMPT =
"You are going to help find the root cause of the phenomenon from the several potential causes listed below. In this RCA process, for each cause, it usually needs to call an API to get some necessary information verify whether it's the right root cause. I've filled the related response for each cause, you should decide which cause are most possible to be the root cause based on these responses. \n\n"
+ "Human: PHENOMENON\n"
Expand Down Expand Up @@ -154,44 +138,57 @@ public <T> void runOption1(Map<String, String> parameters, ActionListener<T> lis
}

@SuppressWarnings("unchecked")
public <T> void runOption2(Map<String, ?> knowledgeBase, ActionListener<T> listener) {
String phenomenon = (String) knowledgeBase.get("phenomenon");

// API response embedded vectors
public <T> void runOption2(Map<String, String> parameters, ActionListener<T> listener) {
String knowledge = parameters.get(KNOWLEDGE_BASE_TOOL_OUTPUT_FIELD);
knowledge = unescapeJson(knowledge);
Map<String, ?> knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class);
List<Map<String, String>> causes = (List<Map<String, String>>) knowledgeBase.get("causes");
List<String> responses = causes.stream()
.map(cause -> cause.get("response"))
.collect(Collectors.toList());
List<RealVector> responseVectors = getEmbeddedVector(responses);

// expected API response embedded vectors
List<String> expectedResponses = causes.stream()
.map(cause -> cause.get("expected_response"))
.collect(Collectors.toList());
List<RealVector> expectedResponseVectors = getEmbeddedVector(expectedResponses);

Map<String, Double> dotProductMap = IntStream.range(0, causes.size())
.boxed()
.collect(Collectors.toMap(
i -> causes.get(i).get("reason"),
i -> responseVectors.get(i).dotProduct(expectedResponseVectors.get(i))
));

Optional<Map.Entry<String, Double>> mapEntry =
dotProductMap.entrySet().stream()
.max(Map.Entry.comparingByValue());

String rootCauseReason = "No root cause found";
if (mapEntry.isPresent()) {
Entry<String, Double> entry = mapEntry.get();
log.info("kNN RCA reason: {} with score: {} for the phenomenon: {}",
entry.getKey(), entry.getValue(), phenomenon);
rootCauseReason = entry.getKey();
} else {
log.warn("No root cause found for the phenomenon: {}", phenomenon);
}
List<String> apiList = causes.stream().map(cause -> cause.get(API_URL_FIELD)).distinct().collect(Collectors.toList());
final GroupedActionListener<Pair<String, String>> groupedListener = new GroupedActionListener<>(ActionListener.wrap(responses -> {
Map<String, String> apiToResponse = responses.stream().collect(Collectors.toMap(Pair::getKey, Pair::getValue));
String phenomenon = (String) knowledgeBase.get("phenomenon");

// API response embedded vectors
Map<String, RealVector> responseVectorMap = apiToResponse.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> {
List<RealVector> responseVector = getEmbeddedVector(List.of(entry.getValue()));
return responseVector.get(0);
}
));

// expected API response embedded vectors
List<String> expectedResponses = causes.stream()
.map(cause -> cause.get("expected_response"))
.collect(Collectors.toList());
List<RealVector> expectedResponseVectors = getEmbeddedVector(expectedResponses);

Map<String, Double> dotProductMap = IntStream.range(0, causes.size())
.boxed()
.collect(Collectors.toMap(
i -> causes.get(i).get("reason"),
i -> responseVectorMap.get(causes.get(i).get(API_URL_FIELD)).dotProduct(expectedResponseVectors.get(i))
));

Optional<Map.Entry<String, Double>> mapEntry =
dotProductMap.entrySet().stream()
.max(Map.Entry.comparingByValue());

String rootCauseReason = "No root cause found";
if (mapEntry.isPresent()) {
Entry<String, Double> entry = mapEntry.get();
log.info("kNN RCA reason: {} with score: {} for the phenomenon: {}",
entry.getKey(), entry.getValue(), phenomenon);
rootCauseReason = entry.getKey();
} else {
log.warn("No root cause found for the phenomenon: {}", phenomenon);
}

listener.onResponse((T) rootCauseReason);
listener.onResponse((T) rootCauseReason);
}, listener::onFailure), apiList.size());
// TODO: support different parameters for different apis
apiList.forEach(api -> invokeAPI(api, parameters, groupedListener));
}

/**
Expand All @@ -205,21 +202,10 @@ public <T> void runOption2(Map<String, ?> knowledgeBase, ActionListener<T> liste
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String knowledge = parameters.get(KNOWLEDGE_BASE_TOOL_OUTPUT_FIELD);
knowledge = unescapeJson(knowledge);
Map<String, ?> knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class);
List<Map<String, String>> causes = (List<Map<String, String>>) knowledgeBase.get("causes");
Map<String, String> apiToResponse = causes
.stream()
.map(c -> c.get(API_URL_FIELD))
.distinct()
.collect(Collectors.toMap(url -> url, url -> invokeAPI(url, parameters)));
causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get(API_URL_FIELD))));

if (isLLMOption) {
runOption1(knowledgeBase, listener);
runOption1(parameters, listener);
} else {
runOption2(knowledgeBase, listener);
runOption2(parameters, listener);
}
} catch (Exception e) {
log.error("Failed to run RCA tool", e);
Expand Down

0 comments on commit f9d2c2d

Please sign in to comment.