-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Bruteforce API implementation Co-authored-by: Vivek Narang <[email protected]>
- Loading branch information
1 parent
6f708eb
commit b98d2bb
Showing
13 changed files
with
967 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,12 @@ | ||
export CMAKE_PREFIX_PATH=`pwd`/../cpp/build | ||
|
||
VERSION="25.02" | ||
GROUP_ID="com.nvidia.cuvs" | ||
SO_FILE_PATH="./internal" | ||
|
||
cd internal && cmake . && cmake --build . \ | ||
&& cd .. \ | ||
&& mvn install:install-file -DgroupId=com.nvidia.cuvs -DartifactId=cuvs-java-internal -Dversion=25.02 -Dpackaging=so -Dfile=./internal/libcuvs_java.so \ | ||
&& mvn install:install-file -DgroupId=$GROUP_ID -DartifactId=cuvs-java-internal -Dversion=$VERSION -Dpackaging=so -Dfile=$SO_FILE_PATH/libcuvs_java.so \ | ||
&& cd cuvs-java \ | ||
&& mvn package \ | ||
&& mvn install:install-file -Dfile=./target/cuvs-java-25.02.1-jar-with-dependencies.jar -DgroupId=com.nvidia.cuvs -DartifactId=cuvs-java -Dversion=54.02.1 -Dpackaging=jar | ||
&& mvn install:install-file -Dfile=./target/cuvs-java-$VERSION-jar-with-dependencies.jar -DgroupId=$GROUP_ID -DartifactId=cuvs-java -Dversion=$VERSION -Dpackaging=jar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
255 changes: 255 additions & 0 deletions
255
java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndex.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.cuvs; | ||
|
||
import java.io.IOException; | ||
import java.lang.foreign.FunctionDescriptor; | ||
import java.lang.foreign.MemoryLayout; | ||
import java.lang.foreign.MemorySegment; | ||
import java.lang.foreign.SequenceLayout; | ||
import java.lang.foreign.ValueLayout; | ||
import java.lang.invoke.MethodHandle; | ||
import java.lang.invoke.MethodHandles; | ||
|
||
import com.nvidia.cuvs.common.Util; | ||
|
||
/** | ||
* The BRUTEFORCE method is running the KNN algorithm. It performs an extensive | ||
* search, and in contrast to ANN methods produces an exact result. | ||
* | ||
* {@link BruteForceIndex} encapsulates a BRUTEFORCE index, along with methods | ||
* to interact with it. | ||
* | ||
* @since 25.02 | ||
*/ | ||
public class BruteForceIndex { | ||
|
||
private final float[][] dataset; | ||
private final long[] prefilterData; | ||
private final CuVSResources resources; | ||
private MethodHandle indexMethodHandle; | ||
private MethodHandle searchMethodHandle; | ||
private MethodHandle destroyIndexMethodHandle; | ||
private IndexReference bruteForceIndexReference; | ||
private BruteForceIndexParams bruteForceIndexParams; | ||
private MemoryLayout longMemoryLayout; | ||
private MemoryLayout intMemoryLayout; | ||
private MemoryLayout floatMemoryLayout; | ||
|
||
/** | ||
* Constructor for building the index using specified dataset | ||
* | ||
* @param dataset the dataset used for creating the BRUTEFORCE | ||
* index | ||
* @param resources an instance of {@link CuVSResources} | ||
* @param bruteForceIndexParams an instance of {@link BruteForceIndexParams} | ||
* holding the index parameters | ||
* @param prefilterData the prefilter data to use while searching the | ||
* BRUTEFORCE index | ||
*/ | ||
private BruteForceIndex(float[][] dataset, CuVSResources resources, BruteForceIndexParams bruteForceIndexParams, | ||
long[] prefilterData) throws Throwable { | ||
this.dataset = dataset; | ||
this.prefilterData = prefilterData; | ||
this.resources = resources; | ||
this.bruteForceIndexParams = bruteForceIndexParams; | ||
|
||
longMemoryLayout = resources.linker.canonicalLayouts().get("long"); | ||
intMemoryLayout = resources.linker.canonicalLayouts().get("int"); | ||
floatMemoryLayout = resources.linker.canonicalLayouts().get("float"); | ||
|
||
initializeMethodHandles(); | ||
this.bruteForceIndexReference = build(); | ||
} | ||
|
||
/** | ||
* Initializes the {@link MethodHandles} for invoking native methods. | ||
* | ||
* @throws IOException @{@link IOException} is unable to load the native library | ||
*/ | ||
private void initializeMethodHandles() throws IOException { | ||
indexMethodHandle = resources.linker.downcallHandle( | ||
resources.getSymbolLookup().find("build_brute_force_index").get(), | ||
FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, longMemoryLayout, longMemoryLayout, | ||
ValueLayout.ADDRESS, ValueLayout.ADDRESS, intMemoryLayout)); | ||
|
||
searchMethodHandle = resources.linker.downcallHandle( | ||
resources.getSymbolLookup().find("search_brute_force_index").get(), | ||
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, intMemoryLayout, longMemoryLayout, | ||
intMemoryLayout, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, | ||
ValueLayout.ADDRESS, longMemoryLayout)); | ||
|
||
destroyIndexMethodHandle = resources.linker.downcallHandle( | ||
resources.getSymbolLookup().find("destroy_brute_force_index").get(), | ||
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); | ||
} | ||
|
||
/** | ||
* Invokes the native destroy_brute_force_index function to de-allocate | ||
* BRUTEFORCE index | ||
*/ | ||
public void destroyIndex() throws Throwable { | ||
MemoryLayout returnValueMemoryLayout = intMemoryLayout; | ||
MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); | ||
destroyIndexMethodHandle.invokeExact(bruteForceIndexReference.getMemorySegment(), returnValueMemorySegment); | ||
} | ||
|
||
/** | ||
* Invokes the native build_brute_force_index function via the Panama API to | ||
* build the {@link BruteForceIndex} | ||
* | ||
* @return an instance of {@link IndexReference} that holds the pointer to the | ||
* index | ||
*/ | ||
private IndexReference build() throws Throwable { | ||
long rows = dataset.length; | ||
long cols = rows > 0 ? dataset[0].length : 0; | ||
|
||
MemoryLayout returnValueMemoryLayout = intMemoryLayout; | ||
MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); | ||
|
||
IndexReference indexReference = new IndexReference((MemorySegment) indexMethodHandle.invokeExact( | ||
Util.buildMemorySegment(resources.linker, resources.arena, dataset), rows, cols, resources.getMemorySegment(), | ||
returnValueMemorySegment, bruteForceIndexParams.getNumWriterThreads())); | ||
|
||
return indexReference; | ||
} | ||
|
||
/** | ||
* Invokes the native search_brute_force_index via the Panama API for searching | ||
* a BRUTEFORCE index. | ||
* | ||
* @param cuvsQuery an instance of {@link BruteForceQuery} holding the query | ||
* vectors and other parameters | ||
* @return an instance of {@link BruteForceSearchResults} containing the results | ||
*/ | ||
public BruteForceSearchResults search(BruteForceQuery cuvsQuery) throws Throwable { | ||
long numQueries = cuvsQuery.getQueryVectors().length; | ||
long numBlocks = cuvsQuery.getTopK() * numQueries; | ||
int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0; | ||
long prefilterDataLength = prefilterData != null ? prefilterData.length : 0; | ||
|
||
SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, longMemoryLayout); | ||
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, floatMemoryLayout); | ||
MemorySegment neighborsMemorySegment = resources.arena.allocate(neighborsSequenceLayout); | ||
MemorySegment distancesMemorySegment = resources.arena.allocate(distancesSequenceLayout); | ||
MemoryLayout returnValueMemoryLayout = intMemoryLayout; | ||
MemorySegment returnValueMemorySegment = resources.arena.allocate(returnValueMemoryLayout); | ||
MemorySegment prefilterDataMemorySegment = prefilterData != null | ||
? Util.buildMemorySegment(resources.linker, resources.arena, prefilterData) | ||
: MemorySegment.NULL; | ||
|
||
searchMethodHandle.invokeExact(bruteForceIndexReference.getMemorySegment(), | ||
Util.buildMemorySegment(resources.linker, resources.arena, cuvsQuery.getQueryVectors()), cuvsQuery.getTopK(), | ||
numQueries, vectorDimension, resources.getMemorySegment(), neighborsMemorySegment, distancesMemorySegment, | ||
returnValueMemorySegment, prefilterDataMemorySegment, prefilterDataLength); | ||
|
||
return new BruteForceSearchResults(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment, | ||
distancesMemorySegment, cuvsQuery.getTopK(), cuvsQuery.getMapping(), numQueries); | ||
} | ||
|
||
/** | ||
* Builder helps configure and create an instance of {@link BruteForceIndex}. | ||
*/ | ||
public static class Builder { | ||
|
||
private float[][] dataset; | ||
private long[] prefilterData; | ||
private CuVSResources cuvsResources; | ||
private BruteForceIndexParams bruteForceIndexParams; | ||
|
||
/** | ||
* Constructs this Builder with an instance of {@link CuVSResources}. | ||
* | ||
* @param cuvsResources an instance of {@link CuVSResources} | ||
*/ | ||
public Builder(CuVSResources cuvsResources) { | ||
this.cuvsResources = cuvsResources; | ||
} | ||
|
||
/** | ||
* Registers an instance of configured {@link BruteForceIndexParams} with this | ||
* Builder. | ||
* | ||
* @param bruteForceIndexParams An instance of BruteForceIndexParams | ||
* @return An instance of this Builder | ||
*/ | ||
public Builder withIndexParams(BruteForceIndexParams bruteForceIndexParams) { | ||
this.bruteForceIndexParams = bruteForceIndexParams; | ||
return this; | ||
} | ||
|
||
/** | ||
* Sets the dataset for building the {@link BruteForceIndex}. | ||
* | ||
* @param dataset a two-dimensional float array | ||
* @return an instance of this Builder | ||
*/ | ||
public Builder withDataset(float[][] dataset) { | ||
this.dataset = dataset; | ||
return this; | ||
} | ||
|
||
/** | ||
* Sets the prefilter data for building the {@link BruteForceIndex}. | ||
* | ||
* @param prefilterData a one-dimensional long array | ||
* @return an instance of this Builder | ||
*/ | ||
public Builder withPrefilterData(long[] prefilterData) { | ||
this.prefilterData = prefilterData; | ||
return this; | ||
} | ||
|
||
/** | ||
* Builds and returns an instance of {@link BruteForceIndex}. | ||
* | ||
* @return an instance of {@link BruteForceIndex} | ||
*/ | ||
public BruteForceIndex build() throws Throwable { | ||
return new BruteForceIndex(dataset, cuvsResources, bruteForceIndexParams, prefilterData); | ||
} | ||
} | ||
|
||
/** | ||
* Holds the memory reference to a BRUTEFORCE index. | ||
*/ | ||
protected static class IndexReference { | ||
|
||
private final MemorySegment memorySegment; | ||
|
||
/** | ||
* Constructs BruteForceIndexReference with an instance of MemorySegment passed | ||
* as a parameter. | ||
* | ||
* @param indexMemorySegment the MemorySegment instance to use for containing | ||
* index reference | ||
*/ | ||
protected IndexReference(MemorySegment indexMemorySegment) { | ||
this.memorySegment = indexMemorySegment; | ||
} | ||
|
||
/** | ||
* Gets the instance of index MemorySegment. | ||
* | ||
* @return index MemorySegment | ||
*/ | ||
protected MemorySegment getMemorySegment() { | ||
return memorySegment; | ||
} | ||
} | ||
} |
71 changes: 71 additions & 0 deletions
71
java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndexParams.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.cuvs; | ||
|
||
/** | ||
* Supplemental parameters to build BRUTEFORCE index. | ||
* | ||
* @since 25.02 | ||
*/ | ||
public class BruteForceIndexParams { | ||
|
||
private final int numWriterThreads; | ||
|
||
private BruteForceIndexParams(int writerThreads) { | ||
this.numWriterThreads = writerThreads; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "BruteForceIndexParams [numWriterThreads=" + numWriterThreads + "]"; | ||
} | ||
|
||
/** | ||
* Gets the number of threads used to build the index. | ||
*/ | ||
public int getNumWriterThreads() { | ||
return numWriterThreads; | ||
} | ||
|
||
/** | ||
* Builder configures and creates an instance of {@link BruteForceIndexParams}. | ||
*/ | ||
public static class Builder { | ||
|
||
private int numWriterThreads = 2; | ||
|
||
/** | ||
* Sets the number of writer threads to use for indexing. | ||
* | ||
* @param numWriterThreads number of writer threads to use | ||
* @return an instance of Builder | ||
*/ | ||
public Builder withNumWriterThreads(int numWriterThreads) { | ||
this.numWriterThreads = numWriterThreads; | ||
return this; | ||
} | ||
|
||
/** | ||
* Builds an instance of {@link BruteForceIndexParams}. | ||
* | ||
* @return an instance of {@link BruteForceIndexParams} | ||
*/ | ||
public BruteForceIndexParams build() { | ||
return new BruteForceIndexParams(numWriterThreads); | ||
} | ||
} | ||
} |
Oops, something went wrong.