Skip to content

Commit

Permalink
Added code to parse the vectors as string and not as an array
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Apr 6, 2024
1 parent b6d626a commit ae20eb0
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 24 deletions.
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ public VectorField(String name, float[] value, IndexableFieldType type) {
}
}

public VectorField(String name, String vector, IndexableFieldType type) {
super(name, new BytesRef(), type);
try {
//final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer();
final byte[] floatToByte = vector.getBytes();
this.setBytesValue(floatToByte);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

/**
* @param name FieldType name
* @param value an array of byte vector values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = null;
try {
pair = KNNCodecUtil.getFloats(values);
pair = KNNCodecUtil.getFloats(values, field.getAttribute("is_String"));
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the segment");
return;
Expand Down
41 changes: 22 additions & 19 deletions src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static final class Pair {

}

public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException {
public static KNNCodecUtil.Pair getFloats(BinaryDocValues values, String isString) throws IOException {
List<float[]> vectorList = new ArrayList<>();
List<Integer> docIdList = new ArrayList<>();
long vectorAddress = 0;
Expand All @@ -57,26 +57,29 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
dimension = vector.length;
// serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
// final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);

if (vectorsPerTransfer == Integer.MIN_VALUE) {
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
if("true".equals(isString)) {
String vectorString = new String(byteStream.readAllBytes());
String[] array = vectorString.split(",");
final float[] vector = new float[array.length];
for (int i = 0; i < array.length; i++) {
vector[i] = Float.parseFloat(array[i]);
}
dimension = vector.length;
if (vectorsPerTransfer == Integer.MIN_VALUE) {
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
}
if (vectorList.size() == vectorsPerTransfer) {
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][]{}), totalLiveDocs * dimension);
// We should probably come up with a better way to reuse the vectorList memory which we have
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
vectorList = new ArrayList<>();
}
vectorList.add(vector);
}
if (vectorList.size() == vectorsPerTransfer) {
vectorAddress = JNICommons.storeVectorData(
vectorAddress,
vectorList.toArray(new float[][] {}),
totalLiveDocs * dimension
);
// We should probably come up with a better way to reuse the vectorList memory which we have
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
vectorList = new ArrayList<>();
}
vectorList.add(vector);
}
docIdList.add(doc);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,14 +563,15 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);
Optional<String> floatsArrayOptional = getFloatsFromContextString(context, dimension,
methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
}
final float[] array = floatsArrayOptional.get();
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);
//final float[] array = floatsArrayOptional.get();
//spaceType.validateVector(array);
VectorField point = new VectorField(name(), floatsArrayOptional.get(), fieldType);
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else {
Expand Down Expand Up @@ -691,6 +692,7 @@ Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, Meth
ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;
String vectors;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
Expand Down Expand Up @@ -735,6 +737,73 @@ Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, Meth
return Optional.of(array);
}

Optional<String> getFloatsFromContextString(ParseContext context, int dimension,
MethodComponentContext methodComponentContext)
throws IOException {
context.path().add(simpleName());

// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'.
// If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be
// clipped to FP16 range.
boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext);
boolean clipVectorValueToFP16RangeFlag = false;
if (isFaissSQfp16Flag) {
clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
);
}

ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;
String vectors = null;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);
} else {
validateFP16VectorValue(value);
}
} else {
validateFloatVectorValue(value);
}

vector.add(value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);
} else {
validateFP16VectorValue(value);
}
} else {
validateFloatVectorValue(value);
}
vector.add(value);
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
context.path().remove();
return Optional.empty();
} else if(token == XContentParser.Token.VALUE_STRING) {
vectors = context.parser().text();
}
//validateVectorDimension(dimension, vector.size());

// float[] array = new float[vector.size()];
// int i = 0;
// for (Float f : vector) {
// array[i++] = f;
// }
return Optional.ofNullable(vectors);
}

@Override
protected boolean docValuesByDefault() {
return true;
Expand Down Expand Up @@ -781,6 +850,7 @@ public static class Defaults {
FIELD_TYPE.setIndexOptions(IndexOptions.NONE);
FIELD_TYPE.setDocValuesType(DocValuesType.BINARY);
FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type
FIELD_TYPE.putAttribute("is_String", "true");
FIELD_TYPE.freeze();
}
}
Expand Down

0 comments on commit ae20eb0

Please sign in to comment.