Skip to content

Commit

Permalink
refactor jni train index
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Aug 1, 2024
1 parent 8123cfc commit 293f46d
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 123 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Introduce KNNVectorValues interface to iterate on different types of Vector values during indexing and search [#1897](https://github.com/opensearch-project/k-NN/pull/1897)
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
* Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913)
* Refactor train index and create index from template APIs in JNI layer [#1918](https://github.com/opensearch-project/k-NN/pull/1918)
19 changes: 14 additions & 5 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class IndexService {
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData);

virtual std::vector<uint8_t> trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters);

virtual void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x);

virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
Expand All @@ -101,7 +105,7 @@ class BinaryIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);
explicit BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);

/**
* Create binary index
Expand All @@ -118,7 +122,7 @@ class BinaryIndexService : public IndexService {
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
*/
virtual void createIndex(
void createIndex(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
faiss::MetricType metric,
Expand All @@ -145,7 +149,7 @@ class BinaryIndexService : public IndexService {
* @param parameters parameters to be applied to faiss index
* @param templateIndexData vector containing the template index data
*/
virtual void createIndexFromTemplate(
void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
Expand All @@ -154,9 +158,14 @@ class BinaryIndexService : public IndexService {
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData);
std::vector<uint8_t> templateIndexData) override;

void InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x);

std::vector<uint8_t> trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters) override;


virtual ~BinaryIndexService() = default;
~BinaryIndexService() override = default;
};

}
Expand Down
9 changes: 1 addition & 8 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,7 @@ namespace knn_jni {
//
// Return the serialized representation
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with
// the vector of floats located at trainVectorsPointerJ.
//
// Return the serialized representation
jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);
jlong trainVectorsPointerJ, IndexService* indexService);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
Expand Down
61 changes: 61 additions & 0 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,37 @@ void IndexService::createIndexFromTemplate(
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

void IndexService::InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
InternalTrainIndex(indexIvf->quantizer, n, x);
}
indexIvf->make_direct_map();
}

if (!index->is_trained) {
index->train(n, x);
}
}

std::vector<uint8_t> IndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters) {
// Create faiss index
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dimension, indexDescription.c_str(), metric));

// Train index if needed
if (!index->is_trained) {
InternalTrainIndex(index.get(), numVectors, trainingVectors);
}

// Write index to a vector
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index(index.get(), &vectorIoWriter);

return std::vector<uint8_t>(vectorIoWriter.data.begin(), vectorIoWriter.data.end());
}



BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void BinaryIndexService::createIndex(
Expand Down Expand Up @@ -223,5 +254,35 @@ void BinaryIndexService::createIndexFromTemplate(
faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

void BinaryIndexService::InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
if (!indexIvf->is_trained) {
indexIvf->train(n, reinterpret_cast<const uint8_t*>(x));
}
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
}
}

std::vector<uint8_t> BinaryIndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map<std::string, jobject>& parameters) {
// Convert Java parameters to C++ parameters
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::index_binary_factory(dimension, indexDescription.c_str()));

// Train the index if it is not already trained
if (!indexWriter->is_trained) {
InternalTrainIndex(indexWriter.get(), numVectors, trainingVectors);
}

// Serialize the trained index to a byte array
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index_binary(indexWriter.get(), &vectorIoWriter);

// Convert the serialized data to a std::vector<uint8_t>
std::vector<uint8_t> trainedIndexData(vectorIoWriter.data.begin(), vectorIoWriter.data.end());

return trainedIndexData;
}
} // namespace faiss_wrapper
} // namespace knn_jni
3 changes: 3 additions & 0 deletions jni/src/faiss_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ faiss::IndexIDMapTemplate<faiss::IndexBinary>* FaissMethods::indexBinaryIdMap(fa
void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) {
faiss::write_index(idx, fname);
}

void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) {
faiss::write_index_binary(idx, fname);
}

faiss::Index* FaissMethods::readIndex(faiss::IOReader* f, int io_flags) {
return faiss::read_index(f, io_flags);
}

faiss::IndexBinary* FaissMethods::readIndexBinary(faiss::IOReader* f, int io_flags) {
return faiss::read_index_binary(f, io_flags);
}
Expand Down
115 changes: 8 additions & 107 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ void knn_jni::faiss_wrapper::InitLibrary() {
}

jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
jint dimensionJ, jlong trainVectorsPointerJ) {
jint dimensionJ, jlong trainVectorsPointerJ, IndexService* indexService) {
// First, we need to build the index
if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
Expand All @@ -572,103 +572,26 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric));

// Related to https://github.com/facebookresearch/faiss/issues/1621. HNSWPQ defaults to l2 even when metric is
// passed in. This updates it to the correct metric.
indexWriter->metric_type = metric;
if (auto * indexHnswPq = dynamic_cast<faiss::IndexHNSWPQ*>(indexWriter.get())) {
indexHnswPq->storage->metric_type = metric;
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Add extra parameters that can't be configured with the index factory
std::unordered_map<std::string, jobject> subParametersCpp;
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get());
jniUtil->DeleteLocalRef(env, subParametersJ);
}

// Train index if needed
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
if(!indexWriter->is_trained) {
InternalTrainIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Now that indexWriter is trained, we just load the bytes into an array and return
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index(indexWriter.get(), &vectorIoWriter);

// Wrap in smart pointer
std::unique_ptr<jbyte[]> jbytesBuffer;
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
int c = 0;
for (auto b : vectorIoWriter.data) {
jbytesBuffer[c++] = (jbyte) b;
}

jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
return ret;
}

jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ,
jint dimensionJ, jlong trainVectorsPointerJ) {
// First, we need to build the index
if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));

std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str()));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]);
}

// Train index if needed
// Train index using IndexService
auto *trainingVectorsPointerCpp = reinterpret_cast<std::vector<float>*>(trainVectorsPointerJ);
int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ;
if(!indexWriter->is_trained) {
InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data());
}
jniUtil->DeleteLocalRef(env, parametersJ);
int numVectors = trainingVectorsPointerCpp->size() / (int) dimensionJ;
std::vector<uint8_t> trainedIndexData = indexService->trainIndex(jniUtil, env, metric, indexDescriptionCpp, dimensionJ, numVectors, trainingVectorsPointerCpp->data(), subParametersCpp);

// Now that indexWriter is trained, we just load the bytes into an array and return
faiss::VectorIOWriter vectorIoWriter;
faiss::write_index_binary(indexWriter.get(), &vectorIoWriter);

// Wrap in smart pointer
std::unique_ptr<jbyte[]> jbytesBuffer;
jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]);
int c = 0;
for (auto b : vectorIoWriter.data) {
jbytesBuffer[c++] = (jbyte) b;
}

jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size());
jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get());
jbyteArray ret = jniUtil->NewByteArray(env, trainedIndexData.size());
jniUtil->SetByteArrayRegion(env, ret, 0, trainedIndexData.size(), reinterpret_cast<jbyte*>(trainedIndexData.data()));
return ret;
}

Expand Down Expand Up @@ -717,28 +640,6 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env,
}
}

void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexIVF*>(index)) {
if (indexIvf->quantizer_trains_alone == 2) {
InternalTrainIndex(indexIvf->quantizer, n, x);
}
indexIvf->make_direct_map();
}

if (!index->is_trained) {
index->train(n, x);
}
}

void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) {
if (auto * indexIvf = dynamic_cast<faiss::IndexBinaryIVF*>(index)) {
indexIvf->make_direct_map();
}
if (!index->is_trained) {
index->train(n, reinterpret_cast<const uint8_t*>(x));
}
}

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap) {
int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr);
int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ);
Expand Down
10 changes: 7 additions & 3 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
try {
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService);
knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down Expand Up @@ -232,7 +232,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex
jlong trainVectorsPointerJ)
{
try {
return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ, &indexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand All @@ -245,7 +247,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinar
jlong trainVectorsPointerJ)
{
try {
return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ);
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(faissMethods));
return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ, &indexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down

0 comments on commit 293f46d

Please sign in to comment.