Skip to content

Commit

Permalink
Enabled transfering of vectors from Java to jni layer for Nmslib duri…
Browse files Browse the repository at this point in the history
…ng index creation to reduce the jvm memory footprint

Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Feb 29, 2024
1 parent ba09ac2 commit d2e07c6
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 7 deletions.
3 changes: 3 additions & 0 deletions jni/include/nmslib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ namespace knn_jni {
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
jstring indexPathJ, jobject parametersJ);

void CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorAddressJ, jint dim, jstring indexPathJ, jobject parametersJ);

// Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters
//
// Return a pointer to the loaded index
Expand Down
3 changes: 3 additions & 0 deletions jni/include/org_opensearch_knn_jni_NmslibService.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ extern "C" {
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndexWithMemoryAddress
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_NmslibService
* Method: loadIndex
Expand Down
146 changes: 146 additions & 0 deletions jni/src/nmslib_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,152 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J
}
}

void knn_jni::nmslib_wrapper::CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorAddressJ, jint dim, jstring indexPathJ, jobject parametersJ) {

if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorAddressJ == 0) {
throw std::runtime_error("Vectors Address cannot be 0");
}

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

// Handle parameters
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
std::vector<std::string> indexParameters;

// Algorithm parameters will be in a sub map
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);

if(subParametersCpp.find(knn_jni::EF_CONSTRUCTION) != subParametersCpp.end()) {
auto efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::EF_CONSTRUCTION]);
indexParameters.push_back(knn_jni::EF_CONSTRUCTION_NMSLIB + "=" + std::to_string(efConstruction));
}

if(subParametersCpp.find(knn_jni::M) != subParametersCpp.end()) {
auto m = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::M]);
indexParameters.push_back(knn_jni::M_NMSLIB + "=" + std::to_string(m));
}

jniUtil->DeleteLocalRef(env, subParametersJ);
}

if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto indexThreadQty = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
indexParameters.push_back(knn_jni::INDEX_THREAD_QUANTITY + "=" + std::to_string(indexThreadQty));
}

jniUtil->DeleteLocalRef(env, parametersJ);

// Get the path to save the index
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));

// Get space type for this index
jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
spaceTypeCpp = TranslateSpaceType(spaceTypeCpp);

std::unique_ptr<similarity::Space<float>> space;
space.reset(similarity::SpaceFactoryRegistry<float>::Instance().CreateSpace(spaceTypeCpp,similarity::AnyParams()));

std::vector<float> *inputVectors = reinterpret_cast<std::vector<float>*>(vectorAddressJ);

// Get number of ids and vectors and dimension
//int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
int numVectors = inputVectors->size() / dim;
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}
//int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);

// Read dataset
similarity::ObjectVector dataset;
dataset.reserve(numVectors);
int* idsCpp;
int topLevelPointer = 0;
try {
// Read in data set
idsCpp = jniUtil->GetIntArrayElements(env, idsJ, nullptr);

//float* floatArrayCpp;
jfloatArray floatArrayJ;
size_t vectorSizeInBytes = dim*sizeof(float);

// Allocate a large buffer that will contain all the vectors. Allocating the objects in one large buffer as
// opposed to individually will prevent heap fragmentation. We have observed that allocating individual
// objects causes RSS to rise throughout the lifetime of a process
// (see https://github.com/opensearch-project/k-NN/issues/772 and
// https://github.com/opensearch-project/k-NN/issues/72). This is because, in typical systems, small
// allocations will reside on some kind of heap managed by an allocator. Once freed, the allocator does not
// always return the memory to the OS. If the heap gets fragmented, this will cause the allocator
// to ask for more memory, causing RSS to grow. On large allocations (> 128 kb), most allocators will
// internally use mmap. Once freed, unmap will be called, which will immediately return memory to the OS
// which in turn prevents RSS from growing out of control. Wrap with a smart pointer so that buffer will be
// freed once variable goes out of scope. For reference, the code that specifies the layout of the buffer can be
// found: https://github.com/nmslib/nmslib/blob/v2.1.1/similarity_search/include/object.h#L61-L75
std::unique_ptr<char[]> objectBuffer(new char[(similarity::ID_SIZE + similarity::LABEL_SIZE + similarity::DATALENGTH_SIZE + vectorSizeInBytes) * numVectors]);
char* ptr = objectBuffer.get();
for (int i = 0; i < numVectors; i++) {
dataset.push_back(new similarity::Object(ptr));

memcpy(ptr, &idsCpp[i], similarity::ID_SIZE);
ptr += similarity::ID_SIZE;
memcpy(ptr, &DEFAULT_LABEL, similarity::LABEL_SIZE);
ptr += similarity::LABEL_SIZE;
memcpy(ptr, &vectorSizeInBytes, similarity::DATALENGTH_SIZE);
ptr += similarity::DATALENGTH_SIZE;

// floatArrayJ = (jfloatArray)jniUtil->GetObjectArrayElement(env, vectorsJ, i);
// if (dim != jniUtil->GetJavaFloatArrayLength(env, floatArrayJ)) {
// throw std::runtime_error("Dimension of vectors is inconsistent");
// }

//floatArrayCpp = jniUtil->GetFloatArrayElements(env, floatArrayJ, nullptr);
float floatArrayCpp[dim];
for(int j = 0 ; j < dim; j ++) {
floatArrayCpp[j] = inputVectors->at(topLevelPointer);
topLevelPointer++;
}

memcpy(ptr, &floatArrayCpp, vectorSizeInBytes);
//jniUtil->ReleaseFloatArrayElements(env, floatArrayJ, floatArrayCpp, JNI_ABORT);
ptr += vectorSizeInBytes;
}
jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT);

std::unique_ptr<similarity::Index<float>> index;
index.reset(similarity::MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset));
index->CreateIndex(similarity::AnyParams(indexParameters));
index->SaveIndex(indexPathCpp);

for (auto & it : dataset) {
delete it;
}
delete inputVectors;
} catch (...) {
for (auto & it : dataset) {
delete it;
}
delete inputVectors;

jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT);
throw;
}
}


jlong knn_jni::nmslib_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ,
jobject parametersJ) {

Expand Down
13 changes: 13 additions & 0 deletions jni/src/org_opensearch_knn_jni_NmslibService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNI
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndexWithMemoryAddress(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorAddressJ, jint dim, jstring indexPathJ,
jobject parametersJ)
{
try {
knn_jni::nmslib_wrapper::CreateIndex_With_Memory_Address(&jniUtil, env, idsJ, vectorAddressJ, dim ,indexPathJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}



JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex(JNIEnv * env, jclass cls,
jstring indexPathJ, jobject parametersJ)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.vectors.length == 0 || pair.docs.length == 0) {
if (pair.vectorsAddress == 0 && (pair.vectors.length == 0 || pair.docs.length == 0)) {
logger.info("Skipping engine index creation as there are no vectors or docs in the documents");
return;
}
long arraySize = calculateArraySize(pair.vectors, pair.serializationMode);
//long arraySize = calculateArraySize(pair.vectors, pair.serializationMode);
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
//KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
Expand Down Expand Up @@ -150,7 +150,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
}

if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
//recordMergeStats(pair.docs.length, arraySize);
}

if (isRefresh) {
Expand Down Expand Up @@ -223,7 +223,8 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa

// Pass the path for the nms library to save the file
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName());
JNIService.createIndex(pair.docs, pair.vectors, pair.vectorsAddress, pair.dim, indexPath, parameters,
knnEngine.getName());
return null;
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

package org.opensearch.knn.index.codec.util;

import org.apache.commons.lang.ArrayUtils;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.jni.JNIService;
import org.apache.commons.lang.ArrayUtils;

import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -26,37 +29,59 @@ public class KNNCodecUtil {
// Java rounds each array size up to multiples of 8 bytes
public static final int JAVA_ROUNDING_NUMBER = 8;

private static final int MAX_SIZE_OF_VECTOR_ARRAY = 10000;

public static final class Pair {
public Pair(int[] docs, float[][] vectors, SerializationMode serializationMode) {
this.docs = docs;
this.vectors = vectors;
this.serializationMode = serializationMode;
this.vectorsAddress = 0;
this.dim = 0;
}

public int[] docs;
public float[][] vectors;
public SerializationMode serializationMode;
public long vectorsAddress;
public int dim;
}

public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException {
ArrayList<float[]> vectorList = new ArrayList<>();
ArrayList<Integer> docIdList = new ArrayList<>();
SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS;
long vectorAddress = 0;
int dim = 0;
for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
dim = vector.length;
vectorList.add(vector);
}
if(vectorList.size() == MAX_SIZE_OF_VECTOR_ARRAY) {
vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {}));
vectorList = new ArrayList<>();
}
docIdList.add(doc);
}
return new KNNCodecUtil.Pair(

if(vectorList.size() > 0) {
vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {}));
vectorList = new ArrayList<>();
}
KNNCodecUtil.Pair pair = new KNNCodecUtil.Pair(
docIdList.stream().mapToInt(Integer::intValue).toArray(),
vectorList.toArray(new float[][] {}),
new float[0][],
serializationMode
);
pair.vectorsAddress = vectorAddress;
pair.dim = dim;
return pair;

}

public static long calculateArraySize(float[][] vectors, SerializationMode serializationMode) {
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ public static void createIndex(int[] ids, float[][] data, String indexPath, Map<
throw new IllegalArgumentException("CreateIndex not supported for provided engine");
}

public static void createIndex(int[] ids, float[][] data, long vectorAddress, int dim, String indexPath,
Map<String, Object> parameters, String engineName) {
if (KNNEngine.NMSLIB.getName().equals(engineName)) {
NmslibService.createIndexWithMemoryAddress(ids, vectorAddress, dim, indexPath, parameters);
return;
}

if (KNNEngine.FAISS.getName().equals(engineName)) {
FaissService.createIndex(ids, data, indexPath, parameters);
return;
}

throw new IllegalArgumentException("CreateIndex not supported for provided engine");
}

/**
* Create an index for the native library with a provided template index
*
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/jni/NmslibService.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class NmslibService {
*/
public static native void createIndex(int[] ids, float[][] data, String indexPath, Map<String, Object> parameters);

public static native void createIndexWithMemoryAddress(int[] ids, long address, int dim, String indexPath,
Map<String, Object> parameters);

/**
* Load an index into memory
*
Expand Down

0 comments on commit d2e07c6

Please sign in to comment.