From eccf8245df28dbbe3e83f3249bfb6f1c092fb1ce Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 29 Apr 2024 17:57:57 -0700 Subject: [PATCH] Fixed the deserialization issue to ensure that vectors created using string are correctly de-searilized Signed-off-by: Navneet Verma --- .../knn/index/codec/util/KNNCodecUtil.java | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index ac68d7f848..d385a0bbed 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,10 +5,11 @@ package org.opensearch.knn.index.codec.util; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.log4j.Log4j2; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; @@ -17,10 +18,10 @@ import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; import org.opensearch.knn.jni.JNICommons; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; @Log4j2 public class KNNCodecUtil { @@ -63,20 +64,27 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep final float[] vector; StopWatch stopWatch = new StopWatch(); stopWatch.start(); - if (isVectorRepresentedAsString(bytesref)) { - String vectorString = new String(bytesref.bytes, 1, bytesref.bytes.length - 1); + String vectorString = new String(byteStream.readAllBytes()); + // resetting as we have just read all the bytes + byteStream.reset(); + if (isVectorRepresentedAsString(vectorString)) { String[] array = vectorString.split(","); vector = new float[array.length]; try { for (int i = 0; i < array.length; i++) { + if (i == 0) { + array[i] = array[i].substring(1, array[i].length()); + } vector[i] = Float.parseFloat(array[i]); } } catch (Exception e) { - log.error("Error while converting floats for str: {}", vectorString, e); + log.error("Vector String: {}", vectorString); + // log.error("Error while converting floats for str: {}", vectorString, e); } stopWatch.stop(); log.info("Time taken to deserialize vector with string is : {} ms", stopWatch.totalTime().millis()); } else { + serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); vector = vectorSerializer.byteToFloatArray(byteStream); @@ -149,8 +157,9 @@ private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) return totalLiveDocs; } - private static boolean isVectorRepresentedAsString(BytesRef bytesref) { + private static boolean isVectorRepresentedAsString(String vectorString) { // Check if first bye is the special character that we have added or not. - return "N".equals(new String(bytesref.bytes, 0, 1)); + char n = 'N'; + return n == vectorString.charAt(0); } }