diff --git a/duckdb b/duckdb index a5e12fe..d8a69cc 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit a5e12fee059bfd374597f32a61986f7c2eaeb2e7 +Subproject commit d8a69cc6563f510a834a80e68d8afc9b81e6b2e3 diff --git a/src/hnsw/CMakeLists.txt b/src/hnsw/CMakeLists.txt index 38bc222..8bd40f2 100644 --- a/src/hnsw/CMakeLists.txt +++ b/src/hnsw/CMakeLists.txt @@ -3,6 +3,7 @@ set(EXTENSION_SOURCES ${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_logical_create.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_macros.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_physical_create.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_pragmas.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hnsw_index_scan.cpp diff --git a/src/hnsw/hnsw_index.cpp b/src/hnsw/hnsw_index.cpp index 012a8ea..a4b4923 100644 --- a/src/hnsw/hnsw_index.cpp +++ b/src/hnsw/hnsw_index.cpp @@ -14,7 +14,7 @@ namespace duckdb { class LinkedBlock { public: - static constexpr const idx_t BLOCK_SIZE = Storage::BLOCK_SIZE - sizeof(validity_t); + static constexpr const idx_t BLOCK_SIZE = Storage::DEFAULT_BLOCK_SIZE - sizeof(validity_t); static constexpr const idx_t BLOCK_DATA_SIZE = BLOCK_SIZE - sizeof(IndexPointer); static_assert(BLOCK_SIZE > sizeof(IndexPointer), "Block size must be larger than the size of an IndexPointer"); diff --git a/src/hnsw/hnsw_index_macros.cpp b/src/hnsw/hnsw_index_macros.cpp new file mode 100644 index 0000000..38243dd --- /dev/null +++ b/src/hnsw/hnsw_index_macros.cpp @@ -0,0 +1,119 @@ +#include "duckdb/function/table_macro_function.hpp" +#include "duckdb/main/extension_util.hpp" +#include "duckdb/optimizer/matcher/expression_matcher.hpp" +#include "hnsw/hnsw.hpp" +#include "hnsw/hnsw_index.hpp" +#include "duckdb/parser/parser.hpp" + +namespace duckdb { + +static constexpr auto VSS_JOIN_MACRO = R"( +SELECT + score, + left_tbl, + right_tbl, +FROM + (SELECT * FROM query_table(left_table::VARCHAR)) as left_tbl, + ( + SELECT + struct_pack(*columns([x for x in (matches.*) if x != 'score'])) as right_tbl, + matches.score as score + FROM ( + SELECT ( + unnest( + CASE WHEN metric = 'l2sq' OR metric = 'l2' + THEN min_by(tbl, tbl.score, k) + ELSE max_by(tbl, tbl.score, k) + END, + max_depth := 2 + ) + ) as result + FROM ( + SELECT + *, + CASE + WHEN metric = 'l2sq' OR metric = 'l2' + THEN array_distance(left_col, right_col) + WHEN metric = 'cosine' OR metric = 'cos' + THEN array_cosine_similarity(left_col, right_col) + WHEN metric = 'ip' + THEN array_inner_product(left_col, right_col) + ELSE error('Unknown metric') + END as score, + FROM query_table(right_table::VARCHAR) + ) as tbl + ) as matches + ) +)"; + +static constexpr auto VSS_MATCH_MACRO = R"( +SELECT + right_tbl as matches, +FROM + ( + SELECT ( + CASE WHEN metric = 'l2sq' OR metric = 'l2' + THEN min_by({'score': score, 'row': t}, score, k) + ELSE max_by({'score': score, 'row': t}, score, k) + END + ) as right_tbl + FROM ( + SELECT + CASE + WHEN metric = 'l2sq' OR metric = 'l2' + THEN array_distance(left_col, right_col) + WHEN metric = 'cosine' OR metric = 'cos' + THEN array_cosine_similarity(left_col, right_col) + WHEN metric = 'ip' + THEN array_inner_product(left_col, right_col) + ELSE error('Unknown metric') + END as score, + tbl as t, + FROM (SELECT * FROM query_table(right_table::VARCHAR)) as tbl + ) + ) +)"; + +//------------------------------------------------------------------------- +// Register +//------------------------------------------------------------------------- +static void RegisterTableMacro(DatabaseInstance &db, const string &name, const string &query, + const vector<string> ¶ms, const child_list_t<Value> &named_params) { + + Parser parser; + parser.ParseQuery(query); + const auto &stmt = parser.statements.back(); + auto &node = stmt->Cast<SelectStatement>().node; + + auto func = make_uniq<TableMacroFunction>(std::move(node)); + for (auto ¶m : params) { + func->parameters.push_back(make_uniq<ColumnRefExpression>(param)); + } + + for (auto ¶m : named_params) { + func->default_parameters[param.first] = make_uniq<ConstantExpression>(param.second); + } + + CreateMacroInfo info(CatalogType::TABLE_MACRO_ENTRY); + info.schema = DEFAULT_SCHEMA; + info.name = name; + info.temporary = true; + info.internal = true; + info.macros.push_back(std::move(func)); + + ExtensionUtil::RegisterFunction(db, info); +} + +void HNSWModule::RegisterMacros(DatabaseInstance &db) { + + RegisterTableMacro(db, "vss_join", VSS_JOIN_MACRO, + {"left_table", "right_table", "left_col", "right_col", "k"}, + {{"metric", Value("l2sq")}}); + + RegisterTableMacro(db, "vss_match", VSS_MATCH_MACRO, + {"right_table", "left_col", "right_col", "k"}, + {{"metric", Value("l2sq")}}); + +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/hnsw/hnsw_index_physical_create.cpp b/src/hnsw/hnsw_index_physical_create.cpp index fa771c0..fbcb045 100644 --- a/src/hnsw/hnsw_index_physical_create.cpp +++ b/src/hnsw/hnsw_index_physical_create.cpp @@ -130,8 +130,8 @@ SinkCombineResultType PhysicalCreateHNSWIndex::Combine(ExecutionContext &context class HNSWIndexConstructTask final : public ExecutorTask { public: HNSWIndexConstructTask(shared_ptr<Event> event_p, ClientContext &context, CreateHNSWIndexGlobalState &gstate_p, - size_t thread_id_p) - : ExecutorTask(context, std::move(event_p)), gstate(gstate_p), thread_id(thread_id_p), local_scan_state() { + size_t thread_id_p, const PhysicalCreateHNSWIndex &op_p) + : ExecutorTask(context, std::move(event_p), op_p), gstate(gstate_p), thread_id(thread_id_p), local_scan_state() { // Initialize the scan chunk gstate.collection->InitializeScanChunk(scan_chunk); } @@ -209,11 +209,12 @@ class HNSWIndexConstructTask final : public ExecutorTask { class HNSWIndexConstructionEvent final : public BasePipelineEvent { public: - HNSWIndexConstructionEvent(CreateHNSWIndexGlobalState &gstate_p, Pipeline &pipeline_p, CreateIndexInfo &info_p, + HNSWIndexConstructionEvent(const PhysicalCreateHNSWIndex &op_p, CreateHNSWIndexGlobalState &gstate_p, Pipeline &pipeline_p, CreateIndexInfo &info_p, const vector<column_t> &storage_ids_p, DuckTableEntry &table_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p), info(info_p), storage_ids(storage_ids_p), table(table_p) { + : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), info(info_p), storage_ids(storage_ids_p), table(table_p) { } + const PhysicalCreateHNSWIndex &op; CreateHNSWIndexGlobalState &gstate; CreateIndexInfo &info; const vector<column_t> &storage_ids; @@ -229,7 +230,7 @@ class HNSWIndexConstructionEvent final : public BasePipelineEvent { vector<shared_ptr<Task>> construct_tasks; for (size_t tnum = 0; tnum < num_threads; tnum++) { - construct_tasks.push_back(make_uniq<HNSWIndexConstructTask>(shared_from_this(), context, gstate, tnum)); + construct_tasks.push_back(make_uniq<HNSWIndexConstructTask>(shared_from_this(), context, gstate, tnum, op)); } SetTasks(std::move(construct_tasks)); } @@ -295,7 +296,7 @@ SinkFinalizeType PhysicalCreateHNSWIndex::Finalize(Pipeline &pipeline, Event &ev collection->InitializeScan(gstate.scan_state, ColumnDataScanProperties::ALLOW_ZERO_COPY); // Create a new event that will construct the index - auto new_event = make_shared_ptr<HNSWIndexConstructionEvent>(gstate, pipeline, *info, storage_ids, table); + auto new_event = make_shared_ptr<HNSWIndexConstructionEvent>(*this, gstate, pipeline, *info, storage_ids, table); event.InsertEvent(std::move(new_event)); return SinkFinalizeType::READY; diff --git a/src/include/hnsw/hnsw.hpp b/src/include/hnsw/hnsw.hpp index 765babf..6247e4f 100644 --- a/src/include/hnsw/hnsw.hpp +++ b/src/include/hnsw/hnsw.hpp @@ -12,6 +12,7 @@ struct HNSWModule { RegisterIndexPragmas(db); RegisterPlanIndexScan(db); RegisterPlanIndexCreate(db); + RegisterMacros(db); } private: @@ -20,6 +21,7 @@ struct HNSWModule { static void RegisterIndexPragmas(DatabaseInstance &db); static void RegisterPlanIndexScan(DatabaseInstance &db); static void RegisterPlanIndexCreate(DatabaseInstance &db); + static void RegisterMacros(DatabaseInstance &db); }; } // namespace duckdb \ No newline at end of file diff --git a/test/sql/hnsw/hnsw_join_macro.test b/test/sql/hnsw/hnsw_join_macro.test new file mode 100644 index 0000000..6b8d095 --- /dev/null +++ b/test/sql/hnsw/hnsw_join_macro.test @@ -0,0 +1,19 @@ +require vss + +statement ok +CREATE TABLE t1 (id int, vec FLOAT[3]); + +statement ok +INSERT INTO t1 SELECT row_number() over (), array_value(a,b,c) FROM range(1,10) ra(a), range(1,10) rb(b), range(1,10) rc(c); + +statement ok +CREATE TABLE s(s_vec FLOAT[3]); + +statement ok +INSERT INTO s VALUES ([5,5,5]), ([1,1,1]); + +query III +SELECT len(matches) = 3 FROM s, vss_match(t1, s_vec, vec, 3) as res; +---- +true +true \ No newline at end of file