Skip to content

Commit

Permalink
Implement string based vector ingestion for improving the build time
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Apr 23, 2024
1 parent a84e13a commit 3421a67
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 52 deletions.
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,22 @@ public static void validateVectorDimension(int dimension, int vectorSize) {
throw new IllegalArgumentException(errorMessage);
}
}

public static void validateVectorDimension(int dimension, String vectors) {
int index = 0;
int count = 0;
while (index != -1) {
index = vectors.indexOf(',', index + 1); // Slight improvement
if (index != -1) {
count++;
}
}
// ensuring we are counting floats
count++;

if (dimension != count) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, count);
throw new IllegalArgumentException(errorMessage);
}
}
}
8 changes: 2 additions & 6 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,7 @@ public class KNNSettings {
Setting.Property.Deprecated
);

public static final Setting<Boolean> CREATE_GRAPHS = Setting.boolSetting(
"knn.create_graphs",
false,
NodeScope, Dynamic
);
public static final Setting<Boolean> CREATE_GRAPHS = Setting.boolSetting("knn.create_graphs", false, NodeScope, Dynamic);

/**
* M - the number of bi-directional links created for every new element during construction.
Expand Down Expand Up @@ -375,7 +371,7 @@ private Setting<?> getSetting(String key) {
return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING;
}

if("knn.create_graphs".equals(key)) {
if ("knn.create_graphs".equals(key)) {
return CREATE_GRAPHS;
}

Expand Down
1 change: 0 additions & 1 deletion src/main/java/org/opensearch/knn/index/VectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ public VectorField(String name, float[] value, IndexableFieldType type) {
public VectorField(String name, String vector, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
//final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer();
final byte[] floatToByte = vector.getBytes();
this.setBytesValue(floatToByte);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th
stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
logger.warn("Refresh operation complete in " + time_in_millis + " ms");
logger.info("Refresh operation complete in " + time_in_millis + " ms");
}
}

Expand All @@ -106,14 +106,14 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) {

public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
throws IOException {
if(KNNSettings.canCreateGraphs() == false) {
if (KNNSettings.canCreateGraphs() == false) {
log.info("Not creating graphs as value is : {}", KNNSettings.canCreateGraphs());
return;
}
log.info("Creating graphs as value is : {}", KNNSettings.canCreateGraphs());
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values, field.getAttribute("is_String"));
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the segment");
return;
Expand Down Expand Up @@ -257,7 +257,7 @@ public void merge(MergeState mergeState) {
stopWatch.stop();
long time_in_millis = stopWatch.totalTime().millis();
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
logger.warn("Merge operation complete in " + time_in_millis + " ms");
logger.info("Merge operation complete in " + time_in_millis + " ms");
}
}
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public static final class Pair {

}

public static KNNCodecUtil.Pair getFloats(BinaryDocValues values, String isString) throws IOException {
public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException {
List<float[]> vectorList = new ArrayList<>();
List<Integer> docIdList = new ArrayList<>();
long vectorAddress = 0;
Expand All @@ -63,23 +63,21 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values, String isStrin
final float[] vector;
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if("true".equals(isString)) {
String vectorString = new String(byteStream.readAllBytes());
if (isVectorRepresentedAsString(bytesref)) {
String vectorString = new String(bytesref.bytes, 1, bytesref.bytes.length - 1);
String[] array = vectorString.split(",");
vector = new float[array.length];
for (int i = 0; i < array.length; i++) {
vector[i] = Float.parseFloat(array[i]);
}
stopWatch.stop();
log.info("Time taken to deserialize vector with string is : {} ms",
stopWatch.totalTime().millis());
log.info("Time taken to deserialize vector with string is : {} ms", stopWatch.totalTime().millis());
} else {
serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
vector = vectorSerializer.byteToFloatArray(byteStream);
stopWatch.stop();
log.info("Time taken to deserialize vector with float array is : {} ms",
stopWatch.totalTime().millis());
log.info("Time taken to deserialize vector with float array is : {} ms", stopWatch.totalTime().millis());
}
dimension = vector.length;

Expand Down Expand Up @@ -146,4 +144,9 @@ private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues)
}
return totalLiveDocs;
}

private static boolean isVectorRepresentedAsString(BytesRef bytesref) {
// Check if first bye is the special character that we have added or not.
return "$".equals(new String(bytesref.bytes, 0, 1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,7 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand Down Expand Up @@ -563,15 +555,19 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<String> floatsArrayOptional = getFloatsFromContextString(context, dimension,
methodComponentContext);
Optional<ParsedVector> optionalVector = getVectorFromContextString(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
if (optionalVector.isEmpty()) {
return;
}
//final float[] array = floatsArrayOptional.get();
//spaceType.validateVector(array);
VectorField point = new VectorField(name(), floatsArrayOptional.get(), fieldType);
// final float[] array = floatsArrayOptional.get();
// spaceType.validateVector(array);
VectorField point;
if (optionalVector.get().getVector() != null) {
point = new VectorField(name(), optionalVector.get().getVector(), fieldType);
} else {
point = new VectorField(name(), optionalVector.get().getVectorString(), fieldType);
}
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
} else {
Expand Down Expand Up @@ -737,9 +733,8 @@ Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, Meth
return Optional.of(array);
}

Optional<String> getFloatsFromContextString(ParseContext context, int dimension,
MethodComponentContext methodComponentContext)
throws IOException {
Optional<ParsedVector> getVectorFromContextString(ParseContext context, int dimension, MethodComponentContext methodComponentContext)
throws IOException {
context.path().add(simpleName());

// Returns an optional array of float values where each value in the vector is parsed as a float and validated
Expand All @@ -750,14 +745,16 @@ Optional<String> getFloatsFromContextString(ParseContext context, int dimension,
boolean clipVectorValueToFP16RangeFlag = false;
if (isFaissSQfp16Flag) {
clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
);
}

ArrayList<Float> vector = new ArrayList<>();
float[] vector = new float[dimension];
int vectorIndex = 0;
XContentParser.Token token = context.parser().currentToken();
float value;
String vectors = null;
String vectors = "$";
final ParsedVector.ParsedVectorBuilder parsedVector = ParsedVector.builder();
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
Expand All @@ -771,8 +768,17 @@ Optional<String> getFloatsFromContextString(ParseContext context, int dimension,
} else {
validateFloatVectorValue(value);
}

vector.add(value);
if (vectorIndex >= dimension) {
String errorMessage = String.format(
Locale.ROOT,
"Vector dimension mismatch. Expected: %d, Given " + "value is greater than : %d",
dimension,
dimension
);
throw new IllegalArgumentException(errorMessage);
}
vector[vectorIndex] = value;
vectorIndex++;
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
Expand All @@ -786,22 +792,33 @@ Optional<String> getFloatsFromContextString(ParseContext context, int dimension,
} else {
validateFloatVectorValue(value);
}
vector.add(value);
if (vectorIndex >= dimension) {
String errorMessage = String.format(
Locale.ROOT,
"Vector dimension mismatch. Expected: %d, Given " + "value is greater than : %d",
dimension,
dimension
);
throw new IllegalArgumentException(errorMessage);
}
vector[vectorIndex] = value;
vectorIndex++;
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return Optional.empty();
} else if(token == XContentParser.Token.VALUE_STRING) {
vectors = context.parser().text();
} else if (token == XContentParser.Token.VALUE_STRING) {
// as vectors have only $ so it will be fast
vectors = vectors + context.parser().text();
}
//validateVectorDimension(dimension, vector.size());

// float[] array = new float[vector.size()];
// int i = 0;
// for (Float f : vector) {
// array[i++] = f;
// }
return Optional.ofNullable(vectors);
if (vectors.equals("$")) {
validateVectorDimension(dimension, vectorIndex);
parsedVector.vector(vector);
} else {
validateVectorDimension(dimension, vectors);
parsedVector.vectorString(vectors);
}
return Optional.ofNullable(parsedVector.build());
}

@Override
Expand Down Expand Up @@ -850,7 +867,6 @@ public static class Defaults {
FIELD_TYPE.setIndexOptions(IndexOptions.NONE);
FIELD_TYPE.setDocValuesType(DocValuesType.BINARY);
FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type
FIELD_TYPE.putAttribute("is_String", "true");
FIELD_TYPE.freeze();
}
}
Expand Down
22 changes: 22 additions & 0 deletions src/main/java/org/opensearch/knn/index/mapper/ParsedVector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.mapper;

import lombok.Builder;
import lombok.Value;

@Builder
@Value
public class ParsedVector {
float[] vector;
String vectorString;
}

0 comments on commit 3421a67

Please sign in to comment.