Skip to content

Commit

Permalink
First non working commit for reducing memory
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Apr 11, 2024
1 parent 172dc84 commit c15724c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
5 changes: 5 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ namespace knn_jni {
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);

// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
long long CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jlong indexAddress, jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
Expand Down
83 changes: 83 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,89 @@ bool isIndexIVFPQL2(faiss::Index * index);
// IndexIDMap which has member that will point to underlying index that stores the data
faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index);

long long CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jlong indexAddressJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

// parametersJ is a Java Map<String, Object>. ConvertJavaMapToCppMap converts it to a c++ map<string, jobject>
// so that it is easier to access.
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

// Get space type for this index
jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
faiss::IndexIDMap *idMap = nullptr;
if(indexAddressJ == 0) {
int dim = (int) dimJ;
// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if (numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory(dim, indexDescriptionCpp.c_str(), metric));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env,
parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
if (parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get());
jniUtil->DeleteLocalRef(env, subParametersJ);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Check that the index does not need to be trained
if (!indexWriter->is_trained) {
throw std::runtime_error("Index is not trained");
}

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
idMap = new faiss::IndexIDMap(indexWriter.get());
} else {

}


idMap->add_with_ids(numVectors, inputVectors->data(), idVector.data());
return (long long)idMap;
}


void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ) {

Expand Down

0 comments on commit c15724c

Please sign in to comment.