Skip to content

Commit

Permalink
refactor filter logic to a separate class
Browse files Browse the repository at this point in the history
Signed-off-by: will-hwang <sang7239@gmail.com>
  • Loading branch information
will-hwang committed Feb 27, 2025
1 parent aaf44d3 commit 90851b6
Show file tree
Hide file tree
Showing 15 changed files with 1,825 additions and 1,886 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -580,11 +580,10 @@ private Map<String, Object> getSourceMapBySourceAndMetadataMap(String processorK

private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size()).forEachOrdered(x -> {
if (sourceValue.get(x) != null) { // only add to keyToResult when sourceValue.get(x) exists,
keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)));
}
});
sourceValue.stream()
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
// sourceValue has been filtered
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
import org.opensearch.transport.client.OpenSearchClient;

/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
Expand All @@ -28,18 +34,29 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
public static final String SKIP_EXISTING = "skip_existing";
public static final boolean DEFAULT_SKIP_EXISTING = Boolean.FALSE;
private static final String INDEX_FIELD = "_index";
private static final String ID_FIELD = "_id";
private final OpenSearchClient openSearchClient;
private final boolean skipExisting;
private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;

public TextEmbeddingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
boolean skipExisting,
TextEmbeddingInferenceFilter textEmbeddingInferenceFilter,
OpenSearchClient openSearchClient,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
this.skipExisting = skipExisting;
this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter;
this.openSearchClient = openSearchClient;
}

@Override
Expand All @@ -49,13 +66,41 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
if (skipExisting) { // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings
// have been copied
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
final Map<String, Object> existingDocument = response.getSourceAsMap();
if (existingDocument == null || existingDocument.isEmpty()) {
makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler);
} else {
// filter given ProcessMap by comparing existing document with ingestDocument
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(
existingDocument,
ingestDocument.getSourceAndMetadata(),
ProcessMap
);
// create inference list based on filtered ProcessMap
List<String> filteredInferenceList = createInferenceList(filteredProcessMap).stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (!filteredInferenceList.isEmpty()) {
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
} else {
handler.accept(ingestDocument, null);
}
}
}, e -> { handler.accept(null, e); }));
} else { // skip existing flag is turned off. Call model inference without filtering
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.optimization.SelectiveTextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
import org.opensearch.transport.client.OpenSearchClient;

/**
Expand Down Expand Up @@ -54,19 +54,19 @@ protected AbstractBatchingProcessor newProcessor(String tag, String description,
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
boolean skipExisting = readBooleanProperty(TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING);
if (skipExisting == true) {
return new SelectiveTextEmbeddingProcessor(
tag,
description,
batchSize,
modelId,
fieldMap,
openSearchClient,
clientAccessor,
environment,
clusterService
);
}
return new TextEmbeddingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService);
TextEmbeddingInferenceFilter textEmbeddingInferenceFilter = new TextEmbeddingInferenceFilter(fieldMap);
return new TextEmbeddingProcessor(
tag,
description,
batchSize,
modelId,
fieldMap,
skipExisting,
textEmbeddingInferenceFilter,
openSearchClient,
clientAccessor,
environment,
clusterService
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.optimization;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
* Abstract class for selective text processing and embedding optimization.
* The InferenceFilter class is designed optimize inference calls by selectively processing text data.
* It reuses existing embeddings when the text content remains unchanged, reducing redundant inference calls and
* improving performance. This is achieved by comparing the text in new and existing documents, copying embeddings
* when the text is identical.
* This class is intended to be extended for different text processing use cases. It provides a recursive filtering
* mechanism that navigates through nested map structures, comparing values, and determining if embeddings can be reused.
*/
@Log4j2
public abstract class InferenceFilter {
/**
* Stores the reverse mapping of field names to support efficient lookups for embedding keys.
* This is generated by flattening and flipping the provided field map.
*/
protected Map<String, String> reversedFieldMap;

/**
* Constructs an InferenceFilter instance and initializes the reversed field map.
*/
public InferenceFilter(Map<String, Object> fieldMap) {
this.reversedFieldMap = ProcessorDocumentUtils.flattenAndFlip(fieldMap);
}

/**
* Abstract method to filter individual values based on the existing and new metadata maps.
* Implementations should provide logic to compare values and determine if embeddings can be reused.
*
* @param currentPath The current dot-notation path for the value being processed.
* @param processValue The value to be checked for potential embedding reuse.
* @param sourceAndMetadataMap The metadata map of the new document.
* @param existingSourceAndMetadataMap The metadata map of the existing document.
* @param index The index of the value in the list, if applicable (-1 if not part of a list).
* @return The processed value or null if embeddings are reused.
*/

public abstract Object filterInferenceValue(
String currentPath,
Object processValue,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> existingSourceAndMetadataMap,
int index
);

/**
* Abstract method to filter and compare lists of values.
* If all elements in the list are identical between the new and existing metadata maps, embeddings are copied,
* and an empty list is returned to indicate no further processing is required.
*
* @param processList The list of values to be processed.
* @param existingList The list of existing values for comparison.
* @param embeddingList The list of existing embeddings.
* @param sourceAndMetadataMap The metadata map of the new document.
* @param existingSourceAndMetadataMap The metadata map of the existing document.
* @param fullEmbeddingKey The dot-notation path for the embedding field.
* @return A processed list or an empty list if embeddings are reused.
*/

public abstract List<Object> filterInferenceValuesInList(
List<Object> processList,
List<Object> existingList,
List<Object> embeddingList,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> existingSourceAndMetadataMap,
String fullEmbeddingKey
);

/**
* This method navigates through the nested structure, checking each key-value pair recursively. It supports:
* Map values: Processed recursively using this method.
* List values: Processed using filterListValue.
* Primitive values: Directly compared using filterInferenceValue.
*
* @param existingSourceAndMetadataMap The metadata map of the existing document.
* @param sourceAndMetadataMap The metadata map of the new document.
* @param processMap The current map being processed.
*
* @return A filtered map containing only elements that require new embeddings.
*
*/
public Map<String, Object> filter(
Map<String, Object> existingSourceAndMetadataMap,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> processMap
) {
return filter(existingSourceAndMetadataMap, sourceAndMetadataMap, processMap, "");
}

/**
* Helper method for filter
* @param existingSourceAndMetadataMap The metadata map of the existing document.
* @param sourceAndMetadataMap The metadata map of the new document.
* @param processMap The current map being processed.
* @param traversedPath The dot-notation path of the previously traversed elements.
* e.g:
* In a map structured like:
* level1
* level2
* level3
* traversedPath would be dot-separated string of level1.level2.level3
* @return A filtered map containing only elements that require new embeddings
*/
private Map<String, Object> filter(
Map<String, Object> existingSourceAndMetadataMap,
Map<String, Object> sourceAndMetadataMap,
Object processMap,
String traversedPath
) {
if (processMap instanceof Map == false) {
throw new IllegalArgumentException("processMap needs to be an instanceof Map");
}
Map<String, Object> filteredProcessMap = new HashMap<>();
Map<String, Object> castedProcessMap = ProcessorUtils.castToMap(processMap);
for (Map.Entry<?, ?> entry : castedProcessMap.entrySet()) {
if ((entry.getKey() instanceof String) == false) {
throw new IllegalArgumentException("key for processMap must be a string");
}
String key = (String) entry.getKey();
Object value = entry.getValue();
String currentPath = traversedPath.isEmpty() ? key : traversedPath + "." + key;
if (value instanceof Map<?, ?>) {
Map<String, Object> filteredInnerMap = filter(existingSourceAndMetadataMap, sourceAndMetadataMap, value, currentPath);
filteredProcessMap.put(key, filteredInnerMap.isEmpty() ? null : filteredInnerMap);
} else if (value instanceof List) {
List<Object> processedList = filterListValue(
currentPath,
(List<Object>) value,
sourceAndMetadataMap,
existingSourceAndMetadataMap
);
filteredProcessMap.put(key, processedList);
} else {
Object processedValue = filterInferenceValue(currentPath, value, sourceAndMetadataMap, existingSourceAndMetadataMap, -1);
filteredProcessMap.put(key, processedValue);
}
}
return filteredProcessMap;
}

/**
* Processes a list of values by comparing them against source and existing metadata.
*
* @param embeddingKey The current path in dot notation for the list being processed
* @param processList The list of values to process
* @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument Document
* @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document
* @return A processed list containing non-filtered elements
*/
protected List<Object> filterListValue(
String embeddingKey,
List<Object> processList,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> existingSourceAndMetadataMap
) {
String textKey = reversedFieldMap.get(embeddingKey);
if (Objects.isNull(textKey)) {
return processList;
}
Optional<Object> existingList = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textKey);
if (existingList.isPresent() == false) {
// if list is not present with given text key, process each entry in processList individually
return filterMapValuesInList(processList, embeddingKey, sourceAndMetadataMap, existingSourceAndMetadataMap);
}
// if list is present with given textKey, but is not a list type, no valid comparison can be made, return processList
if (existingList.get() instanceof List == false) {
return processList;
}
// retrieve embedding for given embedding key, if not present or is not a list type, return processList, as no embedding can be
// copied over
Optional<Object> embeddingList = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, embeddingKey);
if (embeddingList.isPresent() == false || embeddingList.get() instanceof List == false) {
return processList;
}
// return empty list if processList and existingList are equal and embeddings are copied, return empty list otherwise
return filterInferenceValuesInList(
processList,
(List<Object>) existingList.get(),
(List<Object>) embeddingList.get(),
sourceAndMetadataMap,
existingSourceAndMetadataMap,
embeddingKey
);
}

/**
* Processes a list containing map values by iterating through each item and processing it individually.
*
* @param processList The list of Map items to process
* @param currentPath The current path in dot notation
* @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument Document
* @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document
* @return A processed list containing non-filtered elements
*/
public List<Object> filterMapValuesInList(
List<Object> processList,
String currentPath,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> existingSourceAndMetadataMap
) {
List<Object> filteredList = new ArrayList<>();
Iterator<Object> iterator = processList.iterator();
int index = 0;
while (iterator.hasNext()) {
Object processedItem = filterInferenceValue(
currentPath,
iterator.next(),
sourceAndMetadataMap,
existingSourceAndMetadataMap,
index++
);
filteredList.add(processedItem);
}
return filteredList;
}
}
Loading

0 comments on commit 90851b6

Please sign in to comment.