diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 284214631f..f4d417f805 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -23,12 +23,21 @@ namespace knn_jni { void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ); + void CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorAddress, jint dim, jstring indexPathJ, jobject parametersJ); + + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); + void CreateIndexFromTemplate_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddress, jint dim, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ); + + // Load an index from indexPathJ into memory. // // Return a pointer to the loaded index diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index a252643355..ae20ce2836 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -26,6 +26,10 @@ extern "C" { JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject); + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexWithMemoryAddress + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate @@ -34,6 +38,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jbyteArray, jobject); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplateWithMemoryAddress + (JNIEnv * , jclass, jintArray, jlong , jint , jstring , jbyteArray ,jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadIndex @@ -90,6 +97,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors (JNIEnv *, jclass, jlong, jobjectArray); +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2 + (JNIEnv *, jclass, jlong, jobjectArray); + /* * Class: org_opensearch_knn_jni_FaissService * Method: freeVectors diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index e8fb4de201..c302bad82a 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -124,6 +124,92 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN faiss::write_index(&idMap, indexPathCpp.c_str()); } + +void knn_jni::faiss_wrapper::CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorAddressJ, jint dim, jstring indexPathJ, jobject parametersJ) { + + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + +// if (vectorsJ == nullptr) { +// throw std::runtime_error("Vectors cannot be null"); +// } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Get space type for this index + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + auto *inputVectors = reinterpret_cast*>(vectorAddressJ); + + // Read data set + //int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + int numVectors = inputVectors->size() / dim; + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + //int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); + //auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); + //auto dataset = jniUtil->GetFloatArrayElements(env, ) + std::vector dataset; + for (int i = 0; i < numVectors; i++) { + dataset.push_back(inputVectors->at(i)); + } + + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_factory(dim, indexDescriptionCpp.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); + SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get()); + jniUtil->DeleteLocalRef(env, subParametersJ); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, dataset.data(), idVector.data()); + + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index(&idMap, indexPathCpp.c_str()); + delete inputVectors; +} + + void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ) { @@ -184,6 +270,73 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * faiss::write_index(&idMap, indexPathCpp.c_str()); } +void knn_jni::faiss_wrapper::CreateIndexFromTemplate_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddress, jint dim, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + +// if (vectorsJ == nullptr) { +// throw std::runtime_error("Vectors cannot be null"); +// } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Read data set + //int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + auto *inputVectors = reinterpret_cast*>(vectorsAddress); + int numVectors = inputVectors->size() / dim; + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + //int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); + //auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); + std::vector dataset; + for (int i = 0; i < numVectors; i++) { + dataset.push_back(inputVectors->at(i)); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, dataset.data(), idVector.data()); + + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index(&idMap, indexPathCpp.c_str()); + delete inputVectors; +} + jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { if (indexPathJ == nullptr) { throw std::runtime_error("Index path cannot be null"); diff --git a/jni/src/nmslib_wrapper.cpp b/jni/src/nmslib_wrapper.cpp index 3301217364..ab9e383c86 100644 --- a/jni/src/nmslib_wrapper.cpp +++ b/jni/src/nmslib_wrapper.cpp @@ -261,6 +261,7 @@ void knn_jni::nmslib_wrapper::CreateIndex_With_Memory_Address(knn_jni::JNIUtilIn // found: https://github.com/nmslib/nmslib/blob/v2.1.1/similarity_search/include/object.h#L61-L75 std::unique_ptr objectBuffer(new char[(similarity::ID_SIZE + similarity::LABEL_SIZE + similarity::DATALENGTH_SIZE + vectorSizeInBytes) * numVectors]); char* ptr = objectBuffer.get(); + std::cout<<"Number of vectors: "<GetFloatArrayElements(env, floatArrayJ, nullptr); + std::cout<<"Top Level Pointer before getting vectors"<at(topLevelPointer); topLevelPointer++; } - + std::cout<<"Top Level Pointer after getting vectors"<ReleaseFloatArrayElements(env, floatArrayJ, floatArrayCpp, JNI_ABORT); ptr += vectorSizeInBytes; diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 908d557a3e..a47936e986 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -51,6 +51,17 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexWithMemoryAddress(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorAddress, jint dim, jstring indexPathJ, + jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateIndex_With_Memory_Address(&jniUtil, env, idsJ, vectorAddress, dim, indexPathJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, jintArray idsJ, jobjectArray vectorsJ, @@ -65,6 +76,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplateWithMemoryAddress(JNIEnv * env, jclass cls, + jintArray idsJ, + jlong vectorsAddress, + jint dim, + jstring indexPathJ, + jbyteArray templateIndexJ, + jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateIndexFromTemplate_With_Memory_Address(&jniUtil, env, idsJ, vectorsAddress, dim, indexPathJ, templateIndexJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { try { @@ -142,6 +168,24 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors vect = reinterpret_cast*>(vectorsPointerJ); } + int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); + auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); + vect->insert(vect->begin(), dataset.begin(), dataset.end()); + + return (jlong) vect; +} + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2(JNIEnv * env, jclass cls, + jlong vectorsPointerJ, + jobjectArray vectorsJ) +{ + std::vector *vect; + if ((long) vectorsPointerJ == 0) { + vect = new std::vector; + } else { + vect = reinterpret_cast*>(vectorsPointerJ); + } + int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); vect->insert(vect->end(), dataset.begin(), dataset.end()); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index dbd18c8c62..7bbd3838ff 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -184,7 +184,8 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine.getName()); + JNIService.createIndexFromTemplate(pair.docs, pair.vectors, pair.vectorsAddress, pair.dim, indexPath, model, + parameters, knnEngine.getName()); return null; }); } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index ccf2d217e9..744de3ae21 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -63,14 +63,14 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep vectorList.add(vector); } if(vectorList.size() == MAX_SIZE_OF_VECTOR_ARRAY) { - vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {})); + vectorAddress = JNIService.transferVectorsV2(vectorAddress, vectorList.toArray(new float[][] {})); vectorList = new ArrayList<>(); } docIdList.add(doc); } if(vectorList.size() > 0) { - vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {})); + vectorAddress = JNIService.transferVectorsV2(vectorAddress, vectorList.toArray(new float[][] {})); vectorList = new ArrayList<>(); } KNNCodecUtil.Pair pair = new KNNCodecUtil.Pair( diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 5dce15d6e0..91bb9b2d1c 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -48,6 +48,10 @@ class FaissService { */ public static native void createIndex(int[] ids, float[][] data, String indexPath, Map parameters); + + public static native void createIndexWithMemoryAddress(int[] ids, long vectorAddress, int dim, String indexPath, + Map parameters); + /** * Create an index for the native library with a provided template index * @@ -65,6 +69,15 @@ public static native void createIndexFromTemplate( Map parameters ); + public static native void createIndexFromTemplateWithMemoryAddress( + int[] ids, + long vectorAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters + ); + /** * Load an index into memory * @@ -115,6 +128,8 @@ public static native void createIndexFromTemplate( */ public static native long transferVectors(long vectorsPointer, float[][] trainingData); + public static native long transferVectorsV2(long vectorsPointer, float[][] trainingData); + /** * Free vectors from memory * diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 75be038154..c935065ddc 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -53,7 +53,7 @@ public static void createIndex(int[] ids, float[][] data, long vectorAddress, in } if (KNNEngine.FAISS.getName().equals(engineName)) { - FaissService.createIndex(ids, data, indexPath, parameters); + FaissService.createIndexWithMemoryAddress(ids, vectorAddress, dim, indexPath, parameters); return; } @@ -86,6 +86,25 @@ public static void createIndexFromTemplate( throw new IllegalArgumentException("CreateIndexFromTemplate not supported for provided engine"); } + public static void createIndexFromTemplate( + int[] ids, + float[][] data, + long vectorAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters, + String engineName + ) { + if (KNNEngine.FAISS.getName().equals(engineName)) { + FaissService.createIndexFromTemplateWithMemoryAddress(ids, vectorAddress, dim, indexPath, templateIndex, + parameters); + return; + } + + throw new IllegalArgumentException("CreateIndexFromTemplate not supported for provided engine"); + } + /** * Load an index into memory * @@ -182,6 +201,10 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData) return FaissService.transferVectors(vectorsPointer, trainingData); } + public static long transferVectorsV2(long vectorsPointer, float[][] trainingData) { + return FaissService.transferVectors(vectorsPointer, trainingData); + } + /** * Free vectors from memory *