diff --git a/jni/include/nmslib_wrapper.h b/jni/include/nmslib_wrapper.h index 6d862048a..1bce7125b 100644 --- a/jni/include/nmslib_wrapper.h +++ b/jni/include/nmslib_wrapper.h @@ -28,6 +28,9 @@ namespace knn_jni { void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ); + void CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorAddressJ, jint dim, jstring indexPathJ, jobject parametersJ); + // Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters // // Return a pointer to the loaded index diff --git a/jni/include/org_opensearch_knn_jni_NmslibService.h b/jni/include/org_opensearch_knn_jni_NmslibService.h index 02f58d20f..67079ae11 100644 --- a/jni/include/org_opensearch_knn_jni_NmslibService.h +++ b/jni/include/org_opensearch_knn_jni_NmslibService.h @@ -26,6 +26,9 @@ extern "C" { JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndexWithMemoryAddress + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); + /* * Class: org_opensearch_knn_jni_NmslibService * Method: loadIndex diff --git a/jni/src/nmslib_wrapper.cpp b/jni/src/nmslib_wrapper.cpp index f63fd2b01..330121736 100644 --- a/jni/src/nmslib_wrapper.cpp +++ b/jni/src/nmslib_wrapper.cpp @@ -164,6 +164,152 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J } } +void knn_jni::nmslib_wrapper::CreateIndex_With_Memory_Address(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorAddressJ, jint dim, jstring indexPathJ, jobject parametersJ) { + + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorAddressJ == 0) { + throw std::runtime_error("Vectors Address cannot be 0"); + } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // Handle parameters + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + std::vector indexParameters; + + // Algorithm parameters will be in a sub map + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); + + if(subParametersCpp.find(knn_jni::EF_CONSTRUCTION) != subParametersCpp.end()) { + auto efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::EF_CONSTRUCTION]); + indexParameters.push_back(knn_jni::EF_CONSTRUCTION_NMSLIB + "=" + std::to_string(efConstruction)); + } + + if(subParametersCpp.find(knn_jni::M) != subParametersCpp.end()) { + auto m = jniUtil->ConvertJavaObjectToCppInteger(env, subParametersCpp[knn_jni::M]); + indexParameters.push_back(knn_jni::M_NMSLIB + "=" + std::to_string(m)); + } + + jniUtil->DeleteLocalRef(env, subParametersJ); + } + + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto indexThreadQty = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + indexParameters.push_back(knn_jni::INDEX_THREAD_QUANTITY + "=" + std::to_string(indexThreadQty)); + } + + jniUtil->DeleteLocalRef(env, parametersJ); + + // Get the path to save the index + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + + // Get space type for this index + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + spaceTypeCpp = TranslateSpaceType(spaceTypeCpp); + + std::unique_ptr> space; + space.reset(similarity::SpaceFactoryRegistry::Instance().CreateSpace(spaceTypeCpp,similarity::AnyParams())); + + std::vector *inputVectors = reinterpret_cast*>(vectorAddressJ); + + // Get number of ids and vectors and dimension + //int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + int numVectors = inputVectors->size() / dim; + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + //int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); + + // Read dataset + similarity::ObjectVector dataset; + dataset.reserve(numVectors); + int* idsCpp; + int topLevelPointer = 0; + try { + // Read in data set + idsCpp = jniUtil->GetIntArrayElements(env, idsJ, nullptr); + + //float* floatArrayCpp; + jfloatArray floatArrayJ; + size_t vectorSizeInBytes = dim*sizeof(float); + + // Allocate a large buffer that will contain all the vectors. Allocating the objects in one large buffer as + // opposed to individually will prevent heap fragmentation. We have observed that allocating individual + // objects causes RSS to rise throughout the lifetime of a process + // (see https://github.com/opensearch-project/k-NN/issues/772 and + // https://github.com/opensearch-project/k-NN/issues/72). This is because, in typical systems, small + // allocations will reside on some kind of heap managed by an allocator. Once freed, the allocator does not + // always return the memory to the OS. If the heap gets fragmented, this will cause the allocator + // to ask for more memory, causing RSS to grow. On large allocations (> 128 kb), most allocators will + // internally use mmap. Once freed, unmap will be called, which will immediately return memory to the OS + // which in turn prevents RSS from growing out of control. Wrap with a smart pointer so that buffer will be + // freed once variable goes out of scope. For reference, the code that specifies the layout of the buffer can be + // found: https://github.com/nmslib/nmslib/blob/v2.1.1/similarity_search/include/object.h#L61-L75 + std::unique_ptr objectBuffer(new char[(similarity::ID_SIZE + similarity::LABEL_SIZE + similarity::DATALENGTH_SIZE + vectorSizeInBytes) * numVectors]); + char* ptr = objectBuffer.get(); + for (int i = 0; i < numVectors; i++) { + dataset.push_back(new similarity::Object(ptr)); + + memcpy(ptr, &idsCpp[i], similarity::ID_SIZE); + ptr += similarity::ID_SIZE; + memcpy(ptr, &DEFAULT_LABEL, similarity::LABEL_SIZE); + ptr += similarity::LABEL_SIZE; + memcpy(ptr, &vectorSizeInBytes, similarity::DATALENGTH_SIZE); + ptr += similarity::DATALENGTH_SIZE; + +// floatArrayJ = (jfloatArray)jniUtil->GetObjectArrayElement(env, vectorsJ, i); +// if (dim != jniUtil->GetJavaFloatArrayLength(env, floatArrayJ)) { +// throw std::runtime_error("Dimension of vectors is inconsistent"); +// } + + //floatArrayCpp = jniUtil->GetFloatArrayElements(env, floatArrayJ, nullptr); + float floatArrayCpp[dim]; + for(int j = 0 ; j < dim; j ++) { + floatArrayCpp[j] = inputVectors->at(topLevelPointer); + topLevelPointer++; + } + + memcpy(ptr, &floatArrayCpp, vectorSizeInBytes); + //jniUtil->ReleaseFloatArrayElements(env, floatArrayJ, floatArrayCpp, JNI_ABORT); + ptr += vectorSizeInBytes; + } + jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); + + std::unique_ptr> index; + index.reset(similarity::MethodFactoryRegistry::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset)); + index->CreateIndex(similarity::AnyParams(indexParameters)); + index->SaveIndex(indexPathCpp); + + for (auto & it : dataset) { + delete it; + } + delete inputVectors; + } catch (...) { + for (auto & it : dataset) { + delete it; + } + delete inputVectors; + + jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); + throw; + } +} + + jlong knn_jni::nmslib_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ, jobject parametersJ) { diff --git a/jni/src/org_opensearch_knn_jni_NmslibService.cpp b/jni/src/org_opensearch_knn_jni_NmslibService.cpp index 11dd885b1..5eabdeb9c 100644 --- a/jni/src/org_opensearch_knn_jni_NmslibService.cpp +++ b/jni/src/org_opensearch_knn_jni_NmslibService.cpp @@ -48,6 +48,19 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNI } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndexWithMemoryAddress(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorAddressJ, jint dim, jstring indexPathJ, + jobject parametersJ) +{ + try { + knn_jni::nmslib_wrapper::CreateIndex_With_Memory_Address(&jniUtil, env, idsJ, vectorAddressJ, dim ,indexPathJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + + + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ, jobject parametersJ) { diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 901328766..dbd18c8c6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -109,15 +109,15 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); - if (pair.vectors.length == 0 || pair.docs.length == 0) { + if (pair.vectorsAddress == 0 && (pair.vectors.length == 0 || pair.docs.length == 0)) { logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); return; } - long arraySize = calculateArraySize(pair.vectors, pair.serializationMode); + //long arraySize = calculateArraySize(pair.vectors, pair.serializationMode); if (isMerge) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); + //KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); } // Increment counter for number of graph index requests KNNCounter.GRAPH_INDEX_REQUESTS.increment(); @@ -150,7 +150,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, } if (isMerge) { - recordMergeStats(pair.docs.length, arraySize); + //recordMergeStats(pair.docs.length, arraySize); } if (isRefresh) { @@ -223,7 +223,8 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa // Pass the path for the nms library to save the file AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine.getName()); + JNIService.createIndex(pair.docs, pair.vectors, pair.vectorsAddress, pair.dim, indexPath, parameters, + knnEngine.getName()); return null; }); } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 02ab2d833..ccf2d217e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,9 +5,12 @@ package org.opensearch.knn.index.codec.util; +import org.apache.commons.lang.ArrayUtils; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.jni.JNIService; +import org.apache.commons.lang.ArrayUtils; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -26,37 +29,59 @@ public class KNNCodecUtil { // Java rounds each array size up to multiples of 8 bytes public static final int JAVA_ROUNDING_NUMBER = 8; + private static final int MAX_SIZE_OF_VECTOR_ARRAY = 10000; + public static final class Pair { public Pair(int[] docs, float[][] vectors, SerializationMode serializationMode) { this.docs = docs; this.vectors = vectors; this.serializationMode = serializationMode; + this.vectorsAddress = 0; + this.dim = 0; } public int[] docs; public float[][] vectors; public SerializationMode serializationMode; + public long vectorsAddress; + public int dim; } public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { ArrayList vectorList = new ArrayList<>(); ArrayList docIdList = new ArrayList<>(); SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; + long vectorAddress = 0; + int dim = 0; for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + dim = vector.length; vectorList.add(vector); } + if(vectorList.size() == MAX_SIZE_OF_VECTOR_ARRAY) { + vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {})); + vectorList = new ArrayList<>(); + } docIdList.add(doc); } - return new KNNCodecUtil.Pair( + + if(vectorList.size() > 0) { + vectorAddress = JNIService.transferVectors(vectorAddress, vectorList.toArray(new float[][] {})); + vectorList = new ArrayList<>(); + } + KNNCodecUtil.Pair pair = new KNNCodecUtil.Pair( docIdList.stream().mapToInt(Integer::intValue).toArray(), - vectorList.toArray(new float[][] {}), + new float[0][], serializationMode ); + pair.vectorsAddress = vectorAddress; + pair.dim = dim; + return pair; + } public static long calculateArraySize(float[][] vectors, SerializationMode serializationMode) { diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index f45fb0c73..75be03815 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -45,6 +45,21 @@ public static void createIndex(int[] ids, float[][] data, String indexPath, Map< throw new IllegalArgumentException("CreateIndex not supported for provided engine"); } + public static void createIndex(int[] ids, float[][] data, long vectorAddress, int dim, String indexPath, + Map parameters, String engineName) { + if (KNNEngine.NMSLIB.getName().equals(engineName)) { + NmslibService.createIndexWithMemoryAddress(ids, vectorAddress, dim, indexPath, parameters); + return; + } + + if (KNNEngine.FAISS.getName().equals(engineName)) { + FaissService.createIndex(ids, data, indexPath, parameters); + return; + } + + throw new IllegalArgumentException("CreateIndex not supported for provided engine"); + } + /** * Create an index for the native library with a provided template index * diff --git a/src/main/java/org/opensearch/knn/jni/NmslibService.java b/src/main/java/org/opensearch/knn/jni/NmslibService.java index 77896822a..682a14f76 100644 --- a/src/main/java/org/opensearch/knn/jni/NmslibService.java +++ b/src/main/java/org/opensearch/knn/jni/NmslibService.java @@ -48,6 +48,9 @@ class NmslibService { */ public static native void createIndex(int[] ids, float[][] data, String indexPath, Map parameters); + public static native void createIndexWithMemoryAddress(int[] ids, long address, int dim, String indexPath, + Map parameters); + /** * Load an index into memory *