Skip to content

Commit

Permalink
Implements inheritance pattern to break down inference request
Browse files Browse the repository at this point in the history
Signed-off-by: Fen Qin <mfenqin@amazon.com>
  • Loading branch information
Fen Qin authored and fen-qin committed Feb 12, 2025
1 parent 55a6627 commit bd53055
Show file tree
Hide file tree
Showing 14 changed files with 196 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.processor.InferenceRequest;
import org.opensearch.neuralsearch.processor.MapInferenceRequest;
import org.opensearch.neuralsearch.processor.SimilarityInferenceRequest;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.util.RetryUtil;

import lombok.NonNull;
Expand All @@ -39,37 +42,37 @@
@RequiredArgsConstructor
@Log4j2
public class MLCommonsClientAccessor {
private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
private final MachineLearningNodeClient mlClient;

/**
* Wrapper around {@link #inferenceSentences} that expected a single input text and produces a single floating
* point vector as a response.
*
* @param inferenceRequest {@link InferenceRequest}
* @param modelId {@link String}
* @param inputText {@link String}
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out
*/
public void inferenceSentence(@NonNull final InferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
if (inferenceRequest.getInputTexts().size() != 1) {
listener.onFailure(
new IllegalArgumentException(
"Unexpected number of input texts. Expected 1 input text, but got [" + inferenceRequest.getInputTexts().size() + "]"
)
);
return;
}
inferenceSentences(inferenceRequest, ActionListener.wrap(response -> {
if (response.size() != 1) {
listener.onFailure(
new IllegalStateException(
"Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"
)
);
return;
}
public void inferenceSentence(
@NonNull final String modelId,
@NonNull final String inputText,
@NonNull final ActionListener<List<Float>> listener
) {

inferenceSentences(
TextInferenceRequest.builder().modelId(modelId).inputTexts(List.of(inputText)).build(),
ActionListener.wrap(response -> {
if (response.size() != 1) {
listener.onFailure(
new IllegalStateException(
"Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"
)
);
return;
}

listener.onResponse(response.getFirst());
}, listener::onFailure));
listener.onResponse(response.getFirst());
}, listener::onFailure)
);
}

/**
Expand All @@ -83,25 +86,17 @@ public void inferenceSentence(@NonNull final InferenceRequest inferenceRequest,
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentences(
@NonNull final InferenceRequest inferenceRequest,
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<List<Float>>> listener
) {
retryableInferenceSentencesWithVectorResult(
(inferenceRequest.getTargetResponseFilters() == null || inferenceRequest.getTargetResponseFilters().isEmpty())
? TARGET_RESPONSE_FILTERS
: inferenceRequest.getTargetResponseFilters(),
inferenceRequest.getModelId(),
inferenceRequest.getInputTexts(),
0,
listener
);
retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
}

public void inferenceSentencesWithMapResult(
@NonNull final InferenceRequest inferenceRequest,
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<Map<String, ?>>> listener
) {
retryableInferenceSentencesWithMapResult(inferenceRequest.getModelId(), inferenceRequest.getInputTexts(), 0, listener);
retryableInferenceSentencesWithMapResult(inferenceRequest, 0, listener);
}

/**
Expand All @@ -112,16 +107,8 @@ public void inferenceSentencesWithMapResult(
* @param inferenceRequest {@link InferenceRequest}
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentencesMap(@NonNull InferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
retryableInferenceSentencesWithSingleVectorResult(
(inferenceRequest.getTargetResponseFilters() == null || inferenceRequest.getTargetResponseFilters().isEmpty())
? TARGET_RESPONSE_FILTERS
: inferenceRequest.getTargetResponseFilters(),
inferenceRequest.getModelId(),
inferenceRequest.getInputObjects(),
0,
listener
);
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
}

/**
Expand All @@ -132,73 +119,65 @@ public void inferenceSentencesMap(@NonNull InferenceRequest inferenceRequest, @N
* @param inferenceRequest {@link InferenceRequest}
* @param listener {@link ActionListener} receives the result of the inference
*/
public void inferenceSimilarity(@NonNull InferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
retryableInferenceSimilarityWithVectorResult(
inferenceRequest.getModelId(),
inferenceRequest.getQueryText(),
inferenceRequest.getInputTexts(),
0,
listener
);
public void inferenceSimilarity(
@NonNull SimilarityInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(inferenceRequest, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final TextInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
) {
MLInput mlInput = createMLTextInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLTextInput(null, inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithMapResult(modelId, inputText, retryTime + 1, listener),
() -> retryableInferenceSentencesWithMapResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
final TextInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTime + 1, listener),
() -> retryableInferenceSentencesWithVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
}

private void retryableInferenceSimilarityWithVectorResult(
final String modelId,
final String queryText,
final List<String> inputText,
final SimilarityInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener),
() -> retryableInferenceSimilarityWithVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
Expand Down Expand Up @@ -252,28 +231,20 @@ private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
}

private void retryableInferenceSentencesWithSingleVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final Map<String, String> inputObjects,
final MapInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
e,
retryTime,
() -> retryableInferenceSentencesWithSingleVectorResult(
targetResponseFilters,
modelId,
inputObjects,
retryTime + 1,
listener
),
() -> retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, retryTime + 1, listener),
listener
)
));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,32 @@

import java.util.List;

import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.Setter;
import lombok.experimental.SuperBuilder;

@Builder
@AllArgsConstructor
@SuperBuilder
@NoArgsConstructor
@Getter
@Setter
/**
* POJO class to hold request parameters to call ml commons client accessor.
* Base abstract class for inference requests.
* This class contains common fields and behaviors shared across different types of inference requests.
*/
public class InferenceRequest {
public abstract class InferenceRequest {
/**
* Unique identifier for the model to be used for inference.
* This field is required and cannot be null.
*/
@NonNull
private String modelId; // required
private List<String> inputTexts; // on which inference needs to happen
private Map<String, String> inputObjects;
private List<String> targetResponseFilters;
private String queryText;
private String modelId;
/**
* List of targetResponseFilters to be applied.
* Defaults value if not specified.
*/
@Builder.Default
private List<String> targetResponseFilters = List.of("sentence_embedding");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import java.util.Map;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import lombok.experimental.SuperBuilder;

/**
* Implementation of InferenceRequest for inputObjects based inference requests.
* Use this class when the input data consists of key-value pairs.
*
* @see InferenceRequest
*/
@SuperBuilder
@NoArgsConstructor
@Getter
@Setter
public class MapInferenceRequest extends InferenceRequest {
private Map<String, String> inputObjects;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.NoArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.SuperBuilder;

/**
* Implementation of InferenceRequest for similarity based text inference requests.
*
* @see TextInferenceRequest
*/
@SuperBuilder
@NoArgsConstructor
@Getter
@Setter
public class SimilarityInferenceRequest extends TextInferenceRequest {
private String queryText;
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
Expand All @@ -75,7 +75,7 @@ public void doExecute(
@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> {
List<Map<String, Float>> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void doExecute(
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
Expand All @@ -59,7 +59,7 @@ public void doExecute(
@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentences(
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(handler::accept, onException)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
handler.accept(ingestDocument, null);
} else {
mlCommonsClientAccessor.inferenceSentencesMap(
InferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
MapInferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, vectors);
handler.accept(ingestDocument, null);
Expand Down
Loading

0 comments on commit bd53055

Please sign in to comment.