diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 2629eea43..3e1adeac4 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -19,13 +19,13 @@ namespace knn_jni { namespace faiss_wrapper { // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. - void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, + void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); // Load an index from indexPathJ into memory. diff --git a/jni/include/nmslib_wrapper.h b/jni/include/nmslib_wrapper.h index 9b555580a..08494644f 100644 --- a/jni/include/nmslib_wrapper.h +++ b/jni/include/nmslib_wrapper.h @@ -25,7 +25,7 @@ namespace knn_jni { namespace nmslib_wrapper { // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. - void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, + void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddress, jint dim, jstring indexPathJ, jobject parametersJ); // Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 64a858f84..32b6f22f1 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -21,18 +21,18 @@ extern "C" { /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndex - * Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex - (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate - * Signature: ([I[[FLjava/lang/String;[BLjava/util/Map;)V + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate - (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jbyteArray, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); /* * Class: org_opensearch_knn_jni_FaissService @@ -122,14 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors (JNIEnv *, jclass, jlong, jobjectArray); -/* - * Class: org_opensearch_knn_jni_FaissService - * Method: freeVectors - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors - (JNIEnv *, jclass, jlong); - #ifdef __cplusplus } #endif diff --git a/jni/include/org_opensearch_knn_jni_NmslibService.h b/jni/include/org_opensearch_knn_jni_NmslibService.h index 02f58d20f..31422955f 100644 --- a/jni/include/org_opensearch_knn_jni_NmslibService.h +++ b/jni/include/org_opensearch_knn_jni_NmslibService.h @@ -21,10 +21,10 @@ extern "C" { /* * Class: org_opensearch_knn_jni_NmslibService * Method: createIndex - * Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex - (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); /* * Class: org_opensearch_knn_jni_NmslibService diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index a7075740e..03022e0d3 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -81,15 +81,19 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); -void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { +void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } - if (vectorsJ == nullptr) { - throw std::runtime_error("Vectors cannot be null"); + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } if (indexPathJ == nullptr) { @@ -109,16 +113,20 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); - // Read data set - int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); - auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); - // Create faiss index jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); @@ -148,7 +156,7 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, dataset.data(), idVector.data()); + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); // Write the index to disk std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); @@ -156,14 +164,18 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } - if (vectorsJ == nullptr) { - throw std::runtime_error("Vectors cannot be null"); + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } if (indexPathJ == nullptr) { @@ -183,15 +195,15 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil->DeleteLocalRef(env, parametersJ); // Read data set - int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); - auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); - // Get vector of bytes from jbytearray int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); @@ -208,7 +220,7 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, dataset.data(), idVector.data()); + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); // Write the index to disk std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); diff --git a/jni/src/nmslib_wrapper.cpp b/jni/src/nmslib_wrapper.cpp index f63fd2b01..00b7f4311 100644 --- a/jni/src/nmslib_wrapper.cpp +++ b/jni/src/nmslib_wrapper.cpp @@ -32,14 +32,19 @@ std::string TranslateSpaceType(const std::string& spaceType); const similarity::LabelType DEFAULT_LABEL = -1; void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { + jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } - if (vectorsJ == nullptr) { - throw std::runtime_error("Vectors cannot be null"); + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } if (indexPathJ == nullptr) { @@ -91,12 +96,18 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J space.reset(similarity::SpaceFactoryRegistry::Instance().CreateSpace(spaceTypeCpp,similarity::AnyParams())); // Get number of ids and vectors and dimension - int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) ( inputVectors->size() / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); // Read dataset similarity::ObjectVector dataset; @@ -105,10 +116,12 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J try { // Read in data set idsCpp = jniUtil->GetIntArrayElements(env, idsJ, nullptr); - - float* floatArrayCpp; - jfloatArray floatArrayJ; size_t vectorSizeInBytes = dim*sizeof(float); + // vectorPointer needs to be unsigned long long, this will ensure that out of range doesn't happen for this pointer + // when the values of numVectors * dim becomes very large. + // Example: for 10M vectors of 1536 dim vectorPointer max value will be ~15.3B which is already > range of ints. + // keeping it unsigned long long we will never go above the range. + unsigned long long vectorPointer = 0; // Allocate a large buffer that will contain all the vectors. Allocating the objects in one large buffer as // opposed to individually will prevent heap fragmentation. We have observed that allocating individual @@ -134,15 +147,9 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J memcpy(ptr, &vectorSizeInBytes, similarity::DATALENGTH_SIZE); ptr += similarity::DATALENGTH_SIZE; - floatArrayJ = (jfloatArray)jniUtil->GetObjectArrayElement(env, vectorsJ, i); - if (dim != jniUtil->GetJavaFloatArrayLength(env, floatArrayJ)) { - throw std::runtime_error("Dimension of vectors is inconsistent"); - } - - floatArrayCpp = jniUtil->GetFloatArrayElements(env, floatArrayJ, nullptr); - memcpy(ptr, floatArrayCpp, vectorSizeInBytes); - jniUtil->ReleaseFloatArrayElements(env, floatArrayJ, floatArrayCpp, JNI_ABORT); + memcpy(ptr, &(inputVectors->at(vectorPointer)), vectorSizeInBytes); ptr += vectorSizeInBytes; + vectorPointer += dim; } jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index ae9d461ff..3249ed872 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -40,11 +40,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { } JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, - jobject parametersJ) + jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsJ, indexPathJ, parametersJ); + knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -52,13 +52,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, jintArray idsJ, - jobjectArray vectorsJ, + jlong vectorsAddressJ, + jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsJ, indexPathJ, templateIndexJ, parametersJ); + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -189,12 +190,3 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors return (jlong) vect; } - -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIEnv * env, jclass cls, - jlong vectorsPointerJ) -{ - if (vectorsPointerJ != 0) { - auto *vect = reinterpret_cast*>(vectorsPointerJ); - delete vect; - } -} diff --git a/jni/src/org_opensearch_knn_jni_NmslibService.cpp b/jni/src/org_opensearch_knn_jni_NmslibService.cpp index 11dd885b1..d037d3337 100644 --- a/jni/src/org_opensearch_knn_jni_NmslibService.cpp +++ b/jni/src/org_opensearch_knn_jni_NmslibService.cpp @@ -38,11 +38,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { } JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) { try { - knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsJ, indexPathJ, parametersJ); + knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 2b1684cfb..ff9f63893 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -30,17 +30,14 @@ TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; - std::vector> vectors; + std::vector vectors; int dim = 2; + vectors.reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); @@ -63,7 +60,7 @@ TEST(FaissCreateIndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong) &vectors, dim , (jstring)&indexPath, (jobject)¶metersMap); // Make sure index can be loaded @@ -77,17 +74,14 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 100; std::vector ids; - std::vector> vectors; + std::vector vectors; int dim = 2; + vectors.reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); @@ -113,7 +107,7 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { knn_jni::faiss_wrapper::CreateIndexFromTemplate( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong)&vectors, dim, (jstring)&indexPath, reinterpret_cast(&(vectorIoWriter.data)), (jobject) ¶metersMap ); @@ -486,17 +480,14 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; - std::vector> vectors; + std::vector vectors; int dim = 2; + vectors.reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); @@ -519,7 +510,7 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong)&vectors, dim, (jstring)&indexPath, (jobject)¶metersMap); // Make sure index can be loaded diff --git a/jni/tests/nmslib_wrapper_test.cpp b/jni/tests/nmslib_wrapper_test.cpp index 3a21e7401..cb9c14e23 100644 --- a/jni/tests/nmslib_wrapper_test.cpp +++ b/jni/tests/nmslib_wrapper_test.cpp @@ -39,17 +39,14 @@ TEST(NmslibCreateIndexTest, BasicAssertions) { // Define index data int numIds = 100; std::vector ids; - std::vector> vectors; + std::vector vectors; int dim = 2; - for (int i = 0; i < numIds; ++i) { + vectors.reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); @@ -79,7 +76,7 @@ TEST(NmslibCreateIndexTest, BasicAssertions) { // Create the index knn_jni::nmslib_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong) &vectors, dim, (jstring)&indexPath, (jobject)¶metersMap); // Make sure index can be loaded diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 57192beca..2b0e29b55 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNCodecUtil; @@ -110,61 +111,67 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, throws IOException { // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); - if (pair.vectors.length == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); - return; - } - long arraySize = calculateArraySize(pair.vectors, pair.serializationMode); - if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - } - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - // Create library index either from model or from scratch - String engineFileName; - String indexPath; - NativeIndexCreator indexCreator; - final KNNEngine knnEngine = getKNNEngine(field); - if (field.attributes().containsKey(MODEL_ID)) { - - String modelId = field.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getVersion(), field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) - .toString(); - - if (model.getModelBlob() == null) { - throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); + KNNCodecUtil.Pair pair = null; + try { + pair = KNNCodecUtil.getFloats(values); + if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { + logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); + return; + } + long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); + if (isMerge) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); + } + // Increment counter for number of graph index requests + KNNCounter.GRAPH_INDEX_REQUESTS.increment(); + final KNNEngine knnEngine = getKNNEngine(field); + final String engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getVersion(), + field.name, + knnEngine.getExtension() + ); + final String indexPath = Paths.get( + ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), + engineFileName + ).toString(); + KNNCodecUtil.Pair finalPair = pair; + NativeIndexCreator indexCreator; + // Create library index either from model or from scratch + if (field.attributes().containsKey(MODEL_ID)) { + String modelId = field.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), finalPair, knnEngine, indexPath); + } else { + indexCreator = () -> createKNNIndexFromScratch(field, finalPair, knnEngine, indexPath); } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); - } else { - - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getVersion(), field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) - .toString(); - - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } + if (isMerge) { + recordMergeStats(pair.docs.length, arraySize); + } - if (isMerge) { - recordMergeStats(pair.docs.length, arraySize); - } + if (isRefresh) { + recordRefreshStats(); + } - if (isRefresh) { - recordRefreshStats(); + // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that + // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will + // not be marked as added to the directory. + state.directory.createOutput(engineFileName, state.context).close(); + indexCreator.createIndex(); + writeFooter(indexPath, engineFileName); + } finally { + // Freeing up the Native memory where vectors was stored. We added a try block here to ensure that even + // in case of exceptions we are freeing up the space to avoid memory leaks. + if (pair != null) { + JNICommons.freeVectorData(pair.getVectorAddress()); + } } - - // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that - // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will - // not be marked as added to the directory. - state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); - writeFooter(indexPath, engineFileName); } private void recordMergeStats(int length, long arraySize) { @@ -186,7 +193,15 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine); + JNIService.createIndexFromTemplate( + pair.docs, + pair.getVectorAddress(), + pair.getDimension(), + indexPath, + model, + parameters, + knnEngine + ); return null; }); } @@ -228,7 +243,7 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa // Pass the path for the nms library to save the file AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine); + JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); return null; }); } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 02ab2d833..eecec6b50 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,9 +5,13 @@ package org.opensearch.knn.index.codec.util; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.jni.JNICommons; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -26,21 +30,24 @@ public class KNNCodecUtil { // Java rounds each array size up to multiples of 8 bytes public static final int JAVA_ROUNDING_NUMBER = 8; + @AllArgsConstructor public static final class Pair { - public Pair(int[] docs, float[][] vectors, SerializationMode serializationMode) { - this.docs = docs; - this.vectors = vectors; - this.serializationMode = serializationMode; - } - public int[] docs; - public float[][] vectors; + @Getter + @Setter + private long vectorAddress; + @Getter + @Setter + private int dimension; public SerializationMode serializationMode; + } public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { ArrayList vectorList = new ArrayList<>(); ArrayList docIdList = new ArrayList<>(); + long vectorAddress = 0; + int dimension = 0; SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); @@ -48,20 +55,22 @@ public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOExcep serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + dimension = vector.length; vectorList.add(vector); } docIdList.add(doc); } - return new KNNCodecUtil.Pair( - docIdList.stream().mapToInt(Integer::intValue).toArray(), - vectorList.toArray(new float[][] {}), - serializationMode - ); + if (vectorList.isEmpty() == false) { + vectorAddress = JNICommons.storeVectorData( + vectorAddress, + vectorList.toArray(new float[][] {}), + (long) vectorList.size() * dimension + ); + } + return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorAddress, dimension, serializationMode); } - public static long calculateArraySize(float[][] vectors, SerializationMode serializationMode) { - int vectorLength = vectors[0].length; - int numVectors = vectors.length; + public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { if (serializationMode == SerializationMode.ARRAY) { int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 286e6265c..c8b56436b 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.watcher.FileWatcher; @@ -313,7 +314,7 @@ private void cleanup() { closed = true; if (this.memoryAddress != 0) { - JNIService.freeVectors(this.memoryAddress); + JNICommons.freeVectorData(this.memoryAddress); } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 517945968..9b7c545a5 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -53,24 +53,27 @@ class FaissService { * Create an index for the native library * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param parameters parameters to build index */ - public static native void createIndex(int[] ids, float[][] data, String indexPath, Map parameters); + public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); /** * Create an index for the native library with a provided template index * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param templateIndex empty template index * @param parameters additional build time parameters */ public static native void createIndexFromTemplate( int[] ids, - float[][] data, + long vectorsAddress, + int dim, String indexPath, byte[] templateIndex, Map parameters @@ -184,15 +187,4 @@ public static native KNNQueryResult[] queryIndexWithFilter( */ @Deprecated(since = "2.14.0", forRemoval = true) public static native long transferVectors(long vectorsPointer, float[][] trainingData); - - /** - *

- * The function is deprecated. Use {@link JNICommons#freeVectorData(long)} - *

- * Free vectors from memory - * - * @param vectorsPointer to be freed - */ - @Deprecated(since = "2.14.0", forRemoval = true) - public static native void freeVectors(long vectorsPointer); } diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index a85089517..70a8f5b71 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -26,19 +26,28 @@ public class JNIService { * Create an index for the native library * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param parameters parameters to build index * @param knnEngine engine to build index for */ - public static void createIndex(int[] ids, float[][] data, String indexPath, Map parameters, KNNEngine knnEngine) { + public static void createIndex( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + Map parameters, + KNNEngine knnEngine + ) { + if (KNNEngine.NMSLIB == knnEngine) { - NmslibService.createIndex(ids, data, indexPath, parameters); + NmslibService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); return; } if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndex(ids, data, indexPath, parameters); + FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); return; } @@ -49,7 +58,8 @@ public static void createIndex(int[] ids, float[][] data, String indexPath, Map< * Create an index for the native library with a provided template index * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of vectors to be indexed * @param indexPath path to save index file to * @param templateIndex empty template index * @param parameters parameters to build index @@ -57,14 +67,15 @@ public static void createIndex(int[] ids, float[][] data, String indexPath, Map< */ public static void createIndexFromTemplate( int[] ids, - float[][] data, + long vectorsAddress, + int dim, String indexPath, byte[] templateIndex, Map parameters, KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndexFromTemplate(ids, data, indexPath, templateIndex, parameters); + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); return; } @@ -248,17 +259,4 @@ public static byte[] trainIndex(Map indexParameters, int dimensi public static long transferVectors(long vectorsPointer, float[][] trainingData) { return FaissService.transferVectors(vectorsPointer, trainingData); } - - /** - *

- * The function is deprecated. Use {@link JNICommons#freeVectorData(long)} - *

- * Free vectors from memory - * - * @param vectorsPointer to be freed - */ - @Deprecated(since = "2.14.0", forRemoval = true) - public static void freeVectors(long vectorsPointer) { - FaissService.freeVectors(vectorsPointer); - } } diff --git a/src/main/java/org/opensearch/knn/jni/NmslibService.java b/src/main/java/org/opensearch/knn/jni/NmslibService.java index 77896822a..46d9266e8 100644 --- a/src/main/java/org/opensearch/knn/jni/NmslibService.java +++ b/src/main/java/org/opensearch/knn/jni/NmslibService.java @@ -42,11 +42,12 @@ class NmslibService { * Create an index for the native library * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param parameters parameters to build index */ - public static native void createIndex(int[] ids, float[][] data, String indexPath, Map parameters); + public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); /** * Load an index into memory diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index cdf3109a4..2ce3a7c83 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -38,6 +38,7 @@ import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.knn.plugin.stats.KNNGraphValue; @@ -357,7 +358,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio modelBytes, modelId ); - JNIService.freeVectors(trainingPtr); + JNICommons.freeVectorData(trainingPtr); // Setup the model cache to return the correct model ModelDao modelDao = mock(ModelDao.class); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index fabffab46..7573a4394 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -52,7 +53,8 @@ public void testIndexAllocation_close() throws InterruptedException { Arrays.fill(vectors[i], 1f); } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); - JNIService.createIndex(ids, vectors, path, parameters, knnEngine); + long vectorMemoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); + JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 207b6373b..a84974202 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -16,6 +16,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.SpaceType; @@ -53,7 +54,8 @@ public void testIndexLoadStrategy_load() throws IOException { Arrays.fill(vectors[i], 1f); } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); - JNIService.createIndex(ids, vectors, path, parameters, knnEngine); + long memoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); + JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 36a3d93be..3d4b63549 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -83,7 +83,8 @@ public void testCreateIndex_invalid_engineNotSupported() { IllegalArgumentException.class, () -> JNIService.createIndex( new int[] {}, - new float[][] {}, + 0, + 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.LUCENE @@ -96,7 +97,8 @@ public void testCreateIndex_invalid_engineNull() { Exception.class, () -> JNIService.createIndex( new int[] {}, - new float[][] {}, + 0, + 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null @@ -105,12 +107,12 @@ public void testCreateIndex_invalid_engineNull() { } public void testCreateIndex_nmslib_invalid_noSpaceType() { - expectThrows( Exception.class, () -> JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), "something", Collections.emptyMap(), KNNEngine.NMSLIB @@ -122,13 +124,14 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept int[] docIds = new int[] { 1, 2, 3 }; float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; - + long memoryAddress = JNICommons.storeVectorData(0, vectors1, vectors1.length * vectors1[0].length); Path tmpFile1 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors1, + memoryAddress, + vectors1[0].length, tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -136,13 +139,15 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept ); float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; + long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); Path tmpFile2 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors2, + memoryAddress2, + vectors2[0].length, tmpFile2.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -154,13 +159,14 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { int[] docIds = new int[] {}; float[][] vectors = new float[][] {}; - + long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( null, - vectors, + memoryAddress, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -171,7 +177,8 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - null, + 0, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -182,7 +189,8 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + 0, null, ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -191,14 +199,15 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) + () -> JNIService.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) ); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null @@ -210,13 +219,14 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { int[] docIds = new int[] { 1 }; float[][] vectors = new float[][] { { 2, 3 } }; - + long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.NMSLIB @@ -224,28 +234,11 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { ); } - public void testCreateIndex_nmslib_invalid_inconsistentDimensions() throws IOException { - - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> JNIService.createIndex( - docIds, - vectors, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) - ); - } - public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException { - int[] docIds = new int[] {}; - float[][] vectors = new float[][] {}; + int[] docIds = new int[] { 1 }; + float[][] vectors = new float[][] { { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length * vectors[0].length); Map parametersMap = ImmutableMap.of( KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, @@ -258,7 +251,8 @@ public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue(), KNNConstants.PARAMETERS, parametersMap), KNNEngine.NMSLIB @@ -273,7 +267,8 @@ public void testCreateIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.NMSLIB @@ -284,7 +279,8 @@ public void testCreateIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of( KNNConstants.SPACE_TYPE, @@ -301,15 +297,14 @@ public void testCreateIndex_nmslib_valid() throws IOException { } public void testCreateIndex_faiss_invalid_noSpaceType() { - int[] docIds = new int[] {}; - float[][] vectors = new float[][] {}; expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), "something", ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod), KNNEngine.FAISS @@ -321,13 +316,14 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti int[] docIds = new int[] { 1, 2, 3 }; float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; - + long memoryAddress = JNICommons.storeVectorData(0, vectors1, vectors1.length * vectors1[0].length); Path tmpFile1 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors1, + memoryAddress, + vectors1[0].length, tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -335,13 +331,14 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti ); float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; - + long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); Path tmpFile2 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors2, + memoryAddress, + vectors2[0].length, tmpFile2.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -353,13 +350,15 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { int[] docIds = new int[] {}; float[][] vectors = new float[][] {}; + long memoryAddress = JNICommons.storeVectorData(0, vectors, 0); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( null, - vectors, + memoryAddress, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -370,7 +369,8 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - null, + 0, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -381,7 +381,8 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), null, ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -390,14 +391,22 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), null, KNNEngine.FAISS) + () -> JNIService.createIndex( + docIds, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + null, + KNNEngine.FAISS + ) ); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null @@ -409,13 +418,15 @@ public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { int[] docIds = new int[] { 1 }; float[][] vectors = new float[][] { { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.FAISS @@ -423,35 +434,19 @@ public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { ); } - public void testCreateIndex_faiss_invalid_inconsistentDimensions() throws IOException { - - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> JNIService.createIndex( - docIds, - vectors, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); - } - public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOException { int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -460,16 +455,16 @@ public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOExceptio } public void testCreateIndex_faiss_invalid_invalidIndexDescription() throws IOException { - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; - + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "invalid", KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -482,6 +477,8 @@ public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { int[] docIds = new int[] { 1, 2 }; float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + String sqfp16InvalidIndexDescription = "HNSW16,SQfp1655"; Path tmpFile = createTempFile(); @@ -489,7 +486,8 @@ public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of( INDEX_DESCRIPTION_PARAMETER, @@ -508,11 +506,12 @@ public void testLoadIndex_faiss_sqfp16_valid() { int[] docIds = new int[] { 1, 2 }; float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; String sqfp16IndexDescription = "HNSW16,SQfp16"; - + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -528,11 +527,13 @@ public void testQueryIndex_faiss_sqfp16_valid() { String sqfp16IndexDescription = "HNSW16,SQfp16"; int k = 10; - + float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); + long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); Path tmpFile = createTempFile(); JNIService.createIndex( testData.indexData.docs, - truncateToFp16Range(testData.indexData.vectors), + memoryAddress, + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -596,7 +597,7 @@ public void testTrain_whenConfigurationIsIVFSQFP16_thenSucceed() { byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOException { @@ -609,7 +610,8 @@ public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOExcept Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of( INDEX_DESCRIPTION_PARAMETER, @@ -634,7 +636,8 @@ public void testCreateIndex_faiss_valid() throws IOException { Path tmpFile1 = createTempFile(); JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS @@ -684,7 +687,8 @@ public void testLoadIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -719,7 +723,8 @@ public void testLoadIndex_faiss_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -745,7 +750,8 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -770,7 +776,8 @@ public void testQueryIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.NMSLIB @@ -802,7 +809,8 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -826,7 +834,8 @@ public void testQueryIndex_faiss_valid() throws IOException { Path tmpFile = createTempFile(); JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS @@ -867,7 +876,8 @@ public void testQueryIndex_faiss_parentIds() throws IOException { Path tmpFile = createTempFile(); JNIService.createIndex( testDataNested.indexData.docs, - testDataNested.indexData.vectors, + testDataNested.indexData.getVectorAddress(), + testDataNested.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS @@ -941,7 +951,8 @@ public void testFree_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -964,7 +975,8 @@ public void testFree_faiss_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -987,7 +999,7 @@ public void testTransferVectors() { assertEquals(trainPointer1, trainPointer2); } - JNIService.freeVectors(trainPointer1); + JNICommons.freeVectorData(trainPointer1); } public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOException { @@ -1008,7 +1020,7 @@ public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOExceptio byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException { @@ -1038,7 +1050,7 @@ public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException { @@ -1065,7 +1077,7 @@ public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } private long transferVectors(int numDuplicates) { @@ -1120,12 +1132,13 @@ public void testCreateIndexFromTemplate() throws IOException { byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer1, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer1); + JNICommons.freeVectorData(trainPointer1); Path tmpFile1 = createTempFile(); JNIService.createIndexFromTemplate( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile1.toAbsolutePath().toString(), faissIndex, ImmutableMap.of(INDEX_THREAD_QTY, 1), @@ -1255,11 +1268,12 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); Path tmpFile = createTempFile(); JNIService.createIndexFromTemplate( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), faissIndex, ImmutableMap.of(INDEX_THREAD_QTY, 1), @@ -1274,7 +1288,8 @@ private String createFaissHNSWIndex(SpaceType spaceType) throws IOException { Path tmpFile = createTempFile(); JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.indexData.getVectorAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index e4c01a8ec..06b96c57c 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -25,6 +25,7 @@ import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import java.io.File; @@ -170,7 +171,7 @@ public void testRun_success() throws IOException, ExecutionException { when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); doAnswer(invocationOnMock -> { - JNIService.freeVectors(memoryAddress); + JNICommons.freeVectorData(memoryAddress); return null; }).when(nativeMemoryCacheManager).invalidate(tdataKey); @@ -197,11 +198,12 @@ public void testRun_success() throws IOException, ExecutionException { int[] ids = { 1, 2, 3, 4 }; float[][] vectors = new float[ids.length][dimension]; fillFloatArrayRandomly(vectors); - + long vectorsMemoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path indexPath = createTempFile(); JNIService.createIndexFromTemplate( ids, - vectors, + vectorsMemoryAddress, + vectors[0].length, indexPath.toString(), model.getModelBlob(), ImmutableMap.of(INDEX_THREAD_QTY, 1), @@ -456,7 +458,7 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); doAnswer(invocationOnMock -> { - JNIService.freeVectors(memoryAddress); + JNICommons.freeVectorData(memoryAddress); return null; }).when(nativeMemoryCacheManager).invalidate(tdataKey); diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 0de11b2f2..20868d54c 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -6,17 +6,20 @@ package org.opensearch.knn; import com.google.common.collect.ImmutableMap; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.util.Comparator; import java.util.Random; @@ -248,7 +251,7 @@ public static PriorityQueue insertWithOverflow(PriorityQueue idsList = new ArrayList<>(); List vectorsList = new ArrayList<>(); @@ -292,8 +295,9 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { vectorsArray[i][j] = vectorsList.get(i)[j]; } } + long memoryAddress = JNICommons.storeVectorData(0, vectorsArray, (long) vectorsArray.length * vectorsArray[0].length); - return new KNNCodecUtil.Pair(idsArray, vectorsArray, SerializationMode.COLLECTION_OF_FLOATS); + return new Pair(idsArray, memoryAddress, vectorsArray[0].length, SerializationMode.COLLECTION_OF_FLOATS, vectorsArray); } private float[][] readQueries(String path) throws IOException { @@ -324,5 +328,18 @@ private float[][] readQueries(String path) throws IOException { } return queryArray; } + + @AllArgsConstructor + public static class Pair { + public int[] docs; + @Getter + @Setter + private long vectorAddress; + @Getter + @Setter + private int dimension; + public SerializationMode serializationMode; + public float[][] vectors; + } } }