Skip to content

Commit

Permalink
resolve code confclits
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <yuyezhu@amazon.com>
  • Loading branch information
yuye-aws committed Aug 6, 2024
1 parent 927ba0f commit c2d3fb4
Showing 1 changed file with 85 additions and 98 deletions.
183 changes: 85 additions & 98 deletions src/main/java/org/opensearch/agent/tools/RCATool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import static org.apache.commons.text.StringEscapeUtils.unescapeJson;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand Down Expand Up @@ -60,10 +61,10 @@ public class RCATool implements Tool {
private String description = DEFAULT_DESCRIPTION;

private final Client client;
private final String modelId;
private final String llmModelId;
private final String embeddingModelId;
private final Boolean isLLMOption;
private static final String MODEL_ID = "model_id";
private static final String LLM_MODEL_ID_FIELD = "llm_model_id";
private static final String EMBEDDING_MODEL_ID_FIELD = "embedding_model_id";
private static final String IS_LLM_OPTION = "is_llm_option";

Expand All @@ -74,7 +75,7 @@ public class RCATool implements Tool {

public RCATool(Client client, String modelId, String embeddingModelId, Boolean isLLMOption) {
this.client = client;
this.modelId = modelId;
this.llmModelId = modelId;
this.embeddingModelId = embeddingModelId;
this.isLLMOption = isLLMOption;
}
Expand Down Expand Up @@ -108,84 +109,65 @@ public boolean validate(Map<String, String> parameters) {
+ "Assistant: ";

@SuppressWarnings("unchecked")
public <T> void runOption1(Map<String, String> parameters, ActionListener<T> listener) {
String knowledge = parameters.get(KNOWLEDGE_BASE_TOOL_OUTPUT_FIELD);
knowledge = unescapeJson(knowledge);
Map<String, ?> knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class);
List<Map<String, String>> causes = (List<Map<String, String>>) knowledgeBase.get("causes");
List<String> apiList = causes.stream().map(cause -> cause.get(API_URL_FIELD)).distinct().collect(Collectors.toList());
final GroupedActionListener<Pair<String, String>> groupedListener = new GroupedActionListener<>(ActionListener.wrap(responses -> {
Map<String, String> apiToResponse = responses.stream().collect(Collectors.toMap(Pair::getKey, Pair::getValue));
Map<String, String> LLMParams = new java.util.HashMap<>(
Map
.of(
"phenomenon",
(String) knowledgeBase.get("phenomenon"),
"causes",
StringUtils.gson.toJson(causes),
"responses",
StringUtils.gson.toJson(apiToResponse)
)
);
StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT);
LLMParams.put("prompt", finalToolPrompt);
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
listener.onResponse((T) dataMap.get("completion"));
}, listener::onFailure));
}, listener::onFailure), apiList.size());
// TODO: support different parameters for different apis
apiList.forEach(api -> invokeAPI(api, parameters, groupedListener));
public <T> void runOption1(
String phenomenon,
List<Map<String, String>> causes,
Map<String, String> apiToResponse,
ActionListener<T> listener
) {
Map<String, String> LLMParams = new java.util.HashMap<>(
Map.of("phenomenon", phenomenon, "causes", StringUtils.gson.toJson(causes), "responses", StringUtils.gson.toJson(apiToResponse))
);
StringSubstitutor substitute = new StringSubstitutor(LLMParams, "${parameters.", "}");
String finalToolPrompt = substitute.replace(TOOL_PROMPT);
LLMParams.put("prompt", finalToolPrompt);
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(LLMParams).build();
ActionRequest request = new MLPredictionTaskRequest(
llmModelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response.getOutput();
Map<String, ?> dataMap = Optional
.ofNullable(modelTensorOutput.getMlModelOutputs())
.flatMap(outputs -> outputs.stream().findFirst())
.flatMap(modelTensors -> modelTensors.getMlModelTensors().stream().findFirst())
.map(ModelTensor::getDataAsMap)
.orElse(null);
if (dataMap == null) {
throw new IllegalArgumentException("No dataMap returned from LLM.");
}
listener.onResponse((T) dataMap.get("completion"));
}, listener::onFailure));
}

@SuppressWarnings("unchecked")
public <T> void runOption2(Map<String, ?> knowledgeBase, ActionListener<T> listener) {
String phenomenon = (String) knowledgeBase.get("phenomenon");

// API response embedded vectors
List<Map<String, String>> causes = (List<Map<String, String>>) knowledgeBase.get("causes");
List<String> responses = causes.stream()
.map(cause -> cause.get("response"))
.collect(Collectors.toList());
public <T> void runOption2(
String phenomenon,
List<Map<String, String>> causes,
Map<String, String> apiToResponse,
ActionListener<T> listener
) {
List<String> responses = new ArrayList<>(apiToResponse.values());
List<RealVector> responseVectors = getEmbeddedVector(responses);

// expected API response embedded vectors
List<String> expectedResponses = causes.stream()
.map(cause -> cause.get("expected_response"))
.collect(Collectors.toList());
List<String> expectedResponses = causes.stream().map(cause -> cause.get("expected_response")).collect(Collectors.toList());
List<RealVector> expectedResponseVectors = getEmbeddedVector(expectedResponses);

Map<String, Double> dotProductMap = IntStream.range(0, causes.size())
Map<String, Double> dotProductMap = IntStream
.range(0, causes.size())
.boxed()
.collect(Collectors.toMap(
i -> causes.get(i).get("reason"),
i -> responseVectors.get(i).dotProduct(expectedResponseVectors.get(i))
));
.collect(
Collectors.toMap(i -> causes.get(i).get("reason"), i -> responseVectors.get(i).dotProduct(expectedResponseVectors.get(i)))
);

Optional<Map.Entry<String, Double>> mapEntry =
dotProductMap.entrySet().stream()
.max(Map.Entry.comparingByValue());
Optional<Map.Entry<String, Double>> mapEntry = dotProductMap.entrySet().stream().max(Map.Entry.comparingByValue());

String rootCauseReason = "No root cause found";
if (mapEntry.isPresent()) {
Entry<String, Double> entry = mapEntry.get();
log.info("kNN RCA reason: {} with score: {} for the phenomenon: {}",
entry.getKey(), entry.getValue(), phenomenon);
log.info("kNN RCA reason: {} with score: {} for the phenomenon: {}", entry.getKey(), entry.getValue(), phenomenon);
rootCauseReason = entry.getKey();
} else {
log.warn("No root cause found for the phenomenon: {}", phenomenon);
Expand All @@ -203,24 +185,29 @@ public <T> void runOption2(Map<String, ?> knowledgeBase, ActionListener<T> liste
* @param <T>
*/
@Override
@SuppressWarnings("unchecked")
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
try {
String knowledge = parameters.get(KNOWLEDGE_BASE_TOOL_OUTPUT_FIELD);
knowledge = unescapeJson(knowledge);
Map<String, ?> knowledgeBase = StringUtils.gson.fromJson(knowledge, Map.class);
String phenomenon = (String) knowledgeBase.get("phenomenon");
List<Map<String, String>> causes = (List<Map<String, String>>) knowledgeBase.get("causes");
Map<String, String> apiToResponse = causes
.stream()
.map(c -> c.get(API_URL_FIELD))
.distinct()
.collect(Collectors.toMap(url -> url, url -> invokeAPI(url, parameters)));
causes.forEach(cause -> cause.put("response", apiToResponse.get(cause.get(API_URL_FIELD))));

if (isLLMOption) {
runOption1(knowledgeBase, listener);
} else {
runOption2(knowledgeBase, listener);
}
List<String> apiList = causes.stream().map(cause -> cause.get(API_URL_FIELD)).distinct().collect(Collectors.toList());

final GroupedActionListener<Pair<String, String>> groupedListener = new GroupedActionListener<>(
ActionListener.wrap(responses -> {
Map<String, String> apiToResponse = responses.stream().collect(Collectors.toMap(Pair::getKey, Pair::getValue));
if (isLLMOption) {
runOption1(phenomenon, causes, apiToResponse, listener);
} else {
runOption1(phenomenon, causes, apiToResponse, listener);
}
}, listener::onFailure),
apiList.size()
);
// TODO: support different parameters for different apis
apiList.forEach(api -> invokeAPI(api, parameters, groupedListener));
} catch (Exception e) {
log.error("Failed to run RCA tool", e);
listener.onFailure(e);
Expand Down Expand Up @@ -270,27 +257,27 @@ public void onFailure(Exception e) {
}

private List<RealVector> getEmbeddedVector(List<String> docs) {
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder()
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet
.builder()
.docs(docs)
.resultFilter(ModelResultFilter.builder()
.returnNumber(true)
.targetResponse(List.of("sentence_embedding"))
.build())
.resultFilter(ModelResultFilter.builder().returnNumber(true).targetResponse(List.of("sentence_embedding")).build())
.build();
ActionRequest request = new MLPredictionTaskRequest(
embeddingModelId,
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
ActionFuture<MLTaskResponse> mlTaskRspFuture = client.execute(MLPredictionTaskAction.INSTANCE, request);
MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()
);
ActionFuture<MLTaskResponse> mlTaskRspFuture = client.execute(MLPredictionTaskAction.INSTANCE, request);
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskRspFuture.actionGet().getOutput();
List<ModelTensor> mlModelOutputs = modelTensorOutput.getMlModelOutputs().stream()
List<ModelTensor> mlModelOutputs = modelTensorOutput
.getMlModelOutputs()
.stream()
.map(modelTensors -> modelTensors.getMlModelTensors().get(0))
.collect(Collectors.toList());
return mlModelOutputs.stream()
.map(tensor -> {
Number[] data = tensor.getData();
// Simplify the computation in POC, every MLResultDataType will use high precision FLOAT32, aka double in Java.
return new ArrayRealVector(Arrays.stream(data).mapToDouble(Number::doubleValue).toArray());
}).collect(Collectors.toList());
return mlModelOutputs.stream().map(tensor -> {
Number[] data = tensor.getData();
// Simplify the computation in POC, every MLResultDataType will use high precision FLOAT32, aka double in Java.
return new ArrayRealVector(Arrays.stream(data).mapToDouble(Number::doubleValue).toArray());
}).collect(Collectors.toList());
}

public static class Factory implements Tool.Factory<RCATool> {
Expand All @@ -317,14 +304,14 @@ public void init(Client client) {

@Override
public RCATool create(Map<String, Object> parameters) {
Boolean isLLMOption = Boolean.valueOf((String) parameters.getOrDefault(IS_LLM_OPTION, "true"));
String modelId = (String) parameters.get(MODEL_ID);
boolean isLLMOption = Boolean.parseBoolean((String) parameters.getOrDefault(IS_LLM_OPTION, "true"));
String modelId = (String) parameters.get(LLM_MODEL_ID_FIELD);
if (isLLMOption && Strings.isBlank(modelId)) {
throw new IllegalArgumentException("model_id cannot be null or blank.");
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s cannot be null or blank.", LLM_MODEL_ID_FIELD));
}
String embeddingModelId = (String) parameters.get(EMBEDDING_MODEL_ID_FIELD);
if (!isLLMOption && Strings.isBlank(embeddingModelId)) {
throw new IllegalArgumentException("embedding_model_id cannot be null or blank.");
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s cannot be null or blank.", EMBEDDING_MODEL_ID_FIELD));
}
return new RCATool(client, modelId, embeddingModelId, isLLMOption);
}
Expand Down

0 comments on commit c2d3fb4

Please sign in to comment.