From 3421a67339e7043523b3a865c3dccce7fa0bec44 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 22 Apr 2024 17:18:47 -0700 Subject: [PATCH] Implement string based vector ingestion for improving the build time Signed-off-by: Navneet Verma --- .../knn/common/KNNValidationUtil.java | 18 ++++ .../org/opensearch/knn/index/KNNSettings.java | 8 +- .../org/opensearch/knn/index/VectorField.java | 1 - .../KNN80Codec/KNN80DocValuesConsumer.java | 8 +- .../knn/index/codec/util/KNNCodecUtil.java | 17 ++-- .../index/mapper/KNNVectorFieldMapper.java | 84 +++++++++++-------- .../knn/index/mapper/ParsedVector.java | 22 +++++ 7 files changed, 106 insertions(+), 52 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/mapper/ParsedVector.java diff --git a/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java index ca8e1459a..9fc18abe7 100644 --- a/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java @@ -80,4 +80,22 @@ public static void validateVectorDimension(int dimension, int vectorSize) { throw new IllegalArgumentException(errorMessage); } } + + public static void validateVectorDimension(int dimension, String vectors) { + int index = 0; + int count = 0; + while (index != -1) { + index = vectors.indexOf(',', index + 1); // Slight improvement + if (index != -1) { + count++; + } + } + // ensuring we are counting floats + count++; + + if (dimension != count) { + String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, count); + throw new IllegalArgumentException(errorMessage); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 442535550..2cbfc5b37 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -119,11 +119,7 @@ public class KNNSettings { Setting.Property.Deprecated ); - public static final Setting CREATE_GRAPHS = Setting.boolSetting( - "knn.create_graphs", - false, - NodeScope, Dynamic - ); + public static final Setting CREATE_GRAPHS = Setting.boolSetting("knn.create_graphs", false, NodeScope, Dynamic); /** * M - the number of bi-directional links created for every new element during construction. @@ -375,7 +371,7 @@ private Setting getSetting(String key) { return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; } - if("knn.create_graphs".equals(key)) { + if ("knn.create_graphs".equals(key)) { return CREATE_GRAPHS; } diff --git a/src/main/java/org/opensearch/knn/index/VectorField.java b/src/main/java/org/opensearch/knn/index/VectorField.java index e65201999..3770ac059 100644 --- a/src/main/java/org/opensearch/knn/index/VectorField.java +++ b/src/main/java/org/opensearch/knn/index/VectorField.java @@ -27,7 +27,6 @@ public VectorField(String name, float[] value, IndexableFieldType type) { public VectorField(String name, String vector, IndexableFieldType type) { super(name, new BytesRef(), type); try { - //final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer(); final byte[] floatToByte = vector.getBytes(); this.setBytesValue(floatToByte); } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 0402719db..9556343dd 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -83,7 +83,7 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); - logger.warn("Refresh operation complete in " + time_in_millis + " ms"); + logger.info("Refresh operation complete in " + time_in_millis + " ms"); } } @@ -106,14 +106,14 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) throws IOException { - if(KNNSettings.canCreateGraphs() == false) { + if (KNNSettings.canCreateGraphs() == false) { log.info("Not creating graphs as value is : {}", KNNSettings.canCreateGraphs()); return; } log.info("Creating graphs as value is : {}", KNNSettings.canCreateGraphs()); // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values, field.getAttribute("is_String")); + KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); return; @@ -257,7 +257,7 @@ public void merge(MergeState mergeState) { stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); - logger.warn("Merge operation complete in " + time_in_millis + " ms"); + logger.info("Merge operation complete in " + time_in_millis + " ms"); } } } catch (Exception e) { diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 4abe3067e..721c6c135 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -46,7 +46,7 @@ public static final class Pair { } - public static KNNCodecUtil.Pair getFloats(BinaryDocValues values, String isString) throws IOException { + public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { List vectorList = new ArrayList<>(); List docIdList = new ArrayList<>(); long vectorAddress = 0; @@ -63,23 +63,21 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values, String isStrin final float[] vector; StopWatch stopWatch = new StopWatch(); stopWatch.start(); - if("true".equals(isString)) { - String vectorString = new String(byteStream.readAllBytes()); + if (isVectorRepresentedAsString(bytesref)) { + String vectorString = new String(bytesref.bytes, 1, bytesref.bytes.length - 1); String[] array = vectorString.split(","); vector = new float[array.length]; for (int i = 0; i < array.length; i++) { vector[i] = Float.parseFloat(array[i]); } stopWatch.stop(); - log.info("Time taken to deserialize vector with string is : {} ms", - stopWatch.totalTime().millis()); + log.info("Time taken to deserialize vector with string is : {} ms", stopWatch.totalTime().millis()); } else { serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); vector = vectorSerializer.byteToFloatArray(byteStream); stopWatch.stop(); - log.info("Time taken to deserialize vector with float array is : {} ms", - stopWatch.totalTime().millis()); + log.info("Time taken to deserialize vector with float array is : {} ms", stopWatch.totalTime().millis()); } dimension = vector.length; @@ -146,4 +144,9 @@ private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) } return totalLiveDocs; } + + private static boolean isVectorRepresentedAsString(BytesRef bytesref) { + // Check if first bye is the special character that we have added or not. + return "$".equals(new String(bytesref.bytes, 0, 1)); + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 45f0372d0..a1a157099 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -52,15 +52,7 @@ import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; import java.util.Objects; -import java.util.Optional; -import java.util.function.Supplier; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -563,15 +555,19 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point); } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContextString(context, dimension, - methodComponentContext); + Optional optionalVector = getVectorFromContextString(context, dimension, methodComponentContext); - if (floatsArrayOptional.isEmpty()) { + if (optionalVector.isEmpty()) { return; } - //final float[] array = floatsArrayOptional.get(); - //spaceType.validateVector(array); - VectorField point = new VectorField(name(), floatsArrayOptional.get(), fieldType); + // final float[] array = floatsArrayOptional.get(); + // spaceType.validateVector(array); + VectorField point; + if (optionalVector.get().getVector() != null) { + point = new VectorField(name(), optionalVector.get().getVector(), fieldType); + } else { + point = new VectorField(name(), optionalVector.get().getVectorString(), fieldType); + } context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point); } else { @@ -737,9 +733,8 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth return Optional.of(array); } - Optional getFloatsFromContextString(ParseContext context, int dimension, - MethodComponentContext methodComponentContext) - throws IOException { + Optional getVectorFromContextString(ParseContext context, int dimension, MethodComponentContext methodComponentContext) + throws IOException { context.path().add(simpleName()); // Returns an optional array of float values where each value in the vector is parsed as a float and validated @@ -750,14 +745,16 @@ Optional getFloatsFromContextString(ParseContext context, int dimension, boolean clipVectorValueToFP16RangeFlag = false; if (isFaissSQfp16Flag) { clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) + (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) ); } - ArrayList vector = new ArrayList<>(); + float[] vector = new float[dimension]; + int vectorIndex = 0; XContentParser.Token token = context.parser().currentToken(); float value; - String vectors = null; + String vectors = "$"; + final ParsedVector.ParsedVectorBuilder parsedVector = ParsedVector.builder(); if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { @@ -771,8 +768,17 @@ Optional getFloatsFromContextString(ParseContext context, int dimension, } else { validateFloatVectorValue(value); } - - vector.add(value); + if (vectorIndex >= dimension) { + String errorMessage = String.format( + Locale.ROOT, + "Vector dimension mismatch. Expected: %d, Given " + "value is greater than : %d", + dimension, + dimension + ); + throw new IllegalArgumentException(errorMessage); + } + vector[vectorIndex] = value; + vectorIndex++; token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { @@ -786,22 +792,33 @@ Optional getFloatsFromContextString(ParseContext context, int dimension, } else { validateFloatVectorValue(value); } - vector.add(value); + if (vectorIndex >= dimension) { + String errorMessage = String.format( + Locale.ROOT, + "Vector dimension mismatch. Expected: %d, Given " + "value is greater than : %d", + dimension, + dimension + ); + throw new IllegalArgumentException(errorMessage); + } + vector[vectorIndex] = value; + vectorIndex++; context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { context.path().remove(); return Optional.empty(); - } else if(token == XContentParser.Token.VALUE_STRING) { - vectors = context.parser().text(); + } else if (token == XContentParser.Token.VALUE_STRING) { + // as vectors have only $ so it will be fast + vectors = vectors + context.parser().text(); } - //validateVectorDimension(dimension, vector.size()); - -// float[] array = new float[vector.size()]; -// int i = 0; -// for (Float f : vector) { -// array[i++] = f; -// } - return Optional.ofNullable(vectors); + if (vectors.equals("$")) { + validateVectorDimension(dimension, vectorIndex); + parsedVector.vector(vector); + } else { + validateVectorDimension(dimension, vectors); + parsedVector.vectorString(vectors); + } + return Optional.ofNullable(parsedVector.build()); } @Override @@ -850,7 +867,6 @@ public static class Defaults { FIELD_TYPE.setIndexOptions(IndexOptions.NONE); FIELD_TYPE.setDocValuesType(DocValuesType.BINARY); FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type - FIELD_TYPE.putAttribute("is_String", "true"); FIELD_TYPE.freeze(); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ParsedVector.java b/src/main/java/org/opensearch/knn/index/mapper/ParsedVector.java new file mode 100644 index 000000000..62933b25b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/ParsedVector.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.mapper; + +import lombok.Builder; +import lombok.Value; + +@Builder +@Value +public class ParsedVector { + float[] vector; + String vectorString; +}