From 869d6868366a667a4f8900b633c0c825aa2d8e5d Mon Sep 17 00:00:00 2001 From: Will Hwang Date: Mon, 24 Feb 2025 00:51:00 -0800 Subject: [PATCH] refactor with added integration test Signed-off-by: Will Hwang --- .../processor/InferenceProcessor.java | 142 +++++++------- .../processor/SparseEncodingProcessor.java | 2 +- .../processor/TextEmbeddingProcessor.java | 6 +- .../TextEmbeddingProcessorFactory.java | 12 +- ....java => SelectiveInferenceProcessor.java} | 37 ++-- ...a => SelectiveTextEmbeddingProcessor.java} | 67 ++++--- .../processor/util/ProcessorUtils.java | 73 ++++--- .../TextEmbeddingProcessorTests.java | 6 +- .../SelectiveTextEmbeddingProcessorIT.java | 183 ++++++++++++++++++ ...SelectiveTextEmbeddingProcessorTests.java} | 168 ++++++++-------- .../util/ProcessorUtilsTests.java | 11 ++ ...thNestedFieldsMappingWithSkipExisting.json | 20 ++ ...PipelineConfigurationWithSkipExisting.json | 23 +++ src/test/resources/processor/update_doc1.json | 25 +++ src/test/resources/processor/update_doc2.json | 23 +++ src/test/resources/processor/update_doc3.json | 23 +++ src/test/resources/processor/update_doc4.json | 20 ++ src/test/resources/processor/update_doc5.json | 24 +++ .../neuralsearch/BaseNeuralSearchIT.java | 31 ++- 19 files changed, 658 insertions(+), 238 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/optimization/{OptimizedInferenceProcessor.java => SelectiveInferenceProcessor.java} (87%) rename src/main/java/org/opensearch/neuralsearch/processor/optimization/{OptimizedTextEmbeddingProcessor.java => SelectiveTextEmbeddingProcessor.java} (70%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessorIT.java rename src/test/java/org/opensearch/neuralsearch/processor/optimization/{OptimizedTextEmbeddingProcessorTests.java => SelectiveTextEmbeddingProcessorTests.java} (92%) create mode 100644 src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json create mode 100644 src/test/resources/processor/PipelineConfigurationWithSkipExisting.json create mode 100644 src/test/resources/processor/update_doc1.json create mode 100644 src/test/resources/processor/update_doc2.json create mode 100644 src/test/resources/processor/update_doc3.json create mode 100644 src/test/resources/processor/update_doc4.json create mode 100644 src/test/resources/processor/update_doc5.json diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 27d99a80c..a60cc3392 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -10,6 +10,7 @@ import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; +import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.ListIterator; @@ -55,31 +56,34 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor { public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; private static final BiFunction REMAPPING_FUNCTION = (v1, v2) -> { - if (v1 instanceof List && v2 instanceof List) { - List v1List = new ArrayList<>((List) v1); - List v2List = (List) v2; - - Iterator iterator = v2List.iterator(); - for (int i = 0; i < v1List.size() && iterator.hasNext(); i++) { - if (v1List.get(i) == null) { - v1List.set(i, iterator.next()); + /** + * REMAPPING_FUNCTION is invoked when merge is invoked on Map and takes two arguments: + * + * v1: The current (existing) value in the map. + * v2: The new value you're trying to merge in + * + * The function returns the value that should be associated with the key + * + * In case of type Collection (always List type currently) REMAPPING_FUNCTION joins v1 and v2 by inserting values in v2 to v1 where + * the index is marked null + * In case of type Maps, REMAPPING_FUNCTION joins v1 and v2 by putting values in v2 to v1 + * */ + + if (v1 instanceof Collection && v2 instanceof Collection) { + Iterator iterator = ((Collection) v2).iterator(); + for (int i = 0; i < ((Collection) v1).size(); i++) { + if (((List) v1).get(i) == null) { + ((List) v1).set(i, iterator.next()); } } - return v1List; - } - - if (v1 instanceof Map && v2 instanceof Map) { - Map v1Map = new LinkedHashMap<>((Map) v1); - Map v2Map = (Map) v2; - - for (Map.Entry entry : v2Map.entrySet()) { - if (entry.getKey() instanceof String && !v1Map.containsKey(entry.getKey())) { - v1Map.put((String) entry.getKey(), entry.getValue()); - } - } - return v1Map; + assert iterator.hasNext() == false; + return v1; + } else if (v1 instanceof Map && v2 instanceof Map) { + ((Map) v1).putAll((Map) v2); + return v1; + } else { + return v2; } - return v2 != null ? v2 : v1; }; @@ -149,8 +153,9 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { /** * This method will be invoked by PipelineService to make async inference and then delegate the handler to * process the inference response or failure. + * * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. */ @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { @@ -180,9 +185,10 @@ void preprocessIngestDocument(IngestDocument ingestDocument) { /** * This is the function which does actual inference work for batchExecute interface. + * * @param inferenceList a list of String for inference. - * @param handler a callback handler to handle inference results which is a list of objects. - * @param onException an exception callback to handle exception. + * @param handler a callback handler to handle inference results which is a list of objects. + * @param onException an exception callback to handle exception. */ protected abstract void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException); @@ -215,7 +221,8 @@ public void subBatchExecute(List ingestDocumentWrappers, setVectorFieldsToDocument( dataForInference.getIngestDocumentWrapper().getIngestDocument(), dataForInference.getProcessMap(), - inferenceResults + inferenceResults, + false ); } handler.accept(ingestDocumentWrappers); @@ -393,6 +400,7 @@ private Object normalizeSourceValue(Object value) { /** * Process the nested key, such as "a.b.c" to "a", "b.c" + * * @param nestedFieldMapEntry * @return A pair of the original key and the target key */ @@ -429,39 +437,39 @@ protected void setVectorFieldsToDocument( IngestDocument ingestDocument, Map processorMap, List results, - boolean update + boolean partialUpdate ) { Objects.requireNonNull(results, "embedding failed, inference returns null result!"); log.debug("Model inference result fetched, starting build vector output!"); - Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); - if (update) { + // if partialUpdate is set to false, full update is required where each vector embedding in nlpResult + // can directly be populated to ingestDocument + if (partialUpdate == false) { + Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); + nlpResult.forEach(ingestDocument::setFieldValue); + } else { + // if partialUpdate is set to true, some embeddings have been copied over from existing document, which vector embedding in + // nlpResult needs to be populated to only where marked as 'null' + Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); for (Map.Entry nlpEntry : nlpResult.entrySet()) { String key = nlpEntry.getKey(); Object target = ingestDocument.getSourceAndMetadata().get(key); Object nlpValues = nlpEntry.getValue(); - if (target instanceof List targetList && nlpValues instanceof List nlpValueList) { - List list = new ArrayList<>(targetList); - ListIterator iterator = list.listIterator(); - ListIterator nlpIterator = (ListIterator) nlpValueList.listIterator(); - while (iterator.hasNext() && nlpIterator.hasNext()) { + if (target instanceof List && nlpValues instanceof List) { + ListIterator iterator = ((List) target).listIterator(); + ListIterator nlpIterator = ((List) nlpValues).listIterator(); + while (iterator.hasNext()) { if (iterator.next() == null) { iterator.set(nlpIterator.next()); } } - ingestDocument.setFieldValue(key, list); + ingestDocument.setFieldValue(key, target); } else { ingestDocument.setFieldValue(key, nlpValues); } } - } else { - nlpResult.forEach(ingestDocument::setFieldValue); } } - protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { - setVectorFieldsToDocument(ingestDocument, processorMap, results, false); - } - @SuppressWarnings({ "unchecked" }) @VisibleForTesting Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { @@ -560,8 +568,7 @@ private void processMapEntryValue( ) { // build nlp output for object in sourceValue which is map type Iterator> iterator = sourceAndMetadataMapValueInList.iterator(); - IndexWrapper listIndexWrapper = new IndexWrapper(0); - for (int i = 0; i < sourceAndMetadataMapValueInList.size(); i++) { + IntStream.range(0, sourceAndMetadataMapValueInList.size()).forEach(index -> { Map nestedElement = iterator.next(); putNLPResultToSingleSourceMapInList( inputNestedMapEntryKey, @@ -569,9 +576,9 @@ private void processMapEntryValue( results, indexWrapper, nestedElement, - listIndexWrapper + index ); - } + }); } /** @@ -591,7 +598,7 @@ private void putNLPResultToSingleSourceMapInList( List results, IndexWrapper indexWrapper, Map sourceAndMetadataMap, - IndexWrapper listIndexWrapper + int nestedIndex ) { if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; if (sourceValue instanceof Map) { @@ -604,15 +611,12 @@ private void putNLPResultToSingleSourceMapInList( results, indexWrapper, sourceMap, - listIndexWrapper + nestedIndex ); } } else { if (sourceValue instanceof List) { - if (sourceAndMetadataMap.containsKey(processorKey)) { - return; - } - if (((List) sourceValue).get(listIndexWrapper.index++) != null) { + if (((List) sourceValue).get(nestedIndex) != null) { sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION); } } @@ -632,37 +636,35 @@ private Map getSourceMapBySourceAndMetadataMap(String processorK private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { List> keyToResult = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + IntStream.range(0, sourceValue.size()).forEachOrdered(x -> { + if (sourceValue.get(x) != null) { // only add to keyToResult when sourceValue.get(x) exists, + keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))); + } + }); return keyToResult; } + /** + * This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument + * + * @param ingestDocument ingestDocument to populate embeddings to + * @param ProcessMap map indicating the path in ingestDocument to populate embeddings + * @param inferenceList list of texts to be model inference + * @param handler SourceAndMetadataMap of ingestDocument Document + * @param partialUpdate set to false if ingestDocument is ingested for the first time, set to true if ingestDocument is an update operation + * + */ protected void makeInferenceCall( IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler, - boolean update - ) { - mlCommonsClientAccessor.inferenceSentences( - TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), - ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors, update); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); }) - ); - } - - protected void makeInferenceCall( - IngestDocument ingestDocument, - Map ProcessMap, - List inferenceList, - BiConsumer handler + boolean partialUpdate ) { mlCommonsClientAccessor.inferenceSentences( TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors, partialUpdate); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); }) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 438f9bac5..095cf4d5b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -66,7 +66,7 @@ public void doExecute( .stream() .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)) .toList(); - setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); + setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors, false); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); }) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index aa24a7c92..233403cdc 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -26,8 +26,8 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; - public static final String IGNORE_EXISTING = "ignore_existing"; - public static final boolean DEFAULT_IGNORE_EXISTING = Boolean.FALSE; + public static final String SKIP_EXISTING = "skip_existing"; + public static final boolean DEFAULT_SKIP_EXISTING = Boolean.FALSE; public TextEmbeddingProcessor( String tag, @@ -52,7 +52,7 @@ public void doExecute( mlCommonsClientAccessor.inferenceSentences( TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors, false); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); }) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index d1355c843..e396f5f12 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -7,8 +7,8 @@ import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty; import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_IGNORE_EXISTING; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.IGNORE_EXISTING; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.SKIP_EXISTING; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_SKIP_EXISTING; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; @@ -20,7 +20,7 @@ import org.opensearch.ingest.AbstractBatchingProcessor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; -import org.opensearch.neuralsearch.processor.optimization.OptimizedTextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.optimization.SelectiveTextEmbeddingProcessor; import org.opensearch.transport.client.OpenSearchClient; /** @@ -53,9 +53,9 @@ public TextEmbeddingProcessorFactory( protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); - boolean ignoreExisting = readBooleanProperty(TYPE, tag, config, IGNORE_EXISTING, DEFAULT_IGNORE_EXISTING); - if (ignoreExisting == true) { - return new OptimizedTextEmbeddingProcessor( + boolean skipExisting = readBooleanProperty(TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING); + if (skipExisting == true) { + return new SelectiveTextEmbeddingProcessor( tag, description, batchSize, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/optimization/OptimizedInferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/optimization/SelectiveInferenceProcessor.java similarity index 87% rename from src/main/java/org/opensearch/neuralsearch/processor/optimization/OptimizedInferenceProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/optimization/SelectiveInferenceProcessor.java index 9c1565d6d..bbb08f4fd 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/optimization/OptimizedInferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/optimization/SelectiveInferenceProcessor.java @@ -19,13 +19,13 @@ import java.util.Optional; /** - * The abstract class for optimized text processing use cases. On update operation, the optimized inference processor will attempt to - * optimize inference calls by copying over existing embeddings for the same text + * The abstract class for selective text processing use cases. On update operation, the selective inference processor will attempt to + * skip inference calls by copying over existing embeddings for the same text */ @Log4j2 -public abstract class OptimizedInferenceProcessor extends InferenceProcessor { - public OptimizedInferenceProcessor( +public abstract class SelectiveInferenceProcessor extends InferenceProcessor { + public SelectiveInferenceProcessor( String tag, String description, int batchSize, @@ -37,7 +37,18 @@ public OptimizedInferenceProcessor( Environment environment, ClusterService clusterService ) { - super(tag, description, batchSize, type, listTypeNestedMapKey, modelId, fieldMap, clientAccessor, environment, clusterService); + super( + tag, + description, + batchSize, + type, + listTypeNestedMapKey, + modelId, + ProcessorDocumentUtils.unflattenJson(fieldMap), + clientAccessor, + environment, + clusterService + ); } public abstract Object processValue( @@ -100,9 +111,7 @@ protected Map filterProcessMap( currentPath, currLevel ); - if (!filteredInnerMap.isEmpty()) { - filteredProcessMap.put(key, filteredInnerMap); - } + filteredProcessMap.put(key, filteredInnerMap.isEmpty() ? null : filteredInnerMap); } else if (value instanceof List) { List processedList = processListValue( currentPath, @@ -116,9 +125,7 @@ protected Map filterProcessMap( } } else { Object processedValue = processValue(currentPath, value, currLevel, sourceAndMetadataMap, existingSourceAndMetadataMap, -1); - if (processedValue != null) { - filteredProcessMap.put(key, processedValue); - } + filteredProcessMap.put(key, processedValue); } } return filteredProcessMap; @@ -141,9 +148,9 @@ protected List processListValue( Map sourceAndMetadataMap, Map existingSourceAndMetadataMap ) { - String textKey = ProcessorUtils.findKeyFromFromValue(ProcessorDocumentUtils.unflattenJson(fieldMap), currentPath, level); + String textKey = ProcessorUtils.findKeyFromFromValue(fieldMap, currentPath, level); if (textKey == null) { - return new ArrayList<>(processList); + return (List) processList; } String fullTextKey = ProcessorUtils.computeFullTextKey(currentPath, textKey, level); @@ -194,9 +201,7 @@ private List processMapValuesInList( existingSourceAndMetadataMap, i ); - if (processedItem != null) { - filteredList.add(processedItem); - } + filteredList.add(processedItem); } return filteredList; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/optimization/OptimizedTextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessor.java similarity index 70% rename from src/main/java/org/opensearch/neuralsearch/processor/optimization/OptimizedTextEmbeddingProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessor.java index 8e7df9e77..e20438dc5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/optimization/OptimizedTextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessor.java @@ -12,23 +12,25 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.TextInferenceRequest; import org.opensearch.neuralsearch.processor.util.ProcessorUtils; -import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; import org.opensearch.transport.client.OpenSearchClient; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.stream.Collectors; /** - * This processor is used for optimizing text embedding processing. This processor will skip redundant inference calls by comparing existing document and new document. - * If inference texts stay the same, the OptimizedTextEmbeddingProcessor will copy over existing embeddings and filter out the inference text from process map and inference list + * This processor is used for selective text embedding processing. This processor will skip redundant inference calls by comparing existing document and new document. + * If inference texts stay the same, the SelectiveTextEmbeddingProcessor will copy over existing embeddings and mark null inference text from process map */ @Log4j2 -public class OptimizedTextEmbeddingProcessor extends OptimizedInferenceProcessor { +public class SelectiveTextEmbeddingProcessor extends SelectiveInferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; private static final String INDEX_FIELD = "_index"; @@ -36,7 +38,7 @@ public class OptimizedTextEmbeddingProcessor extends OptimizedInferenceProcessor private final OpenSearchClient openSearchClient; - public OptimizedTextEmbeddingProcessor( + public SelectiveTextEmbeddingProcessor( String tag, String description, int batchSize, @@ -63,10 +65,12 @@ public void doExecute( openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> { final Map document = response.getSourceAsMap(); if (document == null || document.isEmpty()) { - makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler); + makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler, false); } else { Map filteredProcessMap = filterProcessMap(document, ingestDocument.getSourceAndMetadata(), ProcessMap); - List filteredInferenceList = createInferenceList(filteredProcessMap); + List filteredInferenceList = createInferenceList(filteredProcessMap).stream() + .filter(Objects::nonNull) + .collect(Collectors.toList()); if (!filteredInferenceList.isEmpty()) { makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler, true); } else { @@ -77,8 +81,11 @@ public void doExecute( } @Override - public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { - // TODO: invoke Opensearch Client's Multi-Get request to enable batch execution + protected void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { + mlCommonsClientAccessor.inferenceSentences( + TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), + ActionListener.wrap(handler::accept, onException) + ); } /** @@ -102,7 +109,7 @@ public Object processValue( Map existingSourceAndMetadataMap, int index ) { - String textKey = ProcessorUtils.findKeyFromFromValue(ProcessorDocumentUtils.unflattenJson(fieldMap), currentPath, level); + String textKey = ProcessorUtils.findKeyFromFromValue(fieldMap, currentPath, level); if (textKey == null) { return processValue; } @@ -126,9 +133,9 @@ public Object processValue( * - In the given list of inference texts, inference texts in the same index is the same between existingSourceAndMetadataMap and sourceAndMetadataMap * - existing existingSourceAndMetadataMap has embeddings for corresponding inference texts * @param processList list of inference texts to check for - * @param sourceList list of inference texts in sourceAndMetadataMap - * @param existingList list of inference texts in existingSourceAndMetadataMap - * @param embeddingList list of embeddings for inference texts + * @param sourceListOptional Optional list of inference texts in sourceAndMetadataMap + * @param existingListOptional Optional list of inference texts in existingSourceAndMetadataMap + * @param embeddingListOptional Optional list of embeddings for inference texts * @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument Document * @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document * @param fullEmbeddingKey path to embedding key @@ -138,28 +145,36 @@ public Object processValue( @Override public List processValues( List processList, - Optional sourceList, - Optional existingList, - Optional embeddingList, + Optional sourceListOptional, + Optional existingListOptional, + Optional embeddingListOptional, Map sourceAndMetadataMap, Map existingSourceAndMetadataMap, String fullEmbeddingKey ) { List filteredList = new ArrayList<>(); List updatedEmbeddings = new ArrayList<>(); - if (sourceList.isPresent() && existingList.isPresent()) { - int min = Math.min(((List) sourceList.get()).size(), ((List) existingList.get()).size()); - for (int j = 0; j < min; j++) { - if (((List) sourceList.get()).get(j).equals(((List) existingList.get()).get(j))) { - updatedEmbeddings.add(((List) embeddingList.get()).get(j)); - } else { - filteredList.add(processList.get(j)); - updatedEmbeddings.add(null); + if (sourceListOptional.isPresent() && existingListOptional.isPresent()) { + if (sourceListOptional.get() instanceof List + && existingListOptional.get() instanceof List + && embeddingListOptional.get() instanceof List) { + List sourceList = (ArrayList) sourceListOptional.get(); + List existingList = (ArrayList) existingListOptional.get(); + List embeddingList = (ArrayList) embeddingListOptional.get(); + int min = Math.min(sourceList.size(), existingList.size()); + for (int j = 0; j < min; j++) { + if (sourceList.get(j).equals(existingList.get(j))) { + updatedEmbeddings.add((embeddingList).get(j)); + filteredList.add(null); + } else { + filteredList.add(processList.get(j)); + updatedEmbeddings.add(null); + } } + ProcessorUtils.setValueToSource(sourceAndMetadataMap, fullEmbeddingKey, updatedEmbeddings); } - ProcessorUtils.setValueToSource(sourceAndMetadataMap, fullEmbeddingKey, updatedEmbeddings); } else { - return new ArrayList<>(processList); + return (List) processList; } return filteredList; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java index c37dd3eec..4807db6cc 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java @@ -140,21 +140,25 @@ public static void removeTargetFieldFromSource(final Map sourceA } } + public static Optional getValueFromSource(final Map sourceAsMap, final String targetField) { + return getValueFromSource(sourceAsMap, targetField, -1); + } + /** * Returns the mapping associated with a path to a value, otherwise * returns an empty optional when it encounters a dead end. - *
* When the targetField has the form (key[.key]) it will iterate through * the map to see if a mapping exists. + * When a List is encountered during traversal, the index parameter is used to select + * the appropriate element from the list. If the selected element is a map, traversal continues with + * that map using the next key in the path. * * @param sourceAsMap The Source map (a map of maps) to iterate through * @param targetField The path to take to get the desired mapping + * @param index the index to use when a list is encountered during traversal; if list processing is not needed, + * -1 is passed in * @return A possible result within an optional */ - public static Optional getValueFromSource(final Map sourceAsMap, final String targetField) { - return getValueFromSource(sourceAsMap, targetField, -1); - } - public static Optional getValueFromSource(final Map sourceAsMap, final String targetField, int index) { String[] keys = targetField.split("\\."); Optional currentValue = Optional.ofNullable(sourceAsMap); @@ -164,13 +168,13 @@ public static Optional getValueFromSource(final Map sour if (value instanceof List && index != -1) { Object listValue = ((List) value).get(index); if (listValue instanceof Map) { - Map currentMap = (Map) listValue; + Map currentMap = (Map) listValue; return Optional.ofNullable(currentMap.get(key)); } } else if (!(value instanceof Map)) { return Optional.empty(); } - Map currentMap = (Map) value; + Map currentMap = (Map) value; return Optional.ofNullable(currentMap.get(key)); }); @@ -182,30 +186,38 @@ public static Optional getValueFromSource(final Map sour return currentValue; } + public static void setValueToSource(Map sourceAsMap, String targetKey, Object targetValue) { + setValueToSource(sourceAsMap, targetKey, targetValue, -1); + } + /** - * Given the path to targetKey in sourceAsMap, sets targetValue in targetKey + * Inserts or updates a value in a nested map structure, with optional support for list traversal. + * This method navigates through the provided sourceAsMap using the dot-delimited key path + * specified by targetKey. Intermediate maps are created as needed. When a List is encountered, + * the provided index is used to select the element from the list. The selected element must be a map to + * continue the traversal. + * Once the final map in the path is reached, the method sets the value for the last key. * * @param sourceAsMap The Source map (a map of maps) to iterate through - * @param targetKey The path to key to insert the desired targetValue - * @param targetValue The value to insert to targetKey + * @param targetKey he path to key to insert the desired targetValue + * @param targetValue the value to set at the specified key path + * @param index the index to use when a list is encountered during traversal; if list processing is not needed, + * -1 is passed in */ - public static void setValueToSource(Map sourceAsMap, String targetKey, Object targetValue) { - setValueToSource(sourceAsMap, targetKey, targetValue, -1); - } public static void setValueToSource(Map sourceAsMap, String targetKey, Object targetValue, int index) { if (sourceAsMap == null || targetKey == null) return; String[] keys = targetKey.split("\\."); - Map current = sourceAsMap; + Map current = sourceAsMap; for (int i = 0; i < keys.length - 1; i++) { Object next = current.computeIfAbsent(keys[i], k -> new HashMap<>()); if (next instanceof List list) { if (index < 0 || index >= list.size()) return; - current = (Map) list.get(index); + current = (Map) list.get(index); } else if (next instanceof Map) { - current = (Map) next; + current = (Map) next; } else { throw new IllegalStateException("Unexpected data structure at " + keys[i]); } @@ -214,11 +226,11 @@ public static void setValueToSource(Map sourceAsMap, String targ String lastKey = keys[keys.length - 1]; Object existingValue = current.get(lastKey); - if (existingValue instanceof List existingList) { - if (index >= 0 && index < existingList.size()) { - ((List) existingList).set(index, targetValue); + if (existingValue instanceof List) { + if (index >= 0 && index < ((List) existingValue).size()) { + ((List) existingValue).set(index, targetValue); } else if (index == -1) { - ((List) existingList).add(targetValue); + ((List) existingValue).add(targetValue); } } else { current.put(lastKey, targetValue); @@ -269,7 +281,7 @@ public static boolean isNumeric(Object value) { * e.g: * path: level1.level2.oldKey * textKey: newKey - * level: 2 + * level: 3 * returns level1.level2.newKey * * @param path path to old key @@ -306,17 +318,20 @@ public static String computeFullTextKey(String path, String textKey, int level) */ public static String findKeyFromFromValue(Map sourceAsMap, String path, int level) { String[] keys = path.split("\\.", level); - Map currentMap = sourceAsMap; + String targetValue = keys[keys.length - 1]; - for (String key : keys) { - if (key.equals(targetValue)) { - break; - } - if (currentMap.containsKey(key)) { - Object value = currentMap.get(key); + Map currentMap = sourceAsMap; + + for (int i = 0; i < keys.length - 1; i++) { + if (currentMap.containsKey(keys[i])) { + Object value = currentMap.get(keys[i]); if (value instanceof Map) { - currentMap = (Map) value; + currentMap = (Map) value; + } else { + return null; } + } else { + return null; } } String lastFoundKey = null; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 259cff24e..9a24e493e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -745,7 +745,7 @@ public void testProcessResponse_successful() throws Exception { Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); + processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList, false); assertEquals(12, ingestDocument.getSourceAndMetadata().size()); } @@ -1046,10 +1046,10 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); + processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList, false); List> modelTensorList1 = createMockVectorResult(); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList1); + processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList1, false); assertEquals(12, ingestDocument.getSourceAndMetadata().size()); assertEquals(2, ((List) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size()); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessorIT.java new file mode 100644 index 000000000..329b9dce1 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessorIT.java @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.optimization; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.lucene.search.join.ScoreMode; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +public class SelectiveTextEmbeddingProcessorIT extends BaseNeuralSearchIT { + + private static final String INDEX_NAME = "selective_text_embedding_index"; + + private static final String PIPELINE_NAME = "selective_pipeline"; + protected static final String QUERY_TEXT = "hello"; + protected static final String LEVEL_1_FIELD = "nested_passages"; + protected static final String LEVEL_2_FIELD = "level_2"; + protected static final String LEVEL_3_FIELD_TEXT = "level_3_text"; + protected static final String LEVEL_3_FIELD_CONTAINER = "level_3_container"; + protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding"; + protected static final String TEXT_FIELD_VALUE_1 = "hello"; + protected static final String TEXT_FIELD_VALUE_2 = "joker"; + protected static final String TEXT_FIELD_VALUE_3 = "def"; + private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI())); + private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); + private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI())); + private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI())); + private final String UPDATE_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/update_doc1.json").toURI())); + private final String UPDATE_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/update_doc3.json").toURI())); + private final String UPDATE_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/update_doc4.json").toURI())); + private final String UPDATE_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/update_doc5.json").toURI())); + + public SelectiveTextEmbeddingProcessorIT() throws IOException, URISyntaxException {} + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + } + + public void testTextEmbeddingProcessorWithSkipExisting() throws Exception { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_SKIP_EXISTING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC1, "1"); + updateDocument(INDEX_NAME, UPDATE_DOC1, "1"); + assertEquals(1, getDocCount(INDEX_NAME)); + assertEquals(2, getDocById(INDEX_NAME, "1").get("_version")); + } + + public void testNestedFieldMapping_whenDocumentsIngestedAndUpdated_thenSuccessful() throws Exception { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING_WITH_SKIP_EXISTING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC3, "3"); + updateDocument(INDEX_NAME, UPDATE_DOC3, "3"); + ingestDocument(INDEX_NAME, INGEST_DOC4, "4"); + updateDocument(INDEX_NAME, UPDATE_DOC4, "4"); + + assertDoc((Map) getDocById(INDEX_NAME, "3").get("_source"), TEXT_FIELD_VALUE_1, Optional.of(TEXT_FIELD_VALUE_3)); + assertDoc((Map) getDocById(INDEX_NAME, "4").get("_source"), TEXT_FIELD_VALUE_2, Optional.empty()); + + NeuralQueryBuilder neuralQueryBuilderQuery = NeuralQueryBuilder.builder() + .fieldName(LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING) + .queryText(QUERY_TEXT) + .modelId(modelId) + .k(10) + .build(); + + QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD, + neuralQueryBuilderQuery, + ScoreMode.Total + ); + QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total); + + Map searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2); + assertNotNull(searchResponseAsMap); + + Map hits = (Map) searchResponseAsMap.get("hits"); + assertNotNull(hits); + + assertEquals(1.0, hits.get("max_score")); + List> listOfHits = (List>) hits.get("hits"); + assertNotNull(listOfHits); + assertEquals(2, listOfHits.size()); + + Map innerHitDetails = listOfHits.get(0); + assertEquals("3", innerHitDetails.get("_id")); + assertEquals(1.0, innerHitDetails.get("_score")); + + innerHitDetails = listOfHits.get(1); + assertEquals("4", innerHitDetails.get("_id")); + assertTrue((double) innerHitDetails.get("_score") <= 1.0); + } + + public void testNestedFieldMapping_whenDocumentInListIngestedAndUpdated_thenSuccessful() throws Exception { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC5, "5"); + updateDocument(INDEX_NAME, UPDATE_DOC5, "5"); + + assertDocWithLevel2AsList((Map) getDocById(INDEX_NAME, "5").get("_source")); + + NeuralQueryBuilder neuralQueryBuilderQuery = NeuralQueryBuilder.builder() + .fieldName(LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING) + .queryText(QUERY_TEXT) + .modelId(modelId) + .k(10) + .build(); + + QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD, + neuralQueryBuilderQuery, + ScoreMode.Total + ); + QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total); + + Map searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2); + assertNotNull(searchResponseAsMap); + + assertEquals(1, getHitCount(searchResponseAsMap)); + + Map innerHitDetails = getFirstInnerHit(searchResponseAsMap); + assertEquals("5", innerHitDetails.get("_id")); + } + + private String uploadTextEmbeddingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); + return registerModelGroupAndUploadModel(requestBody); + } + + private void assertDoc(Map sourceMap, String textFieldValue, Optional level3ExpectedValue) { + assertNotNull(sourceMap); + assertTrue(sourceMap.containsKey(LEVEL_1_FIELD)); + Map nestedPassages = (Map) sourceMap.get(LEVEL_1_FIELD); + assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD)); + Map level2 = (Map) nestedPassages.get(LEVEL_2_FIELD); + assertEquals(textFieldValue, level2.get(LEVEL_3_FIELD_TEXT)); + Map level3 = (Map) level2.get(LEVEL_3_FIELD_CONTAINER); + List embeddings = (List) level3.get(LEVEL_3_FIELD_EMBEDDING); + assertEquals(768, embeddings.size()); + for (Double embedding : embeddings) { + assertTrue(embedding >= 0.0 && embedding <= 1.0); + } + if (level3ExpectedValue.isPresent()) { + assertEquals(level3ExpectedValue.get(), level3.get("level_4_text_field")); + } + } + + private void assertDocWithLevel2AsList(Map sourceMap) { + assertNotNull(sourceMap); + assertTrue(sourceMap.containsKey(LEVEL_1_FIELD)); + assertTrue(sourceMap.get(LEVEL_1_FIELD) instanceof List); + List> nestedPassages = (List>) sourceMap.get(LEVEL_1_FIELD); + nestedPassages.forEach(nestedPassage -> { + assertTrue(nestedPassage.containsKey(LEVEL_2_FIELD)); + Map level2 = (Map) nestedPassage.get(LEVEL_2_FIELD); + Map level3 = (Map) level2.get(LEVEL_3_FIELD_CONTAINER); + List embeddings = (List) level3.get(LEVEL_3_FIELD_EMBEDDING); + assertEquals(768, embeddings.size()); + for (Double embedding : embeddings) { + assertTrue(embedding >= 0.0 && embedding <= 1.0); + } + }); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/optimization/OptimizedTextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessorTests.java similarity index 92% rename from src/test/java/org/opensearch/neuralsearch/processor/optimization/OptimizedTextEmbeddingProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessorTests.java index 89849866f..90a153067 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/optimization/OptimizedTextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/optimization/SelectiveTextEmbeddingProcessorTests.java @@ -57,7 +57,7 @@ import lombok.SneakyThrows; import org.opensearch.transport.client.OpenSearchClient; -public class OptimizedTextEmbeddingProcessorTests extends InferenceProcessorTestCase { +public class SelectiveTextEmbeddingProcessorTests extends InferenceProcessorTestCase { protected static final String PARENT_FIELD = "parent"; protected static final String CHILD_FIELD_LEVEL_1 = "child_level1"; @@ -100,99 +100,99 @@ public void setup() { } @SneakyThrows - private OptimizedTextEmbeddingProcessor createInstanceWithLevel2MapConfig() { + private SelectiveTextEmbeddingProcessor createInstanceWithLevel2MapConfig() { Map registry = new HashMap<>(); Map config = new HashMap<>(); - config.put(OptimizedTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SelectiveTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put( - OptimizedTextEmbeddingProcessor.FIELD_MAP_FIELD, + SelectiveTextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", ImmutableMap.of("test1", "test1_knn"), "key2", ImmutableMap.of("test3", CHILD_LEVEL_2_KNN_FIELD)) ); - config.put(TextEmbeddingProcessor.IGNORE_EXISTING, true); - return (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, true); + return (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows - private OptimizedTextEmbeddingProcessor createInstanceWithLevel1MapConfig() { + private SelectiveTextEmbeddingProcessor createInstanceWithLevel1MapConfig() { Map registry = new HashMap<>(); Map config = new HashMap<>(); - config.put(OptimizedTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(OptimizedTextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1_knn", "key2", "key2_knn")); - config.put(TextEmbeddingProcessor.IGNORE_EXISTING, true); - return (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + config.put(SelectiveTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SelectiveTextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1_knn", "key2", "key2_knn")); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, true); + return (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows - private OptimizedTextEmbeddingProcessor createInstanceWithNestedLevelConfig() { + private SelectiveTextEmbeddingProcessor createInstanceWithNestedLevelConfig() { Map registry = new HashMap<>(); Map config = new HashMap<>(); - config.put(OptimizedTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(TextEmbeddingProcessor.IGNORE_EXISTING, true); + config.put(SelectiveTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, true); config.put( - OptimizedTextEmbeddingProcessor.FIELD_MAP_FIELD, + SelectiveTextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of( String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_FIELD_LEVEL_2)), CHILD_LEVEL_2_KNN_FIELD ) ); - return (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows - private OptimizedTextEmbeddingProcessor createInstanceWithNestedMappingsConfig() { + private SelectiveTextEmbeddingProcessor createInstanceWithNestedMappingsConfig() { Map registry = new HashMap<>(); Map config = new HashMap<>(); - config.put(OptimizedTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(TextEmbeddingProcessor.IGNORE_EXISTING, true); + config.put(SelectiveTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, true); config.put( - OptimizedTextEmbeddingProcessor.FIELD_MAP_FIELD, + SelectiveTextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of( String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_FIELD_LEVEL_2)), CHILD_LEVEL_2_KNN_FIELD ) ); - return (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows - private OptimizedTextEmbeddingProcessor createInstanceWithNestedMapConfig() { + private SelectiveTextEmbeddingProcessor createInstanceWithNestedMapConfig() { Map registry = new HashMap<>(); Map config = new HashMap<>(); - config.put(OptimizedTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(TextEmbeddingProcessor.IGNORE_EXISTING, true); + config.put(SelectiveTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, true); config.put( - OptimizedTextEmbeddingProcessor.FIELD_MAP_FIELD, + SelectiveTextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of( String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1)), Map.of(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD) ) ); - return (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows - private OptimizedTextEmbeddingProcessor createInstanceWithNestedSourceAndDestinationConfig() { + private SelectiveTextEmbeddingProcessor createInstanceWithNestedSourceAndDestinationConfig() { Map registry = new HashMap<>(); Map config = new HashMap<>(); - config.put(OptimizedTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(TextEmbeddingProcessor.IGNORE_EXISTING, true); + config.put(SelectiveTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, true); config.put( - OptimizedTextEmbeddingProcessor.FIELD_MAP_FIELD, + SelectiveTextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of( String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)), CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD ) ); - return (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows - private OptimizedTextEmbeddingProcessor createInstanceWithNestedSourceAndDestinationMapConfig() { + private SelectiveTextEmbeddingProcessor createInstanceWithNestedSourceAndDestinationMapConfig() { Map registry = new HashMap<>(); Map config = buildObjMap( - Pair.of(OptimizedTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), + Pair.of(SelectiveTextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), Pair.of( - OptimizedTextEmbeddingProcessor.FIELD_MAP_FIELD, + SelectiveTextEmbeddingProcessor.FIELD_MAP_FIELD, buildObjMap( Pair.of( PARENT_FIELD, @@ -203,9 +203,9 @@ private OptimizedTextEmbeddingProcessor createInstanceWithNestedSourceAndDestina ) ) ), - Pair.of(TextEmbeddingProcessor.IGNORE_EXISTING, true) + Pair.of(TextEmbeddingProcessor.SKIP_EXISTING, true) ); - return (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + return (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } public void testExecute_when_initial_ingest_successful() throws IOException { @@ -217,7 +217,7 @@ public void testExecute_when_initial_ingest_successful() throws IOException { List inferenceList = Arrays.asList("value1", "value2"); TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - OptimizedInferenceProcessor processor = createInstanceWithLevel1MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); GetResponse response = mockEmptyGetResponse(); doAnswer(invocation -> { @@ -256,9 +256,9 @@ public void testExecute_whenGetDocumentThrowsException_throwRuntimeException() { ); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put(TextEmbeddingProcessor.IGNORE_EXISTING, true); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, true); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); - OptimizedTextEmbeddingProcessor processor = (OptimizedTextEmbeddingProcessor) textEmbeddingProcessorFactory.create( + SelectiveTextEmbeddingProcessor processor = (SelectiveTextEmbeddingProcessor) textEmbeddingProcessorFactory.create( registry, PROCESSOR_TAG, DESCRIPTION, @@ -279,13 +279,13 @@ public void testExecute_with_no_update_successful() { ingestSourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); List inferenceList = Arrays.asList("value1", "value2"); - TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); // no change IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedInferenceProcessor processor = createInstanceWithLevel1MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(2, 2, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, null); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); // insert document processor.execute(updateDocument, handler); // update document @@ -293,7 +293,7 @@ public void testExecute_with_no_update_successful() { verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); - assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + assertEquals(ingestRequest.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); List key1insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key1_knn"); List key1updateVectors = (List) updateDocument.getSourceAndMetadata().get("key1_knn"); List key2insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key2_knn"); @@ -321,9 +321,9 @@ public void testExecute_with_updated_field_successful() { .build(); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedInferenceProcessor processor = createInstanceWithLevel1MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(2, 2, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); // insert @@ -353,15 +353,15 @@ public void testExecute_withListTypeInput_no_update_successful() { ingestSourceAndMetadata.put("key1", list1); ingestSourceAndMetadata.put("key2", list2); List inferenceList = Arrays.asList("test1", "test2", "test3", "test4", "test5", "test6"); - TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedInferenceProcessor processor = createInstanceWithLevel1MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(6, 2, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, null); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -370,7 +370,7 @@ public void testExecute_withListTypeInput_no_update_successful() { verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); - assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + assertEquals(ingestRequest.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); List key1insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key1_knn"); List key1updateVectors = (List) updateDocument.getSourceAndMetadata().get("key1_knn"); List key2insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key2_knn"); @@ -401,9 +401,9 @@ public void testExecute_withListTypeInput_with_update_successful() { .inputTexts(filteredInferenceList) .build(); - OptimizedInferenceProcessor processor = createInstanceWithLevel1MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(6, 2, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -442,9 +442,9 @@ public void testExecute_withNestedListTypeInput_no_update_successful() { Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedInferenceProcessor processor = createInstanceWithLevel2MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(6, 2, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, null); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -489,9 +489,9 @@ public void testExecute_withNestedListTypeInput_with_update_successful() { .inputTexts(filteredInferenceList) .build(); - OptimizedInferenceProcessor processor = createInstanceWithLevel2MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(6, 2, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -527,12 +527,12 @@ public void testExecute_withMapTypeInput_no_update_successful() { List inferenceList = Arrays.asList("test2", "test4"); TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); - OptimizedTextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); mockUpdateDocument(ingestDocument); - mockVectorCreation(2, 6, 0.0f, 0.1f); + mockVectorCreation(request, null); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -563,7 +563,7 @@ public void testExecute_withMapTypeInput_with_update_successful() { List inferenceList = Arrays.asList("test2", "test4"); IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); - OptimizedTextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); ((Map) updateSourceAndMetadata.get("key1")).put("test1", "newValue1"); List filteredInferenceList = Arrays.asList("newValue1"); @@ -574,7 +574,7 @@ public void testExecute_withMapTypeInput_with_update_successful() { .build(); mockUpdateDocument(ingestDocument); - mockVectorCreation(2, 6, 0.0f, 0.1f); + mockVectorCreation(ingestRequest, updateRequest); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -606,10 +606,10 @@ public void testNestedFieldInMapping_withMapTypeInput_no_update_successful() { TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedInferenceProcessor processor = createInstanceWithNestedLevelConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedLevelConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(1, 4, 0.0f, 1.0f); + mockVectorCreation(request, null); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -648,10 +648,10 @@ public void testNestedFieldInMapping_withMapTypeInput_with_update_successful() { .inputTexts(filteredInferenceList) .build(); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedMappingsConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedMappingsConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(1, 4, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -697,10 +697,10 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(1, 100, 0.0f, 1.0f); + mockVectorCreation(request, null); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -752,10 +752,10 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa .modelId("mockModelId") .inputTexts(filteredInferenceList) .build(); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(1, 100, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -798,9 +798,9 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); List inferenceList = Arrays.asList(TEXT_VALUE_1); TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(1, 100, 0.0f, 1.0f); + mockVectorCreation(request, null); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -847,9 +847,9 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi .modelId("mockModelId") .inputTexts(filteredInferenceList) .build(); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(); mockUpdateDocument(ingestDocument); - mockVectorCreation(1, 100, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -907,11 +907,11 @@ public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWitho Map updateSourceAndMetadata = deepCopy(sourceAndMetadata); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationMapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationMapConfig(); List inferenceList = Arrays.asList(TEXT_VALUE_1, TEXT_VALUE_1); TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); mockUpdateDocument(ingestDocument); - mockVectorCreation(2, 100, 0.0f, 1.0f); + mockVectorCreation(request, null); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -984,7 +984,7 @@ public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWitho ); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationMapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationMapConfig(); List inferenceList = Arrays.asList(TEXT_VALUE_1, TEXT_VALUE_1); TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelID").inputTexts(inferenceList).build(); List filteredInferenceList = Arrays.asList("newValue"); @@ -993,7 +993,7 @@ public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWitho .inputTexts(filteredInferenceList) .build(); mockUpdateDocument(ingestDocument); - mockVectorCreation(2, 100, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -1041,11 +1041,11 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_no_update_succe Map updateSourceAndMetadata = deepCopy(sourceAndMetadata); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedMapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedMapConfig(); List inferenceList = Arrays.asList(CHILD_LEVEL_2_TEXT_FIELD_VALUE); TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelID").inputTexts(inferenceList).build(); mockUpdateDocument(ingestDocument); - mockVectorCreation(1, 100, 0.0f, 1.0f); + mockVectorCreation(request, null); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -1080,7 +1080,7 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_with_update_suc ((Map) ((Map) updateSourceAndMetadata.get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1)).put(CHILD_FIELD_LEVEL_2, "newValue"); IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); - OptimizedTextEmbeddingProcessor processor = createInstanceWithNestedMapConfig(); + SelectiveTextEmbeddingProcessor processor = createInstanceWithNestedMapConfig(); List inferenceList = Arrays.asList(CHILD_LEVEL_2_TEXT_FIELD_VALUE); TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelID").inputTexts(inferenceList).build(); List filteredInferenceList = Arrays.asList("newValue"); @@ -1089,7 +1089,7 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_with_update_suc .inputTexts(filteredInferenceList) .build(); mockUpdateDocument(ingestDocument); - mockVectorCreation(2, 100, 0.0f, 1.0f); + mockVectorCreation(ingestRequest, updateRequest); processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); @@ -1183,10 +1183,16 @@ private void verifyEqualEmbeddingInMap(List insertVectors, List update } } - private void mockVectorCreation(int numVectors, int vectorDimension, float min, float max) { + private void mockVectorCreation(TextInferenceRequest ingestRequest, TextInferenceRequest updateRequest) { doAnswer(invocation -> { + int numVectors = ingestRequest.getInputTexts().size(); + ActionListener>> listener = invocation.getArgument(1); + listener.onResponse(createRandomOneDimensionalMockVector(numVectors, 2, 0.0f, 1.0f)); + return null; + }).doAnswer(invocation -> { + int numVectors = updateRequest.getInputTexts().size(); ActionListener>> listener = invocation.getArgument(1); - listener.onResponse(createRandomOneDimensionalMockVector(numVectors, vectorDimension, min, max)); + listener.onResponse(createRandomOneDimensionalMockVector(numVectors, 2, 0.0f, 1.0f)); return null; }).when(mlCommonsClientAccessor).inferenceSentences(isA(TextInferenceRequest.class), isA(ActionListener.class)); } diff --git a/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java index 9affb550a..262bf63f3 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java @@ -270,4 +270,15 @@ public void testComputeFullTextKey_returnsExpectedValues_WithExistingKeys() { String newKey = computeFullTextKey(targetPath, textKey, 2); assertEquals("ml.model", newKey); } + + public void testComputeFullTextKey_returnsExpectedValuesWithDuplicate_WithExistingKeys() { + String setTargetPath = "ml.info.text"; + String targetValue = "ml"; + setUpValidSourceMap(); + setValueToSource(sourceMap, setTargetPath, targetValue); + System.out.println(sourceMap); + String findPath = "ml.info.ml"; + String key = findKeyFromFromValue(sourceMap, findPath, 3); + assertEquals("text", key); + } } diff --git a/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json b/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json new file mode 100644 index 000000000..19f58db36 --- /dev/null +++ b/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json @@ -0,0 +1,20 @@ +{ + "description": "text embedding pipeline for hybrid", + "processors": [ + { + "text_embedding": { + "model_id": "%s", + "field_map": { + "title": "title_knn", + "favor_list": "favor_list_knn", + "favorites": { + "game": "game_knn", + "movie": "movie_knn" + }, + "nested_passages.level_2.level_3_text": "level_3_container.level_3_embedding" + }, + "skip_existing": true + } + } + ] +} diff --git a/src/test/resources/processor/PipelineConfigurationWithSkipExisting.json b/src/test/resources/processor/PipelineConfigurationWithSkipExisting.json new file mode 100644 index 000000000..058d861f4 --- /dev/null +++ b/src/test/resources/processor/PipelineConfigurationWithSkipExisting.json @@ -0,0 +1,23 @@ +{ + "description": "text embedding pipeline for optimized inference call", + "processors": [ + { + "text_embedding": { + "model_id": "%s", + "batch_size": "%d", + "field_map": { + "title": "title_knn", + "favor_list": "favor_list_knn", + "favorites": { + "game": "game_knn", + "movie": "movie_knn" + }, + "nested_passages": { + "text": "embedding" + } + }, + "skip_existing": true + } + } + ] +} diff --git a/src/test/resources/processor/update_doc1.json b/src/test/resources/processor/update_doc1.json new file mode 100644 index 000000000..da4d8fd5a --- /dev/null +++ b/src/test/resources/processor/update_doc1.json @@ -0,0 +1,25 @@ +{ + "title": "This is a good day", + "text": "%s", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": [ + { + "text_not_for_embedding": "test" + }, + { + "text": "bye" + }, + { + "text": "world" + } + ] +} diff --git a/src/test/resources/processor/update_doc2.json b/src/test/resources/processor/update_doc2.json new file mode 100644 index 000000000..455955d72 --- /dev/null +++ b/src/test/resources/processor/update_doc2.json @@ -0,0 +1,23 @@ +{ + "title": "this is a second doc", + "text": "%s", + "description": "the description is not very long", + "favor_list": [ + "favor" + ], + "favorites": { + "game": "silver state", + "movie": null + }, + "nested_passages": [ + { + "text_not_for_embedding": "test" + }, + { + "text": "apple" + }, + { + "text": "banana" + } + ] +} diff --git a/src/test/resources/processor/update_doc3.json b/src/test/resources/processor/update_doc3.json new file mode 100644 index 000000000..597001b8f --- /dev/null +++ b/src/test/resources/processor/update_doc3.json @@ -0,0 +1,23 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": + { + "level_2": + { + "level_3_text": "hello", + "level_3_container": { + "level_4_text_field": "def" + } + } + } +} diff --git a/src/test/resources/processor/update_doc4.json b/src/test/resources/processor/update_doc4.json new file mode 100644 index 000000000..9cecc27a3 --- /dev/null +++ b/src/test/resources/processor/update_doc4.json @@ -0,0 +1,20 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "key", + "hey", + "click" + ], + "favorites": { + "game": "cossacks", + "movie": "matrix" + }, + "nested_passages": + { + "level_2": + { + "level_3_text": "joker" + } + } +} diff --git a/src/test/resources/processor/update_doc5.json b/src/test/resources/processor/update_doc5.json new file mode 100644 index 000000000..65f2c13f2 --- /dev/null +++ b/src/test/resources/processor/update_doc5.json @@ -0,0 +1,24 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "key", + "hey", + "click" + ], + "favorites": { + "game": "cossacks", + "movie": "matrix" + }, + "nested_passages":[ + { + "level_2": + { + "level_3_text": "joker" + } + }, + { + "level_2.level_3_text": "superman" + } + ] +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index bc9762397..6819606fa 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -106,6 +106,10 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json", ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, "processor/PipelineConfigurationWithNestedFieldsMapping.json", + ProcessorType.TEXT_EMBEDDING_WITH_SKIP_EXISTING, + "processor/PipelineConfigurationWithSkipExisting.json", + ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING_WITH_SKIP_EXISTING, + "processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json", ProcessorType.SPARSE_ENCODING_PRUNE, "processor/SparseEncodingPipelineConfigurationWithPrune.json" ); @@ -1381,7 +1385,7 @@ protected void createIndexWithPipeline(String indexName, String indexMappingFile * @param id nullable optional id * @throws Exception */ - protected String ingestDocument(String indexName, String ingestDocument, String id) throws Exception { + protected String ingestDocument(String indexName, String ingestDocument, String id, boolean isUpdate) throws Exception { String endpoint; if (StringUtils.isEmpty(id)) { endpoint = indexName + "/_doc?refresh"; @@ -1403,7 +1407,11 @@ protected String ingestDocument(String indexName, String ingestDocument, String ); String result = (String) map.get("result"); - assertEquals("created", result); + if (isUpdate) { + assertEquals("updated", result); + } else { + assertEquals("created", result); + } return result; } @@ -1414,7 +1422,22 @@ protected String ingestDocument(String indexName, String ingestDocument, String * @throws Exception */ protected String ingestDocument(String indexName, String ingestDocument) throws Exception { - return ingestDocument(indexName, ingestDocument, null); + return ingestDocument(indexName, ingestDocument, null, false); + } + + protected String ingestDocument(String indexName, String ingestDocument, String id) throws Exception { + return ingestDocument(indexName, ingestDocument, id, false); + } + + /** + * Update a document to index using auto generated id + * @param indexName name of the index + * @param ingestDocument + * @param id + * @throws Exception + */ + protected String updateDocument(String indexName, String ingestDocument, String id) throws Exception { + return ingestDocument(indexName, ingestDocument, id, true); } /** @@ -1734,6 +1757,8 @@ protected Object validateDocCountAndInfo( protected enum ProcessorType { TEXT_EMBEDDING, TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, + TEXT_EMBEDDING_WITH_SKIP_EXISTING, + TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING_WITH_SKIP_EXISTING, TEXT_IMAGE_EMBEDDING, SPARSE_ENCODING, SPARSE_ENCODING_PRUNE