Skip to content

Commit

Permalink
implement single document update scenario in text embedding processor
Browse files Browse the repository at this point in the history
Signed-off-by: will-hwang <sang7239@gmail.com>
  • Loading branch information
will-hwang committed Feb 28, 2025
1 parent 628cb64 commit fc83219
Show file tree
Hide file tree
Showing 24 changed files with 2,280 additions and 132 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Add Optimized Text Embedding Processor ([#1191](https://github.com/opensearch-project/neural-search/pull/1191))
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
- Support different embedding types in model's response ([#1007](https://github.com/opensearch-project/neural-search/pull/1007))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
new TextEmbeddingProcessorFactory(
parameters.client,
clientAccessor,
parameters.env,
parameters.ingestService.getClusterService()
),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
TextImageEmbeddingProcessor.TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -118,7 +119,7 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {

public abstract void doExecute(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
Map<String, Object> processMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
);
Expand Down Expand Up @@ -278,7 +279,7 @@ private static class DataForInference {
}

@SuppressWarnings({ "unchecked" })
private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
protected List<String> createInferenceList(Map<String, Object> knnKeyMap) {
List<String> texts = new ArrayList<>();
knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
Object sourceValue = knnMapEntry.getValue();
Expand Down Expand Up @@ -579,11 +580,37 @@ private Map<String, Object> getSourceMapBySourceAndMetadataMap(String processorK

private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
List<Map<String, Object>> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
sourceValue.stream()
.filter(Objects::nonNull) // explicit null check is required since sourceValue can contain null values in cases where
// sourceValue has been filtered
.forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++))));
return keyToResult;
}

/**
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
*
* @param ingestDocument ingestDocument to populate embeddings to
* @param processMap map indicating the path in ingestDocument to populate embeddings
* @param inferenceList list of texts to be model inference
* @param handler SourceAndMetadataMap of ingestDocument Document
*
*/
protected void makeInferenceCall(
IngestDocument ingestDocument,
Map<String, Object> processMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, processMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

@Override
public String getType() {
return type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
import org.opensearch.transport.client.OpenSearchClient;

/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
Expand All @@ -26,34 +32,70 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
public static final String SKIP_EXISTING = "skip_existing";
public static final boolean DEFAULT_SKIP_EXISTING = false;
private static final String INDEX_FIELD = "_index";
private static final String ID_FIELD = "_id";
private final OpenSearchClient openSearchClient;
private final boolean skipExisting;
private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;

public TextEmbeddingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
boolean skipExisting,
TextEmbeddingInferenceFilter textEmbeddingInferenceFilter,
OpenSearchClient openSearchClient,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
this.skipExisting = skipExisting;
this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter;
this.openSearchClient = openSearchClient;
}

@Override
public void doExecute(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
Map<String, Object> processMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
// skip existing flag is turned off. Call model inference without filtering
if (skipExisting == false) {
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
return;
}
// if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
final Map<String, Object> existingDocument = response.getSourceAsMap();
if (existingDocument == null || existingDocument.isEmpty()) {
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
} else {
// filter given ProcessMap by comparing existing document with ingestDocument
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(
existingDocument,
ingestDocument.getSourceAndMetadata(),
processMap
);
// create inference list based on filtered ProcessMap
List<String> filteredInferenceList = createInferenceList(filteredProcessMap).stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (filteredInferenceList.isEmpty()) {
handler.accept(ingestDocument, null);
} else {
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
}
}
}, e -> { handler.accept(null, e); }));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty;
import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.ingest.ConfigurationUtils.readStringProperty;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.SKIP_EXISTING;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_SKIP_EXISTING;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD;
Expand All @@ -17,24 +20,30 @@
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
import org.opensearch.transport.client.OpenSearchClient;

/**
* Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory {

private final OpenSearchClient openSearchClient;

private final MLCommonsClientAccessor clientAccessor;

private final Environment environment;

private final ClusterService clusterService;

public TextEmbeddingProcessorFactory(
final OpenSearchClient openSearchClient,
final MLCommonsClientAccessor clientAccessor,
final Environment environment,
final ClusterService clusterService
) {
super(TYPE);
this.openSearchClient = openSearchClient;
this.clientAccessor = clientAccessor;
this.environment = environment;
this.clusterService = clusterService;
Expand All @@ -43,7 +52,21 @@ public TextEmbeddingProcessorFactory(
@Override
protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map<String, Object> config) {
String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD);
Map<String, Object> filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService);
Map<String, Object> fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD);
boolean skipExisting = readBooleanProperty(TYPE, tag, config, SKIP_EXISTING, DEFAULT_SKIP_EXISTING);
TextEmbeddingInferenceFilter textEmbeddingInferenceFilter = new TextEmbeddingInferenceFilter(fieldMap);
return new TextEmbeddingProcessor(
tag,
description,
batchSize,
modelId,
fieldMap,
skipExisting,
textEmbeddingInferenceFilter,
openSearchClient,
clientAccessor,
environment,
clusterService
);
}
}
Loading

0 comments on commit fc83219

Please sign in to comment.