Skip to content

Commit

Permalink
Fix conflicts after rebase main
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Feb 14, 2025
1 parent 6402127 commit 264abc4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-2 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,13 @@ private void retryableInferenceSentencesWithVectorResult(
private void retryableInferenceSimilarityWithVectorResult(
final SimilarityInferenceRequest inferenceRequest,
final int retryTime,
final ActionListener<List<Number>> listener
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(inferenceRequest.getQueryText(), inferenceRequest.getInputTexts());
mlClient.predict(inferenceRequest.getModelId(), mlInput, ActionListener.wrap(mlOutput -> {
final List<Number> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
final List<Float> scores = buildVectorFromResponse(mlOutput).stream()
.map(v -> v.getFirst().floatValue())
.collect(Collectors.toList());
listener.onResponse(scores);
},
e -> RetryUtil.handleRetryOrFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,17 @@ public void testInferenceSimilarity_whenNodeNotConnectedException_ThenRetry() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

// Verify client.predict is called 4 times (1 initial + 3 retries)
Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

// Verify failure is propagated to the listener after all retries
Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException);
Mockito.verify(similarityResultListener).onFailure(nodeNodeConnectedException);

// Ensure no additional interactions with the listener
Mockito.verifyNoMoreInteractions(singleSentenceResultListener);
Mockito.verifyNoMoreInteractions(similarityResultListener);
}

public void testInferenceSentences_whenExceptionFromMLClient_thenRetry_thenFailure() {
Expand Down Expand Up @@ -356,7 +356,7 @@ public void testInferenceSimilarity_whenValidInput_thenSuccess() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Expand All @@ -372,7 +372,7 @@ public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() {
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));

accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, singleSentenceResultListener);
accessor.inferenceSimilarity(TestCommonConstants.SIMILARITY_INFERENCE_REQUEST, similarityResultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Expand Down

0 comments on commit 264abc4

Please sign in to comment.