From fc832198110f5619bc13589965b013dbbbf3ac86 Mon Sep 17 00:00:00 2001 From: will-hwang Date: Thu, 27 Feb 2025 18:00:56 -0800 Subject: [PATCH] implement single document update scenario in text embedding processor Signed-off-by: will-hwang --- CHANGELOG.md | 1 + .../neuralsearch/plugin/NeuralSearch.java | 7 +- .../processor/InferenceProcessor.java | 33 +- .../processor/TextEmbeddingProcessor.java | 58 +- .../TextEmbeddingProcessorFactory.java | 27 +- .../optimization/InferenceFilter.java | 228 ++++ .../TextEmbeddingInferenceFilter.java | 84 ++ .../processor/util/ProcessorUtils.java | 96 +- .../util/ProcessorDocumentUtils.java | 68 + .../processor/InferenceProcessorTestCase.java | 58 + .../processor/InferenceProcessorTests.java | 2 +- .../processor/TextEmbeddingProcessorIT.java | 99 ++ .../TextEmbeddingProcessorTests.java | 1159 +++++++++++++++-- .../TextEmbeddingInferenceFilterTests.java | 235 ++++ .../util/ProcessorDocumentUtilsTests.java | 47 + .../util/ProcessorUtilsTests.java | 16 + ...thNestedFieldsMappingWithSkipExisting.json | 20 + ...PipelineConfigurationWithSkipExisting.json | 23 + src/test/resources/processor/update_doc1.json | 25 + src/test/resources/processor/update_doc2.json | 23 + src/test/resources/processor/update_doc3.json | 23 + src/test/resources/processor/update_doc4.json | 20 + src/test/resources/processor/update_doc5.json | 24 + .../neuralsearch/BaseNeuralSearchIT.java | 36 +- 24 files changed, 2280 insertions(+), 132 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilterTests.java create mode 100644 src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json create mode 100644 src/test/resources/processor/PipelineConfigurationWithSkipExisting.json create mode 100644 src/test/resources/processor/update_doc1.json create mode 100644 src/test/resources/processor/update_doc2.json create mode 100644 src/test/resources/processor/update_doc3.json create mode 100644 src/test/resources/processor/update_doc4.json create mode 100644 src/test/resources/processor/update_doc5.json diff --git a/CHANGELOG.md b/CHANGELOG.md index caa7a1965..c84cf6830 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features +- Add Optimized Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191)) ### Enhancements - Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838)) - Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007)) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index fad22ba95..2e04ff1d7 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -127,7 +127,12 @@ public Map getProcessors(Processor.Parameters paramet clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); return Map.of( TextEmbeddingProcessor.TYPE, - new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), + new TextEmbeddingProcessorFactory( + parameters.client, + clientAccessor, + parameters.env, + parameters.ingestService.getClusterService() + ), SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), TextImageEmbeddingProcessor.TYPE, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 6ee54afe7..7d0b683b4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.common.collect.Tuple; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; @@ -118,7 +119,7 @@ private void validateEmbeddingConfiguration(Map fieldMap) { public abstract void doExecute( IngestDocument ingestDocument, - Map ProcessMap, + Map processMap, List inferenceList, BiConsumer handler ); @@ -278,7 +279,7 @@ private static class DataForInference { } @SuppressWarnings({ "unchecked" }) - private List createInferenceList(Map knnKeyMap) { + protected List createInferenceList(Map knnKeyMap) { List texts = new ArrayList<>(); knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { Object sourceValue = knnMapEntry.getValue(); @@ -579,11 +580,37 @@ private Map getSourceMapBySourceAndMetadataMap(String processorK private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { List> keyToResult = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) + sourceValue.stream() + .filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where + // sourceValue has been filtered .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); return keyToResult; } + /** + * This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument + * + * @param ingestDocument ingestDocument to populate embeddings to + * @param processMap map indicating the path in ingestDocument to populate embeddings + * @param inferenceList list of texts to be model inference + * @param handler SourceAndMetadataMap of ingestDocument Document + * + */ + protected void makeInferenceCall( + IngestDocument ingestDocument, + Map processMap, + List inferenceList, + BiConsumer handler + ) { + mlCommonsClientAccessor.inferenceSentences( + TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), + ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, processMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); }) + ); + } + @Override public String getType() { return type; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 52904476f..cf525d4c3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -6,9 +6,13 @@ import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; @@ -16,6 +20,8 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter; +import org.opensearch.transport.client.OpenSearchClient; /** * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, @@ -26,6 +32,13 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + public static final String SKIP_EXISTING = "skip_existing"; + public static final boolean DEFAULT_SKIP_EXISTING = false; + private static final String INDEX_FIELD = "_index"; + private static final String ID_FIELD = "_id"; + private final OpenSearchClient openSearchClient; + private final boolean skipExisting; + private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter; public TextEmbeddingProcessor( String tag, @@ -33,27 +46,56 @@ public TextEmbeddingProcessor( int batchSize, String modelId, Map fieldMap, + boolean skipExisting, + TextEmbeddingInferenceFilter textEmbeddingInferenceFilter, + OpenSearchClient openSearchClient, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + this.skipExisting = skipExisting; + this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter; + this.openSearchClient = openSearchClient; } @Override public void doExecute( IngestDocument ingestDocument, - Map ProcessMap, + Map processMap, List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentences( - TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), - ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); }) - ); + // skip existing flag is turned off. Call model inference without filtering + if (skipExisting == false) { + makeInferenceCall(ingestDocument, processMap, inferenceList, handler); + return; + } + // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied + String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString(); + String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString(); + openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> { + final Map existingDocument = response.getSourceAsMap(); + if (existingDocument == null || existingDocument.isEmpty()) { + makeInferenceCall(ingestDocument, processMap, inferenceList, handler); + } else { + // filter given ProcessMap by comparing existing document with ingestDocument + Map filteredProcessMap = textEmbeddingInferenceFilter.filter( + existingDocument, + ingestDocument.getSourceAndMetadata(), + processMap + ); + // create inference list based on filtered ProcessMap + List filteredInferenceList = createInferenceList(filteredProcessMap).stream() + .filter(Objects::nonNull) + .collect(Collectors.toList()); + if (filteredInferenceList.isEmpty()) { + handler.accept(ingestDocument, null); + } else { + makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler); + } + } + }, e -> { handler.accept(null, e); })); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index 6b442b56c..68de02b6f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -4,8 +4,11 @@ */ package org.opensearch.neuralsearch.processor.factory; +import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty; import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.SKIP_EXISTING; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_SKIP_EXISTING; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; @@ -17,12 +20,16 @@ import org.opensearch.ingest.AbstractBatchingProcessor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter; +import org.opensearch.transport.client.OpenSearchClient; /** * Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. */ public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory { + private final OpenSearchClient openSearchClient; + private final MLCommonsClientAccessor clientAccessor; private final Environment environment; @@ -30,11 +37,13 @@ public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcess private final ClusterService clusterService; public TextEmbeddingProcessorFactory( + final OpenSearchClient openSearchClient, final MLCommonsClientAccessor clientAccessor, final Environment environment, final ClusterService clusterService ) { super(TYPE); + this.openSearchClient = openSearchClient; this.clientAccessor = clientAccessor; this.environment = environment; this.clusterService = clusterService; @@ -43,7 +52,21 @@ public TextEmbeddingProcessorFactory( @Override protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); - Map filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); - return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService); + Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + boolean skipExisting = readBooleanProperty(TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING); + TextEmbeddingInferenceFilter textEmbeddingInferenceFilter = new TextEmbeddingInferenceFilter(fieldMap); + return new TextEmbeddingProcessor( + tag, + description, + batchSize, + modelId, + fieldMap, + skipExisting, + textEmbeddingInferenceFilter, + openSearchClient, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java b/src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java new file mode 100644 index 000000000..dbd515be2 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java @@ -0,0 +1,228 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.optimization; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.util.ProcessorUtils; +import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * Abstract class for selective text processing and embedding optimization. + * The InferenceFilter class is designed optimize inference calls by selectively processing text data. + * It reuses existing embeddings when the text content remains unchanged, reducing redundant inference calls and + * improving performance. This is achieved by comparing the text in new and existing documents, copying embeddings + * when the text is identical. + * This class is intended to be extended for different text processing use cases. It provides a recursive filtering + * mechanism that navigates through nested map structures, comparing values, and determining if embeddings can be reused. + */ +@Log4j2 +public abstract class InferenceFilter { + /** + * Stores the reverse mapping of field names to support efficient lookups for embedding keys. + * This is generated by flattening and flipping the provided field map. + */ + protected Map reversedFieldMap; + + /** + * Constructs an InferenceFilter instance and initializes the reversed field map. + */ + public InferenceFilter(Map fieldMap) { + this.reversedFieldMap = ProcessorDocumentUtils.flattenAndFlip(fieldMap); + } + + /** + * Abstract method to filter individual values based on the existing and new metadata maps. + * Implementations should provide logic to compare values and determine if embeddings can be reused. + * + * @param currentPath The current dot-notation path for the value being processed. + * @param processValue The value to be checked for potential embedding reuse. + * @param sourceAndMetadataMap The metadata map of the new document. + * @param existingSourceAndMetadataMap The metadata map of the existing document. + * @param index The index of the value in the list, if applicable (-1 if not part of a list). + * @return The processed value or null if embeddings are reused. + */ + + public abstract Object filterInferenceValue( + String currentPath, + Object processValue, + Map sourceAndMetadataMap, + Map existingSourceAndMetadataMap, + int index + ); + + /** + * Abstract method to filter and compare lists of values. + * If all elements in the list are identical between the new and existing metadata maps, embeddings are copied, + * and an empty list is returned to indicate no further processing is required. + * + * @param processList The list of values to be processed. + * @param existingList The list of existing values for comparison. + * @param embeddingList The list of existing embeddings. + * @param sourceAndMetadataMap The metadata map of the new document. + * @param existingSourceAndMetadataMap The metadata map of the existing document. + * @param fullEmbeddingKey The dot-notation path for the embedding field. + * @return A processed list or an empty list if embeddings are reused. + */ + + public abstract List filterInferenceValuesInList( + List processList, + List existingList, + List embeddingList, + Map sourceAndMetadataMap, + Map existingSourceAndMetadataMap, + String fullEmbeddingKey + ); + + /** + * This method navigates through the nested structure, checking each key-value pair recursively. It supports: + * Map values: Processed recursively using this method. + * List values: Processed using filterListValue. + * Primitive values: Directly compared using filterInferenceValue. + * + * @param existingSourceAndMetadataMap The metadata map of the existing document. + * @param sourceAndMetadataMap The metadata map of the new document. + * @param processMap The current map being processed. + * @return A filtered map containing only elements that require new embeddings. + */ + public Map filter( + Map existingSourceAndMetadataMap, + Map sourceAndMetadataMap, + Map processMap + ) { + return filter(existingSourceAndMetadataMap, sourceAndMetadataMap, processMap, ""); + } + + /** + * Helper method for filter + * @param existingSourceAndMetadataMap The metadata map of the existing document. + * @param sourceAndMetadataMap The metadata map of the new document. + * @param processMap The current map being processed. + * @param traversedPath The dot-notation path of the previously traversed elements. + * e.g: + * In a map structured like: + * level1 + * level2 + * level3 + * traversedPath would be dot-separated string of level1.level2.level3 + * @return A filtered map containing only elements that require new embeddings + */ + private Map filter( + Map existingSourceAndMetadataMap, + Map sourceAndMetadataMap, + Object processMap, + String traversedPath + ) { + if (processMap instanceof Map == false) { + throw new IllegalArgumentException("processMap needs to be an instanceof Map"); + } + Map filteredProcessMap = new HashMap<>(); + Map castedProcessMap = ProcessorUtils.unsafeCastToObjectMap(processMap); + for (Map.Entry entry : castedProcessMap.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + String currentPath = traversedPath.isEmpty() ? key : traversedPath + "." + key; + if (value instanceof Map) { + Map filteredInnerMap = filter(existingSourceAndMetadataMap, sourceAndMetadataMap, value, currentPath); + filteredProcessMap.put(key, filteredInnerMap.isEmpty() ? null : filteredInnerMap); + } else if (value instanceof List) { + List processedList = filterListValue( + currentPath, + ProcessorUtils.unsafeCastToObjectList(value), + sourceAndMetadataMap, + existingSourceAndMetadataMap + ); + filteredProcessMap.put(key, processedList); + } else { + Object processedValue = filterInferenceValue(currentPath, value, sourceAndMetadataMap, existingSourceAndMetadataMap, -1); + filteredProcessMap.put(key, processedValue); + } + } + return filteredProcessMap; + } + + /** + * Processes a list of values by comparing them against source and existing metadata. + * + * @param embeddingKey The current path in dot notation for the list being processed + * @param processList The list of values to process + * @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument Document + * @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document + * @return A processed list containing non-filtered elements + */ + protected List filterListValue( + String embeddingKey, + List processList, + Map sourceAndMetadataMap, + Map existingSourceAndMetadataMap + ) { + String textKey = reversedFieldMap.get(embeddingKey); + if (Objects.isNull(textKey)) { + return processList; + } + Optional existingList = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textKey); + if (existingList.isPresent() == false) { + // if list is not present with given text key, process each entry in processList individually + return filterMapValuesInList(processList, embeddingKey, sourceAndMetadataMap, existingSourceAndMetadataMap); + } + // if list is present with given textKey, but is not a list type, no valid comparison can be made, return processList + if (existingList.get() instanceof List == false) { + return processList; + } + // retrieve embedding for given embedding key, if not present or is not a list type, return processList, as no embedding can be + // copied over + Optional embeddingList = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, embeddingKey); + if (embeddingList.isPresent() == false || embeddingList.get() instanceof List == false) { + return processList; + } + // return empty list if processList and existingList are equal and embeddings are copied, return empty list otherwise + return filterInferenceValuesInList( + processList, + ProcessorUtils.unsafeCastToObjectList(existingList.get()), + ProcessorUtils.unsafeCastToObjectList(embeddingList.get()), + sourceAndMetadataMap, + existingSourceAndMetadataMap, + embeddingKey + ); + } + + /** + * Processes a list containing map values by iterating through each item and processing it individually. + * + * @param processList The list of Map items to process + * @param currentPath The current path in dot notation + * @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument Document + * @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document + * @return A processed list containing non-filtered elements + */ + public List filterMapValuesInList( + List processList, + String currentPath, + Map sourceAndMetadataMap, + Map existingSourceAndMetadataMap + ) { + List filteredList = new ArrayList<>(); + Iterator iterator = processList.iterator(); + int index = 0; + while (iterator.hasNext()) { + Object processedItem = filterInferenceValue( + currentPath, + iterator.next(), + sourceAndMetadataMap, + existingSourceAndMetadataMap, + index++ + ); + filteredList.add(processedItem); + } + return filteredList; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java b/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java new file mode 100644 index 000000000..d82f6ac21 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.optimization; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.util.ProcessorUtils; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * TextEmbeddingInferenceFilter optimizes text embedding inference by selectively processing text data. + * This class extends InferenceFilter to provide efficient text embedding processing by comparing text + * between existing and new documents. If the text is identical, the corresponding embeddings are copied over, + * avoiding redundant inference calls and improving performance. + */ +@Log4j2 +public class TextEmbeddingInferenceFilter extends InferenceFilter { + /** + * Constructs a TextEmbeddingInferenceFilter instance with the specified field map. + */ + public TextEmbeddingInferenceFilter(Map fieldMap) { + super(fieldMap); + } + + /** + * Filters a single value by checking if the text is identical in both the existing and new document. + * If the text matches, the corresponding embedding is copied, and null is returned, indicating no further + * processing is required. + * + * @return Null if embeddings are reused; the original value otherwise. + */ + @Override + public Object filterInferenceValue( + String embeddingPath, + Object processValue, + Map sourceAndMetadataMap, + Map existingSourceAndMetadataMap, + int index + ) { + String textPath = reversedFieldMap.get(embeddingPath); + if (Objects.isNull(textPath)) { + return processValue; + } + Optional existingValue = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textPath, index); + Optional embeddingValue = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, embeddingPath, index); + + if (existingValue.isPresent() && embeddingValue.isPresent() && existingValue.get().equals(processValue)) { + ProcessorUtils.setValueToSource(sourceAndMetadataMap, embeddingPath, embeddingValue.get(), index); + // if successfully copied, return null to be filtered out from process map + return null; + } + // processValue and existingValue are different, return processValue to be included in process map + return processValue; + } + + /** + * Filters List value by checking if the texts in list are identical in both the existing and new document. + * If lists are equal, the corresponding embeddings are copied + * @return empty list if embeddings are reused; the original list otherwise. + */ + @Override + public List filterInferenceValuesInList( + List processList, + List existingList, + List embeddingList, + Map sourceAndMetadataMap, + Map existingSourceAndMetadataMap, + String fullEmbeddingKey + ) { + if (processList.equals(existingList)) { + ProcessorUtils.setValueToSource(sourceAndMetadataMap, fullEmbeddingKey, embeddingList); + // if successfully copied, return empty list to be filtered out from process map + return Collections.emptyList(); + } + // source list and existing list are different, return processList to be included in process map + return processList; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java index d799f323f..57ddf0c1c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java @@ -8,8 +8,11 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.search.SearchHit; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Stack; @@ -115,7 +118,7 @@ public static void removeTargetFieldFromSource(final Map sourceA if (key.equals(lastKey)) { break; } - currentMap = (Map) currentMap.get(key); + currentMap = unsafeCastToObjectMap(currentMap.get(key)); } // Remove the last key this is guaranteed @@ -130,8 +133,7 @@ public static void removeTargetFieldFromSource(final Map sourceA parentMap = currentParentMapWithChild.v1(); key = currentParentMapWithChild.v2(); - @SuppressWarnings("unchecked") - Map innerMap = (Map) parentMap.get(key); + Map innerMap = unsafeCastToObjectMap(parentMap.get(key)); if (innerMap != null && innerMap.isEmpty()) { parentMap.remove(key); @@ -139,38 +141,96 @@ public static void removeTargetFieldFromSource(final Map sourceA } } + public static Optional getValueFromSource(final Map sourceAsMap, final String targetField) { + return getValueFromSource(sourceAsMap, targetField, -1); + } + /** * Returns the mapping associated with a path to a value, otherwise * returns an empty optional when it encounters a dead end. - *
* When the targetField has the form (key[.key]) it will iterate through * the map to see if a mapping exists. + * When a List is encountered during traversal, the index parameter is used to select + * the appropriate element from the list. If the selected element is a map, traversal continues with + * that map using the next key in the path. * * @param sourceAsMap The Source map (a map of maps) to iterate through * @param targetField The path to take to get the desired mapping + * @param index the index to use when a list is encountered during traversal; if list processing is not needed, + * -1 is passed in * @return A possible result within an optional */ - public static Optional getValueFromSource(final Map sourceAsMap, final String targetField) { + public static Optional getValueFromSource(final Map sourceAsMap, final String targetField, int index) { String[] keys = targetField.split("\\."); - Optional currentValue = Optional.of(sourceAsMap); + Optional currentValue = Optional.ofNullable(sourceAsMap); for (String key : keys) { currentValue = currentValue.flatMap(value -> { - if (!(value instanceof Map)) { + if (value instanceof ArrayList && index != -1) { + Object listValue = (unsafeCastToObjectList(value)).get(index); + if (listValue instanceof Map) { + Map currentMap = unsafeCastToObjectMap(listValue); + return Optional.ofNullable(currentMap.get(key)); + } else { + return Optional.empty(); + } + } else if (value instanceof Map) { + Map currentMap = unsafeCastToObjectMap(value); + return Optional.ofNullable(currentMap.get(key)); + } else { return Optional.empty(); } - Map currentMap = (Map) value; - return Optional.ofNullable(currentMap.get(key)); }); - if (currentValue.isEmpty()) { - return Optional.empty(); + break; } } return currentValue; } + public static void setValueToSource(Map sourceAsMap, String targetKey, Object targetValue) { + setValueToSource(sourceAsMap, targetKey, targetValue, -1); + } + + /** + * Inserts or updates a value in a nested map structure, with optional support for list traversal. + * This method navigates through the provided sourceAsMap using the dot-delimited key path + * specified by targetKey. Intermediate maps are created as needed. When a List is encountered, + * the provided index is used to select the element from the list. The selected element must be a map to + * continue the traversal. + * Once the final map in the path is reached, the method sets the value for the last key. + * + * @param sourceAsMap The Source map (a map of maps) to iterate through + * @param targetKey he path to key to insert the desired targetValue + * @param targetValue the value to set at the specified key path + * @param index the index to use when a list is encountered during traversal; if list processing is not needed, + * -1 is passed in + */ + + public static void setValueToSource(Map sourceAsMap, String targetKey, Object targetValue, int index) { + if (Objects.isNull(sourceAsMap) || Objects.isNull(targetKey)) return; + + String[] keys = targetKey.split("\\."); + Map current = sourceAsMap; + + for (int i = 0; i < keys.length - 1; i++) { + Object next = current.computeIfAbsent(keys[i], k -> new HashMap<>()); + if (next instanceof ArrayList list) { + if (index < 0 || index >= list.size()) return; + if (list.get(index) instanceof Map) { + current = unsafeCastToObjectMap(list.get(index)); + } + } else if (next instanceof Map) { + current = unsafeCastToObjectMap(next); + } else { + throw new IllegalStateException("Unexpected data structure at " + keys[i]); + } + } + String lastKey = keys[keys.length - 1]; + current.put(lastKey, targetValue); + } + /** * Determines whether there exists a value that has a mapping according to the pathToValue. This is particularly * useful when the source map is a map of maps and when the pathToValue is of the form key[.key]. @@ -209,4 +269,18 @@ public static boolean isNumeric(Object value) { return false; } + + // This method should be used only when you are certain the object is a `Map`. + // It is recommended to use this method as a last resort. + @SuppressWarnings("unchecked") + public static Map unsafeCastToObjectMap(Object obj) { + return (Map) obj; + } + + // This method should be used only when you are certain the object is a `List`. + // It is recommended to use this method as a last resort. + @SuppressWarnings("unchecked") + public static List unsafeCastToObjectList(Object obj) { + return (List) obj; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java index 0cbf4534d..51bd54ce8 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtils.java @@ -231,6 +231,74 @@ public static Map unflattenJson(Map originalJson return result; } + /** + * Flattens a nested map and then flips each key/value pair. + * For a leaf node, the parent's path is prepended to its value. + * Finally, the flipped map will have the transformed value as the key, + * and the original flattened key as its value. + * + * e.g: + * map: + * { + * "parent": { + * "level1": { + * "key": "value" + * } + * } + * } + * returns + * { + * "parent.level1.value": "parent.level1.key" + * } + * + * @param map the nested map to process + * @return a flattened map with flipped key–value pairs + */ + public static Map flattenAndFlip(Map map) { + Map flippedMap = new HashMap<>(); + flattenAndFlip("", map, flippedMap); + return flippedMap; + } + + /** + * Recursive helper method that processes the nested map. + *

+ * When a leaf value is encountered, the parent's path is computed + * and prepended to the value. The final mapping is flipped, so that + * the new key becomes the computed value and the new value is the flattened key. + * + * @param prefix the current key prefix (initially empty) + * @param map the current map to process + * @param flippedMap the resulting map with flipped key–value pairs + */ + private static void flattenAndFlip(String prefix, Map map, Map flippedMap) { + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + String newKey = prefix.isEmpty() ? key : prefix + "." + key; + + if (value instanceof Map) { + flattenAndFlip(newKey, (Map) value, flippedMap); + } else if (newKey != null) { + String parentPath = ""; + int lastDotIndex = newKey.lastIndexOf('.'); + if (lastDotIndex != -1) { + parentPath = newKey.substring(0, lastDotIndex); + } + + String transformedValue = parentPath.isEmpty() ? value.toString() : parentPath + "." + value.toString(); + if (flippedMap.containsKey(transformedValue)) { + int index = 1; + while (flippedMap.containsKey(transformedValue + "_" + index)) { + index++; + } + transformedValue = transformedValue + "_" + index; + } + flippedMap.put(transformedValue, newKey); + } + } + } + private static List handleList(List list) { List result = new ArrayList<>(); Stack stack = new Stack<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java index caac962e7..1a7276aa0 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTestCase.java @@ -6,11 +6,18 @@ import com.google.common.collect.ImmutableList; import org.apache.commons.lang.math.RandomUtils; +import org.opensearch.action.get.GetResponse; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.get.GetResult; import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.IngestDocumentWrapper; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -72,4 +79,55 @@ protected List> createRandomOneDimensionalMockVector(int numOfVector } return result; } + + protected Map deepCopy(Map original) { + Map copy = new HashMap<>(); + for (Map.Entry entry : original.entrySet()) { + copy.put(entry.getKey(), deepCopyValue(entry.getValue())); + } + return copy; + } + + protected Object deepCopyValue(Object value) { + if (value instanceof Map) { + Map newMap = new HashMap<>(); + for (Map.Entry entry : ((Map) value).entrySet()) { + newMap.put((String) entry.getKey(), deepCopyValue(entry.getValue())); + } + return newMap; + } else if (value instanceof List) { + List newList = new ArrayList<>(); + for (Object item : (List) value) { + newList.add(deepCopyValue(item)); + } + return newList; + } else if (value instanceof String) { + return new String((String) value); + } else { + return value; + } + } + + protected GetResponse mockEmptyGetResponse() throws IOException { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field("_index", "my_index") + .field("_id", "1") + .field("found", false) + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + return GetResponse.fromXContent(contentParser); + } + + protected GetResponse convertToGetResponse(IngestDocument ingestDocument) throws IOException { + String index = ingestDocument.getSourceAndMetadata().get("_index").toString(); + String id = ingestDocument.getSourceAndMetadata().get("_id").toString(); + Map source = ingestDocument.getSourceAndMetadata(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.map(source); + BytesReference bytes = BytesReference.bytes(builder); + GetResult result = new GetResult(index, id, 0, 1, 1, true, bytes, null, null); + return new GetResponse(result); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java index 99f5bf7b0..e0b1ce076 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -200,7 +200,7 @@ public void doExecute( ) {} @Override - void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { + protected void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { // use to verify if doBatchExecute is called from InferenceProcessor clientAccessor.inferenceSentences(TEXT_INFERENCE_REQUEST, ActionListener.wrap(results -> {}, ex -> {})); allInferenceInputs.add(inferenceList); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java index 156e25ad7..71c65d3e1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorIT.java @@ -46,11 +46,19 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT { protected static final String TEXT_FIELD_VALUE_1 = "hello"; protected static final String TEXT_FIELD_VALUE_2 = "clown"; protected static final String TEXT_FIELD_VALUE_3 = "abc"; + protected static final String TEXT_FIELD_VALUE_4 = "def"; + protected static final String TEXT_FIELD_VALUE_5 = "joker"; + private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI())); private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI())); private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI())); private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI())); private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI())); + private final String UPDATE_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/update_doc1.json").toURI())); + private final String UPDATE_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/update_doc3.json").toURI())); + private final String UPDATE_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/update_doc4.json").toURI())); + private final String UPDATE_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/update_doc5.json").toURI())); + private final String BULK_ITEM_TEMPLATE = Files.readString( Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI()) ); @@ -72,6 +80,17 @@ public void testTextEmbeddingProcessor() throws Exception { assertEquals(1, getDocCount(INDEX_NAME)); } + public void testTextEmbeddingProcessorWithSkipExisting() throws Exception { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_SKIP_EXISTING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC1, "1"); + updateDocument(INDEX_NAME, UPDATE_DOC1, "1"); + assertEquals(1, getDocCount(INDEX_NAME)); + assertEquals(2, getDocById(INDEX_NAME, "1").get("_version")); + } + public void testTextEmbeddingProcessor_batch() throws Exception { String modelId = uploadTextEmbeddingModel(); loadModel(modelId); @@ -132,6 +151,86 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws assertTrue((double) innerHitDetails.get("_score") <= 1.0); } + public void testNestedFieldMapping_whenDocumentsIngested_WithSkipExisting_thenSuccessful() throws Exception { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING_WITH_SKIP_EXISTING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC3, "3"); + updateDocument(INDEX_NAME, UPDATE_DOC3, "3"); + ingestDocument(INDEX_NAME, INGEST_DOC4, "4"); + updateDocument(INDEX_NAME, UPDATE_DOC4, "4"); + + assertDoc((Map) getDocById(INDEX_NAME, "3").get("_source"), TEXT_FIELD_VALUE_1, Optional.of(TEXT_FIELD_VALUE_4)); + assertDoc((Map) getDocById(INDEX_NAME, "4").get("_source"), TEXT_FIELD_VALUE_5, Optional.empty()); + + NeuralQueryBuilder neuralQueryBuilderQuery = NeuralQueryBuilder.builder() + .fieldName(LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING) + .queryText(QUERY_TEXT) + .modelId(modelId) + .k(10) + .build(); + + QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD, + neuralQueryBuilderQuery, + ScoreMode.Total + ); + QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total); + + Map searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2); + assertNotNull(searchResponseAsMap); + + Map hits = (Map) searchResponseAsMap.get("hits"); + assertNotNull(hits); + + assertEquals(1.0, hits.get("max_score")); + List> listOfHits = (List>) hits.get("hits"); + assertNotNull(listOfHits); + assertEquals(2, listOfHits.size()); + + Map innerHitDetails = listOfHits.get(0); + assertEquals("3", innerHitDetails.get("_id")); + assertEquals(1.0, innerHitDetails.get("_score")); + + innerHitDetails = listOfHits.get(1); + assertEquals("4", innerHitDetails.get("_id")); + assertTrue((double) innerHitDetails.get("_score") <= 1.0); + } + + public void testNestedFieldMapping_whenDocumentInListIngestedAndUpdated_WithSkipExisting_thenSuccessful() throws Exception { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING); + createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME); + ingestDocument(INDEX_NAME, INGEST_DOC5, "5"); + updateDocument(INDEX_NAME, UPDATE_DOC5, "5"); + + assertDocWithLevel2AsList((Map) getDocById(INDEX_NAME, "5").get("_source")); + + NeuralQueryBuilder neuralQueryBuilderQuery = NeuralQueryBuilder.builder() + .fieldName(LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING) + .queryText(QUERY_TEXT) + .modelId(modelId) + .k(10) + .build(); + + QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery( + LEVEL_1_FIELD + "." + LEVEL_2_FIELD, + neuralQueryBuilderQuery, + ScoreMode.Total + ); + QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total); + + Map searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2); + assertNotNull(searchResponseAsMap); + + assertEquals(1, getHitCount(searchResponseAsMap)); + + Map innerHitDetails = getFirstInnerHit(searchResponseAsMap); + assertEquals("5", innerHitDetails.get("_id")); + } + private void assertDoc(Map sourceMap, String textFieldValue, Optional level3ExpectedValue) { assertNotNull(sourceMap); assertTrue(sourceMap.containsKey(LEVEL_1_FIELD)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 5e89b5d55..e8f2d2a0b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -10,6 +10,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.isNull; @@ -32,10 +33,14 @@ import org.apache.commons.lang3.tuple.Pair; import org.junit.Before; import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -52,6 +57,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.opensearch.transport.client.OpenSearchClient; public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { @@ -67,6 +73,10 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { protected static final String TEXT_VALUE_2 = "text_value2"; protected static final String TEXT_VALUE_3 = "text_value3"; protected static final String TEXT_FIELD_2 = "abc"; + + @Mock + private OpenSearchClient openSearchClient; + @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; @@ -77,6 +87,10 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { @InjectMocks private TextEmbeddingProcessorFactory textEmbeddingProcessorFactory; + + @Captor + private ArgumentCaptor inferenceRequestCaptor; + private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -88,7 +102,7 @@ public void setup() { } @SneakyThrows - private TextEmbeddingProcessor createInstanceWithLevel2MapConfig() { + private TextEmbeddingProcessor createInstanceWithLevel2MapConfig(boolean skipExisting) { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -96,15 +110,17 @@ private TextEmbeddingProcessor createInstanceWithLevel2MapConfig() { TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", ImmutableMap.of("test1", "test1_knn"), "key2", ImmutableMap.of("test3", CHILD_LEVEL_2_KNN_FIELD)) ); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, skipExisting); return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @SneakyThrows - private TextEmbeddingProcessor createInstanceWithLevel1MapConfig() { + private TextEmbeddingProcessor createInstanceWithLevel1MapConfig(boolean skipExisting) { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1_knn", "key2", "key2_knn")); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, skipExisting); return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @@ -118,6 +134,92 @@ private TextEmbeddingProcessor createInstanceWithLevel1MapConfig(int batchSize) return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithNestedLevelConfig(boolean skipExisting) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, skipExisting); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_FIELD_LEVEL_2)), + CHILD_LEVEL_2_KNN_FIELD + ) + ); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithNestedMappingsConfig(boolean skipExisting) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, skipExisting); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_FIELD_LEVEL_2)), + CHILD_LEVEL_2_KNN_FIELD + ) + ); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithNestedSourceAndDestinationConfig(boolean skipExisting) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, skipExisting); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)), + CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD + ) + ); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithNestedMapConfig(boolean skipExisting) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.SKIP_EXISTING, skipExisting); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1)), + Map.of(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD) + ) + ); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextEmbeddingProcessor createInstanceWithNestedSourceAndDestinationMapConfig(boolean skipExisting) { + Map registry = new HashMap<>(); + Map config = buildObjMap( + Pair.of(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), + Pair.of( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + buildObjMap( + Pair.of( + PARENT_FIELD, + Map.of( + CHILD_FIELD_LEVEL_1, + Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD)) + ) + ) + ) + ), + Pair.of(TextEmbeddingProcessor.SKIP_EXISTING, skipExisting) + ); + return (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + @SneakyThrows public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalArgumentException() { Map registry = new HashMap<>(); @@ -152,7 +254,7 @@ public void testExecute_successful() { sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -174,7 +276,9 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + OpenSearchClient openSearchClient = mock(OpenSearchClient.class); TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory( + openSearchClient, accessor, environment, clusterService @@ -203,7 +307,9 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + OpenSearchClient openSearchClient = mock(OpenSearchClient.class); TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory( + openSearchClient, accessor, environment, clusterService @@ -233,7 +339,7 @@ public void testExecute_withListTypeInput_successful() { sourceAndMetadata.put("key1", list1); sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -252,7 +358,7 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", " "); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -265,7 +371,7 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key1", list1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); @@ -278,7 +384,7 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() { sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -293,7 +399,7 @@ public void testExecute_listHasNull_throwIllegalArgumentException() { sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); sourceAndMetadata.put("key2", list); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -309,7 +415,7 @@ public void testExecute_withMapTypeInput_successful() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(false); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -335,22 +441,7 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() { sourceAndMetadata.put(PARENT_FIELD, childLevel1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - ImmutableMap.of( - String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_FIELD_LEVEL_2)), - CHILD_LEVEL_2_KNN_FIELD - ) - ); - TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( - registry, - PROCESSOR_TAG, - DESCRIPTION, - config - ); + TextEmbeddingProcessor processor = createInstanceWithNestedLevelConfig(false); List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { @@ -395,22 +486,7 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHa sourceAndMetadata.put(PARENT_FIELD, childLevel1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - ImmutableMap.of( - String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)), - CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD - ) - ); - TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( - registry, - PROCESSOR_TAG, - DESCRIPTION, - config - ); + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(false); List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { @@ -453,22 +529,7 @@ public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWi sourceAndMetadata.put(PARENT_FIELD, childLevel1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - ImmutableMap.of( - String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)), - CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD - ) - ); - TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( - registry, - PROCESSOR_TAG, - DESCRIPTION, - config - ); + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(false); List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { @@ -526,28 +587,7 @@ public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWitho ); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - Map registry = new HashMap<>(); - Map config = buildObjMap( - Pair.of(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"), - Pair.of( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - buildObjMap( - Pair.of( - PARENT_FIELD, - Map.of( - CHILD_FIELD_LEVEL_1, - Map.of(CHILD_1_TEXT_FIELD, String.join(".", CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD)) - ) - ) - ) - ) - ); - TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( - registry, - PROCESSOR_TAG, - DESCRIPTION, - config - ); + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationMapConfig(false); List> modelTensorList = createRandomOneDimensionalMockVector(2, 100, 0.0f, 1.0f); doAnswer(invocation -> { @@ -588,22 +628,7 @@ public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() { sourceAndMetadata.put(PARENT_FIELD, childLevel1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - ImmutableMap.of( - String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1)), - Map.of(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_KNN_FIELD) - ) - ); - TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create( - registry, - PROCESSOR_TAG, - DESCRIPTION, - config - ); + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfig(false); List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); doAnswer(invocation -> { @@ -634,7 +659,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -648,7 +673,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -661,7 +686,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { sourceAndMetadata.put("key1", "hello world"); sourceAndMetadata.put("key2", ret); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -673,7 +698,7 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(1); @@ -704,7 +729,7 @@ public void testExecute_hybridTypeInput_successful() throws Exception { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key2", map1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(false); IngestDocument document = processor.execute(ingestDocument); assert document.getSourceAndMetadata().containsKey("key2"); } @@ -722,13 +747,13 @@ public void testExecute_simpleTypeInputWithNonStringValue_handleIllegalArgumentE }).when(mlCommonsClientAccessor).inferenceSentences(argThat(request -> request.getInputTexts() != null), isA(ActionListener.class)); BiConsumer handler = mock(BiConsumer.class); - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testGetType_successful() { - TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(false); assert processor.getType().equals(TextEmbeddingProcessor.TYPE); } @@ -1115,6 +1140,876 @@ public void test_batchExecute_exception() { } } + @SneakyThrows + public void testExecute_when_initial_ingest_with_skip_existing_flag_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("_id", "1"); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + List inferenceList = Arrays.asList("value1", "value2"); + TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(true); + + GetResponse response = mockEmptyGetResponse(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(openSearchClient).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(1); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(isA(TextInferenceRequest.class), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(1)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + } + + @SneakyThrows + public void testExecute_with_no_update_with_skip_existing_flag_successful() { + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", "value1"); + ingestSourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + List inferenceList = Arrays.asList("value1", "value2"); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); // no change + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, null); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); // insert document + processor.execute(updateDocument, handler); // update document + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(ingestRequest.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + List key1insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key1_knn"); + List key1updateVectors = (List) updateDocument.getSourceAndMetadata().get("key1_knn"); + List key2insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key2_knn"); + List key2updateVectors = (List) updateDocument.getSourceAndMetadata().get("key2_knn"); + verifyEqualEmbedding(key1insertVectors, key1updateVectors); + verifyEqualEmbedding(key2insertVectors, key2updateVectors); + } + + public void testExecute_with_updated_field_skip_existing_flag_successful() { + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", "value1"); + ingestSourceAndMetadata.put("key2", "value2"); + List inferenceList = Arrays.asList("value1", "value2"); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + updateSourceAndMetadata.put("key2", "newValue"); // updated + List filteredInferenceList = List.of("newValue"); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelId") + .inputTexts(filteredInferenceList) + .build(); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); // insert + processor.execute(updateDocument, handler); // update + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + + List key1insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key1_knn"); + List key1updateVectors = (List) updateDocument.getSourceAndMetadata().get("key1_knn"); + List key2insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key2_knn"); + List key2updateVectors = (List) updateDocument.getSourceAndMetadata().get("key2_knn"); + verifyEqualEmbedding(key1insertVectors, key1updateVectors); + assertEquals(key2insertVectors.size(), key2updateVectors.size()); + } + + public void testExecute_withListTypeInput_no_update_skip_existing_flag_successful() { + List list1 = ImmutableList.of("test1", "test2", "test3"); + List list2 = ImmutableList.of("test4", "test5", "test6"); + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", list1); + ingestSourceAndMetadata.put("key2", list2); + List inferenceList = Arrays.asList("test1", "test2", "test3", "test4", "test5", "test6"); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, null); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(ingestRequest.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + List key1insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key1_knn"); + List key1updateVectors = (List) updateDocument.getSourceAndMetadata().get("key1_knn"); + List key2insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key2_knn"); + List key2updateVectors = (List) updateDocument.getSourceAndMetadata().get("key2_knn"); + verifyEqualEmbeddingInMap(key1insertVectors, key1updateVectors); + verifyEqualEmbeddingInMap(key2insertVectors, key2updateVectors); + } + + public void testExecute_withListTypeInput_with_update_skip_existing_flag_successful() { + List list1 = ImmutableList.of("test1", "test2", "test3"); + List list2 = ImmutableList.of("test4", "test5", "test6"); + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", list1); + ingestSourceAndMetadata.put("key2", list2); + List inferenceList = Arrays.asList("test1", "test2", "test3", "test4", "test5", "test6"); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + updateSourceAndMetadata.put("key1", ImmutableList.of("test1", "newValue1", "newValue2")); + updateSourceAndMetadata.put("key2", ImmutableList.of("newValue3", "test5", "test6")); + List filteredInferenceList = Arrays.asList("test1", "newValue1", "newValue2", "newValue3", "test5", "test6"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelId") + .inputTexts(filteredInferenceList) + .build(); + + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + List key1insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key1_knn"); + List key1updateVectors = (List) updateDocument.getSourceAndMetadata().get("key1_knn"); + List key2insertVectors = (List) ingestDocument.getSourceAndMetadata().get("key2_knn"); + List key2updateVectors = (List) updateDocument.getSourceAndMetadata().get("key2_knn"); + assertEquals(key1insertVectors.size(), key1updateVectors.size()); + assertEquals(key2insertVectors.size(), key2updateVectors.size()); + } + + public void testExecute_withNestedListTypeInput_no_update_skip_existing_flag_successful() { + Map> map1 = new HashMap<>(); + map1.put("test1", ImmutableList.of("test1", "test2", "test3")); + Map> map2 = new HashMap<>(); + map2.put("test3", ImmutableList.of("test4", "test5", "test6")); + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", map1); + ingestSourceAndMetadata.put("key2", map2); + List inferenceList = Arrays.asList("test1", "test2", "test3", "test4", "test5", "test6"); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, null); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(ingestRequest.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + List key1IngestVectors = ((List) ((Map) ingestDocument.getSourceAndMetadata().get("key1")).get("test1_knn")); + List key1UpdateVectors = ((List) ((Map) updateDocument.getSourceAndMetadata().get("key1")).get("test1_knn")); + List key2IngestVectors = ((List) ((Map) ingestDocument.getSourceAndMetadata().get("key2")).get("test3_knn")); + List key2UpdateVectors = ((List) ((Map) updateDocument.getSourceAndMetadata().get("key2")).get("test3_knn")); + assertEquals(key1IngestVectors.size(), key1UpdateVectors.size()); + assertEquals(key2IngestVectors.size(), key2UpdateVectors.size()); + verifyEqualEmbeddingInMap(key1IngestVectors, key1UpdateVectors); + verifyEqualEmbeddingInMap(key2IngestVectors, key2UpdateVectors); + } + + public void testExecute_withNestedListTypeInput_with_update_skip_existing_flag_successful() { + Map> map1 = new HashMap<>(); + map1.put("test1", ImmutableList.of("test1", "test2", "test3")); + Map> map2 = new HashMap<>(); + map2.put("test3", ImmutableList.of("test4", "test5", "test6")); + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", map1); + ingestSourceAndMetadata.put("key2", map2); + List inferenceList = Arrays.asList("test1", "test2", "test3", "test4", "test5", "test6"); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + ((Map) updateSourceAndMetadata.get("key1")).put("test1", ImmutableList.of("test1", "newValue1", "newValue2")); + ((Map) updateSourceAndMetadata.get("key2")).put("test3", ImmutableList.of("newValue3", "test5", "test6")); + + List filteredInferenceList = Arrays.asList("test1", "newValue1", "newValue2", "newValue3", "test5", "test6"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelId") + .inputTexts(filteredInferenceList) + .build(); + + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + + List key1IngestVectors = ((List) ((Map) ingestDocument.getSourceAndMetadata().get("key1")).get("test1_knn")); + List key1UpdateVectors = ((List) ((Map) updateDocument.getSourceAndMetadata().get("key1")).get("test1_knn")); + List key2IngestVectors = ((List) ((Map) ingestDocument.getSourceAndMetadata().get("key2")).get("test3_knn")); + List key2UpdateVectors = ((List) ((Map) updateDocument.getSourceAndMetadata().get("key2")).get("test3_knn")); + assertEquals(key1IngestVectors.size(), key1UpdateVectors.size()); + assertEquals(key2IngestVectors.size(), key2UpdateVectors.size()); + } + + public void testExecute_withMapTypeInput_no_update_skip_existing_flag_successful() { + Map map1 = new HashMap<>(); + map1.put("test1", "test2"); + Map map2 = new HashMap<>(); + map2.put("test3", "test4"); + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", map1); + ingestSourceAndMetadata.put("key2", map2); + List inferenceList = Arrays.asList("test2", "test4"); + TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(true); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + mockUpdateDocument(ingestDocument); + mockVectorCreation(request, null); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + List test1KnnInsertVectors = (List) ((Map) ingestDocument.getSourceAndMetadata().get("key1")).get("test1_knn"); + List test1KnnUpdateVectors = (List) ((Map) updateDocument.getSourceAndMetadata().get("key1")).get("test1_knn"); + List test3KnnInsertVectors = (List) ((Map) ingestDocument.getSourceAndMetadata().get("key2")).get("test3_knn"); + List test3KnnUpdateVectors = (List) ((Map) updateDocument.getSourceAndMetadata().get("key2")).get("test3_knn"); + verifyEqualEmbedding(test1KnnInsertVectors, test1KnnUpdateVectors); + verifyEqualEmbedding(test3KnnInsertVectors, test3KnnUpdateVectors); + } + + public void testExecute_withMapTypeInput_with_update_successful() { + Map map1 = new HashMap<>(); + map1.put("test1", "test2"); + Map map2 = new HashMap<>(); + map2.put("test3", "test4"); + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + ingestSourceAndMetadata.put("key1", map1); + ingestSourceAndMetadata.put("key2", map2); + List inferenceList = Arrays.asList("test2", "test4"); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(true); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + ((Map) updateSourceAndMetadata.get("key1")).put("test1", "newValue1"); + List filteredInferenceList = Arrays.asList("newValue1"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelId") + .inputTexts(filteredInferenceList) + .build(); + + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + processor.execute(updateDocument, handler); + + verify(handler, times(2)).accept(any(IngestDocument.class), isNull()); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + List test1KnnInsertVectors = (List) ((Map) ingestDocument.getSourceAndMetadata().get("key2")).get("test3_knn"); + List test1KnnUpdateVectors = (List) ((Map) updateDocument.getSourceAndMetadata().get("key2")).get("test3_knn"); + verifyEqualEmbedding(test1KnnInsertVectors, test1KnnUpdateVectors); + } + + @SneakyThrows + public void testNestedFieldInMapping_withMapTypeInput_no_update_successful() { + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_TEXT_FIELD_VALUE); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + ingestSourceAndMetadata.put(PARENT_FIELD, childLevel1); + List inferenceList = Arrays.asList(CHILD_LEVEL_2_TEXT_FIELD_VALUE); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithNestedLevelConfig(true); + + mockUpdateDocument(ingestDocument); + mockVectorCreation(request, null); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + List test3KnnInsertVectors = (List) (((Map) ((Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD)).get( + CHILD_FIELD_LEVEL_1 + ))).get(CHILD_LEVEL_2_KNN_FIELD); + List test3KnnUpdateVectors = (List) (((Map) ((Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD)).get( + CHILD_FIELD_LEVEL_1 + ))).get(CHILD_LEVEL_2_KNN_FIELD); + verifyEqualEmbedding(test3KnnInsertVectors, test3KnnUpdateVectors); + } + + @SneakyThrows + public void testNestedFieldInMapping_withMapTypeInput_with_update_successful() { + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_TEXT_FIELD_VALUE); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + ingestSourceAndMetadata.put(PARENT_FIELD, childLevel1); + List inferenceList = Arrays.asList(CHILD_LEVEL_2_TEXT_FIELD_VALUE); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + ((Map) ((Map) updateSourceAndMetadata.get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1)).put(CHILD_FIELD_LEVEL_2, "newValue"); + List filteredInferenceList = Arrays.asList("newValue"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelId") + .inputTexts(filteredInferenceList) + .build(); + + TextEmbeddingProcessor processor = createInstanceWithNestedMappingsConfig(true); + + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + List test3KnnInsertVectors = (List) (((Map) ((Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD)).get( + CHILD_FIELD_LEVEL_1 + ))).get(CHILD_LEVEL_2_KNN_FIELD); + List test3KnnUpdateVectors = (List) (((Map) ((Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD)).get( + CHILD_FIELD_LEVEL_1 + ))).get(CHILD_LEVEL_2_KNN_FIELD); + assertEquals(test3KnnInsertVectors.size(), test3KnnUpdateVectors.size()); + } + + @SneakyThrows + public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHasTheDestinationStructure_no_update_thenSuccessful() { + /* + modeling following document: + parent: + child_level_1: + child_level_1_text_field: "text" + child_level_2: + child_level_2_text_field: "abc" + */ + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + Map childLevel2NestedField = new HashMap<>(); + childLevel2NestedField.put(CHILD_LEVEL_2_TEXT_FIELD_VALUE, TEXT_FIELD_2); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, childLevel2NestedField); + childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + ingestSourceAndMetadata.put(PARENT_FIELD, childLevel1); + List inferenceList = Arrays.asList(TEXT_VALUE_1); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(true); + + mockUpdateDocument(ingestDocument); + mockVectorCreation(request, null); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + Map nestedIngestMap = (Map) (((Map) ((Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1))).get( + CHILD_FIELD_LEVEL_2 + ); + Map nestedUpdateMap = (Map) (((Map) ((Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1))).get( + CHILD_FIELD_LEVEL_2 + ); + List test3KnnIngestVectors = (List) nestedIngestMap.get(CHILD_LEVEL_2_KNN_FIELD); + List test3KnnUpdateVectors = (List) nestedUpdateMap.get(CHILD_LEVEL_2_KNN_FIELD); + + verifyEqualEmbedding(test3KnnIngestVectors, test3KnnUpdateVectors); + } + + @SneakyThrows + public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHasTheDestinationStructure_with_update_thenSuccessful() { + /* + modeling following document: + parent: + child_level_1: + child_level_1_text_field: "text" + child_level_2: + child_level_2_text_field: "abc" + */ + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + Map childLevel2NestedField = new HashMap<>(); + childLevel2NestedField.put(CHILD_LEVEL_2_TEXT_FIELD_VALUE, TEXT_FIELD_2); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, childLevel2NestedField); + childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + ingestSourceAndMetadata.put(PARENT_FIELD, childLevel1); + List inferenceList = Arrays.asList(TEXT_VALUE_1); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + ((Map) ((Map) updateSourceAndMetadata.get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1)).put(CHILD_1_TEXT_FIELD, "newValue"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + List filteredInferenceList = Arrays.asList("newValue"); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelId") + .inputTexts(filteredInferenceList) + .build(); + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(true); + + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + Map nestedIngestMap = (Map) (((Map) ((Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1))).get( + CHILD_FIELD_LEVEL_2 + ); + Map nestedUpdateMap = (Map) (((Map) ((Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1))).get( + CHILD_FIELD_LEVEL_2 + ); + List test3KnnIngestVectors = (List) nestedIngestMap.get(CHILD_LEVEL_2_KNN_FIELD); + List test3KnnUpdateVectors = (List) nestedUpdateMap.get(CHILD_LEVEL_2_KNN_FIELD); + + assertEquals(test3KnnIngestVectors.size(), test3KnnUpdateVectors.size()); + } + + @SneakyThrows + public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWithoutDestinationStructure_no_update_thenSuccessful() { + /* + modeling following document: + parent: + child_level_1: + child_level_1_text_field: "text" + */ + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + ingestSourceAndMetadata.put(PARENT_FIELD, childLevel1); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + List inferenceList = Arrays.asList(TEXT_VALUE_1); + TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(request, null); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + Map parent1AfterProcessor = (Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel1Actual = (Map) parent1AfterProcessor.get(CHILD_FIELD_LEVEL_1); + Map child2Actual = (Map) childLevel1Actual.get(CHILD_FIELD_LEVEL_2); + Map updateParent1AfterProcessor = (Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map updateChildLevel1Actual = (Map) updateParent1AfterProcessor.get(CHILD_FIELD_LEVEL_1); + Map updateChild2Actual = (Map) updateChildLevel1Actual.get(CHILD_FIELD_LEVEL_2); + List ingestVectors = (List) child2Actual.get(CHILD_LEVEL_2_KNN_FIELD); + List updateVectors = (List) updateChild2Actual.get(CHILD_LEVEL_2_KNN_FIELD); + verifyEqualEmbedding(ingestVectors, updateVectors); + + } + + @SneakyThrows + public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWithoutDestinationStructure_with_update_thenSuccessful() { + /* + modeling following document: + parent: + child_level_1: + child_level_1_text_field: "text" + */ + Map ingestSourceAndMetadata = new HashMap<>(); + ingestSourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + ingestSourceAndMetadata.put("_id", "1"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + ingestSourceAndMetadata.put(PARENT_FIELD, childLevel1); + IngestDocument ingestDocument = new IngestDocument(ingestSourceAndMetadata, new HashMap<>()); + List inferenceList = Arrays.asList(TEXT_VALUE_1); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + Map updateSourceAndMetadata = deepCopy(ingestSourceAndMetadata); + ((Map) ((Map) updateSourceAndMetadata.get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1)).put(CHILD_1_TEXT_FIELD, "newValue"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + List filteredInferenceList = Arrays.asList("newValue"); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelId") + .inputTexts(filteredInferenceList) + .build(); + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationConfig(true); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + Map parent1AfterProcessor = (Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel1Actual = (Map) parent1AfterProcessor.get(CHILD_FIELD_LEVEL_1); + Map child2Actual = (Map) childLevel1Actual.get(CHILD_FIELD_LEVEL_2); + Map updateParent1AfterProcessor = (Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map updateChildLevel1Actual = (Map) updateParent1AfterProcessor.get(CHILD_FIELD_LEVEL_1); + Map updateChild2Actual = (Map) updateChildLevel1Actual.get(CHILD_FIELD_LEVEL_2); + List ingestVectors = (List) child2Actual.get(CHILD_LEVEL_2_KNN_FIELD); + List updateVectors = (List) updateChild2Actual.get(CHILD_LEVEL_2_KNN_FIELD); + assertEquals(ingestVectors.size(), updateVectors.size()); + + } + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWithoutDestinationStructure_no_update_theSuccessful() { + /* + modeling following document: + parent: [ + { + child_level_1: + child_1_text_field: "text_value", + }, + { + child_level_1: + child_1_text_field: "text_value", + child_2_text_field: "text_value2", + child_3_text_field: "text_value3", + } + + ] + */ + Map child1Level2 = buildObjMap(Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1)); + Map child1Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child1Level2)); + Map child2Level2 = buildObjMap( + Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1), + Pair.of(CHILD_2_TEXT_FIELD, TEXT_VALUE_2), + Pair.of(CHILD_3_TEXT_FIELD, TEXT_VALUE_3) + ); + Map child2Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child2Level2)); + Map sourceAndMetadata = buildObjMap( + Pair.of(PARENT_FIELD, Arrays.asList(child1Level1, child2Level1)), + Pair.of(IndexFieldMapper.NAME, "my_index"), + Pair.of("_id", "1") + ); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = deepCopy(sourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationMapConfig(true); + List inferenceList = Arrays.asList(TEXT_VALUE_1, TEXT_VALUE_1); + TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelId").inputTexts(inferenceList).build(); + mockUpdateDocument(ingestDocument); + mockVectorCreation(request, null); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + List> parentAfterIngestProcessor = (List>) ingestDocument.getSourceAndMetadata() + .get(PARENT_FIELD); + + List> parentAfterUpdateProcessor = (List>) updateDocument.getSourceAndMetadata() + .get(PARENT_FIELD); + List> insertVectors = new ArrayList<>(); + List> updateVectors = new ArrayList<>(); + for (Map childActual : parentAfterIngestProcessor) { + Map childLevel1Actual = (Map) childActual.get(CHILD_FIELD_LEVEL_1); + assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD)); + assertNotNull(childLevel1Actual.get(CHILD_FIELD_LEVEL_2)); + Map childLevel2Actual = (Map) childLevel1Actual.get(CHILD_FIELD_LEVEL_2); + insertVectors.add((List) childLevel2Actual.get(CHILD_LEVEL_2_KNN_FIELD)); + } + + for (Map childActual : parentAfterUpdateProcessor) { + Map childLevel1Actual = (Map) childActual.get(CHILD_FIELD_LEVEL_1); + assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD)); + assertNotNull(childLevel1Actual.get(CHILD_FIELD_LEVEL_2)); + Map childLevel2Actual = (Map) childLevel1Actual.get(CHILD_FIELD_LEVEL_2); + updateVectors.add((List) childLevel2Actual.get(CHILD_LEVEL_2_KNN_FIELD)); + } + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + verifyEqualEmbeddingInNestedList(insertVectors, updateVectors); + } + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testNestedFieldInMappingForListWithNestedObj_withIngestDocumentWithoutDestinationStructure_with_update_theSuccessful() { + /* + modeling following document: + parent: [ + { + child_level_1: + child_1_text_field: "text_value", + }, + { + child_level_1: + child_1_text_field: "text_value", + child_2_text_field: "text_value2", + child_3_text_field: "text_value3", + } + + ] + */ + Map child1Level2 = buildObjMap(Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1)); + Map child1Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child1Level2)); + Map child2Level2 = buildObjMap( + Pair.of(CHILD_1_TEXT_FIELD, TEXT_VALUE_1), + Pair.of(CHILD_2_TEXT_FIELD, TEXT_VALUE_2), + Pair.of(CHILD_3_TEXT_FIELD, TEXT_VALUE_3) + ); + Map child2Level1 = buildObjMap(Pair.of(CHILD_FIELD_LEVEL_1, child2Level2)); + Map sourceAndMetadata = buildObjMap( + Pair.of(PARENT_FIELD, Arrays.asList(child1Level1, child2Level1)), + Pair.of(IndexFieldMapper.NAME, "my_index"), + Pair.of("_id", "1") + ); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = deepCopy(sourceAndMetadata); + ((Map) ((Map) ((List) updateSourceAndMetadata.get(PARENT_FIELD)).get(0)).get(CHILD_FIELD_LEVEL_1)).put( + CHILD_1_TEXT_FIELD, + "newValue" + ); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithNestedSourceAndDestinationMapConfig(true); + List inferenceList = Arrays.asList(TEXT_VALUE_1, TEXT_VALUE_1); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelID").inputTexts(inferenceList).build(); + List filteredInferenceList = Arrays.asList("newValue"); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelID") + .inputTexts(filteredInferenceList) + .build(); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + + List> parentAfterIngestProcessor = (List>) ingestDocument.getSourceAndMetadata() + .get(PARENT_FIELD); + + List> parentAfterUpdateProcessor = (List>) updateDocument.getSourceAndMetadata() + .get(PARENT_FIELD); + List> insertVectors = new ArrayList<>(); + List> updateVectors = new ArrayList<>(); + for (Map childActual : parentAfterIngestProcessor) { + Map childLevel1Actual = (Map) childActual.get(CHILD_FIELD_LEVEL_1); + assertNotNull(childLevel1Actual.get(CHILD_FIELD_LEVEL_2)); + Map childLevel2Actual = (Map) childLevel1Actual.get(CHILD_FIELD_LEVEL_2); + insertVectors.add((List) childLevel2Actual.get(CHILD_LEVEL_2_KNN_FIELD)); + } + + for (Map childActual : parentAfterUpdateProcessor) { + Map childLevel1Actual = (Map) childActual.get(CHILD_FIELD_LEVEL_1); + assertNotNull(childLevel1Actual.get(CHILD_FIELD_LEVEL_2)); + Map childLevel2Actual = (Map) childLevel1Actual.get(CHILD_FIELD_LEVEL_2); + updateVectors.add((List) childLevel2Actual.get(CHILD_LEVEL_2_KNN_FIELD)); + } + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + assertEquals(insertVectors.get(0).size(), updateVectors.get(0).size()); + verifyEqualEmbedding(insertVectors.get(1), updateVectors.get(1)); + } + + @SneakyThrows + public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_no_update_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("_id", "1"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_TEXT_FIELD_VALUE); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + sourceAndMetadata.put(PARENT_FIELD, childLevel1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = deepCopy(sourceAndMetadata); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfig(true); + List inferenceList = Arrays.asList(CHILD_LEVEL_2_TEXT_FIELD_VALUE); + TextInferenceRequest request = TextInferenceRequest.builder().modelId("mockModelID").inputTexts(inferenceList).build(); + mockUpdateDocument(ingestDocument); + mockVectorCreation(request, null); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + Map childLevel1AfterIngestProcessor = (Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel2AfterIngestProcessor = (Map) childLevel1AfterIngestProcessor.get(CHILD_FIELD_LEVEL_1); + assertEquals(CHILD_LEVEL_2_TEXT_FIELD_VALUE, childLevel2AfterIngestProcessor.get(CHILD_FIELD_LEVEL_2)); + assertNotNull(childLevel2AfterIngestProcessor.get(CHILD_LEVEL_2_KNN_FIELD)); + List ingestVectors = (List) childLevel2AfterIngestProcessor.get(CHILD_LEVEL_2_KNN_FIELD); + Map childLevel1AfterUpdatetProcessor = (Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel2AfterUpdateProcessor = (Map) childLevel1AfterUpdatetProcessor.get(CHILD_FIELD_LEVEL_1); + assertEquals(CHILD_LEVEL_2_TEXT_FIELD_VALUE, childLevel2AfterIngestProcessor.get(CHILD_FIELD_LEVEL_2)); + assertNotNull(childLevel2AfterUpdateProcessor.get(CHILD_LEVEL_2_KNN_FIELD)); + List updateVectors = (List) childLevel2AfterUpdateProcessor.get(CHILD_LEVEL_2_KNN_FIELD); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(1)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + assertEquals(request.getInputTexts(), inferenceRequestCaptor.getValue().getInputTexts()); + verifyEqualEmbedding(ingestVectors, updateVectors); + } + + @SneakyThrows + public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_with_update_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("_id", "1"); + Map childLevel2 = new HashMap<>(); + childLevel2.put(CHILD_FIELD_LEVEL_2, CHILD_LEVEL_2_TEXT_FIELD_VALUE); + Map childLevel1 = new HashMap<>(); + childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2); + sourceAndMetadata.put(PARENT_FIELD, childLevel1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map updateSourceAndMetadata = deepCopy(sourceAndMetadata); + ((Map) ((Map) updateSourceAndMetadata.get(PARENT_FIELD)).get(CHILD_FIELD_LEVEL_1)).put(CHILD_FIELD_LEVEL_2, "newValue"); + IngestDocument updateDocument = new IngestDocument(updateSourceAndMetadata, new HashMap<>()); + + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfig(true); + List inferenceList = Arrays.asList(CHILD_LEVEL_2_TEXT_FIELD_VALUE); + TextInferenceRequest ingestRequest = TextInferenceRequest.builder().modelId("mockModelID").inputTexts(inferenceList).build(); + List filteredInferenceList = Arrays.asList("newValue"); + TextInferenceRequest updateRequest = TextInferenceRequest.builder() + .modelId("mockModelID") + .inputTexts(filteredInferenceList) + .build(); + mockUpdateDocument(ingestDocument); + mockVectorCreation(ingestRequest, updateRequest); + + processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {}); + processor.execute(updateDocument, (BiConsumer) (doc, ex) -> {}); + Map childLevel1AfterIngestProcessor = (Map) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel2AfterIngestProcessor = (Map) childLevel1AfterIngestProcessor.get(CHILD_FIELD_LEVEL_1); + assertEquals(CHILD_LEVEL_2_TEXT_FIELD_VALUE, childLevel2AfterIngestProcessor.get(CHILD_FIELD_LEVEL_2)); + assertNotNull(childLevel2AfterIngestProcessor.get(CHILD_LEVEL_2_KNN_FIELD)); + List ingestVectors = (List) childLevel2AfterIngestProcessor.get(CHILD_LEVEL_2_KNN_FIELD); + Map childLevel1AfterUpdateProcessor = (Map) updateDocument.getSourceAndMetadata().get(PARENT_FIELD); + Map childLevel2AfterUpdateProcessor = (Map) childLevel1AfterUpdateProcessor.get(CHILD_FIELD_LEVEL_1); + assertEquals(CHILD_LEVEL_2_TEXT_FIELD_VALUE, childLevel2AfterIngestProcessor.get(CHILD_FIELD_LEVEL_2)); + assertNotNull(childLevel2AfterUpdateProcessor.get(CHILD_LEVEL_2_KNN_FIELD)); + List updateVectors = (List) childLevel2AfterUpdateProcessor.get(CHILD_LEVEL_2_KNN_FIELD); + verify(openSearchClient, times(2)).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + verify(mlCommonsClientAccessor, times(2)).inferenceSentences(inferenceRequestCaptor.capture(), isA(ActionListener.class)); + List requests = inferenceRequestCaptor.getAllValues(); + assertEquals(ingestRequest.getInputTexts(), requests.get(0).getInputTexts()); + assertEquals(updateRequest.getInputTexts(), requests.get(1).getInputTexts()); + assertEquals(ingestVectors.size(), updateVectors.size()); + } + public void testParsingNestedField_whenNestedFieldsConfigured_thenSuccessful() { Map config = createNestedMapConfiguration(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); @@ -1404,4 +2299,64 @@ private IngestDocument create2LevelNestedListWithNotEmbeddingFieldIngestDocument Map nestedList1 = buildObjMap(Pair.of("nestedField", nestedList)); return new IngestDocument(nestedList1, new HashMap<>()); } + + private void mockUpdateDocument(IngestDocument ingestDocument) { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mockEmptyGetResponse()); // returns empty result for ingest action + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(convertToGetResponse(ingestDocument)); // returns previously ingested document for update action + return null; + }).when(openSearchClient).execute(isA(GetAction.class), isA(GetRequest.class), isA(ActionListener.class)); + } + + private void mockVectorCreation(TextInferenceRequest ingestRequest, TextInferenceRequest updateRequest) { + doAnswer(invocation -> { + int numVectors = ingestRequest.getInputTexts().size(); + ActionListener>> listener = invocation.getArgument(1); + listener.onResponse(createRandomOneDimensionalMockVector(numVectors, 2, 0.0f, 1.0f)); + return null; + }).doAnswer(invocation -> { + int numVectors = updateRequest.getInputTexts().size(); + ActionListener>> listener = invocation.getArgument(1); + listener.onResponse(createRandomOneDimensionalMockVector(numVectors, 2, 0.0f, 1.0f)); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(isA(TextInferenceRequest.class), isA(ActionListener.class)); + } + + private void verifyEqualEmbedding(List insertVectors, List updateVectors) { + assertEquals(insertVectors.size(), updateVectors.size()); + for (int i = 0; i < insertVectors.size(); i++) { + assertEquals(insertVectors.get(i).floatValue(), updateVectors.get(i).floatValue(), 0.0000001f); + } + } + + private void verifyEqualEmbeddingInMap(List insertVectors, List updateVectors) { + assertEquals(insertVectors.size(), updateVectors.size()); + + for (int i = 0; i < insertVectors.size(); i++) { + Map insertMap = insertVectors.get(i); + Map updateMap = updateVectors.get(i); + for (Map.Entry entry : insertMap.entrySet()) { + List insertValue = entry.getValue(); + List updateValue = updateMap.get(entry.getKey()); + for (int j = 0; j < insertValue.size(); j++) { + assertEquals(insertValue.get(j).floatValue(), updateValue.get(j).floatValue(), 0.0000001f); + } + } + } + } + + private void verifyEqualEmbeddingInNestedList(List> insertVectors, List> updateVectors) { + assertEquals(insertVectors.size(), updateVectors.size()); + for (int i = 0; i < insertVectors.size(); i++) { + List insertVector = insertVectors.get(i); + List updateVector = updateVectors.get(i); + for (int j = 0; j < insertVectors.size(); j++) { + assertEquals(insertVector.get(j).floatValue(), updateVector.get(j).floatValue(), 0.0000001f); + } + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilterTests.java b/src/test/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilterTests.java new file mode 100644 index 000000000..b9bc38c4b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilterTests.java @@ -0,0 +1,235 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.optimization; + +import org.junit.Before; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class TextEmbeddingInferenceFilterTests extends OpenSearchTestCase { + + private Map sourceAndMetadataMap; + private Map existingSourceAndMetadataMap; + private TextEmbeddingInferenceFilter textEmbeddingInferenceFilter; + private TextEmbeddingInferenceFilter nestedTextEmbeddingInferenceFilter; + + @Before + public void setup() { + Map fieldMap = new HashMap<>(); + fieldMap.put("textField", "embeddingField"); + Map nestedFieldMap = new HashMap<>(); + nestedFieldMap.put("outerField", fieldMap); + textEmbeddingInferenceFilter = new TextEmbeddingInferenceFilter(fieldMap); + sourceAndMetadataMap = new HashMap<>(); + existingSourceAndMetadataMap = new HashMap<>(); + nestedTextEmbeddingInferenceFilter = new TextEmbeddingInferenceFilter(nestedFieldMap); + } + + public void test_filterInferenceValue_TextUnchanged_ShouldCopyEmbedding() { + String textPath = "textField"; + String embeddingPath = "embeddingField"; + String textValue = "Hello World"; + List embeddingValue = Arrays.asList(0.1, 0.2, 0.3); + + sourceAndMetadataMap.put(textPath, textValue); + existingSourceAndMetadataMap.put(textPath, textValue); + existingSourceAndMetadataMap.put(embeddingPath, embeddingValue); + + Object result = textEmbeddingInferenceFilter.filterInferenceValue( + embeddingPath, + textValue, + sourceAndMetadataMap, + existingSourceAndMetadataMap, + -1 + ); + assertNull(result); + assertEquals(embeddingValue, sourceAndMetadataMap.get(embeddingPath)); + } + + public void test_filterInferenceValue_TextChanged_ShouldNotCopyEmbedding() { + String textPath = "textField"; + String embeddingPath = "embeddingField"; + String newText = "New Text"; + String oldText = "Old Text"; + List embeddingValue = Arrays.asList(0.1, 0.2, 0.3); + + sourceAndMetadataMap.put(textPath, newText); + existingSourceAndMetadataMap.put(textPath, oldText); + existingSourceAndMetadataMap.put(embeddingPath, embeddingValue); + + Object result = textEmbeddingInferenceFilter.filterInferenceValue( + embeddingPath, + newText, + sourceAndMetadataMap, + existingSourceAndMetadataMap, + -1 + ); + + assertEquals(newText, result); + assertNull(sourceAndMetadataMap.get(embeddingPath)); + } + + public void test_filterInferenceValue_NoExistingEmbedding_ShouldNotCopy() { + String textPath = "textField"; + String embeddingPath = "embeddingField"; + String textValue = "Hello World"; + + sourceAndMetadataMap.put(textPath, textValue); + existingSourceAndMetadataMap.put(textPath, textValue); + existingSourceAndMetadataMap.put(embeddingPath, null); + + Object result = textEmbeddingInferenceFilter.filterInferenceValue( + embeddingPath, + textValue, + sourceAndMetadataMap, + existingSourceAndMetadataMap, + -1 + ); + + assertEquals(textValue, result); + assertNull(sourceAndMetadataMap.get(embeddingPath)); + } + + public void test_filterInferenceValuesInList_ListUnchanged_ShouldCopyAllEmbeddings() { + List processList = Arrays.asList("Text A", "Text B"); + List existingList = Arrays.asList("Text A", "Text B"); + List embeddingList = Arrays.asList(Arrays.asList(0.1, 0.2), Arrays.asList(0.3, 0.4)); + + String fullEmbeddingKey = "embeddingField"; + + List result = textEmbeddingInferenceFilter.filterInferenceValuesInList( + processList, + existingList, + embeddingList, + sourceAndMetadataMap, + existingSourceAndMetadataMap, + fullEmbeddingKey + ); + + assertTrue(result.isEmpty()); + assertEquals(embeddingList, sourceAndMetadataMap.get(fullEmbeddingKey)); + } + + public void test_filterInferenceValuesInList_ListPartiallyChanged_ShouldNotCopyEmbeddings() { + List processList = Arrays.asList("Text A", "New Text"); + List existingList = Arrays.asList("Text A", "Text B"); + List embeddingList = Arrays.asList(Arrays.asList(0.1, 0.2), Arrays.asList(0.3, 0.4)); + + String fullEmbeddingKey = "embeddingField"; + + List result = textEmbeddingInferenceFilter.filterInferenceValuesInList( + processList, + existingList, + embeddingList, + sourceAndMetadataMap, + existingSourceAndMetadataMap, + fullEmbeddingKey + ); + + assertEquals(processList.size(), result.size()); + assertNull(sourceAndMetadataMap.get(fullEmbeddingKey)); + } + + public void test_filterInferenceValuesInList_NoMatchingField_ShouldNotCopyEmbeddings() { + List processList = Arrays.asList("Text A", "Text B"); + String fullEmbeddingKey = "embeddingField"; + + List result = textEmbeddingInferenceFilter.filterInferenceValuesInList( + processList, + Collections.emptyList(), + Collections.emptyList(), + sourceAndMetadataMap, + existingSourceAndMetadataMap, + fullEmbeddingKey + ); + + assertEquals(processList, result); + assertNull(sourceAndMetadataMap.get(fullEmbeddingKey)); + } + + public void test_filter_nestedMapValue_Unchanged_ShouldCopyEmbeddings() { + Map nestedMap = new HashMap<>(); + nestedMap.put("embeddingField", "Hello World"); + Map processMap = new HashMap<>(); + processMap.put("outerField", nestedMap); + + Map existingNestedMap = new HashMap<>(); + existingNestedMap.put("textField", "Hello World"); + existingNestedMap.put("embeddingField", Arrays.asList(0.1, 0.2, 0.3)); + + Map existingMap = new HashMap<>(); + existingMap.put("outerField", existingNestedMap); + + Map result = nestedTextEmbeddingInferenceFilter.filter(existingMap, sourceAndMetadataMap, processMap); + + assertNull(((Map) result.get("outerField")).get("embeddingField")); + assertEquals(Arrays.asList(0.1, 0.2, 0.3), ((Map) sourceAndMetadataMap.get("outerField")).get("embeddingField")); + } + + public void testFilter_nestedMapValue_PartiallyChanged_ShouldNotCopyEmbeddings() { + Map nestedMap = new HashMap<>(); + nestedMap.put("textField", "New Text"); + + Map processMap = new HashMap<>(); + processMap.put("outerField", nestedMap); + + Map existingNestedMap = new HashMap<>(); + existingNestedMap.put("textField", "Old Text"); + + Map existingMap = new HashMap<>(); + existingMap.put("outerField", existingNestedMap); + existingMap.put("embeddingField", Arrays.asList(0.1, 0.2, 0.3)); + + Map result = nestedTextEmbeddingInferenceFilter.filter(existingMap, sourceAndMetadataMap, processMap); + + assertFalse(result.isEmpty()); + assertEquals("New Text", ((Map) result.get("outerField")).get("textField")); + assertNull(sourceAndMetadataMap.get("outerField")); + } + + public void test_filter_nestedListValue_Unchanged_ShouldCopyEmbeddings() { + Map nestedMap = new HashMap<>(); + nestedMap.put("embeddingField", Arrays.asList("Hello World", "Bye World")); + Map processMap = new HashMap<>(); + processMap.put("outerField", nestedMap); + + Map existingNestedMap = new HashMap<>(); + existingNestedMap.put("textField", Arrays.asList("Hello World", "Bye World")); + existingNestedMap.put("embeddingField", Arrays.asList(Arrays.asList(0.1, 0.2, 0.3), Arrays.asList(0.4, 0.5, 0.6))); + + Map existingMap = new HashMap<>(); + existingMap.put("outerField", existingNestedMap); + + Map result = nestedTextEmbeddingInferenceFilter.filter(existingMap, sourceAndMetadataMap, processMap); + + assertEquals(0, ((List) ((Map) result.get("outerField")).get("embeddingField")).size()); + assertEquals(Arrays.asList(0.1, 0.2, 0.3), ((List) ((Map) sourceAndMetadataMap.get("outerField")).get("embeddingField")).get(0)); + assertEquals(Arrays.asList(0.4, 0.5, 0.6), ((List) ((Map) sourceAndMetadataMap.get("outerField")).get("embeddingField")).get(1)); + } + + public void test_filter_nestedListValue_PartiallyChanged_ShouldNotCopyEmbeddings() { + Map nestedMap = new HashMap<>(); + nestedMap.put("embeddingField", Arrays.asList("Hello World", "Bye World")); + Map processMap = new HashMap<>(); + processMap.put("outerField", nestedMap); + + Map existingNestedMap = new HashMap<>(); + existingNestedMap.put("textField", Arrays.asList("Hello World", "Goodbye World")); + existingNestedMap.put("embeddingField", Arrays.asList(Arrays.asList(0.1, 0.2, 0.3), Arrays.asList(0.4, 0.5, 0.6))); + + Map existingMap = new HashMap<>(); + existingMap.put("outerField", existingNestedMap); + + Map result = nestedTextEmbeddingInferenceFilter.filter(existingMap, sourceAndMetadataMap, processMap); + + assertEquals(2, ((List) ((Map) result.get("outerField")).get("embeddingField")).size()); + assertNull(sourceAndMetadataMap.get("outerField")); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java index 2a08a350d..075cbdf19 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java @@ -224,4 +224,51 @@ private void testUnflatten_withInvalidUsageOfDots_thenFail(String fieldName, Map assert (illegalArgumentException.getMessage() .contains(String.format(Locale.ROOT, "Field name '%s' contains invalid dot usage", fieldName))); } + + public void testFlattenAndFlip_withMultipleLevelsSeparatedByDots_thenSuccess() { + /* + * parent + * child_level1 + * child_leve1_text_field: child_level2.text_field_knn + * */ + Map childLevel1 = Map.of("child_leve1_text_field", "child_level2.text_field_knn"); + Map parentMap = Map.of("child_level1", childLevel1); + Map nestedMap = Map.of("parent", parentMap); + + Map expected = Map.of( + "parent.child_level1.child_level2.text_field_knn", + "parent.child_level1.child_leve1_text_field" + ); + + Map actual = ProcessorDocumentUtils.flattenAndFlip(nestedMap); + assertEquals(expected, actual); + } + + public void testFlattenAndFlip_withMultipleLevelsWithNestedMaps_thenSuccess() { + /* + * parent + * child_level1 + * child_level2 + * child_level2_text: child_level2_knn + * child_level3 + * child_level4: + * child_level4_text:child_level4_knn + * */ + Map childLevel4 = Map.of("child_level4_text", "child_level4_knn"); + Map childLevel3 = Map.of("child_level4", childLevel4); + Map childLevel2 = Map.of("child_level2_text", "child_level2_knn"); + Map childLevel1 = Map.of("child_level2", childLevel2, "child_level3", childLevel3); + Map parentMap = Map.of("child_level1", childLevel1); + Map nestedMap = Map.of("parent", parentMap); + + Map expected = Map.of( + "parent.child_level1.child_level2.child_level2_knn", + "parent.child_level1.child_level2.child_level2_text", + "parent.child_level1.child_level3.child_level4.child_level4_knn", + "parent.child_level1.child_level3.child_level4.child_level4_text" + ); + + Map actual = ProcessorDocumentUtils.flattenAndFlip(nestedMap); + assertEquals(expected, actual); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java index 101060ca1..54a30fe8b 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java @@ -27,6 +27,7 @@ import static org.mockito.Mockito.verify; import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource; import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.setValueToSource; import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria; public class ProcessorUtilsTests extends OpenSearchTestCase { @@ -213,4 +214,19 @@ public void testRemoveTargetFieldFromSource_successfullyDeletesTargetField_WithE assertFalse("The ml map no longer has the score `info` mapping ", innerMLMap.containsKey("info")); } + public void testSetValueToSource_returnsExpectedValues_WithExistingKeys() { + String targetPath = "ml.newField"; + String targetValue = "newValue"; + setUpValidSourceMap(); + setValueToSource(sourceMap, targetPath, targetValue); + Map innerMLMap = (Map) sourceMap.get("ml"); + assertEquals(targetValue, innerMLMap.get("newField")); + } + + public void testGetValueToSource_returnsExpectedValues_WithNonExistingPath() { + String targetField = "ml.wrong"; + setUpValidSourceMap(); + Optional result = getValueFromSource(sourceMap, targetField); + assertTrue(result.isEmpty()); + } } diff --git a/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json b/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json new file mode 100644 index 000000000..19f58db36 --- /dev/null +++ b/src/test/resources/processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json @@ -0,0 +1,20 @@ +{ + "description": "text embedding pipeline for hybrid", + "processors": [ + { + "text_embedding": { + "model_id": "%s", + "field_map": { + "title": "title_knn", + "favor_list": "favor_list_knn", + "favorites": { + "game": "game_knn", + "movie": "movie_knn" + }, + "nested_passages.level_2.level_3_text": "level_3_container.level_3_embedding" + }, + "skip_existing": true + } + } + ] +} diff --git a/src/test/resources/processor/PipelineConfigurationWithSkipExisting.json b/src/test/resources/processor/PipelineConfigurationWithSkipExisting.json new file mode 100644 index 000000000..058d861f4 --- /dev/null +++ b/src/test/resources/processor/PipelineConfigurationWithSkipExisting.json @@ -0,0 +1,23 @@ +{ + "description": "text embedding pipeline for optimized inference call", + "processors": [ + { + "text_embedding": { + "model_id": "%s", + "batch_size": "%d", + "field_map": { + "title": "title_knn", + "favor_list": "favor_list_knn", + "favorites": { + "game": "game_knn", + "movie": "movie_knn" + }, + "nested_passages": { + "text": "embedding" + } + }, + "skip_existing": true + } + } + ] +} diff --git a/src/test/resources/processor/update_doc1.json b/src/test/resources/processor/update_doc1.json new file mode 100644 index 000000000..da4d8fd5a --- /dev/null +++ b/src/test/resources/processor/update_doc1.json @@ -0,0 +1,25 @@ +{ + "title": "This is a good day", + "text": "%s", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": [ + { + "text_not_for_embedding": "test" + }, + { + "text": "bye" + }, + { + "text": "world" + } + ] +} diff --git a/src/test/resources/processor/update_doc2.json b/src/test/resources/processor/update_doc2.json new file mode 100644 index 000000000..455955d72 --- /dev/null +++ b/src/test/resources/processor/update_doc2.json @@ -0,0 +1,23 @@ +{ + "title": "this is a second doc", + "text": "%s", + "description": "the description is not very long", + "favor_list": [ + "favor" + ], + "favorites": { + "game": "silver state", + "movie": null + }, + "nested_passages": [ + { + "text_not_for_embedding": "test" + }, + { + "text": "apple" + }, + { + "text": "banana" + } + ] +} diff --git a/src/test/resources/processor/update_doc3.json b/src/test/resources/processor/update_doc3.json new file mode 100644 index 000000000..597001b8f --- /dev/null +++ b/src/test/resources/processor/update_doc3.json @@ -0,0 +1,23 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "test", + "hello", + "mock" + ], + "favorites": { + "game": "overwatch", + "movie": null + }, + "nested_passages": + { + "level_2": + { + "level_3_text": "hello", + "level_3_container": { + "level_4_text_field": "def" + } + } + } +} diff --git a/src/test/resources/processor/update_doc4.json b/src/test/resources/processor/update_doc4.json new file mode 100644 index 000000000..9cecc27a3 --- /dev/null +++ b/src/test/resources/processor/update_doc4.json @@ -0,0 +1,20 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "key", + "hey", + "click" + ], + "favorites": { + "game": "cossacks", + "movie": "matrix" + }, + "nested_passages": + { + "level_2": + { + "level_3_text": "joker" + } + } +} diff --git a/src/test/resources/processor/update_doc5.json b/src/test/resources/processor/update_doc5.json new file mode 100644 index 000000000..65f2c13f2 --- /dev/null +++ b/src/test/resources/processor/update_doc5.json @@ -0,0 +1,24 @@ +{ + "title": "This is a good day", + "description": "daily logging", + "favor_list": [ + "key", + "hey", + "click" + ], + "favorites": { + "game": "cossacks", + "movie": "matrix" + }, + "nested_passages":[ + { + "level_2": + { + "level_3_text": "joker" + } + }, + { + "level_2.level_3_text": "superman" + } + ] +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 6045599c8..6819606fa 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -106,6 +106,10 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json", ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, "processor/PipelineConfigurationWithNestedFieldsMapping.json", + ProcessorType.TEXT_EMBEDDING_WITH_SKIP_EXISTING, + "processor/PipelineConfigurationWithSkipExisting.json", + ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING_WITH_SKIP_EXISTING, + "processor/PipelineConfigurationWithNestedFieldsMappingWithSkipExisting.json", ProcessorType.SPARSE_ENCODING_PRUNE, "processor/SparseEncodingPipelineConfigurationWithPrune.json" ); @@ -248,7 +252,10 @@ protected void loadModel(final String modelId) throws Exception { isComplete = checkComplete(taskQueryResult); Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); } - assertTrue(String.format(Locale.ROOT, "failed to load the model, last task finished with status %s", taskQueryResult.get("state")), isComplete); + assertTrue( + String.format(Locale.ROOT, "failed to load the model, last task finished with status %s", taskQueryResult.get("state")), + isComplete + ); } /** @@ -1378,7 +1385,7 @@ protected void createIndexWithPipeline(String indexName, String indexMappingFile * @param id nullable optional id * @throws Exception */ - protected String ingestDocument(String indexName, String ingestDocument, String id) throws Exception { + protected String ingestDocument(String indexName, String ingestDocument, String id, boolean isUpdate) throws Exception { String endpoint; if (StringUtils.isEmpty(id)) { endpoint = indexName + "/_doc?refresh"; @@ -1400,7 +1407,11 @@ protected String ingestDocument(String indexName, String ingestDocument, String ); String result = (String) map.get("result"); - assertEquals("created", result); + if (isUpdate) { + assertEquals("updated", result); + } else { + assertEquals("created", result); + } return result; } @@ -1411,7 +1422,22 @@ protected String ingestDocument(String indexName, String ingestDocument, String * @throws Exception */ protected String ingestDocument(String indexName, String ingestDocument) throws Exception { - return ingestDocument(indexName, ingestDocument, null); + return ingestDocument(indexName, ingestDocument, null, false); + } + + protected String ingestDocument(String indexName, String ingestDocument, String id) throws Exception { + return ingestDocument(indexName, ingestDocument, id, false); + } + + /** + * Update a document to index using auto generated id + * @param indexName name of the index + * @param ingestDocument + * @param id + * @throws Exception + */ + protected String updateDocument(String indexName, String ingestDocument, String id) throws Exception { + return ingestDocument(indexName, ingestDocument, id, true); } /** @@ -1731,6 +1757,8 @@ protected Object validateDocCountAndInfo( protected enum ProcessorType { TEXT_EMBEDDING, TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, + TEXT_EMBEDDING_WITH_SKIP_EXISTING, + TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING_WITH_SKIP_EXISTING, TEXT_IMAGE_EMBEDDING, SPARSE_ENCODING, SPARSE_ENCODING_PRUNE