Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different embedding types of model response #1007

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))
### Bug Fixes
### Infrastructure
- [3.0] Update neural-search for OpenSearch 3.0 compatibility ([#1141](https://github.com/opensearch-project/neural-search/pull/1141))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
loadModel(sparseModelId);
MLModelState oldModelState = getModelState(sparseModelId);
logger.info("Model state in OLD phase: {}", oldModelState);
if (oldModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in OLD phase. Current state: {}", sparseModelId, oldModelState);
if (oldModelState != MLModelState.LOADED && oldModelState != MLModelState.DEPLOYED) {
logger.error(
"Model {} is not in LOADED or DEPLOYED state in OLD phase. Current state: {}",
sparseModelId,
oldModelState
);
waitForModelToLoad(sparseModelId);
}
createPipelineForSparseEncodingProcessor(sparseModelId, SPARSE_PIPELINE, 2);
Expand All @@ -52,8 +56,12 @@ public void testBatchIngestion_SparseEncodingProcessor_E2EFlow() throws Exceptio
loadModel(sparseModelId);
MLModelState mixedModelState = getModelState(sparseModelId);
logger.info("Model state in MIXED phase: {}", mixedModelState);
if (mixedModelState != MLModelState.LOADED) {
logger.error("Model {} is not in LOADED state in MIXED phase. Current state: {}", sparseModelId, mixedModelState);
if (mixedModelState != MLModelState.LOADED && mixedModelState != MLModelState.DEPLOYED) {
logger.error(
"Model {} is not in LOADED or DEPLOYED state in MIXED phase. Current state: {}",
sparseModelId,
mixedModelState
);
waitForModelToLoad(sparseModelId);
}
logger.info("Pipeline state in MIXED phase: {}", getIngestionPipeline(SPARSE_PIPELINE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ public class VectorUtil {
* @param vectorAsList {@link List} of {@link Float}'s representing the vector
* @return array of floats produced from input list
*/
public static float[] vectorAsListToArray(List<Float> vectorAsList) {
public static float[] vectorAsListToArray(List<Number> vectorAsList) {
float[] vector = new float[vectorAsList.size()];
for (int i = 0; i < vectorAsList.size(); i++) {
vector[i] = vectorAsList.get(i);
vector[i] = vectorAsList.get(i).floatValue();
}
return vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class MLCommonsClientAccessor {
public void inferenceSentence(
@NonNull final String modelId,
@NonNull final String inputText,
@NonNull final ActionListener<List<Float>> listener
@NonNull final ActionListener<List<Number>> listener
) {

inferenceSentences(
Expand Down Expand Up @@ -87,7 +87,7 @@ public void inferenceSentence(
*/
public void inferenceSentences(
@NonNull final TextInferenceRequest inferenceRequest,
@NonNull final ActionListener<List<List<Float>>> listener
@NonNull final ActionListener<List<List<Number>>> listener
) {
retryableInferenceSentencesWithVectorResult(inferenceRequest, 0, listener);
}
Expand All @@ -107,7 +107,7 @@ public void inferenceSentencesWithMapResult(
* @param inferenceRequest {@link InferenceRequest}
* @param listener {@link ActionListener} which will be called when prediction is completed or errored out.
*/
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Float>> listener) {
public void inferenceSentencesMap(@NonNull MapInferenceRequest inferenceRequest, @NonNull final ActionListener<List<Number>> listener) {
retryableInferenceSentencesWithSingleVectorResult(inferenceRequest, 0, listener);
}

Expand Down Expand Up @@ -148,11 +148,11 @@ private void retryableInferenceSentencesWithMapResult(
private void retryableInferenceSentencesWithVectorResult(
final TextInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<List<Float>>> listener
final ActionListener<List<List<Number>>> listener
) {
MLInput mlInput = createMLTextInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
final List<List<Number>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
},
e -> RetryUtil.handleRetryOrFailure(
Expand All @@ -171,7 +171,9 @@ private void retryableInferenceSimilarityWithVectorResult(
) {
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
final List<Float> scores = buildVectorFromResponse(mlOutput).stream()
.map(v -> v.getFirst().floatValue())
.collect(Collectors.toList());
listener.onResponse(scores);
},
e -> RetryUtil.handleRetryOrFailure(
Expand All @@ -194,14 +196,14 @@ private MLInput createMLTextPairsInput(final String query, final List<String> in
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
private <T extends Number> List<List<T>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<T>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
for (final ModelTensors tensors : tensorOutputList) {
final List<ModelTensor> tensorsList = tensors.getMlModelTensors();
for (final ModelTensor tensor : tensorsList) {
vector.add(Arrays.stream(tensor.getData()).map(value -> (Float) value).collect(Collectors.toList()));
vector.add(Arrays.stream(tensor.getData()).map(value -> (T) value).collect(Collectors.toList()));
}
}
return vector;
Expand All @@ -225,19 +227,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return resultMaps;
}

private List<Float> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
private <T extends Number> List<T> buildSingleVectorFromResponse(final MLOutput mlOutput) {
final List<List<T>> vector = buildVectorFromResponse(mlOutput);
return vector.isEmpty() ? new ArrayList<>() : vector.get(0);
}

private void retryableInferenceSentencesWithSingleVectorResult(
final MapInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Float>> listener
final ActionListener<List<Number>> listener
) {
MLInput mlInput = createMLMultimodalInput(inferenceRequest.getTargetResponseFilters(), inferenceRequest.getInputObjects());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> vector = buildSingleVectorFromResponse(mlOutput);
final List<Number> vector = buildSingleVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence is : {} ", vector);
listener.onResponse(vector);
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest

}

private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Float> vectors) {
private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List<Number> vectors) {
Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
log.debug("Text embedding result fetched, starting build vector output!");
Map<String, Object> textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors);
Expand Down Expand Up @@ -167,7 +167,7 @@ Map<String, String> buildMapWithKnnKeyAndOriginalValue(final IngestDocument inge

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Float> modelTensorList) {
Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> modelTensorList) {
Map<String, Object> result = new LinkedHashMap<>();
result.put(knnKey, modelTensorList);
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
public class VectorUtilTests extends OpenSearchTestCase {

public void testVectorAsListToArray() {
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
List<Number> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f);
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements);

assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length);
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) {
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f);
assertEquals(vectorAsList_withThreeElements.get(i).floatValue(), vectorAsArray_withThreeElements[i], 0.0f);
}

List<Float> vectorAsList_withNoElements = Collections.emptyList();
List<Number> vectorAsList_withNoElements = Collections.emptyList();
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements);
assertEquals(0, vectorAsArray_withNoElements.length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@
public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
private ActionListener<List<List<Float>>> resultListener;
private ActionListener<List<List<Number>>> resultListener;

@Mock
private ActionListener<List<Float>> singleSentenceResultListener;
private ActionListener<List<Number>> singleSentenceResultListener;

@Mock
private ActionListener<List<Float>> similarityResultListener;

@Mock
private MachineLearningNodeClient client;
Expand All @@ -53,7 +56,7 @@ public void setup() {
}

public void testInferenceSentence_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand All @@ -69,7 +72,7 @@ public void testInferenceSentence_whenValidInput_thenSuccess() {
}

public void testInferenceSentences_whenValidInputThenSuccess() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Arrays.asList(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand All @@ -85,7 +88,7 @@ public void testInferenceSentences_whenValidInputThenSuccess() {
}

public void testInferenceSentences_whenResultFromClient_thenEmptyVectorList() {
final List<List<Float>> vectorList = new ArrayList<>();
final List<List<Number>> vectorList = new ArrayList<>();
vectorList.add(Collections.emptyList());
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
Expand Down Expand Up @@ -127,17 +130,17 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

// Verify client.predict is called 4 times (1 initial + 3 retries)
Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

// Verify failure is propagated to the listener after all retries
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException);

// Ensure no additional interactions with the listener
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() {
Expand Down Expand Up @@ -288,7 +291,7 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa
}

public void testInferenceMultimodal_whenValidInput_thenSuccess() {
final List<Float> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
final List<Number> vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY));
Expand Down Expand Up @@ -353,12 +356,12 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onResponse(vector);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
Expand All @@ -369,12 +372,12 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(singleSentenceResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verify(similarityResultListener).onFailure(exception);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,10 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() {
.modelId(MODEL_ID)
.k(K)
.build();
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(1);
ActionListener<List<Number>> listener = invocation.getArgument(1);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor)
Expand Down Expand Up @@ -810,10 +810,10 @@ public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSe
.modelId(MODEL_ID)
.k(K)
.build();
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
List<Number> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(1);
ActionListener<List<Number>> listener = invocation.getArgument(1);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ protected float[] runInference(final String modelId, final String queryText) {
List<Object> output = (List<Object>) result.get("output");
assertEquals(1, output.size());
Map<String, Object> map = (Map<String, Object>) output.get(0);
List<Float> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
List<Number> data = ((List<Double>) map.get("data")).stream().map(Double::floatValue).collect(Collectors.toList());
return vectorAsListToArray(data);
}

Expand Down
Loading