Skip to content

Commit

Permalink
format, update test
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Aug 12, 2024
1 parent 067cfb1 commit 5ba27de
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension)

project(${TARGET_NAME})

option(USE_SIMSIMD "Use SIMSIMD library to sacrifice portability for vectorized search" OFF)
option(USE_SIMSIMD
"Use SIMSIMD library to sacrifice portability for vectorized search" OFF)
if(USE_SIMSIMD)
add_definitions(-DDUCKDB_USEARCH_USE_SIMSIMD=1)
add_definitions(-DDUCKDB_USEARCH_USE_SIMSIMD=1)
else()
add_definitions(-DDUCKDB_USEARCH_USE_SIMSIMD=0)
add_definitions(-DDUCKDB_USEARCH_USE_SIMSIMD=0)
endif()

include_directories(src/include)
Expand Down
13 changes: 5 additions & 8 deletions src/hnsw/hnsw_index_macros.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ FROM
// Register
//-------------------------------------------------------------------------
static void RegisterTableMacro(DatabaseInstance &db, const string &name, const string &query,
const vector<string> &params, const child_list_t<Value> &named_params) {
const vector<string> &params, const child_list_t<Value> &named_params) {

Parser parser;
parser.ParseQuery(query);
Expand Down Expand Up @@ -106,14 +106,11 @@ static void RegisterTableMacro(DatabaseInstance &db, const string &name, const s

void HNSWModule::RegisterMacros(DatabaseInstance &db) {

RegisterTableMacro(db, "vss_join", VSS_JOIN_MACRO,
{"left_table", "right_table", "left_col", "right_col", "k"},
{{"metric", Value("l2sq")}});

RegisterTableMacro(db, "vss_match", VSS_MATCH_MACRO,
{"right_table", "left_col", "right_col", "k"},
{{"metric", Value("l2sq")}});
RegisterTableMacro(db, "vss_join", VSS_JOIN_MACRO, {"left_table", "right_table", "left_col", "right_col", "k"},
{{"metric", Value("l2sq")}});

RegisterTableMacro(db, "vss_match", VSS_MATCH_MACRO, {"right_table", "left_col", "right_col", "k"},
{{"metric", Value("l2sq")}});
}

} // namespace duckdb
11 changes: 7 additions & 4 deletions src/hnsw/hnsw_index_physical_create.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ class HNSWIndexConstructTask final : public ExecutorTask {
public:
HNSWIndexConstructTask(shared_ptr<Event> event_p, ClientContext &context, CreateHNSWIndexGlobalState &gstate_p,
size_t thread_id_p, const PhysicalCreateHNSWIndex &op_p)
: ExecutorTask(context, std::move(event_p), op_p), gstate(gstate_p), thread_id(thread_id_p), local_scan_state() {
: ExecutorTask(context, std::move(event_p), op_p), gstate(gstate_p), thread_id(thread_id_p),
local_scan_state() {
// Initialize the scan chunk
gstate.collection->InitializeScanChunk(scan_chunk);
}
Expand Down Expand Up @@ -209,9 +210,11 @@ class HNSWIndexConstructTask final : public ExecutorTask {

class HNSWIndexConstructionEvent final : public BasePipelineEvent {
public:
HNSWIndexConstructionEvent(const PhysicalCreateHNSWIndex &op_p, CreateHNSWIndexGlobalState &gstate_p, Pipeline &pipeline_p, CreateIndexInfo &info_p,
const vector<column_t> &storage_ids_p, DuckTableEntry &table_p)
: BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), info(info_p), storage_ids(storage_ids_p), table(table_p) {
HNSWIndexConstructionEvent(const PhysicalCreateHNSWIndex &op_p, CreateHNSWIndexGlobalState &gstate_p,
Pipeline &pipeline_p, CreateIndexInfo &info_p, const vector<column_t> &storage_ids_p,
DuckTableEntry &table_p)
: BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), info(info_p), storage_ids(storage_ids_p),
table(table_p) {
}

const PhysicalCreateHNSWIndex &op;
Expand Down
7 changes: 6 additions & 1 deletion test/sql/hnsw/hnsw_join_macro.test
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ CREATE TABLE s(s_vec FLOAT[3]);
statement ok
INSERT INTO s VALUES ([5,5,5]), ([1,1,1]);

query III
query I
SELECT bool_and(score <= 1.0) FROM vss_join(s, t1, s_vec, vec, 3) as res;
----
true

query I
SELECT len(matches) = 3 FROM s, vss_match(t1, s_vec, vec, 3) as res;
----
true
Expand Down

0 comments on commit 5ba27de

Please sign in to comment.