Skip to content

Commit

Permalink
refactor with added integration test
Browse files Browse the repository at this point in the history
Signed-off-by: Will Hwang <sang7239@gmail.com>
  • Loading branch information
will-hwang committed Feb 24, 2025
1 parent 710354d commit 869d686
Show file tree
Hide file tree
Showing 19 changed files with 658 additions and 238 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
Expand Down Expand Up @@ -55,31 +56,34 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {
public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
if (v1 instanceof List<?> && v2 instanceof List<?>) {
List<Object> v1List = new ArrayList<>((List<?>) v1);
List<?> v2List = (List<?>) v2;

Iterator<?> iterator = v2List.iterator();
for (int i = 0; i < v1List.size() && iterator.hasNext(); i++) {
if (v1List.get(i) == null) {
v1List.set(i, iterator.next());
/**
* REMAPPING_FUNCTION is invoked when merge is invoked on Map and takes two arguments:
*
* v1: The current (existing) value in the map.
* v2: The new value you're trying to merge in
*
* The function returns the value that should be associated with the key
*
* In case of type Collection (always List type currently) REMAPPING_FUNCTION joins v1 and v2 by inserting values in v2 to v1 where
* the index is marked null
* In case of type Maps, REMAPPING_FUNCTION joins v1 and v2 by putting values in v2 to v1
* */

if (v1 instanceof Collection && v2 instanceof Collection) {
Iterator<?> iterator = ((Collection) v2).iterator();
for (int i = 0; i < ((Collection) v1).size(); i++) {
if (((List) v1).get(i) == null) {
((List) v1).set(i, iterator.next());
}
}
return v1List;
}

if (v1 instanceof Map<?, ?> && v2 instanceof Map<?, ?>) {
Map<String, Object> v1Map = new LinkedHashMap<>((Map<String, Object>) v1);
Map<?, ?> v2Map = (Map<?, ?>) v2;

for (Map.Entry<?, ?> entry : v2Map.entrySet()) {
if (entry.getKey() instanceof String && !v1Map.containsKey(entry.getKey())) {
v1Map.put((String) entry.getKey(), entry.getValue());
}
}
return v1Map;
assert iterator.hasNext() == false;
return v1;
} else if (v1 instanceof Map && v2 instanceof Map) {
((Map) v1).putAll((Map) v2);
return v1;
} else {
return v2;
}
return v2 != null ? v2 : v1;

};

Expand Down Expand Up @@ -149,8 +153,9 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
/**
* This method will be invoked by PipelineService to make async inference and then delegate the handler to
* process the inference response or failure.
*
* @param ingestDocument {@link IngestDocument} which is the document passed to processor.
* @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done.
* @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done.
*/
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
Expand Down Expand Up @@ -180,9 +185,10 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {

/**
* This is the function which does actual inference work for batchExecute interface.
*
* @param inferenceList a list of String for inference.
* @param handler a callback handler to handle inference results which is a list of objects.
* @param onException an exception callback to handle exception.
* @param handler a callback handler to handle inference results which is a list of objects.
* @param onException an exception callback to handle exception.
*/
protected abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);

Expand Down Expand Up @@ -215,7 +221,8 @@ public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers,
setVectorFieldsToDocument(
dataForInference.getIngestDocumentWrapper().getIngestDocument(),
dataForInference.getProcessMap(),
inferenceResults
inferenceResults,
false
);
}
handler.accept(ingestDocumentWrappers);
Expand Down Expand Up @@ -393,6 +400,7 @@ private Object normalizeSourceValue(Object value) {

/**
* Process the nested key, such as "a.b.c" to "a", "b.c"
*
* @param nestedFieldMapEntry
* @return A pair of the original key and the target key
*/
Expand Down Expand Up @@ -429,39 +437,39 @@ protected void setVectorFieldsToDocument(
IngestDocument ingestDocument,
Map<String, Object> processorMap,
List<?> results,
boolean update
boolean partialUpdate
) {
Objects.requireNonNull(results, "embedding failed, inference returns null result!");
log.debug("Model inference result fetched, starting build vector output!");
Map<String, Object> nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata());
if (update) {
// if partialUpdate is set to false, full update is required where each vector embedding in nlpResult
// can directly be populated to ingestDocument
if (partialUpdate == false) {
Map<String, Object> nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata());
nlpResult.forEach(ingestDocument::setFieldValue);
} else {
// if partialUpdate is set to true, some embeddings have been copied over from existing document, which vector embedding in
// nlpResult needs to be populated to only where marked as 'null'
Map<String, Object> nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata());
for (Map.Entry<String, Object> nlpEntry : nlpResult.entrySet()) {
String key = nlpEntry.getKey();
Object target = ingestDocument.getSourceAndMetadata().get(key);
Object nlpValues = nlpEntry.getValue();
if (target instanceof List<?> targetList && nlpValues instanceof List<?> nlpValueList) {
List<Object> list = new ArrayList<>(targetList);
ListIterator<Object> iterator = list.listIterator();
ListIterator<Object> nlpIterator = (ListIterator<Object>) nlpValueList.listIterator();
while (iterator.hasNext() && nlpIterator.hasNext()) {
if (target instanceof List && nlpValues instanceof List) {
ListIterator iterator = ((List) target).listIterator();
ListIterator nlpIterator = ((List) nlpValues).listIterator();
while (iterator.hasNext()) {
if (iterator.next() == null) {
iterator.set(nlpIterator.next());
}
}
ingestDocument.setFieldValue(key, list);
ingestDocument.setFieldValue(key, target);
} else {
ingestDocument.setFieldValue(key, nlpValues);
}
}
} else {
nlpResult.forEach(ingestDocument::setFieldValue);
}
}

protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
setVectorFieldsToDocument(ingestDocument, processorMap, results, false);
}

@SuppressWarnings({ "unchecked" })
@VisibleForTesting
Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
Expand Down Expand Up @@ -560,18 +568,17 @@ private void processMapEntryValue(
) {
// build nlp output for object in sourceValue which is map type
Iterator<Map<String, Object>> iterator = sourceAndMetadataMapValueInList.iterator();
IndexWrapper listIndexWrapper = new IndexWrapper(0);
for (int i = 0; i < sourceAndMetadataMapValueInList.size(); i++) {
IntStream.range(0, sourceAndMetadataMapValueInList.size()).forEach(index -> {
Map<String, Object> nestedElement = iterator.next();
putNLPResultToSingleSourceMapInList(
inputNestedMapEntryKey,
inputNestedMapEntryValue,
results,
indexWrapper,
nestedElement,
listIndexWrapper
index
);
}
});
}

/**
Expand All @@ -591,7 +598,7 @@ private void putNLPResultToSingleSourceMapInList(
List<?> results,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap,
IndexWrapper listIndexWrapper
int nestedIndex
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
if (sourceValue instanceof Map) {
Expand All @@ -604,15 +611,12 @@ private void putNLPResultToSingleSourceMapInList(
results,
indexWrapper,
sourceMap,
listIndexWrapper
nestedIndex
);
}
} else {
if (sourceValue instanceof List) {
if (sourceAndMetadataMap.containsKey(processorKey)) {
return;
}
if (((List<Object>) sourceValue).get(listIndexWrapper.index++) != null) {
if (((List<Object>) sourceValue).get(nestedIndex) != null) {
sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION);
}
}
Expand All @@ -632,37 +636,35 @@ private Map<String, Object> getSourceMapBySourceAndMetadataMap(String processorK

private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
IntStream.range(0, sourceValue.size()).forEachOrdered(x -> {
if (sourceValue.get(x) != null) { // only add to keyToResult when sourceValue.get(x) exists,
keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)));
}
});
return keyToResult;
}

/**
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
*
* @param ingestDocument ingestDocument to populate embeddings to
* @param ProcessMap map indicating the path in ingestDocument to populate embeddings
* @param inferenceList list of texts to be model inference
* @param handler SourceAndMetadataMap of ingestDocument Document
* @param partialUpdate set to false if ingestDocument is ingested for the first time, set to true if ingestDocument is an update operation
*
*/
protected void makeInferenceCall(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler,
boolean update
) {
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors, update);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

protected void makeInferenceCall(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
boolean partialUpdate
) {
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors, partialUpdate);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void doExecute(
.stream()
.map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector))
.toList();
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors);
setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors, false);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
public static final String IGNORE_EXISTING = "ignore_existing";
public static final boolean DEFAULT_IGNORE_EXISTING = Boolean.FALSE;
public static final String SKIP_EXISTING = "skip_existing";
public static final boolean DEFAULT_SKIP_EXISTING = Boolean.FALSE;

public TextEmbeddingProcessor(
String tag,
Expand All @@ -52,7 +52,7 @@ public void doExecute(
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors, false);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty;
import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_IGNORE_EXISTING;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.IGNORE_EXISTING;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.SKIP_EXISTING;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_SKIP_EXISTING;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
Expand All @@ -20,7 +20,7 @@
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.optimization.OptimizedTextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.optimization.SelectiveTextEmbeddingProcessor;
import org.opensearch.transport.client.OpenSearchClient;

/**
Expand Down Expand Up @@ -53,9 +53,9 @@ public TextEmbeddingProcessorFactory(
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
boolean ignoreExisting = readBooleanProperty(TYPE, tag, config, IGNORE_EXISTING, DEFAULT_IGNORE_EXISTING);
if (ignoreExisting == true) {
return new OptimizedTextEmbeddingProcessor(
boolean skipExisting = readBooleanProperty(TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING);
if (skipExisting == true) {
return new SelectiveTextEmbeddingProcessor(
tag,
description,
batchSize,
Expand Down
Loading

0 comments on commit 869d686

Please sign in to comment.