diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index b4ada3925..7d0b683b4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -119,7 +119,7 @@ private void validateEmbeddingConfiguration(Map fieldMap) { public abstract void doExecute( IngestDocument ingestDocument, - Map ProcessMap, + Map processMap, List inferenceList, BiConsumer handler ); @@ -167,7 +167,7 @@ void preprocessIngestDocument(IngestDocument ingestDocument) { * @param handler a callback handler to handle inference results which is a list of objects. * @param onException an exception callback to handle exception. */ - protected abstract void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException); + abstract void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException); @Override public void subBatchExecute(List ingestDocumentWrappers, Consumer> handler) { @@ -591,21 +591,21 @@ private List> buildNLPResultForListType(List sourceV * This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument * * @param ingestDocument ingestDocument to populate embeddings to - * @param ProcessMap map indicating the path in ingestDocument to populate embeddings + * @param processMap map indicating the path in ingestDocument to populate embeddings * @param inferenceList list of texts to be model inference * @param handler SourceAndMetadataMap of ingestDocument Document * */ protected void makeInferenceCall( IngestDocument ingestDocument, - Map ProcessMap, + Map processMap, List inferenceList, BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentences( TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + setVectorFieldsToDocument(ingestDocument, processMap, vectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); }) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 62acdf2b1..0bddfde17 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -33,7 +33,7 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; public static final String SKIP_EXISTING = "skip_existing"; - public static final boolean DEFAULT_SKIP_EXISTING = Boolean.FALSE; + public static final boolean DEFAULT_SKIP_EXISTING = false; private static final String INDEX_FIELD = "_index"; private static final String ID_FIELD = "_id"; private final OpenSearchClient openSearchClient; @@ -62,44 +62,39 @@ public TextEmbeddingProcessor( @Override public void doExecute( IngestDocument ingestDocument, - Map ProcessMap, + Map processMap, List inferenceList, BiConsumer handler ) { - if (skipExisting) { // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings - // have been copied + if (skipExisting) { + // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString(); String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString(); openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> { final Map existingDocument = response.getSourceAsMap(); if (existingDocument == null || existingDocument.isEmpty()) { - makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler); + makeInferenceCall(ingestDocument, processMap, inferenceList, handler); } else { // filter given ProcessMap by comparing existing document with ingestDocument Map filteredProcessMap = textEmbeddingInferenceFilter.filter( existingDocument, ingestDocument.getSourceAndMetadata(), - ProcessMap + processMap ); // create inference list based on filtered ProcessMap List filteredInferenceList = createInferenceList(filteredProcessMap).stream() .filter(Objects::nonNull) .collect(Collectors.toList()); - if (!filteredInferenceList.isEmpty()) { - makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler); - } else { + if (filteredInferenceList.isEmpty()) { handler.accept(ingestDocument, null); + } else { + makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler); } } }, e -> { handler.accept(null, e); })); - } else { // skip existing flag is turned off. Call model inference without filtering - mlCommonsClientAccessor.inferenceSentences( - TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(), - ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); }) - ); + } else { + // skip existing flag is turned off. Call model inference without filtering + makeInferenceCall(ingestDocument, processMap, inferenceList, handler); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java b/src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java index 31b8ad984..bd9765e7b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/optimization/InferenceFilter.java @@ -92,9 +92,7 @@ public abstract List filterInferenceValuesInList( * @param existingSourceAndMetadataMap The metadata map of the existing document. * @param sourceAndMetadataMap The metadata map of the new document. * @param processMap The current map being processed. - * * @return A filtered map containing only elements that require new embeddings. - * */ public Map filter( Map existingSourceAndMetadataMap, @@ -129,11 +127,8 @@ private Map filter( } Map filteredProcessMap = new HashMap<>(); Map castedProcessMap = ProcessorUtils.castToMap(processMap); - for (Map.Entry entry : castedProcessMap.entrySet()) { - if ((entry.getKey() instanceof String) == false) { - throw new IllegalArgumentException("key for processMap must be a string"); - } - String key = (String) entry.getKey(); + for (Map.Entry entry : castedProcessMap.entrySet()) { + String key = entry.getKey(); Object value = entry.getValue(); String currentPath = traversedPath.isEmpty() ? key : traversedPath + "." + key; if (value instanceof Map) { @@ -142,7 +137,7 @@ private Map filter( } else if (value instanceof List) { List processedList = filterListValue( currentPath, - (List) value, + ProcessorUtils.castToObjectList(value), sourceAndMetadataMap, existingSourceAndMetadataMap ); @@ -192,8 +187,8 @@ protected List filterListValue( // return empty list if processList and existingList are equal and embeddings are copied, return empty list otherwise return filterInferenceValuesInList( processList, - (List) existingList.get(), - (List) embeddingList.get(), + ProcessorUtils.castToObjectList(existingList.get()), + ProcessorUtils.castToObjectList(embeddingList.get()), sourceAndMetadataMap, existingSourceAndMetadataMap, embeddingKey diff --git a/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java b/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java index e9d50cef8..d82f6ac21 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/optimization/TextEmbeddingInferenceFilter.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; /** @@ -43,7 +44,7 @@ public Object filterInferenceValue( int index ) { String textPath = reversedFieldMap.get(embeddingPath); - if (textPath == null) { + if (Objects.isNull(textPath)) { return processValue; } Optional existingValue = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textPath, index); @@ -51,15 +52,16 @@ public Object filterInferenceValue( if (existingValue.isPresent() && embeddingValue.isPresent() && existingValue.get().equals(processValue)) { ProcessorUtils.setValueToSource(sourceAndMetadataMap, embeddingPath, embeddingValue.get(), index); - return null; // if successfully copied, return null to be filtered out from process map + // if successfully copied, return null to be filtered out from process map + return null; } - return processValue; // processValue and existingValue are different, return processValue to be included in process map + // processValue and existingValue are different, return processValue to be included in process map + return processValue; } /** * Filters List value by checking if the texts in list are identical in both the existing and new document. * If lists are equal, the corresponding embeddings are copied - * * @return empty list if embeddings are reused; the original list otherwise. */ @Override @@ -73,8 +75,10 @@ public List filterInferenceValuesInList( ) { if (processList.equals(existingList)) { ProcessorUtils.setValueToSource(sourceAndMetadataMap, fullEmbeddingKey, embeddingList); - return Collections.emptyList(); // if successfully copied, return empty list to be filtered out from process map + // if successfully copied, return empty list to be filtered out from process map + return Collections.emptyList(); } - return processList; // source list and existing list are different, return processList to be included in process map + // source list and existing list are different, return processList to be included in process map + return processList; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java index 6823990a6..e4c861f09 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java @@ -12,6 +12,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Stack; @@ -117,7 +118,7 @@ public static void removeTargetFieldFromSource(final Map sourceA if (key.equals(lastKey)) { break; } - currentMap = (Map) currentMap.get(key); + currentMap = castToMap(currentMap.get(key)); } // Remove the last key this is guaranteed @@ -132,8 +133,7 @@ public static void removeTargetFieldFromSource(final Map sourceA parentMap = currentParentMapWithChild.v1(); key = currentParentMapWithChild.v2(); - @SuppressWarnings("unchecked") - Map innerMap = (Map) parentMap.get(key); + Map innerMap = castToMap(parentMap.get(key)); if (innerMap != null && innerMap.isEmpty()) { parentMap.remove(key); @@ -167,13 +167,13 @@ public static Optional getValueFromSource(final Map sour for (String key : keys) { currentValue = currentValue.flatMap(value -> { if (value instanceof ArrayList && index != -1) { - Object listValue = ((ArrayList) value).get(index); + Object listValue = (castToObjectList(value)).get(index); if (listValue instanceof Map) { - Map currentMap = (Map) listValue; + Map currentMap = castToMap(listValue); return Optional.ofNullable(currentMap.get(key)); } } else if (value instanceof Map) { - Map currentMap = (Map) value; + Map currentMap = castToMap(value); return Optional.ofNullable(currentMap.get(key)); } return Optional.empty(); @@ -207,7 +207,7 @@ public static void setValueToSource(Map sourceAsMap, String targ */ public static void setValueToSource(Map sourceAsMap, String targetKey, Object targetValue, int index) { - if (sourceAsMap == null || targetKey == null) return; + if (Objects.isNull(sourceAsMap) || Objects.isNull(targetKey)) return; String[] keys = targetKey.split("\\."); Map current = sourceAsMap; @@ -217,10 +217,10 @@ public static void setValueToSource(Map sourceAsMap, String targ if (next instanceof ArrayList list) { if (index < 0 || index >= list.size()) return; if (list.get(index) instanceof Map) { - current = (Map) list.get(index); + current = castToMap(list.get(index)); } } else if (next instanceof Map) { - current = (Map) next; + current = castToMap(next); } else { throw new IllegalStateException("Unexpected data structure at " + keys[i]); } @@ -274,4 +274,11 @@ public static boolean isNumeric(Object value) { public static Map castToMap(Object obj) { return (Map) obj; } + + // This method should be used only when you are certain the object is a `List`. + // It is recommended to use this method as a last resort. + @SuppressWarnings("unchecked") + public static List castToObjectList(Object obj) { + return (List) obj; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java index 075cbdf19..ea6e07c00 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/ProcessorDocumentUtilsTests.java @@ -271,4 +271,23 @@ public void testFlattenAndFlip_withMultipleLevelsWithNestedMaps_thenSuccess() { Map actual = ProcessorDocumentUtils.flattenAndFlip(nestedMap); assertEquals(expected, actual); } + + public void testUnflatten_withListOfObject_thenSuccess() { + Map map1 = Map.of("b.c", "d", "f", "h"); + Map map2 = Map.of("b.c", "e", "f", "i"); + List> list = Arrays.asList(map1, map2); + Map input = Map.of("a", list); + + Map nestedB1 = Map.of("c", "d"); + Map expectedMap1 = Map.of("b", nestedB1, "f", "h"); + Map nestedB2 = Map.of("c", "e"); + Map expectedMap2 = Map.of("b", nestedB2, "f", "i"); + + List> expectedList = Arrays.asList(expectedMap1, expectedMap2); + + Map expected = Map.of("a", expectedList); + + Map result = ProcessorDocumentUtils.unflattenJson(input); + assertEquals(expected, result); + } }