Skip to content

Commit

Permalink
Enabled faiss engine with vector streaming from Java to jni layer
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Mar 2, 2024
1 parent c5287a3 commit 443a0ac
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 5 deletions.
9 changes: 9 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@ namespace knn_jni {
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ,
jstring indexPathJ, jobject parametersJ);

void CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorAddress, jint dim, jstring indexPathJ, jobject parametersJ);


// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

void CreateIndexFromTemplate_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddress, jint dim, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ);


// Load an index from indexPathJ into memory.
//
// Return a pointer to the loaded index
Expand Down
10 changes: 10 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ extern "C" {
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject);


JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexWithMemoryAddress
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand All @@ -34,6 +38,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate
(JNIEnv *, jclass, jintArray, jobjectArray, jstring, jbyteArray, jobject);

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplateWithMemoryAddress
(JNIEnv * , jclass, jintArray, jlong , jint , jstring , jbyteArray ,jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndex
Expand Down Expand Up @@ -90,6 +97,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
(JNIEnv *, jclass, jlong, jobjectArray);

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2
(JNIEnv *, jclass, jlong, jobjectArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: freeVectors
Expand Down
153 changes: 153 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,92 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
faiss::write_index(&idMap, indexPathCpp.c_str());
}


void knn_jni::faiss_wrapper::CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorAddressJ, jint dim, jstring indexPathJ, jobject parametersJ) {

if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

// if (vectorsJ == nullptr) {
// throw std::runtime_error("Vectors cannot be null");
// }

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

// parametersJ is a Java Map<String, Object>. ConvertJavaMapToCppMap converts it to a c++ map<string, jobject>
// so that it is easier to access.
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

// Get space type for this index
jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorAddressJ);

// Read data set
//int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
int numVectors = inputVectors->size() / dim;
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

//int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
//auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);
//auto dataset = jniUtil->GetFloatArrayElements(env, )
std::vector<float> dataset;
for (int i = 0; i < numVectors; i++) {
dataset.push_back(inputVectors->at(i));
}


// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory(dim, indexDescriptionCpp.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get());
jniUtil->DeleteLocalRef(env, subParametersJ);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Check that the index does not need to be trained
if(!indexWriter->is_trained) {
throw std::runtime_error("Index is not trained");
}

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());

// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
delete inputVectors;
}


void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
Expand Down Expand Up @@ -184,6 +270,73 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
faiss::write_index(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddress, jint dim, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

// if (vectorsJ == nullptr) {
// throw std::runtime_error("Vectors cannot be null");
// }

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
//int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddress);
int numVectors = inputVectors->size() / dim;
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

//int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
//auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);
std::vector<float> dataset;
for (int i = 0; i < numVectors; i++) {
dataset.push_back(inputVectors->at(i));
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());

// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
delete inputVectors;
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand Down
4 changes: 3 additions & 1 deletion jni/src/nmslib_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ void knn_jni::nmslib_wrapper::CreateIndex_With_Memory_Address(knn_jni::JNIUtilIn
// found: https://github.com/nmslib/nmslib/blob/v2.1.1/similarity_search/include/object.h#L61-L75
std::unique_ptr<char[]> objectBuffer(new char[(similarity::ID_SIZE + similarity::LABEL_SIZE + similarity::DATALENGTH_SIZE + vectorSizeInBytes) * numVectors]);
char* ptr = objectBuffer.get();
std::cout<<"Number of vectors: "<<numVectors<<std::endl;
for (int i = 0; i < numVectors; i++) {
dataset.push_back(new similarity::Object(ptr));

Expand All @@ -277,12 +278,13 @@ void knn_jni::nmslib_wrapper::CreateIndex_With_Memory_Address(knn_jni::JNIUtilIn
// }

//floatArrayCpp = jniUtil->GetFloatArrayElements(env, floatArrayJ, nullptr);
std::cout<<"Top Level Pointer before getting vectors"<<topLevelPointer<<std::endl;
float floatArrayCpp[dim];
for(int j = 0 ; j < dim; j ++) {
floatArrayCpp[j] = inputVectors->at(topLevelPointer);
topLevelPointer++;
}

std::cout<<"Top Level Pointer after getting vectors"<<topLevelPointer<<std::endl;
memcpy(ptr, &floatArrayCpp, vectorSizeInBytes);
//jniUtil->ReleaseFloatArrayElements(env, floatArrayJ, floatArrayCpp, JNI_ABORT);
ptr += vectorSizeInBytes;
Expand Down
44 changes: 44 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexWithMemoryAddress(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorAddress, jint dim, jstring indexPathJ,
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateIndex_With_Memory_Address(&jniUtil, env, idsJ, vectorAddress, dim, indexPathJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jobjectArray vectorsJ,
Expand All @@ -65,6 +76,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplateWithMemoryAddress(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddress,
jint dim,
jstring indexPathJ,
jbyteArray templateIndexJ,
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateIndexFromTemplate_With_Memory_Address(&jniUtil, env, idsJ, vectorsAddress, dim, indexPathJ, templateIndexJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ)
{
try {
Expand Down Expand Up @@ -142,6 +168,24 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
vect = reinterpret_cast<std::vector<float>*>(vectorsPointerJ);
}

int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);
vect->insert(vect->begin(), dataset.begin(), dataset.end());

return (jlong) vect;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2(JNIEnv * env, jclass cls,
jlong vectorsPointerJ,
jobjectArray vectorsJ)
{
std::vector<float> *vect;
if ((long) vectorsPointerJ == 0) {
vect = new std::vector<float>;
} else {
vect = reinterpret_cast<std::vector<float>*>(vectorsPointerJ);
}

int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);
vect->insert(vect->end(), dataset.begin(), dataset.end());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
);
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName());
JNIService.createIndexFromTemplate(pair.docs, pair.vectors, pair.vectorsAddress, pair.dim, indexPath, model,
parameters, knnEngine.getName());
return null;
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep
vectorList.add(vector);
}
if(vectorList.size() == MAX_SIZE_OF_VECTOR_ARRAY) {
vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {}));
vectorAddress = JNIService.transferVectorsV2(vectorAddress, vectorList.toArray(new float[][] {}));
vectorList = new ArrayList<>();
}
docIdList.add(doc);
}

if(vectorList.size() > 0) {
vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {}));
vectorAddress = JNIService.transferVectorsV2(vectorAddress, vectorList.toArray(new float[][] {}));
vectorList = new ArrayList<>();
}
KNNCodecUtil.Pair pair = new KNNCodecUtil.Pair(
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class FaissService {
*/
public static native void createIndex(int[] ids, float[][] data, String indexPath, Map<String, Object> parameters);


public static native void createIndexWithMemoryAddress(int[] ids, long vectorAddress, int dim, String indexPath,
Map<String, Object> parameters);

/**
* Create an index for the native library with a provided template index
*
Expand All @@ -65,6 +69,15 @@ public static native void createIndexFromTemplate(
Map<String, Object> parameters
);

public static native void createIndexFromTemplateWithMemoryAddress(
int[] ids,
long vectorAddress,
int dim,
String indexPath,
byte[] templateIndex,
Map<String, Object> parameters
);

/**
* Load an index into memory
*
Expand Down Expand Up @@ -115,6 +128,8 @@ public static native void createIndexFromTemplate(
*/
public static native long transferVectors(long vectorsPointer, float[][] trainingData);

public static native long transferVectorsV2(long vectorsPointer, float[][] trainingData);

/**
* Free vectors from memory
*
Expand Down
Loading

0 comments on commit 443a0ac

Please sign in to comment.