diff --git a/src/main/java/org/opensearch/knn/index/VectorField.java b/src/main/java/org/opensearch/knn/index/VectorField.java index f28ef6238..e65201999 100644 --- a/src/main/java/org/opensearch/knn/index/VectorField.java +++ b/src/main/java/org/opensearch/knn/index/VectorField.java @@ -24,6 +24,17 @@ public VectorField(String name, float[] value, IndexableFieldType type) { } } + public VectorField(String name, String vector, IndexableFieldType type) { + super(name, new BytesRef(), type); + try { + //final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer(); + final byte[] floatToByte = vector.getBytes(); + this.setBytesValue(floatToByte); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + /** * @param name FieldType name * @param value an array of byte vector values diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index db332c5d1..2d9276f0c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -118,7 +118,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, BinaryDocValues values = valuesProducer.getBinary(field); KNNCodecUtil.Pair pair = null; try { - pair = KNNCodecUtil.getFloats(values); + pair = KNNCodecUtil.getFloats(values, field.getAttribute("is_String")); if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); return; diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index c5ae469e0..abf23744e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -43,7 +43,7 @@ public static final class Pair { } - public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { + public static KNNCodecUtil.Pair getFloats(BinaryDocValues values, String isString) throws IOException { List vectorList = new ArrayList<>(); List docIdList = new ArrayList<>(); long vectorAddress = 0; @@ -57,26 +57,29 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { - serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - final float[] vector = vectorSerializer.byteToFloatArray(byteStream); - dimension = vector.length; +// serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); +// final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); - if (vectorsPerTransfer == Integer.MIN_VALUE) { - vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + if("true".equals(isString)) { + String vectorString = new String(byteStream.readAllBytes()); + String[] array = vectorString.split(","); + final float[] vector = new float[array.length]; + for (int i = 0; i < array.length; i++) { + vector[i] = Float.parseFloat(array[i]); + } + dimension = vector.length; + if (vectorsPerTransfer == Integer.MIN_VALUE) { + vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + } + if (vectorList.size() == vectorsPerTransfer) { + vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][]{}), totalLiveDocs * dimension); + // We should probably come up with a better way to reuse the vectorList memory which we have + // created. Problem here is doing like this can lead to a lot of list memory which is of no use and + // will be garbage collected later on, but it creates pressure on JVM. We should revisit this. + vectorList = new ArrayList<>(); + } + vectorList.add(vector); } - if (vectorList.size() == vectorsPerTransfer) { - vectorAddress = JNICommons.storeVectorData( - vectorAddress, - vectorList.toArray(new float[][] {}), - totalLiveDocs * dimension - ); - // We should probably come up with a better way to reuse the vectorList memory which we have - // created. Problem here is doing like this can lead to a lot of list memory which is of no use and - // will be garbage collected later on, but it creates pressure on JVM. We should revisit this. - vectorList = new ArrayList<>(); - } - vectorList.add(vector); } docIdList.add(doc); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index a36a4222b..305b4f9f3 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -563,14 +563,15 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point.toString()); } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); + Optional floatsArrayOptional = getFloatsFromContextString(context, dimension, + methodComponentContext); if (floatsArrayOptional.isEmpty()) { return; } - final float[] array = floatsArrayOptional.get(); - spaceType.validateVector(array); - VectorField point = new VectorField(name(), array, fieldType); + //final float[] array = floatsArrayOptional.get(); + //spaceType.validateVector(array); + VectorField point = new VectorField(name(), floatsArrayOptional.get(), fieldType); context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point.toString()); } else { @@ -691,6 +692,7 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); float value; + String vectors; if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { @@ -735,6 +737,73 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth return Optional.of(array); } + Optional getFloatsFromContextString(ParseContext context, int dimension, + MethodComponentContext methodComponentContext) + throws IOException { + context.path().add(simpleName()); + + // Returns an optional array of float values where each value in the vector is parsed as a float and validated + // if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'. + // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be + // clipped to FP16 range. + boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext); + boolean clipVectorValueToFP16RangeFlag = false; + if (isFaissSQfp16Flag) { + clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) + ); + } + + ArrayList vector = new ArrayList<>(); + XContentParser.Token token = context.parser().currentToken(); + float value; + String vectors = null; + if (token == XContentParser.Token.START_ARRAY) { + token = context.parser().nextToken(); + while (token != XContentParser.Token.END_ARRAY) { + value = context.parser().floatValue(); + if (isFaissSQfp16Flag) { + if (clipVectorValueToFP16RangeFlag) { + value = clipVectorValueToFP16Range(value); + } else { + validateFP16VectorValue(value); + } + } else { + validateFloatVectorValue(value); + } + + vector.add(value); + token = context.parser().nextToken(); + } + } else if (token == XContentParser.Token.VALUE_NUMBER) { + value = context.parser().floatValue(); + if (isFaissSQfp16Flag) { + if (clipVectorValueToFP16RangeFlag) { + value = clipVectorValueToFP16Range(value); + } else { + validateFP16VectorValue(value); + } + } else { + validateFloatVectorValue(value); + } + vector.add(value); + context.parser().nextToken(); + } else if (token == XContentParser.Token.VALUE_NULL) { + context.path().remove(); + return Optional.empty(); + } else if(token == XContentParser.Token.VALUE_STRING) { + vectors = context.parser().text(); + } + //validateVectorDimension(dimension, vector.size()); + +// float[] array = new float[vector.size()]; +// int i = 0; +// for (Float f : vector) { +// array[i++] = f; +// } + return Optional.ofNullable(vectors); + } + @Override protected boolean docValuesByDefault() { return true; @@ -781,6 +850,7 @@ public static class Defaults { FIELD_TYPE.setIndexOptions(IndexOptions.NONE); FIELD_TYPE.setDocValuesType(DocValuesType.BINARY); FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type + FIELD_TYPE.putAttribute("is_String", "true"); FIELD_TYPE.freeze(); } }