Skip to content

Commit

Permalink
Fixed the deserialization issue to ensure that vectors created using …
Browse files Browse the repository at this point in the history
…string are correctly de-searilized

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Apr 30, 2024
1 parent e0dd60c commit a4bed4f
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

package org.opensearch.knn.index.codec.util;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
Expand All @@ -17,10 +18,10 @@
import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
public class KNNCodecUtil {
Expand Down Expand Up @@ -63,20 +64,27 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
final float[] vector;
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (isVectorRepresentedAsString(bytesref)) {
String vectorString = new String(bytesref.bytes, 1, bytesref.bytes.length - 1);
String vectorString = new String(byteStream.readAllBytes());
// resetting as we have just read all the bytes
byteStream.reset();
if (isVectorRepresentedAsString(vectorString)) {
String[] array = vectorString.split(",");
vector = new float[array.length];
try {
for (int i = 0; i < array.length; i++) {
if(i == 0) {
array[i] = array[i].substring(1, array[i].length());
}
vector[i] = Float.parseFloat(array[i]);
}
} catch (Exception e) {
log.error("Error while converting floats for str: {}", vectorString, e);
log.error("Vector String: {}", vectorString);
//log.error("Error while converting floats for str: {}", vectorString, e);
}
stopWatch.stop();
log.info("Time taken to deserialize vector with string is : {} ms", stopWatch.totalTime().millis());
} else {

serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
vector = vectorSerializer.byteToFloatArray(byteStream);
Expand Down Expand Up @@ -149,8 +157,9 @@ private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues)
return totalLiveDocs;
}

private static boolean isVectorRepresentedAsString(BytesRef bytesref) {
private static boolean isVectorRepresentedAsString(String vectorString) {
// Check if first bye is the special character that we have added or not.
return "N".equals(new String(bytesref.bytes, 0, 1));
char n = 'N';
return n == vectorString.charAt(0);
}
}

0 comments on commit a4bed4f

Please sign in to comment.