Skip to content

Commit

Permalink
Fixed the deserialization issue to ensure that vectors created using …
Browse files Browse the repository at this point in the history
…string are correctly de-searilized

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed May 1, 2024
1 parent e0dd60c commit 4face75
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@

package org.opensearch.knn.index.codec.util;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;

@Log4j2
public class KNNCodecUtil {
Expand Down Expand Up @@ -61,27 +61,29 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
final float[] vector;
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (isVectorRepresentedAsString(bytesref)) {
String vectorString = new String(bytesref.bytes, 1, bytesref.bytes.length - 1);
char firstChar = (char) (byteStream.read());
// resetting as we have just read all the bytes
byteStream.reset();
if (isVectorRepresentedAsString(firstChar)) {
String vectorString = new String(byteStream.readAllBytes());
String[] array = vectorString.split(",");
vector = new float[array.length];
try {
for (int i = 0; i < array.length; i++) {
if (i == 0) {
array[i] = array[i].substring(1, array[i].length());
}
vector[i] = Float.parseFloat(array[i]);
}
} catch (Exception e) {
log.error("Error while converting floats for str: {}", vectorString, e);
log.error("Vector String: {}", vectorString);
// log.error("Error while converting floats for str: {}", vectorString, e);
}
stopWatch.stop();
log.info("Time taken to deserialize vector with string is : {} ms", stopWatch.totalTime().millis());
} else {

serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
vector = vectorSerializer.byteToFloatArray(byteStream);
stopWatch.stop();
log.info("Time taken to deserialize vector with float array is : {} ms", stopWatch.totalTime().millis());
}
dimension = vector.length;

Expand Down Expand Up @@ -149,8 +151,9 @@ private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues)
return totalLiveDocs;
}

private static boolean isVectorRepresentedAsString(BytesRef bytesref) {
private static boolean isVectorRepresentedAsString(char firstChar) {
// Check if first bye is the special character that we have added or not.
return "N".equals(new String(bytesref.bytes, 0, 1));
char n = 'N';
return n == firstChar;
}
}

0 comments on commit 4face75

Please sign in to comment.