Skip to content

Commit

Permalink
address casting issues
Browse files Browse the repository at this point in the history
Signed-off-by: will-hwang <sang7239@gmail.com>
  • Loading branch information
will-hwang committed Feb 28, 2025
1 parent 90851b6 commit 9bf5cce
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {

public abstract void doExecute(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
Map<String, Object> processMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
);
Expand Down Expand Up @@ -167,7 +167,7 @@ void preprocessIngestDocument(IngestDocument ingestDocument) {
* @param handler a callback handler to handle inference results which is a list of objects.
* @param onException an exception callback to handle exception.
*/
protected abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);
abstract void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException);

@Override
public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
Expand Down Expand Up @@ -591,21 +591,21 @@ private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceV
* This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
*
* @param ingestDocument ingestDocument to populate embeddings to
* @param ProcessMap map indicating the path in ingestDocument to populate embeddings
* @param processMap map indicating the path in ingestDocument to populate embeddings
* @param inferenceList list of texts to be model inference
* @param handler SourceAndMetadataMap of ingestDocument Document
*
*/
protected void makeInferenceCall(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
Map<String, Object> processMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
setVectorFieldsToDocument(ingestDocument, processMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public final class TextEmbeddingProcessor extends InferenceProcessor {
public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
public static final String SKIP_EXISTING = "skip_existing";
public static final boolean DEFAULT_SKIP_EXISTING = Boolean.FALSE;
public static final boolean DEFAULT_SKIP_EXISTING = false;
private static final String INDEX_FIELD = "_index";
private static final String ID_FIELD = "_id";
private final OpenSearchClient openSearchClient;
Expand Down Expand Up @@ -62,44 +62,39 @@ public TextEmbeddingProcessor(
@Override
public void doExecute(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
Map<String, Object> processMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
if (skipExisting) { // if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings
// have been copied
if (skipExisting) {
// if skipExisting flag is turned on, eligible inference texts will be compared and filtered after embeddings are copied
String index = ingestDocument.getSourceAndMetadata().get(INDEX_FIELD).toString();
String id = ingestDocument.getSourceAndMetadata().get(ID_FIELD).toString();
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
final Map<String, Object> existingDocument = response.getSourceAsMap();
if (existingDocument == null || existingDocument.isEmpty()) {
makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler);
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
} else {
// filter given ProcessMap by comparing existing document with ingestDocument
Map<String, Object> filteredProcessMap = textEmbeddingInferenceFilter.filter(
existingDocument,
ingestDocument.getSourceAndMetadata(),
ProcessMap
processMap
);
// create inference list based on filtered ProcessMap
List<String> filteredInferenceList = createInferenceList(filteredProcessMap).stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (!filteredInferenceList.isEmpty()) {
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
} else {
if (filteredInferenceList.isEmpty()) {
handler.accept(ingestDocument, null);
} else {
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);
}
}
}, e -> { handler.accept(null, e); }));
} else { // skip existing flag is turned off. Call model inference without filtering
mlCommonsClientAccessor.inferenceSentences(
TextInferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
} else {
// skip existing flag is turned off. Call model inference without filtering
makeInferenceCall(ingestDocument, processMap, inferenceList, handler);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ public abstract List<Object> filterInferenceValuesInList(
* @param existingSourceAndMetadataMap The metadata map of the existing document.
* @param sourceAndMetadataMap The metadata map of the new document.
* @param processMap The current map being processed.
*
* @return A filtered map containing only elements that require new embeddings.
*
*/
public Map<String, Object> filter(
Map<String, Object> existingSourceAndMetadataMap,
Expand Down Expand Up @@ -129,11 +127,8 @@ private Map<String, Object> filter(
}
Map<String, Object> filteredProcessMap = new HashMap<>();
Map<String, Object> castedProcessMap = ProcessorUtils.castToMap(processMap);
for (Map.Entry<?, ?> entry : castedProcessMap.entrySet()) {
if ((entry.getKey() instanceof String) == false) {
throw new IllegalArgumentException("key for processMap must be a string");
}
String key = (String) entry.getKey();
for (Map.Entry<String, Object> entry : castedProcessMap.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
String currentPath = traversedPath.isEmpty() ? key : traversedPath + "." + key;
if (value instanceof Map<?, ?>) {
Expand All @@ -142,7 +137,7 @@ private Map<String, Object> filter(
} else if (value instanceof List) {
List<Object> processedList = filterListValue(
currentPath,
(List<Object>) value,
ProcessorUtils.castToObjectList(value),
sourceAndMetadataMap,
existingSourceAndMetadataMap
);
Expand Down Expand Up @@ -192,8 +187,8 @@ protected List<Object> filterListValue(
// return empty list if processList and existingList are equal and embeddings are copied, return empty list otherwise
return filterInferenceValuesInList(
processList,
(List<Object>) existingList.get(),
(List<Object>) embeddingList.get(),
ProcessorUtils.castToObjectList(existingList.get()),
ProcessorUtils.castToObjectList(embeddingList.get()),
sourceAndMetadataMap,
existingSourceAndMetadataMap,
embeddingKey
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
Expand Down Expand Up @@ -43,23 +44,24 @@ public Object filterInferenceValue(
int index
) {
String textPath = reversedFieldMap.get(embeddingPath);
if (textPath == null) {
if (Objects.isNull(textPath)) {
return processValue;
}
Optional<Object> existingValue = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, textPath, index);
Optional<Object> embeddingValue = ProcessorUtils.getValueFromSource(existingSourceAndMetadataMap, embeddingPath, index);

if (existingValue.isPresent() && embeddingValue.isPresent() && existingValue.get().equals(processValue)) {
ProcessorUtils.setValueToSource(sourceAndMetadataMap, embeddingPath, embeddingValue.get(), index);
return null; // if successfully copied, return null to be filtered out from process map
// if successfully copied, return null to be filtered out from process map
return null;
}
return processValue; // processValue and existingValue are different, return processValue to be included in process map
// processValue and existingValue are different, return processValue to be included in process map
return processValue;
}

/**
* Filters List value by checking if the texts in list are identical in both the existing and new document.
* If lists are equal, the corresponding embeddings are copied
*
* @return empty list if embeddings are reused; the original list otherwise.
*/
@Override
Expand All @@ -73,8 +75,10 @@ public List<Object> filterInferenceValuesInList(
) {
if (processList.equals(existingList)) {
ProcessorUtils.setValueToSource(sourceAndMetadataMap, fullEmbeddingKey, embeddingList);
return Collections.emptyList(); // if successfully copied, return empty list to be filtered out from process map
// if successfully copied, return empty list to be filtered out from process map
return Collections.emptyList();
}
return processList; // source list and existing list are different, return processList to be included in process map
// source list and existing list are different, return processList to be included in process map
return processList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;

Expand Down Expand Up @@ -117,7 +118,7 @@ public static void removeTargetFieldFromSource(final Map<String, Object> sourceA
if (key.equals(lastKey)) {
break;
}
currentMap = (Map<String, Object>) currentMap.get(key);
currentMap = castToMap(currentMap.get(key));
}

// Remove the last key this is guaranteed
Expand All @@ -132,8 +133,7 @@ public static void removeTargetFieldFromSource(final Map<String, Object> sourceA
parentMap = currentParentMapWithChild.v1();
key = currentParentMapWithChild.v2();

@SuppressWarnings("unchecked")
Map<String, Object> innerMap = (Map<String, Object>) parentMap.get(key);
Map<String, Object> innerMap = castToMap(parentMap.get(key));

if (innerMap != null && innerMap.isEmpty()) {
parentMap.remove(key);
Expand Down Expand Up @@ -167,13 +167,13 @@ public static Optional<Object> getValueFromSource(final Map<String, Object> sour
for (String key : keys) {
currentValue = currentValue.flatMap(value -> {
if (value instanceof ArrayList<?> && index != -1) {
Object listValue = ((ArrayList) value).get(index);
Object listValue = (castToObjectList(value)).get(index);
if (listValue instanceof Map) {
Map<String, Object> currentMap = (Map<String, Object>) listValue;
Map<String, Object> currentMap = castToMap(listValue);
return Optional.ofNullable(currentMap.get(key));
}
} else if (value instanceof Map<?, ?>) {
Map<String, Object> currentMap = (Map<String, Object>) value;
Map<String, Object> currentMap = castToMap(value);
return Optional.ofNullable(currentMap.get(key));
}
return Optional.empty();
Expand Down Expand Up @@ -207,7 +207,7 @@ public static void setValueToSource(Map<String, Object> sourceAsMap, String targ
*/

public static void setValueToSource(Map<String, Object> sourceAsMap, String targetKey, Object targetValue, int index) {
if (sourceAsMap == null || targetKey == null) return;
if (Objects.isNull(sourceAsMap) || Objects.isNull(targetKey)) return;

String[] keys = targetKey.split("\\.");
Map<String, Object> current = sourceAsMap;
Expand All @@ -217,10 +217,10 @@ public static void setValueToSource(Map<String, Object> sourceAsMap, String targ
if (next instanceof ArrayList<?> list) {
if (index < 0 || index >= list.size()) return;
if (list.get(index) instanceof Map) {
current = (Map<String, Object>) list.get(index);
current = castToMap(list.get(index));
}
} else if (next instanceof Map) {
current = (Map<String, Object>) next;
current = castToMap(next);
} else {
throw new IllegalStateException("Unexpected data structure at " + keys[i]);
}
Expand Down Expand Up @@ -274,4 +274,11 @@ public static boolean isNumeric(Object value) {
public static Map<String, Object> castToMap(Object obj) {
return (Map<String, Object>) obj;
}

// This method should be used only when you are certain the object is a `List<Object>`.
// It is recommended to use this method as a last resort.
@SuppressWarnings("unchecked")
public static List<Object> castToObjectList(Object obj) {
return (List<Object>) obj;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,23 @@ public void testFlattenAndFlip_withMultipleLevelsWithNestedMaps_thenSuccess() {
Map<String, String> actual = ProcessorDocumentUtils.flattenAndFlip(nestedMap);
assertEquals(expected, actual);
}

public void testUnflatten_withListOfObject_thenSuccess() {
Map<String, Object> map1 = Map.of("b.c", "d", "f", "h");
Map<String, Object> map2 = Map.of("b.c", "e", "f", "i");
List<Map<String, Object>> list = Arrays.asList(map1, map2);
Map<String, Object> input = Map.of("a", list);

Map<String, Object> nestedB1 = Map.of("c", "d");
Map<String, Object> expectedMap1 = Map.of("b", nestedB1, "f", "h");
Map<String, Object> nestedB2 = Map.of("c", "e");
Map<String, Object> expectedMap2 = Map.of("b", nestedB2, "f", "i");

List<Map<String, Object>> expectedList = Arrays.asList(expectedMap1, expectedMap2);

Map<String, Object> expected = Map.of("a", expectedList);

Map<String, Object> result = ProcessorDocumentUtils.unflattenJson(input);
assertEquals(expected, result);
}
}

0 comments on commit 9bf5cce

Please sign in to comment.