Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce native Readers read flatValues directly from faiss file. #2267

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.knn.indices.ModelUtil;

import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.indices.ModelUtil.getModelMetadata;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser;
Expand Down Expand Up @@ -103,4 +104,9 @@ public static SpaceType getSpaceType(final ModelDao modelDao, final FieldInfo fi
}
return modelMetadata.getSpaceType();
}

public static String getParameters(final FieldInfo fieldInfo) {
final String parameters = fieldInfo.getAttribute(PARAMETERS);
return parameters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.engine.KNNEngine;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.index.engine.KNNEngine.FAISS;

/**
* There is 3 Index in one faiss file |id|hnsw|Storage|
* File Structure like followings:
* |-typeIDMap-||-id_header-|
* |-typeHnsw-||-hnsw_header-||-hnswGraph-|
* |-typeStorage-||-storage_Header-||-storageVector-|
* |-idmap_vector-|
*
* header would like:
* |dim|ntotal|dummy|dummy|is_trained|metric_type|metric_arg|
*
* Example for HNSW32,Flat:
* |idMapType|idMapHeader|hnswType|hnswHeader|hnswGraph|flatType|flatHeader|Vectors|IdVector|FOOTER_MAGIC+CHECKSUM|
*/
@Getter
public class FaissEngineFlatKnnVectorsReader extends FaissEngineKnnVectorsReader {

// 1. A Footer magic number (int - 4 bytes)
// 2. A checksum algorithm id (int - 4 bytes)
// 3. A checksum (long - bytes)
// The checksum is computed on all the bytes written to the file up to that point.
// Logic where footer is written in Lucene can be found here:
// https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412
public static final int FOOT_MAGIC_SIZE = RamUsageEstimator.primitiveSizes.get(Integer.TYPE);
public static final int ALGORITHM_SIZE = RamUsageEstimator.primitiveSizes.get(Integer.TYPE);
public static final int CHECKSUM_SIZE = RamUsageEstimator.primitiveSizes.get(Long.TYPE);
public static final int FLOAT_SIZE = RamUsageEstimator.primitiveSizes.get(Float.TYPE);
public static final int SIZET_SIZE = RamUsageEstimator.primitiveSizes.get(Long.TYPE);
public static final int FOOTER_SIZE = FOOT_MAGIC_SIZE + ALGORITHM_SIZE + CHECKSUM_SIZE;

private Map<String, IndexInput> fieldFileMap;
private Map<String, MetaInfo> fieldMetaMap;

@Override
public void checkIntegrity() throws IOException {

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentionally kept no-op?

}

public FaissEngineFlatKnnVectorsReader(SegmentReadState state) throws IOException {
fieldFileMap = new HashMap<>();
fieldMetaMap = new HashMap<>();
boolean success = false;
try {
for (FieldInfo field : state.fieldInfos) {

KNNEngine knnEngine = KNNCodecUtil.getNativeKNNEngine(field);
if (knnEngine == null || FAISS != knnEngine) {
continue;
}
final String vectorIndexFileName = KNNCodecUtil.getNativeEngineFileFromFieldInfo(field, state.segmentInfo);
if (vectorIndexFileName == null) {
continue;
}
// TODO for fp16, pq
VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(field);
SpaceType spaceType = FieldInfoExtractor.getSpaceType(null, field);
if (vectorDataType != VectorDataType.FLOAT) {
continue;
}
String parameter = FieldInfoExtractor.getParameters(field);
if (parameter == null || parameter.contains("BHNSW")) {
continue;
}
// TODO if not exist file, change to lucene flatVector
IndexInput in = state.directory.openInput(vectorIndexFileName, state.context.withRandomAccess());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we want to give access to entire file through flat vectors reader. Is it possible to limit the access to using IndexInput#slice and then storing it in the map?

if (in == null) {
continue;
}
fieldFileMap.put(field.getName(), in);
}
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}

for (Map.Entry<String, IndexInput> entry : fieldFileMap.entrySet()) {
IndexInput in = entry.getValue();
int h = in.readInt();
MetaInfo metaInfo = read_index_header(in);
fieldMetaMap.put(entry.getKey(), metaInfo);
}
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
MetaInfo metaInfo = fieldMetaMap.get(field);
IndexInput input = fieldFileMap.get(field);
FaissEngineFlatVectorValues vectorValues = new FaissEngineFlatVectorValues(metaInfo, input);
return vectorValues;
}

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return null;
}

@Override
public boolean isNativeVectors(String field) {
return fieldFileMap.containsKey(field) && fieldMetaMap.containsKey(field);
}

private MetaInfo read_index_header(IndexInput in) throws IOException {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: camelCase?


int d = in.readInt();
long ntotal = in.readLong();
long dummy;
dummy = in.readLong();
dummy = in.readLong();
byte is_trained = in.readByte();
//
int metric_type = in.readInt();
float metric_arg = 0;
if (metric_type > 1) {
metric_arg = Float.intBitsToFloat(in.readInt());
}
long filesize = in.length();
// There is (ntotal+1) * idx_t and FOOTER_SIZE
long idSeek = filesize - (ntotal + 1) * SIZET_SIZE - FOOTER_SIZE;
// in.seek(idSeek);
// long size = in.readLong();

// long[] ids = new long[(int) ntotal];
// in.readLongs(ids, 0, (int) ntotal);
long vectorSeek = idSeek - (FLOAT_SIZE * d) * ntotal - SIZET_SIZE;
// in.seek(vectorSeek);

// float[] v = new float[(int) (d * ntotal)];
// size = in.readLong();
// System.out.println("Vector Size: " + size + " d * ntotal" + d * ntotal);
// for(int i = 0; i < ntotal; i++) {
// in.readFloats(v, i * d, d);
// System.out.println("vector:");
// for (int j = 0; j < d; j++) {
// System.out.println(v[i*d + j]);
// }
// }
return new MetaInfo(d, ntotal, is_trained, metric_type, metric_arg, idSeek, vectorSeek);
}

@Override
public void close() throws IOException {
for (Map.Entry<String, IndexInput> entry : fieldFileMap.entrySet()) {
IndexInput in = entry.getValue();
IOUtils.close(in);
}
}

@AllArgsConstructor
@Getter
public class MetaInfo {
int d;
long ntotal;
byte isTrained;
int metricType;
float metricArg;
long idSeek;
long vectorSeek;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.opensearch.knn.index.KNNVectorSimilarityFunction;
import org.opensearch.knn.index.SpaceType;

import java.io.IOException;
import java.util.Arrays;

import static org.opensearch.knn.index.codec.KNN990Codec.FaissEngineFlatKnnVectorsReader.FLOAT_SIZE;
import static org.opensearch.knn.index.codec.KNN990Codec.FaissEngineFlatKnnVectorsReader.SIZET_SIZE;

public class FaissEngineFlatVectorValues extends FloatVectorValues {
private static final int BUCKET_VECTORS = 64; // every time read only bucket size vectors.
protected FaissEngineFlatKnnVectorsReader.MetaInfo metaInfo;
protected final IndexInput slice;
protected final VectorSimilarityFunction similarityFunction;
protected final FlatVectorsScorer flatVectorsScorer;
protected final float[] value;
protected final long[] ids;
protected final float[] buf;
protected int docId = -1;
protected int ord = -1;

public FaissEngineFlatVectorValues(FaissEngineFlatKnnVectorsReader.MetaInfo metaInfo, IndexInput input) throws IOException {
this.metaInfo = metaInfo;
this.slice = input.clone();
this.similarityFunction = getVectorSimilarityFunction(metaInfo.metricType).getVectorSimilarityFunction();
this.flatVectorsScorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

this.value = new float[(int) (metaInfo.d * metaInfo.ntotal)];
this.ids = new long[(int) metaInfo.ntotal];
this.buf = new float[metaInfo.d];
readIds();
}

protected void readIds() throws IOException {
slice.seek(metaInfo.idSeek);
long size = slice.readLong();
assert size == metaInfo.ntotal;
slice.readLongs(ids, 0, (int) metaInfo.ntotal);
}

protected void readBucketVectors() throws IOException {
assert ord >= 0;
assert ord <= metaInfo.ntotal;
int bucketIndex = ord / BUCKET_VECTORS;
slice.seek(metaInfo.vectorSeek + SIZET_SIZE + bucketIndex * BUCKET_VECTORS * FLOAT_SIZE * metaInfo.d);

for (int i = 0, o = ord; i < BUCKET_VECTORS && o < metaInfo.ntotal; i++, o++) {
slice.readFloats(value, i * metaInfo.d, metaInfo.d);
}
}
// public void readInfo() throws IOException {
// slice.seek(metaInfo.idSeek);
// long size = slice.readLong();
// assert size == metaInfo.ntotal;
// slice.readLongs(ids, 0, (int) metaInfo.ntotal);
//
// slice.seek(metaInfo.vectorSeek);
// size = slice.readLong();
// for(int i = 0; i < metaInfo.ntotal; i++) {
// slice.readFloats(value, i * metaInfo.d, metaInfo.d);
// }
// }

@Override
public int dimension() {
return metaInfo.d;
}

@Override
public int size() {
return (int) metaInfo.ntotal;
}

@Override
public float[] vectorValue() throws IOException {
if (ord % BUCKET_VECTORS == 0) {
readBucketVectors();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to understand the thought process here, why do we need to read 64 vectors at once? we can seek to the position based on ordinal and read the required vector right?

}
int bucketOrder = ord % BUCKET_VECTORS;

System.arraycopy(value, bucketOrder * metaInfo.d, buf, 0, metaInfo.d);
return buf;
}

@Override
public VectorScorer scorer(float[] floats) throws IOException {
// TODO
return null;
}

@Override
public int docID() {
return docId;
}

@Override
public int nextDoc() throws IOException {
return advance(docId + 1);
}

@Override
public int advance(int target) throws IOException {
ord = Arrays.binarySearch(ids, ord + 1, ids.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
assert ord <= ids.length;
if (ord == ids.length) {
docId = NO_MORE_DOCS;
} else {
docId = (int) ids[ord];
}
return docId;
}

KNNVectorSimilarityFunction getVectorSimilarityFunction(int metricType) {
// Ref from jni/external/faiss/c_api/Index_c.h
switch (metricType) {
case 0:
return SpaceType.INNER_PRODUCT.getKnnVectorSimilarityFunction();
case 1:
return SpaceType.L2.getKnnVectorSimilarityFunction();
case 2:
return SpaceType.L1.getKnnVectorSimilarityFunction();
case 3:
return SpaceType.LINF.getKnnVectorSimilarityFunction();
default:
return SpaceType.L2.getKnnVectorSimilarityFunction();
}
}
}
Loading
Loading