diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index ac68d7f84..fc9a3d680 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,22 +5,22 @@ package org.opensearch.knn.index.codec.util; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; -import lombok.extern.log4j.Log4j2; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; -import org.opensearch.common.StopWatch; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; import org.opensearch.knn.jni.JNICommons; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; @Log4j2 public class KNNCodecUtil { @@ -61,27 +61,29 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep BytesRef bytesref = values.binaryValue(); try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { final float[] vector; - StopWatch stopWatch = new StopWatch(); - stopWatch.start(); - if (isVectorRepresentedAsString(bytesref)) { - String vectorString = new String(bytesref.bytes, 1, bytesref.bytes.length - 1); + char firstChar = (char) (byteStream.read()); + // resetting as we have just read all the bytes + byteStream.reset(); + if (isVectorRepresentedAsString(firstChar)) { + String vectorString = new String(byteStream.readAllBytes()); String[] array = vectorString.split(","); vector = new float[array.length]; try { for (int i = 0; i < array.length; i++) { + if (i == 0) { + array[i] = array[i].substring(1, array[i].length()); + } vector[i] = Float.parseFloat(array[i]); } } catch (Exception e) { - log.error("Error while converting floats for str: {}", vectorString, e); + log.error("Vector String: {}", vectorString); + // log.error("Error while converting floats for str: {}", vectorString, e); } - stopWatch.stop(); - log.info("Time taken to deserialize vector with string is : {} ms", stopWatch.totalTime().millis()); } else { + serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); vector = vectorSerializer.byteToFloatArray(byteStream); - stopWatch.stop(); - log.info("Time taken to deserialize vector with float array is : {} ms", stopWatch.totalTime().millis()); } dimension = vector.length; @@ -149,8 +151,9 @@ private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) return totalLiveDocs; } - private static boolean isVectorRepresentedAsString(BytesRef bytesref) { + private static boolean isVectorRepresentedAsString(char firstChar) { // Check if first bye is the special character that we have added or not. - return "N".equals(new String(bytesref.bytes, 0, 1)); + char n = 'N'; + return n == firstChar; } }