diff --git a/CHANGELOG.md b/CHANGELOG.md index 0623e6f986..4217a4e1ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,3 +24,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Introduce KNNVectorValues interface to iterate on different types of Vector values during indexing and search [#1897](https://github.com/opensearch-project/k-NN/pull/1897) * Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824) * Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913) +* Refactor train index and create index from template APIs in JNI layer [#1918](https://github.com/opensearch-project/k-NN/pull/1918) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 76e8cfdb97..5d4ec66506 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -88,6 +88,10 @@ class IndexService { std::unordered_map parameters, std::vector templateIndexData); + virtual std::vector trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters); + + virtual void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); + virtual ~IndexService() = default; protected: std::unique_ptr faissMethods; @@ -101,7 +105,7 @@ class BinaryIndexService : public IndexService { public: //TODO Remove dependency on JNIUtilInterface and JNIEnv //TODO Reduce the number of parameters - BinaryIndexService(std::unique_ptr faissMethods); + explicit BinaryIndexService(std::unique_ptr faissMethods); /** * Create binary index @@ -118,7 +122,7 @@ class BinaryIndexService : public IndexService { * @param indexPath path to write index * @param parameters parameters to be applied to faiss index */ - virtual void createIndex( + void createIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, @@ -145,7 +149,7 @@ class BinaryIndexService : public IndexService { * @param parameters parameters to be applied to faiss index * @param templateIndexData vector containing the template index data */ - virtual void createIndexFromTemplate( + void createIndexFromTemplate( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, int dim, @@ -154,9 +158,14 @@ class BinaryIndexService : public IndexService { std::vector ids, std::string indexPath, std::unordered_map parameters, - std::vector templateIndexData); + std::vector templateIndexData) override; + + void InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); + + std::vector trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters) override; + - virtual ~BinaryIndexService() = default; + ~BinaryIndexService() override = default; }; } diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 44882210d0..e2ea0e9553 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -94,14 +94,7 @@ namespace knn_jni { // // Return the serialized representation jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, - jlong trainVectorsPointerJ); - - // Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with - // the vector of floats located at trainVectorsPointerJ. - // - // Return the serialized representation - jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, - jlong trainVectorsPointerJ); + jlong trainVectorsPointerJ, IndexService* indexService); /* * Perform a range search with filter against the index located in memory at indexPointerJ. diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 9432db6033..c683a56354 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -138,6 +138,37 @@ void IndexService::createIndexFromTemplate( faissMethods->writeIndex(idMap.get(), indexPath.c_str()); } +void IndexService::InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + if (indexIvf->quantizer_trains_alone == 2) { + InternalTrainIndex(indexIvf->quantizer, n, x); + } + indexIvf->make_direct_map(); + } + + if (!index->is_trained) { + index->train(n, x); + } +} + +std::vector IndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters) { + // Create faiss index + std::unique_ptr index(faissMethods->indexFactory(dimension, indexDescription.c_str(), metric)); + + // Train index if needed + if (!index->is_trained) { + InternalTrainIndex(index.get(), numVectors, trainingVectors); + } + + // Write index to a vector + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index(index.get(), &vectorIoWriter); + + return std::vector(vectorIoWriter.data.begin(), vectorIoWriter.data.end()); +} + + + BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} void BinaryIndexService::createIndex( @@ -223,5 +254,35 @@ void BinaryIndexService::createIndexFromTemplate( faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); } +void BinaryIndexService::InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + if (!indexIvf->is_trained) { + indexIvf->train(n, reinterpret_cast(x)); + } + } + if (!index->is_trained) { + index->train(n, reinterpret_cast(x)); + } +} + +std::vector BinaryIndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters) { + // Convert Java parameters to C++ parameters + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory(dimension, indexDescription.c_str())); + + // Train the index if it is not already trained + if (!indexWriter->is_trained) { + InternalTrainIndex(indexWriter.get(), numVectors, trainingVectors); + } + + // Serialize the trained index to a byte array + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); + + // Convert the serialized data to a std::vector + std::vector trainedIndexData(vectorIoWriter.data.begin(), vectorIoWriter.data.end()); + + return trainedIndexData; +} } // namespace faiss_wrapper } // namespace knn_jni diff --git a/jni/src/faiss_methods.cpp b/jni/src/faiss_methods.cpp index abc70d4605..0c0924e74a 100644 --- a/jni/src/faiss_methods.cpp +++ b/jni/src/faiss_methods.cpp @@ -32,12 +32,15 @@ faiss::IndexIDMapTemplate* FaissMethods::indexBinaryIdMap(fa void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) { faiss::write_index(idx, fname); } + void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) { faiss::write_index_binary(idx, fname); } + faiss::Index* FaissMethods::readIndex(faiss::IOReader* f, int io_flags) { return faiss::read_index(f, io_flags); } + faiss::IndexBinary* FaissMethods::readIndexBinary(faiss::IOReader* f, int io_flags) { return faiss::read_index_binary(f, io_flags); } diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 3878e4f6c4..7259643e78 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -556,7 +556,7 @@ void knn_jni::faiss_wrapper::InitLibrary() { } jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, - jint dimensionJ, jlong trainVectorsPointerJ) { + jint dimensionJ, jlong trainVectorsPointerJ, IndexService* indexService) { // First, we need to build the index if (parametersJ == nullptr) { throw std::runtime_error("Parameters cannot be null"); @@ -572,16 +572,6 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); - std::unique_ptr indexWriter; - indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric)); - - // Related to https://github.com/facebookresearch/faiss/issues/1621. HNSWPQ defaults to l2 even when metric is - // passed in. This updates it to the correct metric. - indexWriter->metric_type = metric; - if (auto * indexHnswPq = dynamic_cast(indexWriter.get())) { - indexHnswPq->storage->metric_type = metric; - } - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); @@ -589,86 +579,19 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti } // Add extra parameters that can't be configured with the index factory + std::unordered_map subParametersCpp; if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { - jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; - auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); - SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get()); - jniUtil->DeleteLocalRef(env, subParametersJ); - } - - // Train index if needed - auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); - int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; - if(!indexWriter->is_trained) { - InternalTrainIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); - } - jniUtil->DeleteLocalRef(env, parametersJ); - - // Now that indexWriter is trained, we just load the bytes into an array and return - faiss::VectorIOWriter vectorIoWriter; - faiss::write_index(indexWriter.get(), &vectorIoWriter); - - // Wrap in smart pointer - std::unique_ptr jbytesBuffer; - jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); - int c = 0; - for (auto b : vectorIoWriter.data) { - jbytesBuffer[c++] = (jbyte) b; - } - - jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); - jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); - return ret; -} - -jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, - jint dimensionJ, jlong trainVectorsPointerJ) { - // First, we need to build the index - if (parametersJ == nullptr) { - throw std::runtime_error("Parameters cannot be null"); - } - - auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); - - jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); - std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); - faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); - - // Create faiss index - jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); - std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); - - std::unique_ptr indexWriter; - indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str())); - - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { - auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); - omp_set_num_threads(threadCount); + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); } - // Train index if needed + // Train index using IndexService auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); - int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; - if(!indexWriter->is_trained) { - InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); - } - jniUtil->DeleteLocalRef(env, parametersJ); + int numVectors = trainingVectorsPointerCpp->size() / (int) dimensionJ; + std::vector trainedIndexData = indexService->trainIndex(jniUtil, env, metric, indexDescriptionCpp, dimensionJ, numVectors, trainingVectorsPointerCpp->data(), subParametersCpp); // Now that indexWriter is trained, we just load the bytes into an array and return - faiss::VectorIOWriter vectorIoWriter; - faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); - - // Wrap in smart pointer - std::unique_ptr jbytesBuffer; - jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); - int c = 0; - for (auto b : vectorIoWriter.data) { - jbytesBuffer[c++] = (jbyte) b; - } - - jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); - jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); + jbyteArray ret = jniUtil->NewByteArray(env, trainedIndexData.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, trainedIndexData.size(), reinterpret_cast(trainedIndexData.data())); return ret; } @@ -717,28 +640,6 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, } } -void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { - if (auto * indexIvf = dynamic_cast(index)) { - if (indexIvf->quantizer_trains_alone == 2) { - InternalTrainIndex(indexIvf->quantizer, n, x); - } - indexIvf->make_direct_map(); - } - - if (!index->is_trained) { - index->train(n, x); - } -} - -void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { - if (auto * indexIvf = dynamic_cast(index)) { - indexIvf->make_direct_map(); - } - if (!index->is_trained) { - index->train(n, reinterpret_cast(x)); - } -} - std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index d0df5adac9..36b91b74e8 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -86,7 +86,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService); + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -232,7 +232,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex jlong trainVectorsPointerJ) { try { - return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ, &indexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -245,7 +247,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinar jlong trainVectorsPointerJ) { try { - return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ, &indexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); }