Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code refactoring on MLCommonsClientAccessor request #1178

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.processor.InferenceRequest;
import org.opensearch.neuralsearch.processor.MapInferenceRequest;
import org.opensearch.neuralsearch.processor.SimilarityInferenceRequest;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.util.RetryUtil;

import lombok.NonNull;
Expand All @@ -38,53 +42,37 @@
@RequiredArgsConstructor
@Log4j2
public class MLCommonsClientAccessor {
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
private final MachineLearningNodeClient mlClient;

/**
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
* point vector as a response.
*
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen
* @param inputText {@link String}
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
*/
public void inferenceSentence(
@NonNull final String modelId,
@NonNull final String inputText,
@NonNull final ActionListener<List<Float>> listener
) {
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> {
if (response.size() != 1) {
listener.onFailure(
new IllegalStateException(
"Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"
)
);
return;
}

listener.onResponse(response.get(0));
}, listener::onFailure));
}
inferenceSentences(
TextInferenceRequest.builder().modelId(modelId).inputTexts(List.of(inputText)).build(),
ActionListener.wrap(response -> {
if (response.size() != 1) {
listener.onFailure(
new IllegalStateException(
"Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"
)
);
return;
}

/**
* Abstraction to call predict function of api of MLClient with default targetResponse filters. It uses the
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
* inputText. We are not making this function generic enough to take any function or TaskType as currently we
* need to run only TextEmbedding tasks only.
*
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
*/
public void inferenceSentences(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener);
listener.onResponse(response.getFirst());
}, listener::onFailure)
);
}

/**
Expand All @@ -94,121 +82,102 @@ public void inferenceSentences(
* inputText. We are not making this function generic enough to take any function or TaskType as currently we
* need to run only TextEmbedding tasks only.
*
* @param targetResponseFilters {@link List} of {@link String} which filters out the responses
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen
* @param inferenceRequest {@link InferenceRequest}
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentences(
@NonNull final List<String> targetResponseFilters,
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<List<Float>>> listener
) {
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
}

public void inferenceSentencesWithMapResult(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<Map<String, ?>>> listener
) {
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
retryableInferenceSentencesWithMapResult(inferenceRequest, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the
* custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a list of floats in the order of inputText.
*
* @param modelId {@link String}
* @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen
* @param inferenceRequest {@link InferenceRequest}
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentences(
@NonNull final String modelId,
@NonNull final Map<String, String> inputObjects,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
*
* @param modelId {@link String} ML-Commons Model Id
* @param queryText {@link String} The query to compare all the inputText to
* @param inputText {@link List} of {@link String} The texts to compare to the query
* @param inferenceRequest {@link InferenceRequest}
* @param listener {@link ActionListener} receives the result of the inference
*/
public void inferenceSimilarity(
@NonNull final String modelId,
@NonNull final String queryText,
@NonNull final List<String> inputText,
@NonNull SimilarityInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
retryableInferenceSimilarityWithVectorResult(inferenceRequest, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final TextInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
) {
MLInput mlInput = createMLTextInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLTextInput(null, inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithMapResult(modelId, inputText, retryTime + 1, listener),
() -> retryableInferenceSentencesWithMapResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
final TextInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTime + 1, listener),
() -> retryableInferenceSentencesWithVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSimilarityWithVectorResult(
final String modelId,
final String queryText,
final List<String> inputText,
final SimilarityInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener),
() -> retryableInferenceSimilarityWithVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
Expand Down Expand Up @@ -262,28 +231,20 @@ private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
}

private void retryableInferenceSentencesWithSingleVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final Map<String, String> inputObjects,
final MapInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithSingleVectorResult(
targetResponseFilters,
modelId,
inputObjects,
retryTime + 1,
listener
),
() -> retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import java.util.List;

import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.Setter;
import lombok.experimental.SuperBuilder;

@SuperBuilder
@NoArgsConstructor
@Getter
@Setter
/**
* Base abstract class for inference requests.
* This class contains common fields and behaviors shared across different types of inference requests.
*/
public abstract class InferenceRequest {
/**
* Unique identifier for the model to be used for inference.
* This field is required and cannot be null.
*/
@NonNull
private String modelId;
/**
* List of targetResponseFilters to be applied.
* Defaults value if not specified.
*/
@Builder.Default
private List<String> targetResponseFilters = List.of("sentence_embedding");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import java.util.Map;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.experimental.SuperBuilder;

/**
* Implementation of InferenceRequest for inputObjects based inference requests.
* Use this class when the input data consists of key-value pairs.
*
* @see InferenceRequest
*/
@SuperBuilder
@NoArgsConstructor
@Getter
@Setter
public class MapInferenceRequest extends InferenceRequest {
private Map<String, String> inputObjects;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.NoArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.SuperBuilder;

/**
* Implementation of InferenceRequest for similarity based text inference requests.
*
* @see TextInferenceRequest
*/
@SuperBuilder
@NoArgsConstructor
@Getter
@Setter
public class SimilarityInferenceRequest extends TextInferenceRequest {
private String queryText;
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,30 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
handler.accept(sparseVectors);
}, onException));
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
handler.accept(sparseVectors);
}, onException)
);
}
}
Loading
Loading