From b98d2bb25d11b387d3159070eb7e5e40b991fcfc Mon Sep 17 00:00:00 2001 From: Vivek Narang <123010842+narangvivek10@users.noreply.github.com> Date: Sat, 21 Dec 2024 11:32:10 -0500 Subject: [PATCH] Bruteforce API implementation (#8) * Bruteforce API implementation Co-authored-by: Vivek Narang --- java/build.sh | 9 +- java/cuvs-java/pom.xml | 3 +- .../java/com/nvidia/cuvs/BruteForceIndex.java | 255 ++++++++++++++++++ .../nvidia/cuvs/BruteForceIndexParams.java | 71 +++++ .../java/com/nvidia/cuvs/BruteForceQuery.java | 131 +++++++++ .../nvidia/cuvs/BruteForceSearchResults.java | 93 +++++++ .../main/java/com/nvidia/cuvs/CagraIndex.java | 80 +++--- .../com/nvidia/cuvs/CagraIndexParams.java | 2 +- .../main/java/com/nvidia/cuvs/CagraQuery.java | 4 +- .../java/com/nvidia/cuvs/CuVSResources.java | 24 +- .../java/com/nvidia/cuvs/common/Util.java | 29 +- .../nvidia/cuvs/BruteForceAndSearchTest.java | 99 +++++++ java/internal/src/cuvs_java.c | 253 ++++++++++++++--- 13 files changed, 967 insertions(+), 86 deletions(-) create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndex.java create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndexParams.java create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceSearchResults.java create mode 100644 java/cuvs-java/src/test/java/com/nvidia/cuvs/BruteForceAndSearchTest.java diff --git a/java/build.sh b/java/build.sh index 3f088172b..3be9f35f0 100755 --- a/java/build.sh +++ b/java/build.sh @@ -1,7 +1,12 @@ export CMAKE_PREFIX_PATH=`pwd`/../cpp/build + +VERSION="25.02" +GROUP_ID="com.nvidia.cuvs" +SO_FILE_PATH="./internal" + cd internal && cmake . && cmake --build . \ && cd .. \ - && mvn install:install-file -DgroupId=com.nvidia.cuvs -DartifactId=cuvs-java-internal -Dversion=25.02 -Dpackaging=so -Dfile=./internal/libcuvs_java.so \ + && mvn install:install-file -DgroupId=$GROUP_ID -DartifactId=cuvs-java-internal -Dversion=$VERSION -Dpackaging=so -Dfile=$SO_FILE_PATH/libcuvs_java.so \ && cd cuvs-java \ && mvn package \ - && mvn install:install-file -Dfile=./target/cuvs-java-25.02.1-jar-with-dependencies.jar -DgroupId=com.nvidia.cuvs -DartifactId=cuvs-java -Dversion=54.02.1 -Dpackaging=jar + && mvn install:install-file -Dfile=./target/cuvs-java-$VERSION-jar-with-dependencies.jar -DgroupId=$GROUP_ID -DartifactId=cuvs-java -Dversion=$VERSION -Dpackaging=jar diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml index f7f69a21a..9ac634dfb 100644 --- a/java/cuvs-java/pom.xml +++ b/java/cuvs-java/pom.xml @@ -21,7 +21,7 @@ 4.0.0 com.nvidia.cuvs cuvs-java - 25.02.1 + 25.02 cuvs-java jar @@ -114,7 +114,6 @@ - org.apache.maven.plugins maven-assembly-plugin diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndex.java new file mode 100644 index 000000000..443c8df65 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndex.java @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.cuvs; + +import java.io.IOException; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.MemoryLayout; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SequenceLayout; +import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; + +import com.nvidia.cuvs.common.Util; + +/** + * The BRUTEFORCE method is running the KNN algorithm. It performs an extensive + * search, and in contrast to ANN methods produces an exact result. + * + * {@link BruteForceIndex} encapsulates a BRUTEFORCE index, along with methods + * to interact with it. + * + * @since 25.02 + */ +public class BruteForceIndex { + + private final float[][] dataset; + private final long[] prefilterData; + private final CuVSResources resources; + private MethodHandle indexMethodHandle; + private MethodHandle searchMethodHandle; + private MethodHandle destroyIndexMethodHandle; + private IndexReference bruteForceIndexReference; + private BruteForceIndexParams bruteForceIndexParams; + private MemoryLayout longMemoryLayout; + private MemoryLayout intMemoryLayout; + private MemoryLayout floatMemoryLayout; + + /** + * Constructor for building the index using specified dataset + * + * @param dataset the dataset used for creating the BRUTEFORCE + * index + * @param resources an instance of {@link CuVSResources} + * @param bruteForceIndexParams an instance of {@link BruteForceIndexParams} + * holding the index parameters + * @param prefilterData the prefilter data to use while searching the + * BRUTEFORCE index + */ + private BruteForceIndex(float[][] dataset, CuVSResources resources, BruteForceIndexParams bruteForceIndexParams, + long[] prefilterData) throws Throwable { + this.dataset = dataset; + this.prefilterData = prefilterData; + this.resources = resources; + this.bruteForceIndexParams = bruteForceIndexParams; + + longMemoryLayout = resources.linker.canonicalLayouts().get("long"); + intMemoryLayout = resources.linker.canonicalLayouts().get("int"); + floatMemoryLayout = resources.linker.canonicalLayouts().get("float"); + + initializeMethodHandles(); + this.bruteForceIndexReference = build(); + } + + /** + * Initializes the {@link MethodHandles} for invoking native methods. + * + * @throws IOException @{@link IOException} is unable to load the native library + */ + private void initializeMethodHandles() throws IOException { + indexMethodHandle = resources.linker.downcallHandle( + resources.getSymbolLookup().find("build_brute_force_index").get(), + FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, longMemoryLayout, longMemoryLayout, + ValueLayout.ADDRESS, ValueLayout.ADDRESS, intMemoryLayout)); + + searchMethodHandle = resources.linker.downcallHandle( + resources.getSymbolLookup().find("search_brute_force_index").get(), + FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, intMemoryLayout, longMemoryLayout, + intMemoryLayout, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, + ValueLayout.ADDRESS, longMemoryLayout)); + + destroyIndexMethodHandle = resources.linker.downcallHandle( + resources.getSymbolLookup().find("destroy_brute_force_index").get(), + FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); + } + + /** + * Invokes the native destroy_brute_force_index function to de-allocate + * BRUTEFORCE index + */ + public void destroyIndex() throws Throwable { + MemoryLayout returnValueMemoryLayout = intMemoryLayout; + MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); + destroyIndexMethodHandle.invokeExact(bruteForceIndexReference.getMemorySegment(), returnValueMemorySegment); + } + + /** + * Invokes the native build_brute_force_index function via the Panama API to + * build the {@link BruteForceIndex} + * + * @return an instance of {@link IndexReference} that holds the pointer to the + * index + */ + private IndexReference build() throws Throwable { + long rows = dataset.length; + long cols = rows > 0 ? dataset[0].length : 0; + + MemoryLayout returnValueMemoryLayout = intMemoryLayout; + MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); + + IndexReference indexReference = new IndexReference((MemorySegment) indexMethodHandle.invokeExact( + Util.buildMemorySegment(resources.linker, resources.arena, dataset), rows, cols, resources.getMemorySegment(), + returnValueMemorySegment, bruteForceIndexParams.getNumWriterThreads())); + + return indexReference; + } + + /** + * Invokes the native search_brute_force_index via the Panama API for searching + * a BRUTEFORCE index. + * + * @param cuvsQuery an instance of {@link BruteForceQuery} holding the query + * vectors and other parameters + * @return an instance of {@link BruteForceSearchResults} containing the results + */ + public BruteForceSearchResults search(BruteForceQuery cuvsQuery) throws Throwable { + long numQueries = cuvsQuery.getQueryVectors().length; + long numBlocks = cuvsQuery.getTopK() * numQueries; + int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0; + long prefilterDataLength = prefilterData != null ? prefilterData.length : 0; + + SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, longMemoryLayout); + SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, floatMemoryLayout); + MemorySegment neighborsMemorySegment = resources.arena.allocate(neighborsSequenceLayout); + MemorySegment distancesMemorySegment = resources.arena.allocate(distancesSequenceLayout); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; + MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); + MemorySegment prefilterDataMemorySegment = prefilterData != null + ? Util.buildMemorySegment(resources.linker, resources.arena, prefilterData) + : MemorySegment.NULL; + + searchMethodHandle.invokeExact(bruteForceIndexReference.getMemorySegment(), + Util.buildMemorySegment(resources.linker, resources.arena, cuvsQuery.getQueryVectors()), cuvsQuery.getTopK(), + numQueries, vectorDimension, resources.getMemorySegment(), neighborsMemorySegment, distancesMemorySegment, + returnValueMemorySegment, prefilterDataMemorySegment, prefilterDataLength); + + return new BruteForceSearchResults(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment, + distancesMemorySegment, cuvsQuery.getTopK(), cuvsQuery.getMapping(), numQueries); + } + + /** + * Builder helps configure and create an instance of {@link BruteForceIndex}. + */ + public static class Builder { + + private float[][] dataset; + private long[] prefilterData; + private CuVSResources cuvsResources; + private BruteForceIndexParams bruteForceIndexParams; + + /** + * Constructs this Builder with an instance of {@link CuVSResources}. + * + * @param cuvsResources an instance of {@link CuVSResources} + */ + public Builder(CuVSResources cuvsResources) { + this.cuvsResources = cuvsResources; + } + + /** + * Registers an instance of configured {@link BruteForceIndexParams} with this + * Builder. + * + * @param bruteForceIndexParams An instance of BruteForceIndexParams + * @return An instance of this Builder + */ + public Builder withIndexParams(BruteForceIndexParams bruteForceIndexParams) { + this.bruteForceIndexParams = bruteForceIndexParams; + return this; + } + + /** + * Sets the dataset for building the {@link BruteForceIndex}. + * + * @param dataset a two-dimensional float array + * @return an instance of this Builder + */ + public Builder withDataset(float[][] dataset) { + this.dataset = dataset; + return this; + } + + /** + * Sets the prefilter data for building the {@link BruteForceIndex}. + * + * @param prefilterData a one-dimensional long array + * @return an instance of this Builder + */ + public Builder withPrefilterData(long[] prefilterData) { + this.prefilterData = prefilterData; + return this; + } + + /** + * Builds and returns an instance of {@link BruteForceIndex}. + * + * @return an instance of {@link BruteForceIndex} + */ + public BruteForceIndex build() throws Throwable { + return new BruteForceIndex(dataset, cuvsResources, bruteForceIndexParams, prefilterData); + } + } + + /** + * Holds the memory reference to a BRUTEFORCE index. + */ + protected static class IndexReference { + + private final MemorySegment memorySegment; + + /** + * Constructs BruteForceIndexReference with an instance of MemorySegment passed + * as a parameter. + * + * @param indexMemorySegment the MemorySegment instance to use for containing + * index reference + */ + protected IndexReference(MemorySegment indexMemorySegment) { + this.memorySegment = indexMemorySegment; + } + + /** + * Gets the instance of index MemorySegment. + * + * @return index MemorySegment + */ + protected MemorySegment getMemorySegment() { + return memorySegment; + } + } +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndexParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndexParams.java new file mode 100644 index 000000000..46d8534c5 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndexParams.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.cuvs; + +/** + * Supplemental parameters to build BRUTEFORCE index. + * + * @since 25.02 + */ +public class BruteForceIndexParams { + + private final int numWriterThreads; + + private BruteForceIndexParams(int writerThreads) { + this.numWriterThreads = writerThreads; + } + + @Override + public String toString() { + return "BruteForceIndexParams [numWriterThreads=" + numWriterThreads + "]"; + } + + /** + * Gets the number of threads used to build the index. + */ + public int getNumWriterThreads() { + return numWriterThreads; + } + + /** + * Builder configures and creates an instance of {@link BruteForceIndexParams}. + */ + public static class Builder { + + private int numWriterThreads = 2; + + /** + * Sets the number of writer threads to use for indexing. + * + * @param numWriterThreads number of writer threads to use + * @return an instance of Builder + */ + public Builder withNumWriterThreads(int numWriterThreads) { + this.numWriterThreads = numWriterThreads; + return this; + } + + /** + * Builds an instance of {@link BruteForceIndexParams}. + * + * @return an instance of {@link BruteForceIndexParams} + */ + public BruteForceIndexParams build() { + return new BruteForceIndexParams(numWriterThreads); + } + } +} \ No newline at end of file diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java new file mode 100644 index 000000000..5a9da0515 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceQuery.java @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.cuvs; + +import java.util.Arrays; +import java.util.Map; + +/** + * BruteForceQuery holds the query vectors to be used while invoking search. + * + * @since 25.02 + */ +public class BruteForceQuery { + + private Map mapping; + private float[][] queryVectors; + private int topK; + + /** + * Constructs an instance of {@link BruteForceQuery} using queryVectors, + * mapping, and topK. + * + * @param queryVectors 2D float query vector array + * @param mapping an instance of ID mapping + * @param topK the top k results to return + */ + public BruteForceQuery(float[][] queryVectors, Map mapping, int topK) { + this.queryVectors = queryVectors; + this.mapping = mapping; + this.topK = topK; + } + + /** + * Gets the query vector 2D float array. + * + * @return 2D float array + */ + public float[][] getQueryVectors() { + return queryVectors; + } + + /** + * Gets the passed map instance. + * + * @return a map of ID mappings + */ + public Map getMapping() { + return mapping; + } + + /** + * Gets the topK value. + * + * @return an integer + */ + public int getTopK() { + return topK; + } + + @Override + public String toString() { + return "BruteForceQuery [mapping=" + mapping + ", queryVectors=" + Arrays.toString(queryVectors) + ", topK=" + topK + + "]"; + } + + /** + * Builder helps configure and create an instance of BruteForceQuery. + */ + public static class Builder { + + private float[][] queryVectors; + private Map mapping; + private int topK = 2; + + /** + * Registers the query vectors to be passed in the search call. + * + * @param queryVectors 2D float query vector array + * @return an instance of this Builder + */ + public Builder withQueryVectors(float[][] queryVectors) { + this.queryVectors = queryVectors; + return this; + } + + /** + * Sets the instance of mapping to be used for ID mapping. + * + * @param mapping the ID mapping instance + * @return an instance of this Builder + */ + public Builder withMapping(Map mapping) { + this.mapping = mapping; + return this; + } + + /** + * Registers the topK value. + * + * @param topK the topK value used to retrieve the topK results + * @return an instance of this Builder + */ + public Builder withTopK(int topK) { + this.topK = topK; + return this; + } + + /** + * Builds an instance of {@link BruteForceQuery} + * + * @return an instance of {@link BruteForceQuery} + */ + public BruteForceQuery build() { + return new BruteForceQuery(queryVectors, mapping, topK); + } + } +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceSearchResults.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceSearchResults.java new file mode 100644 index 000000000..d53bce751 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceSearchResults.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.cuvs; + +import java.lang.foreign.MemoryLayout.PathElement; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.SequenceLayout; +import java.lang.invoke.VarHandle; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import com.nvidia.cuvs.common.SearchResults; + +/** + * SearchResult encapsulates the logic for reading and holding search results. + * + * @since 25.02 + */ +public class BruteForceSearchResults implements SearchResults { + + private final List> results; + private final Map mapping; // TODO: Is this performant in a user application? + private final SequenceLayout neighboursSequenceLayout; + private final SequenceLayout distancesSequenceLayout; + private final MemorySegment neighboursMemorySegment; + private final MemorySegment distancesMemorySegment; + private final int topK; + private final long numberOfQueries; + + protected BruteForceSearchResults(SequenceLayout neighboursSequenceLayout, SequenceLayout distancesSequenceLayout, + MemorySegment neighboursMemorySegment, MemorySegment distancesMemorySegment, int topK, + Map mapping, long numberOfQueries) { + this.topK = topK; + this.numberOfQueries = numberOfQueries; + this.neighboursSequenceLayout = neighboursSequenceLayout; + this.distancesSequenceLayout = distancesSequenceLayout; + this.neighboursMemorySegment = neighboursMemorySegment; + this.distancesMemorySegment = distancesMemorySegment; + this.mapping = mapping; + results = new LinkedList>(); + + readResultMemorySegments(); + } + + /** + * Reads neighbors and distances {@link MemorySegment} and loads the values + * internally + */ + private void readResultMemorySegments() { + VarHandle neighboursVarHandle = neighboursSequenceLayout.varHandle(PathElement.sequenceElement()); + VarHandle distancesVarHandle = distancesSequenceLayout.varHandle(PathElement.sequenceElement()); + + Map intermediateResultMap = new LinkedHashMap(); + int count = 0; + for (long i = 0; i < topK * numberOfQueries; i++) { + long id = (long) neighboursVarHandle.get(neighboursMemorySegment, 0L, i); + float dst = (float) distancesVarHandle.get(distancesMemorySegment, 0L, i); + intermediateResultMap.put(mapping != null ? mapping.get((int) id) : (int) id, dst); + count += 1; + if (count == topK) { + results.add(intermediateResultMap); + intermediateResultMap = new LinkedHashMap(); + count = 0; + } + } + } + + /** + * Gets a list results as a map of neighbor IDs to distances. + * + * @return a list of results for each query as a map of neighbor IDs to distance + */ + @Override + public List> getResults() { + return results; + } +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java index 67394b53b..ca9df5e13 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java @@ -58,9 +58,19 @@ public class CagraIndex { private CagraIndexParams cagraIndexParameters; private CagraCompressionParams cagraCompressionParams; private IndexReference cagraIndexReference; + private MemoryLayout longMemoryLayout; + private MemoryLayout intMemoryLayout; + private MemoryLayout floatMemoryLayout; - /* + /** * Constructor for building the index using specified dataset + * + * @param indexParameters an instance of {@link CagraIndexParams} holding + * the index parameters + * @param cagraCompressionParams an instance of {@link CagraCompressionParams} + * holding the compression parameters + * @param dataset the dataset for indexing + * @param resources an instance of {@link CuVSResources} */ private CagraIndex(CagraIndexParams indexParameters, CagraCompressionParams cagraCompressionParams, float[][] dataset, CuVSResources resources) throws Throwable { @@ -69,12 +79,19 @@ private CagraIndex(CagraIndexParams indexParameters, CagraCompressionParams cagr this.dataset = dataset; this.resources = resources; + longMemoryLayout = resources.linker.canonicalLayouts().get("long"); + intMemoryLayout = resources.linker.canonicalLayouts().get("int"); + floatMemoryLayout = resources.linker.canonicalLayouts().get("float"); + initializeMethodHandles(); this.cagraIndexReference = build(); } /** * Constructor for loading the index from an {@link InputStream} + * + * @param inputStream an instance of stream to read the index bytes from + * @param resources an instance of {@link CuVSResources} */ private CagraIndex(InputStream inputStream, CuVSResources resources) throws Throwable { this.cagraIndexParameters = null; @@ -82,6 +99,10 @@ private CagraIndex(InputStream inputStream, CuVSResources resources) throws Thro this.dataset = null; this.resources = resources; + longMemoryLayout = resources.linker.canonicalLayouts().get("long"); + intMemoryLayout = resources.linker.canonicalLayouts().get("int"); + floatMemoryLayout = resources.linker.canonicalLayouts().get("float"); + initializeMethodHandles(); this.cagraIndexReference = deserialize(inputStream); } @@ -92,40 +113,39 @@ private CagraIndex(InputStream inputStream, CuVSResources resources) throws Thro * @throws IOException @{@link IOException} is unable to load the native library */ private void initializeMethodHandles() throws IOException { - indexMethodHandle = resources.linker.downcallHandle( - resources.getLibcuvsNativeLibrary().find("build_cagra_index").get(), - FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, resources.linker.canonicalLayouts().get("long"), - resources.linker.canonicalLayouts().get("long"), ValueLayout.ADDRESS, ValueLayout.ADDRESS, - ValueLayout.ADDRESS, ValueLayout.ADDRESS, resources.linker.canonicalLayouts().get("int"))); - - searchMethodHandle = resources.linker.downcallHandle( - resources.getLibcuvsNativeLibrary().find("search_cagra_index").get(), - FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, - resources.linker.canonicalLayouts().get("int"), resources.linker.canonicalLayouts().get("long"), - resources.linker.canonicalLayouts().get("int"), ValueLayout.ADDRESS, ValueLayout.ADDRESS, - ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); + indexMethodHandle = resources.linker.downcallHandle(resources.getSymbolLookup().find("build_cagra_index").get(), + FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, longMemoryLayout, longMemoryLayout, + ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, intMemoryLayout)); + + searchMethodHandle = resources.linker.downcallHandle(resources.getSymbolLookup().find("search_cagra_index").get(), + FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, intMemoryLayout, longMemoryLayout, + intMemoryLayout, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, + ValueLayout.ADDRESS)); serializeMethodHandle = resources.linker.downcallHandle( - resources.getLibcuvsNativeLibrary().find("serialize_cagra_index").get(), + resources.getSymbolLookup().find("serialize_cagra_index").get(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); deserializeMethodHandle = resources.linker.downcallHandle( - resources.getLibcuvsNativeLibrary().find("deserialize_cagra_index").get(), + resources.getSymbolLookup().find("deserialize_cagra_index").get(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); destroyIndexMethodHandle = resources.linker.downcallHandle( - resources.getLibcuvsNativeLibrary().find("destroy_cagra_index").get(), + resources.getSymbolLookup().find("destroy_cagra_index").get(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); } + /** + * Invokes the native destroy_cagra_index to de-allocate the CAGRA index + */ public void destroyIndex() throws Throwable { - MemoryLayout returnValueMemoryLayout = resources.linker.canonicalLayouts().get("int"); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); destroyIndexMethodHandle.invokeExact(cagraIndexReference.getMemorySegment(), returnValueMemorySegment); } /** - * Invokes the native build_index function via the Panama API to build the + * Invokes the native build_cagra_index function via the Panama API to build the * {@link CagraIndex} * * @return an instance of {@link IndexReference} that holds the pointer to the @@ -135,8 +155,8 @@ private IndexReference build() throws Throwable { long rows = dataset.length; long cols = rows > 0 ? dataset[0].length : 0; - MemoryLayout layout = resources.linker.canonicalLayouts().get("int"); - MemorySegment segment = resources.arena.allocate(layout); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; + MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); MemorySegment indexParamsMemorySegment = cagraIndexParameters != null ? cagraIndexParameters.getMemorySegment() : MemorySegment.NULL; @@ -149,14 +169,14 @@ private IndexReference build() throws Throwable { IndexReference indexReference = new IndexReference((MemorySegment) indexMethodHandle.invokeExact( Util.buildMemorySegment(resources.linker, resources.arena, dataset), rows, cols, resources.getMemorySegment(), - segment, indexParamsMemorySegment, compressionParamsMemorySegment, numWriterThreads)); + returnValueMemorySegment, indexParamsMemorySegment, compressionParamsMemorySegment, numWriterThreads)); return indexReference; } /** - * Invokes the native search_index via the Panama API for searching a CAGRA - * index. + * Invokes the native search_cagra_index via the Panama API for searching a + * CAGRA index. * * @param query an instance of {@link CagraQuery} holding the query vectors and * other parameters @@ -167,13 +187,11 @@ public CagraSearchResults search(CagraQuery query) throws Throwable { long numBlocks = query.getTopK() * numQueries; int vectorDimension = numQueries > 0 ? query.getQueryVectors()[0].length : 0; - SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, - resources.linker.canonicalLayouts().get("int")); - SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, - resources.linker.canonicalLayouts().get("float")); + SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, intMemoryLayout); + SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, floatMemoryLayout); MemorySegment neighborsMemorySegment = resources.arena.allocate(neighborsSequenceLayout); MemorySegment distancesMemorySegment = resources.arena.allocate(distancesSequenceLayout); - MemoryLayout returnValueMemoryLayout = resources.linker.canonicalLayouts().get("int"); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); searchMethodHandle.invokeExact(cagraIndexReference.getMemorySegment(), @@ -206,7 +224,7 @@ public void serialize(OutputStream outputStream) throws Throwable { * temporarily */ public void serialize(OutputStream outputStream, File tempFile) throws Throwable { - MemoryLayout returnValueMemoryLayout = resources.linker.canonicalLayouts().get("int"); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); serializeMethodHandle.invokeExact(resources.getMemorySegment(), cagraIndexReference.getMemorySegment(), returnValueMemorySegment, @@ -229,7 +247,7 @@ public void serialize(OutputStream outputStream, File tempFile) throws Throwable * @return an instance of {@link IndexReference}. */ private IndexReference deserialize(InputStream inputStream) throws Throwable { - MemoryLayout returnValueMemoryLayout = resources.linker.canonicalLayouts().get("int"); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); String tmpIndexFile = "/tmp/" + UUID.randomUUID().toString() + ".cag"; IndexReference indexReference = new IndexReference(resources); @@ -351,7 +369,7 @@ public CagraIndex build() throws Throwable { } /** - * Holds the memory reference to an index. + * Holds the memory reference to a CAGRA index. */ protected static class IndexReference { diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java index 66eb235e2..e194575a9 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndexParams.java @@ -144,7 +144,7 @@ public static class Builder { private int intermediateGraphDegree = 128; private int graphDegree = 64; private int nnDescentNumIterations = 20; - private int numWriterThreads = 1; + private int numWriterThreads = 2; public Builder(CuVSResources resources) { this.resources = resources; diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java index 77cbf33f3..7fb0fdcf1 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java @@ -81,7 +81,7 @@ public Map getMapping() { /** * Gets the topK value. * - * @return an integer + * @return the topK value */ public int getTopK() { return topK; @@ -94,7 +94,7 @@ public String toString() { } /** - * Builder helps configure and create an instance of CuVSQuery. + * Builder helps configure and create an instance of CagraQuery. */ public static class Builder { diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java index 60f68981b..effe81ef2 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java @@ -37,12 +37,12 @@ public class CuVSResources implements AutoCloseable { public final Arena arena; public final Linker linker; - public final SymbolLookup libcuvsNativeLibrary; + public final SymbolLookup symbolLookup; protected File nativeLibrary; - private final MethodHandle createResourcesMethodHandle; private final MethodHandle destroyResourcesMethodHandle; private MemorySegment resourcesMemorySegment; + private MemoryLayout intMemoryLayout; /** * Constructor that allocates the resources needed for cuVS @@ -52,14 +52,17 @@ public class CuVSResources implements AutoCloseable { public CuVSResources() throws Throwable { linker = Linker.nativeLinker(); arena = Arena.ofShared(); + nativeLibrary = Util.loadLibraryFromJar("/libcuvs_java.so"); - libcuvsNativeLibrary = SymbolLookup.libraryLookup(nativeLibrary.getAbsolutePath(), arena); + symbolLookup = SymbolLookup.libraryLookup(nativeLibrary.getAbsolutePath(), arena); + intMemoryLayout = linker.canonicalLayouts().get("int"); - createResourcesMethodHandle = linker.downcallHandle(libcuvsNativeLibrary.find("create_resources").get(), + createResourcesMethodHandle = linker.downcallHandle(symbolLookup.find("create_resources").get(), FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); - destroyResourcesMethodHandle = linker.downcallHandle(libcuvsNativeLibrary.find("destroy_resources").get(), + destroyResourcesMethodHandle = linker.downcallHandle(symbolLookup.find("destroy_resources").get(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); + createResources(); } @@ -69,14 +72,14 @@ public CuVSResources() throws Throwable { * @throws Throwable exception thrown when native function is invoked */ public void createResources() throws Throwable { - MemoryLayout returnValueMemoryLayout = linker.canonicalLayouts().get("int"); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; MemorySegment returnValueMemorySegment = arena.allocate(returnValueMemoryLayout); resourcesMemorySegment = (MemorySegment) createResourcesMethodHandle.invokeExact(returnValueMemorySegment); } @Override public void close() { - MemoryLayout returnValueMemoryLayout = linker.canonicalLayouts().get("int"); + MemoryLayout returnValueMemoryLayout = intMemoryLayout; MemorySegment returnValueMemorySegment = arena.allocate(returnValueMemoryLayout); try { destroyResourcesMethodHandle.invokeExact(resourcesMemorySegment, returnValueMemorySegment); @@ -96,9 +99,10 @@ protected MemorySegment getMemorySegment() { } /** - * Returns the loaded libcuvs_java.so as a {@link SymbolLookup} + * Returns the loaded libcuvs_java_cagra.so as a {@link SymbolLookup} */ - protected SymbolLookup getLibcuvsNativeLibrary() { - return libcuvsNativeLibrary; + protected SymbolLookup getSymbolLookup() { + return symbolLookup; } + } \ No newline at end of file diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java index 750e49d64..ed2f74c4d 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java @@ -30,6 +30,7 @@ import org.apache.commons.io.IOUtils; public class Util { + /** * A utility method for getting an instance of {@link MemorySegment} for a * {@link String}. @@ -38,8 +39,9 @@ public class Util { * @return an instance of {@link MemorySegment} */ public static MemorySegment buildMemorySegment(Linker linker, Arena arena, String str) { + MemoryLayout charMemoryLayout = linker.canonicalLayouts().get("char"); StringBuilder sb = new StringBuilder(str).append('\0'); - MemoryLayout stringMemoryLayout = MemoryLayout.sequenceLayout(sb.length(), linker.canonicalLayouts().get("char")); + MemoryLayout stringMemoryLayout = MemoryLayout.sequenceLayout(sb.length(), charMemoryLayout); MemorySegment stringMemorySegment = arena.allocate(stringMemoryLayout); for (int i = 0; i < sb.length(); i++) { @@ -49,6 +51,21 @@ public static MemorySegment buildMemorySegment(Linker linker, Arena arena, Strin return stringMemorySegment; } + /** + * A utility method for building a {@link MemorySegment} for a 1D long array. + * + * @param data The 1D long array for which the {@link MemorySegment} is needed + * @return an instance of {@link MemorySegment} + */ + public static MemorySegment buildMemorySegment(Linker linker, Arena arena, long[] data) { + int cells = data.length; + MemoryLayout longMemoryLayout = linker.canonicalLayouts().get("long"); + MemoryLayout dataMemoryLayout = MemoryLayout.sequenceLayout(cells, longMemoryLayout); + MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout); + MemorySegment.copy(data, 0, dataMemorySegment, (ValueLayout) longMemoryLayout, 0, cells); + return dataMemorySegment; + } + /** * A utility method for building a {@link MemorySegment} for a 2D float array. * @@ -58,14 +75,14 @@ public static MemorySegment buildMemorySegment(Linker linker, Arena arena, Strin public static MemorySegment buildMemorySegment(Linker linker, Arena arena, float[][] data) { long rows = data.length; long cols = rows > 0 ? data[0].length : 0; - - MemoryLayout dataMemoryLayout = MemoryLayout.sequenceLayout(rows * cols, linker.canonicalLayouts().get("float")); + MemoryLayout floatMemoryLayout = linker.canonicalLayouts().get("float"); + MemoryLayout dataMemoryLayout = MemoryLayout.sequenceLayout(rows * cols, floatMemoryLayout); MemorySegment dataMemorySegment = arena.allocate(dataMemoryLayout); - long floatByteSize = linker.canonicalLayouts().get("float").byteSize(); + long floatByteSize = floatMemoryLayout.byteSize(); for (int r = 0; r < rows; r++) { - MemorySegment.copy(data[r], 0, dataMemorySegment, (ValueLayout) linker.canonicalLayouts().get("float"), - (r * cols * floatByteSize), (int) cols); + MemorySegment.copy(data[r], 0, dataMemorySegment, (ValueLayout) floatMemoryLayout, (r * cols * floatByteSize), + (int) cols); } return dataMemorySegment; diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/BruteForceAndSearchTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/BruteForceAndSearchTest.java new file mode 100644 index 000000000..00180c633 --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/BruteForceAndSearchTest.java @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.cuvs; + +import static org.junit.Assert.assertEquals; + +import java.lang.invoke.MethodHandles; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.nvidia.cuvs.common.SearchResults; + +public class BruteForceAndSearchTest { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + /** + * A basic test that checks the whole flow - from indexing to search. + * + * @throws Throwable + */ + @Test + public void testIndexingAndSearchingFlow() throws Throwable { + + // Sample data and query + float[][] dataset = { + { 0.74021935f, 0.9209938f }, + { 0.03902049f, 0.9689629f }, + { 0.92514056f, 0.4463501f }, + { 0.6673192f, 0.10993068f } + }; + Map map = Map.of(0, 0, 1, 1, 2, 2, 3, 3); + float[][] queries = { + { 0.48216683f, 0.0428398f }, + { 0.5084142f, 0.6545497f }, + { 0.51260436f, 0.2643005f }, + { 0.05198065f, 0.5789965f } + }; + + // Expected search results + List> expectedResults = Arrays.asList( + Map.of(3, 0.59198487f, 1, 0.6283694f, 2, 0.77246666f), + Map.of(1, 0.2534914f, 3, 0.33350062f, 2, 0.8748074f), + Map.of(1, 0.4058035f, 3, 0.43066847f, 2, 0.72249544f), + Map.of(3, 0.11946076f, 1, 0.46753132f, 2, 1.0337032f) + ); + + for (int j = 0; j < 10; j++) { + + try (CuVSResources resources = new CuVSResources()) { + + BruteForceIndexParams indexParams = new BruteForceIndexParams.Builder() + .withNumWriterThreads(32) + .build(); + + // Create the index with the dataset + BruteForceIndex index = new BruteForceIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + // Create a query object with the query vectors + BruteForceQuery cuvsQuery = new BruteForceQuery.Builder() + .withTopK(3) + .withQueryVectors(queries) + .withMapping(map) + .build(); + + // Perform the search + SearchResults results = index.search(cuvsQuery); + + // Check results + log.info(results.getResults().toString()); + assertEquals(expectedResults, results.getResults()); + + index.destroyIndex(); + } + } + } +} diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index b6a078e24..ee25ab576 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -26,24 +27,44 @@ #define catch(x) ExitJmp:if(__HadError) #define throw(x) {__HadError=true;goto ExitJmp;} -cuvsResources_t create_resources(int *returnValue) { - cuvsResources_t cuvsResources; - *returnValue = cuvsResourcesCreate(&cuvsResources); - return cuvsResources; +/** + * Create an Initialized opaque C handle + * + * @param return_value return value for cuvsResourcesCreate function call + */ +cuvsResources_t create_resources(int *return_value) { + cuvsResources_t cuvs_resources; + *return_value = cuvsResourcesCreate(&cuvs_resources); + return cuvs_resources; } -void destroy_resources(cuvsResources_t cuvsResources, int *returnValue) { - *returnValue = cuvsResourcesDestroy(cuvsResources); +/** + * Destroy and de-allocate opaque C handle + * + * @param cuvs_resources an opaque C handle + * @param return_value return value for cuvsResourcesDestroy function call + */ +void destroy_resources(cuvsResources_t cuvs_resources, int *return_value) { + *return_value = cuvsResourcesDestroy(cuvs_resources); } -DLManagedTensor prepare_tensor(void *data, int64_t shape[], DLDataTypeCode code) { +/** + * Helper function for creating DLManagedTensor instance + * + * @param data the data pointer points to the allocated data + * @param shape the shape of the tensor + * @param code the type code of base types + * @param bits the shape of the tensor + * @param ndim the number of dimensions + */ +DLManagedTensor prepare_tensor(void *data, int64_t shape[], DLDataTypeCode code, int bits, int ndim) { DLManagedTensor tensor; tensor.dl_tensor.data = data; tensor.dl_tensor.device.device_type = kDLCUDA; - tensor.dl_tensor.ndim = 2; + tensor.dl_tensor.ndim = ndim; tensor.dl_tensor.dtype.code = code; - tensor.dl_tensor.dtype.bits = 32; + tensor.dl_tensor.dtype.bits = bits; tensor.dl_tensor.dtype.lanes = 1; tensor.dl_tensor.shape = shape; tensor.dl_tensor.strides = NULL; @@ -51,65 +72,233 @@ DLManagedTensor prepare_tensor(void *data, int64_t shape[], DLDataTypeCode code) return tensor; } -cuvsCagraIndex_t build_cagra_index(float *dataset, long rows, long dimensions, cuvsResources_t cuvsResources, int *returnValue, - cuvsCagraIndexParams_t index_params, cuvsCagraCompressionParams_t compression_params, int numWriterThreads) { +/** + * Function for building CAGRA index + * + * @param dataset index dataset + * @param rows number of dataset rows + * @param dimensions vector dimension of the dataset + * @param cuvs_resources reference of the underlying opaque C handle + * @param return_value return value for cuvsCagraBuild function call + * @param index_params a reference to the index parameters + * @param compression_params a reference to the compression parameters + * @param n_writer_threads number of omp threads to use + */ +cuvsCagraIndex_t build_cagra_index(float *dataset, long rows, long dimensions, cuvsResources_t cuvs_resources, int *return_value, + cuvsCagraIndexParams_t index_params, cuvsCagraCompressionParams_t compression_params, int n_writer_threads) { + + cudaStream_t stream; + cuvsStreamGet(cuvs_resources, &stream); - omp_set_num_threads(numWriterThreads); - cuvsRMMPoolMemoryResourceEnable(95, 95, true); + omp_set_num_threads(n_writer_threads); + cuvsRMMPoolMemoryResourceEnable(95, 95, false); int64_t dataset_shape[2] = {rows, dimensions}; - DLManagedTensor dataset_tensor = prepare_tensor(dataset, dataset_shape, kDLFloat); + DLManagedTensor dataset_tensor = prepare_tensor(dataset, dataset_shape, kDLFloat, 32, 2); cuvsCagraIndex_t index; cuvsCagraIndexCreate(&index); index_params->compression = compression_params; - *returnValue = cuvsCagraBuild(cuvsResources, index_params, &dataset_tensor, index); + cuvsStreamSync(cuvs_resources); + *return_value = cuvsCagraBuild(cuvs_resources, index_params, &dataset_tensor, index); omp_set_num_threads(1); return index; } -void destroy_cagra_index(cuvsCagraIndex_t index, int *returnValue) { - *returnValue = cuvsCagraIndexDestroy(index); +/** + * A function to de-allocate CAGRA index + * + * @param index cuvsCagraIndex_t to de-allocate + * @param return_value return value for cuvsCagraIndexDestroy function call + */ +void destroy_cagra_index(cuvsCagraIndex_t index, int *return_value) { + *return_value = cuvsCagraIndexDestroy(index); } -void serialize_cagra_index(cuvsResources_t cuvsResources, cuvsCagraIndex_t index, int *returnValue, char* filename) { - *returnValue = cuvsCagraSerialize(cuvsResources, filename, index, true); +/** + * A function to serialize a CAGRA index + * + * @param cuvs_resources reference of the underlying opaque C handle + * @param index cuvsCagraIndex_t reference + * @param return_value return value for cuvsCagraSerialize function call + * @param filename the filename of the index file + */ +void serialize_cagra_index(cuvsResources_t cuvs_resources, cuvsCagraIndex_t index, int *return_value, char* filename) { + *return_value = cuvsCagraSerialize(cuvs_resources, filename, index, true); } -void deserialize_cagra_index(cuvsResources_t cuvsResources, cuvsCagraIndex_t index, int *returnValue, char* filename) { - *returnValue = cuvsCagraDeserialize(cuvsResources, filename, index); +/** + * A function to de-serialize a CAGRA index + * + * @param cuvs_resources reference to the underlying opaque C handle + * @param index cuvsCagraIndex_t reference + * @param return_value return value for cuvsCagraDeserialize function call + * @param filename the filename of the index file + */ +void deserialize_cagra_index(cuvsResources_t cuvs_resources, cuvsCagraIndex_t index, int *return_value, char* filename) { + *return_value = cuvsCagraDeserialize(cuvs_resources, filename, index); } +/** + * A function to search a CAGRA index and return results + * + * @param index reference to a CAGRA index to search on + * @param queries query vectors + * @param topk topK results to return + * @param n_queries number of queries + * @param dimensions vector dimension + * @param cuvs_resources reference to the underlying opaque C handle + * @param neighbors_h reference to the neighbor results on the host memory + * @param distances_h reference to the distance results on the host memory + * @param return_value return value for cuvsCagraSearch function call + * @param search_params reference to cuvsCagraSearchParams_t holding the search parameters + */ void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n_queries, int dimensions, - cuvsResources_t cuvsResources, int *neighbors_h, float *distances_h, int *returnValue, cuvsCagraSearchParams_t search_params) { + cuvsResources_t cuvs_resources, int *neighbors_h, float *distances_h, int *return_value, cuvsCagraSearchParams_t search_params) { + + cudaStream_t stream; + cuvsStreamGet(cuvs_resources, &stream); uint32_t *neighbors; float *distances, *queries_d; - cuvsRMMAlloc(cuvsResources, (void**) &queries_d, sizeof(float) * n_queries * dimensions); - cuvsRMMAlloc(cuvsResources, (void**) &neighbors, sizeof(uint32_t) * n_queries * topk); - cuvsRMMAlloc(cuvsResources, (void**) &distances, sizeof(float) * n_queries * topk); + cuvsRMMAlloc(cuvs_resources, (void**) &queries_d, sizeof(float) * n_queries * dimensions); + cuvsRMMAlloc(cuvs_resources, (void**) &neighbors, sizeof(uint32_t) * n_queries * topk); + cuvsRMMAlloc(cuvs_resources, (void**) &distances, sizeof(float) * n_queries * topk); cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * dimensions, cudaMemcpyDefault); int64_t queries_shape[2] = {n_queries, dimensions}; - DLManagedTensor queries_tensor = prepare_tensor(queries_d, queries_shape, kDLFloat); + DLManagedTensor queries_tensor = prepare_tensor(queries_d, queries_shape, kDLFloat, 32, 2); int64_t neighbors_shape[2] = {n_queries, topk}; - DLManagedTensor neighbors_tensor = prepare_tensor(neighbors, neighbors_shape, kDLUInt); + DLManagedTensor neighbors_tensor = prepare_tensor(neighbors, neighbors_shape, kDLUInt, 32, 2); int64_t distances_shape[2] = {n_queries, topk}; - DLManagedTensor distances_tensor = prepare_tensor(distances, distances_shape, kDLFloat); + DLManagedTensor distances_tensor = prepare_tensor(distances, distances_shape, kDLFloat, 32, 2); - *returnValue = cuvsCagraSearch(cuvsResources, search_params, index, &queries_tensor, &neighbors_tensor, + cuvsStreamSync(cuvs_resources); + *return_value = cuvsCagraSearch(cuvs_resources, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor); cudaMemcpy(neighbors_h, neighbors, sizeof(uint32_t) * n_queries * topk, cudaMemcpyDefault); cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault); - cuvsRMMFree(cuvsResources, distances, sizeof(float) * n_queries * topk); - cuvsRMMFree(cuvsResources, neighbors, sizeof(uint32_t) * n_queries * topk); - cuvsRMMFree(cuvsResources, queries_d, sizeof(float) * n_queries * dimensions); + cuvsRMMFree(cuvs_resources, distances, sizeof(float) * n_queries * topk); + cuvsRMMFree(cuvs_resources, neighbors, sizeof(uint32_t) * n_queries * topk); + cuvsRMMFree(cuvs_resources, queries_d, sizeof(float) * n_queries * dimensions); +} + +/** + * De-allocate BRUTEFORCE index + * + * @param index reference to BRUTEFORCE index + * @param return_value return value for cuvsBruteForceIndexDestroy function call + */ +void destroy_brute_force_index(cuvsBruteForceIndex_t index, int *return_value) { + *return_value = cuvsBruteForceIndexDestroy(index); +} + +/** + * A function to build BRUTEFORCE index + * + * @param dataset the dataset to be indexed + * @param rows the number of rows in the dataset + * @param dimensions the vector dimension + * @param cuvs_resources reference to the underlying opaque C handle + * @param return_value return value for cuvsBruteForceBuild function call + * @param n_writer_threads number of threads to use while indexing + */ +cuvsBruteForceIndex_t build_brute_force_index(float *dataset, long rows, long dimensions, cuvsResources_t cuvs_resources, + int *return_value, int n_writer_threads) { + + omp_set_num_threads(n_writer_threads); + cuvsRMMPoolMemoryResourceEnable(95, 95, false); + + cudaStream_t stream; + cuvsStreamGet(cuvs_resources, &stream); + + float *dataset_d; + cuvsRMMAlloc(cuvs_resources, (void**) &dataset_d, sizeof(float) * rows * dimensions); + cudaMemcpy(dataset_d, dataset, sizeof(float) * rows * dimensions, cudaMemcpyDefault); + + int64_t dataset_shape[2] = {rows, dimensions}; + DLManagedTensor dataset_tensor = prepare_tensor(dataset_d, dataset_shape, kDLFloat, 32, 2); + + cuvsBruteForceIndex_t index; + cuvsError_t index_create_status = cuvsBruteForceIndexCreate(&index); + + cuvsStreamSync(cuvs_resources); + *return_value = cuvsBruteForceBuild(cuvs_resources, &dataset_tensor, L2Expanded, 0.f, index); + + cuvsRMMFree(cuvs_resources, dataset_d, sizeof(float) * rows * dimensions); + omp_set_num_threads(1); + + return index; +} + +/** + * A function to search the BRUTEFORCE index + * + * @param index reference to a BRUTEFORCE index to search on + * @param queries reference to query vectors + * @param topk the top k results to return + * @param n_queries number of queries + * @param dimensions vector dimension + * @param cuvs_resources reference to the underlying opaque C handle + * @param neighbors_h reference to the neighbor results on the host memory + * @param distances_h reference to the distance results on the host memory + * @param return_value return value for cuvsBruteForceSearch function call + * @param prefilter_data cuvsFilter input prefilter that can be used to filter queries and neighbors based on the given bitmap + * @param prefilter_data_length prefilter length input + */ +void search_brute_force_index(cuvsBruteForceIndex_t index, float *queries, int topk, long n_queries, int dimensions, + cuvsResources_t cuvs_resources, int64_t *neighbors_h, float *distances_h, int *return_value, long *prefilter_data, + long prefilter_data_length) { + + cudaStream_t stream; + cuvsStreamGet(cuvs_resources, &stream); + + int64_t *neighbors; + float *distances, *queries_d; + long *prefilter_data_d; + cuvsRMMAlloc(cuvs_resources, (void**) &queries_d, sizeof(float) * n_queries * dimensions); + cuvsRMMAlloc(cuvs_resources, (void**) &neighbors, sizeof(int64_t) * n_queries * topk); + cuvsRMMAlloc(cuvs_resources, (void**) &distances, sizeof(float) * n_queries * topk); + cuvsRMMAlloc(cuvs_resources, (void**) &prefilter_data_d, sizeof(long) * prefilter_data_length); + + cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * dimensions, cudaMemcpyDefault); + cudaMemcpy(prefilter_data_d, prefilter_data, sizeof(long) * prefilter_data_length, cudaMemcpyDefault); + + int64_t queries_shape[2] = {n_queries, dimensions}; + DLManagedTensor queries_tensor = prepare_tensor(queries_d, queries_shape, kDLFloat, 32, 2); + + int64_t neighbors_shape[2] = {n_queries, topk}; + DLManagedTensor neighbors_tensor = prepare_tensor(neighbors, neighbors_shape, kDLInt, 64, 2); + + int64_t distances_shape[2] = {n_queries, topk}; + DLManagedTensor distances_tensor = prepare_tensor(distances, distances_shape, kDLFloat, 32, 2); + + cuvsFilter prefilter; + if (prefilter_data == NULL) { + prefilter.type = NO_FILTER; + prefilter.addr = (uintptr_t)NULL; + } else { + int64_t prefilter_shape[1] = {prefilter_data_length}; + DLManagedTensor prefilter_tensor = prepare_tensor(prefilter_data_d, prefilter_shape, kDLUInt, 32, 1); + prefilter.type = BITMAP; + prefilter.addr = (uintptr_t)&prefilter_tensor; + } + + cuvsStreamSync(cuvs_resources); + *return_value = cuvsBruteForceSearch(cuvs_resources, index, &queries_tensor, &neighbors_tensor, &distances_tensor, prefilter); + + cudaMemcpy(neighbors_h, neighbors, sizeof(int64_t) * n_queries * topk, cudaMemcpyDefault); + cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault); + + cuvsRMMFree(cuvs_resources, neighbors, sizeof(int64_t) * n_queries * topk); + cuvsRMMFree(cuvs_resources, distances, sizeof(float) * n_queries * topk); + cuvsRMMFree(cuvs_resources, queries_d, sizeof(float) * n_queries * dimensions); }